Skip to content

Commit

Permalink
refactor to an allocation namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt committed Dec 19, 2024
1 parent 1c5acd6 commit 92b8ea6
Show file tree
Hide file tree
Showing 24 changed files with 486 additions and 273 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ classifiers = [
'Topic :: Scientific/Engineering :: Physics'
]
dependencies = [
"array-api-compat>=1.9.1",
"astunparse>=1.6.3;python_version<'3.9'",
'attrs>=21.3',
'black>=22.3',
Expand Down Expand Up @@ -265,6 +266,7 @@ markers = [
'uses_unstructured_shift: tests that use a unstructured connectivity',
'uses_max_over: tests that use the max_over builtin',
'uses_mesh_with_skip_values: tests that use a mesh with skip values',
'slices_out_argument: tests that slice the out argument in a field_operator call',
'checks_specific_error: tests that rely on the backend to produce a specific error message'
]
norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*']
Expand Down
55 changes: 41 additions & 14 deletions src/gt4py/_core/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import functools
import math
import numbers
from types import ModuleType

import numpy as np
import numpy.typing as npt
Expand All @@ -28,6 +27,7 @@
Iterator,
Literal,
Protocol,
Self,
Sequence,
Tuple,
Type,
Expand Down Expand Up @@ -339,15 +339,6 @@ def dtype(dtype_like: DTypeLike) -> DType:
return dtype_like if isinstance(dtype_like, DType) else DType(np.dtype(dtype_like).type)


def to_array_api_dtype(xp: ModuleType, dtype_: DTypeLike | None) -> Any:
"""
Converts a GT4Py `DTypeLike` to the dtype object of the given Array API namespace.
Note: For convenience `None` is passed-through as it has a consistent meaning in all Array API implementations.
"""
return None if dtype_ is None else getattr(xp, dtype(dtype_).scalar_type.__name__)


# -- Custom protocols --
class GTDimsInterface(Protocol):
"""
Expand Down Expand Up @@ -415,6 +406,7 @@ class DeviceType(enum.IntEnum):
MetalDeviceTyping,
VPIDeviceTyping,
ROCMDeviceTyping,
covariant=True,
)


Expand Down Expand Up @@ -464,7 +456,7 @@ def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: ...

def any(self) -> bool: ...

def __getitem__(self, item: Any) -> NDArrayObject: ...
def __getitem__(self, item: Any) -> Self: ...

def __abs__(self) -> NDArrayObject: ...

Expand Down Expand Up @@ -517,12 +509,47 @@ def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ...
def __xor__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ...


class MutableNDArrayObject(NDArrayObject, Protocol):
def __setitem__(self, index: Any, value: Any) -> None: ...


class ArrayApiNamespace(Protocol):
@property
def __array_api_version__(self) -> str: ...
def empty(self, shape: Sequence[int], *, dtype: Any = None, device: Any = None) -> Any: ...
def zeros(self, shape: Sequence[int], *, dtype: Any = None, device: Any = None) -> Any: ...
def ones(self, shape: Sequence[int], *, dtype: Any = None, device: Any = None) -> Any: ...
def full(
self, shape: Sequence[int], fill_value: Scalar, *, dtype: Any = None, device: Any = None
) -> Any: ...
def asarray(self, obj: Any, *, dtype: Any = None, copy: Any = None) -> Any: ...

# @property # once all relevant implementations have this attribute
# def __array_api_version__(self) -> str: ... # noqa: ERA001

# TODO(havogt): add relevant methods and attributes or wait for the standard to provide it, see e.g. https://github.com/data-apis/array-api/issues/697


def is_array_api_namespace(obj: Any) -> TypeGuard[ArrayApiNamespace]:
return hasattr(obj, "__array_api_version__")
# return hasattr(obj, "__array_api_version__") # noqa: ERA001 # once all relevant implementations have this attribute
return (
hasattr(obj, "empty")
and hasattr(obj, "zeros")
and hasattr(obj, "ones")
and hasattr(obj, "full")
and hasattr(obj, "asarray")
)


def to_array_api_dtype(xp: ArrayApiNamespace, dtype_: DTypeLike | None) -> Any:
"""
Converts a GT4Py `DTypeLike` to the dtype object of the given Array API namespace.
Note: For convenience `None` is passed-through as it has a consistent meaning in all Array API implementations.
"""
if dtype_ is None:
return None
else:
dtype_ = dtype(dtype_)
assert (
dtype_.tensor_shape == ()
) # TODO(havogt): support tensor shapes (or remove from our DType)
return getattr(xp, dtype_.scalar_type.__name__)
8 changes: 8 additions & 0 deletions src/gt4py/_core/gt_array_namespace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

Loading

0 comments on commit 92b8ea6

Please sign in to comment.