Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add separate function to apply shared memory indexing #129

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def mma(
# CHECK: %[[D5:.+]] = vector.load %[[D0]][%[[D4]], %[[C0]]] : memref<64x16xf16, strided<[16, 1], offset: ?>>,
# CHECK-SAME: vector<4xf16>
# CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU:.+]].address_space<workgroup>>
# CHECK: vector.store %[[D5]], %[[ALLOC]][%[[D4]], %[[C0]]] : memref<32x16xf16,
# CHECK: vector.store %[[D5]], %[[ALLOC]][%[[D2]], %[[C0]]] : memref<32x16xf16,
# CHECK-SAME: #[[GPU]].address_space<workgroup>>, vector<4xf16>
# CHECK: amdgpu.lds_barrier
# CHECK: %[[D6:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index
Expand All @@ -366,7 +366,7 @@ def mma(
# CHECK: %[[D16:.+]] = vector.load %[[D12]][%[[D15]], %[[C0]]] : memref<128x16xf16, strided<[16, 1], offset:
# CHECK-SAME: ?>>, vector<4xf16>
# CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU]].address_space<workgroup>>
# CHECK: vector.store %[[D16]], %[[ALLOC_0]][%[[D15]], %[[C0]]] : memref<32x16xf16,
# CHECK: vector.store %[[D16]], %[[ALLOC_0]][%[[D13]], %[[C0]]] : memref<32x16xf16,
# CHECK-SAME: #[[GPU]].address_space<workgroup>>, vector<4xf16>
# CHECK: amdgpu.lds_barrier
# CHECK: %[[D17:.+]] = arith.addi %[[D6]], %[[D13]] : index
Expand Down Expand Up @@ -483,7 +483,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK: %[[D28:.+]] = arith.muli %[[ARG3]], %[[C16]] : index
# CHECK: %[[D29:.+]] = vector.load %[[D22]][%[[D27]], %[[D28]]] : memref<64x64xf16, strided<[64, 1],
# CHECK-SAME: offset: ?>>, vector<4xf16>
# CHECK: vector.store %[[D29]], %[[ALLOC]][%[[D27]], %[[D28]]] : memref<32x16xf16,
# CHECK: vector.store %[[D29]], %[[ALLOC]][%[[D25]], %[[C0]]] : memref<32x16xf16,
# CHECK-SAME: #[[GPU]].address_space<workgroup>>, vector<4xf16>
# CHECK: amdgpu.lds_barrier
# CHECK: %[[D30:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index
Expand All @@ -498,7 +498,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK: %[[D38:.+]] = arith.addi %[[D37]], %[[D36]] : index
# CHECK: %[[D39:.+]] = vector.load %[[D23]][%[[D38]], %[[D28]]] : memref<128x64xf16, strided<[64, 1],
# CHECK-SAME: offset: ?>>, vector<4xf16>
# CHECK: vector.store %[[D39]], %[[ALLOC_0]][%[[D38]], %[[D28]]] : memref<32x16xf16,
# CHECK: vector.store %[[D39]], %[[ALLOC_0]][%[[D36]], %[[C0]]] : memref<32x16xf16,
# CHECK-SAME: #[[GPU]].address_space<workgroup>>, vector<4xf16>
# CHECK: amdgpu.lds_barrier
# CHECK: %[[D40:.+]] = arith.addi %[[D30]], %[[D36]] : index
Expand Down
4 changes: 4 additions & 0 deletions lit_tests/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from shark_turbine.kernel.ops.wave_ops import *
from shark_turbine.kernel.wave.utils import run_test, print_trace
from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads
from shark_turbine.kernel.wave.shared_memory_indexing import (
apply_shared_memory_indexing_corrections,
)
from shark_turbine.kernel.wave.index_sequence_analysis import (
partition_strided_operators,
)
Expand Down Expand Up @@ -83,6 +86,7 @@ def test_gemm():
hoist_allocs(trace)
expand_graph(trace, constraints)
minimize_global_loads(trace, constraints)
apply_shared_memory_indexing_corrections(trace, constraints)
partition_strided_operators(trace, constraints)
print_trace(trace)
# Root graph:
Expand Down
4 changes: 4 additions & 0 deletions lit_tests/kernel/wave/minimize_global_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from shark_turbine.kernel.wave.utils import run_test, print_trace
from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads
from shark_turbine.kernel.wave.visualization import visualize_graph
from shark_turbine.kernel.wave.shared_memory_indexing import (
apply_shared_memory_indexing_corrections,
)


# Input sizes
Expand Down Expand Up @@ -86,6 +89,7 @@ def test_gemm():
if visualize:
visualize_graph(trace.get_subgraph("region_0"), "before.png")
minimize_global_loads(trace, constraints)
apply_shared_memory_indexing_corrections(trace, constraints)
if visualize:
visualize_graph(trace.get_subgraph("region_0"), "after.png")
add_shared_memory_barriers(trace)
Expand Down
31 changes: 7 additions & 24 deletions shark_turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from .._support.dtype import DataType
from .._support.regions import RegionGraph
from .base import OpDispatcher
from ..lang.global_symbols import MMA_ACC, MMA_LHS, MMA_RHS

if TYPE_CHECKING:
from ..wave.constraints import Constraint
Expand Down Expand Up @@ -439,12 +438,11 @@ def index(self, value: Any):
if value is None:
return
if isinstance(value, dict):
assert all(
isinstance(v, IndexSequence) for v in value.values()
), f"Index must be a dict with values of type IndexSequence"
self.fx_node.index = {}
for dim, key in value.items():
assert isinstance(
key, IndexSequence
), f"Expected IndexSequence, got {key}"
if not hasattr(self.fx_node, "index") or self.fx_node.index is None:
self.fx_node.index = {}
self.fx_node.index[dim] = key
else:
raise ValueError("Index must be a dict")
Expand Down Expand Up @@ -757,21 +755,10 @@ def post_expansion(self, constraints: list["Constraint"]) -> None:
ensuring that the LHS and RHS indices are consistent with their
corresponding address spaces.
"""
from ..wave.constraints import TilingConstraint
from ..wave.utils import remove_global_indexing

tiling_constraints = [c for c in constraints if isinstance(c, TilingConstraint)]
self.lhs.index = self.lhs_index
self.rhs.index = self.rhs_index
self.acc.index = self.acc_index

# TODO: this is really wrong place for it, it relies on specific kernel structure,
# generated, if mma input is not come from load, things will breaks.
if get_custom(self.lhs).memory_type.address_space == SHARED_ADDRESS_SPACE:
self.lhs.index = remove_global_indexing(self.lhs_index, tiling_constraints)
if get_custom(self.rhs).memory_type.address_space == SHARED_ADDRESS_SPACE:
self.rhs.index = remove_global_indexing(self.rhs_index, tiling_constraints)


@define_op("read")
@dataclass
Expand Down Expand Up @@ -924,13 +911,9 @@ def memory_type(self) -> "Memory":
return get_custom(self.memory).type

@property
def index(self) -> dict[IndexSymbol, IndexSequence]:
register_index = get_custom(self.register_).index
return register_index if register_index is not None else super().index

@index.setter
def index(self, value: dict[IndexSymbol, IndexSequence]):
CustomOp.index.fset(self, value)
def register_index(self) -> dict[IndexSymbol, IndexSequence]:
custom = get_custom(self.register_)
return custom.index


@define_op("get_result")
Expand Down
10 changes: 7 additions & 3 deletions shark_turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ def has_strided_access(node: fx.Node) -> bool:
"""
custom = get_custom(node)
if isinstance(custom, Write) and len(custom.type.symbolic_shape) == 2:
strides = [simplify_index(custom.index[dim]).stride for dim in custom.index]
strides = [
simplify_index(custom.register_index[dim]).stride
for dim in custom.register_index
]
elements_per_thread = [
simplify_index(custom.index[dim]).size for dim in custom.index
simplify_index(custom.register_index[dim]).size
for dim in custom.register_index
]
strides = [x for x, y in zip(strides, elements_per_thread) if y > 1]
num_strided_accesses = sum(1 for stride in strides if stride > 1)
Expand All @@ -58,7 +62,7 @@ def has_strided_access(node: fx.Node) -> bool:
for operator in strided_operators:
custom = get_custom(operator)
simplified_index = {
dim: simplify_index(custom.index[dim]) for dim in custom.index
dim: simplify_index(custom.register_index[dim]) for dim in custom.index
}
max_stride = int(max(simplified_index[dim].stride for dim in simplified_index))
shape = get_vector_shape(trace, hw_constraint, custom.type.symbolic_shape)
Expand Down
10 changes: 3 additions & 7 deletions shark_turbine/kernel/wave/minimize_global_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
)
from .._support.tracing import CapturedTrace
from .._support.indexing import IndexingContext, IndexSequence, IndexSymbol, IndexExpr
from ..ops.wave_ops import Read, Write, Output, get_custom
from ..ops.wave_ops import Read, Write, get_custom
from ..lang.global_symbols import *
from .utils import delinearize_index, DCE, remove_global_indexing, subs_idxc
from .utils import delinearize_index, DCE, subs_idxc
from math import prod
import torch.fx as fx
from collections import defaultdict
Expand Down Expand Up @@ -106,7 +106,6 @@ def add_optimized_nodes(
optimizable_loads: dict[fx.Node, tuple[int, Read]],
constraint_tile_size: dict[IndexSymbol, int],
hardware_constraint: HardwareConstraint,
tilingConstraints: list[TilingConstraint],
max_elements_per_load: int,
load_elems_per_thread: int,
) -> list[fx.Node]:
Expand Down Expand Up @@ -140,9 +139,7 @@ def add_optimized_nodes(
write = Write(
read, custom_user.memory, load_elems_per_thread
).add_to_graph(custom.graph)
write.index = remove_global_indexing(
read.index, tilingConstraints
)
write.index = read.index
optimized_writes[custom_user.memory].append(write)
break
return optimized_writes
Expand Down Expand Up @@ -212,7 +209,6 @@ def minimize_global_loads(trace: CapturedTrace, constraints: list[Constraint]):
optimizable_loads,
constraint_tile_size,
hardware_constraint,
[c for c in constraints if isinstance(c, TilingConstraint)],
max_elements_per_load,
load_elems_per_thread,
)
Expand Down
26 changes: 26 additions & 0 deletions shark_turbine/kernel/wave/shared_memory_indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from .._support.tracing import CapturedTrace
from ..ops.wave_ops import Read, Write, get_custom
from ..lang.global_symbols import *
from .utils import remove_global_indexing
from .constraints import Constraint, TilingConstraint
import torch.fx as fx


def apply_shared_memory_indexing_corrections(
trace: CapturedTrace, constraints: list[Constraint]
):
"""
This function removes global indexing from shared memory reads and writes.
Global indexing is an indexing that arises from Workgroup constraints
and Tiling constraints.
"""
tiling_constraints = [c for c in constraints if isinstance(c, TilingConstraint)]

def is_shared_memory_read_or_write(node: fx.Node):
custom = get_custom(node)
if isinstance(custom, (Read, Write)):
if custom.memory_type.address_space == SHARED_ADDRESS_SPACE:
custom.index = remove_global_indexing(custom.index, tiling_constraints)
return False

trace.walk(is_shared_memory_read_or_write)
4 changes: 4 additions & 0 deletions shark_turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..ops import wave_ops
from ..ops.wave_ops import Reduction, CustomOp, get_custom
from .index_sequence_analysis import partition_strided_operators
from .shared_memory_indexing import apply_shared_memory_indexing_corrections
from .register_analysis import determine_register_shape
from .._support.indexing import IndexingContext, IndexExpr
import shark_turbine.kernel.lang as tkl
Expand Down Expand Up @@ -203,6 +204,9 @@ def _trace_and_get_kernel_signature(
# Optimizations.
minimize_global_loads(graph, self.constraints)

# Apply shared memory indexing corrections.
apply_shared_memory_indexing_corrections(graph, self.constraints)

# Partition strided operators.
partition_strided_operators(graph, self.constraints)

Expand Down
Loading