From 9e79f4efefb5ba6407dce0f6087eb770e8ad4992 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 27 Nov 2024 12:35:03 +0100 Subject: [PATCH] [TKW] Move IGEMM conv impl to common place. (#295) Move TKW IGEMM conv impl from the test folder to some common place to allow it to be reused outside the tests (e.g. in iree-kernel-benchmark). Not sure what the proper place for it, suggestions are welcome. --------- Signed-off-by: Ivan Butygin --- .../turbine/kernel/wave/templates/__init__.py | 5 + iree/turbine/kernel/wave/templates/conv.py | 167 ++++++++++++++++++ tests/kernel/wave/wave_e2e_test.py | 137 ++------------ 3 files changed, 187 insertions(+), 122 deletions(-) create mode 100644 iree/turbine/kernel/wave/templates/__init__.py create mode 100644 iree/turbine/kernel/wave/templates/conv.py diff --git a/iree/turbine/kernel/wave/templates/__init__.py b/iree/turbine/kernel/wave/templates/__init__.py new file mode 100644 index 000000000..c68e0440e --- /dev/null +++ b/iree/turbine/kernel/wave/templates/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/iree/turbine/kernel/wave/templates/conv.py b/iree/turbine/kernel/wave/templates/conv.py new file mode 100644 index 000000000..a5887a294 --- /dev/null +++ b/iree/turbine/kernel/wave/templates/conv.py @@ -0,0 +1,167 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from typing import Any, Optional +from iree.turbine.kernel.lang.global_symbols import * + + +def get_igemm_conv2d( + layout: str, + n: int, + h: int, + w: int, + c: int, + hf: int, + wf: int, + nf: int, + stride: int, + mem_space: tkl.IndexSymbol = SHARED_ADDRESS_SPACE, + block_m: Optional[int] = None, + block_n: Optional[int] = None, + block_k: Optional[int] = None, + ratio_m: Optional[int] = None, + ratio_n: Optional[int] = None, +) -> tuple["LaunchableWave", dict[tkl.IndexSymbol, Any]]: + cf = c + padding = 0 # TODO: only pad=0 is supported for now + + sym = tkl.sym + N, C, H, W = sym.N, sym.C, sym.H, sym.W + NF, HF, WF = sym.NF, sym.HF, sym.WF + + H_OUT = (H + 2 * padding - HF) // stride + 1 + W_OUT = (W + 2 * padding - WF) // stride + 1 + SZ_OUT = H_OUT * W_OUT + + K = HF * WF * C + M = SZ_OUT * N + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + + # Align C dim reading pattern to be contiguous for nhwc_hwcf pattern. + x_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={ + N: i // SZ_OUT, + C: j % C, + H: (i % SZ_OUT) % W_OUT * stride + (j // C) % WF, + W: (i % SZ_OUT) // W_OUT * stride + (j // C) // WF, + }, + outputs={M: i, K: j}, + ) + w_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={NF: i % NF, C: j % C, HF: (j // C) % WF, WF: (j // C) // WF}, + outputs={NF: i, K: j}, + ) + out_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, NF: j}, + outputs={ + N: i // SZ_OUT, + NF: j, + H_OUT: (i % SZ_OUT) % W_OUT, + W_OUT: (i % SZ_OUT) // W_OUT, + }, + ) + + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD + + if layout == "nchw_fchw": + x_type = tkl.Memory[N, C, H, W, ADDRESS_SPACE, tkl.f16] + we_type = tkl.Memory[NF, C, HF, WF, ADDRESS_SPACE, tkl.f16] + out_type = tkl.Memory[N, NF, H_OUT, W_OUT, GLOBAL_ADDRESS_SPACE, tkl.f32] + elif layout == "nhwc_hwcf": + x_type = tkl.Memory[N, H, W, C, ADDRESS_SPACE, tkl.f16] + we_type = tkl.Memory[HF, WF, C, NF, ADDRESS_SPACE, tkl.f16] + out_type = tkl.Memory[N, H_OUT, W_OUT, NF, GLOBAL_ADDRESS_SPACE, tkl.f32] + else: + raise ValueError(f"Unsupported layout: {layout}") + + if block_m is None: + block_m = 64 + + if block_n is None: + block_n = 128 + + if block_k is None: + block_k = 32 + + if ratio_m is None: + ratio_m = 2 + + if ratio_n is None: + ratio_n = 2 + + # Expose user-constraints + constraints: list[tkw.Constraint] = [] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(NF, BLOCK_N, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / ratio_m)] + constraints += [tkw.WaveConstraint(NF, BLOCK_N / ratio_n)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(ratio_n, ratio_m, 1), + ) + ] + + @tkw.wave(constraints) + def conv( + x: x_type, + we: we_type, + out: out_type, + ): + c_reg = tkl.Register[M, NF, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: + a_reg = tkw.read( + x, + mapping=x_mapping, + elements_per_thread=ELEMS_PER_THREAD, + ) + b_reg = tkw.read( + we, + mapping=w_mapping, + elements_per_thread=ELEMS_PER_THREAD, + ) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write( + repeat, out, mapping=out_mapping, elements_per_thread=ELEMS_PER_THREAD + ) + + symbols = { + N: n, + C: c, + W: w, + H: h, + NF: nf, + WF: wf, + HF: hf, + BLOCK_M: block_m, + BLOCK_N: block_n, + BLOCK_K: block_k, + ELEMS_PER_THREAD: 4, + ADDRESS_SPACE: mem_space, + } + + return conv, symbols diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 9c462be62..127dc6947 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -8,6 +8,7 @@ import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.wave.wave_sim import wave_sim +from iree.turbine.kernel.wave.templates.conv import get_igemm_conv2d from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel.wave.iree_utils import generate_iree_ref from iree.turbine.kernel.wave.utils import ( @@ -912,114 +913,27 @@ def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, layout, request): convRef.weight = torch.nn.Parameter(we) out_ref = convRef(x).detach().to(torch.float32) - sym = tkl.sym - N, C, H, W = sym.N, sym.C, sym.H, sym.W - NF, HF, WF = sym.NF, sym.HF, sym.WF - - H_OUT = (H + 2 * padding - HF) // stride + 1 - W_OUT = (W + 2 * padding - WF) // stride + 1 - SZ_OUT = H_OUT * W_OUT - - K = HF * WF * C - M = SZ_OUT * N - - i = tkw.IndexMapping.iterator(0) - j = tkw.IndexMapping.iterator(1) - - # Align C dim reading pattern to be contiguous for nhwc_hwcf pattern. - x_mapping = tkw.IndexMapping( - num_iterators=2, - inputs={ - N: i // SZ_OUT, - C: j % C, - H: (i % SZ_OUT) % W_OUT * stride + (j // C) % WF, - W: (i % SZ_OUT) // W_OUT * stride + (j // C) // WF, - }, - outputs={M: i, K: j}, - ) - w_mapping = tkw.IndexMapping( - num_iterators=2, - inputs={NF: i % NF, C: j % C, HF: (j // C) % WF, WF: (j // C) // WF}, - outputs={NF: i, K: j}, - ) - out_mapping = tkw.IndexMapping( - num_iterators=2, - inputs={M: i, NF: j}, - outputs={ - N: i // SZ_OUT, - NF: j, - H_OUT: (i % SZ_OUT) % W_OUT, - W_OUT: (i % SZ_OUT) // W_OUT, - }, - ) - - # Workgroup tile sizes - BLOCK_M = tkl.sym.BLOCK_M - BLOCK_N = tkl.sym.BLOCK_N - BLOCK_K = tkl.sym.BLOCK_K - # Address space (for GPU, shared(1) or global(0)) - ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE - # Other hyperparameters - ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD - if layout == "nchw_fchw": - x_type = tkl.Memory[N, C, H, W, ADDRESS_SPACE, tkl.f16] - we_type = tkl.Memory[NF, C, HF, WF, ADDRESS_SPACE, tkl.f16] - out_type = tkl.Memory[N, NF, H_OUT, W_OUT, GLOBAL_ADDRESS_SPACE, tkl.f32] + pass # Nothing elif layout == "nhwc_hwcf": - x_type = tkl.Memory[N, H, W, C, ADDRESS_SPACE, tkl.f16] - we_type = tkl.Memory[HF, WF, C, NF, ADDRESS_SPACE, tkl.f16] - out_type = tkl.Memory[N, H_OUT, W_OUT, NF, GLOBAL_ADDRESS_SPACE, tkl.f32] x = torch.permute(x, (0, 2, 3, 1)).contiguous() we = torch.permute(we, (2, 3, 1, 0)).contiguous() out_ref = torch.permute(out_ref, (0, 2, 3, 1)).contiguous() else: raise ValueError(f"Invalid layout: {layout}") - ratio_m = 2 - ratio_n = 2 - - # Expose user-constraints - constraints: list[tkw.Constraint] = [] - constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] - constraints += [tkw.WorkgroupConstraint(NF, BLOCK_N, 0)] - constraints += [tkw.WaveConstraint(M, BLOCK_M / ratio_m)] - constraints += [tkw.WaveConstraint(NF, BLOCK_N / ratio_n)] - constraints += [tkw.TilingConstraint(K, BLOCK_K)] - - constraints += [ - tkw.HardwareConstraint( - threads_per_wave=64, - waves_per_block=(ratio_n, ratio_m, 1), - ) - ] - - @tkw.wave(constraints) - def conv( - x: x_type, - we: we_type, - out: out_type, - ): - c_reg = tkl.Register[M, NF, tkl.f32](0.0) - - @tkw.reduction(K, init_args=[c_reg]) - def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: - a_reg = tkw.read( - x, - mapping=x_mapping, - elements_per_thread=ELEMS_PER_THREAD, - ) - b_reg = tkw.read( - we, - mapping=w_mapping, - elements_per_thread=ELEMS_PER_THREAD, - ) - acc = tkw.mma(a_reg, b_reg, acc) - return acc - - tkw.write( - repeat, out, mapping=out_mapping, elements_per_thread=ELEMS_PER_THREAD - ) + conv, symbols = get_igemm_conv2d( + layout=layout, + n=n, + h=h, + w=w, + c=c, + hf=hf, + wf=wf, + nf=nf, + stride=stride, + mem_space=mem_space, + ) config = get_default_run_config() @@ -1037,28 +951,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: ) with tk.gen.TestLaunchContext( - { - N: n, - C: c, - W: w, - H: h, - NF: nf, - WF: wf, - HF: hf, - BLOCK_M: 64, - BLOCK_N: 128, - BLOCK_K: 32, - ELEMS_PER_THREAD: 4, - ADDRESS_SPACE: mem_space, - READ_SHARED_DELAY: 1, - WRITE_SHARED_DELAY: 1, - READ_GLOBAL_DELAY: 2, - WRITE_GLOBAL_DELAY: 2, - MMA_DELAY: 1, - SHARED_MEMORY_UNITS: 4, - GLOBAL_MEMORY_UNITS: 4, - MMA_UNITS: 4, - }, + symbols, canonicalize=True, run=True, run_bench=run_bench,