From 81cb07f6425b04beb4c46ffe04dd05e4def2449d Mon Sep 17 00:00:00 2001 From: Mohit Khatwani Date: Tue, 25 Jul 2023 19:29:34 +0000 Subject: [PATCH] fix tests --- test/cpp/test_xla_sharding.cpp | 8 ++++---- torch_xla/csrc/init_python_bindings.cpp | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 6397d17f837b..78f8dacd18ae 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -41,12 +41,12 @@ TEST_F(XLAShardingTest, GetShardShape) { {2, 3}, }); auto sharding = xla::HloSharding::Tile(mesh).ToProto(); - auto shard_shape = ShardingUtil::GetShardShape(tensor, sharding); + auto shard_shape = ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding); // For tiled sharding, each dimension should be halved EXPECT_EQ(shard_shape, std::vector({4, 4})); sharding = xla::HloSharding::Replicate().ToProto(); - shard_shape = ShardingUtil::GetShardShape(tensor, sharding); + shard_shape = ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding); // For replicated sharding, each dimension should be preserved EXPECT_EQ(shard_shape, std::vector({8, 7})); } @@ -60,7 +60,7 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { {2, 3}, }); auto sharding = xla::HloSharding::Tile(mesh).ToProto(); - auto shard_shape = ShardingUtil::GetShardShape(tensor, sharding); + auto shard_shape = ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding); auto shard_indices = ShardingUtil::GetShardIndicesForDevices( shard_shape, tensor.sizes().vec(), sharding, devices); EXPECT_EQ(shard_indices.size(), devices.size()); @@ -84,7 +84,7 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { } sharding = xla::HloSharding::Replicate().ToProto(); - shard_shape = ShardingUtil::GetShardShape(tensor, sharding); + shard_shape = ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding); shard_indices = ShardingUtil::GetShardIndicesForDevices( shard_shape, tensor.sizes().vec(), sharding, devices); EXPECT_EQ(shard_indices.size(), devices.size()); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index cad6177c6c59..a2dc98b620a0 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -801,10 +801,10 @@ void InitXlaModuleBindings(py::module m) { if (minibatch) { XLA_CHECK(tile_assignment.size() == num_global_devices) << "Sharding of input is only supported along batch dimension"; - } - int batch_dim_shape = + int batch_dim_shape = tensor.sizes()[0] * num_global_devices / num_local_devices; - tensor_shape.set_dimensions(0, batch_dim_shape); + tensor_shape.set_dimensions(0, batch_dim_shape); + } return std::make_shared( ShardingUtil::CreateOpSharding( tile_assignment, group_assignment, replication_groups,