diff --git a/test/pjrt/test_runtime_tpu.py b/test/pjrt/test_runtime_tpu.py index 9b588dfa05bc..56223121158d 100644 --- a/test/pjrt/test_runtime_tpu.py +++ b/test/pjrt/test_runtime_tpu.py @@ -195,22 +195,22 @@ def test_spawn_threads(self): {i: torch.device(f'xla:{i}') for i in range(self.num_devices)}) @staticmethod - def _device_attributes(): - return xr.device_attributes(str(xm.xla_device())) + def _physical_device_attributes(): + return xr.physical_device_attributes(str(xm.xla_device())) - def test_device_attributes(self): - result = pjrt.run_multiprocess(self._device_attributes) + def test_physical_device_attributes(self): + result = pjrt.run_multiprocess(self._physical_device_attributes) for device in result.values(): self.assertCountEqual(['coords', 'core_on_chip'], list(device.keys())) self.assertIsInstance(device['coords'], list) self.assertIsInstance(device['core_on_chip'], int) @staticmethod - def _global_device_attributes(): - return xr.global_device_attributes() + def _global_physical_device_attributes(): + return xr.global_physical_device_attributes() - def test_global_device_attributes(self): - results = pjrt.run_multiprocess(self._global_device_attributes) + def test_global_physical_device_attributes(self): + results = pjrt.run_multiprocess(self._global_physical_device_attributes) for result in results.values(): for device in result: self.assertCountEqual(['coords', 'core_on_chip'], list(device.keys())) diff --git a/torch_xla/experimental/pjrt.py b/torch_xla/experimental/pjrt.py index d57cccb791b5..4b05c199794a 100644 --- a/torch_xla/experimental/pjrt.py +++ b/torch_xla/experimental/pjrt.py @@ -7,9 +7,9 @@ aliases = [ runtime.addressable_device_count, - runtime.device_attributes, + runtime.physical_device_attributes, runtime.device_type, - runtime.global_device_attributes, + runtime.global_physical_device_attributes, runtime.global_device_count, runtime.global_ordinal, runtime.local_device_count,