Skip to content

Commit

Permalink
Address Martin's comments
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Jul 13, 2024
1 parent b8c105f commit bb4918a
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 115 deletions.
112 changes: 32 additions & 80 deletions lit_tests/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def run(func: Callable[[], None]) -> Callable[[], None]:


def get_read_nodes(graph: fx.Graph) -> list[fx.Node]:
nodes: list[fx.Node] = list(graph.nodes)
return [node for node in nodes if hasattr(node, "tkw_op") and node.tkw_op == Read]
custom_nodes: list[CustomOp] = [get_custom(node) for node in graph.nodes]
return [node for node in custom_nodes if isinstance(node, Read)]


def print_trace(trace: CapturedTrace):
Expand Down Expand Up @@ -85,26 +85,20 @@ def test_read_write_equal_sizes():
graph: fx.Graph = trace.get_root_graph()
read_node = get_read_nodes(graph)[0]
IndexingContext.current().finalize()
promote_node(read_node, graph, SHARED_ADDRESS_SPACE)
promote_node(read_node, SHARED_ADDRESS_SPACE)
print_trace(trace)
# CHECK: %a
# CHECK-NEXT: %c
# CHECK-NEXT: %read_0_0
# CHECK-NEXT: %read
# CHECK-SAME: (%a, 4)
# CHECK-NEXT: %read_1_1
# CHECK-SAME: (%a, 4)
# CHECK-NEXT: %read_1_0
# CHECK-SAME: (%a, 4)
# CHECK-NEXT: %read_0_1
# CHECK-SAME: (%a, 4)
# CHECK-NEXT: %write_0_0
# CHECK-SAME: (%read, %c, 4)
# CHECK-NEXT: %write_1_1
# CHECK-SAME: (%read_1_1, %c, 4)
# CHECK-NEXT: %write_1_0
# CHECK-SAME: (%read_1_0, %c, 4)
# CHECK-NEXT: %write_0_1
# CHECK-SAME: (%read_0_1, %c, 4)
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, N), f16, SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %write_1
# CHECK-SAME: (%read, %allocate, 4)
# CHECK-NEXT: %read_1
# CHECK-SAME: (%allocate, 4)
# CHECK-NEXT: %write
# CHECK-SAME: (%read_1, %c, 4)

# CHECK: -----

Expand Down Expand Up @@ -146,79 +140,37 @@ def test_gemm():
graph: fx.Graph = trace.get_subgraph("region_0")
read_nodes = get_read_nodes(graph)
for read_node in read_nodes:
promote_node(read_node, graph, SHARED_ADDRESS_SPACE)
promote_node(read_node, SHARED_ADDRESS_SPACE)
hoist_allocs(trace)
IndexingContext.current().finalize()
print_trace(trace)
# Root graph:
# CHECK: %a
# CHECK-NEXT: %b
# CHECK-NEXT: %c
# CHECK-NEXT: %register_0_0_0
# CHECK-NEXT: %register_1_1_0
# CHECK-NEXT: %register_1_0_0
# CHECK-NEXT: %register_0_1_0
# CHECK-NEXT: %reduction
# CHECK-SAME: %register_0_0_0, %register_0_1_0, %register_1_0_0, %register_1_1_0
# CHECK-NEXT: %getresult_1_1_0
# CHECK-NEXT: %getresult_1_0_0
# CHECK-NEXT: %getresult_0_1_0
# CHECK-NEXT: %getresult_0_0_0
# CHECK-NEXT: %write_0_0_0
# CHECK-SAME: (%get_result_0_0_0, %c, 4)
# CHECK-NEXT: %write_1_1_0
# CHECK-SAME: (%get_result_1_1_0, %c, 4)
# CHECK-NEXT: %write_1_0_0
# CHECK-SAME: (%get_result_1_0_0, %c, 4)
# CHECK-NEXT: %write_0_1_0
# CHECK-SAME: (%get_result_0_1_0, %c, 4)
# CHECK-NEXT: %register
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, K), f16, SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %allocate_1
# CHECK-SAME: ((N, K), f16, SHARED_ADDRESS_SPACE)
# CHECK-NEXT: reduction

# Reduction subgraph:

# CHECK: %acc_0_0_0
# CHECK-NEXT: %acc_1_1_0
# CHECK-NEXT: %acc_1_0_0
# CHECK-NEXT: %acc_0_1_0

