Skip to content

Commit

Permalink
minor edit
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Dec 20, 2024
1 parent f2396c4 commit a0c37cb
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,19 @@ def get_local_view(

def get_tuple_type(data: tuple[FieldopResult, ...]) -> ts.TupleType:
"""
Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes.
Compute the `ts.TupleType` corresponding to the structure of a tuple of `FieldopResult`.
"""
return ts.TupleType(
types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data]
)


def flatten_tuples(name: str, arg: FieldopResult) -> list[tuple[str, FieldopData]]:
"""
Visit a `FieldopResult`, potentially containing nested tuples, and construct a list
of pairs `(str, FieldopData)` containing the symbol name of each tuple field and
the corresponding `FieldopData`.
"""
if isinstance(arg, tuple):
tuple_type = get_tuple_type(arg)
tuple_symbols = dace_gtir_utils.flatten_tuple_fields(name, tuple_type)
Expand Down Expand Up @@ -337,9 +342,6 @@ def _create_field_operator(

# here we setup the edges passing through the map entry node
for edge in input_edges:
if isinstance(edge, gtir_dataflow.EmptyInputEdge) and me is None:
# cannot create empty edge from MapEntry node, if this is not present
continue
edge.connect(me)

def create_field(output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym) -> FieldopData:
Expand Down Expand Up @@ -402,6 +404,7 @@ def create_field(output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym) -
assert isinstance(node_type, ts.FieldType)
return create_field(output_edges, im.sym("x", node_type))
else:
# handle tuples of fields
assert isinstance(node_type, ts.TupleType)
return gtx_utils.tree_map(create_field)(
output_edges, dace_gtir_utils.make_symbol_tuple("x", node_type)
Expand Down Expand Up @@ -888,8 +891,8 @@ def translate_scan(
stencil_expr = scan_expr.args[0]
assert isinstance(stencil_expr, gtir.Lambda)

# params[0]: the lambda parameter to propagate the scan state on the vertical dimension
scan_state = str(stencil_expr.params[0].id)
# params[0]: the lambda parameter to propagate the scan carry on the vertical dimension
scan_carry = str(stencil_expr.params[0].id)

# params[1]: boolean flag for forward/backward scan
assert isinstance(scan_expr.args[1], gtir.Literal) and ti.is_logical(scan_expr.args[1].type)
Expand All @@ -908,13 +911,13 @@ def scan_output_name(input_name: str) -> str:
# visit the initialization value of the scan expression
init_data = sdfg_builder.visit(init_value, sdfg=sdfg, head_state=state)

# extract type definition of the scan state
scan_state_type = (
# extract type definition of the scan carry
scan_carry_type = (
init_data.gt_type if isinstance(init_data, FieldopData) else get_tuple_type(init_data)
)

# create list of params to the lambda function with associated node type
lambda_symbols = {scan_state: scan_state_type} | {
lambda_symbols = {scan_carry: scan_carry_type} | {
str(p.id): arg.type
for p, arg in zip(stencil_expr.params[1:], node.args, strict=True)
if isinstance(arg.type, ts.DataType)
Expand All @@ -925,7 +928,7 @@ def scan_output_name(input_name: str) -> str:
# the data descriptor with the correct field domain offsets for field arguments
lambda_args = [sdfg_builder.visit(arg, sdfg=sdfg, head_state=state) for arg in node.args]
lambda_args_mapping = {
scan_input_name(scan_state): init_data,
scan_input_name(scan_carry): init_data,
} | {
str(param.id): arg for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True)
}
Expand All @@ -941,11 +944,11 @@ def scan_output_name(input_name: str) -> str:
{
str(sym.id): sym.type
for sym in dace_gtir_utils.flatten_tuple_fields(
scan_output_name(scan_state), scan_state_type
scan_output_name(scan_carry), scan_carry_type
)
}
if isinstance(scan_state_type, ts.TupleType)
else {scan_output_name(scan_state): scan_state_type}
if isinstance(scan_carry_type, ts.TupleType)
else {scan_output_name(scan_carry): scan_carry_type}
)

# the scan operator is implemented as an nested SDFG implementing the lambda expression
Expand Down Expand Up @@ -998,14 +1001,14 @@ def scan_output_name(input_name: str) -> str:
nsdfg, compute_state, stencil_builder, stencil_expr, args=stencil_args
)

# now initialize the scan state
scan_state_input = (
dace_gtir_utils.make_symbol_tuple(scan_state, scan_state_type)
if isinstance(scan_state_type, ts.TupleType)
else im.sym(scan_state, scan_state_type)
# now initialize the scan carry
scan_carry_input = (
dace_gtir_utils.make_symbol_tuple(scan_carry, scan_carry_type)
if isinstance(scan_carry_type, ts.TupleType)
else im.sym(scan_carry, scan_carry_type)
)

def init_scan_state(sym: gtir.Sym) -> None:
def init_scan_carry(sym: gtir.Sym) -> None:
scan_state = str(sym.id)
scan_state_desc = nsdfg.data(scan_state)
input_state = scan_input_name(scan_state)
Expand All @@ -1018,17 +1021,17 @@ def init_scan_state(sym: gtir.Sym) -> None:
nsdfg.make_array_memlet(input_state),
)

if isinstance(scan_state_input, tuple):
gtx_utils.tree_map(init_scan_state)(scan_state_input)
if isinstance(scan_carry_input, tuple):
gtx_utils.tree_map(init_scan_carry)(scan_carry_input)
else:
init_scan_state(scan_state_input)
init_scan_carry(scan_carry_input)

# connect the dataflow input directly to the source data nodes, without passing through a map node;
# the reason is that the map for horizontal domain is outside the scan loop region
for edge in input_edges:
edge.connect(map_entry=None)

# connect the dataflow result nodes to the variables that carry the scan state along the column axis
# connect the dataflow result nodes to the carry variables
def connect_scan_output(
scan_output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym
) -> FieldopData:
Expand Down Expand Up @@ -1057,12 +1060,12 @@ def connect_scan_output(
return FieldopData(output_node, output_type, scan_output_offset)

lambda_output = (
gtx_utils.tree_map(connect_scan_output)(result, scan_state_input)
if (isinstance(result, tuple) and isinstance(scan_state_input, tuple))
else connect_scan_output(result, scan_state_input)
gtx_utils.tree_map(connect_scan_output)(result, scan_carry_input)
if (isinstance(result, tuple) and isinstance(scan_carry_input, tuple))
else connect_scan_output(result, scan_carry_input)
if (
isinstance(result, gtir_dataflow.DataflowOutputEdge)
and isinstance(scan_state_input, gtir.Sym)
and isinstance(scan_carry_input, gtir.Sym)
)
else None
)
Expand All @@ -1075,8 +1078,8 @@ def connect_scan_output(
if (compute_state.degree(data_node) == 0) and (
(not data_desc.transient)
or data_node.data.startswith(
scan_state
) # exceptional case where the state is not used, not a scan indeed
scan_carry
) # exceptional case where the carry variable is not used, not a scan indeed
):
# isolated node, connect it to a transient to avoid SDFG validation errors
temp, temp_desc = nsdfg.add_temp_transient_like(data_desc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class EmptyInputEdge(DataflowInputEdge):
node: dace.nodes.Tasklet

def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None:
# cannot create empty edge from MapEntry node, if this is not present
if map_entry is not None:
self.state.add_nedge(map_entry, self.node, dace.Memlet())

Expand Down Expand Up @@ -564,7 +565,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr:
def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[Any, ...], ...]:
assert len(node.args) == 3

# TODO(edopao): enable once DaCe supports it in next release
# TODO(edopao): enable once supported in next DaCe release
use_conditional_block: Final[bool] = False

condition_value = self.visit(node.args[0])
Expand Down Expand Up @@ -690,6 +691,7 @@ def visit_arg(arg: IteratorExpr | DataExpr) -> IteratorExpr | ValueExpr:
lambda_args.append(inner_arg)
lambda_params.append(im.sym(p))

# visit each branch of the if-statement as it was a Lambda node
lambda_node = gtir.Lambda(params=lambda_params, expr=expr)
return visit_lambda(nsdfg, state, self.subgraph_builder, lambda_node, lambda_args)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_map_variable(dim: gtx_common.Dimension) -> str:
def make_symbol_tuple(tuple_name: str, tuple_type: ts.TupleType) -> tuple[gtir.Sym, ...]:
"""
Creates a tuple representation of the symbols corresponding to the tuple fields.
The constructed tuple presrves the nested nature of the type, is any.
The constructed tuple preserves the nested nature of the type, if any.
Examples
--------
Expand All @@ -53,7 +53,7 @@ def make_symbol_tuple(tuple_name: str, tuple_type: ts.TupleType) -> tuple[gtir.S

def flatten_tuple_fields(tuple_name: str, tuple_type: ts.TupleType) -> list[gtir.Sym]:
"""
Creates a list of names with the corresponding data type for all elements of the given tuple.
Creates a list of symbols, annotated with the data type, for all elements of the given tuple.
Examples
--------
Expand Down

0 comments on commit a0c37cb

Please sign in to comment.