diff --git a/iree/turbine/kernel/_support/tracing.py b/iree/turbine/kernel/_support/tracing.py index 857cdb345..7256539e6 100644 --- a/iree/turbine/kernel/_support/tracing.py +++ b/iree/turbine/kernel/_support/tracing.py @@ -8,6 +8,7 @@ Dict, Tuple, ) +from types import FunctionType from ..compiler.ir import Operation @@ -114,6 +115,8 @@ def create_arg(self, a): return a if isinstance(a, IndexMapping): return a + if isinstance(a, FunctionType): + return a return super().create_arg(a) diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 2f5680a00..55c143b63 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -98,6 +98,14 @@ def write( ... +def apply_expr(value: "Register", expr: Callable) -> "Register": + ... + + +def set_symbol(symbol: IndexExpr, value: "Register"): + ... + + def exp2(src: "Register") -> "Register": ... @@ -1380,6 +1388,36 @@ def is_contiguous_vec(self) -> bool: ) +@define_op("apply_expr") +@dataclass +class ApplyExpr(CustomOp): + register_: fx.Proxy + expr: Callable + + @property + def type(self) -> "Register": + return get_custom(self.register_).type + + @property + def indexing_dims(self) -> list[IndexSymbol]: + return get_custom(self.register_).indexing_dims + + +@define_op("set_symbol") +@dataclass +class SetSymbol(CustomOp): + symbol: IndexExpr + register_: fx.Proxy + + @property + def type(self) -> "Register": + return get_custom(self.register_).type + + @property + def indexing_dims(self) -> list[IndexSymbol]: + return get_custom(self.register_).indexing_dims + + @define_py_op(operator.getitem) @define_op("get_result") @dataclass diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 57084ca62..8a1f3118d 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -15,7 +15,6 @@ from collections import namedtuple from .symbolic_constraints import SymbolicAlias - from ..compiler.ir import ( Attribute, DenseElementsAttr, @@ -49,30 +48,32 @@ # TK infrastructure imports. from iree.turbine.kernel.lang.global_symbols import * from ..ops.wave_ops import ( - write, - broadcast, - register, - mma, - shuffle, - read, - reduction, - exp2, - log2, - reciprocal, + CustomOp, abs, - maximum, - get_custom, - get_result, allocate, - shared_memory_barrier, + apply_expr, + broadcast, + cast, + exp2, extract, extract_slice, - CustomOp, - scheduling_barrier, - scheduling_group_barrier, - cast, + get_custom, + get_result, + log2, + maximum, + mma, permute, + read, + reciprocal, + reduction, + register, reshape, + scheduling_barrier, + scheduling_group_barrier, + set_symbol, + shared_memory_barrier, + shuffle, + write, ) from ..lang.wave_types import IndexMapping, IndexSymbol from ..compiler.base import CodegenError, ValidationError, NDEBUG @@ -96,7 +97,7 @@ from .utils import subs_idxc, find_index_bounds, get_hardware_vector_map # Indexing imports. -from .._support.indexing import IndexingContext, IndexExpr, IndexSequence +from .._support.indexing import IndexingContext, IndexExpr, IndexSequence, index_symbol from .scheduling.resources import get_scheduling_mask @@ -871,6 +872,7 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): raise ValidationError("codegen expected write to have index attr.") index = node.index + input_shape = _get_symbolic_shape(register) output_shape = _get_symbolic_shape(memory) if get_custom(node).has_identity_mapping(): @@ -907,6 +909,40 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): vector_d.scatter(kb_dest, start_indices, offsets_vec, mask, insert_vector) +@handle_op(apply_expr) +def handle_apply_expr(emitter: WaveEmitter, node: fx.Node): + try: + register, expr = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + + APPLY_EXPR_ARG = index_symbol("$APPLY_EXPR_ARG") + expr = expr(APPLY_EXPR_ARG) + + register = cast_vector(emitter, register, element_type=IndexType.get()) + + subs = add_emitter_subs(emitter) + subs[APPLY_EXPR_ARG] = register + result = gen_sympy_index(subs, expr) + emitter.bind_node_proxy(node, IRProxyValue(result)) + + +@handle_op(set_symbol) +def handle_set_symbol(emitter: WaveEmitter, node: fx.Node): + try: + symbol, register = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + + register = cast_vector(emitter, register, element_type=IndexType.get()) + src_type = register.type + assert ( + src_type.rank == 1 and src_type.shape[0] == 1 + ), f"Only size 1 vectors are supported: got {register.type}" + register = vector_d.extract(register, static_position=[0], dynamic_position=[]) + emitter.dynamic_dims[symbol] = register + + ############################################################################### # Contraction/MMA Ops ############################################################################### diff --git a/iree/turbine/kernel/wave/expansion/expansion_utils.py b/iree/turbine/kernel/wave/expansion/expansion_utils.py index 84df53e65..1986772a6 100644 --- a/iree/turbine/kernel/wave/expansion/expansion_utils.py +++ b/iree/turbine/kernel/wave/expansion/expansion_utils.py @@ -87,7 +87,8 @@ def get_dim_scaling( or (tile_size / wave_count) % vector_size != 0 ): raise ValueError( - "Tile size must be divisible by wave count and vector size" + f"Tile size must be divisible by wave count and vector size, got: " + f"tile_size={tile_size}, wave_count={wave_count}, vector_size={vector_size}" ) dim_scaling[constraint.dim] = tile_size // wave_count // vector_size diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index b994a4224..e69d785d9 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -593,6 +593,133 @@ def test( # CHECK: vector.store %[[D16]], %{{.*}}[%[[D5]], %[[D20]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16> +@run_test +def test_read_write_dynamic_symbol(): + S = tkl.sym.S + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: 1, N: 1, S: 1}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + mapping = tkw.IndexMapping( + num_iterators=2, + inputs={S: S, N: j}, + outputs={S: i, N: j}, + dynamic_val_mappings={S: i, N: j}, + ) + + @tkw.wave(constraints) + def test_dyn_symbol( + a: tkl.Memory[S, N, ADDRESS_SPACE, tkl.f16], + off: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + offset = tkw.read(off, elements_per_thread=1) + tkw.set_symbol(S, offset) + res = tkw.read( + a, + mapping=mapping, + elements_per_thread=1, + ) + tkw.write(res, b, elements_per_thread=1) + + with codegen_test_context( + canonicalize=True, + dynamic_symbols=[S], + additional_symbols={BLOCK_M: 1, BLOCK_N: 1}, + ): + a = torch.randn(16, 16, dtype=torch.float16) + off = torch.randint(16, (16, 16), dtype=torch.int32) + b = torch.zeros(16, 16, dtype=torch.float16) + print(test_dyn_symbol(a, off, b).module_op) + + # CHECK-LABEL: func.func @test_dyn_symbol + # CHECK-SAME: (%[[ARG0:.*]]: !stream.binding, %[[ARG1:.*]]: !stream.binding, %[[ARG2:.*]]: !stream.binding, %[[ARG3:.*]]: index) + # CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + # CHECK: %[[A2:.*]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<16x16xi32, strided<[16, 1], offset: ?>> + # CHECK: %[[O1:.*]] = vector.load %[[A2]][%[[M:.*]], %[[N:.*]]] : memref<16x16xi32, strided<[16, 1], offset: ?>>, vector<1xi32> + # CHECK: %[[O2:.*]] = arith.index_cast %[[O1]] : vector<1xi32> to vector<1xindex> + # CHECK: %[[O3:.*]] = vector.extract %[[O2]][0] : index from vector<1xindex> + # CHECK: %[[A1:.*]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref>{%arg3} + # CHECK: %[[RES:.*]] = vector.load %[[A1]][%[[O3]], %[[N]]] : memref>, vector<1xf16> + # CHECK: %[[A3:.*]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<16x16xf16, strided<[16, 1], offset: ?>> + # CHECK: vector.store %[[RES]], %[[A3]][%[[M]], %[[N]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<1xf16> + + +@run_test +def test_read_write_dynamic_symbol_expr(): + S = tkl.sym.S + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: 1, N: 1, S: 1}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + mapping = tkw.IndexMapping( + num_iterators=2, + inputs={S: S, N: j}, + outputs={S: i, N: j}, + dynamic_val_mappings={S: i, N: j}, + ) + + @tkw.wave(constraints) + def test_dyn_expr( + a: tkl.Memory[S, N, ADDRESS_SPACE, tkl.f16], + off: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + offset = tkw.read(off, elements_per_thread=1) + offset = tkw.apply_expr(offset, lambda a: M - a - 1) + tkw.set_symbol(S, offset) + res = tkw.read( + a, + mapping=mapping, + elements_per_thread=1, + ) + tkw.write(res, b, elements_per_thread=1) + + with codegen_test_context( + canonicalize=True, + dynamic_symbols=[S], + additional_symbols={BLOCK_M: 1, BLOCK_N: 1}, + ): + a = torch.randn(16, 16, dtype=torch.float16) + off = torch.randint(16, (16, 16), dtype=torch.int32) + b = torch.zeros(16, 16, dtype=torch.float16) + print(test_dyn_expr(a, off, b).module_op) + + # CHECK-LABEL: func.func @test_dyn_expr + # CHECK-SAME: (%[[ARG0:.*]]: !stream.binding, %[[ARG1:.*]]: !stream.binding, %[[ARG2:.*]]: !stream.binding, %[[ARG3:.*]]: index) + # CHECK-DAG: %[[CST:.*]] = arith.constant dense<15> : vector<1xindex> + # CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + # CHECK: %[[A2:.*]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<16x16xi32, strided<[16, 1], offset: ?>> + # CHECK: %[[O1:.*]] = vector.load %[[A2]][%[[M:.*]], %[[N:.*]]] : memref<16x16xi32, strided<[16, 1], offset: ?>>, vector<1xi32> + # CHECK: %[[O2:.*]] = arith.index_cast %[[O1]] : vector<1xi32> to vector<1xindex> + # CHECK: %[[O3:.*]] = arith.subi %[[CST]], %[[O2]] : vector<1xindex> + # CHECK: %[[O4:.*]] = vector.extract %[[O3]][0] : index from vector<1xindex> + # CHECK: %[[A1:.*]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref>{%arg3} + # CHECK: %[[RES:.*]] = vector.load %[[A1]][%[[O4]], %[[N]]] : memref>, vector<1xf16> + # CHECK: %[[A3:.*]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<16x16xf16, strided<[16, 1], offset: ?>> + # CHECK: vector.store %[[RES]], %[[A3]][%[[M]], %[[N]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<1xf16> + + @run_test def test_dynamic_copy(): constraints: list[tkw.Constraint] = [ diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 563ea0f23..7fd8a9898 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -517,6 +517,179 @@ def test( assert_close(out, out_ref) +@require_e2e +@pytest.mark.parametrize("shape", get_test_shapes("test_copy")) +def test_set_symbol(shape, request): + run_bench = request.config.getoption("--runperf") + M = tkl.sym.M + N = tkl.sym.N + S = tkl.sym.S + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + # Each workgroup works on single row of input data, and rows are further + # split into blocks of size up to 256. We have single wave per WG, + # and with default wave size of 64, each thread is operating on up to 4 + # elements. + wave_size = 64 + BLOCK_M = 1 + # Tile size cannot be dynamic, so we use a fixed value here. + + # TODO: Only ELEMS_PER_THREAD == 1 + # BLOCK_N = sympy.Max(sympy.Min(shape[1], 256), wave_size) + BLOCK_N = wave_size + ELEMS_PER_THREAD = BLOCK_N // wave_size + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=wave_size, + waves_per_block=(1, 1, 1), + vector_shapes={M: BLOCK_M, N: BLOCK_N, S: BLOCK_M}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + mapping = tkw.IndexMapping( + num_iterators=2, + inputs={S: S, N: j}, + outputs={S: i, N: j}, + ) + + dynamic_symbols = [] + dynamic_symbols_map = {} + + dynamic_symbols.append(S) + dynamic_symbols_map[S] = 0 + + @tkw.wave(constraints) + def test( + a: tkl.Memory[S, N, ADDRESS_SPACE, tkl.f16], + off: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + offset = tkw.read(off, elements_per_thread=ELEMS_PER_THREAD) + tkw.set_symbol(S, offset) + res = tkw.read( + a, + mapping=mapping, + elements_per_thread=ELEMS_PER_THREAD, + ) + tkw.write(res, b, elements_per_thread=ELEMS_PER_THREAD) + + config = get_default_run_config() + + a = device_randn(shape, dtype=torch.float16) + off = device_randint(shape[0], shape, dtype=torch.int32) + out = device_zeros(shape, dtype=torch.float16) + with tk.gen.TestLaunchContext( + { + M: shape[0], + N: shape[1], + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + dynamic_symbols=dynamic_symbols, + dynamic_symbols_map=dynamic_symbols_map, + ): + test(a, off, out) + out_ref = torch.take_along_dim(a, off.to(torch.long), dim=0) + assert_close(out, out_ref) + + +@require_e2e +@pytest.mark.parametrize("shape", get_test_shapes("test_copy")) +def test_apply_expr(shape, request): + run_bench = request.config.getoption("--runperf") + M = tkl.sym.M + N = tkl.sym.N + S = tkl.sym.S + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + # Each workgroup works on single row of input data, and rows are further + # split into blocks of size up to 256. We have single wave per WG, + # and with default wave size of 64, each thread is operating on up to 4 + # elements. + wave_size = 64 + BLOCK_M = 1 + # Tile size cannot be dynamic, so we use a fixed value here. + + # TODO: Only ELEMS_PER_THREAD == 1 + # BLOCK_N = sympy.Max(sympy.Min(shape[1], 256), wave_size) + BLOCK_N = wave_size + ELEMS_PER_THREAD = BLOCK_N // wave_size + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=wave_size, + waves_per_block=(1, 1, 1), + vector_shapes={M: BLOCK_M, N: BLOCK_N, S: BLOCK_M}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + mapping = tkw.IndexMapping( + num_iterators=2, + inputs={S: S, N: j}, + outputs={S: i, N: j}, + ) + + dynamic_symbols = [] + dynamic_symbols_map = {} + + dynamic_symbols.append(S) + dynamic_symbols_map[S] = 0 + + @tkw.wave(constraints) + def test( + a: tkl.Memory[S, N, ADDRESS_SPACE, tkl.f16], + off: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + offset = tkw.read(off, elements_per_thread=ELEMS_PER_THREAD) + offset = tkw.apply_expr(offset, lambda a: M - a - 1) + tkw.set_symbol(S, offset) + res = tkw.read( + a, + mapping=mapping, + elements_per_thread=ELEMS_PER_THREAD, + ) + tkw.write(res, b, elements_per_thread=ELEMS_PER_THREAD) + + config = get_default_run_config() + + a = device_randn(shape, dtype=torch.float16) + off = device_randint(shape[0], shape, dtype=torch.int32) + out = device_zeros(shape, dtype=torch.float16) + with tk.gen.TestLaunchContext( + { + M: shape[0], + N: shape[1], + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + dynamic_symbols=dynamic_symbols, + dynamic_symbols_map=dynamic_symbols_map, + ): + test(a, off, out) + out_ref = torch.take_along_dim(a, (shape[0] - off - 1).to(torch.long), dim=0) + assert_close(out, out_ref) + + @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_copy")) def test_offset_write(shape, request):