# CHECK: %acc
# CHECK-NEXT: %a
# CHECK-NEXT: %read_0_0_0
# CHECK-SAME: (%a, 4)
# CHECK-NEXT: %read_0_0_1
# CHECK-SAME: (%a, 4)
# CHECK-NEXT: %read_1_0_0
# CHECK-SAME: (%a, 4)
# CHECK-NEXT: %read_1_0_1
# CHECK-SAME: (%a, 4)

# CHECK-NEXT: %read
# CHECK-NEXT: %write
# CHECK-SAME: (%read, %allocate, 4)
# CHECK-NEXT: %read_2
# CHECK-SAME: (%allocate, 4)
# CHECK-NEXT: %b
# CHECK-NEXT: %read_0_0_0
# CHECK-SAME: (%b, 4)
# CHECK-NEXT: %read_0_0_1
# CHECK-SAME: (%b, 4)
# CHECK-NEXT: %read_0_1_0
# CHECK-SAME: (%b, 4)
# CHECK-NEXT: %read_0_1_1
# CHECK-SAME: (%b, 4)

# CHECK-NEXT: %mma_0_0_0
# CHECK-SAME: (%read_0_0_0, %read_0_0_0, %acc)
# CHECK-NEXT: %mma_0_0_1
# CHECK-SAME: (%read_0_0_1, %read_0_0_1, %mma_0_0_0)
# CHECK-NEXT: %mma_1_1_0
# CHECK-SAME: (%read_1_0_0, %read_0_1_0, %acc_1_1_0)
# CHECK-NEXT: %mma_1_1_1
# CHECK-SAME: (%read_1_0_1, %read_0_1_1, %mma_1_1_0)
# CHECK-NEXT: %mma_1_0_0
# CHECK-SAME: (%read_1_0_0, %read_0_0_0, %acc_1_0_0)
# CHECK-NEXT: %mma_1_0_1
# CHECK-SAME: (%read_1_0_1, %read_0_0_1, %mma_1_0_0)
# CHECK-NEXT: %mma_0_1_0
# CHECK-SAME: (%read_0_0_0, %read_0_1_0, %acc_0_1_0)
# CHECK-NEXT: %mma_0_1_1
# CHECK-SAME: (%read_0_0_1, %read_0_1_1, %mma_0_1_0)
# CHECK-NEXT: return [mma_0_0_1, mma_1_1_1, mma_1_0_1, mma_0_1_1]

# CHECK-NEXT: -----
# CHECK-NEXT: %read_1
# CHECK-NEXT: %write_1
# CHECK-SAME: (%read_1, %allocate_1, 4)
# CHECK-NEXT: %read_3
# CHECK-SAME: (%allocate, 4)
# CHECK-NEXT: %mma
# CHECK-SAME: (%read_2, %read_3, %acc)


if __name__ == "__main__":
Expand Down
39 changes: 31 additions & 8 deletions shark_turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,36 @@

def allocate(
shape: tuple[IndexExpr], dtype: DataType, address_space: IndexSymbol
) -> "Memory": ...
) -> "Memory":
...


def read(
memory: "Memory", elements_per_thread: Optional[IndexExpr] = None
) -> "Register": ...
) -> "Register":
...


def reduction(
axis: IndexExpr, args: Sequence["Register"]
) -> Callable[[Callable[[AccT], AccT]], AccT]: ...
) -> Callable[[Callable[[AccT], AccT]], AccT]:
...


def register(
shape: tuple[IndexExpr, ...], dtype: DataType, value: float
) -> "Register": ...
def register(shape: tuple[IndexExpr, ...], dtype: DataType, value: float) -> "Register":
...


def mma(lhs: "Register", rhs: "Register", acc: "Register") -> "Register": ...
def mma(lhs: "Register", rhs: "Register", acc: "Register") -> "Register":
...


def write(
register_: "Register",
memory: "Memory",
elements_per_thread: Optional[IndexExpr | int] = None,
): ...
):
...


def define_op(op_name: str) -> Callable[[T], T]:
Expand Down Expand Up @@ -202,6 +206,25 @@ def copy(self, new_name: Optional[str] = None) -> Self:
new_node.name = new_name
return get_custom(new_node)

