Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPMD] Multi-host batch sharded data loading #5331

Merged
merged 9 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 103 additions & 39 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,22 @@ class XLAShardingTest : public AtenXlaTensorTestBase {};

TEST_F(XLAShardingTest, GetShardShape) {
auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
xla::Array2D<int64_t> mesh({
{0, 1},
{2, 3},
});
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
auto shard_shape = ShardingUtil::GetShardShape(tensor, sharding);
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);

auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
// For tiled sharding, each dimension should be halved
EXPECT_EQ(shard_shape, std::vector<int64_t>({4, 4}));

sharding = xla::HloSharding::Replicate().ToProto();
shard_shape = ShardingUtil::GetShardShape(tensor, sharding);
sharding_spec->sharding = xla::HloSharding::Replicate().ToProto();
shard_shape = ShardingUtil::GetShardShape(sharding_spec);
// For replicated sharding, each dimension should be preserved
EXPECT_EQ(shard_shape, std::vector<int64_t>({8, 7}));
}
Expand All @@ -55,12 +60,16 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
std::vector<std::string> devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3"};

auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
xla::Array2D<int64_t> mesh({
{0, 1},
{2, 3},
});
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
auto shard_shape = ShardingUtil::GetShardShape(tensor, sharding);
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
auto shard_indices = ShardingUtil::GetShardIndicesForDevices(
shard_shape, tensor.sizes().vec(), sharding, devices);
EXPECT_EQ(shard_indices.size(), devices.size());
Expand All @@ -82,9 +91,9 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
EXPECT_EQ(slice.step(), 1);
}
}

sharding = xla::HloSharding::Replicate().ToProto();
shard_shape = ShardingUtil::GetShardShape(tensor, sharding);
sharding_spec->sharding = sharding;
shard_shape = ShardingUtil::GetShardShape(sharding_spec);
shard_indices = ShardingUtil::GetShardIndicesForDevices(
shard_shape, tensor.sizes().vec(), sharding, devices);
EXPECT_EQ(shard_indices.size(), devices.size());
Expand All @@ -100,45 +109,52 @@ TEST_F(XLAShardingTest, ShardTensor) {

// 1D tiled
at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
xla::OpSharding sharding =
xla::HloSharding::Tile1D(
CreateComputationShapeFromTensor(tensor, GetDefaultDevice()),
devices.size())
.ToProto();
auto shards =
ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false);
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({1}));
EXPECT_EQ(shards[1].sizes(), c10::ArrayRef<long>({1}));

// 2D tiled, The first dim is halved and the last replicated. The last shard
// size should be smaller in dim=1 because it's not evenly divisible.
tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat));
tensor_shape = CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
xla::Array2D<int64_t> mesh({
{0, 1, 2, 3},
{4, 5, 6, 7},
});
sharding = xla::HloSharding::Tile(mesh).ToProto();
shards =
ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false);
sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({4, 2, 4}));
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({4, 1, 4}));

// 3D tiled, the first dim is replicated and the last halved. The last shard
// size should be smaller in dim=1 because it's not evenly divisible.
xla::Array3D<int64_t> cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}});
sharding = xla::HloSharding::Tile(cube).ToProto();
shards =
ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false);
sharding_spec->sharding = xla::HloSharding::Tile(cube).ToProto();
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({8, 2, 2}));
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({8, 1, 2}));

// Replicated, all shards should be identical.
sharding = xla::HloSharding::Replicate().ToProto();
shards =
ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false);
sharding_spec->sharding = xla::HloSharding::Replicate().ToProto();
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({8, 7, 4}));
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({8, 7, 4}));
Expand All @@ -147,17 +163,20 @@ TEST_F(XLAShardingTest, ShardTensor) {
// last shard size should be smaller in dim=2 because it's not evenly
// divisible.
tensor = at::ones({1, 8, 7, 4}, at::TensorOptions(at::kFloat));
tensor_shape = CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
xla::Array4D<int64_t> tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}});
sharding = xla::HloSharding::Tile(tesseract).ToProto();
shards =
ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false);
sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({1, 8, 2, 2}));
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({1, 8, 1, 2}));

// 4D tiled and padded, all shard sizes should be idential.
shards =
ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/true);
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/true);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({1, 8, 2, 2}));
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({1, 8, 2, 2}));
Expand All @@ -166,18 +185,21 @@ TEST_F(XLAShardingTest, ShardTensor) {
// last shard size should be smaller in dim=2 because it's not evenly
// divisible.
tensor = at::ones({10, 1, 8, 7, 4}, at::TensorOptions(at::kFloat));
tensor_shape = CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
xla::Array<int64_t> hypercube(std::vector<int64_t>{1, 1, 2, 2, 2});
hypercube.FillIota(0);
sharding = xla::HloSharding::Tile(hypercube).ToProto();
shards =
ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false);
sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({10, 1, 4, 4, 2}));
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({10, 1, 4, 3, 2}));

