-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test: placement tests and shared xla_device simulation for tests
- Loading branch information
Showing
3 changed files
with
213 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# conftest.py | ||
|
||
import importlib | ||
import os | ||
|
||
import pytest | ||
|
||
|
||
@pytest.fixture( | ||
scope="session", | ||
params=[2, 4], | ||
) | ||
def xla_device_count(request): | ||
device_count = request.param | ||
os.environ["XLA_FORCE_HOST_PLATFORM_DEVICE_COUNT"] = str(device_count) | ||
|
||
# Reload JAX to apply the new environment variable | ||
try: | ||
importlib.import_module("jax") | ||
importlib.import_module("jaxlib") | ||
except ImportError as err: | ||
raise ImportError("JAX and JAXLIB must be installed to run tests.") from err | ||
|
||
import jax | ||
|
||
devices = jax.devices() | ||
yield devices | ||
|
||
# Cleanup if necessary | ||
del os.environ["XLA_FORCE_HOST_PLATFORM_DEVICE_COUNT"] | ||
|
||
|
||
# Alternatively, without lambda for 'params' | ||
@pytest.fixture(scope="session") | ||
def simulated_xla_devices(request): | ||
device_count = request.param | ||
os.environ["XLA_FORCE_HOST_PLATFORM_DEVICE_COUNT"] = str(device_count) | ||
|
||
# Reload JAX to apply the new environment variable | ||
try: | ||
importlib.import_module("jax") | ||
importlib.import_module("jaxlib") | ||
except ImportError as err: | ||
raise ImportError("JAX and JAXLIB must be installed to run tests.") from err | ||
|
||
import jax | ||
|
||
devices = jax.devices() | ||
yield devices | ||
|
||
# Cleanup if necessary | ||
del os.environ["XLA_FORCE_HOST_PLATFORM_DEVICE_COUNT"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import jax | ||
import numpy as np | ||
import pytest | ||
from jax.sharding import PartitionSpec | ||
|
||
from loadax.sharding.partition_spec import DataPartitionType | ||
from loadax.sharding.placement import ( | ||
global_to_host_array, | ||
host_to_global_device_array, | ||
with_sharding_constraint, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("simulated_xla_devices", [4, 8], indirect=True) | ||
def test_host_to_global_device_array(simulated_xla_devices): | ||
mesh = jax.sharding.Mesh(simulated_xla_devices, ("data",)) | ||
with mesh: | ||
host_array = np.array([[1, 2], [3, 4]]) | ||
global_array = host_to_global_device_array( | ||
host_array, partition=DataPartitionType.FULL | ||
) | ||
|
||
assert isinstance(global_array, jax.Array) | ||
assert global_array.shape == host_array.shape | ||
assert np.array_equal(np.array(global_array), host_array) | ||
|
||
|
||
@pytest.mark.parametrize("simulated_xla_devices", [4, 8], indirect=True) | ||
def test_global_to_host_array(simulated_xla_devices): | ||
mesh = jax.sharding.Mesh(simulated_xla_devices, ("data",)) | ||
with mesh: | ||
global_array = jax.numpy.array([[1, 2], [3, 4]]) | ||
host_array = global_to_host_array( | ||
global_array, partition=DataPartitionType.FULL | ||
) | ||
|
||
assert isinstance(host_array, np.ndarray) | ||
assert host_array.shape == global_array.shape | ||
assert np.array_equal(host_array, np.array(global_array)) | ||
|
||
|
||
@pytest.mark.parametrize("simulated_xla_devices", [4, 8], indirect=True) | ||
def test_with_sharding_constraint(simulated_xla_devices): | ||
mesh = jax.sharding.Mesh(simulated_xla_devices, ("data",)) | ||
with mesh: | ||
x = jax.numpy.array([1, 2, 3, 4]) | ||
sharded_x = with_sharding_constraint(x, PartitionSpec("data")) | ||
|
||
assert isinstance(sharded_x, jax.Array) | ||
assert np.array_equal(np.array(sharded_x), np.array(x)) | ||
|
||
|
||
@pytest.mark.parametrize("simulated_xla_devices", [4, 8], indirect=True) | ||
@pytest.mark.parametrize( | ||
"partition", [DataPartitionType.FULL, DataPartitionType.REPLICATED] | ||
) | ||
def test_host_to_global_device_array_partition_types(simulated_xla_devices, partition): | ||
mesh = jax.sharding.Mesh(simulated_xla_devices, ("data",)) | ||
device_count = len(simulated_xla_devices) | ||
with mesh: | ||
host_array = np.array([[i, i + 1] for i in range(0, device_count * 2, 2)]) | ||
global_array = host_to_global_device_array(host_array, partition=partition) | ||
|
||
assert isinstance(global_array, jax.Array) | ||
assert global_array.shape == host_array.shape | ||
assert np.array_equal(np.array(global_array), host_array) | ||
|
||
if partition == DataPartitionType.FULL: | ||
assert len(global_array.sharding.device_set) == device_count | ||
elif partition == DataPartitionType.REPLICATED: | ||
assert len(global_array.sharding.device_set) == 1 | ||
|
||
|
||
@pytest.mark.parametrize("simulated_xla_devices", [4, 8], indirect=True) | ||
def test_host_to_global_device_array_nested(simulated_xla_devices): | ||
mesh = jax.sharding.Mesh(simulated_xla_devices, ("data",)) | ||
with mesh: | ||
host_nested = {"a": np.array([1, 2]), "b": {"c": np.array([3, 4])}} | ||
global_nested = host_to_global_device_array( | ||
host_nested, partition=DataPartitionType.FULL | ||
) | ||
|
||
assert isinstance(global_nested, dict) | ||
assert isinstance(global_nested["a"], jax.Array) | ||
assert isinstance(global_nested["b"]["c"], jax.Array) | ||
assert np.array_equal(np.array(global_nested["a"]), host_nested["a"]) | ||
assert np.array_equal(np.array(global_nested["b"]["c"]), host_nested["b"]["c"]) | ||
|
||
|
||
@pytest.mark.parametrize("simulated_xla_devices", [4, 8], indirect=True) | ||
def test_global_to_host_array_nested(simulated_xla_devices): | ||
mesh = jax.sharding.Mesh(simulated_xla_devices, ("data",)) | ||
with mesh: | ||
global_nested = { | ||
"a": jax.numpy.array([1, 2]), | ||
"b": {"c": jax.numpy.array([3, 4])}, | ||
} | ||
host_nested = global_to_host_array( | ||
global_nested, partition=DataPartitionType.FULL | ||
) | ||
|
||
assert isinstance(host_nested, dict) | ||
assert isinstance(host_nested["a"], np.ndarray) | ||
assert isinstance(host_nested["b"]["c"], np.ndarray) | ||
assert np.array_equal(host_nested["a"], np.array(global_nested["a"])) | ||
assert np.array_equal(host_nested["b"]["c"], np.array(global_nested["b"]["c"])) | ||
|
||
|
||
@pytest.mark.parametrize("simulated_xla_devices", [4, 8], indirect=True) | ||
def test_host_to_global_device_array_multi_device(simulated_xla_devices): | ||
device_count = len(simulated_xla_devices) | ||
mesh = jax.sharding.Mesh(simulated_xla_devices, ("data",)) | ||
|
||
with mesh: | ||
host_array = np.array([[i, i + 1] for i in range(0, device_count * 2, 2)]) | ||
global_array = host_to_global_device_array( | ||
host_array, partition=DataPartitionType.FULL | ||
) | ||
|
||
assert isinstance(global_array, jax.Array) | ||
assert global_array.shape == host_array.shape | ||
assert np.array_equal(np.array(global_array), host_array) | ||
assert len(global_array.sharding.device_set) == device_count | ||
|
||
|
||
@pytest.mark.parametrize("simulated_xla_devices", [4, 8], indirect=True) | ||
def test_global_to_host_array_multi_device(simulated_xla_devices): | ||
device_count = len(simulated_xla_devices) | ||
mesh = jax.sharding.Mesh(simulated_xla_devices, ("data",)) | ||
|
||
with mesh: | ||
global_array = jax.numpy.array( | ||
[[i, i + 1] for i in range(0, device_count * 2, 2)] | ||
) | ||
host_array = global_to_host_array( | ||
global_array, partition=DataPartitionType.FULL | ||
) | ||
|
||
assert isinstance(host_array, np.ndarray) | ||
assert host_array.shape == global_array.shape | ||
assert np.array_equal(host_array, np.array(global_array)) |