Skip to content

Commit

Permalink
try hack in missing hmax2 functions (+1 squashed commits)
Browse files Browse the repository at this point in the history
Squashed commits:

[9ba8599] try hack in missing hmax2 functions (+2 squashed commit)

Squashed commit:

[be49749] try hack in missing hmax2 functions

[159ee4c] bypass missing hmax functions on old cuda
  • Loading branch information
LostRuins committed May 1, 2024
1 parent b48ea96 commit c98d0ab
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,22 @@
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.

//hack: polyfill hmax and hmax2 for older cuda version
#if CUDART_VERSION < CUDART_HMAX
__device__ __inline__ __half __hmax(const __half a, const __half b) {
const float fa = __half2float(a);
const float fb = __half2float(b);
return __float2half(fa > fb ? fa : fb);
}
__device__ __inline__ __half2 __hmax2(const __half2 a, const __half2 b) {
__half2 result;
result.x = __hmax(a.x, b.x);
result.y = __hmax(a.y, b.y);
return result;
}
#endif


template<int D, int parallel_blocks> // D == head size
__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
static __global__ void flash_attn_vec_ext_f16(
Expand Down

0 comments on commit c98d0ab

Please sign in to comment.