Skip to content

Commit

Permalink
refactor[catesian]: Type hints and code redability improvements (#1724)
Browse files Browse the repository at this point in the history
## Description

This PR is split off the work for the new GT4Py - DaCe bridge, which
should allow to expose control flow statements (`if` and `while`) to
DaCe to better use DaCe's analytics capabilities. This PR is concerned
with adding type hints and generally improving code readability. Main
parts are

- `daceir_builder.py`: early returns and renamed variable
- `sdfg_builder.py`: type hints and early returns
- `tasklet_codegen.py`: type hints and early returns

`TaskletCodegen` was given `sdfg_ctx`, which wasn't used. That parameter
was thus removed.

Parent issue: GEOS-ESM/NDSL#53

## Requirements

- [x] All fixes and/or new features come with corresponding tests.
  Assumed to be covered by existing tests.
- [ ] Important design decisions have been documented in the approriate
ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md)
folder.
  N/A

---------

Co-authored-by: Roman Cattaneo <>
  • Loading branch information
romanc authored Nov 15, 2024
1 parent 998f279 commit aeff1e3
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 101 deletions.
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class CartesianOffset(eve.Node):
k: int

@classmethod
def zero(cls) -> "CartesianOffset":
def zero(cls) -> CartesianOffset:
return cls(i=0, j=0, k=0)

def to_dict(self) -> Dict[str, int]:
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/dace/daceir.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ class Literal(common.Literal, Expr):


class ScalarAccess(common.ScalarAccess, Expr):
name: eve.Coerced[eve.SymbolRef]
pass


