Skip to content

Commit

Permalink
accept sharding_spec in CreateShardedData
Browse files Browse the repository at this point in the history
  • Loading branch information
khatwanimohit committed Jul 28, 2023
1 parent 26afe27 commit 36b6292
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 65 deletions.
59 changes: 30 additions & 29 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ TEST_F(XLAShardingTest, GetShardShape) {
// For tiled sharding, each dimension should be halved
EXPECT_EQ(shard_shape, std::vector<int64_t>({4, 4}));

sharding = xla::HloSharding::Replicate().ToProto();
sharding_spec->sharding = sharding;
sharding_spec->sharding = xla::HloSharding::Replicate().ToProto();
shard_shape = ShardingUtil::GetShardShape(sharding_spec);
// For replicated sharding, each dimension should be preserved
EXPECT_EQ(shard_shape, std::vector<int64_t>({8, 7}));
Expand All @@ -61,7 +60,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
std::vector<std::string> devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3"};

auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr);
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
xla::Array2D<int64_t> mesh({
{0, 1},
{2, 3},
Expand Down Expand Up @@ -91,7 +91,6 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
EXPECT_EQ(slice.step(), 1);
}
}

sharding = xla::HloSharding::Replicate().ToProto();
sharding_spec->sharding = sharding;
shard_shape = ShardingUtil::GetShardShape(sharding_spec);
Expand All @@ -110,7 +109,8 @@ TEST_F(XLAShardingTest, ShardTensor) {

// 1D tiled
at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr);
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
xla::OpSharding sharding =
xla::HloSharding::Tile1D(
CreateComputationShapeFromTensor(tensor, GetDefaultDevice()),
Expand All @@ -127,7 +127,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
// 2D tiled, The first dim is halved and the last replicated. The last shard
// size should be smaller in dim=1 because it's not evenly divisible.
tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat));
tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr);
tensor_shape = CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
xla::Array2D<int64_t> mesh({
{0, 1, 2, 3},
{4, 5, 6, 7},
Expand All @@ -144,17 +144,15 @@ TEST_F(XLAShardingTest, ShardTensor) {
// 3D tiled, the first dim is replicated and the last halved. The last shard
// size should be smaller in dim=1 because it's not evenly divisible.
xla::Array3D<int64_t> cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}});
sharding = xla::HloSharding::Tile(cube).ToProto();
sharding_spec->sharding = sharding;
sharding_spec->sharding = xla::HloSharding::Tile(cube).ToProto();
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({8, 2, 2}));
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({8, 1, 2}));

// Replicated, all shards should be identical.
sharding = xla::HloSharding::Replicate().ToProto();
sharding_spec->sharding = sharding;
sharding_spec->sharding = xla::HloSharding::Replicate().ToProto();
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
Expand All @@ -165,7 +163,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
// last shard size should be smaller in dim=2 because it's not evenly
// divisible.
tensor = at::ones({1, 8, 7, 4}, at::TensorOptions(at::kFloat));
tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr);
tensor_shape = CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
xla::Array4D<int64_t> tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}});
sharding = xla::HloSharding::Tile(tesseract).ToProto();
sharding_spec =
Expand All @@ -187,7 +185,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
// last shard size should be smaller in dim=2 because it's not evenly
// divisible.
tensor = at::ones({10, 1, 8, 7, 4}, at::TensorOptions(at::kFloat));
tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr);
tensor_shape = CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
xla::Array<int64_t> hypercube(std::vector<int64_t>{1, 1, 2, 2, 2});
hypercube.FillIota(0);
sharding = xla::HloSharding::Tile(hypercube).ToProto();
Expand All @@ -212,7 +210,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {

// 2D tiled, The first dim is halved and the last replicated.
at::Tensor tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr);
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
xla::Array2D<int64_t> mesh({
{4, 5, 0, 1},
{6, 7, 2, 3},
Expand All @@ -234,8 +233,7 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
{0, 1, 4, 5},
{2, 3, 6, 7},
});
sharding = xla::HloSharding::Tile(mesh).ToProto();
sharding_spec->sharding = sharding;
sharding_spec->sharding = xla::HloSharding::Tile(mesh).ToProto();
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 4);
Expand All @@ -248,33 +246,34 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
at::Tensor minibatch_tensor =
at::ones({8, 7, 4}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(minibatch_tensor, nullptr);
CreateComputationShapeFromTensor(minibatch_tensor, GetDefaultDevice());
tensor_shape.set_dimensions(
0, minibatch_tensor.sizes()[0] * 2); // Assuming 2 hosts
xla::Array2D<int64_t> mesh({
{0},
{1},
{2},
{3},
{4},
{5},
{6},
{7},
xla::Array3D<int64_t> mesh({
{{0}},
{{1}},
{{2}},
{{3}},
{{4}},
{{5}},
{{6}},
{{7}},
});

auto sharding = xla::HloSharding::Tile(mesh).ToProto();
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, tensor_shape, /*minibatch=*/true);
auto shards = ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec,
devices, /*padded=*/false);
devices, /*padded=*/true);
EXPECT_EQ(shards.size(), 4);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({2, 7, 4}));
EXPECT_EQ(shards[3].sizes(), c10::ArrayRef<long>({2, 7, 4}));
}

