Skip to content

Commit

Permalink
lit test
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
  • Loading branch information
Hardcode84 committed Jan 30, 2025
1 parent abcd3a0 commit fb7b683
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,49 @@ def read_write_masked(
# CHECK-SAME: strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16>


@run_test
def test_read_write_buffer():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 4, N: 4}
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

@tkw.wave(constraints)
def read_write_buffer(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
res = tkw.read(a, elements_per_thread=4)
tkw.write(res, b, elements_per_thread=4)

with tk.gen.TestLaunchContext(
{
M: 1,
N: 3,
BLOCK_M: 4,
BLOCK_N: 4,
ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value,
},
canonicalize=True,
use_buffer_load_ops=True,
use_buffer_store_ops=True,
):
a = torch.randn(4, 4, dtype=torch.float16)
b = torch.zeros(4, 4, dtype=torch.float16)
print(read_write_buffer(a, b).module_op)

# CHECK-LABEL: func.func @read_write_buffer
# CHECK-COUNT-1: memref.reinterpret_cast
# CHECK-COUNT-4: amdgpu.raw_buffer_load
# CHECK-COUNT-1: memref.reinterpret_cast
# CHECK-COUNT-4: amdgpu.raw_buffer_store


@run_test
def test_read_write_masked_shared():
constraints: list[tkw.Constraint] = [
Expand Down

0 comments on commit fb7b683

Please sign in to comment.