From 0bcf22f70cb74fe80ed64b14c219b1ab55ba5235 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Tue, 18 Jul 2023 14:19:21 +0200 Subject: [PATCH 01/10] some completed files --- src/gt4py/next/ffront/field_operator_ast.py | 5 +++++ src/gt4py/next/ffront/func_to_foast.py | 7 +++++++ src/gt4py/next/ffront/past_passes/type_deduction.py | 4 ++++ src/gt4py/next/type_system/type_translation.py | 2 ++ 4 files changed, 18 insertions(+) diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index 6b772227b2..5852d30667 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -82,6 +82,11 @@ class Name(Expr): id: Coerced[SymbolRef] # noqa: A003 # shadowing a python builtin +class Dict(Expr): + keys_: list[Name] + values_: list[TupleExpr] + + class Constant(Expr): value: Any # TODO: be more specific diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 7ef4f597ab..f6b9f00d2f 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -323,6 +323,13 @@ def visit_Expr(self, node: ast.Expr) -> foast.Expr: def visit_Name(self, node: ast.Name, **kwargs) -> foast.Name: return foast.Name(id=node.id, location=self._make_loc(node)) + def visit_Dict(self, node: ast.Dict) -> foast.Dict: + return foast.Dict( + keys_=[self.visit(cast(ast.AST, param)) for param in node.keys], + values_=[self.visit(param) for param in node.values], + location=self._make_loc(node), + ) + def visit_UnaryOp(self, node: ast.UnaryOp, **kwargs) -> foast.UnaryOp: return foast.UnaryOp( op=self.visit(node.op), operand=self.visit(node.operand), location=self._make_loc(node) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 4beb5dd8da..bca76da616 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -243,6 +243,10 @@ def visit_Call(self, node: past.Call, **kwargs): location=node.location, ) + def visit_Dict(self, node: past.Dict, **kwargs) -> past.Dict: + assert all(isinstance(key, past.Name) for key in node.keys_) + return past.Dict(keys_=node.keys_, values_=self.visit(node.values_), location=node.location) + def visit_Name(self, node: past.Name, **kwargs) -> past.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 3d054c0746..bf9d961a9a 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -183,6 +183,8 @@ def from_value(value: Any) -> ts.TypeSpec: f"Value `{value}` is out of range to be representable as `INT32` or `INT64`." ) return candidate_type + elif isinstance(value, dict): + return value elif isinstance(value, common.Dimension): symbol_type = ts.DimensionType(dim=value) elif isinstance(value, LocatedField): From 311ab954eccd1af2965eeb903db598ea10fda213 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Wed, 19 Jul 2023 09:18:27 +0200 Subject: [PATCH 02/10] some completed files --- src/gt4py/next/ffront/past_to_itir.py | 2 +- src/gt4py/next/program_processors/runners/roundtrip.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 44a408e9f7..8972b38794 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -113,7 +113,7 @@ def visit_Program( # containing the size of all fields. The caller of a program is (e.g. # program decorator) is required to pass these arguments. - params = self.visit(node.params) + params = list(filter(lambda param: param.id != "domain", self.visit(node.params))) if any("domain" not in body_entry.kwargs for body_entry in node.body): params = params + self._gen_size_params_from_program(node) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 79b2ae1831..7b50e63507 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -201,6 +201,7 @@ def execute_roundtrip( *args, column_axis: Optional[common.Dimension] = None, offset_provider: dict[str, embedded.NeighborTableOffsetProvider], + domain: Optional[dict[common.Dimension, tuple]] = None, debug: bool = False, lift_mode: LiftMode = LiftMode.FORCE_INLINE, dispatch_backend: Optional[str] = None, From 5fc4c08818ef2280ed4b6f4209335013e7390c09 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Thu, 20 Jul 2023 09:27:23 +0200 Subject: [PATCH 03/10] removed edit for backend --- src/gt4py/next/program_processors/runners/roundtrip.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 7b50e63507..79b2ae1831 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -201,7 +201,6 @@ def execute_roundtrip( *args, column_axis: Optional[common.Dimension] = None, offset_provider: dict[str, embedded.NeighborTableOffsetProvider], - domain: Optional[dict[common.Dimension, tuple]] = None, debug: bool = False, lift_mode: LiftMode = LiftMode.FORCE_INLINE, dispatch_backend: Optional[str] = None, From 791c451369003e7190b4a5151dfa97486bb53bf8 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Thu, 20 Jul 2023 09:28:23 +0200 Subject: [PATCH 04/10] edits to type_info --- src/gt4py/next/type_system/type_info.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index ebc0921efe..d2c30b5634 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -565,6 +565,8 @@ def canonicalize_function_arguments( ignore_errors=False, use_signature_ordering=False, ) -> tuple[list, dict]: + if "domain" in func_type.pos_or_kw_args.keys(): + func_type.pos_or_kw_args.pop("domain") num_pos_params = len(func_type.pos_only_args) + len(func_type.pos_or_kw_args) cargs = [UNDEFINED_ARG] * max(num_pos_params, len(args)) ckwargs = {**kwargs} @@ -583,7 +585,7 @@ def canonicalize_function_arguments( ) a, b = set(func_type.kw_only_args.keys()), set(ckwargs.keys()) - invalid_kw_args = (a - b) | (b - a) + invalid_kw_args = (a - b) | (b - a) - {"domain"} if invalid_kw_args and (not ignore_errors or use_signature_ordering): # this error can not be ignored as otherwise the invariant that no arguments are dropped # is invalidated. @@ -640,10 +642,10 @@ def structural_function_signature_incompatibilities( yield f"Missing {len(missing_positional_args)} required positional argument{'s' if len(missing_positional_args) != 1 else ''}: {', '.join(missing_positional_args)}" # check for missing or extra keyword arguments - kw_a_m_b = set(func_type.kw_only_args.keys()) - set(kwargs.keys()) + kw_a_m_b = set(func_type.kw_only_args.keys()) - set(kwargs.keys()) - {"domain"} if len(kw_a_m_b) > 0: yield f"Missing required keyword argument{'s' if len(kw_a_m_b) != 1 else ''} `{'`, `'.join(kw_a_m_b)}`." - kw_b_m_a = set(kwargs.keys()) - set(func_type.kw_only_args.keys()) + kw_b_m_a = set(kwargs.keys()) - set(func_type.kw_only_args.keys()) - {"domain"} if len(kw_b_m_a) > 0: yield f"Got unexpected keyword argument{'s' if len(kw_b_m_a) != 1 else ''} `{'`, `'.join(kw_b_m_a)}`." From b8bdf36dfe12134bacf3f51f19c684792160b6be Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Thu, 20 Jul 2023 10:22:31 +0200 Subject: [PATCH 05/10] last edited files --- src/gt4py/next/ffront/decorator.py | 45 ++++++++++- .../ffront_tests/test_arg_call_interface.py | 78 +++++++++++++++---- 2 files changed, 106 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index a4efd6c168..496f8892b4 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -280,7 +280,13 @@ def format_itir( def _validate_args(self, *args, **kwargs) -> None: arg_types = [type_translation.from_value(arg) for arg in args] - kwarg_types = {k: type_translation.from_value(v) for k, v in kwargs.items()} + kwarg_types = {} + for kwarg in kwargs: + if kwarg == "domain": + kwarg_types[kwarg] = kwargs[kwarg] + else: + for k, v in kwargs.items(): + kwarg_types[k] = type_translation.from_value(v) try: type_info.accepts_args( @@ -319,6 +325,8 @@ def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[s " tuple) need to have the same shape and dimensions." ) size_args.extend(shape if shape else [None] * len(dims)) + if "domain" in kwargs.keys(): + kwargs.pop("domain") return tuple(rewritten_args), tuple(size_args), kwargs @functools.cached_property @@ -484,6 +492,28 @@ def __gt_itir__(self) -> itir.FunctionDefinition: def __gt_closure_vars__(self) -> dict[str, Any]: return self.closure_vars + def _construct_domain(self, kwarg_types: dict, location: Any) -> past.Dict: + domain_keys = [] + domain_values = [] + for key in list(kwarg_types["domain"].keys()): + new_past_name = past.Name( + id=key.value, + location=location, + type=ts.DimensionType(dim=Dimension(value=key.value)), + ) + domain_keys.append(new_past_name) + for value in list(kwarg_types["domain"].values()): + value_0 = past.Constant( + value=value[0], type=ts.ScalarType(kind=ts.ScalarKind.INT64), location=location + ) + value_1 = past.Constant( + value=value[1], type=ts.ScalarType(kind=ts.ScalarKind.INT64), location=location + ) + new_past_tuple = past.TupleExpr(elts=[value_0, value_1], location=location) + domain_values.append(new_past_tuple) + domain_ref = past.Dict(keys_=domain_keys, values_=domain_values, location=location) + return domain_ref + def as_program( self, arg_types: list[ts.TypeSpec], kwarg_types: dict[str, ts.TypeSpec] ) -> Program: @@ -513,6 +543,15 @@ def as_program( location=loc, ) out_ref = past.Name(id="out", location=loc) + domain_sym: past.Symbol = past.DataSymbol( + id="domain", + type=ts.DeferredType(constraint=ts.DimensionType), + namespace=dialect_ast_enums.Namespace.LOCAL, + location=loc, + ) + kwargs_dict = {"out": out_ref} + if "domain" in kwarg_types.keys(): + kwargs_dict["domain"] = self._construct_domain(kwarg_types, loc) if self.foast_node.id in self.closure_vars: raise RuntimeError("A closure variable has the same name as the field operator itself.") @@ -529,12 +568,12 @@ def as_program( untyped_past_node = past.Program( id=f"__field_operator_{self.foast_node.id}", type=ts.DeferredType(constraint=ts_ffront.ProgramType), - params=params_decl + [out_sym], + params=params_decl + [out_sym] + [domain_sym], body=[ past.Call( func=past.Name(id=self.foast_node.id, location=loc), args=params_ref, - kwargs={"out": out_ref}, + kwargs=kwargs_dict, location=loc, ) ], diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index a42456b8d0..d586c55441 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -18,9 +18,9 @@ import numpy as np import pytest -from gt4py.next.common import Field -from gt4py.next.ffront.decorator import field_operator, program, scan_operator -from gt4py.next.ffront.fbuiltins import int64 +import gt4py.next as gtx +from gt4py.next.ffront.fbuiltins import int32 +from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeductionError from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu from next_tests.integration_tests import cases @@ -55,7 +55,7 @@ def _generate_arg_permutations( @pytest.mark.parametrize("arg_spec", _generate_arg_permutations(("a", "b", "c"))) def test_call_field_operator_from_python(cartesian_case, arg_spec: tuple[tuple[str], tuple[str]]): - @field_operator + @gtx.field_operator def testee(a: IField, b: IField, c: IField) -> IField: return a * 2 * b - c @@ -78,11 +78,11 @@ def testee(a: IField, b: IField, c: IField) -> IField: @pytest.mark.parametrize("arg_spec", _generate_arg_permutations(("a", "b", "out"))) def test_call_program_from_python(cartesian_case, arg_spec): - @field_operator + @gtx.field_operator def foo(a: IField, b: IField) -> IField: return a + 2 * b - @program + @gtx.program def testee(a: IField, b: IField, out: IField): foo(a, b, out=out) @@ -103,11 +103,11 @@ def testee(a: IField, b: IField, out: IField): def test_call_field_operator_from_field_operator(cartesian_case): - @field_operator + @gtx.field_operator def foo(x: IField, y: IField, z: IField): return x + 2 * y + 3 * z - @field_operator + @gtx.field_operator def testee(a: IField, b: IField, c: IField) -> IField: return foo(a, b, c) + 5 * foo(a, y=b, z=c) + 7 * foo(a, z=c, y=b) + 11 * foo(a, b, z=c) @@ -126,11 +126,11 @@ def testee_np(a, b, c): def test_call_field_operator_from_program(cartesian_case): - @field_operator + @gtx.field_operator def foo(x: IField, y: IField, z: IField) -> IField: return x + 2 * y + 3 * z - @program + @gtx.program def testee( a: IField, b: IField, @@ -172,11 +172,11 @@ def test_call_scan_operator_from_field_operator(cartesian_case): if cartesian_case.backend == dace_iterator.run_dace_iterator: pytest.xfail("Not supported in DaCe backend: scans") - @scan_operator(axis=KDim, forward=True, init=0.0) + @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def testee_scan(state: float, x: float, y: float) -> float: return state + x + 2.0 * y - @field_operator + @gtx.field_operator def testee(a: IJKFloatField, b: IJKFloatField) -> IJKFloatField: return ( testee_scan(a, b) @@ -199,11 +199,11 @@ def test_call_scan_operator_from_program(cartesian_case): if cartesian_case.backend == dace_iterator.run_dace_iterator: pytest.xfail("Not supported in DaCe backend: scans") - @scan_operator(axis=KDim, forward=True, init=0.0) + @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def testee_scan(state: float, x: float, y: float) -> float: return state + x + 2.0 * y - @program + @gtx.program def testee( a: IJKFloatField, b: IJKFloatField, @@ -235,3 +235,53 @@ def testee( ref=(ref, ref, ref, ref), comparison=lambda out, ref: all(map(np.allclose, zip(out, ref))), ) + + +def test_scan_wrong_return_type(cartesian_case): + with pytest.raises( + FieldOperatorTypeDeductionError, + match=(r"Argument `init` to scan operator `testee_scan` must have same type as its return"), + ): + + @gtx.scan_operator(axis=KDim, forward=True, init=0) + def testee_scan( + state: int32, + ) -> float: + return 1.0 + + @gtx.program + def testee(qc: cases.IKFloatField, param_1: int32, param_2: float, scalar: float): + testee_scan(qc, param_1, param_2, scalar, out=(qc, param_1, param_2)) + + +def test_scan_wrong_state_type(cartesian_case): + with pytest.raises( + FieldOperatorTypeDeductionError, + match=( + r"Argument `init` to scan operator `testee_scan` must have same type as `state` argument" + ), + ): + + @gtx.scan_operator(axis=KDim, forward=True, init=0) + def testee_scan( + state: float, + ) -> int32: + return 1 + + @gtx.program + def testee(qc: cases.IKFloatField, param_1: int32, param_2: float, scalar: float): + testee_scan(qc, param_1, param_2, scalar, out=(qc, param_1, param_2)) + + +def test_call_domain_from_field_operator(cartesian_case): + @gtx.field_operator(backend=cartesian_case.backend) + def fieldop_domain(a: cases.IField) -> cases.IField: + return a + a + + a = cases.allocate(cartesian_case, fieldop_domain, "a")() + out = cases.allocate(cartesian_case, fieldop_domain, cases.RETURN)() + fieldop_domain(a, out=out, offset_provider={}, domain={IDim: (1, 9)}) + ref = a.array()[1:9] * 2 + return_out = out.array()[1:9] + + assert np.allclose(ref, return_out) From d5847d2960d54d0a36d433b0c184b66e7e30df16 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Thu, 20 Jul 2023 14:33:55 +0200 Subject: [PATCH 06/10] edited type_deduction.py --- .../next/ffront/past_passes/type_deduction.py | 28 +++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index bca76da616..66eddad73c 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -15,7 +15,7 @@ from typing import Optional, cast from gt4py.eve import NodeTranslator, traits -from gt4py.next.common import GTTypeError +from gt4py.next.common import Dimension, GTTypeError from gt4py.next.ffront import ( dialect_ast_enums, program_ast as past, @@ -85,11 +85,19 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict): raise GTTypeError( f"Only 2 values allowed in domain range, but got `{len(domain_values.elts)}`." ) - if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar( - domain_values.elts[1] + if not ( + _is_integral_scalar(domain_values.elts[0]) + or isinstance(domain_values.elts[0], (past.BinOp, past.Name)) ): raise GTTypeError( - f"Only integer values allowed in domain range, but got {domain_values.elts[0].type} and {domain_values.elts[1].type}." + f"Only integer values allowed in domain range, but got {domain_values.elts[0].type}." + ) + if not ( + _is_integral_scalar(domain_values.elts[1]) + or isinstance(domain_values.elts[1], (past.BinOp, past.Name)) + ): + raise GTTypeError( + f"Only integer values allowed in domain range, but got {domain_values.elts[1].type}." ) @@ -245,7 +253,17 @@ def visit_Call(self, node: past.Call, **kwargs): def visit_Dict(self, node: past.Dict, **kwargs) -> past.Dict: assert all(isinstance(key, past.Name) for key in node.keys_) - return past.Dict(keys_=node.keys_, values_=self.visit(node.values_), location=node.location) + new_keys = [ + past.Name( + id=key.id, type=ts.DimensionType(dim=Dimension(value=key.id)), location=key.location + ) + for key in node.keys_ + ] + new_values_ = [] + # for value in node.values_: + # values_elts = [elt(type=ts.ScalarType(kind=ts.ScalarKind.INT64)) for elt in value.elts] + # new_values_.append(past.TupleExpr(elts=values_elts, location=value.location)) + return past.Dict(keys_=new_keys, values_=node.values_, location=node.location) def visit_Name(self, node: past.Name, **kwargs) -> past.Name: symtable = kwargs["symtable"] From 69127d61ece6b654a6e36d1d8d03727bd12e74a4 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Thu, 20 Jul 2023 14:35:02 +0200 Subject: [PATCH 07/10] edited type_deduction.py --- src/gt4py/next/ffront/past_passes/type_deduction.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 66eddad73c..4aec88d93c 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -259,10 +259,6 @@ def visit_Dict(self, node: past.Dict, **kwargs) -> past.Dict: ) for key in node.keys_ ] - new_values_ = [] - # for value in node.values_: - # values_elts = [elt(type=ts.ScalarType(kind=ts.ScalarKind.INT64)) for elt in value.elts] - # new_values_.append(past.TupleExpr(elts=values_elts, location=value.location)) return past.Dict(keys_=new_keys, values_=node.values_, location=node.location) def visit_Name(self, node: past.Name, **kwargs) -> past.Name: From fccf08929270cc4e6effcf044e91595553d4e287 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Mon, 7 Aug 2023 14:30:14 +0200 Subject: [PATCH 08/10] small changes --- src/gt4py/next/ffront/decorator.py | 30 +++++++++---------- .../next/ffront/past_passes/type_deduction.py | 10 +++---- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index e356c3854d..5fb6d25745 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -279,14 +279,14 @@ def format_itir( ) def _validate_args(self, *args, **kwargs) -> None: + val_kwargs = {**kwargs} arg_types = [type_translation.from_value(arg) for arg in args] kwarg_types = {} for kwarg in kwargs: - if kwarg == "domain": + if isinstance(kwargs[kwarg], dict): kwarg_types[kwarg] = kwargs[kwarg] - else: - for k, v in kwargs.items(): - kwarg_types[k] = type_translation.from_value(v) + val_kwargs.pop(kwarg) + kwarg_types = {k: type_translation.from_value(v) for k, v in val_kwargs.items()} try: type_info.accepts_args( @@ -493,22 +493,20 @@ def __gt_closure_vars__(self) -> dict[str, Any]: def _construct_domain(self, kwarg_types: dict, location: Any) -> past.Dict: domain_keys = [] domain_values = [] - for key in list(kwarg_types["domain"].keys()): + for key_ls, vals_tup in list(kwarg_types["domain"].items()): new_past_name = past.Name( - id=key.value, + id=key_ls.value, location=location, - type=ts.DimensionType(dim=Dimension(value=key.value)), + type=ts.DimensionType(dim=Dimension(value=key_ls.value)), ) + elts_vals = [ + past.Constant( + value=val, type=ts.ScalarType(kind=ts.ScalarKind.INT64), location=location + ) + for val in vals_tup + ] domain_keys.append(new_past_name) - for value in list(kwarg_types["domain"].values()): - value_0 = past.Constant( - value=value[0], type=ts.ScalarType(kind=ts.ScalarKind.INT64), location=location - ) - value_1 = past.Constant( - value=value[1], type=ts.ScalarType(kind=ts.ScalarKind.INT64), location=location - ) - new_past_tuple = past.TupleExpr(elts=[value_0, value_1], location=location) - domain_values.append(new_past_tuple) + domain_values.append(past.TupleExpr(elts=elts_vals, location=location)) domain_ref = past.Dict(keys_=domain_keys, values_=domain_values, location=location) return domain_ref diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index e0830d55f8..ef57653cd8 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -15,7 +15,7 @@ from typing import Optional, cast from gt4py.eve import NodeTranslator, traits -from gt4py.next import errors +from gt4py.next import Dimension, errors from gt4py.next.ffront import ( dialect_ast_enums, program_ast as past, @@ -86,15 +86,15 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict): f"Only 2 values allowed in domain range, but got `{len(domain_values.elts)}`." ) if not ( - _is_integral_scalar(domain_values.elts[0]) - or isinstance(domain_values.elts[0], (past.BinOp, past.Name)) + _is_integral_scalar(domain_values.elts[0]) + or isinstance(domain_values.elts[0], (past.BinOp, past.Name)) ): raise ValueError( f"Only integer values allowed in domain range, but got {domain_values.elts[0].type}." ) if not ( - _is_integral_scalar(domain_values.elts[1]) - or isinstance(domain_values.elts[1], (past.BinOp, past.Name)) + _is_integral_scalar(domain_values.elts[1]) + or isinstance(domain_values.elts[1], (past.BinOp, past.Name)) ): raise ValueError( f"Only integer values allowed in domain range, but got {domain_values.elts[1].type}." From 5b21eca519bafc529ef4f96b2caae6442a667017 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Mon, 7 Aug 2023 14:33:34 +0200 Subject: [PATCH 09/10] removed unnecessary code --- src/gt4py/next/ffront/field_operator_ast.py | 5 ----- src/gt4py/next/ffront/func_to_foast.py | 7 ------- 2 files changed, 12 deletions(-) diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index 5852d30667..6b772227b2 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -82,11 +82,6 @@ class Name(Expr): id: Coerced[SymbolRef] # noqa: A003 # shadowing a python builtin -class Dict(Expr): - keys_: list[Name] - values_: list[TupleExpr] - - class Constant(Expr): value: Any # TODO: be more specific diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 1ad237c167..082939c938 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -317,13 +317,6 @@ def visit_Expr(self, node: ast.Expr) -> foast.Expr: def visit_Name(self, node: ast.Name, **kwargs) -> foast.Name: return foast.Name(id=node.id, location=self.get_location(node)) - def visit_Dict(self, node: ast.Dict) -> foast.Dict: - return foast.Dict( - keys_=[self.visit(cast(ast.AST, param)) for param in node.keys], - values_=[self.visit(param) for param in node.values], - location=self._make_loc(node), - ) - def visit_UnaryOp(self, node: ast.UnaryOp, **kwargs) -> foast.UnaryOp: return foast.UnaryOp( op=self.visit(node.op), From 946009a1a36b49ff9374091d4602e76427cb16ae Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Mon, 7 Aug 2023 15:10:41 +0200 Subject: [PATCH 10/10] additional decorator changes --- src/gt4py/next/ffront/decorator.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 5fb6d25745..2b97cb2761 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -497,12 +497,10 @@ def _construct_domain(self, kwarg_types: dict, location: Any) -> past.Dict: new_past_name = past.Name( id=key_ls.value, location=location, - type=ts.DimensionType(dim=Dimension(value=key_ls.value)), + type=type_translation.from_value(key_ls), ) elts_vals = [ - past.Constant( - value=val, type=ts.ScalarType(kind=ts.ScalarKind.INT64), location=location - ) + past.Constant(value=val, type=type_translation.from_value(val), location=location) for val in vals_tup ] domain_keys.append(new_past_name)