Skip to content

Commit

Permalink
feat[next][dace]: Symbolic domain without dace array offsets (#1735)
Browse files Browse the repository at this point in the history
Add support for field operator domain with symbolic shape,
with dimension extent in non zero-based range.
  • Loading branch information
edopao authored Nov 26, 2024
1 parent d7f5552 commit 3fb206e
Show file tree
Hide file tree
Showing 6 changed files with 367 additions and 152 deletions.
10 changes: 7 additions & 3 deletions src/gt4py/next/program_processors/runners/dace_common/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from __future__ import annotations

import re
from typing import Final, Optional, Sequence
from typing import Final, Literal, Optional, Sequence

import dace

Expand Down Expand Up @@ -51,12 +51,16 @@ def connectivity_identifier(name: str) -> str:
return f"connectivity_{name}"


def field_symbol_name(field_name: str, axis: int, sym: Literal["size", "stride"]) -> str:
return f"__{field_name}_{sym}_{axis}"


def field_size_symbol_name(field_name: str, axis: int) -> str:
return f"__{field_name}_size_{axis}"
return field_symbol_name(field_name, axis, "size")


def field_stride_symbol_name(field_name: str, axis: int) -> str:
return f"__{field_name}_stride_{axis}"
return field_symbol_name(field_name, axis, "stride")


def is_field_symbol(name: str) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import abc
import dataclasses
from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, TypeAlias
from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias

import dace
import dace.subsets as sbs
Expand All @@ -33,6 +33,34 @@
from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg


def _get_domain_indices(
dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None
) -> sbs.Indices:
"""
Helper function to construct the list of indices for a field domain, applying
an optional offset in each dimension as start index.
Args:
dims: The field dimensions.
offsets: The range start index in each dimension.
Returns:
A list of indices for field access in dace arrays. As this list is returned
as `dace.subsets.Indices`, it should be converted to `dace.subsets.Range` before
being used in memlet subset because ranges are better supported throughout DaCe.
"""
index_variables = [dace.symbolic.SymExpr(dace_gtir_utils.get_map_variable(dim)) for dim in dims]
if offsets is None:
return sbs.Indices(index_variables)
else:
return sbs.Indices(
[
index - offset if offset != 0 else index
for index, offset in zip(index_variables, offsets, strict=True)
]
)


@dataclasses.dataclass(frozen=True)
class FieldopData:
"""
Expand All @@ -45,42 +73,59 @@ class FieldopData:
Args:
dc_node: DaCe access node to the data storage.
gt_type: GT4Py type definition, which includes the field domain information.
offset: List of index offsets, in each dimension, when the dimension range
does not start from zero; assume zero offset, if not set.
"""

dc_node: dace.nodes.AccessNode
gt_type: ts.FieldType | ts.ScalarType
offset: Optional[list[dace.symbolic.SymExpr]]

def make_copy(self, data_node: dace.nodes.AccessNode) -> FieldopData:
"""Create a copy of this data descriptor with a different access node."""
assert data_node != self.dc_node
return FieldopData(data_node, self.gt_type, self.offset)

def get_local_view(
self, domain: FieldopDomain
) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr:
"""Helper method to access a field in local view, given a field operator domain."""
"""Helper method to access a field in local view, given the compute domain of a field operator."""
if isinstance(self.gt_type, ts.ScalarType):
return gtir_dataflow.MemletExpr(
dc_node=self.dc_node, gt_dtype=self.gt_type, subset=sbs.Indices([0])
)

if isinstance(self.gt_type, ts.FieldType):
indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = {
dim: gtir_dataflow.SymbolExpr(dace_gtir_utils.get_map_variable(dim), INDEX_DTYPE)
for dim, _, _ in domain
domain_dims = [dim for dim, _, _ in domain]
domain_indices = _get_domain_indices(domain_dims)
it_indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = {
dim: gtir_dataflow.SymbolExpr(index, INDEX_DTYPE)
for dim, index in zip(domain_dims, domain_indices)
}
field_domain = [
(dim, dace.symbolic.SymExpr(0) if self.offset is None else self.offset[i])
for i, dim in enumerate(self.gt_type.dims)
]
local_dims = [
dim for dim in self.gt_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL
]

if len(local_dims) == 0:
return gtir_dataflow.IteratorExpr(
self.dc_node, self.gt_type.dtype, self.gt_type.dims, indices
self.dc_node, self.gt_type.dtype, field_domain, it_indices
)

elif len(local_dims) == 1:
field_dtype = itir_ts.ListType(
element_type=self.gt_type.dtype, offset_type=local_dims[0]
)
field_dims = [
dim for dim in self.gt_type.dims if dim.kind != gtx_common.DimensionKind.LOCAL
field_domain = [
(dim, offset)
for dim, offset in field_domain
if dim.kind != gtx_common.DimensionKind.LOCAL
]
return gtir_dataflow.IteratorExpr(self.dc_node, field_dtype, field_dims, indices)
return gtir_dataflow.IteratorExpr(
self.dc_node, field_dtype, field_domain, it_indices
)

else:
raise ValueError(
Expand Down Expand Up @@ -155,9 +200,9 @@ def _parse_fieldop_arg(
return arg.get_local_view(domain)


def _get_field_shape(
def _get_field_layout(
domain: FieldopDomain,
) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr]]:
) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr], list[dace.symbolic.SymExpr]]:
"""
Parse the field operator domain and generates the shape of the result field.
Expand All @@ -174,11 +219,14 @@ def _get_field_shape(
domain: The field operator domain.
Returns:
A tuple of two lists: the list of field dimensions and the list of dace
array sizes in each dimension.
A tuple of three lists containing:
- the domain dimensions
- the domain offset in each dimension
- the domain size in each dimension
"""
domain_dims, _, domain_ubs = zip(*domain)
return list(domain_dims), list(domain_ubs)
domain_dims, domain_lbs, domain_ubs = zip(*domain)
domain_sizes = [(ub - lb) for lb, ub in zip(domain_lbs, domain_ubs)]
return list(domain_dims), list(domain_lbs), domain_sizes


def _create_temporary_field(
Expand All @@ -189,14 +237,15 @@ def _create_temporary_field(
dataflow_output: gtir_dataflow.DataflowOutputEdge,
) -> FieldopData:
"""Helper method to allocate a temporary field where to write the output of a field operator."""
field_dims, field_shape = _get_field_shape(domain)
field_dims, field_offset, field_shape = _get_field_layout(domain)

output_desc = dataflow_output.result.dc_node.desc(sdfg)
if isinstance(output_desc, dace.data.Array):
assert isinstance(node_type.dtype, itir_ts.ListType)
assert isinstance(node_type.dtype.element_type, ts.ScalarType)
assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype.element_type)
# extend the array with the local dimensions added by the field operator (e.g. `neighbors`)
field_offset.extend(output_desc.offset)
field_shape.extend(output_desc.shape)
elif isinstance(output_desc, dace.data.Scalar):
assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype)
Expand All @@ -215,7 +264,11 @@ def _create_temporary_field(
assert dataflow_output.result.gt_dtype.offset_type is not None
field_dims.append(dataflow_output.result.gt_dtype.offset_type)

return FieldopData(field_node, ts.FieldType(field_dims, field_dtype))
return FieldopData(
field_node,
ts.FieldType(field_dims, field_dtype),
offset=(field_offset if set(field_offset) != {0} else None),
)


def extract_domain(node: gtir.Node) -> FieldopDomain:
Expand Down Expand Up @@ -285,7 +338,8 @@ def translate_as_fieldop(

# parse the domain of the field operator
domain = extract_domain(domain_expr)
domain_indices = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain])
domain_dims, domain_offsets, _ = zip(*domain)
domain_indices = _get_domain_indices(domain_dims, domain_offsets)