TEST_F(XLAShardingTest, EqualShardingSpecs) {
auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr);
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({
{0, 1, 2, 3},
{4, 5, 6, 7},
Expand All @@ -300,7 +299,8 @@ TEST_F(XLAShardingTest, CreateTensorsData) {

std::vector<at::Tensor> tensors(2);
auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr);
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
std::fill_n(tensors.begin(), tensors.size(), tensor);
std::vector<std::string> devices(2);
std::fill_n(devices.begin(), devices.size(), GetDefaultDevice()->toString());
Expand Down Expand Up @@ -349,7 +349,8 @@ TEST_F(XLAShardingTest, InputHandler) {

std::vector<at::Tensor> tensors(2);
auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, nullptr);
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
std::fill_n(tensors.begin(), tensors.size(), tensor);
std::vector<std::string> devices = {"TPU:0", "TPU:1"};
std::vector<XLATensor::ShardingSpecPtr> shardings = {
Expand Down
16 changes: 8 additions & 8 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,25 +792,25 @@ void InitXlaModuleBindings(py::module m) {
const py::list& group_assignment,
const py::list& replication_groups, int sharding_type,
bool minibatch) {
xla::Shape tensor_shape =
xla::Shape global_shape =
CreateComputationShapeFromTensor(tensor, nullptr);
if (minibatch) {
int num_local_devices =
runtime::GetComputationClient()->GetLocalDevices().size();
int num_global_devices =
runtime::GetComputationClient()->GetAllDevices().size();
XLA_CHECK(tile_assignment.size() == num_global_devices)
XLA_CHECK(tile_assignment.size()[0] == num_global_devices)
<< "Minibatch sharding only supports sharding along the batch "
"dimension";
int batch_dim_shape =
tensor.sizes()[0] * num_global_devices / num_local_devices;
tensor_shape.set_dimensions(0, batch_dim_shape);
global_shape.set_dimensions(0, batch_dim_shape);
}
return std::make_shared<XLATensor::ShardingSpec>(
ShardingUtil::CreateOpSharding(
tile_assignment, group_assignment, replication_groups,
ShardingUtil::ShardingType(sharding_type)),
tensor_shape, minibatch);
global_shape, minibatch);
}));
m.def("_xla_tensors_from_aten",
[](const std::vector<at::Tensor>& tensors,
Expand Down Expand Up @@ -1519,11 +1519,11 @@ void InitXlaModuleBindings(py::module m) {
auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
for (auto shard : shards) {
XLA_CHECK(shard.sizes() == shard_shape)
<< "Input shard shape must include padding: " << shard.sizes();
// << " vs " << shard_shape;
<< "Input shard shape must include padding: " << shard.sizes()
<< " vs " << shard_shape;
}
auto xla_data = ShardingUtil::CreateShardedData(shards, devices,
xtensor->shape(), sharding);
auto xla_data =
ShardingUtil::CreateShardedData(shards, devices, sharding_spec);
xtensor->SetXlaData(WrapXlaData(xla_data));
});
// This is useful for debugging and generating a partitioned HLO separately
Expand Down
26 changes: 7 additions & 19 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -614,9 +614,10 @@ torch::lazy::BackendDataPtr TensorToXlaData(
runtime::GetComputationClient()->GetLocalDevices();
auto replicated_data =
std::vector<at::Tensor>(local_devices.size(), tensor);
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Replicate().ToProto(), shape);
return WrapXlaData(ShardingUtil::CreateShardedData(
replicated_data, local_devices, shape,
xla::HloSharding::Replicate().ToProto()));
replicated_data, local_devices, sharding_spec));
}

