diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 434e7d87..4d808e4b 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -776,7 +776,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] device : Hardware device the array data resides on. """ - if is_numpy_array(x): + if is_numpy_array(x) or is_ndonnx_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") if device == 'cpu': diff --git a/docs/supported-array-libraries.md b/docs/supported-array-libraries.md index 1af9f3dc..4519c4ac 100644 --- a/docs/supported-array-libraries.md +++ b/docs/supported-array-libraries.md @@ -138,6 +138,11 @@ The minimum supported Dask version is 2023.12.0. Similar to JAX, `sparse` Array API support is contained directly in `sparse`. +(ndonnx-support)= +## [ndonnx](https://github.com/quantco/ndonnx) + +Similar to JAX, `ndonnx` Array API support is contained directly in `ndonnx`. + (array-api-strict-support)= ## [array-api-strict](https://data-apis.org/array-api-strict/) diff --git a/tests/_helpers.py b/tests/_helpers.py index 5b79aa46..2c8f314b 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -3,11 +3,12 @@ import pytest wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"] -all_libraries = wrapped_libraries + ["array_api_strict", "jax.numpy", "sparse"] - +all_libraries = wrapped_libraries + [ + "array_api_strict", "jax.numpy", "ndonnx", "sparse" +] def import_(library, wrapper=False): - if library == 'cupy': + if library in ('cupy', 'ndonnx'): pytest.importorskip(library) if wrapper: if 'jax' in library: @@ -20,3 +21,14 @@ def import_(library, wrapper=False): library = 'array_api_compat.' + library return import_module(library) + + +def xfail(request: pytest.FixtureRequest, reason: str) -> None: + """ + XFAIL the currently running test. + + Unlike ``pytest.xfail``, allow rest of test to execute instead of immediately + halting it, so that it may result in a XPASS. + xref https://github.com/pandas-dev/pandas/issues/38902 + """ + request.node.add_marker(pytest.mark.xfail(reason=reason)) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index a66a64d9..605c69a1 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -22,6 +22,9 @@ def test_array_namespace(library, api_version, use_compat): if use_compat and library not in wrapped_libraries: pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat)) return + if library == "ndonnx" and api_version in ("2021.12", "2022.12"): + pytest.skip("Unsupported API version") + namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) if use_compat is False or use_compat is None and library not in wrapped_libraries: diff --git a/tests/test_common.py b/tests/test_common.py index e95e305e..32876e69 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -8,15 +8,17 @@ from array_api_compat import ( # noqa: F401 is_numpy_array, is_cupy_array, is_torch_array, is_dask_array, is_jax_array, is_pydata_sparse_array, + is_ndonnx_array, is_numpy_namespace, is_cupy_namespace, is_torch_namespace, is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, - is_array_api_strict_namespace, + is_array_api_strict_namespace, is_ndonnx_namespace, ) from array_api_compat import ( device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device ) -from ._helpers import import_, wrapped_libraries, all_libraries +from ._helpers import all_libraries, import_, wrapped_libraries, xfail + is_array_functions = { 'numpy': 'is_numpy_array', @@ -25,6 +27,7 @@ 'dask.array': 'is_dask_array', 'jax.numpy': 'is_jax_array', 'sparse': 'is_pydata_sparse_array', + 'ndonnx': 'is_ndonnx_array', } is_namespace_functions = { @@ -35,6 +38,7 @@ 'jax.numpy': 'is_jax_namespace', 'sparse': 'is_pydata_sparse_namespace', 'array_api_strict': 'is_array_api_strict_namespace', + 'ndonnx': 'is_ndonnx_namespace', } @@ -185,7 +189,10 @@ class C: @pytest.mark.parametrize("library", all_libraries) -def test_device(library): +def test_device(library, request): + if library == "ndonnx": + xfail(request, reason="Needs ndonnx >=0.9.4") + xp = import_(library, wrapper=True) # We can't test much for device() and to_device() other than that @@ -223,17 +230,19 @@ def test_to_device_host(library): @pytest.mark.parametrize("target_library", is_array_functions.keys()) @pytest.mark.parametrize("source_library", is_array_functions.keys()) def test_asarray_cross_library(source_library, target_library, request): - def _xfail(reason: str) -> None: - # Allow rest of test to execute instead of immediately xfailing - # xref https://github.com/pandas-dev/pandas/issues/38902 - request.node.add_marker(pytest.mark.xfail(reason=reason)) - if source_library == "dask.array" and target_library == "torch": # TODO: remove xfail once # https://github.com/dask/dask/issues/8260 is resolved - _xfail(reason="Bug in dask raising error on conversion") + xfail(request, reason="Bug in dask raising error on conversion") + elif ( + source_library == "ndonnx" + and target_library not in ("array_api_strict", "ndonnx", "numpy") + ): + xfail(request, reason="The truth value of lazy Array Array(dtype=Boolean) is unknown") + elif source_library == "ndonnx" and target_library == "numpy": + xfail(request, reason="produces numpy array of ndonnx scalar arrays") elif source_library == "jax.numpy" and target_library == "torch": - _xfail(reason="casts int to float") + xfail(request, reason="casts int to float") elif source_library == "cupy" and target_library != "cupy": # cupy explicitly disallows implicit conversions to CPU pytest.skip(reason="cupy does not support implicit conversion to CPU")