Skip to content

Commit

Permalink
test: placement tests and shared xla_device simulation for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
walln committed Oct 2, 2024
1 parent 033459a commit e62e309
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 70 deletions.
52 changes: 52 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# conftest.py

import importlib
import os

import pytest


@pytest.fixture(
scope="session",
params=[2, 4],
)
def xla_device_count(request):
device_count = request.param
os.environ["XLA_FORCE_HOST_PLATFORM_DEVICE_COUNT"] = str(device_count)

# Reload JAX to apply the new environment variable
try:
importlib.import_module("jax")
importlib.import_module("jaxlib")
except ImportError as err:
raise ImportError("JAX and JAXLIB must be installed to run tests.") from err

import jax

devices = jax.devices()
yield devices

# Cleanup if necessary
del os.environ["XLA_FORCE_HOST_PLATFORM_DEVICE_COUNT"]


# Alternatively, without lambda for 'params'
@pytest.fixture(scope="session")
def simulated_xla_devices(request):
device_count = request.param
os.environ["XLA_FORCE_HOST_PLATFORM_DEVICE_COUNT"] = str(device_count)

# Reload JAX to apply the new environment variable
try:
importlib.import_module("jax")
importlib.import_module("jaxlib")
except ImportError as err:
raise ImportError("JAX and JAXLIB must be installed to run tests.") from err

import jax

devices = jax.devices()
yield devices

# Cleanup if necessary
del os.environ["XLA_FORCE_HOST_PLATFORM_DEVICE_COUNT"]
90 changes: 20 additions & 70 deletions tests/sharding/test_mesh_shape.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os
import re

import jax
import pytest

