Skip to content

Commit

Permalink
Minor edit based on review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Jan 8, 2025
1 parent 65b4dd2 commit 61b06b3
Showing 1 changed file with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -598,9 +598,9 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[A
entry_state = nsdfg.add_state("entry", is_start_block=True)
nsdfg.add_edge(entry_state, if_region, dace.InterstateEdge())

if_body = dace.sdfg.state.ControlFlowRegion("if_body", sdfg=nsdfg)
tstate = if_body.add_state("true_branch", is_start_block=True)
if_region.add_branch(dace.sdfg.state.CodeBlock("__cond"), if_body)
then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=nsdfg)
tstate = then_body.add_state("true_branch", is_start_block=True)
if_region.add_branch(dace.sdfg.state.CodeBlock("__cond"), then_body)

else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=nsdfg)
fstate = else_body.add_state("false_branch", is_start_block=True)
Expand Down Expand Up @@ -629,7 +629,7 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[A
assert isinstance(condition_value, ValueExpr)
input_memlets["__cond"] = condition_value

def visit_branch(
def construct_if_branch(
state: dace.SDFGState, expr: gtir.Expr
) -> tuple[
list[DataflowInputEdge],
Expand All @@ -656,7 +656,7 @@ def visit_arg(arg: IteratorExpr | DataExpr) -> IteratorExpr | ValueExpr:
arg_data = arg_node.data
# SDFG data containers with name prefix '__tmp' are expected to be transients
inner_data = (
arg_data.replace("__tmp", "__input")
arg_data.replace("__tmp", "__input", count=1)
if arg_data.startswith("__tmp")
else arg_data
)
Expand Down Expand Up @@ -698,7 +698,7 @@ def visit_arg(arg: IteratorExpr | DataExpr) -> IteratorExpr | ValueExpr:
return visit_lambda(nsdfg, state, self.subgraph_builder, lambda_node, lambda_args)

for state, arg in zip([tstate, fstate], node.args[1:3]):
in_edges, out_edge = visit_branch(state, arg)
in_edges, out_edge = construct_if_branch(state, arg)
for edge in in_edges:
edge.connect(map_entry=None)

Expand Down

0 comments on commit 61b06b3

Please sign in to comment.