Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Domain fo call #1291

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
41 changes: 38 additions & 3 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,14 @@ def format_itir(
)

def _validate_args(self, *args, **kwargs) -> None:
val_kwargs = {**kwargs}
arg_types = [type_translation.from_value(arg) for arg in args]
kwarg_types = {k: type_translation.from_value(v) for k, v in kwargs.items()}
kwarg_types = {}
for kwarg in kwargs:
if isinstance(kwargs[kwarg], dict):
kwarg_types[kwarg] = kwargs[kwarg]
val_kwargs.pop(kwarg)
kwarg_types = {k: type_translation.from_value(v) for k, v in val_kwargs.items()}

try:
type_info.accepts_args(
Expand Down Expand Up @@ -317,6 +323,8 @@ def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[s
" tuple) need to have the same shape and dimensions."
)
size_args.extend(shape if shape else [None] * len(dims))
if "domain" in kwargs.keys():
kwargs.pop("domain")
return tuple(rewritten_args), tuple(size_args), kwargs

@functools.cached_property
Expand Down Expand Up @@ -482,6 +490,24 @@ def __gt_itir__(self) -> itir.FunctionDefinition:
def __gt_closure_vars__(self) -> dict[str, Any]:
return self.closure_vars

def _construct_domain(self, kwarg_types: dict, location: Any) -> past.Dict:
domain_keys = []
domain_values = []
for key_ls, vals_tup in list(kwarg_types["domain"].items()):
new_past_name = past.Name(
id=key_ls.value,
location=location,
type=type_translation.from_value(key_ls),
)
elts_vals = [
past.Constant(value=val, type=type_translation.from_value(val), location=location)
for val in vals_tup
]
domain_keys.append(new_past_name)
domain_values.append(past.TupleExpr(elts=elts_vals, location=location))
domain_ref = past.Dict(keys_=domain_keys, values_=domain_values, location=location)
return domain_ref

def as_program(
self, arg_types: list[ts.TypeSpec], kwarg_types: dict[str, ts.TypeSpec]
) -> Program:
Expand Down Expand Up @@ -511,6 +537,15 @@ def as_program(
location=loc,
)
out_ref = past.Name(id="out", location=loc)
domain_sym: past.Symbol = past.DataSymbol(
id="domain",
type=ts.DeferredType(constraint=ts.DimensionType),
namespace=dialect_ast_enums.Namespace.LOCAL,
location=loc,
)
kwargs_dict = {"out": out_ref}
if "domain" in kwarg_types.keys():
kwargs_dict["domain"] = self._construct_domain(kwarg_types, loc)

if self.foast_node.id in self.closure_vars:
raise RuntimeError("A closure variable has the same name as the field operator itself.")
Expand All @@ -527,12 +562,12 @@ def as_program(
untyped_past_node = past.Program(
id=f"__field_operator_{self.foast_node.id}",
type=ts.DeferredType(constraint=ts_ffront.ProgramType),
params=params_decl + [out_sym],
params=params_decl + [out_sym] + [domain_sym],
body=[
past.Call(
func=past.Name(id=self.foast_node.id, location=loc),
args=params_ref,
kwargs={"out": out_ref},
kwargs=kwargs_dict,
location=loc,
)
],
Expand Down
26 changes: 22 additions & 4 deletions src/gt4py/next/ffront/past_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Optional, cast

from gt4py.eve import NodeTranslator, traits
from gt4py.next import errors
from gt4py.next import Dimension, errors
from gt4py.next.ffront import (
dialect_ast_enums,
program_ast as past,
Expand Down Expand Up @@ -85,11 +85,19 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict):
raise ValueError(
f"Only 2 values allowed in domain range, but got `{len(domain_values.elts)}`."
)
if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar(
domain_values.elts[1]
if not (
_is_integral_scalar(domain_values.elts[0])
or isinstance(domain_values.elts[0], (past.BinOp, past.Name))
):
raise ValueError(
f"Only integer values allowed in domain range, but got {domain_values.elts[0].type} and {domain_values.elts[1].type}."
f"Only integer values allowed in domain range, but got {domain_values.elts[0].type}."
)
if not (
_is_integral_scalar(domain_values.elts[1])
or isinstance(domain_values.elts[1], (past.BinOp, past.Name))
):
raise ValueError(
f"Only integer values allowed in domain range, but got {domain_values.elts[1].type}."
)


Expand Down Expand Up @@ -241,6 +249,16 @@ def visit_Call(self, node: past.Call, **kwargs):
location=node.location,
)

def visit_Dict(self, node: past.Dict, **kwargs) -> past.Dict:
assert all(isinstance(key, past.Name) for key in node.keys_)
new_keys = [
past.Name(
id=key.id, type=ts.DimensionType(dim=Dimension(value=key.id)), location=key.location
)
for key in node.keys_
]
return past.Dict(keys_=new_keys, values_=node.values_, location=node.location)

def visit_Name(self, node: past.Name, **kwargs) -> past.Name:
symtable = kwargs["symtable"]
if node.id not in symtable or symtable[node.id].type is None:
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def visit_Program(
# containing the size of all fields. The caller of a program is (e.g.
# program decorator) is required to pass these arguments.

params = self.visit(node.params)
params = list(filter(lambda param: param.id != "domain", self.visit(node.params)))

if any("domain" not in body_entry.kwargs for body_entry in node.body):
params = params + self._gen_size_params_from_program(node)
Expand Down
8 changes: 5 additions & 3 deletions src/gt4py/next/type_system/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,8 @@ def canonicalize_function_arguments(
ignore_errors=False,
use_signature_ordering=False,
) -> tuple[list, dict]:
if "domain" in func_type.pos_or_kw_args.keys():
func_type.pos_or_kw_args.pop("domain")
num_pos_params = len(func_type.pos_only_args) + len(func_type.pos_or_kw_args)
cargs = [UNDEFINED_ARG] * max(num_pos_params, len(args))
ckwargs = {**kwargs}
Expand All @@ -581,7 +583,7 @@ def canonicalize_function_arguments(
)

a, b = set(func_type.kw_only_args.keys()), set(ckwargs.keys())
invalid_kw_args = (a - b) | (b - a)
invalid_kw_args = (a - b) | (b - a) - {"domain"}
if invalid_kw_args and (not ignore_errors or use_signature_ordering):
# this error can not be ignored as otherwise the invariant that no arguments are dropped
# is invalidated.
Expand Down Expand Up @@ -638,10 +640,10 @@ def structural_function_signature_incompatibilities(
yield f"Missing {len(missing_positional_args)} required positional argument{'s' if len(missing_positional_args) != 1 else ''}: {', '.join(missing_positional_args)}"

# check for missing or extra keyword arguments
kw_a_m_b = set(func_type.kw_only_args.keys()) - set(kwargs.keys())
kw_a_m_b = set(func_type.kw_only_args.keys()) - set(kwargs.keys()) - {"domain"}
if len(kw_a_m_b) > 0:
yield f"Missing required keyword argument{'s' if len(kw_a_m_b) != 1 else ''} `{'`, `'.join(kw_a_m_b)}`."
kw_b_m_a = set(kwargs.keys()) - set(func_type.kw_only_args.keys())
kw_b_m_a = set(kwargs.keys()) - set(func_type.kw_only_args.keys()) - {"domain"}
if len(kw_b_m_a) > 0:
yield f"Got unexpected keyword argument{'s' if len(kw_b_m_a) != 1 else ''} `{'`, `'.join(kw_b_m_a)}`."

Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/next/type_system/type_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def from_value(value: Any) -> ts.TypeSpec:
f"Value `{value}` is out of range to be representable as `INT32` or `INT64`."
)
return candidate_type
elif isinstance(value, dict):
return value
elif isinstance(value, common.Dimension):
symbol_type = ts.DimensionType(dim=value)
elif isinstance(value, LocatedField):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pytest

from gt4py.next import errors
from gt4py.next.common import Field
import gt4py.next as gtx
from gt4py.next.ffront.decorator import field_operator, program, scan_operator
from gt4py.next.ffront.fbuiltins import int32, int64
from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu
Expand Down Expand Up @@ -56,7 +56,7 @@ def _generate_arg_permutations(

@pytest.mark.parametrize("arg_spec", _generate_arg_permutations(("a", "b", "c")))
def test_call_field_operator_from_python(cartesian_case, arg_spec: tuple[tuple[str], tuple[str]]):
@field_operator
@gtx.field_operator
def testee(a: IField, b: IField, c: IField) -> IField:
return a * 2 * b - c

Expand All @@ -79,11 +79,11 @@ def testee(a: IField, b: IField, c: IField) -> IField:

@pytest.mark.parametrize("arg_spec", _generate_arg_permutations(("a", "b", "out")))
def test_call_program_from_python(cartesian_case, arg_spec):
@field_operator
@gtx.field_operator
def foo(a: IField, b: IField) -> IField:
return a + 2 * b

@program
@gtx.program
def testee(a: IField, b: IField, out: IField):
foo(a, b, out=out)

Expand All @@ -104,11 +104,11 @@ def testee(a: IField, b: IField, out: IField):


def test_call_field_operator_from_field_operator(cartesian_case):
@field_operator
@gtx.field_operator
def foo(x: IField, y: IField, z: IField):
return x + 2 * y + 3 * z

@field_operator
@gtx.field_operator
def testee(a: IField, b: IField, c: IField) -> IField:
return foo(a, b, c) + 5 * foo(a, y=b, z=c) + 7 * foo(a, z=c, y=b) + 11 * foo(a, b, z=c)

Expand All @@ -127,11 +127,11 @@ def testee_np(a, b, c):


def test_call_field_operator_from_program(cartesian_case):
@field_operator
@gtx.field_operator
def foo(x: IField, y: IField, z: IField) -> IField:
return x + 2 * y + 3 * z

@program
@gtx.program
def testee(
a: IField,
b: IField,
Expand Down Expand Up @@ -175,11 +175,11 @@ def test_call_scan_operator_from_field_operator(cartesian_case):
]:
pytest.xfail("Calling scan from field operator not fully supported.")

@scan_operator(axis=KDim, forward=True, init=0.0)
@gtx.scan_operator(axis=KDim, forward=True, init=0.0)
def testee_scan(state: float, x: float, y: float) -> float:
return state + x + 2.0 * y

@field_operator
@gtx.field_operator
def testee(a: IJKFloatField, b: IJKFloatField) -> IJKFloatField:
return (
testee_scan(a, b)
Expand All @@ -199,11 +199,11 @@ def testee(a: IJKFloatField, b: IJKFloatField) -> IJKFloatField:


def test_call_scan_operator_from_program(cartesian_case):
@scan_operator(axis=KDim, forward=True, init=0.0)
@gtx.scan_operator(axis=KDim, forward=True, init=0.0)
def testee_scan(state: float, x: float, y: float) -> float:
return state + x + 2.0 * y

@program
@gtx.program
def testee(
a: IJKFloatField,
b: IJKFloatField,
Expand Down Expand Up @@ -243,13 +243,13 @@ def test_scan_wrong_return_type(cartesian_case):
match=(r"Argument `init` to scan operator `testee_scan` must have same type as its return"),
):

@scan_operator(axis=KDim, forward=True, init=0)
@gtx.scan_operator(axis=KDim, forward=True, init=0)
def testee_scan(
state: int32,
) -> float:
return 1.0

@program
@gtx.program
def testee(qc: cases.IKFloatField, param_1: int32, param_2: float, scalar: float):
testee_scan(qc, param_1, param_2, scalar, out=(qc, param_1, param_2))

Expand All @@ -262,12 +262,26 @@ def test_scan_wrong_state_type(cartesian_case):
),
):

@scan_operator(axis=KDim, forward=True, init=0)
@gtx.scan_operator(axis=KDim, forward=True, init=0)
def testee_scan(
state: float,
) -> int32:
return 1

@program
@gtx.program
def testee(qc: cases.IKFloatField, param_1: int32, param_2: float, scalar: float):
testee_scan(qc, param_1, param_2, scalar, out=(qc, param_1, param_2))


def test_call_domain_from_field_operator(cartesian_case):
@gtx.field_operator(backend=cartesian_case.backend)
def fieldop_domain(a: cases.IField) -> cases.IField:
return a + a

a = cases.allocate(cartesian_case, fieldop_domain, "a")()
out = cases.allocate(cartesian_case, fieldop_domain, cases.RETURN)()
fieldop_domain(a, out=out, offset_provider={}, domain={IDim: (1, 9)})
ref = a.array()[1:9] * 2
return_out = out.array()[1:9]

assert np.allclose(ref, return_out)