Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ndonnx #154

Merged
merged 2 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ jobs:
else
PIP_EXTRA='numpy==1.26.*'
fi

if [ "${{ matrix.python-version }}" == "3.9" ]; then
sed -i '/^ndonnx/d' requirements-dev.txt
fi

python -m pip install -r requirements-dev.txt $PIP_EXTRA

- name: Run Tests
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

This is a small wrapper around common array libraries that is compatible with
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
NumPy, CuPy, PyTorch, Dask, JAX and `sparse` are supported. If you want support
for other array libraries, or if you encounter any issues, please [open an
issue](https://github.com/data-apis/array-api-compat/issues).
NumPy, CuPy, PyTorch, Dask, JAX, ndonnx and `sparse` are supported. If you want
support for other array libraries, or if you encounter any issues, please [open
an issue](https://github.com/data-apis/array-api-compat/issues).

See the documentation for more details https://data-apis.org/array-api-compat/
36 changes: 35 additions & 1 deletion array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def is_numpy_array(x):
is_array_api_obj
is_cupy_array
is_torch_array
is_ndonnx_array
is_dask_array
is_jax_array
is_pydata_sparse_array
Expand Down Expand Up @@ -78,11 +79,12 @@ def is_cupy_array(x):
is_array_api_obj
is_numpy_array
is_torch_array
is_ndonnx_array
is_dask_array
is_jax_array
is_pydata_sparse_array
"""
# Avoid importing NumPy if it isn't already
# Avoid importing CuPy if it isn't already
if 'cupy' not in sys.modules:
return False

Expand Down Expand Up @@ -118,6 +120,33 @@ def is_torch_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)

def is_ndonnx_array(x):
"""
Return True if `x` is a ndonnx Array.

This function does not import ndonnx if it has not already been imported
and is therefore cheap to use.

See Also
--------

array_namespace
is_array_api_obj
is_numpy_array
is_cupy_array
is_ndonnx_array
is_dask_array
is_jax_array
is_pydata_sparse_array
"""
# Avoid importing torch if it isn't already
if 'ndonnx' not in sys.modules:
return False

import ndonnx as ndx

return isinstance(x, ndx.Array)

def is_dask_array(x):
"""
Return True if `x` is a dask.array Array.
Expand All @@ -133,6 +162,7 @@ def is_dask_array(x):
is_numpy_array
is_cupy_array
is_torch_array
is_ndonnx_array
is_jax_array
is_pydata_sparse_array
"""
Expand Down Expand Up @@ -160,6 +190,7 @@ def is_jax_array(x):
is_numpy_array
is_cupy_array
is_torch_array
is_ndonnx_array
is_dask_array
is_pydata_sparse_array
"""
Expand Down Expand Up @@ -188,6 +219,7 @@ def is_pydata_sparse_array(x) -> bool:
is_numpy_array
is_cupy_array
is_torch_array
is_ndonnx_array
is_dask_array
is_jax_array
"""
Expand All @@ -211,6 +243,7 @@ def is_array_api_obj(x):
is_numpy_array
is_cupy_array
is_torch_array
is_ndonnx_array
is_dask_array
is_jax_array
"""
Expand Down Expand Up @@ -613,6 +646,7 @@ def size(x):
"is_jax_array",
"is_numpy_array",
"is_torch_array",
"is_ndonnx_array",
"is_pydata_sparse_array",
"size",
"to_device",
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ numpy
pytest
torch
sparse >=0.15.1
ndonnx
Loading