Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
egparedes committed Oct 25, 2023
1 parent 50f0f85 commit cdc9853
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 20 deletions.
68 changes: 68 additions & 0 deletions src/gt4py/eve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
Callable,
Collection,
Dict,
Final,
Generic,
Iterable,
Iterator,
Expand Down Expand Up @@ -227,6 +228,73 @@ def itemgetter_(key: Any, default: Any = NOTHING) -> Callable[[Any], Any]:
return lambda obj: getitem_(obj, key, default=default)


_C = TypeVar("_C")
_V = TypeVar("_V")

_dataclass: Final[Callable] = (
functools.partial(dataclasses.dataclass, slots=True)
if sys.version_info >= (3, 10)
else dataclasses.dataclass
)


@_dataclass(frozen=True, slots=True)
class ForwardDescriptor(xtyping.NonDataDescriptor[_C, _V]):
"""
Descriptor to forward attribute access to another member of the object.
Args:
source_member: name of the member to forward the attribute access to.
attribute_name: name of the attribute to be forwarded. If `None`,
the name of the descriptor in the owner class is used.
Examples:
>>> class A:
... def __init__(self, value):
... self.value = value
...
>>> class B:
... def __init__(self, a):
... self.a = a
...
... value = ForwardDescriptor('a')
...
>>> a = A(10)
>>> b = B(a)
>>> b.value
10
"""

source_member: str
attribute_name: Optional[str] = None

def __set_name__(self, _owner_type: _C, _name: str) -> None:
if self.attribute_name is None:
object.__setattr__(self, "attribute_name", _name)

@overload
def __get__(
self, _instance: Literal[None], _owner_type: Optional[Type[_C]] = None
) -> ForwardDescriptor[_C, _V]:
...

@overload
def __get__( # noqa: F811 # redefinion of unused member
self, _instance: _C, _owner_type: Optional[Type[_C]] = None
) -> _V:
...

def __get__( # noqa: F811 # redefinion of unused member
self, _instance: Optional[_C], _owner_type: Optional[Type[_C]] = None
) -> _V | ForwardDescriptor[_C, _V]:
assert self.attribute_name is not None
return (
getattr(getattr(_instance, self.source_member), self.attribute_name)
if _instance is not None
else self
)


_P = ParamSpec("_P")


Expand Down
35 changes: 21 additions & 14 deletions src/gt4py/next/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,18 @@ def pos_of_kind(kind: common.DimensionKind) -> list[int]:

assert core_allocators.is_valid_nplike_allocation_ns(np)

DefaultCPUAllocator: Final[FieldAllocatorInterface] = FieldAllocator(
device_type=core_defs.DeviceType.CPU,
array_ns=np,
layout_mapper=horizontal_first_layout_mapper,
byte_alignment=64,
)
device_allocators[core_defs.DeviceType.CPU] = DefaultCPUAllocator

class DefaultCPUAllocator(FieldAllocator):
def __init__(self) -> None:
super().__init__(
device_type=core_defs.DeviceType.CPU,
array_ns=np,
layout_mapper=horizontal_first_layout_mapper,
byte_alignment=64,
)


device_allocators[core_defs.DeviceType.CPU] = DefaultCPUAllocator()

if cp:
cp_device_type = (
Expand All @@ -136,14 +141,16 @@ def pos_of_kind(kind: common.DimensionKind) -> list[int]:

assert core_allocators.is_valid_nplike_allocation_ns(cp)

DefaultGPUAllocator: Final[FieldAllocatorInterface] = FieldAllocator(
device_type=core_defs.DeviceType.CPU,
array_ns=np,
layout_mapper=horizontal_first_layout_mapper,
byte_alignment=128,
)
class DefaultGPUAllocator(FieldAllocator):
def __init__(self) -> None:
super().__init__(
device_type=core_defs.DeviceType.CPU,
array_ns=np,
layout_mapper=horizontal_first_layout_mapper,
byte_alignment=128,
)

device_allocators[cp_device_type] = DefaultGPUAllocator
device_allocators[cp_device_type] = DefaultGPUAllocator()
else:
DefaultGPUAllocator: Final[Optional[FieldAllocatorInterface]] = None

Expand Down
7 changes: 5 additions & 2 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from devtools import debug

from gt4py._core import definitions as core_defs
from gt4py.eve import utils as eve_utils
from gt4py.eve.extended_typing import Any, Optional
from gt4py.eve.utils import UIDGenerator
from gt4py.next.common import Dimension, DimensionKind, GridType
from gt4py.next.ffront import (
dialect_ast_enums,
Expand Down Expand Up @@ -216,6 +216,9 @@ def __post_init__(self):
if self.backend is not None and hasattr(self.backend, "__gt_allocate__"):
object.__setattr__(self, "__gt_allocate__", self.backend.__gt_allocate__)

__gt_device_type__ = eve_utils.ForwardDescriptor("backend")
__gt_allocate__ = eve_utils.ForwardDescriptor("backend")

def with_backend(self, backend: ppi.ProgramExecutor) -> Program:
return dataclasses.replace(self, backend=backend)

Expand Down Expand Up @@ -611,7 +614,7 @@ def as_program(
# with the out argument of the program we generate here.

loc = self.foast_node.location
param_sym_uids = UIDGenerator() # use a new UID generator to allow caching
param_sym_uids = eve_utils.UIDGenerator() # use a new UID generator to allow caching

type_ = self.__gt_type__()
params_decl: list[past.Symbol] = [
Expand Down
5 changes: 1 addition & 4 deletions src/gt4py/next/program_processors/runners/roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,9 @@ def execute_roundtrip(
return fencil(*args, **new_kwargs)


class RoundtripExecutor(ppi.ProgramExecutor):
class RoundtripExecutor(ppi.ProgramExecutor, next_allocators.DefaultCPUAllocator):
def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> None:
execute_roundtrip(program, *args, **kwargs)

__gt_device_type__ = next_allocators.DefaultCPUAllocator.__gt_device_type__
__gt_allocate__ = next_allocators.DefaultCPUAllocator.__gt_allocate__


executor = RoundtripExecutor()

0 comments on commit cdc9853

Please sign in to comment.