Skip to content

Commit

Permalink
Merge branch 'main' into field_arg_with_non_zero_domain_start
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber authored Jan 10, 2025
2 parents 30a4911 + 8040178 commit 052c54b
Show file tree
Hide file tree
Showing 12 changed files with 1,384 additions and 112 deletions.
7 changes: 0 additions & 7 deletions .github/workflows/daily-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,6 @@ jobs:
shell: bash
run: |
sudo apt install libboost-dev
wget https://boostorg.jfrog.io/artifactory/main/release/1.76.0/source/boost_1_76_0.tar.gz
echo 7bd7ddceec1a1dfdcbdb3e609b60d01739c38390a5f956385a12f3122049f0ca boost_1_76_0.tar.gz > boost_hash.txt
sha256sum -c boost_hash.txt
tar xzf boost_1_76_0.tar.gz
mkdir -p boost/include
mv boost_1_76_0/boost boost/include/
echo "BOOST_ROOT=${PWD}/boost" >> $GITHUB_ENV
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
Expand Down
10 changes: 2 additions & 8 deletions .github/workflows/test-cartesian.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,10 @@ jobs:
tox-factor: [internal, dace]
steps:
- uses: actions/checkout@v4
- name: Install boost
- name: Install C++ libraries
shell: bash
run: |
wget https://boostorg.jfrog.io/artifactory/main/release/1.76.0/source/boost_1_76_0.tar.gz
echo 7bd7ddceec1a1dfdcbdb3e609b60d01739c38390a5f956385a12f3122049f0ca boost_1_76_0.tar.gz > boost_hash.txt
sha256sum -c boost_hash.txt
tar xzf boost_1_76_0.tar.gz
mkdir -p boost/include
mv boost_1_76_0/boost boost/include/
echo "BOOST_ROOT=${PWD}/boost" >> $GITHUB_ENV
sudo apt install libboost-dev
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/field_operator_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class IfStmt(Stmt):
@datamodels.root_validator
@classmethod
def _collect_common_symbols(cls: type[IfStmt], instance: IfStmt) -> None:
common_symbol_names = (
common_symbol_names = sorted( # sort is required to get stable results across runs
instance.true_branch.annex.symtable.keys() & instance.false_branch.annex.symtable.keys()
)
instance.annex.propagated_symbols = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias

import dace
import dace.subsets as sbs
from dace import subsets as dace_subsets

from gt4py.next import common as gtx_common, utils as gtx_utils
from gt4py.next.ffront import fbuiltins as gtx_fbuiltins
Expand All @@ -30,7 +30,7 @@
gtir_python_codegen,
utility as dace_gtir_utils,
)
from gt4py.next.type_system import type_specifications as ts
from gt4py.next.type_system import type_info as ti, type_specifications as ts


if TYPE_CHECKING:
Expand All @@ -39,7 +39,7 @@

def _get_domain_indices(
dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None
) -> sbs.Indices:
) -> dace_subsets.Indices:
"""
Helper function to construct the list of indices for a field domain, applying
an optional offset in each dimension as start index.
Expand All @@ -55,9 +55,9 @@ def _get_domain_indices(
"""
index_variables = [dace.symbolic.SymExpr(dace_gtir_utils.get_map_variable(dim)) for dim in dims]
if offsets is None:
return sbs.Indices(index_variables)
return dace_subsets.Indices(index_variables)
else:
return sbs.Indices(
return dace_subsets.Indices(
[
index - offset if offset != 0 else index
for index, offset in zip(index_variables, offsets, strict=True)
Expand Down Expand Up @@ -96,7 +96,7 @@ def get_local_view(
"""Helper method to access a field in local view, given the compute domain of a field operator."""
if isinstance(self.gt_type, ts.ScalarType):
return gtir_dataflow.MemletExpr(
dc_node=self.dc_node, gt_dtype=self.gt_type, subset=sbs.Indices([0])
dc_node=self.dc_node, gt_dtype=self.gt_type, subset=dace_subsets.Indices([0])
)

if isinstance(self.gt_type, ts.FieldType):
Expand Down Expand Up @@ -263,7 +263,7 @@ def _create_field_operator(

dataflow_output_desc = output_edge.result.dc_node.desc(sdfg)

field_subset = sbs.Range.from_indices(field_indices)
field_subset = dace_subsets.Range.from_indices(field_indices)
if isinstance(output_edge.result.gt_dtype, ts.ScalarType):
assert output_edge.result.gt_dtype == node_type.dtype
assert isinstance(dataflow_output_desc, dace.data.Scalar)
Expand All @@ -280,7 +280,7 @@ def _create_field_operator(
field_dims.append(output_edge.result.gt_dtype.offset_type)
field_shape.extend(dataflow_output_desc.shape)
field_offset.extend(dataflow_output_desc.offset)
field_subset = field_subset + sbs.Range.from_array(dataflow_output_desc)
field_subset = field_subset + dace_subsets.Range.from_array(dataflow_output_desc)

# allocate local temporary storage
field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype)
Expand Down Expand Up @@ -366,36 +366,37 @@ def translate_as_fieldop(
"""
assert isinstance(node, gtir.FunCall)
assert cpm.is_call_to(node.fun, "as_fieldop")
assert isinstance(node.type, ts.FieldType)

fun_node = node.fun
assert len(fun_node.args) == 2
stencil_expr, domain_expr = fun_node.args
fieldop_expr, domain_expr = fun_node.args

if isinstance(stencil_expr, gtir.Lambda):
# Default case, handled below: the argument expression is a lambda function
# representing the stencil operation to be computed over the field domain.
pass
elif cpm.is_ref_to(stencil_expr, "deref"):
assert isinstance(node.type, ts.FieldType)
if cpm.is_ref_to(fieldop_expr, "deref"):
# Special usage of 'deref' as argument to fieldop expression, to pass a scalar
# value to 'as_fieldop' function. It results in broadcasting the scalar value
# over the field domain.
stencil_expr = im.lambda_("a")(im.deref("a"))
stencil_expr.expr.type = node.type.dtype # type: ignore[attr-defined]
stencil_expr.expr.type = node.type.dtype
elif isinstance(fieldop_expr, gtir.Lambda):
# Default case, handled below: the argument expression is a lambda function
# representing the stencil operation to be computed over the field domain.
stencil_expr = fieldop_expr
else:
raise NotImplementedError(
f"Expression type '{type(stencil_expr)}' not supported as argument to 'as_fieldop' node."
f"Expression type '{type(fieldop_expr)}' not supported as argument to 'as_fieldop' node."
)

# parse the domain of the field operator
domain = extract_domain(domain_expr)

# visit the list of arguments to be passed to the lambda expression
stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args]
fieldop_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args]

# represent the field operator as a mapped tasklet graph, which will range over the field domain
taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder)
input_edges, output_edge = taskgen.visit(stencil_expr, args=stencil_args)
input_edges, output_edge = gtir_dataflow.visit_lambda(
sdfg, state, sdfg_builder, stencil_expr, fieldop_args
)

return _create_field_operator(
sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge
Expand Down Expand Up @@ -654,7 +655,7 @@ def translate_tuple_get(

if not isinstance(node.args[0], gtir.Literal):
raise ValueError("Tuple can only be subscripted with compile-time constants.")
assert node.args[0].type == dace_utils.as_itir_type(INDEX_DTYPE)
assert ti.is_integral(node.args[0].type)
index = int(node.args[0].value)

data_nodes = sdfg_builder.visit(
Expand Down
Loading

0 comments on commit 052c54b

Please sign in to comment.