Skip to content

Commit

Permalink
Erase all address spaces and get inlined ukernels (#19646)
Browse files Browse the repository at this point in the history
The `LLVMGPUCastAddressSpaceFunction` pass was selectively erasing the
shared memory address space from pointers around Call ops to achieve
inlining. This PR generalizes that to erasing all address spaces after
checking with its original author that there wasn't anything intentional
here:
[discord](https://discord.com/channels/689900678990135345/1282818085153407038/1326577591557296272)

This has the intended effect of allowing AMDGPU ukernels to get inlined
into their callers.

There is a side benefit of not having to duplicate ukernels for the
various combinations of address spaces of their pointer parameters. This
benefit will be partly rolled back if and when we do assembly ukernels,
as these will need to know the address spaces to write different
instructions, but at least for C ukernels it is nice.

It was counter-intuitive to me that erasing address spaces was possible
at all. The key is that these ukernels only get compiled to LLVM IR, not
to ISA, and the resulting IR gets inlined into a caller where the
addrspacecast was done and where the actual address space is known.
After inlining, the compiler is still able to propagate the actual
address spaces all the way into the inlined ukernel code.

For the current `multi_mma` ukernel there was no immediate problem. The
changes to it in this PR are reaping the benefits of inlining: now the
`unroll_*` parameters become compile-time constants after inlining so we
get to simply declare our accumulator tile as a VLA and let it get
specialized to a normal fixed-size array. No need anymore to use an
arbitrary fixed size array and try to guard that with assertions.

For the exising `argmax` ukernels, the inlining revealed a preexisting
issue: these ukernels are reductions to a single scalar and instead of
returning it by value, write their result value to an output buffer
(which happens to be LDS memory, but the address space doesn't matter).
The problem was that there was no synchronization between the thread
writing the value in the ukernel, and the threads reading the value in
the caller. Solved by adding a `__threadfence_block()`, which compiles
to almost nothing in ISA (s_waitcnt, which we have anyway around memory
accesses) but prevents IR rewrites removing the loads from the output
buffer.

I added `__threadfence_block()` to common.h, copied from AMD device
library headers, along with a few other synchronization functions which
we anticipate will be useful in other ukernels. `__syncthreads` is not
used in this PR.

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
  • Loading branch information
bjacob authored Jan 9, 2025
1 parent a7bac5d commit bb1c561
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 53 deletions.
25 changes: 25 additions & 0 deletions compiler/plugins/target/ROCM/builtins/ukernel/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,31 @@ _Float16 __ockl_wfred_max_f16(_Float16);
int64_t __ockl_wfred_min_i64(int64_t);
int32_t __ockl_wfred_min_i32(int32_t);

#define __CLK_LOCAL_MEM_FENCE 0x01
typedef unsigned __cl_mem_fence_flags;

static inline void __threadfence_block() {
__builtin_amdgcn_fence(__ATOMIC_SEQ_CST, "workgroup");
}

static inline void __work_group_barrier(__cl_mem_fence_flags flags) {
if (flags) {
__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");
} else {
__builtin_amdgcn_s_barrier();
}
}

static inline void __barrier(int n) {
__work_group_barrier((__cl_mem_fence_flags)n);
}

[[clang::convergent]] static inline void __syncthreads() {
__barrier(__CLK_LOCAL_MEM_FENCE);
}

//===----------------------------------------------------------------------===//
// Local replacements for HIP headers
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,18 @@ iree_uk_amdgpu_argmax_f16i32(const _Float16 *inputBuffer, int64_t input_offset,
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
// if there is only one max value holder, write and exit.
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
if (wgMax == laneMax)
if (wgMax == laneMax) {
outputBuffer[output_offset] = laneResult;
return;
}
} else {
// if there are multiple max value holder, find smallest index (argmax
// semantics).
int32_t indexVal = wgMax == laneMax ? laneResult : __INT32_MAX__;
laneResult = __ockl_wfred_min_i32(indexVal);
if (laneID == 0) {
outputBuffer[output_offset] = laneResult;
}
}

// if there are multiple max value holder, find smallest index (argmax
// semantics).
int32_t indexVal = wgMax == laneMax ? laneResult : __INT32_MAX__;
laneResult = __ockl_wfred_min_i32(indexVal);
if (laneID == 0)
outputBuffer[output_offset] = laneResult;
// TODO(bjacob): this fence should be on the caller side. Move to TileAndFuse?
__threadfence_block();
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,18 @@ iree_uk_amdgpu_argmax_f16i64(const _Float16 *inputBuffer, int64_t input_offset,
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
// if there is only one max value holder, write and exit.
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
if (wgMax == laneMax)
if (wgMax == laneMax) {
outputBuffer[output_offset] = laneResult;
return;
}
} else {
// if there are multiple max value holder, find smallest index (argmax
// semantics).
int64_t indexVal = wgMax == laneMax ? laneResult : INT64_MAX;
laneResult = __ockl_wfred_min_i64(indexVal);
if (laneID == 0) {
outputBuffer[output_offset] = laneResult;
}
}
// if there are multiple max value holder, find smallest index (argmax
// semantics).
int64_t indexVal = wgMax == laneMax ? laneResult : INT64_MAX;
laneResult = __ockl_wfred_min_i64(indexVal);
if (laneID == 0)
outputBuffer[output_offset] = laneResult;
// TODO(bjacob): this fence should be on the caller side. Move to TileAndFuse?
__threadfence_block();
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,18 @@ iree_uk_amdgpu_argmax_f32i32(const float *inputBuffer, int64_t input_offset,
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
// if there is only one max value holder, write and exit.
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
if (wgMax == laneMax)
if (wgMax == laneMax) {
outputBuffer[output_offset] = laneResult;
return;
}
} else {
// if there are multiple max value holder, find smallest index (argmax
// semantics).
int32_t indexVal = wgMax == laneMax ? laneResult : __INT32_MAX__;
laneResult = __ockl_wfred_min_i32(indexVal);
if (laneID == 0) {
outputBuffer[output_offset] = laneResult;
}
}
// if there are multiple max value holder, find smallest index (argmax
// semantics).
int32_t indexVal = wgMax == laneMax ? laneResult : __INT32_MAX__;
laneResult = __ockl_wfred_min_i32(indexVal);
if (laneID == 0)
outputBuffer[output_offset] = laneResult;
// TODO(bjacob): this fence should be on the caller side. Move to TileAndFuse?
__threadfence_block();
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,18 @@ iree_uk_amdgpu_argmax_f32i64(const float *inputBuffer, int64_t input_offset,
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
// if there is only one max value holder, write and exit.
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
if (wgMax == laneMax)
if (wgMax == laneMax) {
outputBuffer[output_offset] = laneResult;
return;
}
} else {
// if there are multiple max value holder, find smallest index (argmax
// semantics).
int64_t indexVal = wgMax == laneMax ? laneResult : INT64_MAX;
laneResult = __ockl_wfred_min_i64(indexVal);
if (laneID == 0) {
outputBuffer[output_offset] = laneResult;
}
}
// if there are multiple max value holder, find smallest index (argmax
// semantics).
int64_t indexVal = wgMax == laneMax ? laneResult : INT64_MAX;
laneResult = __ockl_wfred_min_i64(indexVal);
if (laneID == 0)
outputBuffer[output_offset] = laneResult;
// TODO(bjacob): this fence should be on the caller side. Move to TileAndFuse?
__threadfence_block();
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,17 @@
#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h"

// Very naive kernel. TODO(bjacob):
// 1. Inlining: the `always_inline` attribute here is correctly preserved in
// the bitcode, but isn't having the intended effect of inlining calls to
// this function. Making that work is key as various function parameters
// (e.g. `unroll_m`) are meant to be constants.
// 2. Shared memory: can't allocate it within the microkernel (which is just a
// 1. Shared memory: can't allocate it within the microkernel (which is just a
// helper device function, not the actual amdgpu_kernel). Need to get it
// passed down here as a `T [[clang::address_space(3)]] *` parameter.
// 3. Better scheduling via either barrier intrinsics or inline assemby.
// 4. Subgroups1x4 being asymmetric is a historical accident... should be 2x2.
// passed down here as additional parameters.
// 2. Better scheduling via either barrier intrinsics or inline assemby.
[[clang::always_inline]] void iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8(
const int8_t *a_buffer, int64_t a_offset, const int8_t *b_buffer,
int64_t b_offset, int32_t *c_buffer, int64_t c_offset, int32_t k_size,
int32_t unroll_m, int32_t subgroups_m, int32_t unroll_n,
int32_t subgroups_n, int32_t unroll_k) {
/*
TODO(bjacob): reenable this once inlining works.
// Load existing accumulators. This is a VLA, but should become fixed-size
// once this function is inlined and unroll_* factors become constants.
int32x4_t c[unroll_m][unroll_n];
*/
// Load existing accumulators.
if (unroll_m > 8 || unroll_n > 2) {
__builtin_trap();
}
int32x4_t c[8][2];
// Load existing accumulators. The VLA becomes a normal array after inlining.
int32x4_t c[unroll_m][unroll_n];
int32x4_t *c_global = (int32x4_t *)(c_buffer + c_offset);
for (int m = 0; m < unroll_m; ++m) {
for (int n = 0; n < unroll_n; ++n) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct LLVMGPUCastAddressSpaceFunctionPass final
bool anyCasted = false;
for (auto operand : operands) {
if (auto memrefType = dyn_cast<mlir::MemRefType>(operand.getType())) {
if (hasSharedMemoryAddressSpace(memrefType)) {
if (memrefType.getMemorySpace()) {
mlir::MemRefType new_memrefType = mlir::MemRefType::get(
memrefType.getShape(), memrefType.getElementType(),
memrefType.getLayout());
Expand Down

0 comments on commit bb1c561

Please sign in to comment.