Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed] Make xm.all_gather a single graph in Dynamo #4922

Merged
merged 10 commits into from
Apr 22, 2023

Conversation

alanwaketan
Copy link
Collaborator

Summary:
This pull request makes xm.all_gather, the _all_gather_using_all_reduce path, a single graph in Dynamo. To do that, it:

  1. removes a hardware type check, specialize CPU doesn't seem to be worth it.
  2. caches ordinal and xrt_world_size.

Test Plan:
PJRT_DEVICE=TPU python test/test_mp_all_gather.py

@alanwaketan alanwaketan changed the base branch from master to alanwaketan/all_reduce_d April 21, 2023 05:04
@@ -78,7 +78,7 @@ def get_xla_supported_devices(devkind=None, max_devices=None):
if kind_devices:
return kind_devices[:max_devices] if max_devices else kind_devices


g_xrt_world_size = None
def xrt_world_size(defval=1):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wconstab This is the python function that I want to use in 'allow_in_graph'.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, if you are going to manually cache the value of this anyway, then i think just using allow_in_graph without the caching is the same thing.

the issue with allow_in_graph is if you expect the value to be updated on later iterations, allow_in_graph will prevent that from working. But if you expect the value to be a constant for the whole execution, then allow_in_graph will capture the value during compile and reuse it later (e.g. cache it)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to use allow_in_graph. However, it looks like that the function I pass into allow_in_graph will need to return a tensor type? If the function return a bool or int, is there a workaround?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is how I use allow_in_graph:

ptxla@t1v-n-307ffe96-w-0:/workspaces/work/pytorch/xla$ git diff
diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py
index 6ff4a5a5..a07ff472 100755
--- a/torch_xla/core/xla_model.py
+++ b/torch_xla/core/xla_model.py
@@ -6,6 +6,7 @@ import time
 from typing import List, Optional
 import torch
 import torch.distributed._functional_collectives
+from torch._dynamo import allow_in_graph
 import torch.nn.functional as F
 import torch_xla
 from torch_xla.experimental import pjrt