def copy_to_new_graph(
self, new_graph: fx.Graph, new_name: Optional[str] = None
) -> Self:
"""Returns a duplicate of this node."""
new_node = new_graph.node_copy(self.fx_node)
new_node.tkw_op = self
if new_name:
new_node.name = new_name
return get_custom(new_node)

def replace_all_uses_with(self, new_node: CustomOp):
"""Replace all uses of the current node with the new node."""
for user in self.users:
user.update_arg(user.node_args.index(self), new_node)

def erase(self):
"""Erase the current node from the graph where it exists."""
self.graph.erase_node(self.fx_node)

@classmethod
def handle(cls, graph, *args, **kwargs) -> fx.Node:
node = cls(*args, **kwargs)
Expand Down
17 changes: 8 additions & 9 deletions shark_turbine/kernel/wave/hoisting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
import torch.fx as fx
from ..ops.wave_ops import *
from .address_spaces import *
import shark_turbine.kernel.lang as tkl

logger = get_logger("turbine.wave.hoisting")


def get_allocs_(graph: fx.Graph) -> list[fx.Node]:
allocs = []
for node in graph.nodes:
if hasattr(node, "tkw_op") and node.tkw_op == Allocate:
allocs.append(node)
return allocs
def get_allocs_(graph: fx.Graph) -> list[CustomOp]:
return [
custom_node
for node in graph.nodes
if isinstance((custom_node := get_custom(node)), Allocate)
]


def hoist_allocs(trace: CapturedTrace):
Expand All @@ -27,6 +26,6 @@ def hoist_allocs(trace: CapturedTrace):
subgraph = trace.get_subgraph(custom_node.subgraph_name)
allocs = get_allocs_(subgraph)
for alloc in allocs:
new_alloc = root_graph.node_copy(alloc)
new_alloc = alloc.copy_to_new_graph(root_graph)
alloc.replace_all_uses_with(new_alloc)
subgraph.erase_node(alloc)
alloc.erase()
31 changes: 13 additions & 18 deletions shark_turbine/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,22 @@
logger = get_logger("turbine.wave.promotion")


def apply_promotion_pattern_(
custom_node: Read | Write, allocate_node: Allocate, graph: fx.Graph
) -> list[fx.Node]:
promoted_nodes = []
def apply_promotion_pattern_(custom_node: Read | Write, allocate_node: Allocate):
match custom_node:
case Read(
memory, elements_per_thread
) if memory.type.address_space != allocate_node.address_space:
promoted_write = Write(
custom_node.fx_node, allocate_node.fx_node, elements_per_thread
).add_to_graph(graph)
promoted_read = Read(
allocate_node.fx_node, elements_per_thread
).add_to_graph(graph)
promoted_nodes = [promoted_write, promoted_read]
custom_node.fx_node.replace_all_uses_with(promoted_read)
return promoted_nodes
).add_to_graph(custom_node.graph)
custom_node.replace_all_uses_with(promoted_read)
with custom_node.graph.inserting_before(promoted_read):
Write(
custom_node.fx_node, allocate_node.fx_node, elements_per_thread
).add_to_graph(custom_node.graph)


def promote_node(node: fx.Node, graph: fx.Graph, address_space: IndexSymbol):
def promote_node(node: CustomOp, address_space: IndexSymbol):
"""Promotes the given operand in the provided graph
to the specified address space.
Expand All @@ -36,11 +32,10 @@ def promote_node(node: fx.Node, graph: fx.Graph, address_space: IndexSymbol):
memory location and subsequent uses reading from there.
"""

custom_node = get_custom(node)
assert isinstance(custom_node, Read) or isinstance(custom_node, Write)
with graph.inserting_before(node.next):
assert isinstance(node, Read) or isinstance(node, Write)
with node.graph.inserting_before(node.fx_node.next):
allocate_node = Allocate(
custom_node.type.symbolic_shape, custom_node.type.dtype, address_space
node.type.symbolic_shape, node.type.dtype, address_space
)
allocate_node.add_to_graph(graph)
apply_promotion_pattern_(custom_node, allocate_node, graph)
allocate_node.add_to_graph(node.graph)
apply_promotion_pattern_(node, allocate_node)

0 comments on commit bb4918a

Please sign in to comment.