# visit the list of arguments to be passed to the lambda expression
stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args]
Expand Down Expand Up @@ -350,10 +404,8 @@ def translate_broadcast_scalar(
assert cpm.is_ref_to(stencil_expr, "deref")

domain = extract_domain(domain_expr)
field_dims, field_shape = _get_field_shape(domain)
field_subset = sbs.Range.from_string(
",".join(dace_gtir_utils.get_map_variable(dim) for dim in field_dims)
)
output_dims, output_offset, output_shape = _get_field_layout(domain)
output_subset = sbs.Range.from_indices(_get_domain_indices(output_dims, output_offset))

assert len(node.args) == 1
scalar_expr = _parse_fieldop_arg(node.args[0], sdfg, state, sdfg_builder, domain)
Expand All @@ -369,26 +421,15 @@ def translate_broadcast_scalar(
assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr)
if len(node.args[0].type.dims) == 0: # zero-dimensional field
input_subset = "0"
elif all(
isinstance(scalar_expr.indices[dim], gtir_dataflow.SymbolExpr)
for dim in scalar_expr.dimensions
if dim not in field_dims
):
input_subset = ",".join(
dace_gtir_utils.get_map_variable(dim)
if dim in field_dims
else scalar_expr.indices[dim].value # type: ignore[union-attr] # catched by exception above
for dim in scalar_expr.dimensions
)
else:
raise ValueError(f"Cannot deref field {scalar_expr.field} in broadcast expression.")
input_subset = scalar_expr.get_memlet_subset(sdfg)

input_node = scalar_expr.field
gt_dtype = node.args[0].type.dtype
else:
raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.")

output, _ = sdfg.add_temp_transient(field_shape, input_node.desc(sdfg).dtype)
output, _ = sdfg.add_temp_transient(output_shape, input_node.desc(sdfg).dtype)
output_node = state.add_access(output)

sdfg_builder.add_mapped_tasklet(
Expand All @@ -400,13 +441,13 @@ def translate_broadcast_scalar(
},
inputs={"__inp": dace.Memlet(data=input_node.data, subset=input_subset)},
code="__val = __inp",
outputs={"__val": dace.Memlet(data=output_node.data, subset=field_subset)},
outputs={"__val": dace.Memlet(data=output_node.data, subset=output_subset)},
input_nodes={input_node.data: input_node},
output_nodes={output_node.data: output_node},
external_edges=True,
)

return FieldopData(output_node, ts.FieldType(field_dims, gt_dtype))
return FieldopData(output_node, ts.FieldType(output_dims, gt_dtype), output_offset)


def translate_if(
Expand Down Expand Up @@ -467,7 +508,7 @@ def construct_output(inner_data: FieldopData) -> FieldopData:
outer, _ = sdfg.add_temp_transient_like(inner_desc)
outer_node = state.add_access(outer)

return FieldopData(outer_node, inner_data.gt_type)
return inner_data.make_copy(outer_node)

result_temps = gtx_utils.tree_map(construct_output)(true_br_args)

Expand Down Expand Up @@ -513,7 +554,7 @@ def _get_data_nodes(
) -> FieldopResult:
if isinstance(data_type, ts.FieldType):
data_node = state.add_access(data_name)
return FieldopData(data_node, data_type)
return sdfg_builder.make_field(data_node, data_type)

elif isinstance(data_type, ts.ScalarType):
if data_name in sdfg.symbols:
Expand All @@ -522,7 +563,7 @@ def _get_data_nodes(
)
else:
data_node = state.add_access(data_name)
return FieldopData(data_node, data_type)
return sdfg_builder.make_field(data_node, data_type)

elif isinstance(data_type, ts.TupleType):
tuple_fields = dace_gtir_utils.get_tuple_fields(data_name, data_type)
Expand Down Expand Up @@ -579,7 +620,7 @@ def translate_literal(
data_type = node.type
data_node = _get_symbolic_value(sdfg, state, sdfg_builder, node.value, data_type)

return FieldopData(data_node, data_type)
return FieldopData(data_node, data_type, offset=None)


def translate_make_tuple(
Expand Down Expand Up @@ -708,7 +749,7 @@ def translate_scalar_expr(
dace.Memlet(data=temp_name, subset="0"),
)

return FieldopData(temp_node, node.type)
return FieldopData(temp_node, node.type, offset=None)


def translate_symbol_ref(
Expand Down
Loading

0 comments on commit 3fb206e

Please sign in to comment.