Skip to content

Commit

Permalink
move device to allocation namespace construction
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt committed Dec 20, 2024
1 parent 6709b04 commit 3ee60b8
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 148 deletions.
240 changes: 113 additions & 127 deletions src/gt4py/next/_allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def empty(
domain: common.DomainLike,
*,
dtype: Optional[core_defs.DTypeLike] = None,
device: Optional[core_defs.Device] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject: ...

Expand All @@ -72,7 +71,6 @@ def zeros(
domain: common.DomainLike,
*,
dtype: Optional[core_defs.DTypeLike] = None,
device: Optional[core_defs.Device] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject: ...

Expand All @@ -81,7 +79,6 @@ def ones(
domain: common.DomainLike,
*,
dtype: Optional[core_defs.DTypeLike] = None,
device: Optional[core_defs.Device] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject: ...

Expand All @@ -91,7 +88,6 @@ def full(
fill_value: core_defs.Scalar,
*,
dtype: Optional[core_defs.DTypeLike] = None,
device: Optional[core_defs.Device] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject: ...

Expand All @@ -101,7 +97,6 @@ def asarray(
*,
domain: common.DomainLike,
dtype: Optional[core_defs.DTypeLike] = None,
device: Optional[core_defs.Device] = None,
copy: Optional[bool] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject: ...
Expand Down Expand Up @@ -362,9 +357,9 @@ def allocate(
*,
domain: common.Domain,
dtype: core_defs.DType[core_defs.ScalarT],
allocator: FieldBufferAllocatorProtocol,
device: core_defs.Device,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
allocator: Optional[FieldBufferAllocationUtil] = None,
device: Optional[core_defs.Device] = None,
) -> core_defs.MutableNDArrayObject:
"""
TODO: docstring
Expand All @@ -390,18 +385,8 @@ def allocate(
If illegal or inconsistent arguments are specified.
"""
if device is None and allocator is None:
raise ValueError("No 'device' or 'allocator' specified.")
actual_allocator = get_allocator(allocator)
if actual_allocator is None:
assert device is not None # for mypy
actual_allocator = device_allocators[device.device_type]
elif device is None:
device = core_defs.Device(actual_allocator.__gt_device_type__, 0)
elif device.device_type != actual_allocator.__gt_device_type__:
raise ValueError(f"Device '{device}' and allocator '{actual_allocator}' are incompatible.")

return actual_allocator.__gt_allocate__(
return allocator.__gt_allocate__(
domain=domain,
dtype=dtype,
device_id=device.device_id,
Expand All @@ -419,11 +404,26 @@ def _check_unsupported_device_and_aligned_index(
raise NotImplementedError("Device specification is not yet supported.")


def _get_actual_allocator_and_device(
allocator: Optional[FieldBufferAllocationUtil], device: Optional[core_defs.Device]
) -> tuple[FieldBufferAllocatorProtocol, core_defs.Device]:
if device is None and allocator is None:
raise ValueError("No 'device' or 'allocator' specified.")
actual_allocator = get_allocator(allocator)
if actual_allocator is None:
assert device is not None # for mypy
actual_allocator = device_allocators[device.device_type]
elif device is None:
device = core_defs.Device(actual_allocator.__gt_device_type__, 0)
elif device.device_type != actual_allocator.__gt_device_type__:
raise ValueError(f"Device '{device}' and allocator '{actual_allocator}' are incompatible.")
return actual_allocator, device


def get_array_allocation_namespace(
allocator: FieldBufferAllocationUtil | core_defs.ArrayApiNamespace | None,
allocator: Optional[FieldBufferAllocationUtil | core_defs.ArrayApiNamespace],
device: Optional[core_defs.Device] = None,
) -> GTArrayAllocationNamespace:
if allocator is None:
allocator = StandardCPUFieldBufferAllocator()
if core_defs.is_array_api_namespace(allocator):
assert core_defs.is_array_api_namespace(allocator)
array_ns = array_api_compat.array_namespace(allocator.empty([0]))
Expand All @@ -434,7 +434,6 @@ def empty(
domain: common.DomainLike,
*,
dtype: Optional[core_defs.DTypeLike] = None,
device: Optional[core_defs.Device] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject:
_check_unsupported_device_and_aligned_index(device, aligned_index)
Expand All @@ -448,7 +447,6 @@ def zeros(
domain: common.DomainLike,
*,
dtype: Optional[core_defs.DTypeLike] = None,
device: Optional[core_defs.Device] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject:
_check_unsupported_device_and_aligned_index(device, aligned_index)
Expand All @@ -462,7 +460,6 @@ def ones(
domain: common.DomainLike,
*,
dtype: Optional[core_defs.DTypeLike] = None,
device: Optional[core_defs.Device] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject:
_check_unsupported_device_and_aligned_index(device, aligned_index)
Expand All @@ -477,7 +474,6 @@ def full(
fill_value: core_defs.Scalar,
*,
dtype: Optional[core_defs.DTypeLike] = None,
device: Optional[core_defs.Device] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject:
_check_unsupported_device_and_aligned_index(device, aligned_index)
Expand All @@ -493,7 +489,6 @@ def asarray(
*,
domain: common.DomainLike,
dtype: Optional[core_defs.DTypeLike] = None,
device: Optional[core_defs.Device] = None,
copy: Optional[bool] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject:
Expand All @@ -509,106 +504,97 @@ def asarray(

return _ArrayNamespaceWrapper

else:

class _CustomAllocationArrayNamespace:
@staticmethod
def empty(
domain: common.DomainLike,
*,
dtype: Optional[core_defs.DTypeLike] = None,
device: Optional[core_defs.Device] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject:
assert is_field_allocation_tool(allocator)
return allocate(
domain=common.domain(domain),
dtype=core_defs.dtype(dtype),
aligned_index=aligned_index,
allocator=allocator,
device=device,
)

@staticmethod
def zeros(
domain: common.DomainLike,
*,
dtype: Optional[core_defs.DTypeLike] = None,
device: Optional[core_defs.Device] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject:
assert is_field_allocation_tool(allocator)
buffer = allocate(
domain=common.domain(domain),
dtype=core_defs.dtype(dtype),
aligned_index=aligned_index,
allocator=allocator,
device=device,
)
buffer[...] = 0
return buffer

@staticmethod
def ones(
domain: common.DomainLike,
*,
dtype: Optional[core_defs.DTypeLike] = None,
device: Optional[core_defs.Device] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject:
assert is_field_allocation_tool(allocator)
buffer = allocate(
domain=common.domain(domain),
dtype=core_defs.dtype(dtype),
aligned_index=aligned_index,
allocator=allocator,
device=device,
)
buffer[...] = 1
return buffer

@staticmethod
def full(
domain: common.DomainLike,
fill_value: core_defs.Scalar,
*,
dtype: Optional[core_defs.DTypeLike] = None,
device: Optional[core_defs.Device] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject:
assert is_field_allocation_tool(allocator)
buffer = allocate(
domain=common.domain(domain),
dtype=core_defs.dtype(dtype), # TODO check all dtypes
aligned_index=aligned_index,
allocator=allocator,
device=device,
)
buffer[...] = fill_value
return buffer
assert is_field_allocation_tool(allocator) or allocator is None
actual_allocator, actual_device = _get_actual_allocator_and_device(allocator, device)

class _CustomAllocationArrayNamespace:
@staticmethod
def empty(
domain: common.DomainLike,
*,
dtype: Optional[core_defs.DTypeLike] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject:
return allocate(
domain=common.domain(domain),
dtype=core_defs.dtype(dtype),
aligned_index=aligned_index,
allocator=actual_allocator,
device=actual_device,
)

@staticmethod
def asarray(
data: core_defs.NDArrayObject,
*,
domain: common.DomainLike,
dtype: Optional[core_defs.DTypeLike] = None,
device: Optional[core_defs.Device] = None,
copy: Optional[bool] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject:
assert is_field_allocation_tool(allocator)
if not copy:
raise NotImplementedError("Zero-copy construction is not yet supported.")
dtype = core_defs.dtype(data.dtype) if dtype is None else core_defs.dtype(dtype)
buffer = allocate(
domain=common.domain(domain),
dtype=dtype,
aligned_index=aligned_index,
allocator=allocator,
device=device,
)
buffer[...] = array_api_compat.array_namespace(buffer).asarray(data)
return buffer
@staticmethod
def zeros(
domain: common.DomainLike,
*,
dtype: Optional[core_defs.DTypeLike] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject:
buffer = allocate(
domain=common.domain(domain),
dtype=core_defs.dtype(dtype),
aligned_index=aligned_index,
allocator=actual_allocator,
device=actual_device,
)
buffer[...] = 0
return buffer

@staticmethod
def ones(
domain: common.DomainLike,
*,
dtype: Optional[core_defs.DTypeLike] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject:
buffer = allocate(
domain=common.domain(domain),
dtype=core_defs.dtype(dtype),
aligned_index=aligned_index,
allocator=actual_allocator,
device=actual_device,
)
buffer[...] = 1
return buffer

@staticmethod
def full(
domain: common.DomainLike,
fill_value: core_defs.Scalar,
*,
dtype: Optional[core_defs.DTypeLike] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject:
buffer = allocate(
domain=common.domain(domain),
dtype=core_defs.dtype(dtype), # TODO check all dtypes
aligned_index=aligned_index,
allocator=actual_allocator,
device=actual_device,
)
buffer[...] = fill_value
return buffer

@staticmethod
def asarray(
data: core_defs.NDArrayObject,
*,
domain: common.DomainLike,
dtype: Optional[core_defs.DTypeLike] = None,
copy: Optional[bool] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
) -> core_defs.NDArrayObject:
if not copy:
raise NotImplementedError("Zero-copy construction is not yet supported.")
dtype = core_defs.dtype(data.dtype) if dtype is None else core_defs.dtype(dtype)
buffer = allocate(
domain=common.domain(domain),
dtype=dtype,
aligned_index=aligned_index,
allocator=actual_allocator,
device=actual_device,
)
buffer[...] = array_api_compat.array_namespace(buffer).asarray(data)
return buffer

return _CustomAllocationArrayNamespace
return _CustomAllocationArrayNamespace
Loading

0 comments on commit 3ee60b8

Please sign in to comment.