diff --git a/README.md b/README.md index 7e96c4df..4b0b0c9c 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ This is a small wrapper around common array libraries that is compatible with the [Array API standard](https://data-apis.org/array-api/latest/). Currently, -NumPy, CuPy, PyTorch, Dask, JAX and `sparse` are supported. If you want support -for other array libraries, or if you encounter any issues, please [open an -issue](https://github.com/data-apis/array-api-compat/issues). +NumPy, CuPy, PyTorch, Dask, JAX, ndonnx and `sparse` are supported. If you want +support for other array libraries, or if you encounter any issues, please [open +an issue](https://github.com/data-apis/array-api-compat/issues). See the documentation for more details https://data-apis.org/array-api-compat/ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 32fb0e70..b55b16e2 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -48,6 +48,7 @@ def is_numpy_array(x): is_array_api_obj is_cupy_array is_torch_array + is_ndonnx_array is_dask_array is_jax_array is_pydata_sparse_array @@ -78,11 +79,12 @@ def is_cupy_array(x): is_array_api_obj is_numpy_array is_torch_array + is_ndonnx_array is_dask_array is_jax_array is_pydata_sparse_array """ - # Avoid importing NumPy if it isn't already + # Avoid importing CuPy if it isn't already if 'cupy' not in sys.modules: return False @@ -118,6 +120,33 @@ def is_torch_array(x): # TODO: Should we reject ndarray subclasses? return isinstance(x, torch.Tensor) +def is_ndonnx_array(x): + """ + Return True if `x` is a ndonnx Array. + + This function does not import ndonnx if it has not already been imported + and is therefore cheap to use. + + See Also + -------- + + array_namespace + is_array_api_obj + is_numpy_array + is_cupy_array + is_ndonnx_array + is_dask_array + is_jax_array + is_pydata_sparse_array + """ + # Avoid importing torch if it isn't already + if 'ndonnx' not in sys.modules: + return False + + import ndonnx as ndx + + return isinstance(x, ndx.Array) + def is_dask_array(x): """ Return True if `x` is a dask.array Array. @@ -133,6 +162,7 @@ def is_dask_array(x): is_numpy_array is_cupy_array is_torch_array + is_ndonnx_array is_jax_array is_pydata_sparse_array """ @@ -160,6 +190,7 @@ def is_jax_array(x): is_numpy_array is_cupy_array is_torch_array + is_ndonnx_array is_dask_array is_pydata_sparse_array """ @@ -188,6 +219,7 @@ def is_pydata_sparse_array(x) -> bool: is_numpy_array is_cupy_array is_torch_array + is_ndonnx_array is_dask_array is_jax_array """ @@ -211,6 +243,7 @@ def is_array_api_obj(x): is_numpy_array is_cupy_array is_torch_array + is_ndonnx_array is_dask_array is_jax_array """ @@ -613,6 +646,7 @@ def size(x): "is_jax_array", "is_numpy_array", "is_torch_array", + "is_ndonnx_array", "is_pydata_sparse_array", "size", "to_device", diff --git a/requirements-dev.txt b/requirements-dev.txt index d06de300..c9d10f71 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,3 +5,4 @@ numpy pytest torch sparse >=0.15.1 +ndonnx