Skip to content

Commit

Permalink
Sharding should be per output of IR Node, instead of per IR Node (#5330)
Browse files Browse the repository at this point in the history
* sharding should be per output of IR Node, instead of per IR Node

* Update sharding_hash method

* Add test for sharding on IR with multiple output

* fix cpu test

* Fix a bug in getSharding
  • Loading branch information
JackCaoG authored Jul 24, 2023
1 parent f5edcb2 commit 901d154
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 56 deletions.
15 changes: 15 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,21 @@ def test_sharded_tensor_aliasing(self):
xm.mark_step()
self.assertEqual(met.metric_data("InputOutputAliasCount")[0], 1)

def test_mark_sharding_ir_with_multiple_output(self):
partition_spec = (0,)
xt1 = torch.randn(8, 8).to(xm.xla_device())
# max return 2 tensors `value` and `indices`. They are the output
# of the same IR Node `MaxInDim`
(xt_val, xt_index) = torch.max(xt1, 1)
xst_val = xs.mark_sharding(xt_val, self._get_mesh((self.n_devices)),
partition_spec)
# `xst_val`` should have sharding spec now, but `xst_index` should not
self.assertNotEqual(torch_xla._XLAC._get_xla_sharding_spec(xt_val), '')
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xt_index), '')
# xst_index's HLO should not have any sharding
self.assertNotIn('convert(s32[8]{0} %get-tuple-element.25), sharding',
torch_xla._XLAC._get_xla_tensors_hlo([xt_index]))


if __name__ == '__main__':
test = unittest.main()
Expand Down
83 changes: 47 additions & 36 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,14 @@ torch::lazy::hash_t XlaNode::GetOpHash(torch::lazy::OpKind op,
return torch::lazy::HashCombine(h, hash_seed);
}

