From 87b5bd50f24337d6dd57996e5f5dbf27a2ee4b8a Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 10 Jan 2025 16:55:06 +0100 Subject: [PATCH] add tuple_get --- .../dace_fieldview/gtir_python_codegen.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index 4bdb602f5f..956a5c6435 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -73,18 +73,12 @@ } -def builtin_cast(*args: Any) -> str: - val, target_type = args +def builtin_cast(val: str, target_type: str) -> str: assert target_type in gtir.TYPEBUILTINS return MATH_BUILTINS_MAPPING[target_type].format(val) -def builtin_if(*args: Any) -> str: - cond, true_val, false_val = args - return f"{true_val} if {cond} else {false_val}" - - -def make_const_list(arg: str) -> str: +def builtin_const_list(arg: str) -> str: """ Takes a single scalar argument and broadcasts this value on the local dimension of map expression. In a dataflow, we represent it as a tasklet that writes @@ -93,10 +87,19 @@ def make_const_list(arg: str) -> str: return arg -GENERAL_BUILTIN_MAPPING: dict[str, Callable[[Any], str]] = { +def builtin_if(cond: str, true_val: str, false_val: str) -> str: + return f"{true_val} if {cond} else {false_val}" + + +def builtin_tuple_get(index: str, tuple_name: str) -> str: + return f"{tuple_name}_{index}" + + +GENERAL_BUILTIN_MAPPING: dict[str, Callable[..., str]] = { "cast_": builtin_cast, "if_": builtin_if, - "make_const_list": make_const_list, + "make_const_list": builtin_const_list, + "tuple_get": builtin_tuple_get, }