@@ -1088,3 +1089,6 @@ def optimization_barrier_(tensors):
     tensors (List[torch.Tensor]): List of `torch.Tensor` to add barrier to.
   """
   torch_xla._XLAC._xla_optimization_barrier_(tensors)
+
+
+allow_in_graph(xrt_world_size)

And here is the error:

root@t1v-n-307ffe96-w-0:/workspaces/work/pytorch/xla# PJRT_DEVICE=TPU python test/test_mp_all_gather.py 
concurrent.futures.process._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
    return [fn(*args) for args in chunk]
  File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
    return [fn(*args) for args in chunk]
  File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 92, in wrapper
    return fn(*args, **kwargs)
  File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 245, in _run_thread_per_device
    replica_results = list(
  File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
  File "/usr/local/lib/python3.8/concurrent/futures/thread.py", line 57, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 238, in _thread_fn
    return fn()
  File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 341, in __call__
    self.fn(global_ordinal(), *self.args, **self.kwargs)
  File "/workspaces/work/pytorch/xla/test/test_mp_all_gather.py", line 32, in _mp_fn
    result = compiled_all_gather(ordinal_tensor, dim=0)
  File "/workspaces/work/pytorch/torch/_dynamo/eval_frame.py", line 252, in _fn
    return fn(*args, **kwargs)
  File "/workspaces/work/pytorch/torch/_dynamo/eval_frame.py", line 405, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
  File "/workspaces/work/pytorch/torch/_dynamo/convert_frame.py", line 122, in _fn
    return fn(*args, **kwargs)
  File "/workspaces/work/pytorch/torch/_dynamo/convert_frame.py", line 331, in _convert_frame_assert
    return _compile(
  File "/workspaces/work/pytorch/torch/_dynamo/utils.py", line 169, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspaces/work/pytorch/torch/_dynamo/convert_frame.py", line 401, in _compile
    out_code = transform_code_object(code, transform)
  File "/workspaces/work/pytorch/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
    transformations(instructions, code_options)
  File "/workspaces/work/pytorch/torch/_dynamo/convert_frame.py", line 386, in transform
    tracer.run()
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1972, in run
    super().run()
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 670, in run
    and self.step()
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 630, in step
    getattr(self, inst.opname)(inst)
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 355, in wrapper
    return inner_fn(self, inst)
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1138, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 521, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/workspaces/work/pytorch/torch/_dynamo/variables/functions.py", line 269, in call_function
    return super().call_function(tx, args, kwargs)
  File "/workspaces/work/pytorch/torch/_dynamo/variables/functions.py", line 102, in call_function
    return tx.inline_user_function_return(
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 557, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2077, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2155, in inline_call_
    tracer.run()
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 670, in run
    and self.step()
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 630, in step
    getattr(self, inst.opname)(inst)
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 355, in wrapper
    return inner_fn(self, inst)
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1138, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 521, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/workspaces/work/pytorch/torch/_dynamo/variables/functions.py", line 269, in call_function
    return super().call_function(tx, args, kwargs)
  File "/workspaces/work/pytorch/torch/_dynamo/variables/functions.py", line 102, in call_function
    return tx.inline_user_function_return(
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 557, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2077, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2155, in inline_call_
    tracer.run()
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 670, in run
    and self.step()
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 630, in step
    getattr(self, inst.opname)(inst)
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 355, in wrapper
    return inner_fn(self, inst)
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1086, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 521, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/workspaces/work/pytorch/torch/_dynamo/variables/torch.py", line 603, in call_function
    tensor_variable = wrap_fx_proxy(
  File "/workspaces/work/pytorch/torch/_dynamo/variables/builder.py", line 923, in wrap_fx_proxy
    return wrap_fx_proxy_cls(
  File "/workspaces/work/pytorch/torch/_dynamo/variables/builder.py", line 1098, in wrap_fx_proxy_cls
    unimplemented(
  File "/workspaces/work/pytorch/torch/_dynamo/exc.py", line 107, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function <function xrt_world_size at 0x7fb184e94ca0>

from user code:
   File "/workspaces/work/pytorch/xla/torch_xla/core/xla_model.py", line 550, in all_gather
    return _all_gather_using_all_reduce(
  File "/workspaces/work/pytorch/xla/torch_xla/core/xla_model.py", line 511, in _all_gather_using_all_reduce
    left, right = ordinal, xrt_world_size() - 1 - ordinal

Set torch._dynamo.config.verbose=True or TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "test/test_mp_all_gather.py", line 66, in <module>
    xmp.spawn(_mp_fn, args=())
  File "/workspaces/work/pytorch/xla/torch_xla/distributed/xla_multiprocessing.py", line 367, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 365, in spawn
    _run_multiprocess(spawn_fn, start_method=start_method)
  File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 92, in wrapper
    return fn(*args, **kwargs)
  File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 322, in _run_multiprocess
    replica_results = list(
  File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 323, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function <function xrt_world_size at 0x7fb184e94ca0>

from user code:
   File "/workspaces/work/pytorch/xla/torch_xla/core/xla_model.py", line 550, in all_gather
    return _all_gather_using_all_reduce(
  File "/workspaces/work/pytorch/xla/torch_xla/core/xla_model.py", line 511, in _all_gather_using_all_reduce
    left, right = ordinal, xrt_world_size() - 1 - ordinal

Set torch._dynamo.config.verbose=True or TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

root@t1v-n-307ffe96-w-0:/workspaces/work/pytorch/xla# 


g_ordinal = None
def get_ordinal(defval=0):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wconstab This is the python function that I want to use in 'allow_in_graph'.

@alanwaketan alanwaketan changed the base branch from alanwaketan/all_reduce_d to alanwaketan/all_reduce_t April 21, 2023 16:19
@@ -109,10 +114,15 @@ def get_ordinal(defval=0):
Returns:
The replication ordinal of the current thread.
"""
if pjrt.using_pjrt():
return pjrt.global_ordinal()
global g_ordinal
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will break PJRT + v3 cases, the implementation we had checks the devices in

  m.def("_xla_get_default_device_ordinal", []() {
    std::string device_str = GetCurrentThreadDevice();
    torch::lazy::BackendDevice device =
        bridge::AtenDeviceToXlaDevice(device_str);
    return device.ordinal();
  });

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this is confusing. That call is in the C++ layer. Then allow_in_graph won't work here.

But we can work around by caching a map...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure actually, effectively this function won't return constant in the v3 cases because there are two devices per process. This is a bit tricky.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can bypass the v3 cases for now, what's going to happen if you add a condition here to skip this cahce value of we are on v3 + PJRT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will introduce graph breaks in Dynamo.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this use thread local storage instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's cool. Was not aware python has this feature. Let me work on it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dynamo doesn't seem to compile in the same thread as the user code. threading.local doesn't work here.

@@ -533,8 +553,7 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
A tensor which has, in the ``dim`` dimension, all the values from the
participating replicas.
"""
if pin_layout and xla_device_hw(
value.device) in ('TPU', 'GPU', 'XPU') and output == None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we had it because CPU was not supported at some point. Do you need to remove it because it will break dynamo?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@alanwaketan alanwaketan changed the base branch from alanwaketan/all_reduce_t to master April 21, 2023 20:35
@alanwaketan
Copy link
Collaborator Author

Thanks Jack for approving.

@alanwaketan alanwaketan merged commit 486b32e into master Apr 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants