Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPMD] Multi-host batch sharded data loading #5331

Merged
merged 9 commits into from
Aug 4, 2023
Merged

Conversation

khatwanimohit
Copy link
Collaborator

@khatwanimohit khatwanimohit commented Jul 21, 2023

This PR is changing the way we do multi-host data loading. Here are some changes user script will need.

  1. Add a Sampler
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas = xr.process_count(), rank = xr.process_index())
  1. Using sampler in torch.utils.data.DataLoader
loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=FLAGS.batch_size // xr.process_count(),
        sampler=sampler,
        drop_last=FLAGS.drop_last,
        shuffle=False)
  1. Change sharding_spec in MpDeviceLoader
loader = pl.MpDeviceLoader(
          loader,
          device,
          input_sharding=xs.ShardingSpec(input_mesh, (0, 1, 2, 3), minibatch = True))

minibatch flag denotes that input is already sharded along the batch axes. Multi-host dataloading currently only supports batch dimension sharding.

@khatwanimohit khatwanimohit requested a review from jonb377 July 21, 2023 19:04
@khatwanimohit khatwanimohit force-pushed the mohit/dataloading branch 2 times, most recently from 165c739 to ddf982c Compare July 21, 2023 19:13
torch_xla/csrc/xla_sharding_util.cpp Outdated Show resolved Hide resolved
return std::make_shared<XLATensor::ShardingSpec>(
ShardingUtil::CreateOpSharding(
tile_assignment, group_assignment, replication_groups,
ShardingUtil::ShardingType(sharding_type)),
CreateComputationShapeFromTensor(tensor, nullptr));
CreateComputationShapeFromTensor(tensor, nullptr), minibatch);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to attach the global shape to the ShardingSpec here. CreateComputationShapeFromTensor(tensor, nullptr) will return the shape of the minibatch.

I think we get the correct xla::Shape by using CreateComputationShapeFromTensor on the minibatch to get the local shape and use set_dimensions to set dim 0 to tensor.sizes[0] * global_device_count / local_device_count.

We'll also need to enforce that all devices in the tiling assignment are on the batch axis.

@@ -937,8 +937,10 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
std::vector<std::string> local_devices =
runtime::GetComputationClient()->GetLocalDevices();
xla::OpSharding sharding;
bool minibatch;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should initialize, bool minibatch = false;

torch_xla/csrc/tensor_util.cpp Outdated Show resolved Hide resolved
if (device_index.find(core) == device_index.end()) {
// Skip any shards whose device is not part of the `devices` list.
continue;
if (minibatch) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can simplify the logic some here:

  for (int i = 0; i < devices.size(); i++) {
    auto indices = std::vector<at::indexing::TensorIndex>(tensor_shape.size(), at::indexing::Slice());
    indices[0] = at::indexing::Slice(i * shard_shape[0], (i + 1) * shard_shape[0]);
    shard_indices.push_back(slices);
  }

torch_xla/csrc/xla_sharding_util.cpp Show resolved Hide resolved
torch_xla/csrc/xla_sharding_util.h Outdated Show resolved Hide resolved
@alanwaketan
Copy link
Collaborator

Do we have a design doc for this? It seems quite a large change in the user experience land. Trying to understand more about this before reviewing the actual code.

@khatwanimohit khatwanimohit force-pushed the mohit/dataloading branch 2 times, most recently from 81cb07f to 09e9410 Compare July 25, 2023 19:35
@khatwanimohit khatwanimohit requested a review from jonb377 July 25, 2023 21:49
@khatwanimohit khatwanimohit marked this pull request as ready for review July 26, 2023 00:22
runtime::GetComputationClient()->GetAllDevices().size();
if (minibatch) {
XLA_CHECK(tile_assignment.size() == num_global_devices)
<< "Sharding of input is only supported along batch dimension";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Minibatch sharding only supports sharding along the batch dimension"

return shard_shape;
} else {
TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type();
}
}

std::vector<std::vector<at::indexing::TensorIndex>>
ShardingUtil::GetShardIndicesForBatchShardedTensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: can we rename this GetShardIndicesForMinibatchTensor? Just to be explicit that this is a helper for the minibatch case

