Skip to content

Commit

Permalink
Fix TKW gemm tests (#36)
Browse files Browse the repository at this point in the history
* Use TKW `create_vmfb_file` option to get vmfb file directly
* Get scheduling params from the TK itself
* Disable scheduling for large K as compilation taking too long
* Support case when `out_type` != `accumulator_type`
* Update `exec_args` as code generated by TK expected out param to be
present
* Gracefully handle TK errors instead of terminating entire script

---------

Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
  • Loading branch information
Hardcode84 authored Dec 18, 2024
1 parent 6498e61 commit c3bdf8e
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 21 deletions.
4 changes: 3 additions & 1 deletion gemmbench/gemm_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,14 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
f"--device={device}",
"--device_allocator=caching",
f"--module={vmfb_filename}",
"--benchmark_repetitions=3",
f"--input={inp1}",
f"--input={inp2}",
"--benchmark_repetitions=3",
]

if tk:
out_shape = config.get_out()
exec_args.append(f"--input={out_shape}")
exec_args += ["--function=isolated_benchmark"]
else:
exec_args += ["--function=main"]
Expand Down
74 changes: 54 additions & 20 deletions gemmbench/gemm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel.wave.utils import (
get_default_run_config,
get_default_scheduling_params,
)
import torch
import traceback


@dataclass
Expand Down Expand Up @@ -38,6 +43,9 @@ def get_inp2(self) -> str:
return f"{self.N}x{self.K}x{self.operand_element_type}"
return f"{self.K}x{self.N}x{self.operand_element_type}"

def get_out(self) -> str:
return f"{self.M}x{self.N}x{self.result_element_type}"

def get_byte_count(self) -> int:
dtype_to_bytes = {
"f32": 4,
Expand Down Expand Up @@ -153,13 +161,29 @@ def get_tk_tuned_config(config: GemmConfig) -> TkTunedConfig:
# Default config
return TkTunedConfig(64, 64, 32, 2, 2, 1, 2, 2, 2, 1, 1, 2)

def _convert_dtype(dtype: str):
dtypes = {
"i8": tkl.i8,
"i16": tkl.i16,
"i32": tkl.i32,
"i64": tkl.i64,
"f16": tkl.f16,
"f32": tkl.f32,
"f64": tkl.f64,
"bf16": tkl.bf16,
}
return dtypes[dtype]


def generate_tk_mlir(config: GemmConfig):

def generate_tk_mlir(config: GemmConfig, vmfb_file: Path):
# TODO: Enable waves_per_eu
# TODO: Use scheduling barriers with LLVM patch
tc = get_tk_tuned_config(config)
assert config.operand_element_type == 'f16', "Unsupported problem"
assert config.accumulator_element_type == 'f32', "Unsupported problem"

res_dtype = _convert_dtype(config.result_element_type)
# Input sizes
M = tkl.sym.M
N = tkl.sym.N
Expand Down Expand Up @@ -197,7 +221,7 @@ def generate_tk_mlir(config: GemmConfig):
def gemm(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, res_dtype],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

Expand All @@ -213,14 +237,14 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
acc = tkw.mma(a_reg, b_reg, acc)
return acc

if res_dtype != tkl.f32:
repeat = tkw.cast(repeat, res_dtype)

# repeat represents the results of the loop
tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

shape = [config.M, config.N, config.K]
operand_element_type_map = {
"f16": torch.float16,
}
operand_element_type = operand_element_type_map[config.operand_element_type]
schedule = (config.K < 4096)

hyperparams = {
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
Expand All @@ -232,23 +256,21 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
M: shape[0],
N: shape[1],
K: shape[2],
READ_SHARED_DELAY: tc.DELAY_SHARED,
WRITE_SHARED_DELAY: tc.DELAY_SHARED,
READ_GLOBAL_DELAY: tc.DELAY_GLOBAL,
WRITE_GLOBAL_DELAY: tc.DELAY_GLOBAL,
MMA_DELAY: tc.DELAY_MMA,
SHARED_MEMORY_UNITS: tc.SHARED_UNITS,
GLOBAL_MEMORY_UNITS: tc.GLOBAL_UNITS,
MMA_UNITS: tc.MMA_UNITS,
}
hyperparams.update(get_default_scheduling_params())
# config = get_default_run_config() TODO: detects device as CPU for some reason
config = {"backend": "rocm", "device": "hip", "target": "gfx942"}


# TODO: Scheduling is taking too long time with large K.
with tk.gen.TestLaunchContext(
hyperparams, canonicalize=True, run=True, run_config=config, schedule=True,
hyperparams,
canonicalize=True,
create_vmfb_file=vmfb_file,
run_config=config,
schedule=schedule,
):
a = torch.randn(shape[0], shape[2], dtype=operand_element_type)
b = torch.randn(shape[1], shape[2], dtype=operand_element_type)
c = torch.zeros(shape[0], shape[1], dtype=torch.float32)
mb = gemm(a, b, c)
mb = gemm()

return mb.module_op.get_asm()

Expand All @@ -265,14 +287,26 @@ def compile_gemm_config(

# Generate mlir content
if tk:
mlir_content = generate_tk_mlir(config)
try:
mlir_content = generate_tk_mlir(config, vmfb_file)
except Exception as e:
traceback.print_exc()
error_file = vmfb_dir / (config.get_name() + "_error.txt")
print(f"Failed to compile {config.get_name()}. Error dumped in {error_file}")
with open(error_file, "w") as f:
f.write(str(e))
f.write(traceback.format_exc())
return mlir_file, None
else:
mlir_content = generate_mlir(config)

# Write MLIR content to file
with open(mlir_file, "w") as f:
f.write(mlir_content)

if tk:
return mlir_file, vmfb_file

# Compile MLIR to VMFB
exec_args = [
"iree-compile",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy
tqdm
matplotlib
torch>=2.3.0

0 comments on commit c3bdf8e

Please sign in to comment.