From dc31c606bc20e5ef7fe90b47ace7eea2c57c05da Mon Sep 17 00:00:00 2001
From: Harsh Menon <harsh@nod-labs.com>
Date: Wed, 11 Sep 2024 17:48:42 -0700
Subject: [PATCH] Add code to construct pipelined loop from schedule

This PR adds code to construct the epilogue, kernel
and prologue once we have computed a schedule. We
simulate rotating registers in software and add
visualization tools to show the pipelined graphs.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
---
 lit_tests/kernel/wave/codegen.py              | 333 +++++++++++
 lit_tests/kernel/wave/scheduling.py           | 227 +++++++
 shark_turbine/kernel/_support/tracing.py      |   3 +
 shark_turbine/kernel/ops/wave_ops.py          |   8 +-
 shark_turbine/kernel/wave/codegen.py          |  22 +-
 .../kernel/wave/scheduling/graph_utils.py     |   3 +-
 .../wave/scheduling/loop_reconstruction.py    | 556 ++++++++++++++++++
 .../scheduling/loop_reconstruction_utils.py   | 285 +++++++++
 .../wave/scheduling/modulo_scheduling.py      |   9 +
 .../kernel/wave/scheduling/schedule.py        |  43 +-
 shark_turbine/kernel/wave/utils.py            |  58 +-
 shark_turbine/kernel/wave/visualization.py    |  95 ++-
 shark_turbine/kernel/wave/wave.py             |   7 +-
 tests/kernel/wave/wave_gemm_test.py           |  25 +-
 14 files changed, 1646 insertions(+), 28 deletions(-)
 create mode 100644 lit_tests/kernel/wave/scheduling.py
 create mode 100644 shark_turbine/kernel/wave/scheduling/loop_reconstruction.py
 create mode 100644 shark_turbine/kernel/wave/scheduling/loop_reconstruction_utils.py

diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py
index b84cc271..8d67d42e 100644
--- a/lit_tests/kernel/wave/codegen.py
+++ b/lit_tests/kernel/wave/codegen.py
@@ -606,6 +606,339 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
         # CHECK:            return
 
 
+@run_test
+def test_gemm_pipelined():
+    constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
+    constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
+    constraints += [tkw.TilingConstraint(K, BLOCK_K)]
+    constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)]
+    constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]
+
+    constraints += [
+        tkw.HardwareConstraint(
+            threads_per_wave=64,
+            waves_per_block=(2, 2, 1),
+            mma_type=tkw.MMAType.F32_16x16x16_F16,
+        )
+    ]
+
+    @tkw.wave(constraints)
+    def gemm_pipelined(
+        a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
+        b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
+        c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
+    ):
+        c_reg = tkl.Register[M, N, tkl.f32](0.0)
+
+        @tkw.reduction(K, init_args=[c_reg])
+        def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
+            a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
+            b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
+            acc = tkw.mma(a_reg, b_reg, acc)
+            return acc
+
+        tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD)
+
+    with tk.gen.TestLaunchContext(
+        {
+            M: 128,
+            N: 128,
+            K: 128,
+            BLOCK_M: 64,
+            BLOCK_N: 64,
+            BLOCK_K: 32,
+            LOAD_ELEMS_PER_THREAD: 4,
+            STORE_ELEMS_PER_THREAD: 4,
+            ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
+            ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE,
+            READ_SHARED_DELAY: 1,
+            WRITE_SHARED_DELAY: 1,
+            READ_GLOBAL_DELAY: 2,
+            WRITE_GLOBAL_DELAY: 2,
+            MMA_DELAY: 1,
+            SHARED_MEMORY_UNITS: 4,
+            GLOBAL_MEMORY_UNITS: 4,
+            MMA_UNITS: 4,
+        },
+        canonicalize=True,
+        schedule=True,
+    ):
+        a = torch.randn(64, 32, dtype=torch.float16)
+        b = torch.randn(128, 32, dtype=torch.float16)
+        c = torch.zeros(64, 128, dtype=torch.float32)
+        print(gemm_pipelined(a, b, c).module_op)
+
+        # CHECK:          func.func @gemm_pipelined(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]:
+        # CHECK-SAME:       !stream.binding, %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info =
+        # CHECK-SAME:       #[[TRANSLATION:.+]]} {
+        # CHECK-DAG:        %[[C19:.+]] = arith.constant 19 : index
+        # CHECK-DAG:        %[[C18:.+]] = arith.constant 18 : index
+        # CHECK-DAG:        %[[C17:.+]] = arith.constant 17 : index
+        # CHECK-DAG:        %[[C3:.+]] = arith.constant 3 : index
+        # CHECK-DAG:        %[[C2:.+]] = arith.constant 2 : index
+        # CHECK-DAG:        %[[C16:.+]] = arith.constant 16 : index
+        # CHECK-DAG:        %[[C8:.+]] = arith.constant 8 : index
+        # CHECK-DAG:        %[[C4:.+]] = arith.constant 4 : index
+        # CHECK-DAG:        %[[C1:.+]] = arith.constant 1 : index
+        # CHECK-DAG:        %[[C32:.+]] = arith.constant 32 : index
+        # CHECK-DAG:        %[[C64:.+]] = arith.constant 64 : index
+        # CHECK-DAG:        %[[C0:.+]] = arith.constant 0 : index
+        # CHECK-DAG:        %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+        # CHECK:            %[[WORKGROUP_ID_0:.+]] = stream.dispatch.workgroup.id[0] : index
+        # CHECK:            %[[WORKGROUP_ID_1:.+]] = stream.dispatch.workgroup.id[1] : index
+        # CHECK-DAG:        %[[THREAD_ID_X:.+]] = gpu.thread_id  x
+        # CHECK-DAG:        %[[THREAD_ID_Y:.+]] = gpu.thread_id  y
+        # CHECK:            %[[ALLOC:.+]] = memref.alloc() : memref<64x32xf16, #[[GPU:.+]].address_space<workgroup>>
+        # CHECK:            %[[ALLOC_0:.+]] = memref.alloc() : memref<64x32xf16, #[[GPU]].address_space<workgroup>>
+        # CHECK:            %[[D0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<128x128xf16,
+        # CHECK-SAME:         strided<[128, 1], offset: ?>>
+        # CHECK:            %[[D1:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C64]] : index
+        # CHECK:            %[[D2:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C32]] : index
+        # CHECK:            %[[D3:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C4]] : index
+        # CHECK:            %[[D4:.+]] = arith.addi %[[D3]], %[[D2]] : index
+        # CHECK:            %[[D5:.+]] = arith.remsi %[[D4]], %[[C64]] : index
+        # CHECK:            %[[D6:.+]] = arith.addi %[[D5]], %[[D1]] : index
+        # CHECK:            %[[D7:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C4]] : index
+        # CHECK:            %[[D8:.+]] = arith.muli %[[D7]], %[[C8]] : index
+        # CHECK:            %[[D9:.+]] = vector.load %[[D0]][%[[D6]], %[[D8]]] : memref<128x128xf16, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<8xf16>
+        # CHECK:            %[[D10:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x128xf16,
+        # CHECK-SAME:         strided<[128, 1], offset: ?>>
+        # CHECK:            %[[D11:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C64]] : index
+        # CHECK:            %[[D12:.+]] = arith.addi %[[D5]], %[[D11]] : index
+        # CHECK:            %[[D13:.+]] = vector.load %[[D10]][%[[D12]], %[[D8]]] : memref<128x128xf16, strided<[128, 1],
+        # CHECK-SAME:         offset: ?>>, vector<8xf16>
+        # CHECK:            vector.store %[[D9]], %[[ALLOC]][%[[D5]], %[[D8]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<8xf16>
+        # CHECK:            vector.store %[[D13]], %[[ALLOC_0]][%[[D5]], %[[D8]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<8xf16>
+        # CHECK:            amdgpu.lds_barrier
+        # CHECK:            %[[D14:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index
+        # CHECK:            %[[D15:.+]] = arith.muli %[[D14]], %[[C32]] : index
+        # CHECK:            %[[D16:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index
+        # CHECK:            %[[D17:.+]] = arith.addi %[[D16]], %[[D15]] : index
+        # CHECK:            %[[D18:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index
+        # CHECK:            %[[D19:.+]] = arith.divsi %[[D18]], %[[C16]] : index
+        # CHECK:            %[[D20:.+]] = arith.muli %[[D19]], %[[C4]] : index
+        # CHECK:            %[[D21:.+]] = arith.addi %[[D20]], %[[C16]] : index
+        # CHECK:            %[[D22:.+]] = vector.load %[[ALLOC]][%[[D17]], %[[D21]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:            %[[D23:.+]] = arith.addi %[[D16]], %[[D2]] : index
+        # CHECK:            %[[D24:.+]] = vector.load %[[ALLOC_0]][%[[D23]], %[[D21]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:            %[[D25:.+]] = arith.addi %[[D23]], %[[C16]] : index
+        # CHECK:            %[[D26:.+]] = vector.load %[[ALLOC_0]][%[[D25]], %[[D20]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:            %[[D27:.+]] = vector.load %[[ALLOC_0]][%[[D25]], %[[D21]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:            %[[D28:.+]] = arith.addi %[[D8]], %[[C32]] : index
+        # CHECK:            %[[D29:.+]] = vector.load %[[D0]][%[[D6]], %[[D28]]] : memref<128x128xf16, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<8xf16>
+        # CHECK:            %[[D30:.+]] = vector.load %[[D10]][%[[D12]], %[[D28]]] : memref<128x128xf16, strided<[128, 1],
+        # CHECK-SAME:         offset: ?>>, vector<8xf16>
+        # CHECK:            %[[D31:.+]] = vector.load %[[ALLOC]][%[[D17]], %[[D20]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:            %[[D32:.+]] = arith.addi %[[D17]], %[[C16]] : index
+        # CHECK:            %[[D33:.+]] = vector.load %[[ALLOC]][%[[D32]], %[[D20]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:            %[[D34:.+]] = vector.load %[[ALLOC]][%[[D32]], %[[D21]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:            %[[D35:.+]] = vector.load %[[ALLOC_0]][%[[D23]], %[[D20]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:            %[[D36:.+]] = amdgpu.mfma %[[D31]] * %[[D35]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 :
+        # CHECK-SAME:         i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:            %[[D37:.+]] = amdgpu.mfma %[[D33]] * %[[D26]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 :
+        # CHECK-SAME:         i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:            %[[D38:.+]] = amdgpu.mfma %[[D33]] * %[[D35]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 :
+        # CHECK-SAME:         i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:            %[[D39:.+]] = amdgpu.mfma %[[D31]] * %[[D26]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 :
+        # CHECK-SAME:         i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:            amdgpu.lds_barrier
+        # CHECK:            vector.store %[[D29]], %[[ALLOC]][%[[D5]], %[[D8]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<8xf16>
+        # CHECK:            vector.store %[[D30]], %[[ALLOC_0]][%[[D5]], %[[D8]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<8xf16>
+        # CHECK:            %[[D40:.+]]:8 = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]]
+        # CHECK-SAME:         iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[D22]], %[[ARG5:[a-zA-Z0-9_]+]] = %[[D34]],
+        # CHECK-SAME:         %[[ARG6:[a-zA-Z0-9_]+]] = %[[D24]], %[[ARG7:[a-zA-Z0-9_]+]] = %[[D27]], %[[ARG8:[a-zA-Z0-9_]+]] =
+        # CHECK-SAME:         %[[D36]], %[[ARG9:[a-zA-Z0-9_]+]] = %[[D37]], %[[ARG10:[a-zA-Z0-9_]+]] = %[[D38]],
+        # CHECK-SAME:         %[[ARG11:[a-zA-Z0-9_]+]] = %[[D39]]) -> (vector<4xf16>, vector<4xf16>, vector<4xf16>,
+        # CHECK-SAME:         vector<4xf16>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
+        # CHECK:              %[[D90:.+]] = amdgpu.mfma %[[ARG4]] * %[[ARG6]] + %[[ARG8]] {blocks = 1 : i32, k = 16 : i32, m =
+        # CHECK-SAME:           16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:              %[[D91:.+]] = amdgpu.mfma %[[ARG5]] * %[[ARG7]] + %[[ARG9]] {blocks = 1 : i32, k = 16 : i32, m =
+        # CHECK-SAME:           16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:              %[[D92:.+]] = amdgpu.mfma %[[ARG5]] * %[[ARG6]] + %[[ARG10]] {blocks = 1 : i32, k = 16 : i32, m =
+        # CHECK-SAME:           16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:              %[[D93:.+]] = amdgpu.mfma %[[ARG4]] * %[[ARG7]] + %[[ARG11]] {blocks = 1 : i32, k = 16 : i32, m =
+        # CHECK-SAME:           16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:              amdgpu.lds_barrier
+        # CHECK:              %[[D94:.+]] = vector.load %[[ALLOC]][%[[D17]], %[[D21]]] : memref<64x32xf16,
+        # CHECK-SAME:           #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:              %[[D95:.+]] = vector.load %[[ALLOC_0]][%[[D23]], %[[D21]]] : memref<64x32xf16,
+        # CHECK-SAME:           #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:              %[[D96:.+]] = vector.load %[[ALLOC_0]][%[[D25]], %[[D20]]] : memref<64x32xf16,
+        # CHECK-SAME:           #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:              %[[D97:.+]] = vector.load %[[ALLOC_0]][%[[D25]], %[[D21]]] : memref<64x32xf16,
+        # CHECK-SAME:           #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:              %[[D98:.+]] = arith.muli %[[ARG3]], %[[C32]] : index
+        # CHECK:              %[[D99:.+]] = arith.addi %[[D98]], %[[D8]] : index
+        # CHECK:              %[[D100:.+]] = arith.addi %[[D99]], %[[C64]] : index
+        # CHECK:              %[[D101:.+]] = vector.load %[[D0]][%[[D6]], %[[D100]]] : memref<128x128xf16, strided<[128, 1],
+        # CHECK-SAME:           offset: ?>>, vector<8xf16>
+        # CHECK:              %[[D102:.+]] = vector.load %[[D10]][%[[D12]], %[[D100]]] : memref<128x128xf16, strided<[128, 1],
+        # CHECK-SAME:           offset: ?>>, vector<8xf16>
+        # CHECK:              %[[D103:.+]] = vector.load %[[ALLOC]][%[[D17]], %[[D20]]] : memref<64x32xf16,
+        # CHECK-SAME:           #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:              %[[D104:.+]] = vector.load %[[ALLOC]][%[[D32]], %[[D20]]] : memref<64x32xf16,
+        # CHECK-SAME:           #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:              %[[D105:.+]] = vector.load %[[ALLOC]][%[[D32]], %[[D21]]] : memref<64x32xf16,
+        # CHECK-SAME:           #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:              %[[D106:.+]] = vector.load %[[ALLOC_0]][%[[D23]], %[[D20]]] : memref<64x32xf16,
+        # CHECK-SAME:           #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:              %[[D107:.+]] = amdgpu.mfma %[[D103]] * %[[D106]] + %[[D90]] {blocks = 1 : i32, k = 16 : i32, m =
+        # CHECK-SAME:           16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:              %[[D108:.+]] = amdgpu.mfma %[[D104]] * %[[D96]] + %[[D91]] {blocks = 1 : i32, k = 16 : i32, m = 16
+        # CHECK-SAME:           : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:              %[[D109:.+]] = amdgpu.mfma %[[D104]] * %[[D106]] + %[[D92]] {blocks = 1 : i32, k = 16 : i32, m =
+        # CHECK-SAME:           16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:              %[[D110:.+]] = amdgpu.mfma %[[D103]] * %[[D96]] + %[[D93]] {blocks = 1 : i32, k = 16 : i32, m = 16
+        # CHECK-SAME:           : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:              amdgpu.lds_barrier
+        # CHECK:              vector.store %[[D101]], %[[ALLOC]][%[[D5]], %[[D8]]] : memref<64x32xf16,
+        # CHECK-SAME:           #[[GPU]].address_space<workgroup>>, vector<8xf16>
+        # CHECK:              vector.store %[[D102]], %[[ALLOC_0]][%[[D5]], %[[D8]]] : memref<64x32xf16,
+        # CHECK-SAME:           #[[GPU]].address_space<workgroup>>, vector<8xf16>
+        # CHECK:              scf.yield %[[D94]], %[[D105]], %[[D95]], %[[D97]], %[[D107]], %[[D108]], %[[D109]], %[[D110]] :
+        # CHECK-SAME:           vector<4xf16>, vector<4xf16>, vector<4xf16>, vector<4xf16>, vector<4xf32>, vector<4xf32>,
+        # CHECK-SAME:           vector<4xf32>, vector<4xf32>
+        # CHECK:            }
+        # CHECK:            %[[D41:.+]] = amdgpu.mfma %[[D40]]#[[D0:.+]] * %[[D40]]#[[D2:.+]] + %[[D40]]#[[D4:.+]] {blocks = 1 :
+        # CHECK-SAME:         i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>,
+        # CHECK-SAME:         vector<4xf32>
+        # CHECK:            %[[D42:.+]] = amdgpu.mfma %[[D40]]#[[D1:.+]] * %[[D40]]#[[D3:.+]] + %[[D40]]#[[D5:.+]] {blocks = 1 :
+        # CHECK-SAME:         i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>,
+        # CHECK-SAME:         vector<4xf32>
+        # CHECK:            %[[D43:.+]] = amdgpu.mfma %[[D40]]#[[D1]] * %[[D40]]#[[D2]] + %[[D40]]#[[D6:.+]] {blocks = 1 : i32,
+        # CHECK-SAME:         k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>,
+        # CHECK-SAME:         vector<4xf32>
+        # CHECK:            %[[D44:.+]] = amdgpu.mfma %[[D40]]#[[D0]] * %[[D40]]#[[D3]] + %[[D40]]#[[D7:.+]] {blocks = 1 : i32,
+        # CHECK-SAME:         k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>,
+        # CHECK-SAME:         vector<4xf32>
+        # CHECK:            amdgpu.lds_barrier
+        # CHECK:            %[[D45:.+]] = vector.load %[[ALLOC]][%[[D17]], %[[D21]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:            %[[D46:.+]] = vector.load %[[ALLOC_0]][%[[D23]], %[[D21]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:            %[[D47:.+]] = vector.load %[[ALLOC_0]][%[[D25]], %[[D20]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:            %[[D48:.+]] = vector.load %[[ALLOC_0]][%[[D25]], %[[D21]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:            %[[D49:.+]] = vector.load %[[ALLOC]][%[[D17]], %[[D20]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:            %[[D50:.+]] = vector.load %[[ALLOC]][%[[D32]], %[[D20]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:            %[[D51:.+]] = vector.load %[[ALLOC]][%[[D32]], %[[D21]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:            %[[D52:.+]] = vector.load %[[ALLOC_0]][%[[D23]], %[[D20]]] : memref<64x32xf16,
+        # CHECK-SAME:         #[[GPU]].address_space<workgroup>>, vector<4xf16>
+        # CHECK:            %[[D53:.+]] = amdgpu.mfma %[[D49]] * %[[D52]] + %[[D41]] {blocks = 1 : i32, k = 16 : i32, m = 16 :
+        # CHECK-SAME:         i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:            %[[D54:.+]] = amdgpu.mfma %[[D50]] * %[[D47]] + %[[D42]] {blocks = 1 : i32, k = 16 : i32, m = 16 :
+        # CHECK-SAME:         i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:            %[[D55:.+]] = amdgpu.mfma %[[D50]] * %[[D52]] + %[[D43]] {blocks = 1 : i32, k = 16 : i32, m = 16 :
+        # CHECK-SAME:         i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:            %[[D56:.+]] = amdgpu.mfma %[[D49]] * %[[D47]] + %[[D44]] {blocks = 1 : i32, k = 16 : i32, m = 16 :
+        # CHECK-SAME:         i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:            %[[D57:.+]] = amdgpu.mfma %[[D45]] * %[[D46]] + %[[D53]] {blocks = 1 : i32, k = 16 : i32, m = 16 :
+        # CHECK-SAME:         i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:            %[[D58:.+]] = amdgpu.mfma %[[D51]] * %[[D48]] + %[[D54]] {blocks = 1 : i32, k = 16 : i32, m = 16 :
+        # CHECK-SAME:         i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:            %[[D59:.+]] = amdgpu.mfma %[[D51]] * %[[D46]] + %[[D55]] {blocks = 1 : i32, k = 16 : i32, m = 16 :
+        # CHECK-SAME:         i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:            %[[D60:.+]] = amdgpu.mfma %[[D45]] * %[[D48]] + %[[D56]] {blocks = 1 : i32, k = 16 : i32, m = 16 :
+        # CHECK-SAME:         i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+        # CHECK:            %[[D61:.+]] = vector.extract_strided_slice %[[D57]] {offsets = [0], sizes = [1], strides = [1]} :
+        # CHECK-SAME:         vector<4xf32> to vector<1xf32>
+        # CHECK:            %[[D62:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<128x128xf32,
+        # CHECK-SAME:         strided<[128, 1], offset: ?>>
+        # CHECK:            %[[D63:.+]] = arith.addi %[[D1]], %[[D15]] : index
+        # CHECK:            %[[D64:.+]] = arith.addi %[[D63]], %[[D20]] : index
+        # CHECK:            %[[D65:.+]] = arith.addi %[[D16]], %[[D11]] : index
+        # CHECK:            %[[D66:.+]] = arith.addi %[[D65]], %[[D2]] : index
+        # CHECK:            vector.store %[[D61]], %[[D62]][%[[D64]], %[[D66]]] : memref<128x128xf32, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<1xf32>
+        # CHECK:            %[[D67:.+]] = vector.extract_strided_slice %[[D57]] {offsets = [1], sizes = [1], strides = [1]} :
+        # CHECK-SAME:         vector<4xf32> to vector<1xf32>
+        # CHECK:            %[[D68:.+]] = arith.addi %[[D64]], %[[C1]] : index
+        # CHECK:            vector.store %[[D67]], %[[D62]][%[[D68]], %[[D66]]] : memref<128x128xf32, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<1xf32>
+        # CHECK:            %[[D69:.+]] = vector.extract_strided_slice %[[D57]] {offsets = [2], sizes = [1], strides = [1]} :
+        # CHECK-SAME:         vector<4xf32> to vector<1xf32>
+        # CHECK:            %[[D70:.+]] = arith.addi %[[D64]], %[[C2]] : index
+        # CHECK:            vector.store %[[D69]], %[[D62]][%[[D70]], %[[D66]]] : memref<128x128xf32, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<1xf32>
+        # CHECK:            %[[D71:.+]] = vector.extract_strided_slice %[[D57]] {offsets = [3], sizes = [1], strides = [1]} :
+        # CHECK-SAME:         vector<4xf32> to vector<1xf32>
+        # CHECK:            %[[D72:.+]] = arith.addi %[[D64]], %[[C3]] : index
+        # CHECK:            vector.store %[[D71]], %[[D62]][%[[D72]], %[[D66]]] : memref<128x128xf32, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<1xf32>
+        # CHECK:            %[[D73:.+]] = vector.extract_strided_slice %[[D58]] {offsets = [0], sizes = [1], strides = [1]} :
+        # CHECK-SAME:         vector<4xf32> to vector<1xf32>
+        # CHECK:            %[[D74:.+]] = arith.addi %[[D64]], %[[C16]] : index
+        # CHECK:            %[[D75:.+]] = arith.addi %[[D66]], %[[C16]] : index
+        # CHECK:            vector.store %[[D73]], %[[D62]][%[[D74]], %[[D75]]] : memref<128x128xf32, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<1xf32>
+        # CHECK:            %[[D76:.+]] = vector.extract_strided_slice %[[D58]] {offsets = [1], sizes = [1], strides = [1]} :
+        # CHECK-SAME:         vector<4xf32> to vector<1xf32>
+        # CHECK:            %[[D77:.+]] = arith.addi %[[D64]], %[[C17]] : index
+        # CHECK:            vector.store %[[D76]], %[[D62]][%[[D77]], %[[D75]]] : memref<128x128xf32, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<1xf32>
+        # CHECK:            %[[D78:.+]] = vector.extract_strided_slice %[[D58]] {offsets = [2], sizes = [1], strides = [1]} :
+        # CHECK-SAME:         vector<4xf32> to vector<1xf32>
+        # CHECK:            %[[D79:.+]] = arith.addi %[[D64]], %[[C18]] : index
+        # CHECK:            vector.store %[[D78]], %[[D62]][%[[D79]], %[[D75]]] : memref<128x128xf32, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<1xf32>
+        # CHECK:            %[[D80:.+]] = vector.extract_strided_slice %[[D58]] {offsets = [3], sizes = [1], strides = [1]} :
+        # CHECK-SAME:         vector<4xf32> to vector<1xf32>
+        # CHECK:            %[[D81:.+]] = arith.addi %[[D64]], %[[C19]] : index
+        # CHECK:            vector.store %[[D80]], %[[D62]][%[[D81]], %[[D75]]] : memref<128x128xf32, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<1xf32>
+        # CHECK:            %[[D82:.+]] = vector.extract_strided_slice %[[D59]] {offsets = [0], sizes = [1], strides = [1]} :
+        # CHECK-SAME:         vector<4xf32> to vector<1xf32>
+        # CHECK:            vector.store %[[D82]], %[[D62]][%[[D74]], %[[D66]]] : memref<128x128xf32, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<1xf32>
+        # CHECK:            %[[D83:.+]] = vector.extract_strided_slice %[[D59]] {offsets = [1], sizes = [1], strides = [1]} :
+        # CHECK-SAME:         vector<4xf32> to vector<1xf32>
+        # CHECK:            vector.store %[[D83]], %[[D62]][%[[D77]], %[[D66]]] : memref<128x128xf32, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<1xf32>
+        # CHECK:            %[[D84:.+]] = vector.extract_strided_slice %[[D59]] {offsets = [2], sizes = [1], strides = [1]} :
+        # CHECK-SAME:         vector<4xf32> to vector<1xf32>
+        # CHECK:            vector.store %[[D84]], %[[D62]][%[[D79]], %[[D66]]] : memref<128x128xf32, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<1xf32>
+        # CHECK:            %[[D85:.+]] = vector.extract_strided_slice %[[D59]] {offsets = [3], sizes = [1], strides = [1]} :
+        # CHECK-SAME:         vector<4xf32> to vector<1xf32>
+        # CHECK:            vector.store %[[D85]], %[[D62]][%[[D81]], %[[D66]]] : memref<128x128xf32, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<1xf32>
+        # CHECK:            %[[D86:.+]] = vector.extract_strided_slice %[[D60]] {offsets = [0], sizes = [1], strides = [1]} :
+        # CHECK-SAME:         vector<4xf32> to vector<1xf32>
+        # CHECK:            vector.store %[[D86]], %[[D62]][%[[D64]], %[[D75]]] : memref<128x128xf32, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<1xf32>
+        # CHECK:            %[[D87:.+]] = vector.extract_strided_slice %[[D60]] {offsets = [1], sizes = [1], strides = [1]} :
+        # CHECK-SAME:         vector<4xf32> to vector<1xf32>
+        # CHECK:            vector.store %[[D87]], %[[D62]][%[[D68]], %[[D75]]] : memref<128x128xf32, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<1xf32>
+        # CHECK:            %[[D88:.+]] = vector.extract_strided_slice %[[D60]] {offsets = [2], sizes = [1], strides = [1]} :
+        # CHECK-SAME:         vector<4xf32> to vector<1xf32>
+        # CHECK:            vector.store %[[D88]], %[[D62]][%[[D70]], %[[D75]]] : memref<128x128xf32, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<1xf32>
+        # CHECK:            %[[D89:.+]] = vector.extract_strided_slice %[[D60]] {offsets = [3], sizes = [1], strides = [1]} :
+        # CHECK-SAME:         vector<4xf32> to vector<1xf32>
+        # CHECK:            vector.store %[[D89]], %[[D62]][%[[D72]], %[[D75]]] : memref<128x128xf32, strided<[128, 1], offset:
+        # CHECK-SAME:         ?>>, vector<1xf32>
+        # CHECK:            return
+
+
 @run_test
 def test_add_float():
     constraints: list[tkw.Constraint] = [
diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py
new file mode 100644
index 00000000..eafabb27
--- /dev/null
+++ b/lit_tests/kernel/wave/scheduling.py
@@ -0,0 +1,227 @@
+# RUN: python %s | FileCheck %s
+
+import logging
+import unittest
+import shark_turbine.kernel as tk
+import shark_turbine.kernel.lang as tkl
+import shark_turbine.kernel.wave as tkw
+from shark_turbine.kernel.wave.promotion import promote_placeholders
+from shark_turbine.kernel.wave.hoisting import hoist_allocs
+from shark_turbine.kernel.wave.expansion import expand_graph
+from shark_turbine.kernel.lang.global_symbols import *
+from shark_turbine.kernel._support.tracing import CapturedTrace
+from shark_turbine.kernel._support.indexing import IndexingContext
+from shark_turbine.kernel.ops.wave_ops import *
+from shark_turbine.kernel.wave.utils import run_test, print_subgraph
+from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads
+from shark_turbine.kernel.wave.shared_memory_indexing import (
+    apply_shared_memory_indexing_corrections,
+)
+from shark_turbine.kernel.wave.scheduling.schedule import schedule_graph
+
+
+# Input sizes
+M = tkl.sym.M
+N = tkl.sym.N
+K = tkl.sym.K
+
+# Workgroup tile sizes
+BLOCK_M = tkl.sym.BLOCK_M
+BLOCK_N = tkl.sym.BLOCK_N
+BLOCK_K = tkl.sym.BLOCK_K
+
+# Address space
+ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
+ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0
+
+# Induction variable for dimension K
+ARGK = tkl.sym.ARGK
+
+
+@tkw.wave_trace_only()
+def gemm_pipelined(
+    a: tkl.Memory[M, K, ADDRESS_SPACE_0, tkl.f16],
+    b: tkl.Memory[N, K, ADDRESS_SPACE_0, tkl.f16],
+    c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
+):
+    c_reg = tkl.Register[M, N, tkl.f32](0.0)
+
+    @tkw.reduction(K, init_args=[c_reg])
+    def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
+        a_reg = tkw.read(a, elements_per_thread=4)
+        b_reg = tkw.read(b, elements_per_thread=4)
+        acc = tkw.mma(a_reg, b_reg, acc)
+        return acc
+
+    tkw.write(repeat, c, elements_per_thread=4)
+
+
+@run_test
+def test_gemm_pipelined():
+    constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
+    constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
+    constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, 0)]
+    constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, 1)]
+    constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)]
+    constraints += [
+        tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1))
+    ]
+    with tk.gen.TestLaunchContext(
+        {
+            M: 128,
+            N: 256,
+            K: 128,
+            BLOCK_M: 64,
+            BLOCK_N: 64,
+            BLOCK_K: 32,
+            ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE,
+            ADDRESS_SPACE_0: SHARED_ADDRESS_SPACE,
+            READ_SHARED_DELAY: 1,
+            WRITE_SHARED_DELAY: 1,
+            READ_GLOBAL_DELAY: 2,
+            WRITE_GLOBAL_DELAY: 2,
+            MMA_DELAY: 1,
+            SHARED_MEMORY_UNITS: 2,
+            GLOBAL_MEMORY_UNITS: 2,
+            MMA_UNITS: 2,
+        }
+    ):
+        trace: CapturedTrace = gemm_pipelined()
+        IndexingContext.current().finalize()
+        promote_placeholders(trace, constraints)
+        hoist_allocs(trace)
+        expand_graph(trace, constraints)
+        minimize_global_loads(trace, constraints)
+        apply_shared_memory_indexing_corrections(trace, constraints)
+        schedule_graph(trace, constraints)
+
+        print_subgraph(trace, "pipelined_reduction", False)
+        # CHECK: %acc_0_0_0
+        # CHECK-NEXT: %acc_0_1_0
+        # CHECK-NEXT: %acc_1_0_0
+        # CHECK-NEXT: %acc_1_1_0
+        # CHECK-NEXT: %rotating_reg_0
+        # CHECK-NEXT: %rotating_reg_1
+        # CHECK-NEXT: %rotating_reg_2
+        # CHECK-NEXT: %rotating_reg_3
+        # CHECK-NEXT: %rotating_reg_4
+        # CHECK-NEXT: %rotating_reg_5
+        # CHECK-NEXT: %rotating_reg_6
+        # CHECK-NEXT: %mma_1_1_1
+        # CHECK-SAME: (%rotating_reg_1, %rotating_reg_4, %rotating_reg_6)
+        # CHECK-NEXT: %read_shared_0_0_0
+        # CHECK-NEXT: %read_shared_0_0_1
+        # CHECK-NEXT: %read_4
+        # CHECK-NEXT: %read_5
+        # CHECK-NEXT: %read_shared_1_0_0
+        # CHECK-NEXT: %read_shared_1_0_1
+        # CHECK-NEXT: %mma_0_0_0
+        # CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_0_1, %acc_0_0_0)
+        # CHECK-NEXT: %mma_0_1_0
+        # CHECK-SAME: (%read_shared_0_0_0, %rotating_reg_3, %acc_0_1_0)
+        # CHECK-NEXT: %mma_0_0_1
+        # CHECK-SAME: (%rotating_reg_0, %rotating_reg_2, %mma_0_0_0)
+        # CHECK-NEXT: %mma_1_0_0
+        # CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_0_1, %acc_1_0_0)
+        # CHECK-NEXT: %write_2
+        # CHECK-NEXT: %write_3
+        # CHECK-NEXT: %mma_1_0_1
+        # CHECK-SAME: (%read_shared_1_0_1, %rotating_reg_2, %mma_1_0_0)
+        # CHECK-NEXT: %mma_0_1_1
+        # CHECK-SAME: (%rotating_reg_0, %rotating_reg_5, %mma_0_1_0)
+        # CHECK-NEXT: %read_shared_0_1_0
+        # CHECK-NEXT: %read_shared_0_1_1
+        # CHECK-NEXT: %mma_1_1_0
+        # CHECK-SAME: (%read_shared_1_0_0, %rotating_reg_3, %mma_1_1_1)
+        # CHECK-NEXT: %read_shared_0_0_2
+        # CHECK-NEXT: %read_shared_0_0_3
+        # CHECK-NEXT: [mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1, read_shared_0_0_2, read_shared_1_0_1, read_shared_0_0_3, read_shared_0_1_0, rotating_reg_5, read_shared_0_1_1, mma_1_1_0]
+
+        print_subgraph(trace, "region_1", False)
+        # CHECK: %a
+        # CHECK-NEXT: %b
+        # CHECK-NEXT: %c
+        # CHECK-NEXT: %register_0_0_0
+        # CHECK-NEXT: %register_1_1_0
+        # CHECK-NEXT: %register_1_0_0
+        # CHECK-NEXT: %register_0_1_0
+        # CHECK-NEXT: %allocate
+        # CHECK-NEXT: %allocate_1
+        # CHECK-NEXT: %read_4
+        # CHECK-NEXT: %read_5
+        # CHECK-NEXT: %write_2
+        # CHECK-NEXT: %write_3
+        # CHECK-NEXT: %read_shared_0_1_0
+        # CHECK-NEXT: %read_shared_0_1_1
+        # CHECK-NEXT: %read_shared_0_0_1
+        # CHECK-NEXT: %read_shared_0_0_2
+        # CHECK-NEXT: %read_shared_0_0_0
+        # CHECK-NEXT: %read_shared_0_0_3
+        # CHECK-NEXT: %read_6
+        # CHECK-NEXT: %read_7
+        # CHECK-NEXT: %read_shared_1_0_0
+        # CHECK-NEXT: %read_shared_1_0_1
+        # CHECK-NEXT: %mma_0_0_0
+        # CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_0_3, %register_0_0_0)
+        # CHECK-NEXT: %mma_0_1_0
+        # CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_1_0, %register_0_1_0)
+        # CHECK-NEXT: %mma_0_0_1
+        # CHECK-SAME: (%read_shared_0_0_1, %read_shared_0_0_2, %mma_0_0_0)
+        # CHECK-NEXT: %mma_1_0_0
+        # CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_0_3, %register_1_0_0)
+        # CHECK-NEXT: %write_4
+        # CHECK-NEXT: %write_5
+        # CHECK-NEXT: %mma_1_0_1
+        # CHECK-SAME: (%read_shared_1_0_1, %read_shared_0_0_2, %mma_1_0_0)
+        # CHECK-NEXT: %mma_0_1_1
+        # CHECK-SAME: (%read_shared_0_0_1, %read_shared_0_1_1, %mma_0_1_0)
+        # CHECK-NEXT: %read_shared_0_1_2
+        # CHECK-NEXT: %read_shared_0_1_3
+        # CHECK-NEXT: %mma_1_1_0
+        # CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_1_0, %register_1_1_0)
+        # CHECK-NEXT: %read_shared_0_0_4
+        # CHECK-NEXT: %read_shared_0_0_5
+        # CHECK-NEXT: %reduction_1
+        # CHECK-NEXT: %getresult_1_1_0
+        # CHECK-NEXT: %getresult_1_0_0
+        # CHECK-NEXT: %getresult_0_1_0
+        # CHECK-NEXT: %getresult_0_0_0
+        # CHECK-NEXT: %get_result_4
+        # CHECK-NEXT: %get_result_5
+        # CHECK-NEXT: %get_result_6
+        # CHECK-NEXT: %get_result_7
+        # CHECK-NEXT: %get_result_8
+        # CHECK-NEXT: %get_result_9
+        # CHECK-NEXT: %get_result_10
+        # CHECK-NEXT: %mma_1_1_1
+        # CHECK-SAME: (%get_result_5, %get_result_9, %get_result_10)
+        # CHECK-NEXT: %read_shared_0_0_6
+        # CHECK-NEXT: %read_shared_0_0_7
+        # CHECK-NEXT: %read_shared_1_0_2
+        # CHECK-NEXT: %read_shared_1_0_3
+        # CHECK-NEXT: %mma_0_0_2
+        # CHECK-SAME: (%read_shared_0_0_6, %read_shared_0_0_7, %getresult_0_0_0)
+        # CHECK-NEXT: %mma_0_1_2
+        # CHECK-SAME: (%read_shared_0_0_6, %get_result_7, %getresult_0_1_0)
+        # CHECK-NEXT: %mma_0_0_3
+        # CHECK-SAME: (%get_result_4, %get_result_6, %mma_0_0_2)
+        # CHECK-NEXT: %mma_1_0_2
+        # CHECK-SAME: (%read_shared_1_0_2, %read_shared_0_0_7, %getresult_1_0_0)
+        # CHECK-NEXT: %mma_1_0_3
+        # CHECK-SAME: (%read_shared_1_0_3, %get_result_6, %mma_1_0_2)
+        # CHECK-NEXT: %mma_0_1_3
+        # CHECK-SAME: (%get_result_4, %get_result_9, %mma_0_1_2)
+        # CHECK-NEXT: %mma_1_1_2
+        # CHECK-SAME: (%read_shared_1_0_2, %get_result_7, %mma_1_1_1)
+        # CHECK-NEXT: %mma_1_1_3
+        # CHECK-SAME: (%read_shared_1_0_3, %get_result_9, %mma_1_1_2)
+        # CHECK-NEXT: %write_0_0_0
+        # CHECK-NEXT: %write_1_1_0
+        # CHECK-NEXT: %write_1_0_0
+        # CHECK-NEXT: %write_0_1_0
+        # CHECK-NEXT: return None
+
+
+if __name__ == "__main__":
+    logging.basicConfig(level=logging.DEBUG)
+    unittest.main()
diff --git a/shark_turbine/kernel/_support/tracing.py b/shark_turbine/kernel/_support/tracing.py
index 42424257..857cdb34 100644
--- a/shark_turbine/kernel/_support/tracing.py
+++ b/shark_turbine/kernel/_support/tracing.py
@@ -129,6 +129,9 @@ def __init__(self, region_graph: RegionGraph, root_graph: str):
     def get_subgraph(self, name: str) -> fx.Graph:
         return self.region_graph.subgraphs[name]
 
+    def add_subgraph(self, name: str, graph: fx.Graph):
+        self.region_graph.subgraphs[name] = graph
+
     def get_root_graph(self) -> fx.Graph:
         return self.get_subgraph(self.root_graph)
 
diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py
index ebadf0c4..442a876a 100644
--- a/shark_turbine/kernel/ops/wave_ops.py
+++ b/shark_turbine/kernel/ops/wave_ops.py
@@ -453,6 +453,8 @@ def index(self, value: Any):
             self.fx_node.index = {}
             for dim, key in value.items():
                 self.fx_node.index[dim] = key
+        elif isinstance(value, list):
+            self.fx_node.index = list(value)
         else:
             raise ValueError("Index must be a dict")
 
@@ -692,7 +694,7 @@ def is_barrier_between(self, src: fx.Node, dst: fx.Node) -> bool:
             prev_node, found_src = prev_node.prev, prev_node == src
         if not found_src:
             return False
-        while next_node and not found_dst:
+        while next_node.next.op != "root" and not found_dst:
             next_node, found_dst = next_node.next, next_node == dst
         return found_dst
 
@@ -921,6 +923,10 @@ def index(self) -> list[dict[IndexSymbol, IndexSequence]]:
                     else None
                 )
 
+    @index.setter
+    def index(self, value: Any):
+        CustomOp.index.fset(self, value)
+
 
 @define_op("write")
 @dataclass
diff --git a/shark_turbine/kernel/wave/codegen.py b/shark_turbine/kernel/wave/codegen.py
index e4a8cf72..1534895e 100644
--- a/shark_turbine/kernel/wave/codegen.py
+++ b/shark_turbine/kernel/wave/codegen.py
@@ -90,6 +90,7 @@ class WaveEmitter:
     root_sig: BoundKernelSignature
     trace: CapturedTrace
     constraints: list[Constraint]
+    scheduling_metadata: dict[fx.Node, int]
     ip: InsertionPoint = None
     OP_HANDLERS: ClassVar[dict[str, Callable[["WaveEmitter", fx.Node], None]]] = {}
     _node_values: ClassVar[dict[fx.Node, List[IRProxyValue]]] = {}
@@ -209,13 +210,14 @@ def _get_div(mul, add, denominator):
 
     induction_var_syms = []
     induction_vars = []
-    for constraint in emitter.constraints:
-        if isinstance(constraint, TilingConstraint):
-            assert (
-                constraint.dim in emitter.induction_vars
-            ), f"Could not find induction var for {constraint.dim} dimension"
-            induction_var_syms.append(constraint.induction_var)
-            induction_vars.append(emitter.induction_vars[constraint.dim])
+    if emitter.induction_vars:
+        for constraint in emitter.constraints:
+            if isinstance(constraint, TilingConstraint):
+                assert (
+                    constraint.dim in emitter.induction_vars
+                ), f"Could not find induction var for {constraint.dim} dimension"
+                induction_var_syms.append(constraint.induction_var)
+                induction_vars.append(emitter.induction_vars[constraint.dim])
 
     # TODO: factor this out
     all_symbols = emitter.thread_ids + emitter.workgroup_ids + induction_vars
@@ -910,7 +912,6 @@ def handle_reduction(emitter: WaveEmitter, node: fx.Node):
     flat_init_args, _ = pytree.tree_flatten((init_args))
     flat_init_args = [cast_py_value(emitter, arg) for arg in flat_init_args]
 
-    # Without scheduling, we assume that we always start at 0.
     start = arith_d.constant(IndexType.get(), int(0))
 
     count = None
@@ -921,7 +922,10 @@ def handle_reduction(emitter: WaveEmitter, node: fx.Node):
 
     # For now, we assume that dimensions that have tiling constraints on them,
     # do not have any other constraints.
-    end = arith_d.constant(IndexType.get(), int(count))
+    end_value = int(count)
+    if node in emitter.scheduling_metadata:
+        end_value = emitter.scheduling_metadata[node]
+    end = arith_d.constant(IndexType.get(), end_value)
 
     # Since we divide the end by the tile size, we need to make sure that the
     # step is 1.
