Skip to content

Commit

Permalink
Merge pull request #232 from crusaderky/ndonnx
Browse files Browse the repository at this point in the history
ENH: ndonnx device() support; TST: better ndonnx test coverage
  • Loading branch information
ev-br authored Feb 5, 2025
2 parents 7948ac0 + 8434019 commit cb6a3ec
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 14 deletions.
2 changes: 1 addition & 1 deletion array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
device : Hardware device the array data resides on.
"""
if is_numpy_array(x):
if is_numpy_array(x) or is_ndonnx_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
if device == 'cpu':
Expand Down
5 changes: 5 additions & 0 deletions docs/supported-array-libraries.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ The minimum supported Dask version is 2023.12.0.

Similar to JAX, `sparse` Array API support is contained directly in `sparse`.

(ndonnx-support)=
## [ndonnx](https://github.com/quantco/ndonnx)

Similar to JAX, `ndonnx` Array API support is contained directly in `ndonnx`.

(array-api-strict-support)=
## [array-api-strict](https://data-apis.org/array-api-strict/)

Expand Down
18 changes: 15 additions & 3 deletions tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import pytest

wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
all_libraries = wrapped_libraries + ["array_api_strict", "jax.numpy", "sparse"]

all_libraries = wrapped_libraries + [
"array_api_strict", "jax.numpy", "ndonnx", "sparse"
]

def import_(library, wrapper=False):
if library == 'cupy':
if library in ('cupy', 'ndonnx'):
pytest.importorskip(library)
if wrapper:
if 'jax' in library:
Expand All @@ -20,3 +21,14 @@ def import_(library, wrapper=False):
library = 'array_api_compat.' + library

return import_module(library)


def xfail(request: pytest.FixtureRequest, reason: str) -> None:
"""
XFAIL the currently running test.
Unlike ``pytest.xfail``, allow rest of test to execute instead of immediately
halting it, so that it may result in a XPASS.
xref https://github.com/pandas-dev/pandas/issues/38902
"""
request.node.add_marker(pytest.mark.xfail(reason=reason))
3 changes: 3 additions & 0 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def test_array_namespace(library, api_version, use_compat):
if use_compat and library not in wrapped_libraries:
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
return
if library == "ndonnx" and api_version in ("2021.12", "2022.12"):
pytest.skip("Unsupported API version")

namespace = array_namespace(array, api_version=api_version, use_compat=use_compat)

if use_compat is False or use_compat is None and library not in wrapped_libraries:
Expand Down
29 changes: 19 additions & 10 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@
from array_api_compat import ( # noqa: F401
is_numpy_array, is_cupy_array, is_torch_array,
is_dask_array, is_jax_array, is_pydata_sparse_array,
is_ndonnx_array,
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
is_array_api_strict_namespace,
is_array_api_strict_namespace, is_ndonnx_namespace,
)

from array_api_compat import (
device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device
)
from ._helpers import import_, wrapped_libraries, all_libraries
from ._helpers import all_libraries, import_, wrapped_libraries, xfail


is_array_functions = {
'numpy': 'is_numpy_array',
Expand All @@ -25,6 +27,7 @@
'dask.array': 'is_dask_array',
'jax.numpy': 'is_jax_array',
'sparse': 'is_pydata_sparse_array',
'ndonnx': 'is_ndonnx_array',
}

is_namespace_functions = {
Expand All @@ -35,6 +38,7 @@
'jax.numpy': 'is_jax_namespace',
'sparse': 'is_pydata_sparse_namespace',
'array_api_strict': 'is_array_api_strict_namespace',
'ndonnx': 'is_ndonnx_namespace',
}


Expand Down Expand Up @@ -185,7 +189,10 @@ class C:


@pytest.mark.parametrize("library", all_libraries)
def test_device(library):
def test_device(library, request):
if library == "ndonnx":
xfail(request, reason="Needs ndonnx >=0.9.4")

xp = import_(library, wrapper=True)

# We can't test much for device() and to_device() other than that
Expand Down Expand Up @@ -223,17 +230,19 @@ def test_to_device_host(library):
@pytest.mark.parametrize("target_library", is_array_functions.keys())
@pytest.mark.parametrize("source_library", is_array_functions.keys())
def test_asarray_cross_library(source_library, target_library, request):
def _xfail(reason: str) -> None:
# Allow rest of test to execute instead of immediately xfailing
# xref https://github.com/pandas-dev/pandas/issues/38902
request.node.add_marker(pytest.mark.xfail(reason=reason))

if source_library == "dask.array" and target_library == "torch":
# TODO: remove xfail once
# https://github.com/dask/dask/issues/8260 is resolved
_xfail(reason="Bug in dask raising error on conversion")
xfail(request, reason="Bug in dask raising error on conversion")
elif (
source_library == "ndonnx"
and target_library not in ("array_api_strict", "ndonnx", "numpy")
):
xfail(request, reason="The truth value of lazy Array Array(dtype=Boolean) is unknown")
elif source_library == "ndonnx" and target_library == "numpy":
xfail(request, reason="produces numpy array of ndonnx scalar arrays")
elif source_library == "jax.numpy" and target_library == "torch":
_xfail(reason="casts int to float")
xfail(request, reason="casts int to float")
elif source_library == "cupy" and target_library != "cupy":
# cupy explicitly disallows implicit conversions to CPU
pytest.skip(reason="cupy does not support implicit conversion to CPU")
Expand Down

0 comments on commit cb6a3ec

Please sign in to comment.