diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 6397d17f837..646a45b92f3 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -36,17 +36,22 @@ class XLAShardingTest : public AtenXlaTensorTestBase {}; TEST_F(XLAShardingTest, GetShardShape) { auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array2D mesh({ {0, 1}, {2, 3}, }); auto sharding = xla::HloSharding::Tile(mesh).ToProto(); - auto shard_shape = ShardingUtil::GetShardShape(tensor, sharding); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + + auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); // For tiled sharding, each dimension should be halved EXPECT_EQ(shard_shape, std::vector({4, 4})); - sharding = xla::HloSharding::Replicate().ToProto(); - shard_shape = ShardingUtil::GetShardShape(tensor, sharding); + sharding_spec->sharding = xla::HloSharding::Replicate().ToProto(); + shard_shape = ShardingUtil::GetShardShape(sharding_spec); // For replicated sharding, each dimension should be preserved EXPECT_EQ(shard_shape, std::vector({8, 7})); } @@ -55,12 +60,16 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3"}; auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array2D mesh({ {0, 1}, {2, 3}, }); auto sharding = xla::HloSharding::Tile(mesh).ToProto(); - auto shard_shape = ShardingUtil::GetShardShape(tensor, sharding); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); auto shard_indices = ShardingUtil::GetShardIndicesForDevices( shard_shape, tensor.sizes().vec(), sharding, devices); EXPECT_EQ(shard_indices.size(), devices.size()); @@ -82,9 +91,9 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { EXPECT_EQ(slice.step(), 1); } } - sharding = xla::HloSharding::Replicate().ToProto(); - shard_shape = ShardingUtil::GetShardShape(tensor, sharding); + sharding_spec->sharding = sharding; + shard_shape = ShardingUtil::GetShardShape(sharding_spec); shard_indices = ShardingUtil::GetShardIndicesForDevices( shard_shape, tensor.sizes().vec(), sharding, devices); EXPECT_EQ(shard_indices.size(), devices.size()); @@ -100,13 +109,17 @@ TEST_F(XLAShardingTest, ShardTensor) { // 1D tiled at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::OpSharding sharding = xla::HloSharding::Tile1D( CreateComputationShapeFromTensor(tensor, GetDefaultDevice()), devices.size()) .ToProto(); - auto shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({1})); EXPECT_EQ(shards[1].sizes(), c10::ArrayRef({1})); @@ -114,13 +127,16 @@ TEST_F(XLAShardingTest, ShardTensor) { // 2D tiled, The first dim is halved and the last replicated. The last shard // size should be smaller in dim=1 because it's not evenly divisible. tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); + tensor_shape = CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array2D mesh({ {0, 1, 2, 3}, {4, 5, 6, 7}, }); sharding = xla::HloSharding::Tile(mesh).ToProto(); - shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false); + sharding_spec = + std::make_shared(sharding, tensor_shape); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({4, 2, 4})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({4, 1, 4})); @@ -128,17 +144,17 @@ TEST_F(XLAShardingTest, ShardTensor) { // 3D tiled, the first dim is replicated and the last halved. The last shard // size should be smaller in dim=1 because it's not evenly divisible. xla::Array3D cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}); - sharding = xla::HloSharding::Tile(cube).ToProto(); - shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false); + sharding_spec->sharding = xla::HloSharding::Tile(cube).ToProto(); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({8, 2, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({8, 1, 2})); // Replicated, all shards should be identical. - sharding = xla::HloSharding::Replicate().ToProto(); - shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false); + sharding_spec->sharding = xla::HloSharding::Replicate().ToProto(); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({8, 7, 4})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({8, 7, 4})); @@ -147,17 +163,20 @@ TEST_F(XLAShardingTest, ShardTensor) { // last shard size should be smaller in dim=2 because it's not evenly // divisible. tensor = at::ones({1, 8, 7, 4}, at::TensorOptions(at::kFloat)); + tensor_shape = CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array4D tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}}); sharding = xla::HloSharding::Tile(tesseract).ToProto(); - shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false); + sharding_spec = + std::make_shared(sharding, tensor_shape); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({1, 8, 2, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({1, 8, 1, 2})); // 4D tiled and padded, all shard sizes should be idential. - shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/true); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/true); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({1, 8, 2, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({1, 8, 2, 2})); @@ -166,18 +185,21 @@ TEST_F(XLAShardingTest, ShardTensor) { // last shard size should be smaller in dim=2 because it's not evenly // divisible. tensor = at::ones({10, 1, 8, 7, 4}, at::TensorOptions(at::kFloat)); + tensor_shape = CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array hypercube(std::vector{1, 1, 2, 2, 2}); hypercube.FillIota(0); sharding = xla::HloSharding::Tile(hypercube).ToProto(); - shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false); + sharding_spec = + std::make_shared(sharding, tensor_shape); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({10, 1, 4, 3, 2})); // 5D tiled and padded, all shard sizes should be identical. - shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/true); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/true); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); @@ -188,16 +210,19 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) { // 2D tiled, The first dim is halved and the last replicated. at::Tensor tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array2D mesh({ {4, 5, 0, 1}, {6, 7, 2, 3}, }); auto sharding = xla::HloSharding::Tile(mesh).ToProto(); - + auto sharding_spec = + std::make_shared(sharding, tensor_shape); // For devices at the start of the mesh, all shards should have the same // unpadded shape. - auto shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false); + auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 4); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({4, 2, 4})); EXPECT_EQ(shards[3].sizes(), c10::ArrayRef({4, 2, 4})); @@ -208,23 +233,58 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) { {0, 1, 4, 5}, {2, 3, 6, 7}, }); - sharding = xla::HloSharding::Tile(mesh).ToProto(); - shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false); + sharding_spec->sharding = xla::HloSharding::Tile(mesh).ToProto(); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 4); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({4, 2, 4})); EXPECT_EQ(shards[3].sizes(), c10::ArrayRef({4, 1, 4})); } +TEST_F(XLAShardingTest, ShardTensorMiniBatch) { + std::vector devices = {"TPU:4", "TPU:5", "TPU:6", "TPU:7"}; + at::Tensor minibatch_tensor = + at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); + xla::Shape global_shape = + CreateComputationShapeFromTensor(minibatch_tensor, GetDefaultDevice()); + global_shape.set_dimensions( + 0, minibatch_tensor.sizes()[0] * 2); // Assuming 2 hosts + xla::Array3D mesh({ + {{0}}, + {{1}}, + {{2}}, + {{3}}, + {{4}}, + {{5}}, + {{6}}, + {{7}}, + }); + + auto sharding = xla::HloSharding::Tile(mesh).ToProto(); + auto sharding_spec = std::make_shared( + sharding, global_shape, /*minibatch=*/true); + auto shards = ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec, + devices, /*padded=*/true); + EXPECT_EQ(shards.size(), 4); + EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({2, 7, 4})); + EXPECT_EQ(shards[3].sizes(), c10::ArrayRef({2, 7, 4})); +} + TEST_F(XLAShardingTest, EqualShardingSpecs) { + auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({ {0, 1, 2, 3}, {4, 5, 6, 7}, }) - .ToProto()); + .ToProto(), + tensor_shape); XLATensor::ShardingSpec tiled_3d( - xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto()); - XLATensor::ShardingSpec replicated(xla::HloSharding::Replicate().ToProto()); + xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto(), + tensor_shape); + XLATensor::ShardingSpec replicated(xla::HloSharding::Replicate().ToProto(), + tensor_shape); EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_2d)); EXPECT_FALSE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_3d)); EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(replicated, replicated)); @@ -238,13 +298,15 @@ TEST_F(XLAShardingTest, CreateTensorsData) { } std::vector tensors(2); - std::fill_n(tensors.begin(), tensors.size(), - at::ones({8, 8}, at::TensorOptions(at::kFloat))); + auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); + std::fill_n(tensors.begin(), tensors.size(), tensor); std::vector devices(2); std::fill_n(devices.begin(), devices.size(), GetDefaultDevice()->toString()); std::vector shardings = { nullptr, std::make_shared( - xla::HloSharding::Replicate().ToProto())}; + xla::HloSharding::Replicate().ToProto(), tensor_shape)}; std::vector tensors_data = CreateTensorsData(tensors, shardings, devices); @@ -286,12 +348,14 @@ TEST_F(XLAShardingTest, InputHandler) { } std::vector tensors(2); - std::fill_n(tensors.begin(), tensors.size(), - at::ones({8, 8}, at::TensorOptions(at::kFloat))); + auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); + std::fill_n(tensors.begin(), tensors.size(), tensor); std::vector devices = {"TPU:0", "TPU:1"}; std::vector shardings = { nullptr, std::make_shared( - xla::HloSharding::Replicate().ToProto())}; + xla::HloSharding::Replicate().ToProto(), tensor_shape)}; std::vector tensors_data = CreateTensorsData(tensors, shardings, devices); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a7b4cb1ebca..419a281691c 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -790,12 +790,27 @@ void InitXlaModuleBindings(py::module m) { m, "XlaShardingSpec") .def(py::init([](at::Tensor tensor, const py::list& tile_assignment, const py::list& group_assignment, - const py::list& replication_groups, int sharding_type) { + const py::list& replication_groups, int sharding_type, + bool minibatch) { + xla::Shape global_shape = + CreateComputationShapeFromTensor(tensor, nullptr); + if (minibatch) { + int num_local_devices = + runtime::GetComputationClient()->GetLocalDevices().size(); + int num_global_devices = + runtime::GetComputationClient()->GetAllDevices().size(); + XLA_CHECK(tile_assignment.size() == num_global_devices) + << "Minibatch sharding only supports sharding along the batch " + "dimension"; + int batch_dim_shape = + tensor.sizes()[0] * num_global_devices / num_local_devices; + global_shape.set_dimensions(0, batch_dim_shape); + } return std::make_shared( ShardingUtil::CreateOpSharding( tile_assignment, group_assignment, replication_groups, ShardingUtil::ShardingType(sharding_type)), - CreateComputationShapeFromTensor(tensor, nullptr)); + global_shape, minibatch); })); m.def("_xla_tensors_from_aten", [](const std::vector& tensors, @@ -1490,9 +1505,9 @@ void InitXlaModuleBindings(py::module m) { for (auto& shard : shards) { shard_devices.push_back(shard->device()); } - + auto sharding_spec = xtensor->sharding_spec(); auto sharding = xtensor->sharding_spec()->sharding; - auto shard_shape = ShardingUtil::GetShardShape(input, sharding); + auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); auto indices = ShardingUtil::GetShardIndicesForDevices( shard_shape, input.sizes().vec(), sharding, shard_devices); @@ -1532,17 +1547,18 @@ void InitXlaModuleBindings(py::module m) { runtime::GetComputationClient()->GetLocalDevices().size()) << "Shards must be provided for all local devices"; auto sharding = xtensor->sharding_spec()->sharding; + auto sharding_spec = xtensor->sharding_spec(); XLA_CHECK(sharding.type() != xla::OpSharding::REPLICATED) << "Replicated tensor should not be loaded from _load_local_shards - " "use copy_"; - auto shard_shape = ShardingUtil::GetShardShape(tensor, sharding); + auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); for (auto shard : shards) { XLA_CHECK(shard.sizes() == shard_shape) << "Input shard shape must include padding: " << shard.sizes() << " vs " << shard_shape; } - auto xla_data = ShardingUtil::CreateShardedData(shards, devices, - xtensor->shape(), sharding); + auto xla_data = + ShardingUtil::CreateShardedData(shards, devices, sharding_spec); xtensor->SetXlaData(WrapXlaData(xla_data)); }); // This is useful for debugging and generating a partitioned HLO separately diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index cfac3712e21..f2d12b2325f 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -253,13 +253,17 @@ class XLATensor : public torch::lazy::LazyTensor { // XLA SPMD sharding spec annoation. The XLA tensor uses this to create // HloSharding for replication, manual and tile shardings. struct ShardingSpec { - ShardingSpec(const xla::OpSharding& sharding) : sharding(sharding) {} ShardingSpec(const xla::OpSharding& sharding, const xla::Shape& shape) : sharding(sharding), shape(shape) {} + ShardingSpec(const xla::OpSharding& sharding, const xla::Shape& shape, + const bool& minibatch) + : sharding(sharding), shape(shape), minibatch(minibatch) {} xla::OpSharding sharding; // Optional source tensor shape unpartitioned. - std::optional shape; + xla::Shape shape; + // Parameter for represent input batch in sharded along batch axes + bool minibatch = false; }; // Annotate the IR value with ShardingSpec. diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 3650962cf16..e2d76982759 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -614,9 +614,10 @@ torch::lazy::BackendDataPtr TensorToXlaData( runtime::GetComputationClient()->GetLocalDevices(); auto replicated_data = std::vector(local_devices.size(), tensor); + auto sharding_spec = std::make_shared( + xla::HloSharding::Replicate().ToProto(), shape); return WrapXlaData(ShardingUtil::CreateShardedData( - replicated_data, local_devices, shape, - xla::HloSharding::Replicate().ToProto())); + replicated_data, local_devices, sharding_spec)); } static const bool transfer_async = @@ -870,9 +871,10 @@ std::vector CreateTensorsData( auto shape = CreateComputationShapeFromTensor(tensors[i], &device); auto replicated_data = std::vector(local_devices.size(), tensors[i]); + auto sharding_spec = std::make_shared( + xla::HloSharding::Replicate().ToProto(), shape); handles.push_back(ShardingUtil::CreateShardedData( - replicated_data, local_devices, shape, - xla::HloSharding::Replicate().ToProto())); + replicated_data, local_devices, sharding_spec)); } return WrapXlaData(handles); } @@ -941,21 +943,14 @@ std::vector CreateTensorsData( // global ordinals (e.g. ["TPU:4", "TPU:5", "TPU:6", "TPU:7"]). std::vector local_devices = runtime::GetComputationClient()->GetLocalDevices(); - xla::OpSharding sharding; - if (shardings[i] != nullptr) { - sharding = shardings[i]->sharding; - } else { - // If using SPMD and no sharding is attached to the tensor, implicitly - // replicate to all local devices. - sharding = xla::HloSharding::Replicate().ToProto(); - } // Shards the input tensors with padding, to split evenly. // The execution requires consistent shard sizes, and the zero-padded // values should be ignored. - std::vector local_shards = ShardingUtil::ShardTensor( - tensors[i], sharding, local_devices, /*padded=*/true); + std::vector local_shards = + ShardingUtil::ShardTensor(tensors[i], shardings[i], local_devices, + /*padded=*/true); new_handles.push_back(ShardingUtil::CreateShardedData( - local_shards, local_devices, shape, sharding)); + local_shards, local_devices, shardings[i])); } else { auto populate_fn = [&, i, device]( diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 4f431e16795..433520e7796 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -7,6 +7,7 @@ #include "torch/csrc/lazy/core/ir_util.h" #include "torch_xla/csrc/device.h" +#include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/tensor.h" @@ -339,13 +340,10 @@ std::vector ShardingUtil::OutputHandler( XLATensor::ShardingSpecPtr sharding = sharding_specs[i]; if (replicated_output && sharding && (sharding->sharding.type() != xla::OpSharding::REPLICATED)) { - XLA_CHECK(sharding->shape.has_value()) - << "Sharding or Wrapping data shards in OutputHandler requires " - "unpartitioned tensor shape."; // Reshards replicated output if `sharding` is present. std::vector tensors = XlaDataToTensors( {WrapXlaData(sharded_results[0][i])}, - TensorTypeFromXlaType(sharding->shape.value().element_type())); + TensorTypeFromXlaType(sharding->shape.element_type())); outputs.push_back(UnwrapXlaData(CreateTensorsData( tensors, {sharding}, std::vector{GetVirtualDevice().toString()})[0])); @@ -365,7 +363,7 @@ std::vector ShardingUtil::OutputHandler( sharded_results[0][i]->shape()); } outputs.push_back(runtime::GetComputationClient()->WrapDataShards( - shards, GetVirtualDevice().toString(), sharding->shape.value(), + shards, GetVirtualDevice().toString(), sharding->shape, sharding->sharding)); } } @@ -373,27 +371,53 @@ std::vector ShardingUtil::OutputHandler( } std::vector ShardingUtil::GetShardShape( - const at::Tensor& tensor, const xla::OpSharding sharding) { + const XLATensor::ShardingSpecPtr shardings) { + auto sharding = shardings->sharding; + auto global_shape = shardings->shape.dimensions(); if (sharding.type() == xla::OpSharding::REPLICATED) { - return tensor.sizes().vec(); + std::vector globalShape; + globalShape.assign(global_shape.begin(), global_shape.end()); + return globalShape; } else if (sharding.type() == xla::OpSharding::OTHER) { auto tile_shape = sharding.tile_assignment_dimensions(); - // `shard_shape[j]` is the size of dimension `j` in the resulting shard. std::vector shard_shape; for (int j = 0; j < tile_shape.size(); j++) { if (sharding.replicate_on_last_tile_dim() && j == tile_shape.size() - 1) { continue; } - shard_shape.push_back(tensor.sizes()[j] / tile_shape[j] + - (tensor.sizes()[j] % tile_shape[j] != 0)); + shard_shape.push_back(global_shape[j] / tile_shape[j] + + (global_shape[j] % tile_shape[j] != 0)); } + return shard_shape; } else { TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type(); } } +std::vector> +ShardingUtil::GetShardIndicesForMinibatchTensor( + const std::vector& shard_shape, + const std::vector& devices) { + std::vector> shard_indices( + devices.size()); + for (int i = 0; i < devices.size(); i++) { + std::vector indices; + // For batch dimension sharding we just change shard indices on first axis + // and copy all indices for all remaining axes. + for (int j = 0; j < shard_shape.size(); j++) { + indices.push_back(at::indexing::Slice(0, shard_shape[j])); + } + // As the tensor is batch sharded we just care about the first dimension + // to calculate shard indices. + indices[0] = + at::indexing::Slice(i * shard_shape[0], (i + 1) * shard_shape[0]); + shard_indices[i] = indices; + } + return shard_indices; +} + std::vector> ShardingUtil::GetShardIndicesForDevices( const std::vector& shard_shape, @@ -432,13 +456,14 @@ ShardingUtil::GetShardIndicesForDevices( } // Given the shard's row-major index `i`, we need to calculate shard's - // coordinates (n_0, ..., n_d) in the tiling to generate the index slices. - // Using `N_j = tile_shape[j]` and `0 <= n_j < N_j`, the following - // equation needs to be solved for all n_j: - // `i = n_d + N_d * (n_{d-1} + N_{d-1} * (... + (N_1 * n_0)))` - // Let `offset_j = n_j + N_j * (n_{j-1} + N_{j-1} * (... + (N_1 * n_0)))`. - // Then `offset_d = i`, `n_j = offset_j % N_j`, and `offset_{j-1} = - // offset_j / N_j`. + // coordinates (n_0, ..., n_d) in the tiling to generate the index + // slices. Using `N_j = tile_shape[j]` and `0 <= n_j < N_j`, the + // following equation needs to be solved for all n_j: + // `i = n_d + N_d * (n_{d-1} + N_{d-1} * (... + (N_1 * + // n_0)))` + // Let `offset_j = n_j + N_j * (n_{j-1} + N_{j-1} * (... + (N_1 * + // n_0)))`. Then `offset_d = i`, `n_j = offset_j % N_j`, and + // `offset_{j-1} = offset_j / N_j`. int offset = i; std::vector indices; for (int j = tile_shape.size() - 1; j >= 0; j--) { @@ -468,28 +493,39 @@ ShardingUtil::GetShardIndicesForDevices( } std::vector ShardingUtil::ShardTensor( - const at::Tensor& tensor, const xla::OpSharding sharding, + const at::Tensor& tensor, const XLATensor::ShardingSpecPtr shardings, const std::vector& devices, bool padded) { - TF_LOG(INFO) << "ShardTensor with sharding type(" << sharding.type() << ")..." - << std::endl; + xla::OpSharding sharding; + bool minibatch = false; + if (shardings != nullptr) { + sharding = shardings->sharding; + minibatch = shardings->minibatch; + } + TF_LOG(INFO) << "ShardTensor with sharding type(" << sharding.type() + << ")... and minibatch = " << minibatch << std::endl; auto device_index = build_index_map(devices); std::vector shards(devices.size()); - if (sharding.type() == xla::OpSharding::REPLICATED) { + if (shardings == nullptr || sharding.type() == xla::OpSharding::REPLICATED) { std::fill_n(shards.begin(), shards.size(), tensor); } else if (sharding.type() == xla::OpSharding::OTHER) { XLA_CHECK(sharding.tile_shape().dimensions_size() <= 2); XLA_CHECK(tensor.sizes().size() >= sharding.tile_shape().dimensions_size()); - auto shard_shape = GetShardShape(tensor, sharding); - auto shard_indices = GetShardIndicesForDevices( - shard_shape, tensor.sizes().vec(), sharding, devices); + auto shard_shape = GetShardShape(shardings); + + std::vector> shard_indices; + if (minibatch) { + shard_indices = GetShardIndicesForMinibatchTensor(shard_shape, devices); + } else { + shard_indices = GetShardIndicesForDevices( + shard_shape, tensor.sizes().vec(), sharding, devices); + } for (size_t i = 0; i < shard_indices.size(); i++) { at::Tensor shard = tensor.index( c10::ArrayRef(shard_indices[i])); shards[i] = shard.contiguous(at::MemoryFormat::Contiguous); } - // Zero-pad to the right to ensure the sizes are even if (shards.size() > 0 && padded) { for (size_t i = 0; i < shards.size(); ++i) { @@ -572,8 +608,7 @@ void ShardingUtil::PrepareOutputShardingPropagation( // replication. auto sharded_data_placeholder = WrapXlaData(runtime::GetComputationClient()->WrapDataShards( - {}, GetVirtualDevice().toString(), - (*sharding_specs)[i]->shape.value(), + {}, GetVirtualDevice().toString(), (*sharding_specs)[i]->shape, (*sharding_specs)[i]->sharding)); // Register the sharded data placeholder to the tensor and its node. @@ -636,7 +671,7 @@ void ShardingUtil::PrepareOutputShardingPropagation( // replication. auto sharded_data_placeholder = WrapXlaData(runtime::GetComputationClient()->WrapDataShards( - {}, GetVirtualDevice().toString(), sharding_specs[i]->shape.value(), + {}, GetVirtualDevice().toString(), sharding_specs[i]->shape, sharding_specs[i]->sharding)); // Register the sharded data placeholder to the tensor and its node. @@ -646,10 +681,22 @@ void ShardingUtil::PrepareOutputShardingPropagation( runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( std::vector& local_shards, std::vector& devices, - xla::Shape global_shape, xla::OpSharding sharding) { + const XLATensor::ShardingSpecPtr& sharding_spec) { XLA_CHECK(local_shards.size() == devices.size()) << "A device must be speficied for each shard"; std::vector source_tensors; + xla::Shape global_shape; + xla::OpSharding sharding; + if (sharding_spec == nullptr) { + // if sharding.type is replicated, global_shape is shape of the tensor. + auto first_device = ParseDeviceString(devices[0]); + global_shape = + CreateComputationShapeFromTensor(local_shards[0], &first_device); + sharding = xla::HloSharding::Replicate().ToProto(); + } else { + global_shape = sharding_spec->shape; + sharding = sharding_spec->sharding; + } for (int64_t j = 0; j < devices.size(); ++j) { auto shard_device = ParseDeviceString(devices[j]); auto shard_shape = diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index f335085656f..b22cc7594f5 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -87,8 +87,8 @@ class ShardingUtil { // Returns the shape of the resulting shards of `tensor` after applying // `sharding`. This assumes the shards will be padded to ensure they all // have the same shape. - static std::vector GetShardShape(const at::Tensor& tensor, - const xla::OpSharding sharding); + static std::vector GetShardShape( + const XLATensor::ShardingSpecPtr shardings); // Uses the provided `sharding` spec and expected shard shape to determine the // index slices for the shards which belong on `devices`. Only supports @@ -99,6 +99,12 @@ class ShardingUtil { const xla::OpSharding sharding, const std::vector& devices); + // Returns the indices for the shards. Supports `OTHER` sharding types and + // called when input is sharded along the batch axis. + static std::vector> + GetShardIndicesForMinibatchTensor(const std::vector& shard_shape, + const std::vector& devices); + // Shards a tensor and returns the sharded tensors which belong on `devices` // based on the `sharding` spec. REPLICATED sharding should result in shards // identical to the input; OTHERS (tiled) sharding result in shards where @@ -109,7 +115,7 @@ class ShardingUtil { // The the returned tensors will be in 1:1 correspondence with the `devices` // vector, so the `i`th result will belong on the `i`th device. static std::vector ShardTensor( - const at::Tensor& tensor, const xla::OpSharding sharding, + const at::Tensor& tensor, const XLATensor::ShardingSpecPtr shardings, const std::vector& devices, bool padded = true); // Prepares output sharding propagation by extracting output parameter @@ -139,7 +145,7 @@ class ShardingUtil { // the PjRtShardedData wrapping the shards. static runtime::ComputationClient::DataPtr CreateShardedData( std::vector& shards, std::vector& devices, - xla::Shape global_shape, xla::OpSharding sharding); + const XLATensor::ShardingSpecPtr& sharding_spec); }; } // namespace torch_xla diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 43151653ae8..0ac313c34ad 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -468,6 +468,7 @@ def wrap_if_sharded(x: Any) -> Any: class ShardingSpec: mesh: Mesh partition_spec: Tuple[Union[int, None]] + minibatch: Optional[bool] = False # Derived fields _tile_assignment: List[int] = field(init=False) @@ -494,7 +495,8 @@ def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]: return torch_xla._XLAC.XlaShardingSpec(t, self._tile_assignment, self._group_assignment, self._replication_groups, - int(self._sharding_type)) + int(self._sharding_type), + self.minibatch) def can_apply(self, t: torch.Tensor) -> bool: """