Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] set_symbol and apply_expr ops #382

Merged
merged 12 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions iree/turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Dict,
Tuple,
)
from types import FunctionType

from ..compiler.ir import Operation

Expand Down Expand Up @@ -114,6 +115,8 @@ def create_arg(self, a):
return a
if isinstance(a, IndexMapping):
return a
if isinstance(a, FunctionType):
return a
return super().create_arg(a)


Expand Down
38 changes: 38 additions & 0 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ def write(
...


def apply_expr(value: "Register", expr: Callable) -> "Register":
...


def set_symbol(symbol: IndexExpr, value: "Register"):
...


def exp2(src: "Register") -> "Register":
...

Expand Down Expand Up @@ -1380,6 +1388,36 @@ def is_contiguous_vec(self) -> bool:
)


@define_op("apply_expr")
@dataclass
class ApplyExpr(CustomOp):
register_: fx.Proxy
expr: Callable

@property
def type(self) -> "Register":
return get_custom(self.register_).type

@property
def indexing_dims(self) -> list[IndexSymbol]:
return get_custom(self.register_).indexing_dims


@define_op("set_symbol")
@dataclass
class SetSymbol(CustomOp):
symbol: IndexExpr
register_: fx.Proxy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering why mark these as type fx.Proxy instead of fx.Node?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Half of the ops are using fx.Proxy intead fx.Node, we should probably align all of them.


@property
def type(self) -> "Register":
return get_custom(self.register_).type

@property
def indexing_dims(self) -> list[IndexSymbol]:
return get_custom(self.register_).indexing_dims


@define_py_op(operator.getitem)
@define_op("get_result")
@dataclass
Expand Down
76 changes: 56 additions & 20 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from collections import namedtuple

from .symbolic_constraints import SymbolicAlias

