Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
khatwanimohit committed Jul 25, 2023
1 parent 2042f43 commit 81cb07f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
8 changes: 4 additions & 4 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ TEST_F(XLAShardingTest, GetShardShape) {
{2, 3},
});
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
auto shard_shape = ShardingUtil::GetShardShape(tensor, sharding);
auto shard_shape = ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding);
// For tiled sharding, each dimension should be halved
EXPECT_EQ(shard_shape, std::vector<int64_t>({4, 4}));

sharding = xla::HloSharding::Replicate().ToProto();
shard_shape = ShardingUtil::GetShardShape(tensor, sharding);
shard_shape = ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding);
// For replicated sharding, each dimension should be preserved
EXPECT_EQ(shard_shape, std::vector<int64_t>({8, 7}));
}
Expand All @@ -60,7 +60,7 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
{2, 3},
});
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
auto shard_shape = ShardingUtil::GetShardShape(tensor, sharding);
auto shard_shape = ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding);
auto shard_indices = ShardingUtil::GetShardIndicesForDevices(
shard_shape, tensor.sizes().vec(), sharding, devices);
EXPECT_EQ(shard_indices.size(), devices.size());
Expand All @@ -84,7 +84,7 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
}

sharding = xla::HloSharding::Replicate().ToProto();
shard_shape = ShardingUtil::GetShardShape(tensor, sharding);
shard_shape = ShardingUtil::GetShardShape(tensor.sizes().vec(), sharding);
shard_indices = ShardingUtil::GetShardIndicesForDevices(
shard_shape, tensor.sizes().vec(), sharding, devices);
EXPECT_EQ(shard_indices.size(), devices.size());
Expand Down
6 changes: 3 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -801,10 +801,10 @@ void InitXlaModuleBindings(py::module m) {
if (minibatch) {
XLA_CHECK(tile_assignment.size() == num_global_devices)
<< "Sharding of input is only supported along batch dimension";
}
int batch_dim_shape =
int batch_dim_shape =
tensor.sizes()[0] * num_global_devices / num_local_devices;
tensor_shape.set_dimensions(0, batch_dim_shape);
tensor_shape.set_dimensions(0, batch_dim_shape);
}
return std::make_shared<XLATensor::ShardingSpec>(
ShardingUtil::CreateOpSharding(
tile_assignment, group_assignment, replication_groups,
Expand Down

0 comments on commit 81cb07f

Please sign in to comment.