Expand Down Expand Up @@ -40,40 +37,28 @@ def test_mesh_config_initialization(mesh_shape, mesh_axis_names, batch_axis_name
assert config.batch_axis_names == batch_axis_names


@pytest.fixture(params=[4, 8])
def set_xla_flags(request):
device_count = request.param
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={device_count}"
yield device_count
del os.environ["XLA_FLAGS"]

@pytest.mark.parametrize("simulated_xla_devices", [4], indirect=True)
def test_mesh_config_create_device_mesh(simulated_xla_devices):
device_count = len(simulated_xla_devices)
mesh_shape = (1, device_count)
config = MeshConfig(mesh_shape=mesh_shape, mesh_axis_names=("data", "model"))
mesh = config.create_device_mesh()

def test_mesh_config_create_device_mesh():
device_count = 4
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={device_count}"
try:
mesh_shape = (1, 4)
config = MeshConfig(mesh_shape=mesh_shape, mesh_axis_names=("data", "model"))
mesh = config.create_device_mesh()
assert isinstance(mesh, jax.sharding.Mesh)
assert dict(zip(config.mesh_axis_names, mesh_shape, strict=False)) == dict(
mesh.shape
)
assert mesh.axis_names == ("data", "model")
assert len(mesh.devices.flatten()) == device_count

assert isinstance(mesh, jax.sharding.Mesh)
assert dict(zip(config.mesh_axis_names, mesh_shape, strict=False)) == dict(
mesh.shape
# Check if the mesh shape is correct for the given device count
if isinstance(config.mesh_shape, HybridMeshShape):
assert (
config.mesh_shape.ici_mesh_shape[0] * config.mesh_shape.ici_mesh_shape[1]
== device_count
)
assert mesh.axis_names == ("data", "model")
assert len(mesh.devices.flatten()) == device_count

# Check if the mesh shape is correct for the given device count
if isinstance(config.mesh_shape, HybridMeshShape):
assert (
config.mesh_shape.ici_mesh_shape[0]
* config.mesh_shape.ici_mesh_shape[1]
== device_count
)
else:
assert config.mesh_shape[0] * config.mesh_shape[1] == device_count
finally:
del os.environ["XLA_FLAGS"]
else:
assert config.mesh_shape[0] * config.mesh_shape[1] == device_count


def test_mesh_config_hosts_and_host_id():
Expand All @@ -83,39 +68,4 @@ def test_mesh_config_hosts_and_host_id():
assert config.host_id == jax.process_index()


@pytest.mark.parametrize(
("mesh_rules", "mesh_selector", "expected_shape", "device_count"),
[
([("cpu", (2, 2))], "cpu", (2, 2), 4),
([("cpu", (4, 1))], "cpu", (4, 1), 4),
([("cpu", (1, 4))], "cpu", (1, 4), 4),
([("gpu", (2, 2)), ("cpu", (1, 4))], "cpu", (1, 4), 4),
([("gpu", (2, 2)), ("cpu", (1, 4))], "gpu", (2, 2), 4),
],
)
def test_mesh_config_with_rules(
mesh_rules, mesh_selector, expected_shape, device_count, monkeypatch
):
monkeypatch.setenv(
"XLA_FLAGS", f"--xla_force_host_platform_device_count={device_count}"
)
config = MeshConfig(
mesh_shape=(1, device_count), # Default shape
mesh_axis_names=("data", "model"),
mesh_rules=mesh_rules,
)

# Override the mesh_shape based on the selected rule
for rule, shape in mesh_rules:
if re.match(rule, mesh_selector):
config.mesh_shape = shape
break

mesh = config.create_device_mesh()
assert dict(zip(config.mesh_axis_names, expected_shape, strict=False)) == dict(
mesh.shape
)
assert len(mesh.devices.flatten()) == device_count

# Verify that all devices are of the expected type (CPU in this case)
assert all(device.platform.lower() == "cpu" for device in mesh.devices.flatten())
# TODO: walln - tests for mesh_rules if ther are going to stay in the API
141 changes: 141 additions & 0 deletions tests/sharding/test_placement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import jax
import numpy as np
import pytest
from jax.sharding import PartitionSpec

from loadax.sharding.partition_spec import DataPartitionType
from loadax.sharding.placement import (
global_to_host_array,
host_to_global_device_array,
with_sharding_constraint,
)


@pytest.mark.parametrize("simulated_xla_devices", [4, 8], indirect=True)
def test_host_to_global_device_array(simulated_xla_devices):
mesh = jax.sharding.Mesh(simulated_xla_devices, ("data",))
with mesh:
host_array = np.array([[1, 2], [3, 4]])
global_array = host_to_global_device_array(
host_array, partition=DataPartitionType.FULL
)

assert isinstance(global_array, jax.Array)
assert global_array.shape == host_array.shape
assert np.array_equal(np.array(global_array), host_array)


@pytest.mark.parametrize("simulated_xla_devices", [4, 8], indirect=True)
def test_global_to_host_array(simulated_xla_devices):
mesh = jax.sharding.Mesh(simulated_xla_devices, ("data",))
with mesh:
global_array = jax.numpy.array([[1, 2], [3, 4]])
host_array = global_to_host_array(
global_array, partition=DataPartitionType.FULL
)

assert isinstance(host_array, np.ndarray)
assert host_array.shape == global_array.shape
assert np.array_equal(host_array, np.array(global_array))


@pytest.mark.parametrize("simulated_xla_devices", [4, 8], indirect=True)
def test_with_sharding_constraint(simulated_xla_devices):
mesh = jax.sharding.Mesh(simulated_xla_devices, ("data",))
with mesh:
x = jax.numpy.array([1, 2, 3, 4])
sharded_x = with_sharding_constraint(x, PartitionSpec("data"))

assert isinstance(sharded_x, jax.Array)
assert np.array_equal(np.array(sharded_x), np.array(x))


@pytest.mark.parametrize("simulated_xla_devices", [4, 8], indirect=True)
@pytest.mark.parametrize(
"partition", [DataPartitionType.FULL, DataPartitionType.REPLICATED]
)
def test_host_to_global_device_array_partition_types(simulated_xla_devices, partition):
mesh = jax.sharding.Mesh(simulated_xla_devices, ("data",))
device_count = len(simulated_xla_devices)
with mesh:
host_array = np.array([[i, i + 1] for i in range(0, device_count * 2, 2)])
global_array = host_to_global_device_array(host_array, partition=partition)

assert isinstance(global_array, jax.Array)
assert global_array.shape == host_array.shape
assert np.array_equal(np.array(global_array), host_array)

if partition == DataPartitionType.FULL:
assert len(global_array.sharding.device_set) == device_count
elif partition == DataPartitionType.REPLICATED:
assert len(global_array.sharding.device_set) == 1


@pytest.mark.parametrize("simulated_xla_devices", [4, 8], indirect=True)
def test_host_to_global_device_array_nested(simulated_xla_devices):
mesh = jax.sharding.Mesh(simulated_xla_devices, ("data",))
with mesh:
host_nested = {"a": np.array([1, 2]), "b": {"c": np.array([3, 4])}}
global_nested = host_to_global_device_array(
host_nested, partition=DataPartitionType.FULL
)

assert isinstance(global_nested, dict)
assert isinstance(global_nested["a"], jax.Array)
assert isinstance(global_nested["b"]["c"], jax.Array)
assert np.array_equal(np.array(global_nested["a"]), host_nested["a"])
assert np.array_equal(np.array(global_nested["b"]["c"]), host_nested["b"]["c"])


@pytest.mark.parametrize("simulated_xla_devices", [4, 8], indirect=True)
def test_global_to_host_array_nested(simulated_xla_devices):
mesh = jax.sharding.Mesh(simulated_xla_devices, ("data",))
with mesh:
global_nested = {
"a": jax.numpy.array([1, 2]),
"b": {"c": jax.numpy.array([3, 4])},
}
host_nested = global_to_host_array(
global_nested, partition=DataPartitionType.FULL
)

assert isinstance(host_nested, dict)
assert isinstance(host_nested["a"], np.ndarray)
assert isinstance(host_nested["b"]["c"], np.ndarray)
assert np.array_equal(host_nested["a"], np.array(global_nested["a"]))
assert np.array_equal(host_nested["b"]["c"], np.array(global_nested["b"]["c"]))


@pytest.mark.parametrize("simulated_xla_devices", [4, 8], indirect=True)
def test_host_to_global_device_array_multi_device(simulated_xla_devices):
device_count = len(simulated_xla_devices)
mesh = jax.sharding.Mesh(simulated_xla_devices, ("data",))

with mesh:
host_array = np.array([[i, i + 1] for i in range(0, device_count * 2, 2)])
global_array = host_to_global_device_array(
host_array, partition=DataPartitionType.FULL
)

assert isinstance(global_array, jax.Array)
assert global_array.shape == host_array.shape
assert np.array_equal(np.array(global_array), host_array)
assert len(global_array.sharding.device_set) == device_count


@pytest.mark.parametrize("simulated_xla_devices", [4, 8], indirect=True)
def test_global_to_host_array_multi_device(simulated_xla_devices):
device_count = len(simulated_xla_devices)
mesh = jax.sharding.Mesh(simulated_xla_devices, ("data",))

with mesh:
global_array = jax.numpy.array(
[[i, i + 1] for i in range(0, device_count * 2, 2)]
)
host_array = global_to_host_array(
global_array, partition=DataPartitionType.FULL
)

assert isinstance(host_array, np.ndarray)
assert host_array.shape == global_array.shape
assert np.array_equal(host_array, np.array(global_array))

0 comments on commit e62e309

Please sign in to comment.