Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TKW gemm tests #36

Merged
merged 13 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion gemmbench/gemm_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,14 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
f"--device={device}",
"--device_allocator=caching",
f"--module={vmfb_filename}",
"--benchmark_repetitions=3",
f"--input={inp1}",
f"--input={inp2}",
"--benchmark_repetitions=3",
]

if tk:
out_shape = config.get_out()
exec_args.append(f"--input={out_shape}")
exec_args += ["--function=isolated_benchmark"]
else:
exec_args += ["--function=main"]
Expand Down
74 changes: 54 additions & 20 deletions gemmbench/gemm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel.wave.utils import (
get_default_run_config,
get_default_scheduling_params,
)
import torch
import traceback


@dataclass
Expand Down Expand Up @@ -38,6 +43,9 @@ def get_inp2(self) -> str:
return f"{self.N}x{self.K}x{self.operand_element_type}"
return f"{self.K}x{self.N}x{self.operand_element_type}"

def get_out(self) -> str:
return f"{self.M}x{self.N}x{self.result_element_type}"

def get_byte_count(self) -> int:
dtype_to_bytes = {
"f32": 4,
Expand Down Expand Up @@ -153,13 +161,29 @@ def get_tk_tuned_config(config: GemmConfig) -> TkTunedConfig:
# Default config
return TkTunedConfig(64, 64, 32, 2, 2, 1, 2, 2, 2, 1, 1, 2)

def _convert_dtype(dtype: str):
dtypes = {
"i8": tkl.i8,
"i16": tkl.i16,
"i32": tkl.i32,
"i64": tkl.i64,
"f16": tkl.f16,
"f32": tkl.f32,
"f64": tkl.f64,
"bf16": tkl.bf16,
}
return dtypes[dtype]


def generate_tk_mlir(config: GemmConfig):

def generate_tk_mlir(config: GemmConfig, vmfb_file: Path):
# TODO: Enable waves_per_eu
# TODO: Use scheduling barriers with LLVM patch
tc = get_tk_tuned_config(config)
assert config.operand_element_type == 'f16', "Unsupported problem"
assert config.accumulator_element_type == 'f32', "Unsupported problem"

res_dtype = _convert_dtype(config.result_element_type)
# Input sizes
M = tkl.sym.M
N = tkl.sym.N
Expand Down Expand Up @@ -197,7 +221,7 @@ def generate_tk_mlir(config: GemmConfig):
def gemm(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, res_dtype],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

Expand All @@ -213,14 +237,14 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
acc = tkw.mma(a_reg, b_reg, acc)
return acc

if res_dtype != tkl.f32:
repeat = tkw.cast(repeat, res_dtype)

# repeat represents the results of the loop
tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

shape = [config.M, config.N, config.K]
operand_element_type_map = {
"f16": torch.float16,
}
operand_element_type = operand_element_type_map[config.operand_element_type]
schedule = (config.K < 4096)

hyperparams = {
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
Expand All @@ -232,23 +256,21 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
M: shape[0],
N: shape[1],
K: shape[2],
READ_SHARED_DELAY: tc.DELAY_SHARED,
WRITE_SHARED_DELAY: tc.DELAY_SHARED,
READ_GLOBAL_DELAY: tc.DELAY_GLOBAL,
WRITE_GLOBAL_DELAY: tc.DELAY_GLOBAL,
MMA_DELAY: tc.DELAY_MMA,
SHARED_MEMORY_UNITS: tc.SHARED_UNITS,
GLOBAL_MEMORY_UNITS: tc.GLOBAL_UNITS,
MMA_UNITS: tc.MMA_UNITS,
}
hyperparams.update(get_default_scheduling_params())
# config = get_default_run_config() TODO: detects device as CPU for some reason
config = {"backend": "rocm", "device": "hip", "target": "gfx942"}


# TODO: Scheduling is taking too long time with large K.
with tk.gen.TestLaunchContext(
hyperparams, canonicalize=True, run=True, run_config=config, schedule=True,
hyperparams,
canonicalize=True,
create_vmfb_file=vmfb_file,
run_config=config,
schedule=schedule,
):
a = torch.randn(shape[0], shape[2], dtype=operand_element_type)
b = torch.randn(shape[1], shape[2], dtype=operand_element_type)
c = torch.zeros(shape[0], shape[1], dtype=torch.float32)
mb = gemm(a, b, c)
mb = gemm()

return mb.module_op.get_asm()

Expand All @@ -265,14 +287,26 @@ def compile_gemm_config(

# Generate mlir content
if tk:
mlir_content = generate_tk_mlir(config)
try:
mlir_content = generate_tk_mlir(config, vmfb_file)
except Exception as e:
traceback.print_exc()
error_file = vmfb_dir / (config.get_name() + "_error.txt")
print(f"Failed to compile {config.get_name()}. Error dumped in {error_file}")
with open(error_file, "w") as f:
f.write(str(e))
f.write(traceback.format_exc())
return mlir_file, None
else:
mlir_content = generate_mlir(config)

# Write MLIR content to file
with open(mlir_file, "w") as f:
f.write(mlir_content)

if tk:
return mlir_file, vmfb_file

# Compile MLIR to VMFB
exec_args = [
"iree-compile",
Expand Down
Loading