diff --git a/test/run_tests.sh b/test/run_tests.sh index ddcf89ea60e9..490352eb0c0b 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -201,6 +201,7 @@ function run_xla_op_tests { run_test "$CDIR/pjrt/test_ddp.py" run_test "$CDIR/pjrt/test_mesh_service.py" run_test "$CDIR/spmd/test_xla_sharding.py" + run_test "$CDIR/spmd/test_xla_sharding_hlo.py" run_test "$CDIR/spmd/test_xla_virtual_device.py" run_test "$CDIR/spmd/test_dynamo_spmd.py" run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py" diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 1ded8445e5ff..8fc907cce540 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -682,17 +682,6 @@ def test_mark_sharding_ir(self): self.assertTrue(torch.allclose(expected, actual.cpu())) - @patch.dict(os.environ, {"XLA_DUMP_POST_OPTIMIZATIONS": "1"}) - def test_xla_sharded_hlo_dump_post_optimizations(self): - t1 = torch.randn(1, 128).to(xm.xla_device()) - t2 = torch.randn(128, 1).to(xm.xla_device()) - xs.mark_sharding(t1, self._get_mesh((1, self.n_devices)), (0, 1)) - - t3 = t1 @ t2 - hlo = torch_xla._XLAC._get_xla_tensors_hlo([t3]) - if self.n_devices > 1: - self.assertIn('all-reduce', hlo) - def test_sharded_tensor_aliasing(self): met.clear_all() partition_spec = (0, 1) diff --git a/test/spmd/test_xla_sharding_hlo.py b/test/spmd/test_xla_sharding_hlo.py new file mode 100644 index 000000000000..3a39a9062614 --- /dev/null +++ b/test/spmd/test_xla_sharding_hlo.py @@ -0,0 +1,38 @@ +import copy + +import unittest +from unittest.mock import patch +import os +import sys + +import torch +import torch_xla +import torch_xla.runtime as xr +import torch_xla.core.xla_model as xm +import torch_xla.experimental.xla_sharding as xs + +import test_xla_sharding_base + + +class BasicShardingTest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + xr.use_spmd() + super().setUpClass() + + @patch.dict(os.environ, {"XLA_DUMP_POST_OPTIMIZATIONS": "1"}) + def test_xla_sharded_hlo_dump_post_optimizations(self): + t1 = torch.randn(1, 128).to(xm.xla_device()) + t2 = torch.randn(128, 1).to(xm.xla_device()) + xs.mark_sharding(t1, self._get_mesh((1, self.n_devices)), (0, 1)) + + t3 = t1 @ t2 + hlo = torch_xla._XLAC._get_xla_tensors_hlo([t3]) + if self.n_devices > 1: + self.assertIn('all-reduce', hlo) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1)