Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Jan 13, 2025
1 parent eaaee4e commit d4599d2
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def decorated_program(
args = (*args, out)
flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(tuple(args))
if len(sdfg.arg_names) > len(flat_args):
# The Ahead-of-Time (AOT) workflow for FieldView programs requires domain size arguments.
flat_args = (*flat_args, *arguments.iter_size_args(args))

if sdfg_program._lastargs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,72 @@ def _get_field_layout(
return list(domain_dims), list(domain_lbs), domain_sizes


def _create_field_operator_impl(
sdfg: dace.SDFG,
state: dace.SDFGState,
domain: FieldopDomain,
output_edge: gtir_dataflow.DataflowOutputEdge,
output_type: ts.FieldType,
map_exit: dace.nodes.MapExit,
) -> FieldopData:
"""
Helper method to allocate a temporary array that stores one field computed by a field operator.
This method is called by `_create_field_operator()`.
Args:
sdfg: The SDFG that represents the scope of the field data.
state: The SDFG state where to create an access node to the field data.
domain: The domain of the field operator that computes the field.
output_edge: The dataflow write edge representing the output data.
output_type: The GT4Py field type descriptor.
map_exit: The `MapExit` node of the field operator map scope.
Returns:
The field data descriptor, which includes the field access node in the given `state`
and the field domain offset.
"""
dataflow_output_desc = output_edge.result.dc_node.desc(sdfg)

domain_dims, domain_offset, domain_shape = _get_field_layout(domain)
domain_indices = _get_domain_indices(domain_dims, domain_offset)
domain_subset = dace_subsets.Range.from_indices(domain_indices)

if isinstance(output_edge.result.gt_dtype, ts.ScalarType):
assert output_edge.result.gt_dtype == output_type.dtype
field_dtype = output_edge.result.gt_dtype
field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset)
assert isinstance(dataflow_output_desc, dace.data.Scalar)
field_subset = domain_subset
else:
assert isinstance(output_type.dtype, ts.ListType)
assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType)
assert output_edge.result.gt_dtype.element_type == output_type.dtype.element_type
field_dtype = output_edge.result.gt_dtype.element_type
assert isinstance(dataflow_output_desc, dace.data.Array)
assert len(dataflow_output_desc.shape) == 1
# extend the array with the local dimensions added by the field operator (e.g. `neighbors`)
assert output_edge.result.gt_dtype.offset_type is not None
field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type]
field_shape = [*domain_shape, dataflow_output_desc.shape[0]]
field_offset = [*domain_offset, dataflow_output_desc.offset[0]]
field_subset = domain_subset + dace_subsets.Range.from_array(dataflow_output_desc)

# allocate local temporary storage
assert dataflow_output_desc.dtype == dace_utils.as_dace_type(field_dtype)
field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype)
field_node = state.add_access(field_name)

# and here the edge writing the dataflow result data through the map exit node
output_edge.connect(map_exit, field_node, field_subset)

return FieldopData(
field_node,
ts.FieldType(field_dims, field_dtype),
offset=(field_offset if set(field_offset) != {0} else None),
)


