-
Notifications
You must be signed in to change notification settings - Fork 498
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPMD] Fix XLA_DUMP_POST_OPTIMIZATIONS test (#5485)
Summary: XLA_DUMP_POST_OPTIMIZATIONS was set as static which means that the value will be fixed during the whole test run for a particular test suite. Therefore, let's make a separate file. Test Plan: PJRT_DEVICE=TPU USE_XLA_SPMD=1 python test/spmd/test_xla_sharding.py PJRT_DEVICE=TPU USE_XLA_SPMD=1 python test/spmd/test_xla_sharding_hlo.py
- Loading branch information
1 parent
8e4b543
commit 2bb5ff2
Showing
3 changed files
with
39 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import copy | ||
|
||
import unittest | ||
from unittest.mock import patch | ||
import os | ||
import sys | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.runtime as xr | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.experimental.xla_sharding as xs | ||
|
||
import test_xla_sharding_base | ||
|
||
|
||
class BasicShardingTest(test_xla_sharding_base.XlaShardingTest): | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
xr.use_spmd() | ||
super().setUpClass() | ||
|
||
@patch.dict(os.environ, {"XLA_DUMP_POST_OPTIMIZATIONS": "1"}) | ||
def test_xla_sharded_hlo_dump_post_optimizations(self): | ||
t1 = torch.randn(1, 128).to(xm.xla_device()) | ||
t2 = torch.randn(128, 1).to(xm.xla_device()) | ||
xs.mark_sharding(t1, self._get_mesh((1, self.n_devices)), (0, 1)) | ||
|
||
t3 = t1 @ t2 | ||
hlo = torch_xla._XLAC._get_xla_tensors_hlo([t3]) | ||
if self.n_devices > 1: | ||
self.assertIn('all-reduce', hlo) | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |