Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mark_sharding over a replicated tensor is allowed. #5513

Merged
merged 1 commit into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,16 @@ def test_clear_sharding(self):
xs.clear_sharding(xt)
self.assertFalse(torch_xla._XLAC._get_xla_sharding_spec(xt))

def test_replication_with_no_clear_sharding(self):
xt = torch.randn(2, 4).to(xm.xla_device())
# replication
xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (None, None))
# sharding annotation over an existing replication sharding is permitted.
xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (0, 1))
if self.n_devices > 1:
self.assertFalse(
"replicated" in torch_xla._XLAC._get_xla_sharding_spec(xt))

def test_deep_copy(self):
xt = torch.randn(2, 4, 8, 16).to(xm.xla_device())
xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)),
Expand Down
11 changes: 6 additions & 5 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1432,12 +1432,13 @@ void InitXlaModuleBindings(py::module m) {
cpu_tensor = xtensor->CurrentTensorData().value();
} else {
// A new input tensor is not expected to be sharded. But sometimes,
// the same input is used sharding annotation, in which case we can
// skip if it's the same sharding; however, if it's the same input
// with a different sharding then we block & ask the user to clear
// the existing sharding first.
// the same input is called for sharding annotation over multiple steps,
// in which case we can skip if it's the same sharding; however, if it's
// the same input with a different sharding then we block & ask the user
// to clear the existing sharding first.
auto current_sharding_spec = xtensor->sharding_spec();
if (current_sharding_spec) {
if (current_sharding_spec && (current_sharding_spec->sharding.type() !=
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More of a general comment, but what do you think about adding a TraceMe for the profiler after this check (at L1448)? That way it's easy to tell if a TransferFromServer call is due to resharding.

Copy link
Contributor Author

@yeounoh yeounoh Aug 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea :) (for traceability). One problem would be that we won't see this -- as the resharding in this case is a no-op (skipped). We prob want to have an if-statement, to trace if it's already marked replicated, then ... (traceme). For the specificiy, maybe punt on it and rely on TransferFromServer traces for now?

xla::OpSharding::REPLICATED)) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec))
<< "Existing annotation must be cleared first.";
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ void XLATensor::SetShardingSpec(const ShardingSpec& sharding) {
// Existing annotation must be cleared explicitly. We do not clear and
// overwrite the existing sharding on the user's behalf. This is a no-op if
// the same sharding already applied.
if (!sharding_spec()) {
if (!sharding_spec() ||
(sharding_spec()->sharding.type() == xla::OpSharding::REPLICATED)) {
TORCH_LAZY_COUNTER("SetShardingSpec", 1);
data()->sharding = std::make_shared<ShardingSpec>(sharding);
} else {
Expand Down