class VariableKOffset(common.VariableKOffset[Expr]):
Expand Down
84 changes: 42 additions & 42 deletions src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def _get_tasklet_inout_memlets(
*,
get_outputs: bool,
global_ctx: DaCeIRBuilder.GlobalContext,
**kwargs,
):
**kwargs: Any,
) -> List[dcir.Memlet]:
access_infos = compute_dcir_access_infos(
node,
block_extents=global_ctx.library_node.get_extents,
Expand All @@ -85,7 +85,7 @@ def _get_tasklet_inout_memlets(
**kwargs,
)

res = list()
memlets: List[dcir.Memlet] = []
for name, offset, tasklet_symbol in _access_iter(node, get_outputs=get_outputs):
access_info = access_infos[name]
if not access_info.variable_offset_axes:
Expand All @@ -95,26 +95,27 @@ def _get_tasklet_inout_memlets(
axis, extent=(offset_dict[axis.lower()], offset_dict[axis.lower()])
)

memlet = dcir.Memlet(
field=name,
connector=tasklet_symbol,
access_info=access_info,
is_read=not get_outputs,
is_write=get_outputs,
memlets.append(
dcir.Memlet(
field=name,
connector=tasklet_symbol,
access_info=access_info,
is_read=not get_outputs,
is_write=get_outputs,
)
)
res.append(memlet)
return res
return memlets


def _all_stmts_same_region(scope_nodes, axis: dcir.Axis, interval):
def all_statements_in_region(scope_nodes):
def _all_stmts_same_region(scope_nodes, axis: dcir.Axis, interval: Any) -> bool:
def all_statements_in_region(scope_nodes: List[eve.Node]) -> bool:
return all(
isinstance(stmt, dcir.HorizontalRestriction)
for tasklet in eve.walk_values(scope_nodes).if_isinstance(dcir.Tasklet)
for stmt in tasklet.stmts
)

def all_regions_same(scope_nodes):
def all_regions_same(scope_nodes: List[eve.Node]) -> bool:
return (
len(
set(
Expand Down Expand Up @@ -179,11 +180,11 @@ def _get_dcir_decl(
oir_decl: oir.Decl = self.library_node.declarations[field]
assert isinstance(oir_decl, oir.FieldDecl)
dace_array = self.arrays[field]
for s in dace_array.strides:
for sym in dace.symbolic.symlist(s).values():
symbol_collector.add_symbol(str(sym))
for sym in access_info.grid_subset.free_symbols:
symbol_collector.add_symbol(sym)
for stride in dace_array.strides:
for symbol in dace.symbolic.symlist(stride).values():
symbol_collector.add_symbol(str(symbol))
for symbol in access_info.grid_subset.free_symbols:
symbol_collector.add_symbol(symbol)

return dcir.FieldDecl(
name=field,
Expand Down Expand Up @@ -236,11 +237,7 @@ def push_expansion_item(self, item: Union[Map, Loop]) -> DaCeIRBuilder.Iteration
if not isinstance(item, (Map, Loop)):
raise ValueError

if isinstance(item, Map):
iterations = item.iterations
else:
iterations = [item]

iterations = item.iterations if isinstance(item, Map) else [item]
grid_subset = self.grid_subset
for it in iterations:
axis = it.axis
Expand All @@ -267,13 +264,13 @@ def pop(self) -> DaCeIRBuilder.IterationContext:
class SymbolCollector:
symbol_decls: Dict[str, dcir.SymbolDecl] = dataclasses.field(default_factory=dict)

def add_symbol(self, name: str, dtype: common.DataType = common.DataType.INT32):
def add_symbol(self, name: str, dtype: common.DataType = common.DataType.INT32) -> None:
if name not in self.symbol_decls:
self.symbol_decls[name] = dcir.SymbolDecl(name=name, dtype=dtype)
else:
assert self.symbol_decls[name].dtype == dtype

def remove_symbol(self, name: eve.SymbolRef):
def remove_symbol(self, name: eve.SymbolRef) -> None:
if name in self.symbol_decls:
del self.symbol_decls[name]

Expand Down Expand Up @@ -304,11 +301,14 @@ def visit_HorizontalRestriction(
symbol_collector.add_symbol(axis.iteration_symbol())
if bound.level == common.LevelMarker.END:
symbol_collector.add_symbol(axis.domain_symbol())

return dcir.HorizontalRestriction(
mask=node.mask, body=self.visit(node.body, symbol_collector=symbol_collector, **kwargs)
)

def visit_VariableKOffset(self, node: oir.VariableKOffset, **kwargs):
def visit_VariableKOffset(
self, node: oir.VariableKOffset, **kwargs: Any
) -> dcir.VariableKOffset:
return dcir.VariableKOffset(k=self.visit(node.k, **kwargs))

def visit_LocalScalar(self, node: oir.LocalScalar, **kwargs: Any) -> dcir.LocalScalarDecl:
Expand Down Expand Up @@ -419,7 +419,7 @@ def visit_HorizontalExecution(
symbol_collector: DaCeIRBuilder.SymbolCollector,
loop_order,
k_interval,
**kwargs,
**kwargs: Any,
):
# skip type checking due to https://github.com/python/mypy/issues/5485
extent = global_ctx.library_node.get_extents(node) # type: ignore
Expand Down Expand Up @@ -581,7 +581,7 @@ def to_dataflow(
nodes = flatten_list(nodes)
if all(isinstance(n, (dcir.NestedSDFG, dcir.DomainMap, dcir.Tasklet)) for n in nodes):
return nodes
elif not all(isinstance(n, (dcir.ComputationState, dcir.DomainLoop)) for n in nodes):
if not all(isinstance(n, (dcir.ComputationState, dcir.DomainLoop)) for n in nodes):
raise ValueError("Can't mix dataflow and state nodes on same level.")

read_memlets, write_memlets, field_memlets = union_inout_memlets(nodes)
Expand Down Expand Up @@ -615,10 +615,10 @@ def to_state(self, nodes, *, grid_subset: dcir.GridSubset):
nodes = flatten_list(nodes)
if all(isinstance(n, (dcir.ComputationState, dcir.DomainLoop)) for n in nodes):
return nodes
elif all(isinstance(n, (dcir.NestedSDFG, dcir.DomainMap, dcir.Tasklet)) for n in nodes):
if all(isinstance(n, (dcir.NestedSDFG, dcir.DomainMap, dcir.Tasklet)) for n in nodes):
return [dcir.ComputationState(computations=nodes, grid_subset=grid_subset)]
else:
raise ValueError("Can't mix dataflow and state nodes on same level.")

raise ValueError("Can't mix dataflow and state nodes on same level.")

def _process_map_item(
self,
Expand All @@ -628,8 +628,8 @@ def _process_map_item(
global_ctx: DaCeIRBuilder.GlobalContext,
iteration_ctx: DaCeIRBuilder.IterationContext,
symbol_collector: DaCeIRBuilder.SymbolCollector,
**kwargs,
):
**kwargs: Any,
) -> List[dcir.DomainMap]:
grid_subset = iteration_ctx.grid_subset
read_memlets, write_memlets, _ = union_inout_memlets(list(scope_nodes))
scope_nodes = self.to_dataflow(
Expand Down Expand Up @@ -723,11 +723,11 @@ def _process_loop_item(
scope_nodes,
item: Loop,
*,
global_ctx,
global_ctx: DaCeIRBuilder.GlobalContext,
iteration_ctx: DaCeIRBuilder.IterationContext,
symbol_collector: DaCeIRBuilder.SymbolCollector,
**kwargs,
):
**kwargs: Any,
) -> List[dcir.DomainLoop]:
grid_subset = union_node_grid_subsets(list(scope_nodes))
read_memlets, write_memlets, _ = union_inout_memlets(list(scope_nodes))
scope_nodes = self.to_state(scope_nodes, grid_subset=grid_subset)
Expand Down Expand Up @@ -793,14 +793,14 @@ def _process_loop_item(
def _process_iteration_item(self, scope, item, **kwargs):
if isinstance(item, Map):
return self._process_map_item(scope, item, **kwargs)
elif isinstance(item, Loop):
if isinstance(item, Loop):
return self._process_loop_item(scope, item, **kwargs)
else:
raise ValueError("Invalid expansion specification set.")

raise ValueError("Invalid expansion specification set.")

def visit_VerticalLoop(
self, node: oir.VerticalLoop, *, global_ctx: DaCeIRBuilder.GlobalContext, **kwargs
):
self, node: oir.VerticalLoop, *, global_ctx: DaCeIRBuilder.GlobalContext, **kwargs: Any
) -> dcir.NestedSDFG:
overall_extent = Extent.zeros(2)
for he in node.walk_values().if_isinstance(oir.HorizontalExecution):
overall_extent = overall_extent.union(global_ctx.library_node.get_extents(he))
Expand Down
36 changes: 17 additions & 19 deletions src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def visit_Memlet(
scope_node: dcir.ComputationNode,
sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext,
node_ctx: StencilComputationSDFGBuilder.NodeContext,
connector_prefix="",
connector_prefix: str = "",
symtable: ChainMap[eve.SymbolRef, dcir.Decl],
) -> None:
field_decl = symtable[node.field]
Expand Down Expand Up @@ -139,13 +139,12 @@ def visit_Tasklet(
sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext,
node_ctx: StencilComputationSDFGBuilder.NodeContext,
symtable: ChainMap[eve.SymbolRef, dcir.Decl],
**kwargs,
**kwargs: Any,
) -> None:
code = TaskletCodegen.apply_codegen(
node,
read_memlets=node.read_memlets,
write_memlets=node.write_memlets,
sdfg_ctx=sdfg_ctx,
symtable=symtable,
)

Expand Down Expand Up @@ -177,7 +176,7 @@ def visit_Tasklet(
tasklet, tasklet, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx
)

def visit_Range(self, node: dcir.Range, **kwargs) -> Dict[str, str]:
def visit_Range(self, node: dcir.Range, **kwargs: Any) -> Dict[str, str]:
start, end = node.interval.to_dace_symbolic()
return {node.var: str(dace.subsets.Range([(start, end - 1, node.stride)]))}

Expand All @@ -187,7 +186,7 @@ def visit_DomainMap(
*,
node_ctx: StencilComputationSDFGBuilder.NodeContext,
sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext,
**kwargs,
**kwargs: Any,
) -> None:
ndranges = {
k: v
Expand Down Expand Up @@ -248,7 +247,7 @@ def visit_DomainLoop(
node: dcir.DomainLoop,
*,
sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext,
**kwargs,
**kwargs: Any,
) -> None:
sdfg_ctx = sdfg_ctx.add_loop(node.index_range)
self.visit(node.loop_states, sdfg_ctx=sdfg_ctx, **kwargs)
Expand All @@ -259,7 +258,7 @@ def visit_ComputationState(
node: dcir.ComputationState,
*,
sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext,
**kwargs,
**kwargs: Any,
) -> None:
sdfg_ctx.add_state()
read_acc_and_conn: Dict[Optional[str], Tuple[dace.nodes.Node, Optional[str]]] = {}
Expand Down Expand Up @@ -289,7 +288,7 @@ def visit_FieldDecl(
*,
sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext,
non_transients: Set[eve.SymbolRef],
**kwargs,
**kwargs: Any,
) -> None:
assert len(node.strides) == len(node.shape)
sdfg_ctx.sdfg.add_array(
Expand All @@ -307,7 +306,7 @@ def visit_SymbolDecl(
node: dcir.SymbolDecl,
*,
sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext,
**kwargs,
**kwargs: Any,
) -> None:
if node.name not in sdfg_ctx.sdfg.symbols:
sdfg_ctx.sdfg.add_symbol(node.name, stype=data_type_to_dace_typeclass(node.dtype))
Expand All @@ -319,7 +318,7 @@ def visit_NestedSDFG(
sdfg_ctx: Optional[StencilComputationSDFGBuilder.SDFGContext] = None,
node_ctx: Optional[StencilComputationSDFGBuilder.NodeContext] = None,
symtable: ChainMap[eve.SymbolRef, Any],
**kwargs,
**kwargs: Any,
) -> dace.nodes.NestedSDFG:
sdfg = dace.SDFG(node.label)
inner_sdfg_ctx = StencilComputationSDFGBuilder.SDFGContext(
Expand Down Expand Up @@ -365,13 +364,12 @@ def visit_NestedSDFG(
StencilComputationSDFGBuilder._add_empty_edges(
nsdfg, nsdfg, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx
)
else:
nsdfg = dace.nodes.NestedSDFG(
label=sdfg.label,
sdfg=sdfg,
inputs={memlet.connector for memlet in node.read_memlets},
outputs={memlet.connector for memlet in node.write_memlets},
symbol_mapping=symbol_mapping,
)
return nsdfg

return nsdfg
return dace.nodes.NestedSDFG(
label=sdfg.label,
sdfg=sdfg,
inputs={memlet.connector for memlet in node.read_memlets},
outputs={memlet.connector for memlet in node.write_memlets},
symbol_mapping=symbol_mapping,
)
Loading

0 comments on commit aeff1e3

Please sign in to comment.