// 5D tiled and padded, all shard sizes should be identical.
shards =
ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/true);
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/true);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({10, 1, 4, 4, 2}));
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({10, 1, 4, 4, 2}));
Expand All @@ -188,16 +210,19 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {

// 2D tiled, The first dim is halved and the last replicated.
at::Tensor tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
xla::Array2D<int64_t> mesh({
{4, 5, 0, 1},
{6, 7, 2, 3},
});
auto sharding = xla::HloSharding::Tile(mesh).ToProto();

auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
// For devices at the start of the mesh, all shards should have the same
// unpadded shape.
auto shards =
ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false);
auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 4);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({4, 2, 4}));
EXPECT_EQ(shards[3].sizes(), c10::ArrayRef<long>({4, 2, 4}));
Expand All @@ -208,23 +233,58 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
{0, 1, 4, 5},
{2, 3, 6, 7},
});
sharding = xla::HloSharding::Tile(mesh).ToProto();
shards =
ShardingUtil::ShardTensor(tensor, sharding, devices, /*padded=*/false);
sharding_spec->sharding = xla::HloSharding::Tile(mesh).ToProto();
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 4);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({4, 2, 4}));
EXPECT_EQ(shards[3].sizes(), c10::ArrayRef<long>({4, 1, 4}));
}

TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
Copy link
Collaborator

@miladm miladm Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@khatwanimohit did we test minibatch=True on SPMD 2D sharding?

cc @tengyifei

std::vector<std::string> devices = {"TPU:4", "TPU:5", "TPU:6", "TPU:7"};
at::Tensor minibatch_tensor =
at::ones({8, 7, 4}, at::TensorOptions(at::kFloat));
xla::Shape global_shape =
CreateComputationShapeFromTensor(minibatch_tensor, GetDefaultDevice());
global_shape.set_dimensions(
0, minibatch_tensor.sizes()[0] * 2); // Assuming 2 hosts
xla::Array3D<int64_t> mesh({
{{0}},
{{1}},
{{2}},
{{3}},
{{4}},
{{5}},
{{6}},
{{7}},
});

auto sharding = xla::HloSharding::Tile(mesh).ToProto();
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, global_shape, /*minibatch=*/true);
auto shards = ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec,
devices, /*padded=*/true);
EXPECT_EQ(shards.size(), 4);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({2, 7, 4}));
EXPECT_EQ(shards[3].sizes(), c10::ArrayRef<long>({2, 7, 4}));
}

TEST_F(XLAShardingTest, EqualShardingSpecs) {
auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({
{0, 1, 2, 3},
{4, 5, 6, 7},
})
.ToProto());
.ToProto(),
tensor_shape);
XLATensor::ShardingSpec tiled_3d(
xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto());
XLATensor::ShardingSpec replicated(xla::HloSharding::Replicate().ToProto());
xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto(),
tensor_shape);
XLATensor::ShardingSpec replicated(xla::HloSharding::Replicate().ToProto(),
tensor_shape);
EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_2d));
EXPECT_FALSE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_3d));
EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(replicated, replicated));
Expand All @@ -238,13 +298,15 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
}

std::vector<at::Tensor> tensors(2);
std::fill_n(tensors.begin(), tensors.size(),
at::ones({8, 8}, at::TensorOptions(at::kFloat)));
auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
std::fill_n(tensors.begin(), tensors.size(), tensor);
std::vector<std::string> devices(2);
std::fill_n(devices.begin(), devices.size(), GetDefaultDevice()->toString());
std::vector<XLATensor::ShardingSpecPtr> shardings = {
nullptr, std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Replicate().ToProto())};
xla::HloSharding::Replicate().ToProto(), tensor_shape)};
std::vector<torch::lazy::BackendDataPtr> tensors_data =
CreateTensorsData(tensors, shardings, devices);

Expand Down Expand Up @@ -286,12 +348,14 @@ TEST_F(XLAShardingTest, InputHandler) {
}

