Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix loop lifting for trailing increment assignments #1860

Merged
merged 3 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions dace/transformation/interstate/loop_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@ class DetectLoop(transformation.PatternTransformation):
""" Detects a for-loop construct from an SDFG. """

# Always available
loop_begin = transformation.PatternNode(sd.SDFGState)
exit_state = transformation.PatternNode(sd.SDFGState)
loop_begin = transformation.PatternNode(ControlFlowBlock)
exit_state = transformation.PatternNode(ControlFlowBlock)

# Available for natural loops
loop_guard = transformation.PatternNode(sd.SDFGState)
loop_guard = transformation.PatternNode(ControlFlowBlock)

# Available for rotated loops
loop_latch = transformation.PatternNode(sd.SDFGState)
loop_latch = transformation.PatternNode(ControlFlowBlock)

# Available for rotated and self loops
entry_state = transformation.PatternNode(sd.SDFGState)
entry_state = transformation.PatternNode(ControlFlowBlock)

# Available for explicit-latch rotated loops
loop_break = transformation.PatternNode(sd.SDFGState)
loop_break = transformation.PatternNode(ControlFlowBlock)

@classmethod
def expressions(cls):
Expand Down Expand Up @@ -260,6 +260,10 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool,
if latch_outedges[0].data.condition_sympy() != (sp.Not(latch_outedges[1].data.condition_sympy())):
return None

# Make sure the backedge (i.e, one of the condition edges) goes from the latch to the beginning state.
if latch_outedges[0].dst is not self.loop_begin and latch_outedges[1].dst is not self.loop_begin:
return None

# All nodes inside loop must be dominated by loop start
dominators = nx.dominance.immediate_dominators(graph.nx, graph.start_block)
if begin is ltest:
Expand Down
20 changes: 17 additions & 3 deletions dace/transformation/interstate/loop_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dace import properties
from dace.sdfg.sdfg import SDFG, InterstateEdge
from dace.sdfg.state import ControlFlowRegion, LoopRegion
from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, LoopRegion
from dace.transformation import transformation
from dace.transformation.interstate.loop_detection import DetectLoop

Expand Down Expand Up @@ -82,8 +82,22 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG):
added.add(e)
if e is incr_edge:
if left_over_incr_assignments != {}:
dst = loop.add_state(label + '_tail') if not inverted else e.dst
loop.add_edge(e.src, dst, InterstateEdge(assignments=left_over_incr_assignments))
assignments = left_over_incr_assignments
dst = e.dst
if e.dst is first_state:
if not update_before_condition:
left_over_incr_cond_region = ConditionalBlock(label + '_post_incr_conditional')
incr_graph = ControlFlowRegion(label + '_post_incr')
left_over_incr_cond_region.add_branch(cond_edge.data.condition, incr_graph)
incr_graph.add_edge(incr_graph.add_state(label + '_post_incr_start',
is_start_block=True),
incr_graph.add_state(label + '_post_incr_end'),
InterstateEdge(assignments=left_over_incr_assignments))
dst = left_over_incr_cond_region
assignments = {}
else:
dst = loop.add_state(label + '_tail')
loop.add_edge(e.src, dst, InterstateEdge(assignments=assignments))
elif e is cond_edge:
if not inverted:
e.data.condition = properties.CodeBlock('1')
Expand Down
44 changes: 43 additions & 1 deletion tests/transformations/interstate/loop_lifting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_lift_loop_llvm_canonical_while():


def test_do_while():
sdfg = SDFG('regular_for')
sdfg = SDFG('do_while')
N = dace.symbol('N')
sdfg.add_symbol('i', dace.int32)
sdfg.add_symbol('j', dace.int32)
Expand Down Expand Up @@ -209,9 +209,51 @@ def test_do_while():
assert np.allclose(A_valid, A)


def test_inverted_loop_with_additional_increment_assignment():
sdfg = SDFG('inverted_loop_with_additional_increment_assignment')
N = dace.symbol('N')
sdfg.add_scalar('i', dace.int32, transient=True)
sdfg.add_symbol('k', dace.int32)
sdfg.add_array('A', (N,), dace.int32)
a_state = sdfg.add_state('a_state', is_start_block=True)
b_state = sdfg.add_state('b_state')
c_state = sdfg.add_state('c_state')
d_state = sdfg.add_state('d_state')
sdfg.add_edge(a_state, b_state, InterstateEdge(assignments={'k': 0}))
sdfg.add_edge(b_state, c_state, InterstateEdge())
sdfg.add_edge(c_state, b_state, InterstateEdge(condition='i < N', assignments={'k': 'k + 1'}))
sdfg.add_edge(c_state, d_state, InterstateEdge(condition='i >= N'))
a_access = b_state.add_access('A')
w_tasklet = b_state.add_tasklet('t1', {}, {'out'}, 'out = 1')
b_state.add_edge(w_tasklet, 'out', a_access, None, Memlet('A[i]'))
i_read = c_state.add_access('i')
i_write = c_state.add_access('i')
iw_tasklet = c_state.add_tasklet('t2', {'in1'}, {'out'}, 'out = in1 + 2')
c_state.add_edge(i_read, None, iw_tasklet, 'in1', Memlet('i[0]'))
c_state.add_edge(iw_tasklet, 'out', i_write, None, Memlet('i[0]'))
a_access_2 = d_state.add_access('A')
w_tasklet_2 = d_state.add_tasklet('t1', {}, {'out'}, 'out = k')
d_state.add_edge(w_tasklet_2, 'out', a_access_2, None, Memlet('A[1]'))

N = 30
A = np.zeros((N,)).astype(np.int32)
A_valid = np.zeros((N,)).astype(np.int32)
sdfg(A=A_valid, N=N)

sdfg.apply_transformations_repeated([LoopLifting])

assert sdfg.using_explicit_control_flow == True
assert any(isinstance(x, LoopRegion) for x in sdfg.nodes())

sdfg(A=A, N=N)

assert np.allclose(A_valid, A)


if __name__ == '__main__':
test_lift_regular_for_loop()
test_lift_loop_llvm_canonical(True)
test_lift_loop_llvm_canonical(False)
test_lift_loop_llvm_canonical_while()
test_do_while()
test_inverted_loop_with_additional_increment_assignment()
Loading