From 36b6292828d1c4553f44126bf4d875acda7d4701 Mon Sep 17 00:00:00 2001 From: Mohit Khatwani Date: Fri, 28 Jul 2023 17:43:22 +0000 Subject: [PATCH] accept sharding_spec in CreateShardedData --- test/cpp/test_xla_sharding.cpp | 59 +++++++++++++------------ torch_xla/csrc/init_python_bindings.cpp | 16 +++---- torch_xla/csrc/tensor_util.cpp | 26 +++-------- torch_xla/csrc/xla_sharding_util.cpp | 17 +++---- torch_xla/csrc/xla_sharding_util.h | 2 +- 5 files changed, 55 insertions(+), 65 deletions(-) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 9442436a38c..02de6e1ef89 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -50,8 +50,7 @@ TEST_F(XLAShardingTest, GetShardShape) { // For tiled sharding, each dimension should be halved EXPECT_EQ(shard_shape, std::vector({4, 4})); - sharding = xla::HloSharding::Replicate().ToProto(); - sharding_spec->sharding = sharding; + sharding_spec->sharding = xla::HloSharding::Replicate().ToProto(); shard_shape = ShardingUtil::GetShardShape(sharding_spec); // For replicated sharding, each dimension should be preserved EXPECT_EQ(shard_shape, std::vector({8, 7})); @@ -61,7 +60,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3"}; auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); - xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array2D mesh({ {0, 1}, {2, 3}, @@ -91,7 +91,6 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { EXPECT_EQ(slice.step(), 1); } } - sharding = xla::HloSharding::Replicate().ToProto(); sharding_spec->sharding = sharding; shard_shape = ShardingUtil::GetShardShape(sharding_spec); @@ -110,7 +109,8 @@ TEST_F(XLAShardingTest, ShardTensor) { // 1D tiled at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat)); - xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::OpSharding sharding = xla::HloSharding::Tile1D( CreateComputationShapeFromTensor(tensor, GetDefaultDevice()), @@ -127,7 +127,7 @@ TEST_F(XLAShardingTest, ShardTensor) { // 2D tiled, The first dim is halved and the last replicated. The last shard // size should be smaller in dim=1 because it's not evenly divisible. tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); - tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + tensor_shape = CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array2D mesh({ {0, 1, 2, 3}, {4, 5, 6, 7}, @@ -144,8 +144,7 @@ TEST_F(XLAShardingTest, ShardTensor) { // 3D tiled, the first dim is replicated and the last halved. The last shard // size should be smaller in dim=1 because it's not evenly divisible. xla::Array3D cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}); - sharding = xla::HloSharding::Tile(cube).ToProto(); - sharding_spec->sharding = sharding; + sharding_spec->sharding = xla::HloSharding::Tile(cube).ToProto(); shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 8); @@ -153,8 +152,7 @@ TEST_F(XLAShardingTest, ShardTensor) { EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({8, 1, 2})); // Replicated, all shards should be identical. - sharding = xla::HloSharding::Replicate().ToProto(); - sharding_spec->sharding = sharding; + sharding_spec->sharding = xla::HloSharding::Replicate().ToProto(); shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 8); @@ -165,7 +163,7 @@ TEST_F(XLAShardingTest, ShardTensor) { // last shard size should be smaller in dim=2 because it's not evenly // divisible. tensor = at::ones({1, 8, 7, 4}, at::TensorOptions(at::kFloat)); - tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + tensor_shape = CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array4D tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}}); sharding = xla::HloSharding::Tile(tesseract).ToProto(); sharding_spec = @@ -187,7 +185,7 @@ TEST_F(XLAShardingTest, ShardTensor) { // last shard size should be smaller in dim=2 because it's not evenly // divisible. tensor = at::ones({10, 1, 8, 7, 4}, at::TensorOptions(at::kFloat)); - tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + tensor_shape = CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array hypercube(std::vector{1, 1, 2, 2, 2}); hypercube.FillIota(0); sharding = xla::HloSharding::Tile(hypercube).ToProto(); @@ -212,7 +210,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) { // 2D tiled, The first dim is halved and the last replicated. at::Tensor tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); - xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array2D mesh({ {4, 5, 0, 1}, {6, 7, 2, 3}, @@ -234,8 +233,7 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) { {0, 1, 4, 5}, {2, 3, 6, 7}, }); - sharding = xla::HloSharding::Tile(mesh).ToProto(); - sharding_spec->sharding = sharding; + sharding_spec->sharding = xla::HloSharding::Tile(mesh).ToProto(); shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 4); @@ -248,25 +246,25 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) { at::Tensor minibatch_tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); xla::Shape tensor_shape = - CreateComputationShapeFromTensor(minibatch_tensor, nullptr); + CreateComputationShapeFromTensor(minibatch_tensor, GetDefaultDevice()); tensor_shape.set_dimensions( 0, minibatch_tensor.sizes()[0] * 2); // Assuming 2 hosts - xla::Array2D mesh({ - {0}, - {1}, - {2}, - {3}, - {4}, - {5}, - {6}, - {7}, + xla::Array3D mesh({ + {{0}}, + {{1}}, + {{2}}, + {{3}}, + {{4}}, + {{5}}, + {{6}}, + {{7}}, }); auto sharding = xla::HloSharding::Tile(mesh).ToProto(); auto sharding_spec = std::make_shared( sharding, tensor_shape, /*minibatch=*/true); auto shards = ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec, - devices, /*padded=*/false); + devices, /*padded=*/true); EXPECT_EQ(shards.size(), 4); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({2, 7, 4})); EXPECT_EQ(shards[3].sizes(), c10::ArrayRef({2, 7, 4})); @@ -274,7 +272,8 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) { TEST_F(XLAShardingTest, EqualShardingSpecs) { auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); - xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({ {0, 1, 2, 3}, {4, 5, 6, 7}, @@ -300,7 +299,8 @@ TEST_F(XLAShardingTest, CreateTensorsData) { std::vector tensors(2); auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat)); - xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); std::fill_n(tensors.begin(), tensors.size(), tensor); std::vector devices(2); std::fill_n(devices.begin(), devices.size(), GetDefaultDevice()->toString()); @@ -349,7 +349,8 @@ TEST_F(XLAShardingTest, InputHandler) { std::vector tensors(2); auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat)); - xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); std::fill_n(tensors.begin(), tensors.size(), tensor); std::vector devices = {"TPU:0", "TPU:1"}; std::vector shardings = { diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index e6ec99a46cf..edbf037ffb4 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -792,25 +792,25 @@ void InitXlaModuleBindings(py::module m) { const py::list& group_assignment, const py::list& replication_groups, int sharding_type, bool minibatch) { - xla::Shape tensor_shape = + xla::Shape global_shape = CreateComputationShapeFromTensor(tensor, nullptr); if (minibatch) { int num_local_devices = runtime::GetComputationClient()->GetLocalDevices().size(); int num_global_devices = runtime::GetComputationClient()->GetAllDevices().size(); - XLA_CHECK(tile_assignment.size() == num_global_devices) + XLA_CHECK(tile_assignment.size()[0] == num_global_devices) << "Minibatch sharding only supports sharding along the batch " "dimension"; int batch_dim_shape = tensor.sizes()[0] * num_global_devices / num_local_devices; - tensor_shape.set_dimensions(0, batch_dim_shape); + global_shape.set_dimensions(0, batch_dim_shape); } return std::make_shared( ShardingUtil::CreateOpSharding( tile_assignment, group_assignment, replication_groups, ShardingUtil::ShardingType(sharding_type)), - tensor_shape, minibatch); + global_shape, minibatch); })); m.def("_xla_tensors_from_aten", [](const std::vector& tensors, @@ -1519,11 +1519,11 @@ void InitXlaModuleBindings(py::module m) { auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); for (auto shard : shards) { XLA_CHECK(shard.sizes() == shard_shape) - << "Input shard shape must include padding: " << shard.sizes(); - // << " vs " << shard_shape; + << "Input shard shape must include padding: " << shard.sizes() + << " vs " << shard_shape; } - auto xla_data = ShardingUtil::CreateShardedData(shards, devices, - xtensor->shape(), sharding); + auto xla_data = + ShardingUtil::CreateShardedData(shards, devices, sharding_spec); xtensor->SetXlaData(WrapXlaData(xla_data)); }); // This is useful for debugging and generating a partitioned HLO separately diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 22371a50efa..e8a3cf32b81 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -614,9 +614,10 @@ torch::lazy::BackendDataPtr TensorToXlaData( runtime::GetComputationClient()->GetLocalDevices(); auto replicated_data = std::vector(local_devices.size(), tensor); + auto sharding_spec = std::make_shared( + xla::HloSharding::Replicate().ToProto(), shape); return WrapXlaData(ShardingUtil::CreateShardedData( - replicated_data, local_devices, shape, - xla::HloSharding::Replicate().ToProto())); + replicated_data, local_devices, sharding_spec)); } static const bool transfer_async = @@ -865,9 +866,10 @@ std::vector CreateTensorsData( auto shape = CreateComputationShapeFromTensor(tensors[i], &device); auto replicated_data = std::vector(local_devices.size(), tensors[i]); + auto sharding_spec = std::make_shared( + xla::HloSharding::Replicate().ToProto(), shape); handles.push_back(ShardingUtil::CreateShardedData( - replicated_data, local_devices, shape, - xla::HloSharding::Replicate().ToProto())); + replicated_data, local_devices, sharding_spec)); } return WrapXlaData(handles); } @@ -936,28 +938,14 @@ std::vector CreateTensorsData( // global ordinals (e.g. ["TPU:4", "TPU:5", "TPU:6", "TPU:7"]). std::vector local_devices = runtime::GetComputationClient()->GetLocalDevices(); - xla::OpSharding sharding; - bool minibatch = false; - if (shardings[i] != nullptr) { - sharding = shardings[i]->sharding; - minibatch = shardings[i]->minibatch; - } else { - // If using SPMD and no sharding is attached to the tensor, implicitly - // replicate to all local devices. - sharding = xla::HloSharding::Replicate().ToProto(); - } // Shards the input tensors with padding, to split evenly. // The execution requires consistent shard sizes, and the zero-padded // values should be ignored. std::vector local_shards = ShardingUtil::ShardTensor(tensors[i], shardings[i], local_devices, /*padded=*/true); - if (minibatch) { // change global shape as tensor is already sharded - // accross batch dimesion. - shape = shardings[i]->shape; - } new_handles.push_back(ShardingUtil::CreateShardedData( - local_shards, local_devices, shape, sharding)); + local_shards, local_devices, shardings[i])); } else { auto populate_fn = [&, i, device]( diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 7fb766357d4..32e87565ccd 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -327,9 +327,6 @@ std::vector ShardingUtil::OutputHandler( XLATensor::ShardingSpecPtr sharding = sharding_specs[i]; if (replicated_output && sharding && (sharding->sharding.type() != xla::OpSharding::REPLICATED)) { - // XLA_CHECK(sharding->shape.has_value()) - // << "Sharding or Wrapping data shards in OutputHandler requires " - // "unpartitioned tensor shape."; // Reshards replicated output if `sharding` is present. std::vector tensors = XlaDataToTensors( {WrapXlaData(sharded_results[0][i])}, @@ -370,7 +367,6 @@ std::vector ShardingUtil::GetShardShape( return globalShape; } else if (sharding.type() == xla::OpSharding::OTHER) { auto tile_shape = sharding.tile_assignment_dimensions(); - // `shard_shape[j]` is the size of dimension `j` in the resulting shard. std::vector shard_shape; for (int j = 0; j < tile_shape.size(); j++) { @@ -477,13 +473,16 @@ ShardingUtil::GetShardIndicesForDevices( std::vector ShardingUtil::ShardTensor( const at::Tensor& tensor, const XLATensor::ShardingSpecPtr shardings, const std::vector& devices, bool padded) { - auto sharding = shardings->sharding; + xla::OpSharding sharding; + if (shardings != nullptr) { + sharding = shardings->sharding; + } bool minibatch = shardings->minibatch; TF_LOG(INFO) << "ShardTensor with sharding type(" << sharding.type() << ")... and minibatch = " << minibatch << std::endl; auto device_index = build_index_map(devices); std::vector shards(devices.size()); - if (sharding.type() == xla::OpSharding::REPLICATED) { + if (shardings == nullptr || sharding.type() == xla::OpSharding::REPLICATED) { std::fill_n(shards.begin(), shards.size(), tensor); } else if (sharding.type() == xla::OpSharding::OTHER) { XLA_CHECK(sharding.tile_shape().dimensions_size() <= 2); @@ -505,7 +504,7 @@ std::vector ShardingUtil::ShardTensor( c10::ArrayRef(shard_indices[i])); shards[i] = shard.contiguous(at::MemoryFormat::Contiguous); } - + TF_LOG(INFO) << "check shard shape " << shard_shape; // Zero-pad to the right to ensure the sizes are even if (shards.size() > 0 && padded) { for (size_t i = 0; i < shards.size(); ++i) { @@ -661,10 +660,12 @@ void ShardingUtil::PrepareOutputShardingPropagation( runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( std::vector& local_shards, std::vector& devices, - xla::Shape global_shape, xla::OpSharding sharding) { + const XLATensor::ShardingSpecPtr& sharding_spec) { XLA_CHECK(local_shards.size() == devices.size()) << "A device must be speficied for each shard"; std::vector source_tensors; + auto global_shape = sharding_spec->shape; + auto sharding = sharding_spec->sharding; for (int64_t j = 0; j < devices.size(); ++j) { auto shard_device = ParseDeviceString(devices[j]); auto shard_shape = diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 7b7fa906a28..9129c7ac7de 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -147,7 +147,7 @@ class ShardingUtil { // the PjRtShardedData wrapping the shards. static runtime::ComputationClient::DataPtr CreateShardedData( std::vector& shards, std::vector& devices, - xla::Shape global_shape, xla::OpSharding sharding); + const XLATensor::ShardingSpecPtr& sharding_spec); }; } // namespace torch_xla