Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Jun 23, 2023
1 parent 0da54cc commit cfe3ae7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
16 changes: 8 additions & 8 deletions test/pjrt/test_runtime_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,22 +195,22 @@ def test_spawn_threads(self):
{i: torch.device(f'xla:{i}') for i in range(self.num_devices)})

@staticmethod
def _device_attributes():
return xr.device_attributes(str(xm.xla_device()))
def _physical_device_attributes():
return xr.physical_device_attributes(str(xm.xla_device()))

def test_device_attributes(self):
result = pjrt.run_multiprocess(self._device_attributes)
def test_physical_device_attributes(self):
result = pjrt.run_multiprocess(self._physical_device_attributes)
for device in result.values():
self.assertCountEqual(['coords', 'core_on_chip'], list(device.keys()))
self.assertIsInstance(device['coords'], list)
self.assertIsInstance(device['core_on_chip'], int)

@staticmethod
def _global_device_attributes():
return xr.global_device_attributes()
def _global_physical_device_attributes():
return xr.global_physical_device_attributes()

def test_global_device_attributes(self):
results = pjrt.run_multiprocess(self._global_device_attributes)
def test_global_physical_device_attributes(self):
results = pjrt.run_multiprocess(self._global_physical_device_attributes)
for result in results.values():
for device in result:
self.assertCountEqual(['coords', 'core_on_chip'], list(device.keys()))
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/experimental/pjrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

aliases = [
runtime.addressable_device_count,
runtime.device_attributes,
runtime.physical_device_attributes,
runtime.device_type,
runtime.global_device_attributes,
runtime.global_physical_device_attributes,
runtime.global_device_count,
runtime.global_ordinal,
runtime.local_device_count,
Expand Down

0 comments on commit cfe3ae7

Please sign in to comment.