diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 434e7d8..80239c0 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from typing import Optional, Union, Any - from ._typing import Array, Device + from ._typing import Array, Device, Namespace import sys import math @@ -439,7 +439,7 @@ def _check_api_version(api_version: str) -> None: raise ValueError("Only the 2023.12 version of the array API specification is currently supported") -def array_namespace(*xs, api_version=None, use_compat=None): +def array_namespace(*xs, api_version=None, use_compat=None) -> Namespace: """ Get the array API compatible namespace for the arrays `xs`. diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index 1f916cd..d8acdef 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -5,6 +5,7 @@ "SupportsBufferProtocol", ] +from types import ModuleType from typing import ( Any, TypeVar, @@ -22,3 +23,4 @@ def __len__(self, /) -> int: ... Array = Any Device = Any DType = Any +Namespace = ModuleType diff --git a/array_api_compat/py.typed b/array_api_compat/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/setup.py b/setup.py index d0a2840..2368ccc 100644 --- a/setup.py +++ b/setup.py @@ -34,4 +34,7 @@ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], + package_data={ + "array_api_compat": ["py.typed"], + }, )