Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Jan 17, 2025
1 parent d9717ac commit ca246d6
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,12 @@ def _create_scan_field_operator_impl(
domain_subset = dace_subsets.Range.from_indices(domain_indices)

if isinstance(output_edge.result.gt_dtype, ts.ScalarType):
assert output_edge.result.gt_dtype == output_type.dtype
field_dtype = output_edge.result.gt_dtype
field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset)
# the scan field operator produces a 1D vertical field
assert isinstance(dataflow_output_desc, dace.data.Array)
assert len(dataflow_output_desc.shape) == 1
assert output_edge.result.gt_dtype == output_type.dtype
field_dtype = output_edge.result.gt_dtype
field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset)
# the vertical dimension should not belong to the field operator domain
# but we need to write it to the output field
scan_dim_index = domain_dims.index(scan_dim)
Expand Down Expand Up @@ -952,11 +952,10 @@ def translate_scan(
scan_domain = [
(dim, lower_bound, upper_bound)
for dim, lower_bound, upper_bound in domain
if dim.kind == gtx_common.DimensionKind.VERTICAL
if sdfg_builder.is_column_axis(dim)
]
assert len(scan_domain) == 1
scan_dim, scan_lower_bound, scan_upper_bound = scan_domain[0]
assert sdfg_builder.is_column_axis(scan_dim)

# parse scan parameters
assert len(scan_expr.args) == 3
Expand All @@ -970,8 +969,14 @@ def translate_scan(
assert isinstance(scan_expr.args[1], gtir.Literal) and ti.is_logical(scan_expr.args[1].type)
scan_forward = scan_expr.args[1].value == "True"

# params[2]: the value for scan initialization
init_value = scan_expr.args[2]
# params[2]: the expression that computes the value for scan initialization
init_expr = scan_expr.args[2]
# visit the initialization value of the scan expression
init_data = sdfg_builder.visit(init_expr, sdfg=sdfg, head_state=state)
# extract type definition of the scan carry
scan_carry_type = (
init_data.gt_type if isinstance(init_data, FieldopData) else get_tuple_type(init_data)
)

# make naming consistent throughut this function scope
def scan_input_name(input_name: str) -> str:
Expand All @@ -980,20 +985,14 @@ def scan_input_name(input_name: str) -> str:
def scan_output_name(input_name: str) -> str:
return f"__gtir_scan_output_{input_name}"

# visit the initialization value of the scan expression
init_data = sdfg_builder.visit(init_value, sdfg=sdfg, head_state=state)

# extract type definition of the scan carry
scan_carry_type = (
init_data.gt_type if isinstance(init_data, FieldopData) else get_tuple_type(init_data)
)

# define the set of symbols available in the lambda context, which consists of
# the carry argument and all lambda function arguments
lambda_symbols = {scan_carry: scan_carry_type} | {
str(p.id): arg.type
for p, arg in zip(stencil_expr.params[1:], node.args, strict=True)
if isinstance(arg.type, ts.DataType)
lambda_arg_types = [scan_carry_type] + [
arg.type for arg in node.args if isinstance(arg.type, ts.DataType)
]
lambda_symbols = {
str(p.id): arg_type
for p, arg_type in zip(stencil_expr.params, lambda_arg_types, strict=True)
}

# visit the arguments to be passed to the lambda expression
Expand All @@ -1008,9 +1007,10 @@ def scan_output_name(input_name: str) -> str:

# parse the dataflow input and output symbols
lambda_flat_args: dict[str, FieldopData] = {}
# the field offset is set to `None` when it is zero in all dimensions
lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = {}
for param, arg in lambda_args_mapping.items():
tuple_fields = flatten_tuples(param, arg)
for param, outer_arg in lambda_args_mapping.items():
tuple_fields = flatten_tuples(param, outer_arg)
lambda_field_offsets |= {tsym: tfield.offset for tsym, tfield in tuple_fields}
lambda_flat_args |= dict(tuple_fields)
if isinstance(scan_carry_type, ts.TupleType):
Expand All @@ -1030,9 +1030,26 @@ def scan_output_name(input_name: str) -> str:
stencil_expr, nsdfg, lambda_symbols, lambda_field_offsets
)

# in case the scan operator computes a list (not a scalar), we need to add an extra dimension
def get_scan_output_shape(scan_init_data: FieldopData) -> list[dace.symbolic.SymExpr]:
scan_column_size = scan_upper_bound - scan_lower_bound
if isinstance(scan_init_data.gt_type, ts.ScalarType):
return [scan_column_size]
assert isinstance(scan_init_data.gt_type, ts.ListType)
assert scan_init_data.gt_type.offset_type
offset_type = scan_init_data.gt_type.offset_type
offset_provider_type = sdfg_builder.get_offset_provider_type(offset_type.value)
assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType)
list_size = offset_provider_type.max_neighbors
return [scan_column_size, dace.symbolic.SymExpr(list_size)]

if isinstance(init_data, tuple):
lambda_result_shape = gtx_utils.tree_map(get_scan_output_shape)(init_data)
else:
lambda_result_shape = get_scan_output_shape(init_data)

# extract the scan loop range
scan_loop_var = dace_gtir_utils.get_map_variable(scan_dim)
_, scan_output_offset, scan_output_shape = _get_field_layout(scan_domain)

# create a loop region for lambda call over the scan dimension
if scan_forward:
Expand All @@ -1058,8 +1075,7 @@ def scan_output_name(input_name: str) -> str:
init_state = nsdfg.add_state("scan_init", is_start_block=True)
nsdfg.add_edge(init_state, scan_loop, dace.InterstateEdge())
compute_state = scan_loop.add_state("scan_compute")
update_state = scan_loop.add_state("scan_update")
scan_loop.add_edge(compute_state, update_state, dace.InterstateEdge())
update_state = scan_loop.add_state_after(compute_state, "scan_update")

# visit the list of arguments to be passed to the scan expression
stencil_args = [
Expand All @@ -1070,7 +1086,7 @@ def scan_output_name(input_name: str) -> str:
]

# generate the dataflow representing the scan field operator
input_edges, result = gtir_dataflow.translate_lambda_to_dataflow(
lambda_input_edges, lambda_result = gtir_dataflow.translate_lambda_to_dataflow(
nsdfg, compute_state, lambda_translator, stencil_expr, args=stencil_args
)

Expand Down Expand Up @@ -1101,25 +1117,40 @@ def init_scan_carry(sym: gtir.Sym) -> None:

# connect the dataflow input directly to the source data nodes, without passing through a map node;
# the reason is that the map for horizontal domain is outside the scan loop region
for edge in input_edges:
for edge in lambda_input_edges:
edge.connect(map_entry=None)

# connect the dataflow result nodes to the carry variables
output_column_index = dace.symbolic.pystr_to_symbolic(scan_loop_var) - scan_lower_bound

def connect_scan_output(
scan_output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym
scan_output_edge: gtir_dataflow.DataflowOutputEdge,
scan_output_shape: list[dace.symbolic.SymExpr],
sym: gtir.Sym,
) -> FieldopData:
scan_result = scan_output_edge.result
assert isinstance(scan_result.gt_dtype, ts.ScalarType)
if isinstance(scan_result.gt_dtype, ts.ScalarType):
assert scan_result.gt_dtype == sym.type
# the scan field operator computes a column of scalar values
assert len(scan_output_shape) == 1
output_subset = dace_subsets.Range.from_string(str(output_column_index))
else:
assert isinstance(sym.type, ts.ListType)
assert scan_result.gt_dtype.element_type == sym.type.element_type
# the scan field operator computes a list of scalar values for each column level
assert len(scan_output_shape) == 2
output_subset = dace_subsets.Range.from_string(
f"{output_column_index}, 0:{scan_output_shape[1]}"
)
scan_result_data = scan_result.dc_node.data
scan_result_desc = scan_result.dc_node.desc(nsdfg)

# `sym` represents the global output data, that is the lambda output connector
# `sym` represents the global output data, that is the nested-SDFG output connector
lambda_output = str(sym.id)
output = scan_output_name(lambda_output)
assert scan_result.gt_dtype == sym.type
nsdfg.add_array(output, scan_output_shape, scan_result_desc.dtype)
output_node = compute_state.add_access(output)
output_subset = str(dace.symbolic.pystr_to_symbolic(scan_loop_var) - scan_lower_bound)

compute_state.add_nedge(
scan_result.dc_node, output_node, dace.Memlet(data=output, subset=output_subset)
)
Expand All @@ -1131,14 +1162,18 @@ def connect_scan_output(
)

output_type = ts.FieldType(dims=[scan_dim], dtype=scan_result.gt_dtype)
return FieldopData(output_node, output_type, scan_output_offset)
return FieldopData(output_node, output_type, offset=scan_lower_bound)

if isinstance(result, tuple):
assert isinstance(scan_carry_input, tuple)
lambda_output = gtx_utils.tree_map(connect_scan_output)(result, scan_carry_input)
if isinstance(scan_carry_input, tuple):
assert isinstance(lambda_result, tuple)
assert isinstance(lambda_result_shape, tuple)
lambda_output = gtx_utils.tree_map(connect_scan_output)(
lambda_result, lambda_result_shape, scan_carry_input
)
else:
assert isinstance(scan_carry_input, gtir.Sym)
lambda_output = connect_scan_output(result, scan_carry_input)
assert isinstance(lambda_result, gtir_dataflow.DataflowOutputEdge)
assert isinstance(lambda_result_shape, list)
lambda_output = connect_scan_output(lambda_result, lambda_result_shape, scan_carry_input)

# in case tuples are passed as argument, isolated access nodes might be left in the state,
# because not all tuple fields are necessarily accessed inside the lambda scope
Expand All @@ -1154,9 +1189,9 @@ def connect_scan_output(

# build the mapping of symbols from nested SDFG to parent SDFG
nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols}
for inner, arg in lambda_flat_args.items():
inner_desc = nsdfg.data(inner)
outer_desc = arg.dc_node.desc(sdfg)
for inner_dataname, outer_arg in lambda_flat_args.items():
inner_desc = nsdfg.data(inner_dataname)
outer_desc = outer_arg.dc_node.desc(sdfg)
nsdfg_symbols_mapping |= {
str(nested_symbol): parent_symbol
for nested_symbol, parent_symbol in zip(
Expand All @@ -1176,14 +1211,14 @@ def connect_scan_output(
symbol_mapping=nsdfg_symbols_mapping,
)

input_edges = []
for input_connector, arg in lambda_flat_args.items():
arg_desc = arg.dc_node.desc(sdfg)
lambda_input_edges = []
for input_connector, outer_arg in lambda_flat_args.items():
arg_desc = outer_arg.dc_node.desc(sdfg)
input_subset = dace_subsets.Range.from_array(arg_desc)
input_edge = gtir_dataflow.MemletInputEdge(
state, arg.dc_node, input_subset, nsdfg_node, input_connector
state, outer_arg.dc_node, input_subset, nsdfg_node, input_connector
)
input_edges.append(input_edge)
lambda_input_edges.append(input_edge)

def construct_output_edge(scan_data: FieldopData) -> gtir_dataflow.DataflowOutputEdge:
assert isinstance(scan_data.gt_type, ts.FieldType)
Expand All @@ -1208,7 +1243,7 @@ def construct_output_edge(scan_data: FieldopData) -> gtir_dataflow.DataflowOutpu
)

return _create_field_operator(
sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edges, scan_dim
sdfg, state, domain, node.type, sdfg_builder, lambda_input_edges, output_edges, scan_dim
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,16 @@ def nested_context(
field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]],
) -> SDFGBuilder:
"""
Create a new context for lowering of an expression in a nested SDFG.
Create an SDFG context to translate a nested expression, indipendent
from the current context where the parent expression is being translated.
This method will setup the global symbols, that correspond to the parameters
of the expression to be lowered, as well as the set of symbolic arguments,
that is scalar values used in internal domain expressions.
Args:
expr: The GTIR expresson to be lowered.
sdfg: The SDFG where to lower the expression.
expr: The nested expresson to be lowered.
sdfg: The SDFG where to lower the nested expression.
global_symbols: Mapping from symbol name to GTIR data type.
field_offsets: Mapping from symbol name to field origin, `None` if field origin is 0 in all dimensions.
Expand Down

0 comments on commit ca246d6

Please sign in to comment.