Skip to content

Commit

Permalink
fix get_actual_allocator
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt committed Dec 20, 2024
1 parent 3ee60b8 commit 4fb8bab
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions src/gt4py/next/_allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

# TODO: make this module private

import abc
import dataclasses
import functools
Expand Down Expand Up @@ -407,13 +405,12 @@ def _check_unsupported_device_and_aligned_index(
def _get_actual_allocator_and_device(
allocator: Optional[FieldBufferAllocationUtil], device: Optional[core_defs.Device]
) -> tuple[FieldBufferAllocatorProtocol, core_defs.Device]:
if device is None and allocator is None:
raise ValueError("No 'device' or 'allocator' specified.")
actual_allocator = get_allocator(allocator)
if actual_allocator is None:
assert device is not None # for mypy
actual_allocator = device_allocators[device.device_type]
elif device is None:
if allocator is None and device is not None:
return device_allocators[device.device_type], device

actual_allocator = get_allocator(allocator, default=device_allocators[core_defs.DeviceType.CPU])
assert actual_allocator is not None
if device is None:
device = core_defs.Device(actual_allocator.__gt_device_type__, 0)
elif device.device_type != actual_allocator.__gt_device_type__:
raise ValueError(f"Device '{device}' and allocator '{actual_allocator}' are incompatible.")
Expand Down

0 comments on commit 4fb8bab

Please sign in to comment.