Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Oct 10, 2024
1 parent edffd97 commit 3196a11
Showing 1 changed file with 14 additions and 17 deletions.
31 changes: 14 additions & 17 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def _transform_if(
return [
itir.IfStmt(
cond=cond,
# recursively _transform_stmt
true_branch=_transform_stmt(
itir.SetAt(target=stmt.target, expr=true_val, domain=stmt.domain),
declarations,
Expand All @@ -62,8 +61,8 @@ def _transform_by_pattern(
predicate=predicate,
uid_generator=eve_utils.UIDGenerator(prefix="__tmp_subexpr"),
# TODO(tehrengruber): extracting the deepest expression first would allow us to fuse
# the extracted expressions resulting in fewer kernel calls, better data-locality.
# Extracting the multiple expressions deepest-first is however not supported right now.
# the extracted expressions resulting in fewer kernel calls & better data-locality.
# Extracting multiple expressions deepest-first is however not supported right now.
# deepest_expr_first=True # noqa: ERA001
)

Expand Down Expand Up @@ -146,14 +145,14 @@ def _transform_stmt(
unprocessed_stmts: list[itir.Stmt] = [stmt]
stmts: list[itir.Stmt] = []

_transform_stmts: list[Callable] = [
# _transform_stmt functional if_ into if-stmt
transforms: list[Callable] = [
# transform `if_` call into `IfStmt`
_transform_if,
# extract applied `as_fieldop` to top-level
functools.partial(
_transform_by_pattern, predicate=lambda expr, _: cpm.is_applied_as_fieldop(expr)
),
# extract functional if_ to the top-level
# extract if_ call to the top-level
functools.partial(
_transform_by_pattern, predicate=lambda expr, _: cpm.is_call_to(expr, "if_")
),
Expand All @@ -162,18 +161,16 @@ def _transform_stmt(
while unprocessed_stmts:
stmt = unprocessed_stmts.pop(0)

did_transform_stmt = False
for _transform_stmt in _transform_stmts:
_transform_stmted_stmts = _transform_stmt(
stmt=stmt, declarations=declarations, uids=uids
)
if _transform_stmted_stmts:
unprocessed_stmts = [*_transform_stmted_stmts, *unprocessed_stmts]
did_transform_stmt = True
did_transform = False
for transform in transforms:
transformed_stmts = transform(stmt=stmt, declarations=declarations, uids=uids)
if transformed_stmts:
unprocessed_stmts = [*transformed_stmts, *unprocessed_stmts]
did_transform = True
break

# no _transform_stmtation occurred
if not did_transform_stmt:
# no transformation occurred
if not did_transform:
stmts.append(stmt)

return stmts
Expand All @@ -185,7 +182,7 @@ def create_global_tmps(
"""
Given an `itir.Program` create temporaries for intermediate values.
This pass looks at all `as_fieldop` calls and _transform_stmts field-typed subexpressions of its
This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its
arguments into temporaries.
"""
program = infer_domain.infer_program(program, offset_provider)
Expand Down

0 comments on commit 3196a11

Please sign in to comment.