static const bool transfer_async =
Expand Down Expand Up @@ -865,9 +866,10 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
auto shape = CreateComputationShapeFromTensor(tensors[i], &device);
auto replicated_data =
std::vector<at::Tensor>(local_devices.size(), tensors[i]);
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Replicate().ToProto(), shape);
handles.push_back(ShardingUtil::CreateShardedData(
replicated_data, local_devices, shape,
xla::HloSharding::Replicate().ToProto()));
replicated_data, local_devices, sharding_spec));
}
return WrapXlaData(handles);
}
Expand Down Expand Up @@ -936,28 +938,14 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
// global ordinals (e.g. ["TPU:4", "TPU:5", "TPU:6", "TPU:7"]).
std::vector<std::string> local_devices =
runtime::GetComputationClient()->GetLocalDevices();
xla::OpSharding sharding;
bool minibatch = false;
if (shardings[i] != nullptr) {
sharding = shardings[i]->sharding;
minibatch = shardings[i]->minibatch;
} else {
// If using SPMD and no sharding is attached to the tensor, implicitly
// replicate to all local devices.
sharding = xla::HloSharding::Replicate().ToProto();
}
// Shards the input tensors with padding, to split evenly.
// The execution requires consistent shard sizes, and the zero-padded
// values should be ignored.
std::vector<at::Tensor> local_shards =
ShardingUtil::ShardTensor(tensors[i], shardings[i], local_devices,
/*padded=*/true);
if (minibatch) { // change global shape as tensor is already sharded
// accross batch dimesion.
shape = shardings[i]->shape;
}
new_handles.push_back(ShardingUtil::CreateShardedData(
local_shards, local_devices, shape, sharding));
local_shards, local_devices, shardings[i]));
} else {
auto populate_fn =
[&, i, device](
Expand Down
17 changes: 9 additions & 8 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,6 @@ std::vector<runtime::ComputationClient::DataPtr> ShardingUtil::OutputHandler(
XLATensor::ShardingSpecPtr sharding = sharding_specs[i];
if (replicated_output && sharding &&
(sharding->sharding.type() != xla::OpSharding::REPLICATED)) {
// XLA_CHECK(sharding->shape.has_value())
// << "Sharding or Wrapping data shards in OutputHandler requires "
// "unpartitioned tensor shape.";
// Reshards replicated output if `sharding` is present.
std::vector<at::Tensor> tensors = XlaDataToTensors(
{WrapXlaData(sharded_results[0][i])},
Expand Down Expand Up @@ -370,7 +367,6 @@ std::vector<int64_t> ShardingUtil::GetShardShape(
return globalShape;
} else if (sharding.type() == xla::OpSharding::OTHER) {
auto tile_shape = sharding.tile_assignment_dimensions();

// `shard_shape[j]` is the size of dimension `j` in the resulting shard.
std::vector<int64_t> shard_shape;
for (int j = 0; j < tile_shape.size(); j++) {
Expand Down Expand Up @@ -477,13 +473,16 @@ ShardingUtil::GetShardIndicesForDevices(
std::vector<at::Tensor> ShardingUtil::ShardTensor(
const at::Tensor& tensor, const XLATensor::ShardingSpecPtr shardings,
const std::vector<std::string>& devices, bool padded) {
auto sharding = shardings->sharding;
xla::OpSharding sharding;
if (shardings != nullptr) {
sharding = shardings->sharding;
}
bool minibatch = shardings->minibatch;
TF_LOG(INFO) << "ShardTensor with sharding type(" << sharding.type()
<< ")... and minibatch = " << minibatch << std::endl;
auto device_index = build_index_map(devices);
std::vector<at::Tensor> shards(devices.size());
if (sharding.type() == xla::OpSharding::REPLICATED) {
if (shardings == nullptr || sharding.type() == xla::OpSharding::REPLICATED) {
std::fill_n(shards.begin(), shards.size(), tensor);
} else if (sharding.type() == xla::OpSharding::OTHER) {
XLA_CHECK(sharding.tile_shape().dimensions_size() <= 2);
Expand All @@ -505,7 +504,7 @@ std::vector<at::Tensor> ShardingUtil::ShardTensor(
c10::ArrayRef<at::indexing::TensorIndex>(shard_indices[i]));
shards[i] = shard.contiguous(at::MemoryFormat::Contiguous);
}

TF_LOG(INFO) << "check shard shape " << shard_shape;
// Zero-pad to the right to ensure the sizes are even
if (shards.size() > 0 && padded) {
for (size_t i = 0; i < shards.size(); ++i) {
Expand Down Expand Up @@ -661,10 +660,12 @@ void ShardingUtil::PrepareOutputShardingPropagation(

runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData(
std::vector<at::Tensor>& local_shards, std::vector<std::string>& devices,
xla::Shape global_shape, xla::OpSharding sharding) {
const XLATensor::ShardingSpecPtr& sharding_spec) {
XLA_CHECK(local_shards.size() == devices.size())
<< "A device must be speficied for each shard";
std::vector<runtime::ComputationClient::TensorSource> source_tensors;
auto global_shape = sharding_spec->shape;
auto sharding = sharding_spec->sharding;
for (int64_t j = 0; j < devices.size(); ++j) {
auto shard_device = ParseDeviceString(devices[j]);
auto shard_shape =
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class ShardingUtil {
// the PjRtShardedData wrapping the shards.
static runtime::ComputationClient::DataPtr CreateShardedData(
std::vector<at::Tensor>& shards, std::vector<std::string>& devices,
xla::Shape global_shape, xla::OpSharding sharding);
const XLATensor::ShardingSpecPtr& sharding_spec);
};

} // namespace torch_xla
Expand Down

0 comments on commit 36b6292

Please sign in to comment.