def _create_field_operator(
sdfg: dace.SDFG,
state: dace.SDFGState,
Expand All @@ -275,7 +341,10 @@ def _create_field_operator(
| tuple[gtir_dataflow.DataflowOutputEdge | tuple[Any, ...], ...],
) -> FieldopResult:
"""
Helper method to allocate a temporary field to store the output of a field operator.
Helper method to build the output of a field operator, which can consist of
a single field or, in case of scan, a tuple of fields.
The scan field operator typically computes multiple fields for each K-level:
for each field, this method will call `_create_field_operator_impl()`.
Args:
sdfg: The SDFG that represents the scope of the field data.
Expand All @@ -287,55 +356,12 @@ def _create_field_operator(
output_edges: Single edge or tuple of edges representing the dataflow output data.
Returns:
The field data descriptor, which includes the field access node in the given `state`
and the field domain offset.
The descriptor of the field operator result, which can be either a single field
or a tuple fields.
"""

def _create_field_operator_impl(
output_edge: gtir_dataflow.DataflowOutputEdge, mx: dace.nodes.MapExit, sym: gtir.Sym
) -> FieldopData:
assert isinstance(sym.type, ts.FieldType)
dataflow_output_desc = output_edge.result.dc_node.desc(sdfg)
if isinstance(output_edge.result.gt_dtype, ts.ScalarType):
assert output_edge.result.gt_dtype == sym.type.dtype
field_dtype = output_edge.result.gt_dtype
field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset)
assert isinstance(dataflow_output_desc, dace.data.Scalar)
field_subset = domain_subset
else:
assert isinstance(sym.type.dtype, ts.ListType)
assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType)
assert output_edge.result.gt_dtype.element_type == sym.type.dtype.element_type
field_dtype = output_edge.result.gt_dtype.element_type
assert isinstance(dataflow_output_desc, dace.data.Array)
assert len(dataflow_output_desc.shape) == 1
# extend the array with the local dimensions added by the field operator (e.g. `neighbors`)
assert output_edge.result.gt_dtype.offset_type is not None
field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type]
field_shape = [*domain_shape, dataflow_output_desc.shape[0]]
field_offset = [*domain_offset, dataflow_output_desc.offset[0]]
field_subset = domain_subset + dace_subsets.Range.from_array(dataflow_output_desc)

# allocate local temporary storage
assert dataflow_output_desc.dtype == dace_utils.as_dace_type(field_dtype)
field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype)
field_node = state.add_access(field_name)

# and here the edge writing the dataflow result data through the map exit node
output_edge.connect(mx, field_node, field_subset)

return FieldopData(
field_node,
ts.FieldType(field_dims, field_dtype),
offset=(field_offset if set(field_offset) != {0} else None),
)

domain_dims, domain_offset, domain_shape = _get_field_layout(domain)
domain_indices = _get_domain_indices(domain_dims, domain_offset)
domain_subset = dace_subsets.Range.from_indices(domain_indices)

# create map range corresponding to the field operator domain
me, mx = sdfg_builder.add_map(
map_entry, map_exit = sdfg_builder.add_map(
"fieldop",
state,
ndrange={
Expand All @@ -346,17 +372,19 @@ def _create_field_operator_impl(

# here we setup the edges passing through the map entry node
for edge in input_edges:
edge.connect(me)
edge.connect(map_entry)

if isinstance(output_edges, gtir_dataflow.DataflowOutputEdge):
assert isinstance(node_type, ts.FieldType)
return _create_field_operator_impl(output_edges, mx, im.sym("x", node_type))
if isinstance(node_type, ts.FieldType):
assert isinstance(output_edges, gtir_dataflow.DataflowOutputEdge)
return _create_field_operator_impl(sdfg, state, domain, output_edges, node_type, map_exit)
else:
# handle tuples of fields
assert isinstance(node_type, ts.TupleType)
tuple_syms = dace_gtir_utils.make_symbol_tuple("x", node_type)
return gtx_utils.tree_map(
lambda output_edge, sym: _create_field_operator_impl(output_edge, mx, sym)
)(output_edges, dace_gtir_utils.make_symbol_tuple("x", node_type))
lambda output_edge, output_sym: _create_field_operator_impl(
sdfg, state, domain, output_edge, output_sym.type, map_exit
)
)(output_edges, tuple_syms)


def extract_domain(node: gtir.Node) -> FieldopDomain:
Expand Down Expand Up @@ -444,7 +472,7 @@ def translate_as_fieldop(
fieldop_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args]

# represent the field operator as a mapped tasklet graph, which will range over the field domain
input_edges, output_edges = gtir_dataflow.apply(
input_edges, output_edges = gtir_dataflow.translate_lambda_to_dataflow(
sdfg, state, sdfg_builder, stencil_expr, fieldop_args
)

Expand Down
Loading

0 comments on commit d4599d2

Please sign in to comment.