Skip to content

Commit

Permalink
[SPMD] Fix XLA_DUMP_POST_OPTIMIZATIONS test (#5485)
Browse files Browse the repository at this point in the history
Summary:
XLA_DUMP_POST_OPTIMIZATIONS was set as static which means that the value will be fixed during the whole test run for a particular test suite.

Therefore, let's make a separate file.

Test Plan:
PJRT_DEVICE=TPU USE_XLA_SPMD=1 python test/spmd/test_xla_sharding.py
PJRT_DEVICE=TPU USE_XLA_SPMD=1 python test/spmd/test_xla_sharding_hlo.py
  • Loading branch information
alanwaketan authored and will-cromar committed Sep 14, 2023
1 parent 8e4b543 commit 2bb5ff2
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 11 deletions.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ function run_xla_op_tests {
run_test "$CDIR/pjrt/test_ddp.py"
run_test "$CDIR/pjrt/test_mesh_service.py"
run_test "$CDIR/spmd/test_xla_sharding.py"
run_test "$CDIR/spmd/test_xla_sharding_hlo.py"
run_test "$CDIR/spmd/test_xla_virtual_device.py"
run_test "$CDIR/spmd/test_dynamo_spmd.py"
run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py"
Expand Down
11 changes: 0 additions & 11 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,17 +682,6 @@ def test_mark_sharding_ir(self):

self.assertTrue(torch.allclose(expected, actual.cpu()))

@patch.dict(os.environ, {"XLA_DUMP_POST_OPTIMIZATIONS": "1"})
def test_xla_sharded_hlo_dump_post_optimizations(self):
t1 = torch.randn(1, 128).to(xm.xla_device())
t2 = torch.randn(128, 1).to(xm.xla_device())
xs.mark_sharding(t1, self._get_mesh((1, self.n_devices)), (0, 1))

t3 = t1 @ t2
hlo = torch_xla._XLAC._get_xla_tensors_hlo([t3])
if self.n_devices > 1:
self.assertIn('all-reduce', hlo)

def test_sharded_tensor_aliasing(self):
met.clear_all()
partition_spec = (0, 1)
Expand Down
38 changes: 38 additions & 0 deletions test/spmd/test_xla_sharding_hlo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import copy

import unittest
from unittest.mock import patch
import os
import sys

import torch
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs

import test_xla_sharding_base


class BasicShardingTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

@patch.dict(os.environ, {"XLA_DUMP_POST_OPTIMIZATIONS": "1"})
def test_xla_sharded_hlo_dump_post_optimizations(self):
t1 = torch.randn(1, 128).to(xm.xla_device())
t2 = torch.randn(128, 1).to(xm.xla_device())
xs.mark_sharding(t1, self._get_mesh((1, self.n_devices)), (0, 1))

t3 = t1 @ t2
hlo = torch_xla._XLAC._get_xla_tensors_hlo([t3])
if self.n_devices > 1:
self.assertIn('all-reduce', hlo)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)

0 comments on commit 2bb5ff2

Please sign in to comment.