shard_shape, tensor.sizes().vec(), sharding, devices);
auto shard_shape = GetShardShape(tensor.sizes().vec(), sharding);
if (minibatch) {
shard_shape[0] = tensor.sizes().vec()[0] / devices.size();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call GetShardShape with the global shape instead of modifying the shard shape here? I think you'd just need to change the sharding parameter from xla::OpSharding to ShardingSpecPtr, then the global shape can come from the ShardingSpec,

We could then also get rid of the minibatch parameter

int num_local_devices =
runtime::GetComputationClient()->GetLocalDevices().size();
int num_global_devices =
runtime::GetComputationClient()->GetAllDevices().size();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These can go in the if (minibatch) block

@khatwanimohit khatwanimohit requested a review from jonb377 July 27, 2023 18:57
Comment on lines 53 to 54
sharding = xla::HloSharding::Replicate().ToProto();
shard_shape = ShardingUtil::GetShardShape(tensor, sharding);
sharding_spec->sharding = sharding;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit - it would be a bit easier to read if we use sharding_spec->sharding as the source of truth, i.e. remove the sharding variable and update like sharding_spec->sharding = xla::HloSharding::Replicate().ToProto();

@@ -55,12 +61,15 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
std::vector<std::string> devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3"};

auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why nullptr instead of GetDefaultDevice() like on L40?

CreateComputationShapeFromTensor(minibatch_tensor, nullptr);
tensor_shape.set_dimensions(
0, minibatch_tensor.sizes()[0] * 2); // Assuming 2 hosts
xla::Array2D<int64_t> mesh({
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make the tensor 2D or the mesh 3D? I'm surprised that this is working with a 3D tensor over a 2D mesh

auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, tensor_shape, /*minibatch=*/true);
auto shards = ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec,
devices, /*padded=*/false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why false on the padding? Do we need to disable padding in the minibatch case generally?

const py::list& replication_groups, int sharding_type) {
const py::list& replication_groups, int sharding_type,
bool minibatch) {
xla::Shape tensor_shape =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this global_shape to be more explicit?

<< "Input shard shape must include padding: " << shard.sizes()
<< " vs " << shard_shape;
<< "Input shard shape must include padding: " << shard.sizes();
// << " vs " << shard_shape;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be okay to add back in since GetShardShape returns a vector

std::vector<at::Tensor> local_shards = ShardingUtil::ShardTensor(
tensors[i], sharding, local_devices, /*padded=*/true);
std::vector<at::Tensor> local_shards =
ShardingUtil::ShardTensor(tensors[i], shardings[i], local_devices,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shardings[i] could be nullptr, in which case we want to pass in a REPLICATED sharding spec to ShardTensor. Or, we can handle a null input in this if block:

if (sharding.type() == xla::OpSharding::REPLICATED) {

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also curious on the deletion of the null check.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is handled within ShardingUtil::ShardTensor line 485..

Comment on lines 959 to 960
new_handles.push_back(ShardingUtil::CreateShardedData(
local_shards, local_devices, shape, sharding));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is turning into a larger refactor, but can we also just pass the ShardingSpecPtr into CreateShardedData?

Then we won't need L955-958 and L939-948

torch_xla/csrc/xla_sharding_util.cpp Outdated Show resolved Hide resolved
@khatwanimohit khatwanimohit force-pushed the mohit/dataloading branch 2 times, most recently from 36b6292 to da04c4c Compare July 28, 2023 19:30
std::vector<at::Tensor> local_shards = ShardingUtil::ShardTensor(
tensors[i], sharding, local_devices, /*padded=*/true);
std::vector<at::Tensor> local_shards =
ShardingUtil::ShardTensor(tensors[i], shardings[i], local_devices,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also curious on the deletion of the null check.


xla::OpSharding sharding;
// Optional source tensor shape unpartitioned.
std::optional<xla::Shape> shape;
xla::Shape shape;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is not optional anymore? Then I guess you need to delete the above comment as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that global tensor's shape doesn't reflect the truth.

for (int j = 0; j < tensor_shape.size(); j++) {
indices.push_back(at::indexing::Slice(0, tensor_shape[j]));
}
indices[0] =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why overwrites the 0th index?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only care about index 0 because that's the batch dim and then the remaining are just copys.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@khatwanimohit Can you add a comment to clarify this?

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please address @jonb377 comments as well.

// Returns the indices for the shards. Supports `OTHER` sharding types and
// called when input is sharded along the batch axis.
static std::vector<std::vector<at::indexing::TensorIndex>>
GetShardIndicesForMinibatchTensor(const std::vector<int64_t>& shard_shape,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shard_shape is the local shard shape for the host. tensor_shape is the global tensor shape. devices is the tpu devices for the local host.

Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM, thanks Mohit!

std::vector<std::string> devices = {"TPU:4", "TPU:5", "TPU:6", "TPU:7"};
at::Tensor minibatch_tensor =
at::ones({8, 7, 4}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: global_shape instead of tensor_shape

for (int j = 0; j < tensor_shape.size(); j++) {
indices.push_back(at::indexing::Slice(0, tensor_shape[j]));
}
indices[0] =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@khatwanimohit Can you add a comment to clarify this?

std::vector<std::vector<at::indexing::TensorIndex>>
ShardingUtil::GetShardIndicesForMinibatchTensor(
const std::vector<int64_t>& shard_shape,
const std::vector<int64_t>& tensor_shape, const xla::OpSharding sharding,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking in tensor_shape is ambiguous - should it be the global shape or minibatch shape? It would be easier to follow if we remove this as a parameter and use shard_shape for all dimensions except batch. You can add another XLA_CHECK on the sharding->tile_shape or sharding->tile_assignment_devices to ensure that the shard_shape corresponds to pure batch sharding.

Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, great stuff Mohit! 🚀

@khatwanimohit khatwanimohit merged commit f2f877f into master Aug 4, 2023
will-cromar added a commit that referenced this pull request Sep 15, 2023
* Sharding should be per output of IR Node, instead of per IR Node (#5330)

* sharding should be per output of IR Node, instead of per IR Node

* Update sharding_hash method

* Add test for sharding on IR with multiple output

* fix cpu test

* Fix a bug in getSharding

* Update Python device API for SPMD (#5129)

* Make python Api to respect the virtual device when SPMD is enabled

* fix typo

* Check out the release branch instead of origin/master in ansible (#5344)

* Also dump output sharding on HLO file (#5339)

* Also dump output sharding on HLO file

* only dump output sharding if dump format is HLO

* add test

* fix typo

* Make all-reduce a no-op when world size is 1 (#5342)

* Make all-reduce a no-op when world size is 1

* Fix torch.distributed test

* add fs linker flag (#5347)

* Add py3.10 whl path to doc, refactor whl table (#5354)

* fix amp dtype setting for GPU (#5337)

* fix amp dtype setting for GPU.

* fix ut

* fix lint.

* minor.

* Add python test for SPMD+Runtime Python API (#5349)

* Add python test for SPMD+Runtime Python API

* replace test name

* Update test_xla_spmd_python_api_interaction.py

* Check the actual device instead of query env var for virtual device (#5352)

* Check the actual device instead of query env var for virtual device

* revert unneeded change

* minor changes

* [BE] use self.assertEquals instead of str equality in test_zero1.py (#5364)

* Revert "[BE] use self.assertEquals instead of str equality in test_zero1.py (#5364)" (#5366)

This reverts commit 8ada333.

* [Dynamo|TPU] Tweak `atol` and `rtol` for `test_dynamo.py` (#5363)

* tweak `atol` and `rtol`

* [Dynamo|TPU] Skip`DynamoTrainingBasicTest.test_resnet18` on TPU (#5362)

*  Skip`DynamoTrainingBasicTest.test_resnet18` on TPU

* Add a script for running stablehlo tests. (#5360)

* Add kokoro presubmit for stablehlo tests

* Don't rewrite index hints in global save planning (#5348)

* [Dynamo|TPU] Skip `DynamoInferenceBasicTest.test_resnet18` on TPU (#5361)


* Skip `DynamoInferenceBasicTest.test_resnet18` on TPU

* [BE] use self.assertEquals instead of str equality in test_zero1.py (#5367)

* [BE] use self.assertEquals instead of str equality in test_zero1.py

* Use our own assertEqual

* Remove print statements

* Fix ReplicateShardedData for int type (#5374)

* Fix ReplicateShardedData for int type

* add test

* Update dynamo.md (#5378)

Update dynamo.md to remove note about fallback ops since they're supported now

* Revert "Fix ReplicateShardedData for int type (#5374)" (#5380)

This reverts commit 7fb7dfe.

* Remove the mention of XRT_TPU_CONFIG in the CONTRIBUTING.md (#5379)

* [Dynamo|TPU] Tweak `atol` and `rtol` for `test_simple_model_with_different_input_shape` on TPU (#5373)

* tweak `atol` and `rtol` for `test_simple_model_with_different_input_shape` on TPU

* Rectify test_zero1.py once optim.load_state_dict doesn't guarantee immutability (#5382)

* [TEST ONLY] print statements for test_zero1.py to debug

* Try fix

* Rectify test_zero1.py to account for state_dict modification

* Fix lint

* Add gpu doc for how to build PyTorch/XLA from source with GPU support. (#5384)

* Add gpu doc for how to build PyTorch/XLA from source with GPU support.

* fix typo

* fix comments

* fix comments

* clear pending ir should also clear the cc op tokens (#5385)

* Port resnet data loading optimizations to SPMD test script (#5386)

* Add support for in-place ops with self tensors in dynamo bridge (#5309)

* Add more support for in-place ops in dynamo bridge

Run linter

* Add check to explicitly sync self tensors

Remove debugging lines

Update unit tests to a model

* Clean up some code

Surround  in an if-statement

Update metrics for fallback related dynamo tests

Update cloned args logic

Revert "Update metrics for fallback related dynamo tests"

This reverts commit 3855f43.

* Update single_node flag back to False

* Add dynamo test in TPU CI (#5381)

Add dynamo test in TPU CI

* Add manual seed in multihost checkpoint (#5392)

* Fix change_id type in coverage uploading (#5394)

* Update dynamo cpu fallback op to aten::_foobar (#5393)

* Run single host multi GPU tests in the CI. (#5387)

* Add gpu doc for how to build PyTorch/XLA from source with GPU support.

* Run single host multi GPU tests.

* fix linter

* fix linter

* fix error

* fix test

* [PJRT] Separate collective ops test from TPU runtime test. (#5396)

* [PJRT] Separate collective ops test from TPU runtime test.

* formatting

* Fix ReplicateShardedData for int type (#5404)

* Update the dynamo backend name to `openxla` (#5402)

* Replace aot backend with openxla

* Update the inference backend except the fallback tests

* handle the fallback tests

* update remaining test

* update doc

* add torch pin

* Delete .torcch_pin

* linter

* [SPMD] Multi-host batch sharded data loading (#5331)

* Refactor to share code between export_torch_model and save_as_stablehlo (#5388)

* Refactor to share code between export_torch_model and save_as_stablehlo

* Fix TPU collective ops test for multi-host TPUs (#5408)

* Fix TPU collective ops test for multi-host TPUs

* formatting

* Partially replicate lower-rank tensors (#5409)

* Partially replicate lower-rank tensors

* Fix unit test

* Remove unnecessary device count check

* Fix unordered partition spec test

* yapf

* Revert "Partially replicate lower-rank tensors (#5409)" (#5412)

This reverts commit 56a6a02.

* SPMD cross slice-replication using partial_replication sharding (#5411)

* Revert "Support unordered sharding spec for partial replication (#5316)"
* Update test_2d_tensor_3d_mesh unit test to surface a bug
* Use partial replication for 2D tensor over 3D mesh sharding

* Fix the incorect clone arg condition in dynamo bridge (#5414)

* [SPMD] named partition spec support (#5415)

[SPMD] named partition spec

* [PJRT|TPU] Update `test_xla_devices_single_process_all_chips` for expected device number (#5421)

Update `test_xla_devices_single_process_all_chips` for expected device number

* Add repo for libcudnn8=8.7.0.84 and CUDA 11.8 (#5425)

* Update fix_includes.sh (#5441)

Without this patch I cannot get torch_xla to build outside of the docker. This should fix it.

* [PJRT] Support `torchrun` with `pjrt://` `init_method` (#5438)

* Support torchrun with `pjrt://` `init_method`

* move import

* fix error

* Fix NameError

* Fix path

* Remove from TPU CI

* Bugfix + add more test for llama (#5439)

Bugfix details:
1. When the graph have mutations the exported graph will have additional
   inputs. For now we are dropping them.
2. We should trace with args instead of final_args.

* Move the C++ test build to CI build job instead of test job (#5442)

* Update gcc to 10. (#5445)

* Update gcc to 10,

And use unversioned clang-format (so it's installation will succeed)
in both debian bullseye and buster

* gcc10 to ansible

* Update the random seed for every dynamo execution (#5444)

* Revert "Update gcc to 10. (#5445)" (#5449)

This reverts commit 454e916.

Co-authored-by: JackCaoG <59073027+JackCaoG@users.noreply.github.com>

* Install gcc-10 (#5450)

* Revert "Install gcc-10 (#5450)" (#5452)

This reverts commit 65b7639.

* parallelize SPMD inputhandler and GetDataShards (#5447)

* parallelize SPMD inputhandler and GetDataShards

* add output handler trace

* Remove base image override from TPU CI build (#5453)

* Update to GCC 10 (#5451)

* Cache sharded placeholder for dynamo execution (#5446)

* Cache the output sharding spec for dynamo

* address review comments

* add test

* remove dead code

* add missing wait deivce ops

* Update xla_graph_executor.cpp

* linter

* Remove Docker image override from dev image (#5456)

* hack: implement (unimplement?) GetDataShard for XRT

* skip flaky test (#5459)

* Neuron import hook (#5429)

* Enable Neuron import hook for calling initialization functions if using AWS Neuron

* removing copy/paste error

* moving aws init call and removing comment

* Add missing includes (#5434)

* Add missing includes

Currently this is included indirectly through PyTorch includes, but when I remove
the include from PyTorch's headers, the xla build fails.

* [TESTING] Pin PyTorch PR

* Retrigger CI after timeout

* Remove .torch_pin

* [GPU]Update README.md with wheel/docker for CUDA12.0 and deprecate CUDA11.7 (#5443)

* [GPU]Update README.md with wheel and docker support CUDA12.0 and deprecate CUDA 11.7

* Update README.md with docker support CUDA 12.0 and python 3.8

* Update README.md

* Update README.md

* update remote cache key in ansible (#5463)

* Fix data type in Pow with Scalar base and Tensor exponent (#5467)

* fix dtype inference

* fix linter

* bump the timeout for CI (#5470)

* Fix the input sharding for dynamo (#5469)

* Enabling sharding device data IR (#5475)

* Allow shard device data IR

* Handle XLATensor that is DeviceData IR and does not have XLAData

* fix typo

* Introduce `torch_xla.runtime.use_spmd()` (#5474)

Introduce torch_xla.runtime.use_spmd() and torch_xla.runtime.is_spmd()

* Enable PJRT C API Client and other changes for Neuron (#5428)

* Enable PJRT C API Client and other changes for Neuron

* keeping quotes consistent

* fixing device type call

* refactoring neuron initialization with spawn

* updating replication setting only for torchrun

* removing set replication in xla backed was added to rendezvous handler

* removing workaround for world_size/master_port for neuron

* fixing linter issues

* Don't move full tensor to device in deferred_init (#4819)

* [SPMD] Fix HybridMesh ordering (#5478)

Summary:
In xs.HybridMesh, it assumes the xr.global_runtime_device_attributes() will return the attributes according to the PyTorch/XLA's logical global ordinals. However, it turns out not to be the case.

To fix this, we pass the logical global ordinal as one of the attributes and xs.HybridMesh will sort the attributes according to this new attribute before using the array.

Test Plan:
PJRT_DEVICE=TPU USE_XLA_SPMD=1 python test/spmd/test_xla_sharding.py -v -k test_hybrid_mesh

* [SPMD] Properly skip tests on TPU V2 (#5479)

Summary:
Some of the tests only fail on TPU V2 but were skipped for all TPUs.
Let's fix that.

Test Plan:
PJRT_DEVICE=TPU USE_XLA_SPMD=1 python test/spmd/test_xla_sharding.py

* Add @yeounoh to .github CODEOWNERS (#5482)

* Add Python API to execute StableHLO bytecode (#5476)

* [SPMD] Fix TPU CI after #5478 (#5487)

* [SPMD] Fix TPU CI after #5478

Summary:
Let's fix all TPU CI failures after #5478.

Test Plan:
TPU CI

* Fix linters

* [SPMD] Fix XLA_DUMP_POST_OPTIMIZATIONS test (#5485)

Summary:
XLA_DUMP_POST_OPTIMIZATIONS was set as static which means that the value will be fixed during the whole test run for a particular test suite.

Therefore, let's make a separate file.

Test Plan:
PJRT_DEVICE=TPU USE_XLA_SPMD=1 python test/spmd/test_xla_sharding.py
PJRT_DEVICE=TPU USE_XLA_SPMD=1 python test/spmd/test_xla_sharding_hlo.py

* [Dist] Refactor ZeRO-1 (#5145)

* refactor

* fix

* fix

* add padding

* more robust save/load

* Update artifacts.auto.tfvars for 2.1 release (#5483)

* Update artifacts.auto.tfvars for 2.1 release

Update artifacts.auto.tfvars for 2.1 release

* Remove cuda version 11.7 and add 12.0 for 2.1 triggers

* Add 3.10 tpu version

* Add ShardingSpec to XLATensor when it is created with a PJRTShardedData (#5489)

* Add ShardingSpec to XLATensor when it is created with a PJRTShardedData

* add test

* Add topological sorting to dynamo partitions (#5472)

* Add topological sorting to dynamo partitions

* Run linter

* Update unit tests to include more in-place ops

* [SPMD] Patch nn.Linear (#5491)

Summary:
This pull request introduces a patched version of torch.nn.functional.linear that uses einsum instead of torch.matmul which will flatten the tensors to 2D and collide the sharded dimensions. The torch.matmul default behavior makes it very hard for XLA compiler to propagate the sharding annotation.

Test Plan:
PJRT_DEVICE=CPU python test/test_operations.py -v -k test_patched_linear

* [original author: mrnikwaws] Neuron operator support (#5471)

* adding glu operator support

* adding glu operator

* fixing yaml

* fixing linter issues

* fixing linter issues

* fixing spacing

* fixing spacing

* fixing spacing

* fixing spacing

* fixing shape helper

* fixing spacing

* [SPMD] Make IR sharding custom sharding op (#5433)

Summary:
This pull request changes the syntax of IR sharding by making it a new node instead of just attaching the sharding spec to the tensor. On the same time, we will still attach a sharding spec to the newly created XLATensor which will hold the new IR node.

This new IR node will be a CustomSharding node and in hlo:
%annotate = f32[6,3] custom-call(%copy), custom_call_target="Sharding", sharding={devices=[2,1]0,1}

Test Plan:
PJRT_DEVICE=TPU XLA_USE_SPMD=1 python test/spmd/test_xla_sharding.py -v -k test_mark_sharding_ir
PJRT_DEVICE=TPU XLA_USE_SPMD=1 python test/spmd/test_xla_sharding.py -v -k test_inplace_add_with_sharding

* Support input sharding changed after first dynamo tracing (#5477)

* Support input sharding changed after first dynamo tracing

* fix linter

* Handle the different input for dynamo sharding change

* update counter

* only get sharding specs when spmd is enabled

* add option to skip checking input sharding after x runs

* handle the cpu test

* make XLA_DYNAMO_INPUT_SHARDING_CHECK_THREASHOLD configable

* fix review comments

* Always use ExecuteReplicated with SPMD (#5494)

* Always use ExecuteReplicated with SPMD

* Add unit test

* Skip a couple tests on TPU due to precision issue (#5496)

* Refactor stablehlo API and put them in official location. (#5493)

Changes include:

* make end point in torch_xla/init.py for exposed APIs torch_xla.save_as_stablehlo and torch_xla.save_torch_model_as_stablehlo.
* All tf related integration to its own file.
* Remove args as argument (because it will spear inside of ExportedProgram) but allow user to override it (which we use for now.

* Support tuples in partition spec (#5488)

* Support tuples in partition spec

* Add unit test for partial replication

* yapf

* Support higher-rank tensors over lower-rank mesh

* Fix test & yapf

* Don't use partition_spec when creating group assignment

* Update documentation

* More documentation

* Translate named specs in ShardingSpec

* Add a API to explictly init runtime (#5500)

* Add explict error message when tensor is on CPU for dynamo backend (#5499)

* remove torchvision in stablehlo.py (#5501)

* Fix tupled partition spec test on v3 (#5503)

* Update dynamo doc (#5506)

* Update dynamo.md (#5509)

fixing typo

* Get original_traced_args as example_inputs. (#5511)

Change due to changing name in pytorch/pytorch#107978

* mark_sharding over a replicated tensor is allowed. (#5513)

* [SPMD] Propagate replicated output (#5508)

Summary:
During the LLaMA2 experiements, I disovered that manually marking 1D tensors to be replicated can greatly save a lot of memory. Then I disocvered that explicitly replicated spec will get dropped after mark_step. That is caused by PrepareOutputShardingPropagation where it explicitly clear the sharding spec for replicated output. So, I went ahead and fix that.

Further, I did some experiements of propogating replicated output and that drop the requirements of manually replicating 1D tensors. Hence, I made this change.

I'm still not quite sure why, will follow up later.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py

* Disable cxx abi in ansible when building pt/xla for branch r2.0 (#5332)

* Update pytorch git tag for r2.1 (#5529)

Update more places

Add torch_pin

* Enable megacore_dense by default (#5520) (#5531)

Summary:
This change enables megacore_dense by default to allow asynchorous cc
ops especailly for GSPMD.

Test Plan:
CI

Co-authored-by: Jiewen Tan <jwtan@google.com>

* Add option to unbundle libtpu (#5534) (#5536)

* Add optiona to unbundle libtpu

* Add clarifying note

* Revert 2.1 terraform changes (#5537)

* Fix FSDP for Models with Frozen Weights (#5484) (#5539)

* Fix fsdp not freeing forzen full params

* add test

* formatting

* remove unnecessary env var in test

Co-authored-by: Liyang90 <liyanglu@google.com>

* Update r2.1 wheel to be compatible with PyPI (#5550)

* Update project metadata and remove useless files

* Update README

* Add manylinux platform tag

* formatting

* Add resnet50-weight-quant colab notebook (#5407) (#5556)

* Add resnet50-weight-only-quant colab notebook

* update notebook with llama blog link

Co-authored-by: Siyuan Liu <lsiyuan@google.com>

* hack: add placeholders for `HasSharding` and `GetSharding` to XRT

* formatting

* hack: always return false from `HasSharding`

* Update torch pin to current RC for CI testing

* Cherry pick `pjrt://` init method rename and doc updates (#5562)

* Change `pjrt://` init method to `xla://` (#5560)

* Update PJRT documentation for the 2.1 release (#5557)

* Update PJRT documentation for the 2.1 release

* clarify plugins

* clarify PJRT doc

* Update `pjrt://` to `xla://`

* Use new cache silo and skip test build

* hack: disable missing test

* hack: alter cache silo name

* formatting

---------

Co-authored-by: JackCaoG <59073027+JackCaoG@users.noreply.github.com>
Co-authored-by: iefgnoix <isaacwxf23@gmail.com>
Co-authored-by: Siyuan Liu <lsiyuan@google.com>
Co-authored-by: Baole Ai <baoleai01@gmail.com>
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Co-authored-by: Manfei <41607353+ManfeiBai@users.noreply.github.com>
Co-authored-by: qihqi <hanq@google.com>
Co-authored-by: jonb377 <jonbolin@google.com>
Co-authored-by: Wonjoo Lee <wonjoo@google.com>
Co-authored-by: Mohit Khatwani <118776932+khatwanimohit@users.noreply.github.com>
Co-authored-by: Yeounoh Chung <yeounoh@google.com>
Co-authored-by: Mateusz Lewko <mateusz.lewko@gmail.com>
Co-authored-by: Alisson Azzolini <37222419+aazzolini@users.noreply.github.com>
Co-authored-by: aws-kingrj <78175353+aws-kingrj@users.noreply.github.com>
Co-authored-by: peterbell10 <peterbell10@live.co.uk>
Co-authored-by: Zach Zheng <zczheng@amazon.com>
Co-authored-by: Jiewen Tan <jwtan@google.com>
Co-authored-by: Huang, Guangtai <guangtai@amazon.com>
Co-authored-by: Shauheen <shauheen@users.noreply.github.com>
Co-authored-by: Liyang90 <liyanglu@google.com>
EXPECT_EQ(shards.size(), 4);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({4, 2, 4}));
EXPECT_EQ(shards[3].sizes(), c10::ArrayRef<long>({4, 1, 4}));
}

TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
Copy link
Collaborator

@miladm miladm Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@khatwanimohit did we test minibatch=True on SPMD 2D sharding?

cc @tengyifei

return shard_shape;
} else {
TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type();
}
}

std::vector<std::vector<at::indexing::TensorIndex>>
ShardingUtil::GetShardIndicesForMinibatchTensor(
const std::vector<int64_t>& shard_shape,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@khatwanimohit wonder if we can revisit the minibatch logic for the 2D model sharding scheme.

More specifically, I am curious to learn if the logic around shard_shape properly handles 2D sharding.

cc @tengyifei

Copy link
Collaborator

@miladm miladm Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like 2D sharding + minibatch is disallowed; if so, @khatwanimohit to help confirm why the following check not always throw an error.

https://github.com/pytorch/xla/blob/2f5dc1f10e41afbf1c73497ea426a79728b2a1bd/torch_xla/csrc/runtime/tf_logging.h

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants