From 99e33856dd6d0793738595d7169a7d5f131ab352 Mon Sep 17 00:00:00 2001 From: Mohit Khatwani Date: Fri, 21 Jul 2023 19:08:12 +0000 Subject: [PATCH 1/9] batch sharded --- torch_xla/csrc/init_python_bindings.cpp | 11 ++- torch_xla/csrc/tensor.h | 5 ++ torch_xla/csrc/tensor_util.cpp | 11 ++- torch_xla/csrc/xla_sharding_util.cpp | 106 +++++++++++++++--------- torch_xla/csrc/xla_sharding_util.h | 10 ++- torch_xla/experimental/xla_sharding.py | 4 +- 6 files changed, 96 insertions(+), 51 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a7b4cb1ebca..e004f538e54 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -790,12 +790,13 @@ void InitXlaModuleBindings(py::module m) { m, "XlaShardingSpec") .def(py::init([](at::Tensor tensor, const py::list& tile_assignment, const py::list& group_assignment, - const py::list& replication_groups, int sharding_type) { + const py::list& replication_groups, int sharding_type, + bool minibatch) { return std::make_shared( ShardingUtil::CreateOpSharding( tile_assignment, group_assignment, replication_groups, ShardingUtil::ShardingType(sharding_type)), - CreateComputationShapeFromTensor(tensor, nullptr)); + CreateComputationShapeFromTensor(tensor, nullptr), minibatch); })); m.def("_xla_tensors_from_aten", [](const std::vector& tensors, @@ -1492,7 +1493,8 @@ void InitXlaModuleBindings(py::module m) { } auto sharding = xtensor->sharding_spec()->sharding; - auto shard_shape = ShardingUtil::GetShardShape(input, sharding); + auto shard_shape = + ShardingUtil::GetShardShape(input.sizes().vec(), sharding); auto indices = ShardingUtil::GetShardIndicesForDevices( shard_shape, input.sizes().vec(), sharding, shard_devices); @@ -1535,7 +1537,8 @@ void InitXlaModuleBindings(py::module m) { XLA_CHECK(sharding.type() != xla::OpSharding::REPLICATED) << "Replicated tensor should not be loaded from _load_local_shards - " "use copy_"; - auto shard_shape = ShardingUtil::GetShardShape(tensor, sharding); + auto shard_shape = + ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding); for (auto shard : shards) { XLA_CHECK(shard.sizes() == shard_shape) << "Input shard shape must include padding: " << shard.sizes() diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index cfac3712e21..562fb5945d7 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -256,10 +256,15 @@ class XLATensor : public torch::lazy::LazyTensor { ShardingSpec(const xla::OpSharding& sharding) : sharding(sharding) {} ShardingSpec(const xla::OpSharding& sharding, const xla::Shape& shape) : sharding(sharding), shape(shape) {} + ShardingSpec(const xla::OpSharding& sharding, const xla::Shape& shape, + const bool& minibatch) + : sharding(sharding), shape(shape), minibatch(minibatch) {} xla::OpSharding sharding; // Optional source tensor shape unpartitioned. std::optional shape; + // Parameter for represent input batch in sharded along batch axes + bool minibatch = false; }; // Annotate the IR value with ShardingSpec. diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 3650962cf16..2153d574ccc 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -942,8 +942,10 @@ std::vector CreateTensorsData( std::vector local_devices = runtime::GetComputationClient()->GetLocalDevices(); xla::OpSharding sharding; + bool minibatch; if (shardings[i] != nullptr) { sharding = shardings[i]->sharding; + minibatch = shardings[i]->minibatch; } else { // If using SPMD and no sharding is attached to the tensor, implicitly // replicate to all local devices. @@ -952,8 +954,13 @@ std::vector CreateTensorsData( // Shards the input tensors with padding, to split evenly. // The execution requires consistent shard sizes, and the zero-padded // values should be ignored. - std::vector local_shards = ShardingUtil::ShardTensor( - tensors[i], sharding, local_devices, /*padded=*/true); + std::vector local_shards = + ShardingUtil::ShardTensor(tensors[i], sharding, local_devices, + /*padded=*/true, /*minibatch=*/minibatch); + if (minibatch) { // change global shape as tensor is already sharded + // accross batch dimesion. + continue; + } new_handles.push_back(ShardingUtil::CreateShardedData( local_shards, local_devices, shape, sharding)); } else { diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 4f431e16795..bba16540f48 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -373,9 +373,9 @@ std::vector ShardingUtil::OutputHandler( } std::vector ShardingUtil::GetShardShape( - const at::Tensor& tensor, const xla::OpSharding sharding) { + const std::vector& tensor_shape, const xla::OpSharding sharding) { if (sharding.type() == xla::OpSharding::REPLICATED) { - return tensor.sizes().vec(); + return tensor_shape; } else if (sharding.type() == xla::OpSharding::OTHER) { auto tile_shape = sharding.tile_assignment_dimensions(); @@ -385,9 +385,10 @@ std::vector ShardingUtil::GetShardShape( if (sharding.replicate_on_last_tile_dim() && j == tile_shape.size() - 1) { continue; } - shard_shape.push_back(tensor.sizes()[j] / tile_shape[j] + - (tensor.sizes()[j] % tile_shape[j] != 0)); + shard_shape.push_back(tensor_shape[j] / tile_shape[j] + + (tensor_shape[j] % tile_shape[j] != 0)); } + return shard_shape; } else { TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type(); @@ -398,7 +399,7 @@ std::vector> ShardingUtil::GetShardIndicesForDevices( const std::vector& shard_shape, const std::vector& tensor_shape, const xla::OpSharding sharding, - const std::vector& devices) { + const std::vector& devices, const bool minibatch) { // `shard_indices[dev][dim]` represents the index slice for dimension `dim` // that belongs on device `devices[dev]` if the tensor is sharded. If // `sharding` is REPLICATED, `shard_indices[dev]` will only have a single @@ -414,7 +415,27 @@ ShardingUtil::GetShardIndicesForDevices( std::fill_n(shard_indices.begin(), shard_indices.size(), indices); } else if (sharding.type() == xla::OpSharding::OTHER) { auto device_index = build_index_map(devices); - std::vector tile_assignment_devices( + if (minibatch) { + // shard tensor local to host + int start = 0; + for (int i = 0; i < devices.size(); i++) { + std::vector indices; + for (int j = tile_shape.size() - 1; j >= 0; j--) { + if (sharding.replicate_on_last_tile_dim() && + j == tile_shape.size() - 1) { + continue; + } + auto slice = at::indexing::Slice(0, shard_shape[j]); + if (j == 0) { // batch axis + slice = at::indexing::Slice(start, start + shard_shape[j]); + } + indices.push_back(slice); + } + std::reverse(indices.begin(), indices.end()); + shard_indices[i] = indices; + } + } else { + std::vector tile_assignment_devices( sharding.tile_assignment_devices().begin(), sharding.tile_assignment_devices().end()); if (!sharding.iota_reshape_dims().empty()) { @@ -425,41 +446,43 @@ ShardingUtil::GetShardIndicesForDevices( tileAssignment.array().begin(), tileAssignment.array().end()); } for (size_t i = 0; i < tile_assignment_devices.size(); i++) { - int64_t core = tile_assignment_devices[i]; - if (device_index.find(core) == device_index.end()) { - // Skip any shards whose device is not part of the `devices` list. - continue; - } + int64_t core = tile_assignment_devices[i]; + if (device_index.find(core) == device_index.end()) { + // Skip any shards whose device is not part of the `devices` list. + continue; + } - // Given the shard's row-major index `i`, we need to calculate shard's - // coordinates (n_0, ..., n_d) in the tiling to generate the index slices. - // Using `N_j = tile_shape[j]` and `0 <= n_j < N_j`, the following - // equation needs to be solved for all n_j: - // `i = n_d + N_d * (n_{d-1} + N_{d-1} * (... + (N_1 * n_0)))` - // Let `offset_j = n_j + N_j * (n_{j-1} + N_{j-1} * (... + (N_1 * n_0)))`. - // Then `offset_d = i`, `n_j = offset_j % N_j`, and `offset_{j-1} = - // offset_j / N_j`. - int offset = i; - std::vector indices; - for (int j = tile_shape.size() - 1; j >= 0; j--) { - if (sharding.replicate_on_last_tile_dim() && - j == tile_shape.size() - 1) { - // the last tile assignment dimension is replicated, which implies - // that the consecutive `tile_shape[j]` devices hold the replicated. + // Given the shard's row-major index `i`, we need to calculate shard's + // coordinates (n_0, ..., n_d) in the tiling to generate the index + // slices. Using `N_j = tile_shape[j]` and `0 <= n_j < N_j`, the + // following equation needs to be solved for all n_j: + // `i = n_d + N_d * (n_{d-1} + N_{d-1} * (... + (N_1 * + // n_0)))` + // Let `offset_j = n_j + N_j * (n_{j-1} + N_{j-1} * (... + (N_1 * + // n_0)))`. Then `offset_d = i`, `n_j = offset_j % N_j`, and + // `offset_{j-1} = offset_j / N_j`. + int offset = i; + std::vector indices; + for (int j = tile_shape.size() - 1; j >= 0; j--) { + if (sharding.replicate_on_last_tile_dim() && + j == tile_shape.size() - 1) { + // the last tile assignment dimension is replicated, which implies + // that the consecutive `tile_shape[j]` devices hold the replicated. + offset /= tile_shape[j]; + continue; + } + int64_t n_j = offset % tile_shape[j]; + // Clamp the slice bounds to the tensor shape to accurately reflect + // the shard size without padding. + int start = std::min(n_j * shard_shape[j], tensor_shape[j]); + int end = std::min((n_j + 1) * shard_shape[j], tensor_shape[j]); + auto slice = at::indexing::Slice(start, end); + indices.push_back(at::indexing::TensorIndex(slice)); offset /= tile_shape[j]; - continue; } - int64_t n_j = offset % tile_shape[j]; - // Clamp the slice bounds to the tensor shape to accurately reflect - // the shard size without padding. - int start = std::min(n_j * shard_shape[j], tensor_shape[j]); - int end = std::min((n_j + 1) * shard_shape[j], tensor_shape[j]); - auto slice = at::indexing::Slice(start, end); - indices.push_back(at::indexing::TensorIndex(slice)); - offset /= tile_shape[j]; + std::reverse(indices.begin(), indices.end()); + shard_indices[device_index[core]] = indices; } - std::reverse(indices.begin(), indices.end()); - shard_indices[device_index[core]] = indices; } } else { TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type(); @@ -469,7 +492,7 @@ ShardingUtil::GetShardIndicesForDevices( std::vector ShardingUtil::ShardTensor( const at::Tensor& tensor, const xla::OpSharding sharding, - const std::vector& devices, bool padded) { + const std::vector& devices, bool padded, bool minibatch) { TF_LOG(INFO) << "ShardTensor with sharding type(" << sharding.type() << ")..." << std::endl; auto device_index = build_index_map(devices); @@ -480,9 +503,12 @@ std::vector ShardingUtil::ShardTensor( XLA_CHECK(sharding.tile_shape().dimensions_size() <= 2); XLA_CHECK(tensor.sizes().size() >= sharding.tile_shape().dimensions_size()); - auto shard_shape = GetShardShape(tensor, sharding); + auto shard_shape = GetShardShape(tensor.sizes().vec(), sharding); + if (minibatch) { + shard_shape[0] = tensor.sizes().vec()[0] / devices.size(); + } auto shard_indices = GetShardIndicesForDevices( - shard_shape, tensor.sizes().vec(), sharding, devices); + shard_shape, tensor.sizes().vec(), sharding, devices, minibatch); for (size_t i = 0; i < shard_indices.size(); i++) { at::Tensor shard = tensor.index( diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index f335085656f..fb50fafab74 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -87,8 +87,8 @@ class ShardingUtil { // Returns the shape of the resulting shards of `tensor` after applying // `sharding`. This assumes the shards will be padded to ensure they all // have the same shape. - static std::vector GetShardShape(const at::Tensor& tensor, - const xla::OpSharding sharding); + static std::vector GetShardShape( + const std::vector& tensor_shape, const xla::OpSharding sharding); // Uses the provided `sharding` spec and expected shard shape to determine the // index slices for the shards which belong on `devices`. Only supports @@ -97,7 +97,8 @@ class ShardingUtil { GetShardIndicesForDevices(const std::vector& shard_shape, const std::vector& tensor_shape, const xla::OpSharding sharding, - const std::vector& devices); + const std::vector& devices, + const bool minibatch = false); // Shards a tensor and returns the sharded tensors which belong on `devices` // based on the `sharding` spec. REPLICATED sharding should result in shards @@ -110,7 +111,8 @@ class ShardingUtil { // vector, so the `i`th result will belong on the `i`th device. static std::vector ShardTensor( const at::Tensor& tensor, const xla::OpSharding sharding, - const std::vector& devices, bool padded = true); + const std::vector& devices, bool padded = true, + bool minibatch = false); // Prepares output sharding propagation by extracting output parameter // ShardingSpec into `sharding_specs` from the SPMD compiled `computation` and diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 43151653ae8..0ac313c34ad 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -468,6 +468,7 @@ def wrap_if_sharded(x: Any) -> Any: class ShardingSpec: mesh: Mesh partition_spec: Tuple[Union[int, None]] + minibatch: Optional[bool] = False # Derived fields _tile_assignment: List[int] = field(init=False) @@ -494,7 +495,8 @@ def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]: return torch_xla._XLAC.XlaShardingSpec(t, self._tile_assignment, self._group_assignment, self._replication_groups, - int(self._sharding_type)) + int(self._sharding_type), + self.minibatch) def can_apply(self, t: torch.Tensor) -> bool: """ From 1ac26a3f9e56d5c7e672ccfa4dc55b3648d7c518 Mon Sep 17 00:00:00 2001 From: Mohit Khatwani Date: Mon, 24 Jul 2023 23:18:03 +0000 Subject: [PATCH 2/9] change get_indices logic --- torch_xla/csrc/init_python_bindings.cpp | 15 ++- torch_xla/csrc/tensor_util.cpp | 4 +- torch_xla/csrc/xla_sharding_util.cpp | 130 +++++++++++++----------- torch_xla/csrc/xla_sharding_util.h | 11 +- 4 files changed, 94 insertions(+), 66 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index e004f538e54..f497f906e14 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -792,11 +792,24 @@ void InitXlaModuleBindings(py::module m) { const py::list& group_assignment, const py::list& replication_groups, int sharding_type, bool minibatch) { + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, nullptr); + int num_local_devices = + runtime::GetComputationClient()->GetLocalDevices().size(); + int num_global_devices = + runtime::GetComputationClient()->GetAllDevices().size(); + if (minibatch) { + XLA_CHECK(tile_assignment.size() == num_global_devices) + << "Sharding of input is only supported along batch dimension"; + } + int batch_dim_shape = + tensor.sizes()[0] * num_global_devices / num_local_devices; + tensor_shape.set_dimensions(0, batch_dim_shape); return std::make_shared( ShardingUtil::CreateOpSharding( tile_assignment, group_assignment, replication_groups, ShardingUtil::ShardingType(sharding_type)), - CreateComputationShapeFromTensor(tensor, nullptr), minibatch); + tensor_shape, minibatch); })); m.def("_xla_tensors_from_aten", [](const std::vector& tensors, diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 2153d574ccc..ea7e0b0d714 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -942,7 +942,7 @@ std::vector CreateTensorsData( std::vector local_devices = runtime::GetComputationClient()->GetLocalDevices(); xla::OpSharding sharding; - bool minibatch; + bool minibatch = false; if (shardings[i] != nullptr) { sharding = shardings[i]->sharding; minibatch = shardings[i]->minibatch; @@ -959,7 +959,7 @@ std::vector CreateTensorsData( /*padded=*/true, /*minibatch=*/minibatch); if (minibatch) { // change global shape as tensor is already sharded // accross batch dimesion. - continue; + shape = shardings[i]->shape.value(); } new_handles.push_back(ShardingUtil::CreateShardedData( local_shards, local_devices, shape, sharding)); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index bba16540f48..8c6b57b3cef 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -373,9 +373,9 @@ std::vector ShardingUtil::OutputHandler( } std::vector ShardingUtil::GetShardShape( - const std::vector& tensor_shape, const xla::OpSharding sharding) { + const std::vector& global_shape, const xla::OpSharding sharding) { if (sharding.type() == xla::OpSharding::REPLICATED) { - return tensor_shape; + return global_shape; } else if (sharding.type() == xla::OpSharding::OTHER) { auto tile_shape = sharding.tile_assignment_dimensions(); @@ -385,8 +385,8 @@ std::vector ShardingUtil::GetShardShape( if (sharding.replicate_on_last_tile_dim() && j == tile_shape.size() - 1) { continue; } - shard_shape.push_back(tensor_shape[j] / tile_shape[j] + - (tensor_shape[j] % tile_shape[j] != 0)); + shard_shape.push_back(global_shape[j] / tile_shape[j] + + (global_shape[j] % tile_shape[j] != 0)); } return shard_shape; @@ -395,11 +395,34 @@ std::vector ShardingUtil::GetShardShape( } } +std::vector> +ShardingUtil::GetShardIndicesForBatchShardedTensor( + const std::vector& shard_shape, + const std::vector& tensor_shape, const xla::OpSharding sharding, + const std::vector& devices) { + std::vector> shard_indices( + devices.size()); + if (sharding.type() == xla::OpSharding::OTHER) { + for (int i = 0; i < devices.size(); i++) { + std::vector indices; + for (int j = 0; j < tensor_shape.size(); j++) { + indices.push_back(at::indexing::Slice(0, tensor_shape[j])); + } + indices[0] = + at::indexing::Slice(i * shard_shape[0], (i + 1) * shard_shape[0]); + shard_indices[i] = indices; + } + } else { + TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type(); + } + return shard_indices; +} + std::vector> ShardingUtil::GetShardIndicesForDevices( const std::vector& shard_shape, const std::vector& tensor_shape, const xla::OpSharding sharding, - const std::vector& devices, const bool minibatch) { + const std::vector& devices) { // `shard_indices[dev][dim]` represents the index slice for dimension `dim` // that belongs on device `devices[dev]` if the tensor is sharded. If // `sharding` is REPLICATED, `shard_indices[dev]` will only have a single @@ -415,27 +438,7 @@ ShardingUtil::GetShardIndicesForDevices( std::fill_n(shard_indices.begin(), shard_indices.size(), indices); } else if (sharding.type() == xla::OpSharding::OTHER) { auto device_index = build_index_map(devices); - if (minibatch) { - // shard tensor local to host - int start = 0; - for (int i = 0; i < devices.size(); i++) { - std::vector indices; - for (int j = tile_shape.size() - 1; j >= 0; j--) { - if (sharding.replicate_on_last_tile_dim() && - j == tile_shape.size() - 1) { - continue; - } - auto slice = at::indexing::Slice(0, shard_shape[j]); - if (j == 0) { // batch axis - slice = at::indexing::Slice(start, start + shard_shape[j]); - } - indices.push_back(slice); - } - std::reverse(indices.begin(), indices.end()); - shard_indices[i] = indices; - } - } else { - std::vector tile_assignment_devices( + std::vector tile_assignment_devices( sharding.tile_assignment_devices().begin(), sharding.tile_assignment_devices().end()); if (!sharding.iota_reshape_dims().empty()) { @@ -446,43 +449,42 @@ ShardingUtil::GetShardIndicesForDevices( tileAssignment.array().begin(), tileAssignment.array().end()); } for (size_t i = 0; i < tile_assignment_devices.size(); i++) { - int64_t core = tile_assignment_devices[i]; - if (device_index.find(core) == device_index.end()) { - // Skip any shards whose device is not part of the `devices` list. - continue; - } + int64_t core = tile_assignment_devices[i]; + if (device_index.find(core) == device_index.end()) { + // Skip any shards whose device is not part of the `devices` list. + continue; + } - // Given the shard's row-major index `i`, we need to calculate shard's - // coordinates (n_0, ..., n_d) in the tiling to generate the index - // slices. Using `N_j = tile_shape[j]` and `0 <= n_j < N_j`, the - // following equation needs to be solved for all n_j: - // `i = n_d + N_d * (n_{d-1} + N_{d-1} * (... + (N_1 * - // n_0)))` - // Let `offset_j = n_j + N_j * (n_{j-1} + N_{j-1} * (... + (N_1 * - // n_0)))`. Then `offset_d = i`, `n_j = offset_j % N_j`, and - // `offset_{j-1} = offset_j / N_j`. - int offset = i; - std::vector indices; - for (int j = tile_shape.size() - 1; j >= 0; j--) { - if (sharding.replicate_on_last_tile_dim() && - j == tile_shape.size() - 1) { - // the last tile assignment dimension is replicated, which implies - // that the consecutive `tile_shape[j]` devices hold the replicated. - offset /= tile_shape[j]; - continue; - } - int64_t n_j = offset % tile_shape[j]; - // Clamp the slice bounds to the tensor shape to accurately reflect - // the shard size without padding. - int start = std::min(n_j * shard_shape[j], tensor_shape[j]); - int end = std::min((n_j + 1) * shard_shape[j], tensor_shape[j]); - auto slice = at::indexing::Slice(start, end); - indices.push_back(at::indexing::TensorIndex(slice)); + // Given the shard's row-major index `i`, we need to calculate shard's + // coordinates (n_0, ..., n_d) in the tiling to generate the index + // slices. Using `N_j = tile_shape[j]` and `0 <= n_j < N_j`, the + // following equation needs to be solved for all n_j: + // `i = n_d + N_d * (n_{d-1} + N_{d-1} * (... + (N_1 * + // n_0)))` + // Let `offset_j = n_j + N_j * (n_{j-1} + N_{j-1} * (... + (N_1 * + // n_0)))`. Then `offset_d = i`, `n_j = offset_j % N_j`, and + // `offset_{j-1} = offset_j / N_j`. + int offset = i; + std::vector indices; + for (int j = tile_shape.size() - 1; j >= 0; j--) { + if (sharding.replicate_on_last_tile_dim() && + j == tile_shape.size() - 1) { + // the last tile assignment dimension is replicated, which implies + // that the consecutive `tile_shape[j]` devices hold the replicated. offset /= tile_shape[j]; + continue; } - std::reverse(indices.begin(), indices.end()); - shard_indices[device_index[core]] = indices; + int64_t n_j = offset % tile_shape[j]; + // Clamp the slice bounds to the tensor shape to accurately reflect + // the shard size without padding. + int start = std::min(n_j * shard_shape[j], tensor_shape[j]); + int end = std::min((n_j + 1) * shard_shape[j], tensor_shape[j]); + auto slice = at::indexing::Slice(start, end); + indices.push_back(at::indexing::TensorIndex(slice)); + offset /= tile_shape[j]; } + std::reverse(indices.begin(), indices.end()); + shard_indices[device_index[core]] = indices; } } else { TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type(); @@ -507,8 +509,14 @@ std::vector ShardingUtil::ShardTensor( if (minibatch) { shard_shape[0] = tensor.sizes().vec()[0] / devices.size(); } - auto shard_indices = GetShardIndicesForDevices( - shard_shape, tensor.sizes().vec(), sharding, devices, minibatch); + std::vector> shard_indices; + if (minibatch) { + shard_indices = GetShardIndicesForBatchShardedTensor( + shard_shape, tensor.sizes().vec(), sharding, devices); + } else { + shard_indices = GetShardIndicesForDevices( + shard_shape, tensor.sizes().vec(), sharding, devices); + } for (size_t i = 0; i < shard_indices.size(); i++) { at::Tensor shard = tensor.index( diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index fb50fafab74..c2098769349 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -97,8 +97,15 @@ class ShardingUtil { GetShardIndicesForDevices(const std::vector& shard_shape, const std::vector& tensor_shape, const xla::OpSharding sharding, - const std::vector& devices, - const bool minibatch = false); + const std::vector& devices); + + // Returns the indices for the shards. Supports `OTHER` sharding types and + // called when input is sharded along the batch axis. + static std::vector> + GetShardIndicesForBatchShardedTensor(const std::vector& shard_shape, + const std::vector& tensor_shape, + const xla::OpSharding sharding, + const std::vector& devices); // Shards a tensor and returns the sharded tensors which belong on `devices` // based on the `sharding` spec. REPLICATED sharding should result in shards From 7688be40a13fe6c3c2e2c57fb4d2237f0a5b868f Mon Sep 17 00:00:00 2001 From: Mohit Khatwani Date: Tue, 25 Jul 2023 19:29:34 +0000 Subject: [PATCH 3/9] fix tests --- test/cpp/test_xla_sharding.cpp | 10 ++++++---- torch_xla/csrc/init_python_bindings.cpp | 6 +++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 6397d17f837..c4c610dedab 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -41,12 +41,13 @@ TEST_F(XLAShardingTest, GetShardShape) { {2, 3}, }); auto sharding = xla::HloSharding::Tile(mesh).ToProto(); - auto shard_shape = ShardingUtil::GetShardShape(tensor, sharding); + auto shard_shape = + ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding); // For tiled sharding, each dimension should be halved EXPECT_EQ(shard_shape, std::vector({4, 4})); sharding = xla::HloSharding::Replicate().ToProto(); - shard_shape = ShardingUtil::GetShardShape(tensor, sharding); + shard_shape = ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding); // For replicated sharding, each dimension should be preserved EXPECT_EQ(shard_shape, std::vector({8, 7})); } @@ -60,7 +61,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { {2, 3}, }); auto sharding = xla::HloSharding::Tile(mesh).ToProto(); - auto shard_shape = ShardingUtil::GetShardShape(tensor, sharding); + auto shard_shape = + ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding); auto shard_indices = ShardingUtil::GetShardIndicesForDevices( shard_shape, tensor.sizes().vec(), sharding, devices); EXPECT_EQ(shard_indices.size(), devices.size()); @@ -84,7 +86,7 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { } sharding = xla::HloSharding::Replicate().ToProto(); - shard_shape = ShardingUtil::GetShardShape(tensor, sharding); + shard_shape = ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding); shard_indices = ShardingUtil::GetShardIndicesForDevices( shard_shape, tensor.sizes().vec(), sharding, devices); EXPECT_EQ(shard_indices.size(), devices.size()); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index f497f906e14..eec48f30132 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -801,10 +801,10 @@ void InitXlaModuleBindings(py::module m) { if (minibatch) { XLA_CHECK(tile_assignment.size() == num_global_devices) << "Sharding of input is only supported along batch dimension"; + int batch_dim_shape = + tensor.sizes()[0] * num_global_devices / num_local_devices; + tensor_shape.set_dimensions(0, batch_dim_shape); } - int batch_dim_shape = - tensor.sizes()[0] * num_global_devices / num_local_devices; - tensor_shape.set_dimensions(0, batch_dim_shape); return std::make_shared( ShardingUtil::CreateOpSharding( tile_assignment, group_assignment, replication_groups, From 93aaa550aa2bea3088eb931b383c3498f64ea088 Mon Sep 17 00:00:00 2001 From: Mohit Khatwani Date: Wed, 26 Jul 2023 19:01:57 +0000 Subject: [PATCH 4/9] accept sharding_spec as argument --- test/cpp/test_xla_sharding.cpp | 15 ++++++++++ torch_xla/csrc/init_python_bindings.cpp | 20 ++++++------- torch_xla/csrc/tensor.h | 3 +- torch_xla/csrc/tensor_util.cpp | 6 ++-- torch_xla/csrc/xla_sharding_util.cpp | 39 ++++++++++++++----------- torch_xla/csrc/xla_sharding_util.h | 15 +++++----- 6 files changed, 58 insertions(+), 40 deletions(-) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index c4c610dedab..c0af062d8a1 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -96,6 +96,21 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { } } +// TEST_F(XLAShardingTest, GetShardIndicesForBatchShardedTensor) { +// std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3"}; + +// auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); +// xla::Array2D mesh({ +// {0}, +// {1}, +// {2}, +// {3}, +// }); +// auto sharding = xla::HloSharding::Tile(mesh).ToProto(); +// auto shard_shape = +// ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding); +// } + TEST_F(XLAShardingTest, ShardTensor) { std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5", "TPU:6", "TPU:7"}; diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index eec48f30132..4c843625efb 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -794,13 +794,14 @@ void InitXlaModuleBindings(py::module m) { bool minibatch) { xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); - int num_local_devices = - runtime::GetComputationClient()->GetLocalDevices().size(); - int num_global_devices = - runtime::GetComputationClient()->GetAllDevices().size(); if (minibatch) { + int num_local_devices = + runtime::GetComputationClient()->GetLocalDevices().size(); + int num_global_devices = + runtime::GetComputationClient()->GetAllDevices().size(); XLA_CHECK(tile_assignment.size() == num_global_devices) - << "Sharding of input is only supported along batch dimension"; + << "Minibatch sharding only supports sharding along the batch " + "dimension"; int batch_dim_shape = tensor.sizes()[0] * num_global_devices / num_local_devices; tensor_shape.set_dimensions(0, batch_dim_shape); @@ -1504,10 +1505,9 @@ void InitXlaModuleBindings(py::module m) { for (auto& shard : shards) { shard_devices.push_back(shard->device()); } - + auto sharding_spec = xtensor->sharding_spec(); auto sharding = xtensor->sharding_spec()->sharding; - auto shard_shape = - ShardingUtil::GetShardShape(input.sizes().vec(), sharding); + auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); auto indices = ShardingUtil::GetShardIndicesForDevices( shard_shape, input.sizes().vec(), sharding, shard_devices); @@ -1547,11 +1547,11 @@ void InitXlaModuleBindings(py::module m) { runtime::GetComputationClient()->GetLocalDevices().size()) << "Shards must be provided for all local devices"; auto sharding = xtensor->sharding_spec()->sharding; + auto sharding_spec = xtensor->sharding_spec(); XLA_CHECK(sharding.type() != xla::OpSharding::REPLICATED) << "Replicated tensor should not be loaded from _load_local_shards - " "use copy_"; - auto shard_shape = - ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding); + auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); for (auto shard : shards) { XLA_CHECK(shard.sizes() == shard_shape) << "Input shard shape must include padding: " << shard.sizes() diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 562fb5945d7..f2d12b2325f 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -253,7 +253,6 @@ class XLATensor : public torch::lazy::LazyTensor { // XLA SPMD sharding spec annoation. The XLA tensor uses this to create // HloSharding for replication, manual and tile shardings. struct ShardingSpec { - ShardingSpec(const xla::OpSharding& sharding) : sharding(sharding) {} ShardingSpec(const xla::OpSharding& sharding, const xla::Shape& shape) : sharding(sharding), shape(shape) {} ShardingSpec(const xla::OpSharding& sharding, const xla::Shape& shape, @@ -262,7 +261,7 @@ class XLATensor : public torch::lazy::LazyTensor { xla::OpSharding sharding; // Optional source tensor shape unpartitioned. - std::optional shape; + xla::Shape shape; // Parameter for represent input batch in sharded along batch axes bool minibatch = false; }; diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index ea7e0b0d714..b1d7436e83f 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -955,11 +955,11 @@ std::vector CreateTensorsData( // The execution requires consistent shard sizes, and the zero-padded // values should be ignored. std::vector local_shards = - ShardingUtil::ShardTensor(tensors[i], sharding, local_devices, - /*padded=*/true, /*minibatch=*/minibatch); + ShardingUtil::ShardTensor(tensors[i], shardings[i], local_devices, + /*padded=*/true); if (minibatch) { // change global shape as tensor is already sharded // accross batch dimesion. - shape = shardings[i]->shape.value(); + shape = shardings[i]->shape; } new_handles.push_back(ShardingUtil::CreateShardedData( local_shards, local_devices, shape, sharding)); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 8c6b57b3cef..f46744d6b0a 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -7,6 +7,7 @@ #include "torch/csrc/lazy/core/ir_util.h" #include "torch_xla/csrc/device.h" +#include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/tensor.h" @@ -339,13 +340,13 @@ std::vector ShardingUtil::OutputHandler( XLATensor::ShardingSpecPtr sharding = sharding_specs[i]; if (replicated_output && sharding && (sharding->sharding.type() != xla::OpSharding::REPLICATED)) { - XLA_CHECK(sharding->shape.has_value()) - << "Sharding or Wrapping data shards in OutputHandler requires " - "unpartitioned tensor shape."; + // XLA_CHECK(sharding->shape.has_value()) + // << "Sharding or Wrapping data shards in OutputHandler requires " + // "unpartitioned tensor shape."; // Reshards replicated output if `sharding` is present. std::vector tensors = XlaDataToTensors( {WrapXlaData(sharded_results[0][i])}, - TensorTypeFromXlaType(sharding->shape.value().element_type())); + TensorTypeFromXlaType(sharding->shape.element_type())); outputs.push_back(UnwrapXlaData(CreateTensorsData( tensors, {sharding}, std::vector{GetVirtualDevice().toString()})[0])); @@ -365,7 +366,7 @@ std::vector ShardingUtil::OutputHandler( sharded_results[0][i]->shape()); } outputs.push_back(runtime::GetComputationClient()->WrapDataShards( - shards, GetVirtualDevice().toString(), sharding->shape.value(), + shards, GetVirtualDevice().toString(), sharding->shape, sharding->sharding)); } } @@ -373,7 +374,10 @@ std::vector ShardingUtil::OutputHandler( } std::vector ShardingUtil::GetShardShape( - const std::vector& global_shape, const xla::OpSharding sharding) { + const XLATensor::ShardingSpecPtr shardings) { + auto sharding = shardings->sharding; + auto global_shape = XlaHelpers::GetAllDimensions(shardings->shape); + TF_LOG(ERROR) << "Print global shape" << global_shape; if (sharding.type() == xla::OpSharding::REPLICATED) { return global_shape; } else if (sharding.type() == xla::OpSharding::OTHER) { @@ -396,7 +400,7 @@ std::vector ShardingUtil::GetShardShape( } std::vector> -ShardingUtil::GetShardIndicesForBatchShardedTensor( +ShardingUtil::GetShardIndicesForMinibatchTensor( const std::vector& shard_shape, const std::vector& tensor_shape, const xla::OpSharding sharding, const std::vector& devices) { @@ -493,8 +497,10 @@ ShardingUtil::GetShardIndicesForDevices( } std::vector ShardingUtil::ShardTensor( - const at::Tensor& tensor, const xla::OpSharding sharding, - const std::vector& devices, bool padded, bool minibatch) { + const at::Tensor& tensor, const XLATensor::ShardingSpecPtr shardings, + const std::vector& devices, bool padded) { + auto sharding = shardings->sharding; + bool minibatch = shardings->minibatch; TF_LOG(INFO) << "ShardTensor with sharding type(" << sharding.type() << ")..." << std::endl; auto device_index = build_index_map(devices); @@ -505,13 +511,13 @@ std::vector ShardingUtil::ShardTensor( XLA_CHECK(sharding.tile_shape().dimensions_size() <= 2); XLA_CHECK(tensor.sizes().size() >= sharding.tile_shape().dimensions_size()); - auto shard_shape = GetShardShape(tensor.sizes().vec(), sharding); - if (minibatch) { - shard_shape[0] = tensor.sizes().vec()[0] / devices.size(); - } + auto shard_shape = GetShardShape(shardings); + // if (minibatch) { + // shard_shape[0] = tensor.sizes().vec()[0] / devices.size(); + // } std::vector> shard_indices; if (minibatch) { - shard_indices = GetShardIndicesForBatchShardedTensor( + shard_indices = GetShardIndicesForMinibatchTensor( shard_shape, tensor.sizes().vec(), sharding, devices); } else { shard_indices = GetShardIndicesForDevices( @@ -606,8 +612,7 @@ void ShardingUtil::PrepareOutputShardingPropagation( // replication. auto sharded_data_placeholder = WrapXlaData(runtime::GetComputationClient()->WrapDataShards( - {}, GetVirtualDevice().toString(), - (*sharding_specs)[i]->shape.value(), + {}, GetVirtualDevice().toString(), (*sharding_specs)[i]->shape, (*sharding_specs)[i]->sharding)); // Register the sharded data placeholder to the tensor and its node. @@ -670,7 +675,7 @@ void ShardingUtil::PrepareOutputShardingPropagation( // replication. auto sharded_data_placeholder = WrapXlaData(runtime::GetComputationClient()->WrapDataShards( - {}, GetVirtualDevice().toString(), sharding_specs[i]->shape.value(), + {}, GetVirtualDevice().toString(), sharding_specs[i]->shape, sharding_specs[i]->sharding)); // Register the sharded data placeholder to the tensor and its node. diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index c2098769349..7b7fa906a28 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -88,7 +88,7 @@ class ShardingUtil { // `sharding`. This assumes the shards will be padded to ensure they all // have the same shape. static std::vector GetShardShape( - const std::vector& tensor_shape, const xla::OpSharding sharding); + const XLATensor::ShardingSpecPtr shardings); // Uses the provided `sharding` spec and expected shard shape to determine the // index slices for the shards which belong on `devices`. Only supports @@ -102,10 +102,10 @@ class ShardingUtil { // Returns the indices for the shards. Supports `OTHER` sharding types and // called when input is sharded along the batch axis. static std::vector> - GetShardIndicesForBatchShardedTensor(const std::vector& shard_shape, - const std::vector& tensor_shape, - const xla::OpSharding sharding, - const std::vector& devices); + GetShardIndicesForMinibatchTensor(const std::vector& shard_shape, + const std::vector& tensor_shape, + const xla::OpSharding sharding, + const std::vector& devices); // Shards a tensor and returns the sharded tensors which belong on `devices` // based on the `sharding` spec. REPLICATED sharding should result in shards @@ -117,9 +117,8 @@ class ShardingUtil { // The the returned tensors will be in 1:1 correspondence with the `devices` // vector, so the `i`th result will belong on the `i`th device. static std::vector ShardTensor( - const at::Tensor& tensor, const xla::OpSharding sharding, - const std::vector& devices, bool padded = true, - bool minibatch = false); + const at::Tensor& tensor, const XLATensor::ShardingSpecPtr shardings, + const std::vector& devices, bool padded = true); // Prepares output sharding propagation by extracting output parameter // ShardingSpec into `sharding_specs` from the SPMD compiled `computation` and From 2d21fff724d6d9d28a388427f5ea0bc2ab9f4fc2 Mon Sep 17 00:00:00 2001 From: Mohit Khatwani Date: Wed, 26 Jul 2023 23:46:01 +0000 Subject: [PATCH 5/9] Add test --- test/cpp/test_xla_sharding.cpp | 105 +++++++++++++++--------- torch_xla/csrc/init_python_bindings.cpp | 4 +- torch_xla/csrc/xla_sharding_util.cpp | 13 ++- 3 files changed, 75 insertions(+), 47 deletions(-) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index c0af062d8a1..d1664729f5c 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -36,18 +36,22 @@ class XLAShardingTest : public AtenXlaTensorTestBase {}; TEST_F(XLAShardingTest, GetShardShape) { auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array2D mesh({ {0, 1}, {2, 3}, }); auto sharding = xla::HloSharding::Tile(mesh).ToProto(); + auto sharding_spec = std::make_shared(sharding, tensor_shape); + auto shard_shape = - ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding); + ShardingUtil::GetShardShape(sharding_spec); // For tiled sharding, each dimension should be halved EXPECT_EQ(shard_shape, std::vector({4, 4})); sharding = xla::HloSharding::Replicate().ToProto(); - shard_shape = ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding); + sharding_spec->sharding = sharding; + shard_shape = ShardingUtil::GetShardShape(sharding_spec); // For replicated sharding, each dimension should be preserved EXPECT_EQ(shard_shape, std::vector({8, 7})); } @@ -56,13 +60,15 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3"}; auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); xla::Array2D mesh({ {0, 1}, {2, 3}, }); auto sharding = xla::HloSharding::Tile(mesh).ToProto(); + auto sharding_spec = std::make_shared(sharding, tensor_shape); auto shard_shape = - ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding); + ShardingUtil::GetShardShape(sharding_spec); auto shard_indices = ShardingUtil::GetShardIndicesForDevices( shard_shape, tensor.sizes().vec(), sharding, devices); EXPECT_EQ(shard_indices.size(), devices.size()); @@ -86,7 +92,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { } sharding = xla::HloSharding::Replicate().ToProto(); - shard_shape = ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding); + sharding_spec->sharding = sharding; + shard_shape = ShardingUtil::GetShardShape(sharding_spec); shard_indices = ShardingUtil::GetShardIndicesForDevices( shard_shape, tensor.sizes().vec(), sharding, devices); EXPECT_EQ(shard_indices.size(), devices.size()); @@ -96,20 +103,6 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { } } -// TEST_F(XLAShardingTest, GetShardIndicesForBatchShardedTensor) { -// std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3"}; - -// auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); -// xla::Array2D mesh({ -// {0}, -// {1}, -// {2}, -// {3}, -// }); -// auto sharding = xla::HloSharding::Tile(mesh).ToProto(); -// auto shard_shape = -// ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding); -// } TEST_F(XLAShardingTest, ShardTensor) { std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3", @@ -117,13 +110,15 @@ TEST_F(XLAShardingTest, ShardTensor) { // 1D tiled at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); xla::OpSharding sharding = xla::HloSharding::Tile1D( CreateComputationShapeFromTensor(tensor, GetDefaultDevice()), devices.size()) .ToProto(); + auto sharding_spec = std::make_shared(sharding, tensor_shape); auto shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false); + ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({1})); EXPECT_EQ(shards[1].sizes(), c10::ArrayRef({1})); @@ -131,13 +126,15 @@ TEST_F(XLAShardingTest, ShardTensor) { // 2D tiled, The first dim is halved and the last replicated. The last shard // size should be smaller in dim=1 because it's not evenly divisible. tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); + tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); xla::Array2D mesh({ {0, 1, 2, 3}, {4, 5, 6, 7}, }); sharding = xla::HloSharding::Tile(mesh).ToProto(); + sharding_spec = std::make_shared(sharding, tensor_shape); shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false); + ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({4, 2, 4})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({4, 1, 4})); @@ -146,16 +143,18 @@ TEST_F(XLAShardingTest, ShardTensor) { // size should be smaller in dim=1 because it's not evenly divisible. xla::Array3D cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}); sharding = xla::HloSharding::Tile(cube).ToProto(); + sharding_spec->sharding = sharding; shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false); + ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({8, 2, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({8, 1, 2})); // Replicated, all shards should be identical. sharding = xla::HloSharding::Replicate().ToProto(); + sharding_spec->sharding = sharding; shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false); + ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({8, 7, 4})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({8, 7, 4})); @@ -164,17 +163,19 @@ TEST_F(XLAShardingTest, ShardTensor) { // last shard size should be smaller in dim=2 because it's not evenly // divisible. tensor = at::ones({1, 8, 7, 4}, at::TensorOptions(at::kFloat)); + tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); xla::Array4D tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}}); sharding = xla::HloSharding::Tile(tesseract).ToProto(); + sharding_spec = std::make_shared(sharding, tensor_shape); shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false); + ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({1, 8, 2, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({1, 8, 1, 2})); // 4D tiled and padded, all shard sizes should be idential. shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/true); + ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/true); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({1, 8, 2, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({1, 8, 2, 2})); @@ -183,18 +184,20 @@ TEST_F(XLAShardingTest, ShardTensor) { // last shard size should be smaller in dim=2 because it's not evenly // divisible. tensor = at::ones({10, 1, 8, 7, 4}, at::TensorOptions(at::kFloat)); + tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); xla::Array hypercube(std::vector{1, 1, 2, 2, 2}); hypercube.FillIota(0); sharding = xla::HloSharding::Tile(hypercube).ToProto(); + sharding_spec = std::make_shared(sharding, tensor_shape); shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false); + ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({10, 1, 4, 3, 2})); // 5D tiled and padded, all shard sizes should be identical. shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/true); + ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/true); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); @@ -205,16 +208,17 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) { // 2D tiled, The first dim is halved and the last replicated. at::Tensor tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); xla::Array2D mesh({ {4, 5, 0, 1}, {6, 7, 2, 3}, }); auto sharding = xla::HloSharding::Tile(mesh).ToProto(); - + auto sharding_spec = std::make_shared(sharding, tensor_shape); // For devices at the start of the mesh, all shards should have the same // unpadded shape. auto shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false); + ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 4); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({4, 2, 4})); EXPECT_EQ(shards[3].sizes(), c10::ArrayRef({4, 2, 4})); @@ -226,22 +230,45 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) { {2, 3, 6, 7}, }); sharding = xla::HloSharding::Tile(mesh).ToProto(); + sharding_spec->sharding = sharding; shards = - ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false); + ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 4); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({4, 2, 4})); EXPECT_EQ(shards[3].sizes(), c10::ArrayRef({4, 1, 4})); } +TEST_F(XLAShardingTest, ShardTensorMiniBatch) { + std::vector devices = {"TPU:4", "TPU:5", "TPU:6", "TPU:7"}; + at::Tensor minibatch_tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = CreateComputationShapeFromTensor(minibatch_tensor, nullptr); + tensor_shape.set_dimensions(0,minibatch_tensor.sizes()[0] * 2); // Assuming 2 hosts + xla::Array2D mesh({ + {0}, {1}, {2}, {3}, + {4}, {5}, {6}, {7}, + }); + + auto sharding = xla::HloSharding::Tile(mesh).ToProto(); + auto sharding_spec = std::make_shared(sharding, tensor_shape, /*minibatch=*/true); + auto shards = + ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec, devices, /*padded=*/false); + EXPECT_EQ(shards.size(), 4); + EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({2, 7, 4})); + EXPECT_EQ(shards[3].sizes(), c10::ArrayRef({2, 7, 4})); + +} + TEST_F(XLAShardingTest, EqualShardingSpecs) { + auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({ {0, 1, 2, 3}, {4, 5, 6, 7}, }) - .ToProto()); + .ToProto(), tensor_shape); XLATensor::ShardingSpec tiled_3d( - xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto()); - XLATensor::ShardingSpec replicated(xla::HloSharding::Replicate().ToProto()); + xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto(), tensor_shape); + XLATensor::ShardingSpec replicated(xla::HloSharding::Replicate().ToProto(), tensor_shape); EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_2d)); EXPECT_FALSE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_3d)); EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(replicated, replicated)); @@ -255,13 +282,14 @@ TEST_F(XLAShardingTest, CreateTensorsData) { } std::vector tensors(2); - std::fill_n(tensors.begin(), tensors.size(), - at::ones({8, 8}, at::TensorOptions(at::kFloat))); + auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + std::fill_n(tensors.begin(), tensors.size(), tensor); std::vector devices(2); std::fill_n(devices.begin(), devices.size(), GetDefaultDevice()->toString()); std::vector shardings = { nullptr, std::make_shared( - xla::HloSharding::Replicate().ToProto())}; + xla::HloSharding::Replicate().ToProto(), tensor_shape)}; std::vector tensors_data = CreateTensorsData(tensors, shardings, devices); @@ -303,12 +331,13 @@ TEST_F(XLAShardingTest, InputHandler) { } std::vector tensors(2); - std::fill_n(tensors.begin(), tensors.size(), - at::ones({8, 8}, at::TensorOptions(at::kFloat))); + auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + std::fill_n(tensors.begin(), tensors.size(),tensor); std::vector devices = {"TPU:0", "TPU:1"}; std::vector shardings = { nullptr, std::make_shared( - xla::HloSharding::Replicate().ToProto())}; + xla::HloSharding::Replicate().ToProto(), tensor_shape)}; std::vector tensors_data = CreateTensorsData(tensors, shardings, devices); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 4c843625efb..e91f0644b76 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1554,8 +1554,8 @@ void InitXlaModuleBindings(py::module m) { auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); for (auto shard : shards) { XLA_CHECK(shard.sizes() == shard_shape) - << "Input shard shape must include padding: " << shard.sizes() - << " vs " << shard_shape; + << "Input shard shape must include padding: " << shard.sizes(); + // << " vs " << shard_shape; } auto xla_data = ShardingUtil::CreateShardedData(shards, devices, xtensor->shape(), sharding); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index f46744d6b0a..9401a9252e1 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -376,10 +376,11 @@ std::vector ShardingUtil::OutputHandler( std::vector ShardingUtil::GetShardShape( const XLATensor::ShardingSpecPtr shardings) { auto sharding = shardings->sharding; - auto global_shape = XlaHelpers::GetAllDimensions(shardings->shape); - TF_LOG(ERROR) << "Print global shape" << global_shape; + auto global_shape = shardings->shape.dimensions(); if (sharding.type() == xla::OpSharding::REPLICATED) { - return global_shape; + std::vector globalShape; + globalShape.assign(global_shape.begin(), global_shape.end()); + return globalShape; } else if (sharding.type() == xla::OpSharding::OTHER) { auto tile_shape = sharding.tile_assignment_dimensions(); @@ -501,7 +502,7 @@ std::vector ShardingUtil::ShardTensor( const std::vector& devices, bool padded) { auto sharding = shardings->sharding; bool minibatch = shardings->minibatch; - TF_LOG(INFO) << "ShardTensor with sharding type(" << sharding.type() << ")..." + TF_LOG(INFO) << "ShardTensor with sharding type(" << sharding.type() << ")... and minibatch = " << minibatch << std::endl; auto device_index = build_index_map(devices); std::vector shards(devices.size()); @@ -512,9 +513,7 @@ std::vector ShardingUtil::ShardTensor( XLA_CHECK(tensor.sizes().size() >= sharding.tile_shape().dimensions_size()); auto shard_shape = GetShardShape(shardings); - // if (minibatch) { - // shard_shape[0] = tensor.sizes().vec()[0] / devices.size(); - // } + std::vector> shard_indices; if (minibatch) { shard_indices = GetShardIndicesForMinibatchTensor( From 4fe4210af668cc252ca9944d5903071edf93f01d Mon Sep 17 00:00:00 2001 From: Mohit Khatwani Date: Wed, 26 Jul 2023 23:48:06 +0000 Subject: [PATCH 6/9] lint fix --- test/cpp/test_xla_sharding.cpp | 109 ++++++++++++++---------- torch_xla/csrc/init_python_bindings.cpp | 2 +- torch_xla/csrc/xla_sharding_util.cpp | 4 +- 3 files changed, 66 insertions(+), 49 deletions(-) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index d1664729f5c..9442436a38c 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -36,16 +36,17 @@ class XLAShardingTest : public AtenXlaTensorTestBase {}; TEST_F(XLAShardingTest, GetShardShape) { auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); - xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array2D mesh({ {0, 1}, {2, 3}, }); auto sharding = xla::HloSharding::Tile(mesh).ToProto(); - auto sharding_spec = std::make_shared(sharding, tensor_shape); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); - auto shard_shape = - ShardingUtil::GetShardShape(sharding_spec); + auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); // For tiled sharding, each dimension should be halved EXPECT_EQ(shard_shape, std::vector({4, 4})); @@ -66,9 +67,9 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { {2, 3}, }); auto sharding = xla::HloSharding::Tile(mesh).ToProto(); - auto sharding_spec = std::make_shared(sharding, tensor_shape); - auto shard_shape = - ShardingUtil::GetShardShape(sharding_spec); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); auto shard_indices = ShardingUtil::GetShardIndicesForDevices( shard_shape, tensor.sizes().vec(), sharding, devices); EXPECT_EQ(shard_indices.size(), devices.size()); @@ -103,7 +104,6 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { } } - TEST_F(XLAShardingTest, ShardTensor) { std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5", "TPU:6", "TPU:7"}; @@ -116,9 +116,10 @@ TEST_F(XLAShardingTest, ShardTensor) { CreateComputationShapeFromTensor(tensor, GetDefaultDevice()), devices.size()) .ToProto(); - auto sharding_spec = std::make_shared(sharding, tensor_shape); - auto shards = - ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({1})); EXPECT_EQ(shards[1].sizes(), c10::ArrayRef({1})); @@ -132,9 +133,10 @@ TEST_F(XLAShardingTest, ShardTensor) { {4, 5, 6, 7}, }); sharding = xla::HloSharding::Tile(mesh).ToProto(); - sharding_spec = std::make_shared(sharding, tensor_shape); - shards = - ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); + sharding_spec = + std::make_shared(sharding, tensor_shape); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({4, 2, 4})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({4, 1, 4})); @@ -144,8 +146,8 @@ TEST_F(XLAShardingTest, ShardTensor) { xla::Array3D cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}); sharding = xla::HloSharding::Tile(cube).ToProto(); sharding_spec->sharding = sharding; - shards = - ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({8, 2, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({8, 1, 2})); @@ -153,8 +155,8 @@ TEST_F(XLAShardingTest, ShardTensor) { // Replicated, all shards should be identical. sharding = xla::HloSharding::Replicate().ToProto(); sharding_spec->sharding = sharding; - shards = - ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({8, 7, 4})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({8, 7, 4})); @@ -166,16 +168,17 @@ TEST_F(XLAShardingTest, ShardTensor) { tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); xla::Array4D tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}}); sharding = xla::HloSharding::Tile(tesseract).ToProto(); - sharding_spec = std::make_shared(sharding, tensor_shape); - shards = - ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); + sharding_spec = + std::make_shared(sharding, tensor_shape); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({1, 8, 2, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({1, 8, 1, 2})); // 4D tiled and padded, all shard sizes should be idential. - shards = - ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/true); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/true); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({1, 8, 2, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({1, 8, 2, 2})); @@ -188,16 +191,17 @@ TEST_F(XLAShardingTest, ShardTensor) { xla::Array hypercube(std::vector{1, 1, 2, 2, 2}); hypercube.FillIota(0); sharding = xla::HloSharding::Tile(hypercube).ToProto(); - sharding_spec = std::make_shared(sharding, tensor_shape); - shards = - ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); + sharding_spec = + std::make_shared(sharding, tensor_shape); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({10, 1, 4, 3, 2})); // 5D tiled and padded, all shard sizes should be identical. - shards = - ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/true); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/true); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); @@ -214,11 +218,12 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) { {6, 7, 2, 3}, }); auto sharding = xla::HloSharding::Tile(mesh).ToProto(); - auto sharding_spec = std::make_shared(sharding, tensor_shape); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); // For devices at the start of the mesh, all shards should have the same // unpadded shape. - auto shards = - ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); + auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 4); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({4, 2, 4})); EXPECT_EQ(shards[3].sizes(), c10::ArrayRef({4, 2, 4})); @@ -231,8 +236,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) { }); sharding = xla::HloSharding::Tile(mesh).ToProto(); sharding_spec->sharding = sharding; - shards = - ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 4); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({4, 2, 4})); EXPECT_EQ(shards[3].sizes(), c10::ArrayRef({4, 1, 4})); @@ -240,22 +245,31 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) { TEST_F(XLAShardingTest, ShardTensorMiniBatch) { std::vector devices = {"TPU:4", "TPU:5", "TPU:6", "TPU:7"}; - at::Tensor minibatch_tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); - xla::Shape tensor_shape = CreateComputationShapeFromTensor(minibatch_tensor, nullptr); - tensor_shape.set_dimensions(0,minibatch_tensor.sizes()[0] * 2); // Assuming 2 hosts + at::Tensor minibatch_tensor = + at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(minibatch_tensor, nullptr); + tensor_shape.set_dimensions( + 0, minibatch_tensor.sizes()[0] * 2); // Assuming 2 hosts xla::Array2D mesh({ - {0}, {1}, {2}, {3}, - {4}, {5}, {6}, {7}, + {0}, + {1}, + {2}, + {3}, + {4}, + {5}, + {6}, + {7}, }); auto sharding = xla::HloSharding::Tile(mesh).ToProto(); - auto sharding_spec = std::make_shared(sharding, tensor_shape, /*minibatch=*/true); - auto shards = - ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec, devices, /*padded=*/false); + auto sharding_spec = std::make_shared( + sharding, tensor_shape, /*minibatch=*/true); + auto shards = ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec, + devices, /*padded=*/false); EXPECT_EQ(shards.size(), 4); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({2, 7, 4})); EXPECT_EQ(shards[3].sizes(), c10::ArrayRef({2, 7, 4})); - } TEST_F(XLAShardingTest, EqualShardingSpecs) { @@ -265,10 +279,13 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) { {0, 1, 2, 3}, {4, 5, 6, 7}, }) - .ToProto(), tensor_shape); + .ToProto(), + tensor_shape); XLATensor::ShardingSpec tiled_3d( - xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto(), tensor_shape); - XLATensor::ShardingSpec replicated(xla::HloSharding::Replicate().ToProto(), tensor_shape); + xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto(), + tensor_shape); + XLATensor::ShardingSpec replicated(xla::HloSharding::Replicate().ToProto(), + tensor_shape); EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_2d)); EXPECT_FALSE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_3d)); EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(replicated, replicated)); @@ -333,7 +350,7 @@ TEST_F(XLAShardingTest, InputHandler) { std::vector tensors(2); auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat)); xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); - std::fill_n(tensors.begin(), tensors.size(),tensor); + std::fill_n(tensors.begin(), tensors.size(), tensor); std::vector devices = {"TPU:0", "TPU:1"}; std::vector shardings = { nullptr, std::make_shared( diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index e91f0644b76..4e6343aa90a 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1555,7 +1555,7 @@ void InitXlaModuleBindings(py::module m) { for (auto shard : shards) { XLA_CHECK(shard.sizes() == shard_shape) << "Input shard shape must include padding: " << shard.sizes(); - // << " vs " << shard_shape; + // << " vs " << shard_shape; } auto xla_data = ShardingUtil::CreateShardedData(shards, devices, xtensor->shape(), sharding); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 9401a9252e1..65891a1bc03 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -502,8 +502,8 @@ std::vector ShardingUtil::ShardTensor( const std::vector& devices, bool padded) { auto sharding = shardings->sharding; bool minibatch = shardings->minibatch; - TF_LOG(INFO) << "ShardTensor with sharding type(" << sharding.type() << ")... and minibatch = " << minibatch - << std::endl; + TF_LOG(INFO) << "ShardTensor with sharding type(" << sharding.type() + << ")... and minibatch = " << minibatch << std::endl; auto device_index = build_index_map(devices); std::vector shards(devices.size()); if (sharding.type() == xla::OpSharding::REPLICATED) { From b9ce99fe968766bb50b192a41e3a2314d861aa89 Mon Sep 17 00:00:00 2001 From: Mohit Khatwani Date: Fri, 28 Jul 2023 17:43:22 +0000 Subject: [PATCH 7/9] accept sharding_spec in CreateShardedData --- test/cpp/test_xla_sharding.cpp | 59 +++++++++++++------------ torch_xla/csrc/init_python_bindings.cpp | 14 +++--- torch_xla/csrc/tensor_util.cpp | 26 +++-------- torch_xla/csrc/xla_sharding_util.cpp | 16 +++---- torch_xla/csrc/xla_sharding_util.h | 2 +- 5 files changed, 53 insertions(+), 64 deletions(-) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 9442436a38c..02de6e1ef89 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -50,8 +50,7 @@ TEST_F(XLAShardingTest, GetShardShape) { // For tiled sharding, each dimension should be halved EXPECT_EQ(shard_shape, std::vector({4, 4})); - sharding = xla::HloSharding::Replicate().ToProto(); - sharding_spec->sharding = sharding; + sharding_spec->sharding = xla::HloSharding::Replicate().ToProto(); shard_shape = ShardingUtil::GetShardShape(sharding_spec); // For replicated sharding, each dimension should be preserved EXPECT_EQ(shard_shape, std::vector({8, 7})); @@ -61,7 +60,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3"}; auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); - xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array2D mesh({ {0, 1}, {2, 3}, @@ -91,7 +91,6 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { EXPECT_EQ(slice.step(), 1); } } - sharding = xla::HloSharding::Replicate().ToProto(); sharding_spec->sharding = sharding; shard_shape = ShardingUtil::GetShardShape(sharding_spec); @@ -110,7 +109,8 @@ TEST_F(XLAShardingTest, ShardTensor) { // 1D tiled at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat)); - xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::OpSharding sharding = xla::HloSharding::Tile1D( CreateComputationShapeFromTensor(tensor, GetDefaultDevice()), @@ -127,7 +127,7 @@ TEST_F(XLAShardingTest, ShardTensor) { // 2D tiled, The first dim is halved and the last replicated. The last shard // size should be smaller in dim=1 because it's not evenly divisible. tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); - tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + tensor_shape = CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array2D mesh({ {0, 1, 2, 3}, {4, 5, 6, 7}, @@ -144,8 +144,7 @@ TEST_F(XLAShardingTest, ShardTensor) { // 3D tiled, the first dim is replicated and the last halved. The last shard // size should be smaller in dim=1 because it's not evenly divisible. xla::Array3D cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}); - sharding = xla::HloSharding::Tile(cube).ToProto(); - sharding_spec->sharding = sharding; + sharding_spec->sharding = xla::HloSharding::Tile(cube).ToProto(); shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 8); @@ -153,8 +152,7 @@ TEST_F(XLAShardingTest, ShardTensor) { EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({8, 1, 2})); // Replicated, all shards should be identical. - sharding = xla::HloSharding::Replicate().ToProto(); - sharding_spec->sharding = sharding; + sharding_spec->sharding = xla::HloSharding::Replicate().ToProto(); shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 8); @@ -165,7 +163,7 @@ TEST_F(XLAShardingTest, ShardTensor) { // last shard size should be smaller in dim=2 because it's not evenly // divisible. tensor = at::ones({1, 8, 7, 4}, at::TensorOptions(at::kFloat)); - tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + tensor_shape = CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array4D tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}}); sharding = xla::HloSharding::Tile(tesseract).ToProto(); sharding_spec = @@ -187,7 +185,7 @@ TEST_F(XLAShardingTest, ShardTensor) { // last shard size should be smaller in dim=2 because it's not evenly // divisible. tensor = at::ones({10, 1, 8, 7, 4}, at::TensorOptions(at::kFloat)); - tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + tensor_shape = CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array hypercube(std::vector{1, 1, 2, 2, 2}); hypercube.FillIota(0); sharding = xla::HloSharding::Tile(hypercube).ToProto(); @@ -212,7 +210,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) { // 2D tiled, The first dim is halved and the last replicated. at::Tensor tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); - xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); xla::Array2D mesh({ {4, 5, 0, 1}, {6, 7, 2, 3}, @@ -234,8 +233,7 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) { {0, 1, 4, 5}, {2, 3, 6, 7}, }); - sharding = xla::HloSharding::Tile(mesh).ToProto(); - sharding_spec->sharding = sharding; + sharding_spec->sharding = xla::HloSharding::Tile(mesh).ToProto(); shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 4); @@ -248,25 +246,25 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) { at::Tensor minibatch_tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); xla::Shape tensor_shape = - CreateComputationShapeFromTensor(minibatch_tensor, nullptr); + CreateComputationShapeFromTensor(minibatch_tensor, GetDefaultDevice()); tensor_shape.set_dimensions( 0, minibatch_tensor.sizes()[0] * 2); // Assuming 2 hosts - xla::Array2D mesh({ - {0}, - {1}, - {2}, - {3}, - {4}, - {5}, - {6}, - {7}, + xla::Array3D mesh({ + {{0}}, + {{1}}, + {{2}}, + {{3}}, + {{4}}, + {{5}}, + {{6}}, + {{7}}, }); auto sharding = xla::HloSharding::Tile(mesh).ToProto(); auto sharding_spec = std::make_shared( sharding, tensor_shape, /*minibatch=*/true); auto shards = ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec, - devices, /*padded=*/false); + devices, /*padded=*/true); EXPECT_EQ(shards.size(), 4); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({2, 7, 4})); EXPECT_EQ(shards[3].sizes(), c10::ArrayRef({2, 7, 4})); @@ -274,7 +272,8 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) { TEST_F(XLAShardingTest, EqualShardingSpecs) { auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); - xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({ {0, 1, 2, 3}, {4, 5, 6, 7}, @@ -300,7 +299,8 @@ TEST_F(XLAShardingTest, CreateTensorsData) { std::vector tensors(2); auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat)); - xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); std::fill_n(tensors.begin(), tensors.size(), tensor); std::vector devices(2); std::fill_n(devices.begin(), devices.size(), GetDefaultDevice()->toString()); @@ -349,7 +349,8 @@ TEST_F(XLAShardingTest, InputHandler) { std::vector tensors(2); auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat)); - xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, GetDefaultDevice()); std::fill_n(tensors.begin(), tensors.size(), tensor); std::vector devices = {"TPU:0", "TPU:1"}; std::vector shardings = { diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 4e6343aa90a..419a281691c 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -792,7 +792,7 @@ void InitXlaModuleBindings(py::module m) { const py::list& group_assignment, const py::list& replication_groups, int sharding_type, bool minibatch) { - xla::Shape tensor_shape = + xla::Shape global_shape = CreateComputationShapeFromTensor(tensor, nullptr); if (minibatch) { int num_local_devices = @@ -804,13 +804,13 @@ void InitXlaModuleBindings(py::module m) { "dimension"; int batch_dim_shape = tensor.sizes()[0] * num_global_devices / num_local_devices; - tensor_shape.set_dimensions(0, batch_dim_shape); + global_shape.set_dimensions(0, batch_dim_shape); } return std::make_shared( ShardingUtil::CreateOpSharding( tile_assignment, group_assignment, replication_groups, ShardingUtil::ShardingType(sharding_type)), - tensor_shape, minibatch); + global_shape, minibatch); })); m.def("_xla_tensors_from_aten", [](const std::vector& tensors, @@ -1554,11 +1554,11 @@ void InitXlaModuleBindings(py::module m) { auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); for (auto shard : shards) { XLA_CHECK(shard.sizes() == shard_shape) - << "Input shard shape must include padding: " << shard.sizes(); - // << " vs " << shard_shape; + << "Input shard shape must include padding: " << shard.sizes() + << " vs " << shard_shape; } - auto xla_data = ShardingUtil::CreateShardedData(shards, devices, - xtensor->shape(), sharding); + auto xla_data = + ShardingUtil::CreateShardedData(shards, devices, sharding_spec); xtensor->SetXlaData(WrapXlaData(xla_data)); }); // This is useful for debugging and generating a partitioned HLO separately diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index b1d7436e83f..e2d76982759 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -614,9 +614,10 @@ torch::lazy::BackendDataPtr TensorToXlaData( runtime::GetComputationClient()->GetLocalDevices(); auto replicated_data = std::vector(local_devices.size(), tensor); + auto sharding_spec = std::make_shared( + xla::HloSharding::Replicate().ToProto(), shape); return WrapXlaData(ShardingUtil::CreateShardedData( - replicated_data, local_devices, shape, - xla::HloSharding::Replicate().ToProto())); + replicated_data, local_devices, sharding_spec)); } static const bool transfer_async = @@ -870,9 +871,10 @@ std::vector CreateTensorsData( auto shape = CreateComputationShapeFromTensor(tensors[i], &device); auto replicated_data = std::vector(local_devices.size(), tensors[i]); + auto sharding_spec = std::make_shared( + xla::HloSharding::Replicate().ToProto(), shape); handles.push_back(ShardingUtil::CreateShardedData( - replicated_data, local_devices, shape, - xla::HloSharding::Replicate().ToProto())); + replicated_data, local_devices, sharding_spec)); } return WrapXlaData(handles); } @@ -941,28 +943,14 @@ std::vector CreateTensorsData( // global ordinals (e.g. ["TPU:4", "TPU:5", "TPU:6", "TPU:7"]). std::vector local_devices = runtime::GetComputationClient()->GetLocalDevices(); - xla::OpSharding sharding; - bool minibatch = false; - if (shardings[i] != nullptr) { - sharding = shardings[i]->sharding; - minibatch = shardings[i]->minibatch; - } else { - // If using SPMD and no sharding is attached to the tensor, implicitly - // replicate to all local devices. - sharding = xla::HloSharding::Replicate().ToProto(); - } // Shards the input tensors with padding, to split evenly. // The execution requires consistent shard sizes, and the zero-padded // values should be ignored. std::vector local_shards = ShardingUtil::ShardTensor(tensors[i], shardings[i], local_devices, /*padded=*/true); - if (minibatch) { // change global shape as tensor is already sharded - // accross batch dimesion. - shape = shardings[i]->shape; - } new_handles.push_back(ShardingUtil::CreateShardedData( - local_shards, local_devices, shape, sharding)); + local_shards, local_devices, shardings[i])); } else { auto populate_fn = [&, i, device]( diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 65891a1bc03..b2c6120c9d2 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -340,9 +340,6 @@ std::vector ShardingUtil::OutputHandler( XLATensor::ShardingSpecPtr sharding = sharding_specs[i]; if (replicated_output && sharding && (sharding->sharding.type() != xla::OpSharding::REPLICATED)) { - // XLA_CHECK(sharding->shape.has_value()) - // << "Sharding or Wrapping data shards in OutputHandler requires " - // "unpartitioned tensor shape."; // Reshards replicated output if `sharding` is present. std::vector tensors = XlaDataToTensors( {WrapXlaData(sharded_results[0][i])}, @@ -383,7 +380,6 @@ std::vector ShardingUtil::GetShardShape( return globalShape; } else if (sharding.type() == xla::OpSharding::OTHER) { auto tile_shape = sharding.tile_assignment_dimensions(); - // `shard_shape[j]` is the size of dimension `j` in the resulting shard. std::vector shard_shape; for (int j = 0; j < tile_shape.size(); j++) { @@ -500,13 +496,16 @@ ShardingUtil::GetShardIndicesForDevices( std::vector ShardingUtil::ShardTensor( const at::Tensor& tensor, const XLATensor::ShardingSpecPtr shardings, const std::vector& devices, bool padded) { - auto sharding = shardings->sharding; + xla::OpSharding sharding; + if (shardings != nullptr) { + sharding = shardings->sharding; + } bool minibatch = shardings->minibatch; TF_LOG(INFO) << "ShardTensor with sharding type(" << sharding.type() << ")... and minibatch = " << minibatch << std::endl; auto device_index = build_index_map(devices); std::vector shards(devices.size()); - if (sharding.type() == xla::OpSharding::REPLICATED) { + if (shardings == nullptr || sharding.type() == xla::OpSharding::REPLICATED) { std::fill_n(shards.begin(), shards.size(), tensor); } else if (sharding.type() == xla::OpSharding::OTHER) { XLA_CHECK(sharding.tile_shape().dimensions_size() <= 2); @@ -528,7 +527,6 @@ std::vector ShardingUtil::ShardTensor( c10::ArrayRef(shard_indices[i])); shards[i] = shard.contiguous(at::MemoryFormat::Contiguous); } - // Zero-pad to the right to ensure the sizes are even if (shards.size() > 0 && padded) { for (size_t i = 0; i < shards.size(); ++i) { @@ -684,10 +682,12 @@ void ShardingUtil::PrepareOutputShardingPropagation( runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( std::vector& local_shards, std::vector& devices, - xla::Shape global_shape, xla::OpSharding sharding) { + const XLATensor::ShardingSpecPtr& sharding_spec) { XLA_CHECK(local_shards.size() == devices.size()) << "A device must be speficied for each shard"; std::vector source_tensors; + auto global_shape = sharding_spec->shape; + auto sharding = sharding_spec->sharding; for (int64_t j = 0; j < devices.size(); ++j) { auto shard_device = ParseDeviceString(devices[j]); auto shard_shape = diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 7b7fa906a28..9129c7ac7de 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -147,7 +147,7 @@ class ShardingUtil { // the PjRtShardedData wrapping the shards. static runtime::ComputationClient::DataPtr CreateShardedData( std::vector& shards, std::vector& devices, - xla::Shape global_shape, xla::OpSharding sharding); + const XLATensor::ShardingSpecPtr& sharding_spec); }; } // namespace torch_xla From 48f12c5e0fce6d425b1b5330eef9338507437e30 Mon Sep 17 00:00:00 2001 From: Mohit Khatwani Date: Mon, 31 Jul 2023 22:49:36 +0000 Subject: [PATCH 8/9] remove tensor_shape argument --- test/cpp/test_xla_sharding.cpp | 6 +++--- torch_xla/csrc/xla_sharding_util.cpp | 26 ++++++++++++-------------- torch_xla/csrc/xla_sharding_util.h | 2 -- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 02de6e1ef89..646a45b92f3 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -245,9 +245,9 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) { std::vector devices = {"TPU:4", "TPU:5", "TPU:6", "TPU:7"}; at::Tensor minibatch_tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); - xla::Shape tensor_shape = + xla::Shape global_shape = CreateComputationShapeFromTensor(minibatch_tensor, GetDefaultDevice()); - tensor_shape.set_dimensions( + global_shape.set_dimensions( 0, minibatch_tensor.sizes()[0] * 2); // Assuming 2 hosts xla::Array3D mesh({ {{0}}, @@ -262,7 +262,7 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) { auto sharding = xla::HloSharding::Tile(mesh).ToProto(); auto sharding_spec = std::make_shared( - sharding, tensor_shape, /*minibatch=*/true); + sharding, global_shape, /*minibatch=*/true); auto shards = ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec, devices, /*padded=*/true); EXPECT_EQ(shards.size(), 4); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index b2c6120c9d2..d5f6cb5143c 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -399,22 +399,21 @@ std::vector ShardingUtil::GetShardShape( std::vector> ShardingUtil::GetShardIndicesForMinibatchTensor( const std::vector& shard_shape, - const std::vector& tensor_shape, const xla::OpSharding sharding, const std::vector& devices) { std::vector> shard_indices( devices.size()); - if (sharding.type() == xla::OpSharding::OTHER) { - for (int i = 0; i < devices.size(); i++) { - std::vector indices; - for (int j = 0; j < tensor_shape.size(); j++) { - indices.push_back(at::indexing::Slice(0, tensor_shape[j])); - } - indices[0] = - at::indexing::Slice(i * shard_shape[0], (i + 1) * shard_shape[0]); - shard_indices[i] = indices; + for (int i = 0; i < devices.size(); i++) { + std::vector indices; + // For batch dimension sharding we just change shard indices on first axis + // and copy all indices for all remaining axes. + for (int j = 0; j < shard_shape.size(); j++) { + indices.push_back(at::indexing::Slice(0, shard_shape[j])); } - } else { - TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type(); + // As the tensor is batch sharded we just care about the first dimension + // to calculate shard indices. + indices[0] = + at::indexing::Slice(i * shard_shape[0], (i + 1) * shard_shape[0]); + shard_indices[i] = indices; } return shard_indices; } @@ -515,8 +514,7 @@ std::vector ShardingUtil::ShardTensor( std::vector> shard_indices; if (minibatch) { - shard_indices = GetShardIndicesForMinibatchTensor( - shard_shape, tensor.sizes().vec(), sharding, devices); + shard_indices = GetShardIndicesForMinibatchTensor(shard_shape, devices); } else { shard_indices = GetShardIndicesForDevices( shard_shape, tensor.sizes().vec(), sharding, devices); diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 9129c7ac7de..b22cc7594f5 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -103,8 +103,6 @@ class ShardingUtil { // called when input is sharded along the batch axis. static std::vector> GetShardIndicesForMinibatchTensor(const std::vector& shard_shape, - const std::vector& tensor_shape, - const xla::OpSharding sharding, const std::vector& devices); // Shards a tensor and returns the sharded tensors which belong on `devices` From c7fcdfbc0c5b718f0fc282babc1f09f0a54d6167 Mon Sep 17 00:00:00 2001 From: Mohit Khatwani Date: Wed, 2 Aug 2023 22:59:06 +0000 Subject: [PATCH 9/9] multihost changes --- torch_xla/csrc/xla_sharding_util.cpp | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index d5f6cb5143c..433520e7796 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -496,10 +496,11 @@ std::vector ShardingUtil::ShardTensor( const at::Tensor& tensor, const XLATensor::ShardingSpecPtr shardings, const std::vector& devices, bool padded) { xla::OpSharding sharding; + bool minibatch = false; if (shardings != nullptr) { sharding = shardings->sharding; + minibatch = shardings->minibatch; } - bool minibatch = shardings->minibatch; TF_LOG(INFO) << "ShardTensor with sharding type(" << sharding.type() << ")... and minibatch = " << minibatch << std::endl; auto device_index = build_index_map(devices); @@ -684,8 +685,18 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( XLA_CHECK(local_shards.size() == devices.size()) << "A device must be speficied for each shard"; std::vector source_tensors; - auto global_shape = sharding_spec->shape; - auto sharding = sharding_spec->sharding; + xla::Shape global_shape; + xla::OpSharding sharding; + if (sharding_spec == nullptr) { + // if sharding.type is replicated, global_shape is shape of the tensor. + auto first_device = ParseDeviceString(devices[0]); + global_shape = + CreateComputationShapeFromTensor(local_shards[0], &first_device); + sharding = xla::HloSharding::Replicate().ToProto(); + } else { + global_shape = sharding_spec->shape; + sharding = sharding_spec->sharding; + } for (int64_t j = 0; j < devices.size(); ++j) { auto shard_device = ParseDeviceString(devices[j]); auto shard_shape =