diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 0d07ce75e53b..2aad3be523b4 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -442,6 +442,16 @@ def test_clear_sharding(self): xs.clear_sharding(xt) self.assertFalse(torch_xla._XLAC._get_xla_sharding_spec(xt)) + def test_replication_with_no_clear_sharding(self): + xt = torch.randn(2, 4).to(xm.xla_device()) + # replication + xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (None, None)) + # sharding annotation over an existing replication sharding is permitted. + xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (0, 1)) + if self.n_devices > 1: + self.assertFalse( + "replicated" in torch_xla._XLAC._get_xla_sharding_spec(xt)) + def test_deep_copy(self): xt = torch.randn(2, 4, 8, 16).to(xm.xla_device()) xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)), diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c19645e43529..5b39cefcd798 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1432,12 +1432,13 @@ void InitXlaModuleBindings(py::module m) { cpu_tensor = xtensor->CurrentTensorData().value(); } else { // A new input tensor is not expected to be sharded. But sometimes, - // the same input is used sharding annotation, in which case we can - // skip if it's the same sharding; however, if it's the same input - // with a different sharding then we block & ask the user to clear - // the existing sharding first. + // the same input is called for sharding annotation over multiple steps, + // in which case we can skip if it's the same sharding; however, if it's + // the same input with a different sharding then we block & ask the user + // to clear the existing sharding first. auto current_sharding_spec = xtensor->sharding_spec(); - if (current_sharding_spec) { + if (current_sharding_spec && (current_sharding_spec->sharding.type() != + xla::OpSharding::REPLICATED)) { XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, *current_sharding_spec)) << "Existing annotation must be cleared first."; diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 2ccdd9fca829..da7a5e20a635 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -237,7 +237,8 @@ void XLATensor::SetShardingSpec(const ShardingSpec& sharding) { // Existing annotation must be cleared explicitly. We do not clear and // overwrite the existing sharding on the user's behalf. This is a no-op if // the same sharding already applied. - if (!sharding_spec()) { + if (!sharding_spec() || + (sharding_spec()->sharding.type() == xla::OpSharding::REPLICATED)) { TORCH_LAZY_COUNTER("SetShardingSpec", 1); data()->sharding = std::make_shared(sharding); } else {