Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multihost SPMD execution #4573

Merged
merged 1 commit into from
Feb 11, 2023
Merged

Support multihost SPMD execution #4573

merged 1 commit into from
Feb 11, 2023

Conversation

jonb377
Copy link
Collaborator

@jonb377 jonb377 commented Feb 6, 2023

The only main change to support multihost execution is to restrict the generated shards in ShardTensor to those which belong to addressable devices.

@jonb377 jonb377 requested a review from yeounoh February 6, 2023 18:16
@jonb377 jonb377 force-pushed the jonbolin-multihost-spmd branch 2 times, most recently from b86aaec to 9e5fb70 Compare February 7, 2023 19:36
@@ -88,7 +89,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
Examples
—------------------------------
mesh_shape = (4, 2)
num_devices = len(xm.get_xla_supported_devices())
num_devices = pjrt.global_device_count()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great :)

const std::vector<std::string>& devices) {
std::unordered_map<int, int> device_index;
for (int i = 0; i < devices.size(); ++i) {
int global_ordinal = ParseDeviceString(devices[i]).ordinal();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first global device gets the local index 0, so the order of the input devices list is important. Is this a correct understanding? Can we add some comments on this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first device in the list gets local index 0, but the order of the global ordinals within devices doesn't matter. I'll add some more documentation around this.

Copy link
Collaborator

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, a minor comment.

@jonb377 jonb377 force-pushed the jonbolin-multihost-spmd branch 3 times, most recently from ba5ed74 to e999a95 Compare February 10, 2023 19:12
@@ -931,30 +931,24 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
std::vector<xla::ComputationClient::DataPtr> new_handles; // out
if (shardings[i] != nullptr) {
xla::OpSharding sharding = shardings[i]->sharding;
// TODO(yeounoh) PJRT runs a process per host for SPMD and without cross
// host communications. This means that we may need to manually shard
// across global devices for multi-host training.
std::vector<std::string> local_devices =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does GetLocalDevices() return local devices with global ordinals? If so, let's leave a comment.

Copy link
Collaborator

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, have 2 comments --nit.

@jonb377 jonb377 merged commit ced6456 into master Feb 11, 2023
@jonb377 jonb377 deleted the jonbolin-multihost-spmd branch February 11, 2023 22:19
mateuszlewko pushed a commit that referenced this pull request Mar 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants