Skip to content

Commit

Permalink
Add code to construct pipelined loop from schedule
Browse files Browse the repository at this point in the history
This PR adds code to construct the epilogue, kernel
and prologue once we have computed a schedule. We
simulate rotating registers in software and add
visualization tools to show the pipelined graphs.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Sep 25, 2024
1 parent 909411a commit dc31c60
Show file tree
Hide file tree
Showing 14 changed files with 1,646 additions and 28 deletions.
333 changes: 333 additions & 0 deletions lit_tests/kernel/wave/codegen.py

Large diffs are not rendered by default.

227 changes: 227 additions & 0 deletions lit_tests/kernel/wave/scheduling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# RUN: python %s | FileCheck %s

import logging
import unittest
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import shark_turbine.kernel.wave as tkw
from shark_turbine.kernel.wave.promotion import promote_placeholders
from shark_turbine.kernel.wave.hoisting import hoist_allocs
from shark_turbine.kernel.wave.expansion import expand_graph
from shark_turbine.kernel.lang.global_symbols import *
from shark_turbine.kernel._support.tracing import CapturedTrace
from shark_turbine.kernel._support.indexing import IndexingContext
from shark_turbine.kernel.ops.wave_ops import *
from shark_turbine.kernel.wave.utils import run_test, print_subgraph
from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads
from shark_turbine.kernel.wave.shared_memory_indexing import (
apply_shared_memory_indexing_corrections,
)
from shark_turbine.kernel.wave.scheduling.schedule import schedule_graph


# Input sizes
M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K

# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K

# Address space
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0

# Induction variable for dimension K
ARGK = tkl.sym.ARGK


@tkw.wave_trace_only()
def gemm_pipelined(
a: tkl.Memory[M, K, ADDRESS_SPACE_0, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE_0, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.reduction(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a, elements_per_thread=4)
b_reg = tkw.read(b, elements_per_thread=4)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(repeat, c, elements_per_thread=4)


@run_test
def test_gemm_pipelined():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, 0)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)]
constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1))
]
with tk.gen.TestLaunchContext(
{
M: 128,
N: 256,
K: 128,
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K: 32,
ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE,
ADDRESS_SPACE_0: SHARED_ADDRESS_SPACE,
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
SHARED_MEMORY_UNITS: 2,
GLOBAL_MEMORY_UNITS: 2,
MMA_UNITS: 2,
}
):
trace: CapturedTrace = gemm_pipelined()
IndexingContext.current().finalize()
promote_placeholders(trace, constraints)
hoist_allocs(trace)
expand_graph(trace, constraints)
minimize_global_loads(trace, constraints)
apply_shared_memory_indexing_corrections(trace, constraints)
schedule_graph(trace, constraints)

print_subgraph(trace, "pipelined_reduction", False)
# CHECK: %acc_0_0_0
# CHECK-NEXT: %acc_0_1_0
# CHECK-NEXT: %acc_1_0_0
# CHECK-NEXT: %acc_1_1_0
# CHECK-NEXT: %rotating_reg_0
# CHECK-NEXT: %rotating_reg_1
# CHECK-NEXT: %rotating_reg_2
# CHECK-NEXT: %rotating_reg_3
# CHECK-NEXT: %rotating_reg_4
# CHECK-NEXT: %rotating_reg_5
# CHECK-NEXT: %rotating_reg_6
# CHECK-NEXT: %mma_1_1_1
# CHECK-SAME: (%rotating_reg_1, %rotating_reg_4, %rotating_reg_6)
# CHECK-NEXT: %read_shared_0_0_0
# CHECK-NEXT: %read_shared_0_0_1
# CHECK-NEXT: %read_4
# CHECK-NEXT: %read_5
# CHECK-NEXT: %read_shared_1_0_0
# CHECK-NEXT: %read_shared_1_0_1
# CHECK-NEXT: %mma_0_0_0
# CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_0_1, %acc_0_0_0)
# CHECK-NEXT: %mma_0_1_0
# CHECK-SAME: (%read_shared_0_0_0, %rotating_reg_3, %acc_0_1_0)
# CHECK-NEXT: %mma_0_0_1
# CHECK-SAME: (%rotating_reg_0, %rotating_reg_2, %mma_0_0_0)
# CHECK-NEXT: %mma_1_0_0
# CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_0_1, %acc_1_0_0)
# CHECK-NEXT: %write_2
# CHECK-NEXT: %write_3
# CHECK-NEXT: %mma_1_0_1
# CHECK-SAME: (%read_shared_1_0_1, %rotating_reg_2, %mma_1_0_0)
# CHECK-NEXT: %mma_0_1_1
# CHECK-SAME: (%rotating_reg_0, %rotating_reg_5, %mma_0_1_0)
# CHECK-NEXT: %read_shared_0_1_0
# CHECK-NEXT: %read_shared_0_1_1
# CHECK-NEXT: %mma_1_1_0
# CHECK-SAME: (%read_shared_1_0_0, %rotating_reg_3, %mma_1_1_1)
# CHECK-NEXT: %read_shared_0_0_2
# CHECK-NEXT: %read_shared_0_0_3
# CHECK-NEXT: [mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1, read_shared_0_0_2, read_shared_1_0_1, read_shared_0_0_3, read_shared_0_1_0, rotating_reg_5, read_shared_0_1_1, mma_1_1_0]

print_subgraph(trace, "region_1", False)
# CHECK: %a
# CHECK-NEXT: %b
# CHECK-NEXT: %c
# CHECK-NEXT: %register_0_0_0
# CHECK-NEXT: %register_1_1_0
# CHECK-NEXT: %register_1_0_0
# CHECK-NEXT: %register_0_1_0
# CHECK-NEXT: %allocate
# CHECK-NEXT: %allocate_1
# CHECK-NEXT: %read_4
# CHECK-NEXT: %read_5
# CHECK-NEXT: %write_2
# CHECK-NEXT: %write_3
# CHECK-NEXT: %read_shared_0_1_0
# CHECK-NEXT: %read_shared_0_1_1
# CHECK-NEXT: %read_shared_0_0_1
# CHECK-NEXT: %read_shared_0_0_2
# CHECK-NEXT: %read_shared_0_0_0
# CHECK-NEXT: %read_shared_0_0_3
# CHECK-NEXT: %read_6
# CHECK-NEXT: %read_7
# CHECK-NEXT: %read_shared_1_0_0
# CHECK-NEXT: %read_shared_1_0_1
# CHECK-NEXT: %mma_0_0_0
# CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_0_3, %register_0_0_0)
# CHECK-NEXT: %mma_0_1_0
# CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_1_0, %register_0_1_0)
# CHECK-NEXT: %mma_0_0_1
# CHECK-SAME: (%read_shared_0_0_1, %read_shared_0_0_2, %mma_0_0_0)
# CHECK-NEXT: %mma_1_0_0
# CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_0_3, %register_1_0_0)
# CHECK-NEXT: %write_4
# CHECK-NEXT: %write_5
# CHECK-NEXT: %mma_1_0_1
# CHECK-SAME: (%read_shared_1_0_1, %read_shared_0_0_2, %mma_1_0_0)
# CHECK-NEXT: %mma_0_1_1
# CHECK-SAME: (%read_shared_0_0_1, %read_shared_0_1_1, %mma_0_1_0)
# CHECK-NEXT: %read_shared_0_1_2
# CHECK-NEXT: %read_shared_0_1_3
# CHECK-NEXT: %mma_1_1_0
# CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_1_0, %register_1_1_0)
# CHECK-NEXT: %read_shared_0_0_4
# CHECK-NEXT: %read_shared_0_0_5
# CHECK-NEXT: %reduction_1
# CHECK-NEXT: %getresult_1_1_0
# CHECK-NEXT: %getresult_1_0_0
# CHECK-NEXT: %getresult_0_1_0
# CHECK-NEXT: %getresult_0_0_0
# CHECK-NEXT: %get_result_4
# CHECK-NEXT: %get_result_5
# CHECK-NEXT: %get_result_6
# CHECK-NEXT: %get_result_7
# CHECK-NEXT: %get_result_8
# CHECK-NEXT: %get_result_9
# CHECK-NEXT: %get_result_10
# CHECK-NEXT: %mma_1_1_1
# CHECK-SAME: (%get_result_5, %get_result_9, %get_result_10)
# CHECK-NEXT: %read_shared_0_0_6
# CHECK-NEXT: %read_shared_0_0_7
# CHECK-NEXT: %read_shared_1_0_2
# CHECK-NEXT: %read_shared_1_0_3
# CHECK-NEXT: %mma_0_0_2
# CHECK-SAME: (%read_shared_0_0_6, %read_shared_0_0_7, %getresult_0_0_0)
# CHECK-NEXT: %mma_0_1_2
# CHECK-SAME: (%read_shared_0_0_6, %get_result_7, %getresult_0_1_0)
# CHECK-NEXT: %mma_0_0_3
# CHECK-SAME: (%get_result_4, %get_result_6, %mma_0_0_2)
# CHECK-NEXT: %mma_1_0_2
# CHECK-SAME: (%read_shared_1_0_2, %read_shared_0_0_7, %getresult_1_0_0)
# CHECK-NEXT: %mma_1_0_3
# CHECK-SAME: (%read_shared_1_0_3, %get_result_6, %mma_1_0_2)
# CHECK-NEXT: %mma_0_1_3
# CHECK-SAME: (%get_result_4, %get_result_9, %mma_0_1_2)
# CHECK-NEXT: %mma_1_1_2
# CHECK-SAME: (%read_shared_1_0_2, %get_result_7, %mma_1_1_1)
# CHECK-NEXT: %mma_1_1_3
# CHECK-SAME: (%read_shared_1_0_3, %get_result_9, %mma_1_1_2)
# CHECK-NEXT: %write_0_0_0
# CHECK-NEXT: %write_1_1_0
# CHECK-NEXT: %write_1_0_0
# CHECK-NEXT: %write_0_1_0
# CHECK-NEXT: return None


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
3 changes: 3 additions & 0 deletions shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def __init__(self, region_graph: RegionGraph, root_graph: str):
def get_subgraph(self, name: str) -> fx.Graph:
return self.region_graph.subgraphs[name]

def add_subgraph(self, name: str, graph: fx.Graph):
self.region_graph.subgraphs[name] = graph

def get_root_graph(self) -> fx.Graph:
return self.get_subgraph(self.root_graph)

Expand Down
8 changes: 7 additions & 1 deletion shark_turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ def index(self, value: Any):
self.fx_node.index = {}
for dim, key in value.items():
self.fx_node.index[dim] = key
elif isinstance(value, list):
self.fx_node.index = list(value)
else:
raise ValueError("Index must be a dict")

Expand Down Expand Up @@ -692,7 +694,7 @@ def is_barrier_between(self, src: fx.Node, dst: fx.Node) -> bool:
prev_node, found_src = prev_node.prev, prev_node == src
if not found_src:
return False
while next_node and not found_dst:
while next_node.next.op != "root" and not found_dst:
next_node, found_dst = next_node.next, next_node == dst
return found_dst

Expand Down Expand Up @@ -921,6 +923,10 @@ def index(self) -> list[dict[IndexSymbol, IndexSequence]]:
else None
)

@index.setter
def index(self, value: Any):
CustomOp.index.fset(self, value)


@define_op("write")
@dataclass
Expand Down
22 changes: 13 additions & 9 deletions shark_turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class WaveEmitter:
root_sig: BoundKernelSignature
trace: CapturedTrace
constraints: list[Constraint]
scheduling_metadata: dict[fx.Node, int]
ip: InsertionPoint = None
OP_HANDLERS: ClassVar[dict[str, Callable[["WaveEmitter", fx.Node], None]]] = {}
_node_values: ClassVar[dict[fx.Node, List[IRProxyValue]]] = {}
Expand Down Expand Up @@ -209,13 +210,14 @@ def _get_div(mul, add, denominator):

induction_var_syms = []
induction_vars = []
for constraint in emitter.constraints:
if isinstance(constraint, TilingConstraint):
assert (
constraint.dim in emitter.induction_vars
), f"Could not find induction var for {constraint.dim} dimension"
induction_var_syms.append(constraint.induction_var)
induction_vars.append(emitter.induction_vars[constraint.dim])
if emitter.induction_vars:
for constraint in emitter.constraints:
if isinstance(constraint, TilingConstraint):
assert (
constraint.dim in emitter.induction_vars
), f"Could not find induction var for {constraint.dim} dimension"
induction_var_syms.append(constraint.induction_var)
induction_vars.append(emitter.induction_vars[constraint.dim])

# TODO: factor this out
all_symbols = emitter.thread_ids + emitter.workgroup_ids + induction_vars
Expand Down Expand Up @@ -910,7 +912,6 @@ def handle_reduction(emitter: WaveEmitter, node: fx.Node):
flat_init_args, _ = pytree.tree_flatten((init_args))
flat_init_args = [cast_py_value(emitter, arg) for arg in flat_init_args]

# Without scheduling, we assume that we always start at 0.
start = arith_d.constant(IndexType.get(), int(0))

count = None
Expand All @@ -921,7 +922,10 @@ def handle_reduction(emitter: WaveEmitter, node: fx.Node):

# For now, we assume that dimensions that have tiling constraints on them,
# do not have any other constraints.
end = arith_d.constant(IndexType.get(), int(count))
end_value = int(count)
if node in emitter.scheduling_metadata:
end_value = emitter.scheduling_metadata[node]
end = arith_d.constant(IndexType.get(), end_value)

# Since we divide the end by the tile size, we need to make sure that the
# step is 1.
Expand Down
3 changes: 2 additions & 1 deletion shark_turbine/kernel/wave/scheduling/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,13 @@ def topological_sort_nodes(
Perform a topological sort on the nodes in the strongly connected component that have an edge in edges, excluding
certain nodes.
"""
scc_nodes = set(scc) - set(exclude)
scc_nodes = set(scc)
filtered_nodes = set()
for edge in edges:
if edge._from in scc_nodes and edge._to in scc_nodes:
filtered_nodes.add(edge._to)
filtered_nodes.add(edge._from)
filtered_nodes -= set(exclude) if exclude is not None else set()
sorted_nodes = sorted(filtered_nodes, key=lambda x: x.f)
return sorted_nodes

Expand Down
Loading

0 comments on commit dc31c60

Please sign in to comment.