Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Python device API for SPMD #5129

Merged
merged 2 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions test/pjrt/test_runtime_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,22 +188,22 @@ def test_spawn_threads(self):
{i: torch.device(f'xla:{i}') for i in range(self.num_devices)})

@staticmethod
def _device_attributes():
return xr.device_attributes(str(xm.xla_device()))
def _runtime_device_attributes():
return xr.runtime_device_attributes(str(xm.xla_device()))

def test_device_attributes(self):
result = pjrt.run_multiprocess(self._device_attributes)
def test_runtime_device_attributes(self):
result = pjrt.run_multiprocess(self._runtime_device_attributes)
for device in result.values():
self.assertCountEqual(['coords', 'core_on_chip'], list(device.keys()))
self.assertIsInstance(device['coords'], list)
self.assertIsInstance(device['core_on_chip'], int)

@staticmethod
def _global_device_attributes():
return xr.global_device_attributes()
def _global_runtime_device_attributes():
return xr.global_runtime_device_attributes()

def test_global_device_attributes(self):
results = pjrt.run_multiprocess(self._global_device_attributes)
def test_global_runtime_device_attributes(self):
results = pjrt.run_multiprocess(self._global_runtime_device_attributes)
for result in results.values():
for device in result:
self.assertCountEqual(['coords', 'core_on_chip'], list(device.keys()))
Expand Down
59 changes: 59 additions & 0 deletions test/spmd/test_spmd_xla_model_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import unittest
import os
import sys

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import test_xla_sharding_base


class BasicXMAPITest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
os.environ["XLA_USE_SPMD"] = "1"
super().setUpClass()

def test_get_xla_supported_devices(self):
device_type = os.environ['PJRT_DEVICE']
devices = xm.get_xla_supported_devices(device_type)
self.assertEqual(len(devices), 1)
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved

def test_world_size(self):
self.assertEqual(xm.xrt_world_size(), 1)

def test_get_ordinal(self):
self.assertEqual(xm.get_ordinal(), 0)

def test_get_local_ordinal(self):
self.assertEqual(xm.get_local_ordinal(), 0)

def test_is_master_ordinal(self):
self.assertTrue(xm.is_master_ordinal())

def test_xla_device(self):
device = xm.xla_device()
self.assertEqual(device, torch.device('xla:0'))

def test_xla_real_devices(self):
device = xm.xla_device()
device_type = os.environ['PJRT_DEVICE']
self.assertEqual(xm.xla_real_devices([device]), [device_type + ':0'])

def test_xla_device_hw(self):
device = xm.xla_device()
device_type = os.environ['PJRT_DEVICE']
replication_devices = xm.xla_replication_devices([device])
self.assertEqual(xm.xla_device_hw(device), device_type)

def test_xla_replication_devices(self):
device = xm.xla_device()
device_type = os.environ['PJRT_DEVICE']
replication_devices = xm.xla_replication_devices([device])
self.assertEqual(xm.xla_real_devices([device]), [device_type + ':0'])


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
2 changes: 1 addition & 1 deletion test/spmd/test_train_spmd_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def train_imagenet():

input_mesh = None
if FLAGS.sharding:
num_devices = xr.global_device_count()
num_devices = xr.global_runtime_device_count()
device_ids = np.arange(num_devices)
# Model sharding
if 'conv' in FLAGS.sharding:
Expand Down
3 changes: 2 additions & 1 deletion test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import nn
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.debug.profiler as xp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.experimental.xla_sharding as xs
Expand Down Expand Up @@ -66,7 +67,7 @@ def train():
torch.manual_seed(42)
model = SimpleLinear().to(device)

num_devices = len(xm.get_xla_supported_devices())
num_devices = xr.global_runtime_device_count()
print(f'num_devices: {num_devices}')
# Define a mesh with all devices along one axis
mesh_shape = (num_devices, 1)
Expand Down
6 changes: 3 additions & 3 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_sharded_to_unsharded(self):

