diff --git a/src/gt4py/next/_allocators.py b/src/gt4py/next/_allocators.py index 440aa6c0de..3b08056653 100644 --- a/src/gt4py/next/_allocators.py +++ b/src/gt4py/next/_allocators.py @@ -63,7 +63,6 @@ def empty( domain: common.DomainLike, *, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: ... @@ -72,7 +71,6 @@ def zeros( domain: common.DomainLike, *, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: ... @@ -81,7 +79,6 @@ def ones( domain: common.DomainLike, *, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: ... @@ -91,7 +88,6 @@ def full( fill_value: core_defs.Scalar, *, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: ... @@ -101,7 +97,6 @@ def asarray( *, domain: common.DomainLike, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, copy: Optional[bool] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: ... @@ -362,9 +357,9 @@ def allocate( *, domain: common.Domain, dtype: core_defs.DType[core_defs.ScalarT], + allocator: FieldBufferAllocatorProtocol, + device: core_defs.Device, aligned_index: Optional[Sequence[common.NamedIndex]] = None, - allocator: Optional[FieldBufferAllocationUtil] = None, - device: Optional[core_defs.Device] = None, ) -> core_defs.MutableNDArrayObject: """ TODO: docstring @@ -390,18 +385,8 @@ def allocate( If illegal or inconsistent arguments are specified. """ - if device is None and allocator is None: - raise ValueError("No 'device' or 'allocator' specified.") - actual_allocator = get_allocator(allocator) - if actual_allocator is None: - assert device is not None # for mypy - actual_allocator = device_allocators[device.device_type] - elif device is None: - device = core_defs.Device(actual_allocator.__gt_device_type__, 0) - elif device.device_type != actual_allocator.__gt_device_type__: - raise ValueError(f"Device '{device}' and allocator '{actual_allocator}' are incompatible.") - return actual_allocator.__gt_allocate__( + return allocator.__gt_allocate__( domain=domain, dtype=dtype, device_id=device.device_id, @@ -419,11 +404,26 @@ def _check_unsupported_device_and_aligned_index( raise NotImplementedError("Device specification is not yet supported.") +def _get_actual_allocator_and_device( + allocator: Optional[FieldBufferAllocationUtil], device: Optional[core_defs.Device] +) -> tuple[FieldBufferAllocatorProtocol, core_defs.Device]: + if device is None and allocator is None: + raise ValueError("No 'device' or 'allocator' specified.") + actual_allocator = get_allocator(allocator) + if actual_allocator is None: + assert device is not None # for mypy + actual_allocator = device_allocators[device.device_type] + elif device is None: + device = core_defs.Device(actual_allocator.__gt_device_type__, 0) + elif device.device_type != actual_allocator.__gt_device_type__: + raise ValueError(f"Device '{device}' and allocator '{actual_allocator}' are incompatible.") + return actual_allocator, device + + def get_array_allocation_namespace( - allocator: FieldBufferAllocationUtil | core_defs.ArrayApiNamespace | None, + allocator: Optional[FieldBufferAllocationUtil | core_defs.ArrayApiNamespace], + device: Optional[core_defs.Device] = None, ) -> GTArrayAllocationNamespace: - if allocator is None: - allocator = StandardCPUFieldBufferAllocator() if core_defs.is_array_api_namespace(allocator): assert core_defs.is_array_api_namespace(allocator) array_ns = array_api_compat.array_namespace(allocator.empty([0])) @@ -434,7 +434,6 @@ def empty( domain: common.DomainLike, *, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: _check_unsupported_device_and_aligned_index(device, aligned_index) @@ -448,7 +447,6 @@ def zeros( domain: common.DomainLike, *, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: _check_unsupported_device_and_aligned_index(device, aligned_index) @@ -462,7 +460,6 @@ def ones( domain: common.DomainLike, *, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: _check_unsupported_device_and_aligned_index(device, aligned_index) @@ -477,7 +474,6 @@ def full( fill_value: core_defs.Scalar, *, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: _check_unsupported_device_and_aligned_index(device, aligned_index) @@ -493,7 +489,6 @@ def asarray( *, domain: common.DomainLike, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, copy: Optional[bool] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: @@ -509,106 +504,97 @@ def asarray( return _ArrayNamespaceWrapper - else: - - class _CustomAllocationArrayNamespace: - @staticmethod - def empty( - domain: common.DomainLike, - *, - dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, - aligned_index: Optional[Sequence[common.NamedIndex]] = None, - ) -> core_defs.NDArrayObject: - assert is_field_allocation_tool(allocator) - return allocate( - domain=common.domain(domain), - dtype=core_defs.dtype(dtype), - aligned_index=aligned_index, - allocator=allocator, - device=device, - ) - - @staticmethod - def zeros( - domain: common.DomainLike, - *, - dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, - aligned_index: Optional[Sequence[common.NamedIndex]] = None, - ) -> core_defs.NDArrayObject: - assert is_field_allocation_tool(allocator) - buffer = allocate( - domain=common.domain(domain), - dtype=core_defs.dtype(dtype), - aligned_index=aligned_index, - allocator=allocator, - device=device, - ) - buffer[...] = 0 - return buffer - - @staticmethod - def ones( - domain: common.DomainLike, - *, - dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, - aligned_index: Optional[Sequence[common.NamedIndex]] = None, - ) -> core_defs.NDArrayObject: - assert is_field_allocation_tool(allocator) - buffer = allocate( - domain=common.domain(domain), - dtype=core_defs.dtype(dtype), - aligned_index=aligned_index, - allocator=allocator, - device=device, - ) - buffer[...] = 1 - return buffer - - @staticmethod - def full( - domain: common.DomainLike, - fill_value: core_defs.Scalar, - *, - dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, - aligned_index: Optional[Sequence[common.NamedIndex]] = None, - ) -> core_defs.NDArrayObject: - assert is_field_allocation_tool(allocator) - buffer = allocate( - domain=common.domain(domain), - dtype=core_defs.dtype(dtype), # TODO check all dtypes - aligned_index=aligned_index, - allocator=allocator, - device=device, - ) - buffer[...] = fill_value - return buffer + assert is_field_allocation_tool(allocator) or allocator is None + actual_allocator, actual_device = _get_actual_allocator_and_device(allocator, device) + + class _CustomAllocationArrayNamespace: + @staticmethod + def empty( + domain: common.DomainLike, + *, + dtype: Optional[core_defs.DTypeLike] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + return allocate( + domain=common.domain(domain), + dtype=core_defs.dtype(dtype), + aligned_index=aligned_index, + allocator=actual_allocator, + device=actual_device, + ) - @staticmethod - def asarray( - data: core_defs.NDArrayObject, - *, - domain: common.DomainLike, - dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, - copy: Optional[bool] = None, - aligned_index: Optional[Sequence[common.NamedIndex]] = None, - ) -> core_defs.NDArrayObject: - assert is_field_allocation_tool(allocator) - if not copy: - raise NotImplementedError("Zero-copy construction is not yet supported.") - dtype = core_defs.dtype(data.dtype) if dtype is None else core_defs.dtype(dtype) - buffer = allocate( - domain=common.domain(domain), - dtype=dtype, - aligned_index=aligned_index, - allocator=allocator, - device=device, - ) - buffer[...] = array_api_compat.array_namespace(buffer).asarray(data) - return buffer + @staticmethod + def zeros( + domain: common.DomainLike, + *, + dtype: Optional[core_defs.DTypeLike] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + buffer = allocate( + domain=common.domain(domain), + dtype=core_defs.dtype(dtype), + aligned_index=aligned_index, + allocator=actual_allocator, + device=actual_device, + ) + buffer[...] = 0 + return buffer + + @staticmethod + def ones( + domain: common.DomainLike, + *, + dtype: Optional[core_defs.DTypeLike] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + buffer = allocate( + domain=common.domain(domain), + dtype=core_defs.dtype(dtype), + aligned_index=aligned_index, + allocator=actual_allocator, + device=actual_device, + ) + buffer[...] = 1 + return buffer + + @staticmethod + def full( + domain: common.DomainLike, + fill_value: core_defs.Scalar, + *, + dtype: Optional[core_defs.DTypeLike] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + buffer = allocate( + domain=common.domain(domain), + dtype=core_defs.dtype(dtype), # TODO check all dtypes + aligned_index=aligned_index, + allocator=actual_allocator, + device=actual_device, + ) + buffer[...] = fill_value + return buffer + + @staticmethod + def asarray( + data: core_defs.NDArrayObject, + *, + domain: common.DomainLike, + dtype: Optional[core_defs.DTypeLike] = None, + copy: Optional[bool] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + if not copy: + raise NotImplementedError("Zero-copy construction is not yet supported.") + dtype = core_defs.dtype(data.dtype) if dtype is None else core_defs.dtype(dtype) + buffer = allocate( + domain=common.domain(domain), + dtype=dtype, + aligned_index=aligned_index, + allocator=actual_allocator, + device=actual_device, + ) + buffer[...] = array_api_compat.array_namespace(buffer).asarray(data) + return buffer - return _CustomAllocationArrayNamespace + return _CustomAllocationArrayNamespace diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index bc3dce3acb..21b4b63636 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -77,10 +77,8 @@ def empty( >>> b.shape (3, 3) """ - gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) - buffer = gtarray_namespace.empty( - domain, device=device, dtype=dtype, aligned_index=aligned_index - ) + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator, device) + buffer = gtarray_namespace.empty(domain, dtype=dtype, aligned_index=aligned_index) res = common._field(buffer, domain=domain, allocation_ns=gtarray_namespace) assert isinstance(res, common.MutableField) assert isinstance(res, nd_array_field.NdArrayField) @@ -107,10 +105,8 @@ def zeros( >>> gtx.zeros({IDim: range(3, 10)}, allocator=gtx.itir_python).ndarray array([0., 0., 0., 0., 0., 0., 0.]) """ - gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) - buffer = gtarray_namespace.zeros( - domain, device=device, dtype=dtype, aligned_index=aligned_index - ) + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator, device) + buffer = gtarray_namespace.zeros(domain, dtype=dtype, aligned_index=aligned_index) res = common._field(buffer, domain=domain, allocation_ns=gtarray_namespace) assert isinstance(res, common.MutableField) assert isinstance(res, nd_array_field.NdArrayField) @@ -137,8 +133,8 @@ def ones( >>> gtx.ones({IDim: range(3, 10)}, allocator=gtx.itir_python).ndarray array([1., 1., 1., 1., 1., 1., 1.]) """ - gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) - buffer = gtarray_namespace.ones(domain, device=device, dtype=dtype, aligned_index=aligned_index) + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator, device) + buffer = gtarray_namespace.ones(domain, dtype=dtype, aligned_index=aligned_index) res = common._field(buffer, domain=domain, allocation_ns=gtarray_namespace) assert isinstance(res, common.MutableField) assert isinstance(res, nd_array_field.NdArrayField) @@ -171,11 +167,10 @@ def full( >>> gtx.full({IDim: 3}, 5, allocator=gtx.itir_python).ndarray array([5, 5, 5]) """ - gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator, device) buffer = gtarray_namespace.full( domain, fill_value, - device=device, dtype=dtype if dtype is not None else core_defs.dtype(type(fill_value)), aligned_index=aligned_index, ) @@ -284,12 +279,11 @@ def as_field( if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) - gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator, device) buffer = gtarray_namespace.asarray( data, domain=actual_domain, dtype=dtype, - device=device, copy=True, # TODO(havogt) add support for zero-copy construction aligned_index=aligned_index, ) @@ -367,12 +361,11 @@ def as_connectivity( if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) - gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator, device) buffer = gtarray_namespace.asarray( data, domain=actual_domain, dtype=dtype, - device=device, copy=True, # TODO(havogt) add support for zero-copy construction ) connectivity_field = common._connectivity( diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 182a1af1e2..06eeecb66a 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -436,8 +436,10 @@ def _slice( def __copy__(self) -> NdArrayField: # Note: `copy` copies the data, following NumPy behavior - allocation_ns = self._allocation_ns or _allocators.get_array_allocation_namespace( - self.array_ns + allocation_ns = ( + self._allocation_ns + if self._allocation_ns is not None + else _allocators.get_array_allocation_namespace(self.array_ns) ) ndarray_copy = allocation_ns.asarray( self.ndarray, diff --git a/tests/next_tests/unit_tests/test_constructors.py b/tests/next_tests/unit_tests/test_constructors.py index 984d87fd00..7e592a687d 100644 --- a/tests/next_tests/unit_tests/test_constructors.py +++ b/tests/next_tests/unit_tests/test_constructors.py @@ -184,9 +184,6 @@ def test_field_wrong_origin(): with pytest.raises(ValueError, match=(r"Origin keys {'J'} not in domain")): gtx.as_field([I], np.random.rand(sizes[I]).astype(gtx.float32), origin={"J": 0}) - with pytest.raises(ValueError, match=(r"Cannot specify origin for domain I")): - gtx.as_field("I", np.random.rand(sizes[J]).astype(gtx.float32), origin={"J": 0}) - @pytest.mark.xfail(reason="aligned_index not supported yet") def test_aligned_index():