std::vector<at::Tensor> tensors(2);
std::fill_n(tensors.begin(), tensors.size(),
at::ones({8, 8}, at::TensorOptions(at::kFloat)));
auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
std::fill_n(tensors.begin(), tensors.size(), tensor);
std::vector<std::string> devices = {"TPU:0", "TPU:1"};
std::vector<XLATensor::ShardingSpecPtr> shardings = {
nullptr, std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Replicate().ToProto())};
xla::HloSharding::Replicate().ToProto(), tensor_shape)};
std::vector<torch::lazy::BackendDataPtr> tensors_data =
CreateTensorsData(tensors, shardings, devices);

Expand Down
30 changes: 23 additions & 7 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -790,12 +790,27 @@ void InitXlaModuleBindings(py::module m) {
m, "XlaShardingSpec")
.def(py::init([](at::Tensor tensor, const py::list& tile_assignment,
const py::list& group_assignment,
const py::list& replication_groups, int sharding_type) {
const py::list& replication_groups, int sharding_type,
bool minibatch) {
xla::Shape global_shape =
CreateComputationShapeFromTensor(tensor, nullptr);
if (minibatch) {
int num_local_devices =
runtime::GetComputationClient()->GetLocalDevices().size();
int num_global_devices =
runtime::GetComputationClient()->GetAllDevices().size();
XLA_CHECK(tile_assignment.size() == num_global_devices)
alanwaketan marked this conversation as resolved.
Show resolved Hide resolved
<< "Minibatch sharding only supports sharding along the batch "
"dimension";
int batch_dim_shape =
tensor.sizes()[0] * num_global_devices / num_local_devices;
global_shape.set_dimensions(0, batch_dim_shape);
}
return std::make_shared<XLATensor::ShardingSpec>(
ShardingUtil::CreateOpSharding(
tile_assignment, group_assignment, replication_groups,
ShardingUtil::ShardingType(sharding_type)),
CreateComputationShapeFromTensor(tensor, nullptr));
global_shape, minibatch);
}));
m.def("_xla_tensors_from_aten",
[](const std::vector<at::Tensor>& tensors,
Expand Down Expand Up @@ -1490,9 +1505,9 @@ void InitXlaModuleBindings(py::module m) {
for (auto& shard : shards) {
shard_devices.push_back(shard->device());
}

auto sharding_spec = xtensor->sharding_spec();
auto sharding = xtensor->sharding_spec()->sharding;
auto shard_shape = ShardingUtil::GetShardShape(input, sharding);
auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
auto indices = ShardingUtil::GetShardIndicesForDevices(
shard_shape, input.sizes().vec(), sharding, shard_devices);

Expand Down Expand Up @@ -1532,17 +1547,18 @@ void InitXlaModuleBindings(py::module m) {
runtime::GetComputationClient()->GetLocalDevices().size())
<< "Shards must be provided for all local devices";
auto sharding = xtensor->sharding_spec()->sharding;
auto sharding_spec = xtensor->sharding_spec();
XLA_CHECK(sharding.type() != xla::OpSharding::REPLICATED)
<< "Replicated tensor should not be loaded from _load_local_shards - "
"use copy_";
auto shard_shape = ShardingUtil::GetShardShape(tensor, sharding);
auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
for (auto shard : shards) {
XLA_CHECK(shard.sizes() == shard_shape)
<< "Input shard shape must include padding: " << shard.sizes()
<< " vs " << shard_shape;
}
auto xla_data = ShardingUtil::CreateShardedData(shards, devices,
xtensor->shape(), sharding);
auto xla_data =
ShardingUtil::CreateShardedData(shards, devices, sharding_spec);
xtensor->SetXlaData(WrapXlaData(xla_data));
});
// This is useful for debugging and generating a partitioned HLO separately
Expand Down
8 changes: 6 additions & 2 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,17 @@ class XLATensor : public torch::lazy::LazyTensor {
// XLA SPMD sharding spec annoation. The XLA tensor uses this to create
// HloSharding for replication, manual and tile shardings.
struct ShardingSpec {
ShardingSpec(const xla::OpSharding& sharding) : sharding(sharding) {}
ShardingSpec(const xla::OpSharding& sharding, const xla::Shape& shape)
: sharding(sharding), shape(shape) {}
ShardingSpec(const xla::OpSharding& sharding, const xla::Shape& shape,
const bool& minibatch)
: sharding(sharding), shape(shape), minibatch(minibatch) {}

xla::OpSharding sharding;
// Optional source tensor shape unpartitioned.
std::optional<xla::Shape> shape;
xla::Shape shape;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is not optional anymore? Then I guess you need to delete the above comment as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that global tensor's shape doesn't reflect the truth.

// Parameter for represent input batch in sharded along batch axes
bool minibatch = false;
};

// Annotate the IR value with ShardingSpec.
Expand Down
Loading