# TODO(jonbolin): Enable tests for resharding into coarser meshes
@unittest.skip("View assignment with virtual device is not yet supported")
@unittest.skipIf(xr.global_device_count() == 1,
@unittest.skipIf(xr.global_runtime_device_count() == 1,
"Multiple devices needed to change mesh")
def test_different_device_mesh(self):
dim = self.n_devices // 2
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_local_load_plan(self):
# If unsharded, there should be a single ReadItem per model parameter
self.assertEqual(parameter_count, len(plan.items))

@unittest.skipIf(xr.global_device_count() == 1,
@unittest.skipIf(xr.global_runtime_device_count() == 1,
"Multiple devices required to shard tensors")
def test_resolve_and_commit_sharded_tensor(self):
model = self._get_sharded_model()
Expand Down Expand Up @@ -261,7 +261,7 @@ def _write_item_assertions(plan, n_devices, parameter_count):
parameter_count = len(list(model.parameters()))
_write_item_assertions(plan, self.n_devices, parameter_count)

@unittest.skipIf(xr.global_device_count() == 1,
@unittest.skipIf(xr.global_runtime_device_count() == 1,
"Multiple devices required to shard tensors")
def test_resolve_shard_data(self):
model = self._get_sharded_model()
Expand Down
7 changes: 5 additions & 2 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
import test_xla_sharding_base

import torch_xla.core.xla_env_vars as xenv
import torch_xla.utils.utils as xu


class BasicShardingTest(test_xla_sharding_base.XlaShardingTest):

Expand Down Expand Up @@ -649,7 +652,7 @@ def test_2d_tensor_3d_mesh(self):

@unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU v2")
@unittest.skipUnless(
xm.get_xla_supported_devices("TPU"),
xu.getenv_as(xenv.PJRT_DEVICE, str) == "TPU",
f"Requires PJRT_DEVICE set to `TPU`.")
def test_hybrid_mesh_shape(self):
mesh = self._get_mesh((1, self.n_devices))
Expand All @@ -659,7 +662,7 @@ def test_hybrid_mesh_shape(self):
hybrid_mesh.get_logical_mesh().shape)

@unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU v2")
@patch('torch_xla.runtime.global_device_attributes')
@patch('torch_xla.runtime.global_runtime_device_attributes')
@patch('torch_xla.core.xla_model.xla_device_hw')
def test_hybrid_mesh(self, xla_device_mock, device_attributes_mock):
# mock device attributes for 2 slices of v4-8
Expand Down
7 changes: 5 additions & 2 deletions test/spmd/test_xla_sharding_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla.core.xla_env_vars as xenv
import torch_xla.utils.utils as xu


@unittest.skipIf(not xr.using_pjrt() or xm.get_xla_supported_devices("GPU"),
@unittest.skipIf(not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) == "GPU",
f"Requires PJRT_DEVICE set to `TPU` or `CPU`.")
class XlaShardingTest(unittest.TestCase):

Expand All @@ -29,7 +32,7 @@ def forward(self, x):

@classmethod
def setUpClass(cls):
cls.n_devices = len(xm.get_xla_supported_devices())
cls.n_devices = xr.global_runtime_device_count()
cls.device_ids = np.array(range(cls.n_devices))

def _get_mesh(self, mesh_shape, device_ids=None):
Expand Down
1 change: 1 addition & 0 deletions test/tpu/xla_test_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ spec:
python3 /src/pytorch/xla/test/spmd/test_xla_sharding.py
python3 /src/pytorch/xla/test/spmd/test_xla_virtual_device.py
python3 /src/pytorch/xla/test/spmd/test_train_spmd_linear_model.py
python3 /src/pytorch/xla/test/spmd/test_spmd_xla_model_api.py
XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shape_models.py -v
XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shapes.py -v
python3 /src/pytorch/xla/test/test_autocast.py
Expand Down
45 changes: 40 additions & 5 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -836,12 +836,47 @@ void InitXlaModuleBindings(py::module m) {
[](const at::Tensor& tensor) { return GetTensorViewAliasId(tensor); });
m.def("_xla_get_tensor_id",
[](const at::Tensor& tensor) { return GetTensorId(tensor); });
m.def("_xla_get_devices",
m.def("_xla_get_devices", []() {
if (UseVirtualDevice()) {
// Under SPMD context, there is only one virtual devices from user
// perspective.
std::vector<std::string> all_devices =
runtime::GetComputationClient()->GetAllDevices();
all_devices.resize(1);
return all_devices;
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved
} else {
return runtime::GetComputationClient()->GetLocalDevices();
}
});
m.def("_xla_num_devices", []() -> int64_t {
if (UseVirtualDevice()) {
return 1;
} else {
return runtime::GetComputationClient()->GetNumDevices();
}
});
m.def("_xla_get_all_devices", []() {
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved
std::vector<std::string> all_devices =
runtime::GetComputationClient()->GetAllDevices();
if (UseVirtualDevice()) {
// Under SPMD context, there is only one virtual devices from user
// perspective.
std::vector<std::string> devices = {all_devices[0]};
return devices;
} else {
return all_devices;
}
});
m.def("_xla_get_runtime_devices",
[]() { return runtime::GetComputationClient()->GetLocalDevices(); });
m.def("_xla_num_devices",
[]() { return runtime::GetComputationClient()->GetNumDevices(); });
m.def("_xla_get_all_devices",
[]() { return runtime::GetComputationClient()->GetAllDevices(); });
m.def("_xla_num_runtime_devices", []() -> int64_t {
return runtime::GetComputationClient()->GetNumDevices();
});
m.def("_xla_get_all_runtime_devices", []() {
std::vector<std::string> all_devices =
runtime::GetComputationClient()->GetAllDevices();
return all_devices;
});
m.def("_xla_real_devices", [](const std::vector<std::string>& devices) {
std::vector<std::string> xla_devices;
{
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/experimental/pjrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

aliases = [
runtime.addressable_device_count,
runtime.device_attributes,
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved
runtime.device_type,
runtime.global_device_attributes,
runtime.global_device_count,
runtime.global_ordinal,
runtime.local_device_count,
Expand All @@ -28,6 +26,9 @@
]

rendezvous = deprecated(this_module, xm.xla_rendezvous)
device_attributes = deprecated(this_module, runtime.runtime_device_attributes)
global_device_attributes = deprecated(this_module,
runtime.global_runtime_device_attributes)

for alias in aliases:
register_deprecated(this_module, alias)
10 changes: 5 additions & 5 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(self,
dcn_mesh_shape = tuple([1] * len(ici_mesh_shape))
assert len(ici_mesh_shape) == len(dcn_mesh_shape)
mesh_shape = tuple([x * y for x, y in zip(ici_mesh_shape, dcn_mesh_shape)])
self.device_attributes = xr.global_device_attributes()
self.device_attributes = xr.global_runtime_device_attributes()
if 'slice_index' in self.device_attributes[0] and np.prod(
dcn_mesh_shape) == 1:
raise ValueError('Provide dcn_mesh_shape to create a mesh for multislice')
Expand Down Expand Up @@ -248,7 +248,7 @@ def _create_device_mesh(self,
"""

if devices is None:
devices = np.arange(xr.global_device_count())
devices = np.arange(xr.global_runtime_device_count())
if np.prod(mesh_shape) != len(devices):
raise ValueError(
f'Number of devices {len(devices)} must equal the product '
Expand Down Expand Up @@ -384,7 +384,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
Examples
—------------------------------
mesh_shape = (4, 2)
num_devices = xr.global_device_count()
num_devices = xr.global_runtime_device_count()
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

Expand All @@ -396,7 +396,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
linear = nn.Linear(32, 10).to(xm.xla_device())
xs.mark_sharding(linear.weight, mesh, (None, 1))
"""
num_devices = xr.global_device_count()
num_devices = xr.global_runtime_device_count()
assert num_devices > 0, "This requires XLA supported device(s)."
assert mesh.size() == num_devices, \
f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
Expand Down Expand Up @@ -480,7 +480,7 @@ def __post_init__(self):
partition_spec, mesh = self.partition_spec, self.mesh
self._tile_assignment = _get_tile_assignment(mesh, partition_spec)
self._sharding_type = _get_sharding_type(partition_spec,
xr.global_device_count())
xr.global_runtime_device_count())
self._group_assignment, self._replication_groups = _get_group_assignment(
self._sharding_type, partition_spec, self._tile_assignment)

Expand Down
27 changes: 20 additions & 7 deletions torch_xla/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,19 +175,32 @@ def process_count() -> int:


@requires_pjrt
def device_attributes(device: str) -> Dict[str, object]:
def host_index() -> int:
if device_type() == 'TPU':
return tpu.worker_id()

# TODO: Update this when we support multi-host GPU
return 0


# API below will be used to query physcial device attribute.
@requires_pjrt
def runtime_device_attributes(device: str) -> Dict[str, object]:
return torch_xla._XLAC._xla_get_device_attributes(device)


@requires_pjrt
def global_device_attributes() -> List[Dict[str, object]]:
def global_runtime_device_attributes() -> List[Dict[str, object]]:
return torch_xla._XLAC._xla_get_all_device_attributes()


@requires_pjrt
def host_index() -> int:
if device_type() == 'TPU':
return tpu.worker_id()
def global_runtime_device_count() -> int:
"""Returns the total number of runtime devices across all processes/hosts."""
return len(torch_xla._XLAC._xla_get_all_runtime_devices())

# TODO: Update this when we support multi-host GPU
return 0

@requires_pjrt
def addressable_runtime_device_count() -> int:
"""Returns the number of devices visible to this process."""
return torch_xla._XLAC._xla_num_runtime_devices()