from ..compiler.ir import (
Attribute,
DenseElementsAttr,
Expand Down Expand Up @@ -49,30 +48,32 @@
# TK infrastructure imports.
from iree.turbine.kernel.lang.global_symbols import *
from ..ops.wave_ops import (
write,
broadcast,
register,
mma,
shuffle,
read,
reduction,
exp2,
log2,
reciprocal,
CustomOp,
abs,
maximum,
get_custom,
get_result,
allocate,
shared_memory_barrier,
apply_expr,
broadcast,
cast,
exp2,
extract,
extract_slice,
CustomOp,
scheduling_barrier,
scheduling_group_barrier,
cast,
get_custom,
get_result,
log2,
maximum,
mma,
permute,
read,
reciprocal,
reduction,
register,
reshape,
scheduling_barrier,
scheduling_group_barrier,
set_symbol,
shared_memory_barrier,
shuffle,
write,
)
from ..lang.wave_types import IndexMapping, IndexSymbol
from ..compiler.base import CodegenError, ValidationError, NDEBUG
Expand All @@ -96,7 +97,7 @@
from .utils import subs_idxc, find_index_bounds, get_hardware_vector_map

# Indexing imports.
from .._support.indexing import IndexingContext, IndexExpr, IndexSequence
from .._support.indexing import IndexingContext, IndexExpr, IndexSequence, index_symbol
from .scheduling.resources import get_scheduling_mask


Expand Down Expand Up @@ -871,6 +872,7 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
raise ValidationError("codegen expected write to have index attr.")

index = node.index

input_shape = _get_symbolic_shape(register)
output_shape = _get_symbolic_shape(memory)
if get_custom(node).has_identity_mapping():
Expand Down Expand Up @@ -907,6 +909,40 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
vector_d.scatter(kb_dest, start_indices, offsets_vec, mask, insert_vector)


@handle_op(apply_expr)
def handle_apply_expr(emitter: WaveEmitter, node: fx.Node):
try:
register, expr = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e

APPLY_EXPR_ARG = index_symbol("$APPLY_EXPR_ARG")
expr = expr(APPLY_EXPR_ARG)

register = cast_vector(emitter, register, element_type=IndexType.get())

subs = add_emitter_subs(emitter)
subs[APPLY_EXPR_ARG] = register
result = gen_sympy_index(subs, expr)
emitter.bind_node_proxy(node, IRProxyValue(result))


@handle_op(set_symbol)
def handle_set_symbol(emitter: WaveEmitter, node: fx.Node):
try:
symbol, register = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e

register = cast_vector(emitter, register, element_type=IndexType.get())
src_type = register.type
assert (
src_type.rank == 1 and src_type.shape[0] == 1
), f"Only size 1 vectors are supported: got {register.type}"
register = vector_d.extract(register, static_position=[0], dynamic_position=[])
emitter.dynamic_dims[symbol] = register


###############################################################################
# Contraction/MMA Ops
###############################################################################
Expand Down
3 changes: 2 additions & 1 deletion iree/turbine/kernel/wave/expansion/expansion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def get_dim_scaling(
or (tile_size / wave_count) % vector_size != 0
):
raise ValueError(
"Tile size must be divisible by wave count and vector size"
f"Tile size must be divisible by wave count and vector size, got: "
f"tile_size={tile_size}, wave_count={wave_count}, vector_size={vector_size}"
)
dim_scaling[constraint.dim] = tile_size // wave_count // vector_size

Expand Down
127 changes: 127 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,133 @@ def test(
# CHECK: vector.store %[[D16]], %{{.*}}[%[[D5]], %[[D20]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16>


@run_test
def test_read_write_dynamic_symbol():
S = tkl.sym.S
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(1, 1, 1),
vector_shapes={M: 1, N: 1, S: 1},
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)
mapping = tkw.IndexMapping(
num_iterators=2,
inputs={S: S, N: j},
outputs={S: i, N: j},
dynamic_val_mappings={S: i, N: j},
)

@tkw.wave(constraints)
def test_dyn_symbol(
a: tkl.Memory[S, N, ADDRESS_SPACE, tkl.f16],
off: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
offset = tkw.read(off, elements_per_thread=1)
tkw.set_symbol(S, offset)
res = tkw.read(
a,
mapping=mapping,
elements_per_thread=1,
)
tkw.write(res, b, elements_per_thread=1)

with codegen_test_context(
canonicalize=True,
dynamic_symbols=[S],
additional_symbols={BLOCK_M: 1, BLOCK_N: 1},
):
a = torch.randn(16, 16, dtype=torch.float16)
off = torch.randint(16, (16, 16), dtype=torch.int32)
b = torch.zeros(16, 16, dtype=torch.float16)
print(test_dyn_symbol(a, off, b).module_op)

# CHECK-LABEL: func.func @test_dyn_symbol
# CHECK-SAME: (%[[ARG0:.*]]: !stream.binding, %[[ARG1:.*]]: !stream.binding, %[[ARG2:.*]]: !stream.binding, %[[ARG3:.*]]: index)
# CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
# CHECK: %[[A2:.*]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<16x16xi32, strided<[16, 1], offset: ?>>
# CHECK: %[[O1:.*]] = vector.load %[[A2]][%[[M:.*]], %[[N:.*]]] : memref<16x16xi32, strided<[16, 1], offset: ?>>, vector<1xi32>
# CHECK: %[[O2:.*]] = arith.index_cast %[[O1]] : vector<1xi32> to vector<1xindex>
# CHECK: %[[O3:.*]] = vector.extract %[[O2]][0] : index from vector<1xindex>
# CHECK: %[[A1:.*]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<?x16xf16, strided<[16, 1], offset: ?>>{%arg3}
# CHECK: %[[RES:.*]] = vector.load %[[A1]][%[[O3]], %[[N]]] : memref<?x16xf16, strided<[16, 1], offset: ?>>, vector<1xf16>
# CHECK: %[[A3:.*]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<16x16xf16, strided<[16, 1], offset: ?>>
# CHECK: vector.store %[[RES]], %[[A3]][%[[M]], %[[N]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<1xf16>


@run_test
def test_read_write_dynamic_symbol_expr():
S = tkl.sym.S
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(1, 1, 1),
vector_shapes={M: 1, N: 1, S: 1},
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)
mapping = tkw.IndexMapping(
num_iterators=2,
inputs={S: S, N: j},
outputs={S: i, N: j},
dynamic_val_mappings={S: i, N: j},
)

@tkw.wave(constraints)
def test_dyn_expr(
a: tkl.Memory[S, N, ADDRESS_SPACE, tkl.f16],
off: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
offset = tkw.read(off, elements_per_thread=1)
offset = tkw.apply_expr(offset, lambda a: M - a - 1)
tkw.set_symbol(S, offset)
res = tkw.read(
a,
mapping=mapping,
elements_per_thread=1,
)
tkw.write(res, b, elements_per_thread=1)

with codegen_test_context(
canonicalize=True,
dynamic_symbols=[S],
additional_symbols={BLOCK_M: 1, BLOCK_N: 1},
):
a = torch.randn(16, 16, dtype=torch.float16)
off = torch.randint(16, (16, 16), dtype=torch.int32)
b = torch.zeros(16, 16, dtype=torch.float16)
print(test_dyn_expr(a, off, b).module_op)

# CHECK-LABEL: func.func @test_dyn_expr
# CHECK-SAME: (%[[ARG0:.*]]: !stream.binding, %[[ARG1:.*]]: !stream.binding, %[[ARG2:.*]]: !stream.binding, %[[ARG3:.*]]: index)
# CHECK-DAG: %[[CST:.*]] = arith.constant dense<15> : vector<1xindex>
# CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
# CHECK: %[[A2:.*]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<16x16xi32, strided<[16, 1], offset: ?>>
# CHECK: %[[O1:.*]] = vector.load %[[A2]][%[[M:.*]], %[[N:.*]]] : memref<16x16xi32, strided<[16, 1], offset: ?>>, vector<1xi32>
# CHECK: %[[O2:.*]] = arith.index_cast %[[O1]] : vector<1xi32> to vector<1xindex>
# CHECK: %[[O3:.*]] = arith.subi %[[CST]], %[[O2]] : vector<1xindex>
# CHECK: %[[O4:.*]] = vector.extract %[[O3]][0] : index from vector<1xindex>
# CHECK: %[[A1:.*]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<?x16xf16, strided<[16, 1], offset: ?>>{%arg3}
# CHECK: %[[RES:.*]] = vector.load %[[A1]][%[[O4]], %[[N]]] : memref<?x16xf16, strided<[16, 1], offset: ?>>, vector<1xf16>
# CHECK: %[[A3:.*]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<16x16xf16, strided<[16, 1], offset: ?>>
# CHECK: vector.store %[[RES]], %[[A3]][%[[M]], %[[N]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<1xf16>


@run_test
def test_dynamic_copy():
constraints: list[tkw.Constraint] = [
Expand Down
Loading
Loading