diff --git a/shark_turbine/kernel/wave/scheduling/graph_utils.py b/shark_turbine/kernel/wave/scheduling/graph_utils.py
index e625b666..af398af3 100644
--- a/shark_turbine/kernel/wave/scheduling/graph_utils.py
+++ b/shark_turbine/kernel/wave/scheduling/graph_utils.py
@@ -213,12 +213,13 @@ def topological_sort_nodes(
     Perform a topological sort on the nodes in the strongly connected component that have an edge in edges, excluding
     certain nodes.
     """
-    scc_nodes = set(scc) - set(exclude)
+    scc_nodes = set(scc)
     filtered_nodes = set()
     for edge in edges:
         if edge._from in scc_nodes and edge._to in scc_nodes:
             filtered_nodes.add(edge._to)
             filtered_nodes.add(edge._from)
+    filtered_nodes -= set(exclude) if exclude is not None else set()
     sorted_nodes = sorted(filtered_nodes, key=lambda x: x.f)
     return sorted_nodes
 
diff --git a/shark_turbine/kernel/wave/scheduling/loop_reconstruction.py b/shark_turbine/kernel/wave/scheduling/loop_reconstruction.py
new file mode 100644
index 00000000..52f205b1
--- /dev/null
+++ b/shark_turbine/kernel/wave/scheduling/loop_reconstruction.py
@@ -0,0 +1,556 @@
+from ..constraints import Constraint, TilingConstraint
+from ..._support.indexing import IndexSymbol
+from ..._support.tracing import CapturedTrace
+from ...ops.wave_ops import (
+    Reduction,
+    IterArg,
+    Placeholder,
+    Allocate,
+    Output,
+    Write,
+    GetResult,
+    get_custom,
+)
+from .modulo_scheduling import ModuloScheduler
+from ..utils import (
+    graph_copy,
+    erase_graph,
+    get_induction_variable,
+    replace_uses_in,
+)
+from ..utils import subs_idxc
+import torch.fx as fx
+import math
+from collections import deque
+from ..visualization import visualize_mapped_graphs, visualize_graph
+from ....support.logging import get_logger
+from ...lang.global_symbols import SHARED_ADDRESS_SPACE
+import random
+from typing import Optional
+from .loop_reconstruction_utils import (
+    ArgumentContext,
+    create_fill_stage_schedule,
+    create_drain_stage_schedule,
+    liveness_analysis,
+    partition_graph_by_stage,
+    interleave_instructions,
+)
+from enum import Enum
+
+logger = get_logger("turbine.wave.scheduling.loop_reconstruction")
+
+
+class PipelineStage(Enum):
+    PROLOGUE = 0
+    KERNEL = 1
+    EPILOGUE = 2
+
+
+def add_nodes_by_schedule(
+    reduction_graph: fx.Graph,
+    partitioned_graph: list[dict[int, fx.Node]],
+    arg_context: ArgumentContext,
+    stages: list[int],
+    initiation_interval: int,
+    induction_variable: IndexSymbol,
+    current_induction_variables: list[int],
+    rotating_registers: dict[fx.Node, list[fx.Node]],
+    pipelining_stage: PipelineStage = PipelineStage.KERNEL,
+):
+    """
+    Interleave the instructions in the partitioned graph by stage
+    for a single initiation interval, updating the argument maps
+    per stage starting at the provided start times and indices.
+    """
+    fill_or_drain = pipelining_stage in [PipelineStage.PROLOGUE, PipelineStage.EPILOGUE]
+    fill = pipelining_stage == PipelineStage.PROLOGUE
+    drain = pipelining_stage == PipelineStage.EPILOGUE
+
+    for cycle in range(initiation_interval):
+        logger.debug(f"Cycle: {cycle}")
+        # Interleave the instructions that are scheduled at the same cycle.
+        interleaved_instructions = []
+        for iteration, stage in enumerate(stages):
+            if stage is None:
+                continue
+            if cycle not in partitioned_graph[stage]:
+                continue
+            for node in partitioned_graph[stage][cycle]:
+                interleaved_instructions.append((iteration, stage, node))
+        interleave_instructions(interleaved_instructions)
+
+        for iteration, stage, node in interleaved_instructions:
+            logger.debug(f"Node: {node}, Stage: {stage}, Iteration: {iteration}")
+            custom_node = get_custom(node)
+            logger.debug(f"Node args: {node.args}")
+            for arg in node.args:
+                if arg_context.contains_in_iteration(iteration, arg):
+                    logger.debug(
+                        f"Found arg: {arg} in partitioned argument map. Using {arg_context.get_from_iteration(iteration, arg)}."
+                    )
+                    continue
+            new_node = custom_node.copy(
+                new_graph=reduction_graph,
+                arg_transform=lambda x: (
+                    arg_context.get_from_iteration(iteration, x)
+                    if arg_context.contains_in_iteration(iteration, x)
+                    else x
+                ),
+            )
+            # Update the argument context.
+            arg_context[(iteration, stage, node)] = new_node.fx_node
+            logger.debug(
+                f"Copying Node: {node}, Stage: {stage}, Iteration: {iteration} -> {new_node.fx_node}"
+            )
+            # Set the index for the new node by substituting the induction variable
+            # for the current iteration.
+            new_node.index = node.index
+            for dim in new_node.index:
+                new_node.index[dim] = new_node.index[dim].subs(
+                    {induction_variable: current_induction_variables[iteration]}
+                )
+            # Add scheduling parameters for debugging.
+            new_node.scheduling_parameters = node.scheduling_parameters
+            # Update the rotating registers and argument context for the current node (if applicable).
+            if node in rotating_registers:
+                rotating_registers[node].append(new_node.fx_node)
+                rotating_registers[node].popleft()
+                # If draining, then override the rotating registers and update the argument context.
+                if fill_or_drain:
+                    for next_stage in range(stage + 1, len(stages)):
+                        arg_context[(iteration, next_stage, node)] = new_node.fx_node
+
+            # Update the init args in the argument context whenever a result is computed.
+            if node in arg_context.results:
+                if (
+                    pipelining_stage == PipelineStage.KERNEL
+                    or pipelining_stage == PipelineStage.EPILOGUE
+                ):
+                    logger.debug(
+                        f"Updating result: {node} -> {arg_context.result_to_iter_arg[node]} to {new_node.fx_node}."
+                    )
+                    arg_context.map_arg_all(
+                        arg_context.result_to_iter_arg[node], new_node.fx_node
+                    )
+                if pipelining_stage == PipelineStage.PROLOGUE:
+                    logger.debug(
+                        f"Updating result: {node} -> {arg_context.result_to_init_arg[node]} to {new_node.fx_node}."
+                    )
+                    arg_context.map_arg_all(
+                        arg_context.result_to_init_arg[node], new_node.fx_node
+                    )
+
+
+def push_placeholders(
+    implicit_captures: list[fx.Node],
+    reduction_subgraph: fx.Node,
+    arg_context: ArgumentContext,
+):
+    """
+    Push placeholders into the argument context for the reduction graph.
+    """
+    for node in reduction_subgraph.nodes:
+        custom = get_custom(node)
+        if isinstance(custom, Placeholder) and not isinstance(custom, IterArg):
+            root_node = [x for x in implicit_captures if x.name == node.name][0]
+            assert root_node is not None
+            arg_context.map_arg_all(node, root_node)
+
+
+def construct_prologue(
+    reduction_subgraph: fx.Graph,
+    reduction: Reduction,
+    partitioned_graph: list[dict[int, fx.Node]],
+    scheduler: ModuloScheduler,
+    rotating_registers: dict[fx.Node, list[fx.Node]],
+    induction_variable: IndexSymbol,
+    new_induction_variables: list[int],
+    stages: list[int],
+):
+    """
+    Construct the prologue of the pipelined loop.
+    For this, we need to copy nodes from the reduction_graph and insert them
+    before the reduction operator in the root graph in the appropriate order.
+    We also need to initialize the rotating registers and update the indices
+    of the nodes to use the appropriate values of the induction variable.
+    """
+    logger.debug("=====================================")
+    logger.debug("Constructing prologue.")
+    logger.debug("=====================================")
+
+    arg_context = ArgumentContext(
+        reduction.outputs(reduction_subgraph),
+        reduction.iter_args(reduction_subgraph),
+        reduction.init_args,
+        scheduler.num_stages,
+    )
+
+    # Map iter args to init args in the prologue.
+    for iter_arg, init_arg in zip(
+        reduction.iter_args(reduction_subgraph), reduction.init_args
+    ):
+        arg_context.map_arg_all(iter_arg, init_arg)
+
+    push_placeholders(reduction.implicit_captures, reduction_subgraph, arg_context)
+    with reduction.graph.inserting_before(reduction.fx_node):
+        for i in range(scheduler.num_stages - 1):
+            add_nodes_by_schedule(
+                reduction.graph,
+                partitioned_graph,
+                arg_context,
+                stages[i],
+                scheduler.initiation_interval,
+                induction_variable,
+                new_induction_variables,
+                rotating_registers,
+                PipelineStage.PROLOGUE,
+            )
+
+    # During the prologue, we may have computed results that need to be passed as init args
+    # to the kernel.
+    new_init_args: list[fx.Node] = []
+    for init_arg in reduction.init_args:
+        mapped_init_arg = arg_context.lookup(init_arg)
+        if mapped_init_arg is None:
+            mapped_init_arg = init_arg
+        new_init_args.append(mapped_init_arg)
+    reduction.init_args = new_init_args
+
+
+def flatten_dict_values(
+    rotating_registers: dict[fx.Node, list[fx.Node]]
+) -> list[fx.Node]:
+    """
+    Flatten the values of the rotating registers into a list.
+    """
+    return [
+        register for registers in rotating_registers.values() for register in registers
+    ]
+
+
+def unflatten_dict_values(
+    rotating_registers_shapes: dict[fx.Node, int], values: list[fx.Node]
+) -> dict[fx.Node, list[fx.Node]]:
+    """
+    Unflatten the values of the rotating registers into a dictionary
+    using the provided shapes.
+    """
+    rotating_registers = {}
+    count = 0
+    for node, shape in rotating_registers_shapes.items():
+        rotating_registers[node] = deque(values[count : count + shape])
+        count += shape
+    assert count == sum(rotating_registers_shapes.values())
+    return rotating_registers
+
+
+def push_rotating_registers(
+    arg_context: ArgumentContext,
+    rotating_registers: dict[fx.Node, list[fx.Node]],
+    graph: fx.Graph,
+    node_map: dict[fx.Node, fx.Node],
+    create_new_nodes: bool = False,
+) -> dict[fx.Node, deque[fx.Node]]:
+    """
+    Pushes the rotating registers into the argument map
+    at the appropriate stages. Create new nodes in the
+    specified graph if requested.
+
+    For each rotating register,
+    we evaluate which stage it belongs to and update the argument
+    context for the next stage and n - 1 stages after it, where
+    n is the total number of rotating registers.
+    If var a has [a, b, c] as rotating registers, then in a 3-stage schedule
+        a is used in stage 2, (iteration 0)
+        b in stage 1, (iteration 1)
+        c in stage 0. (iteration 2)
+    """
+    new_rotating_registers: dict[fx.Node, deque[fx.Node]] = {}
+    count = 0
+    for node, registers in rotating_registers.items():
+        new_registers: deque[fx.Node] = deque()
+        custom = get_custom(node)
+        stage = custom.scheduling_parameters["stage"]
+        iteration = arg_context.get_kernel_iteration(stage)
+        arg_context[(iteration, stage, node)] = registers[-1]
+        for i, register in enumerate(registers):
+            mapped_stage = stage + len(registers) - i
+            mapped_iteration = arg_context.get_kernel_iteration(mapped_stage)
+            if create_new_nodes:
+                iter_arg = IterArg(f"rotating_reg_{count}").add_to_graph(graph)
+                iter_arg.type = get_custom(node).type
+                iter_arg.index = get_custom(node).index
+                arg_context[(mapped_iteration, mapped_stage, node)] = iter_arg
+                new_registers.append(iter_arg)
+                logger.debug(
+                    f"Mapped orig: {node_map[node]} / mapped: {iter_arg} to stage {mapped_stage}."
+                )
+            else:
+                arg_context[(mapped_iteration, mapped_stage, node)] = register
+                logger.debug(
+                    f"Mapped orig: {node_map[node]} / mapped: {register} to stage {mapped_stage}."
+                )
+            count += 1
+        if new_registers:
+            new_rotating_registers[node] = new_registers
+    return new_rotating_registers
+
+
+def construct_kernel(
+    reduction_subgraph: fx.Graph,
+    reduction: Reduction,
+    partitioned_graph: list[dict[int, fx.Node]],
+    scheduler: ModuloScheduler,
+    rotating_registers: dict[fx.Node, list[fx.Node]],
+    induction_variable: IndexSymbol,
+    new_induction_variables: list[int],
+    node_map: dict[fx.Node, fx.Node],
+    visualize: bool = False,
+) -> tuple[Reduction, fx.Graph]:
+    """
+    Construct the kernel of the pipelined loop.
+    First, we construct a new reduction op with an empty graph.
+    Then, we set the init args, construct the iter args and add the ops.
+    Finally, we create the output node with the return values.
+    The iter args/results of the pipelined reduction are always:
+    [results0, result1, ..., resultN, rotating_reg0, rotating_reg1, ..., rotating_regN]
+    """
+    logger.debug("=====================================")
+    logger.debug("Constructing kernel.")
+    logger.debug("=====================================")
+
+    with reduction.graph.inserting_before(reduction.fx_node):
+        pipelined_reduction = Reduction(
+            reduction.axis,
+            init_args=reduction.init_args + flatten_dict_values(rotating_registers),
+            subgraph_name="pipelined_reduction",
+            implicit_captures=reduction.implicit_captures,
+        ).add_to_graph(reduction.graph)
+        pipelined_reduction.index = reduction.index
+        pipelined_reduction_graph = fx.Graph()
+        reduction.graph.subgraphs["pipelined_reduction"] = pipelined_reduction_graph
+
+        # Update the argument map for the new reduction.
+        arg_context = ArgumentContext(
+            reduction.outputs(reduction_subgraph),
+            reduction.iter_args(reduction_subgraph),
+            reduction.init_args,
+            scheduler.num_stages,
+        )
+        push_placeholders(reduction.implicit_captures, reduction_subgraph, arg_context)
+
+        # For the original iter args, we just map the old ones to the new ones.
+        # Do this for all stages, since the original iter args are "dummy" nodes
+        # during scheduling.
+        for node in arg_context.iter_args:
+            iter_arg = IterArg(node.name).add_to_graph(pipelined_reduction_graph)
+            iter_arg.type = get_custom(node).type
+            iter_arg.index = get_custom(node).index
+            arg_context.map_arg_all(node, iter_arg)
+
+        # Push the rotating registers into the argument context.
+        new_rotating_registers: dict[fx.Node, deque[fx.Node]] = push_rotating_registers(
+            arg_context,
+            rotating_registers,
+            pipelined_reduction_graph,
+            node_map,
+            create_new_nodes=True,
+        )
+
+        add_nodes_by_schedule(
+            pipelined_reduction_graph,
+            partitioned_graph,
+            arg_context,
+            list(reversed(range(scheduler.num_stages))),
+            scheduler.initiation_interval,
+            induction_variable,
+            new_induction_variables,
+            new_rotating_registers,
+            PipelineStage.KERNEL,
+        )
+
+        # Create output node (last node in the graph).
+        return_vals: list[fx.Node] = arg_context.get_kernel_results()
+        for registers in new_rotating_registers.values():
+            return_vals.extend(registers)
+
+        Output(return_vals).add_to_graph(pipelined_reduction_graph)
+        reduction.replace_all_uses_with(pipelined_reduction)
+
+        if visualize:
+            visualize_mapped_graphs(
+                pipelined_reduction_graph,
+                new_rotating_registers,
+                arg_context.argument_map,
+                "kernel.png",
+            )
+
+        return pipelined_reduction, pipelined_reduction_graph
+
+
+def construct_epilogue(
+    reduction_subgraph: fx.Graph,
+    reduction: Reduction,
+    pipelined_reduction: Reduction,
+    partitioned_graph: list[dict[int, fx.Node]],
+    scheduler: ModuloScheduler,
+    rotating_registers: dict[fx.Node, list[fx.Node]],
+    induction_variable: IndexSymbol,
+    new_induction_variables: list[int],
+    stages: list[int],
+    num_rotating_registers: dict[fx.Node, int],
+    node_map: dict[fx.Node, fx.Node],
+    visualize: bool = False,
+):
+    """
+    Construct the epilogue of the pipelined loop.
+    The difference from the prologue is that we need to map the results
+    of the pipelined reduction to the remaining stages. (In the prologue,
+    no iteration is every completed and so we don't compute the final results)
+    We emit GetResult nodes for the rotating registers and map them to
+    the different epilogue stages.
+    """
+    logger.debug("=====================================")
+    logger.debug("Constructing epilogue.")
+    logger.debug("=====================================")
+
+    arg_context = ArgumentContext(
+        reduction.outputs(reduction_subgraph),
+        reduction.iter_args(reduction_subgraph),
+        reduction.init_args,
+        scheduler.num_stages,
+    )
+
+    existing_get_results: list[GetResult] = sorted(
+        [x for x in pipelined_reduction.users if isinstance(x, GetResult)],
+        key=lambda x: x.res_idx,
+    )
+    existing_users = {x: x.users for x in existing_get_results}
+
+    # Map the results from the kernel to the init args (for stages).
+    for iter_arg, get_result in zip(
+        reduction.iter_args(reduction_subgraph), existing_get_results
+    ):
+        arg_context.map_arg_all(iter_arg, get_result.fx_node)
+
+    with pipelined_reduction.graph.inserting_before(
+        existing_get_results[0].fx_node.next
+    ):
+        # Add get result nodes for the rotating registers and update the
+        # argument map with them.
+        rotating_registers_get_results = []
+        offset = len(existing_get_results)
+        for i in range(len(flatten_dict_values(rotating_registers))):
+            rotating_registers_get_results.append(
+                GetResult(pipelined_reduction.fx_node, i + offset).add_to_graph(
+                    pipelined_reduction.graph
+                )
+            )
+        rotating_registers = unflatten_dict_values(
+            num_rotating_registers, rotating_registers_get_results
+        )
+
+        # Push the rotating registers onto the argument map.
+        push_rotating_registers(arg_context, rotating_registers, None, node_map, False)
+        push_placeholders(reduction.implicit_captures, reduction_subgraph, arg_context)
+
+        for i in range(scheduler.num_stages - 1):
+            add_nodes_by_schedule(
+                pipelined_reduction.graph,
+                partitioned_graph,
+                arg_context,
+                stages[i],
+                scheduler.initiation_interval,
+                induction_variable,
+                new_induction_variables,
+                rotating_registers,
+                PipelineStage.EPILOGUE,
+            )
+
+        # Replace the existing uses with the new results.
+        new_results = arg_context.get_mapped_results(existing_get_results)
+        assert len(new_results) == len(existing_get_results)
+        for i, get_result in enumerate(existing_get_results):
+            replace_uses_in(existing_users, get_result, new_results[i])
+
+        if visualize:
+            visualize_mapped_graphs(
+                pipelined_reduction.graph,
+                rotating_registers,
+                arg_context.argument_map,
+                "epilogue.png",
+            )
+
+
+def construct_pipelined_loop(
+    trace: CapturedTrace,
+    reduction: Reduction,
+    graph: fx.Graph,
+    constraints: list[Constraint],
+    scheduler: ModuloScheduler,
+    node_map: dict[fx.Node, fx.Node],
+    max_induction_variable: int,
+    visualize: bool = False,
+) -> fx.Node:
+    """
+    Given a graph annotated with scheduling parameters, construct a pipelined loop
+    with a prologue, kernel and epilogue.
+    """
+    induction_variable = get_induction_variable(reduction, constraints)
+    num_rotating_registers = liveness_analysis(graph, constraints, scheduler)
+    rotating_registers: dict[fx.Node, deque[fx.Node]] = {
+        k: deque([None for _ in range(v)]) for k, v in num_rotating_registers.items()
+    }
+    partitioned_graph = partition_graph_by_stage(graph, scheduler)
+    # Construct prologue.
+    construct_prologue(
+        graph,
+        reduction,
+        partitioned_graph,
+        scheduler,
+        rotating_registers,
+        induction_variable,
+        list(range(scheduler.num_stages)),
+        create_fill_stage_schedule(scheduler.num_stages),
+    )
+    # Construct kernel.
+    pipelined_reduction, pipelined_reduction_graph = construct_kernel(
+        graph,
+        reduction,
+        partitioned_graph,
+        scheduler,
+        rotating_registers,
+        induction_variable,
+        [induction_variable + i for i in range(scheduler.num_stages)],
+        node_map,
+        visualize,
+    )
+    trace.add_subgraph(
+        get_custom(pipelined_reduction).subgraph_name, pipelined_reduction_graph
+    )
+    # Construct epilogue.
+    construct_epilogue(
+        graph,
+        reduction,
+        get_custom(pipelined_reduction),
+        partitioned_graph,
+        scheduler,
+        rotating_registers,
+        induction_variable,
+        [
+            max_induction_variable - scheduler.num_stages + i
+            for i in range(scheduler.num_stages)
+        ],
+        create_drain_stage_schedule(scheduler.num_stages),
+        num_rotating_registers,
+        node_map,
+        visualize,
+    )
+
+    # Remove the unpipelined reduction.
+    reduction.graph.erase_node(reduction.fx_node)
+
+    if visualize:
+        visualize_graph(pipelined_reduction.graph, "pipelined.png")
+
+    return pipelined_reduction
diff --git a/shark_turbine/kernel/wave/scheduling/loop_reconstruction_utils.py b/shark_turbine/kernel/wave/scheduling/loop_reconstruction_utils.py
new file mode 100644
index 00000000..b6993a21
--- /dev/null
+++ b/shark_turbine/kernel/wave/scheduling/loop_reconstruction_utils.py
@@ -0,0 +1,285 @@
+from ..constraints import Constraint, TilingConstraint
+from ..._support.indexing import IndexSymbol
+from ..._support.tracing import CapturedTrace
+from ...ops.wave_ops import Reduction, IterArg, Output, Write, GetResult, get_custom
+from .modulo_scheduling import ModuloScheduler
+from ..utils import graph_copy, erase_graph
+from ..utils import subs_idxc
+import torch.fx as fx
+import math
+from collections import defaultdict, deque, ChainMap
+from ..visualization import visualize_mapped_graphs
+from ....support.logging import get_logger
+from ...lang.global_symbols import SHARED_ADDRESS_SPACE
+import random
+from typing import Optional
+
+logger = get_logger("turbine.wave.scheduling.loop_reconstruction_utils")
+
+
+class ArgumentContext:
+    """
+    The argument context is used to store the mapping of arguments
+    for each modulo pipelining stage.
+    """
+
+    def __init__(
+        self,
+        results: list[fx.Node],
+        iter_args: list[fx.Node],
+        init_args: list[fx.Node],
+        num_stages: int,
+    ) -> None:
+        self.argument_map: list[list[dict[fx.Node, fx.Node]]] = [
+            [{} for _ in range(num_stages)] for _ in range(num_stages)
+        ]
+        self.results = results
+        self.iter_args = iter_args
+        self.init_args = init_args
+        self.num_stages = num_stages
+        self.num_iterations = num_stages
+        self.result_to_iter_arg: dict[fx.Node, fx.Node] = {}
+        self.result_to_init_arg: dict[fx.Node, fx.Node] = {}
+
+        for result, iter_arg in zip(results, iter_args):
+            self.result_to_iter_arg[result] = iter_arg
+        for result, init_arg in zip(results, init_args):
+            self.result_to_init_arg[result] = init_arg
+
+    def map_arg_all(self, from_: fx.Node, to_: fx.Node) -> None:
+        """
+        Maps the given argument from one to another into the argument context for all stages
+        and for all iterations.
+        """
+        for iteration in range(self.num_iterations):
+            for stage in range(self.num_stages):
+                self.argument_map[iteration][stage][from_] = to_
+
+    def map_arg_all_iterations(self, stage: int, from_: fx.Node, to_: fx.Node) -> None:
+        """
+        Maps the given argument from one to another into the argument context for all stages
+        and for all iterations.
+        """
+        for iteration in range(self.num_iterations):
+            self.argument_map[iteration][stage][from_] = to_
+
+    def get_mapped_results(self, get_results: list[GetResult]) -> list[fx.Node]:
+        """
+        Gets the mapped results from the last iteration. If the result is not
+        in the last iteration, then get it from the get result nodes.
+        """
+        mapped_results = []
+        for result, get_result in zip(self.results, get_results):
+            stage = result.scheduling_parameters["stage"]
+            if result not in self.argument_map[self.num_iterations - 1][stage]:
+                mapped_results.append(get_result.fx_node)
+            else:
+                mapped_results.append(
+                    self.argument_map[self.num_iterations - 1][stage][result]
+                )
+        return mapped_results
+
+    def get_kernel_iteration(self, stage: int) -> int:
+        """
+        Get the iteration from the stage for the kernel.
+        """
+        return self.num_stages - 1 - stage
+
+    def get_kernel_results(self) -> list[fx.Node]:
+        """
+        Gets the mapped results for the kernel. Here there
+        exists a fixed relationship between the iteration and stage.
+        """
+        mapped_results = []
+        for result in self.results:
+            stage = result.scheduling_parameters["stage"]
+            iteration = self.get_kernel_iteration(stage)
+            mapped_results.append(self.argument_map[iteration][stage][result])
+        return mapped_results
+
+    def __setitem__(self, key: tuple[int, fx.Node], value: fx.Node) -> None:
+        """
+        Sets the argument mapping for the given stage.
+        """
+        assert isinstance(key, tuple), "Argument context key must be a tuple"
+        iteration, stage, from_ = key
+        assert iteration < len(
+            self.argument_map
+        ), f"Iteration {iteration} not yet initialized"
+        assert stage < len(self.argument_map), f"Stage {stage} not yet initialized"
+        self.argument_map[iteration][stage][from_] = value
+
+    def __getitem__(self, value: tuple[int, fx.Node]) -> fx.Node:
+        """
+        Gets the argument mapping for the given stage.
+        """
+        assert isinstance(value, tuple), "Argument context key must be a tuple"
+        iteration, stage, key = value
+        assert iteration < len(
+            self.argument_map
+        ), f"Iteration {iteration} not yet initialized"
+        assert stage < len(self.argument_map), f"Stage {stage} not yet initialized"
+        return self.argument_map[iteration][stage].get(key, None)
+
+    def __contains__(self, key: fx.Node | tuple[int, fx.Node]) -> bool:
+        """
+        Checks if the argument context contains the given node at a specified
+        iteration and stage or at all iterations and stages.
+        """
+        if isinstance(key, tuple):
+            iteration, stage, key = key
+            return key in self.argument_map[iteration][stage]
+        return any(
+            key in self.argument_map[iteration][stage]
+            for iteration in range(self.num_iterations)
+            for stage in range(self.num_stages)
+        )
+
+    def lookup(self, key: fx.Node) -> Optional[fx.Node]:
+        """
+        Looks up the argument mapping for the given node.
+        """
+        for iteration in range(self.num_iterations):
+            for stage in range(self.num_stages):
+                if key in self.argument_map[iteration][stage]:
+                    return self.argument_map[iteration][stage][key]
+        return None
+
+    def contains_in_iteration(self, iteration: int, key: fx.Node) -> bool:
+        """
+        Checks if the argument context contains the given node at a specified
+        iteration.
+        """
+        return any(
+            key in self.argument_map[iteration][stage]
+            for stage in range(self.num_stages)
+        )
+
+    def get_from_iteration(self, iteration: int, key: fx.Node) -> fx.Node:
+        """
+        Gets the argument mapping for the given iteration.
+        """
+        for stage in range(self.num_stages):
+            if key in self.argument_map[iteration][stage]:
+                return self.argument_map[iteration][stage][key]
+        return None
+
+    def dump(self):
+        """
+        Dump the argument context to the logger.
+        """
+        for iteration in range(self.num_iterations):
+            for stage in range(self.num_stages):
+                logger.debug(f"Iteration: {iteration}, Stage: {stage}")
+                for key, value in self.argument_map[iteration][stage].items():
+                    logger.debug(f"  {key} -> {value}")
+
+
+def create_fill_stage_schedule(n: int) -> list[list[int]]:
+    """
+    Create the schedule of which stages need to be interleaved for the prologue (fill).
+    This looks like:
+    [0 None None None]
+    [1    0 None None]
+    [2    1    0 None]
+    """
+    schedule = []
+    for i in range(n - 1):
+        row = list(range(i, -1, -1))
+        row.extend([None] * (n - i - 1))
+        schedule.append(row)
+    return schedule
+
+
+def create_drain_stage_schedule(n: int) -> list[list[int]]:
+    """
+    Create the schedule of which stages need to be interleaved for the epilogue (drain).
+    This looks like:
+    [None    3    2 1]
+    [None None    3 2]
+    [None None None 3]
+    """
+    schedule = []
+    for i in range(n - 1):
+        row = [None] * (i + 1)
+        row.extend(range(n - 1, i, -1))
+        schedule.append(row)
+    return schedule
+
+
+def liveness_analysis(
+    graph: fx.Graph, constraints: list[Constraint], scheduler: ModuloScheduler
+) -> dict[fx.Node, int]:
+    """
+    Perform liveness analysis on the graph to determine the live ranges of
+    variables and use that to deduce how many rotating registers we need.
+    """
+    lifetime: dict[fx.Node, int] = {}
+    for node in graph.nodes:
+        custom = get_custom(node)
+        if custom.scheduling_parameters is None:
+            continue
+        if node not in lifetime:
+            lifetime[node] = 0
+        for user in custom.users:
+            if user.scheduling_parameters is None:
+                continue
+            logger.debug(
+                f"Node: {node}, User: {user.fx_node}, lifetime: {user.scheduling_parameters['stage'] - custom.scheduling_parameters['stage']}"
+            )
+            lifetime[node] = max(
+                user.scheduling_parameters["stage"]
+                - custom.scheduling_parameters["stage"],
+                lifetime[node],
+            )
+
+    # Determine how many copies we need for each node. If the lifetime of a node
+    # is l clocks and the initiation interval is T, then only ceil(l/T) values
+    # of the node can be live at the same time. We need to create copies of only
+    # those nodes that are live at more than one stage.
+    num_rotating_registers: dict[fx.Node, int] = {}
+    for node, l in lifetime.items():
+        if node in num_rotating_registers:
+            continue
+        custom = get_custom(node)
+        if (
+            isinstance(custom, Write)
+            and custom.memory_type.address_space == SHARED_ADDRESS_SPACE
+        ):
+            continue
+        if l > 0:
+            num_rotating_registers[node] = l
+
+    return num_rotating_registers
+
+
+def partition_graph_by_stage(
+    graph: fx.Graph, scheduler: ModuloScheduler
+) -> list[dict[int, list[fx.Node]]]:
+    """
+    Partition the graph into stages based on the scheduling parameters.
+    """
+    partitioned_graph: list[dict[int, list[fx.Node]]] = [
+        defaultdict(list) for _ in range(scheduler.num_stages)
+    ]
+    for stage in range(scheduler.num_stages):
+        for node in graph.nodes:
+            custom = get_custom(node)
+            if custom.scheduling_parameters is None:
+                continue
+            if isinstance(custom, IterArg):
+                continue
+            if custom.scheduling_parameters["stage"] == stage:
+                cycle = custom.scheduling_parameters["cycle"]
+                partitioned_graph[stage][cycle].append(node)
+    return partitioned_graph
+
+
+def interleave_instructions(instructions: list[tuple[int, int, fx.Node]]):
+    """
+    Interleave the instructions that are scheduled in the same cycle.
+    Currently, we just randomly shuffle them, but we could also sort
+    them based on some criteria.
+    """
+    rng = random.Random(0)
+    # rng.shuffle(instructions)
diff --git a/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py b/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py
index f2abbd13..82940113 100644
--- a/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py
+++ b/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py
@@ -18,6 +18,7 @@
 )
 from typing import Callable
 import numpy as np
+import math
 
 logger = get_logger("turbine.wave.modulo_scheduling")
 
@@ -263,3 +264,11 @@ def resource_reservations(self) -> np.array:
         Returns the resource reservations of the schedule.
         """
         return self.RT
+
+    @property
+    def num_stages(self) -> int:
+        """
+        Returns the number of stages in the kernel of the pipelined loop.
+        """
+        max_cycle = max([t for t in self.schedule.values()])
+        return math.ceil(max_cycle / self.initiation_interval)
diff --git a/shark_turbine/kernel/wave/scheduling/schedule.py b/shark_turbine/kernel/wave/scheduling/schedule.py
index a03ad082..8740fa08 100644
--- a/shark_turbine/kernel/wave/scheduling/schedule.py
+++ b/shark_turbine/kernel/wave/scheduling/schedule.py
@@ -11,8 +11,12 @@
 from .graph_utils import create_scheduling_edges, Edge
 from .resources import get_available_resources, annotate_resource_usage
 from ..visualization import visualize_edges, visualize_graph, visualize_schedule
-from ..utils import subs_idxc, graph_copy, erase_graph
+from .loop_reconstruction import construct_pipelined_loop
+from ..utils import graph_copy, erase_graph, get_tiling_constraint, subs_idxc
 import torch.fx as fx
+from ....support.logging import get_logger
+
+logger = get_logger("turbine.wave.scheduling.schedule")
 
 
 def visualize_scheduling_graph(edges: list[Edge]):
@@ -21,7 +25,7 @@ def visualize_scheduling_graph(edges: list[Edge]):
 
 def schedule_reduction(
     reduction: Reduction, trace: CapturedTrace, constraints: list[Constraint]
-):
+) -> dict[fx.Node, int]:
     """
     Clones the reduction graph and does the following:
     1. Annotates resource usage for each node.
@@ -68,8 +72,35 @@ def schedule_reduction(
 
     erase_graph(graph)
 
+    # After scheduling has completed, we have enough information to decide
+    # whether to pipeline the loop. For pipelining to be possible, we need
+    # to have atleast N iterations of the loop where N > num_stages - 1 (because
+    # we will be peeling off num_stages iterations from the loop).
+    tiling_constraint = get_tiling_constraint(reduction, constraints)
+    max_induction_variable = int(
+        subs_idxc(tiling_constraint.dim) // subs_idxc(tiling_constraint.tile_size)
+    )
+    if max_induction_variable <= scheduler.num_stages - 1:
+        logger.warn("Not enough iterations to pipeline the loop. Skipping pipelining.")
+        return {}
+
+    new_reduction = construct_pipelined_loop(
+        trace,
+        reduction,
+        reduction_graph,
+        constraints,
+        scheduler,
+        node_map,
+        max_induction_variable,
+        visualize,
+    )
+
+    return {new_reduction: max_induction_variable - (scheduler.num_stages - 1)}
+
 
-def schedule_graph(trace: CapturedTrace, constraints: list[Constraint]):
+def schedule_graph(
+    trace: CapturedTrace, constraints: list[Constraint]
+) -> dict[fx.Node, int]:
     """
     Given a graph, pipelines the reductions in the graph.
     """
@@ -81,5 +112,9 @@ def is_reduction(node: fx.Node) -> bool:
     if not reduction_nodes:
         return
 
+    scheduling_metadata = {}
     for reduction_node in reduction_nodes:
-        schedule_reduction(get_custom(reduction_node), trace, constraints)
+        scheduling_metadata.update(
+            schedule_reduction(get_custom(reduction_node), trace, constraints)
+        )
+    return scheduling_metadata
diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py
index affd5fef..7b5aa757 100644
--- a/shark_turbine/kernel/wave/utils.py
+++ b/shark_turbine/kernel/wave/utils.py
@@ -16,8 +16,8 @@
 from .._support.tracing import CapturedTrace
 from .._support.indexing import IndexExpr, IndexingContext, IndexSymbol, IndexSequence
 from ..lang.global_symbols import *
-from ..ops.wave_ops import get_custom, Output, Write, MMA
-from .constraints import Constraint, HardwareConstraint, TilingConstraint
+from ..ops.wave_ops import get_custom, Output, Write, Reduction, MMA, CustomOp
+from .constraints import HardwareConstraint, TilingConstraint, Constraint
 import torch.fx as fx
 import shark_turbine.kernel.lang as tkl
 
@@ -90,6 +90,21 @@ def print_trace(trace: CapturedTrace, custom_print: bool = True):
                 print(get_custom(node))
 
 
+def print_subgraph(trace: CapturedTrace, subgraph_name: str, custom_print: bool = True):
+    """
+    Prints a specific subgraphs of a trace.
+    The graphs are printed first in the torch printing format and
+    then using our custom node format.
+    """
+    # The root graph is at the back so we print the subgraphs in reverse order
+    for name, subgraph in trace.region_graph.subgraphs.items():
+        if name == subgraph_name:
+            print(subgraph)
+            if custom_print:
+                for node in subgraph.nodes:
+                    print(get_custom(node))
+
+
 def DCE(trace: CapturedTrace):
     """
     Removes all operators that are not used in the graph,
@@ -378,3 +393,42 @@ def erase_graph(graph: fx.Graph):
         for user in node.users:
             graph.erase_node(user)
         graph.erase_node(node)
+
+
+def get_induction_variable(
+    reduction: Reduction, constraints: list[Constraint]
+) -> IndexSymbol:
+    induction_var = None
+    for constraint in constraints:
+        if (
+            isinstance(constraint, TilingConstraint)
+            and reduction.axis == constraint.dim
+        ):
+            induction_var = constraint.induction_var
+            break
+    else:
+        raise ValueError(f"Could not find induction variable for reduction {reduction}")
+    return induction_var
+
+
+def get_tiling_constraint(
+    reduction: Reduction, constraints: list[Constraint]
+) -> TilingConstraint:
+    for constraint in constraints:
+        if (
+            isinstance(constraint, TilingConstraint)
+            and reduction.axis == constraint.dim
+        ):
+            return constraint
+    else:
+        raise ValueError(f"Could not find tiling constraint for reduction {reduction}")
+
+
+def replace_uses_in(users: dict[fx.Node, list[CustomOp]], old: CustomOp, new: fx.Node):
+    """
+    Replace all uses of `old` with `new` in the list of users.
+    """
+    for user in users[old]:
+        for i, arg in enumerate(user.fx_node.args):
+            if arg == old.fx_node:
+                user.update_arg(i, new)
diff --git a/shark_turbine/kernel/wave/visualization.py b/shark_turbine/kernel/wave/visualization.py
index 924c36bd..d6438bfc 100644
--- a/shark_turbine/kernel/wave/visualization.py
+++ b/shark_turbine/kernel/wave/visualization.py
@@ -11,6 +11,8 @@
     graphviz_disabled = True
 from torch import fx
 from .scheduling.graph_utils import Edge
+from ..ops.wave_ops import Output, Placeholder, IterArg, get_custom
+from collections import ChainMap
 import math
 
 
@@ -27,6 +29,9 @@ def visualize_graph(graph: fx.Graph, file_name: str):
         G.add_node(node_numbering[id(node)], label=node.name)
     for node in graph.nodes:
         for user in node.users.keys():
+            # Handle scenario where nodes are shared across graphs.
+            if user not in graph.nodes:
+                continue
             G.add_edge(node_numbering[id(node)], node_numbering[id(user)])
     G.layout(prog="dot")
     G.draw(file_name)
@@ -71,7 +76,7 @@ def visualize_schedule(
         for key, value in schedule.items():
             table[value + stage * initiation_interval][stage] += f"{key}<br>"
 
-    df = pd.DataFrame(table, columns=[f"Stage {i}" for i in range(cols)])
+    df = pd.DataFrame(table, columns=[f"Iteration {i}" for i in range(cols)])
     s = df.style.set_properties(**{"text-align": "center"})
     s = s.set_table_styles(
         [
@@ -95,3 +100,91 @@ def visualize_schedule(
     ).to_html()
     with open(f"{file_name}", "w") as f:
         f.write(output)
+
+
+def visualize_mapped_graphs(
+    second: fx.Graph,
+    rotating_registers: dict[fx.Node, list[fx.Node]],
+    mappings: list[list[dict[fx.Node, fx.Node]]],
+    file_name: str,
+):
+    """
+    Given the pipelined graph and a list of mappings of nodes from the original
+    graph to the pipelined graph (per stage), visualize the pipelined graph (with their original labels)
+
+    """
+
+    if graphviz_disabled:
+        raise ImportError("pygraphviz not installed, cannot visualize graph")
+    second_numbering = number_nodes(second)
+
+    flat_inverse_map: dict[fx.Node, fx.Node] = {}
+    flat_map: dict[fx.Node, fx.Node] = {}
+    for iteration_mapping in mappings:
+        for mapping in iteration_mapping:
+            flat_inverse_map.update({v: k for k, v in mapping.items()})
+            flat_map.update(mapping)
+    flat_inverse_map = ChainMap(flat_inverse_map)
+    flat_map = ChainMap(flat_map)
+
+    # Draw nodes and edges in the pipelined graph.
+    G = pgv.AGraph(directed=True)
+    G0 = G.add_subgraph(name="pipelined")
+    stage: dict[fx.Node, int] = {}
+    for node in second.nodes:
+        if hasattr(node, "scheduling_parameters"):
+            if node in flat_inverse_map:
+                name = flat_inverse_map[node].name
+            else:
+                name = node.name
+        else:
+            name = node.name
+        G0.add_node(
+            second_numbering[id(node)],
+            label=name,
+            color="lightblue",
+            style="filled",
+        )
+        for user in node.users.keys():
+            if user not in second.nodes:
+                continue
+            if isinstance(get_custom(user), Output):
+                continue
+            G0.add_edge(
+                second_numbering[id(node)],
+                second_numbering[id(user)],
+                color="black",
+            )
+
+    # Draw nodes and edges in the original graph.
+    colors = ["red", "green", "orange", "purple", "orange", "cyan", "magenta"]
+    max_stage = len(mappings)
+    for node, mapped_node in flat_map.items():
+        for user in node.users.keys():
+            if user not in flat_map:
+                continue
+            mapped_user = flat_map[user]
+            if mapped_user not in second.nodes or mapped_node not in second.nodes:
+                continue
+            stage = ""
+            if hasattr(user, "scheduling_parameters"):
+                stage = user.scheduling_parameters["stage"]
+            G.add_edge(
+                second_numbering[id(mapped_node)],
+                second_numbering[id(mapped_user)],
+                label=f"{stage}",
+                color=colors[stage % max_stage],
+            )
+
+    # Draw edges between rotating registers for the same variable.
+    for node in rotating_registers:
+        all_registers = [k for k, v in flat_inverse_map.items() if v == node]
+        for second, first in zip(all_registers[:-1], all_registers[1:]):
+            G.add_edge(
+                second_numbering[id(first)],
+                second_numbering[id(second)],
+                color="blue",
+            )
+
+    G.layout(prog="dot")
+    G.draw(file_name)
diff --git a/shark_turbine/kernel/wave/wave.py b/shark_turbine/kernel/wave/wave.py
index eb6003de..ef1dbd7f 100644
--- a/shark_turbine/kernel/wave/wave.py
+++ b/shark_turbine/kernel/wave/wave.py
@@ -221,8 +221,9 @@ def _trace_and_get_kernel_signature(
         decompose_reduce_ops(graph, self.constraints, idxc.subs)
 
         # Schedule the reduction ops.
+        scheduling_metadata = {}
         if kwargs.get("schedule", False):
-            schedule_graph(graph, self.constraints)
+            scheduling_metadata = schedule_graph(graph, self.constraints)
 
         # Add shared memory barriers.
         add_shared_memory_barriers(graph)
@@ -250,7 +251,9 @@ def _trace_and_get_kernel_signature(
             entrypoint_name, kernel_sig, grid, workgroup_size, subgroup_size
         )
 
-        emitter = WaveEmitter(dispatch_entrypoint, graph, self.constraints)
+        emitter = WaveEmitter(
+            dispatch_entrypoint, graph, self.constraints, scheduling_metadata
+        )
         emitter.emit(graph.get_root_graph())
         emitter.finish()
 
diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py
index 344032a4..2386ebd9 100644
--- a/tests/kernel/wave/wave_gemm_test.py
+++ b/tests/kernel/wave/wave_gemm_test.py
@@ -15,6 +15,7 @@
 from shark_turbine.kernel.wave.iree_utils import generate_iree_ref
 import os
 import json
+from torch.testing import assert_close
 
 _run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0))
 require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled")
@@ -40,7 +41,8 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]:
 
 @require_e2e
 @pytest.mark.parametrize("shape", get_test_shapes("test_gemm"))
-def testGemm(shape: tuple[int]):
+@pytest.mark.parametrize("enable_scheduling", [False, True])
+def testGemm(shape: tuple[int], enable_scheduling: bool):
 
     # Input sizes
     M = tkl.sym.M
@@ -106,10 +108,22 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
         M: shape[0],
         N: shape[1],
         K: shape[2],
+        READ_SHARED_DELAY: 1,
+        WRITE_SHARED_DELAY: 1,
+        READ_GLOBAL_DELAY: 2,
+        WRITE_GLOBAL_DELAY: 2,
+        MMA_DELAY: 1,
+        SHARED_MEMORY_UNITS: 4,
+        GLOBAL_MEMORY_UNITS: 4,
+        MMA_UNITS: 4,
     }
     config = {"backend": "rocm", "device": "hip", "target": "gfx942"}
     with tk.gen.TestLaunchContext(
-        hyperparams, canonicalize=True, run=True, run_config=config
+        hyperparams,
+        canonicalize=True,
+        run=True,
+        run_config=config,
+        schedule=enable_scheduling,
     ):
         a = torch.randn(shape[0], shape[2], dtype=torch.float16)
         b = torch.randn(shape[1], shape[2], dtype=torch.float16)
@@ -123,9 +137,4 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
 
         iree_ref = torch.zeros(shape[0], shape[1], dtype=torch.float32)
         generate_iree_ref("mmt", [a, b], [iree_ref], config)
-        assert torch.equal(c, iree_ref)
-
-
-if __name__ == "__main__":
-    logging.basicConfig(level=logging.DEBUG)
-    unittest.main()
+        assert_close(c, iree_ref)