void XlaNode::SetSharding(const xla::OpSharding& sharding) {
output_sharding_ = std::make_shared<xla::OpSharding>(sharding);
sharding_hash_ = CreateShardingHash(output_sharding_, node_hash_);
void XlaNode::SetSharding(const xla::OpSharding& sharding, size_t index) {
if (output_shardings_.size() == 0) {
output_shardings_ =
std::vector<std::shared_ptr<xla::OpSharding>>(num_outputs());
}
output_shardings_[index] = std::make_shared<xla::OpSharding>(sharding);
// TODO(JackCaoG): fix this hashing
UpdateShardingHash();
}

xla::Shape XlaNode::GetOpShape(
Expand All @@ -179,40 +184,46 @@ const xla::Shape& GetXlaShape(const torch::lazy::Value& value) {

// The sharding hash is only based on relevant fields from the xla::OpSharding
// object. We skip the field that's irrelevant, which is the layout.
torch::lazy::hash_t XlaNode::CreateShardingHash(
std::shared_ptr<xla::OpSharding> sharding, torch::lazy::hash_t hash_seed) {
torch::lazy::hash_t sharding_hash = hash_seed;
for (const auto& tile_assignment_dimension :
sharding->tile_assignment_dimensions()) {
sharding_hash = torch::lazy::HashCombine(
sharding_hash, (uint32_t)tile_assignment_dimension);
}
for (const auto& tile_assignment_device :
sharding->tile_assignment_devices()) {
sharding_hash = torch::lazy::HashCombine(sharding_hash,
(uint32_t)tile_assignment_device);
}
for (const auto& last_tile_dim : sharding->last_tile_dims()) {
sharding_hash =
torch::lazy::HashCombine(sharding_hash, (uint32_t)last_tile_dim);
}
sharding_hash =
torch::lazy::HashCombine(sharding_hash, (uint32_t)sharding->type());
sharding_hash = torch::lazy::HashCombine(
sharding_hash, (uint32_t)sharding->replicate_on_last_tile_dim());

xla::ShapeProto shape_proto = sharding->tile_shape();
sharding_hash = torch::lazy::HashCombine(
sharding_hash, (uint32_t)shape_proto.element_type());
for (const auto& dim : shape_proto.dimensions()) {
sharding_hash = torch::lazy::HashCombine(sharding_hash, (uint32_t)dim);
}
for (const auto& is_dyn_dim : shape_proto.is_dynamic_dimension()) {
sharding_hash =
torch::lazy::HashCombine(sharding_hash, (uint32_t)is_dyn_dim);
void XlaNode::UpdateShardingHash() {
sharding_hash_ = node_hash_;
for (size_t i = 0; i < output_shardings_.size(); i++) {
// keep the index as part of the hash
sharding_hash_ = torch::lazy::HashCombine(sharding_hash_, (uint32_t)i);
std::shared_ptr<xla::OpSharding> sharding = output_shardings_[i];
// skip the hash compute for empty sharding
if (!sharding) {
continue;
}
for (const auto& tile_assignment_dimension :
sharding->tile_assignment_dimensions()) {
sharding_hash_ = torch::lazy::HashCombine(
sharding_hash_, (uint32_t)tile_assignment_dimension);
}
for (const auto& tile_assignment_device :
sharding->tile_assignment_devices()) {
sharding_hash_ = torch::lazy::HashCombine(
sharding_hash_, (uint32_t)tile_assignment_device);
}
for (const auto& last_tile_dim : sharding->last_tile_dims()) {
sharding_hash_ =
torch::lazy::HashCombine(sharding_hash_, (uint32_t)last_tile_dim);
}
sharding_hash_ =
torch::lazy::HashCombine(sharding_hash_, (uint32_t)sharding->type());
sharding_hash_ = torch::lazy::HashCombine(
sharding_hash_, (uint32_t)sharding->replicate_on_last_tile_dim());

xla::ShapeProto shape_proto = sharding->tile_shape();
sharding_hash_ = torch::lazy::HashCombine(
sharding_hash_, (uint32_t)shape_proto.element_type());
for (const auto& dim : shape_proto.dimensions()) {
sharding_hash_ = torch::lazy::HashCombine(sharding_hash_, (uint32_t)dim);
}
for (const auto& is_dyn_dim : shape_proto.is_dynamic_dimension()) {
sharding_hash_ =
torch::lazy::HashCombine(sharding_hash_, (uint32_t)is_dyn_dim);
}
}

return sharding_hash;
}

} // namespace torch_xla
20 changes: 10 additions & 10 deletions torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,17 @@ class XlaNode : public torch::lazy::Node {
torch::lazy::hash_t shardingHash() const { return sharding_hash_; }

// The node's outputs get assigned the same HLO sharding
// TODO: test multi-output example.
const std::shared_ptr<xla::OpSharding> GetSharding() const {
return output_sharding_;
const std::shared_ptr<xla::OpSharding> GetSharding(size_t index) const {
if (output_shardings_.size() == 0) {
return nullptr;
}
return output_shardings_[index];
}

void SetSharding(const xla::OpSharding& sharding);
void SetSharding(const xla::OpSharding& sharding, size_t index);

void ClearSharding() {
output_sharding_ = nullptr;
output_shardings_.clear();
sharding_hash_ = 0;
}

Expand All @@ -145,17 +147,15 @@ class XlaNode : public torch::lazy::Node {

static std::vector<torch::lazy::SourceLocation> GetFrameInfo();

static torch::lazy::hash_t CreateShardingHash(
std::shared_ptr<xla::OpSharding> sharding, torch::lazy::hash_t hash_seed);
void UpdateShardingHash();

xla::Shape xla_shape_;
torch::lazy::hash_t node_hash_ = 0;
torch::lazy::hash_t dag_hash_;
torch::lazy::hash_t sharding_hash_ = 0;

// Experimental sharding annotation attached to the IR node.
// TODO(yeounoh): make sure that view update doesn't reset this.
std::shared_ptr<xla::OpSharding> output_sharding_ = nullptr;
// Experimental sharding annotations attached to the IR node.
std::vector<std::shared_ptr<xla::OpSharding>> output_shardings_;
};

inline std::ostream& operator<<(std::ostream& stream, const XlaNode& node) {
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/ops/device_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ DeviceData::DeviceData(std::shared_ptr<torch::lazy::BackendData> data)
torch_xla::runtime::GetComputationClient()->GetDataSharding(
UnwrapXlaData(data_));
if (op_sharding.has_value()) {
SetSharding(op_sharding.value());
// DeviceData Node only has 1 output.
SetSharding(op_sharding.value(), 0);
}
}

Expand Down
13 changes: 7 additions & 6 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,10 @@ XLATensor::XLATensor(torch::lazy::Value ir_value,
// Preserve sharding if a new tensor is created from a sharded IR node.
if (CurrentIrValue()) {
auto* xla_node = dynamic_cast<XlaNode*>(CurrentIrValue().node.get());
if (xla_node->GetSharding()) {
if (xla_node->GetSharding(CurrentIrValue().index)) {
ShardingSpec sharding =
ShardingSpec{*xla_node->GetSharding(), xla_node->xla_shape()};
ShardingSpec{*xla_node->GetSharding(CurrentIrValue().index),
xla_node->xla_shape()};
SetShardingSpec(sharding);
}
}
Expand Down Expand Up @@ -239,7 +240,7 @@ void XLATensor::SetShardingSpec(const ShardingSpec& sharding) {
<< sharding.sharding.DebugString();
}
dynamic_cast<XlaNode*>(GetIrValue().node.get())
->SetSharding(sharding_spec()->sharding);
->SetSharding(sharding_spec()->sharding, GetIrValue().index);
}
void XLATensor::ClearShardingSpec() {
data()->sharding = nullptr;
Expand All @@ -256,10 +257,10 @@ XLATensor::ShardingSpecPtr XLATensor::sharding_spec() const {
if (sharding && ir_value) {
// The copy of sharding annotation on the IR node should be the same.
auto* xla_node = dynamic_cast<XlaNode*>(ir_value.node.get());
if (xla_node->GetSharding()) {
if (xla_node->GetSharding(ir_value.index)) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(
*sharding,
ShardingSpec{*xla_node->GetSharding(), xla_node->xla_shape()}));
*sharding, ShardingSpec{*xla_node->GetSharding(ir_value.index),
xla_node->xla_shape()}));
}
}
return sharding;
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors(
XLATensor::ShardingSpecPtr sharding = tensors[i]->sharding_spec();
if (sharding) {
dynamic_cast<XlaNode*>(ir_value.node.get())
->SetSharding(sharding->sharding);
->SetSharding(sharding->sharding, ir_value.index);
}
}
} else if (config.force_ltc_data) {
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ bool ShardingUtil::SetHloSharding(LoweringContext* lowering_ctx) {
const torch::lazy::Node* node = elem.first.node;
const XlaNode* xla_node = dynamic_cast<const XlaNode*>(node);
auto instruction = XlaBuilderFriend::GetInstruction(elem.second);
if (xla_node->GetSharding() != nullptr) {
*instruction->mutable_sharding() = *xla_node->GetSharding();
if (xla_node->GetSharding(elem.first.index) != nullptr) {
*instruction->mutable_sharding() =
*xla_node->GetSharding(elem.first.index);
is_sharded = true;
}
}
Expand Down

0 comments on commit 901d154

Please sign in to comment.