Skip to content

Commit

Permalink
feat:support any num_heads for get_alibi_slope (flashinfer-ai#200)
Browse files Browse the repository at this point in the history
When I was using flashinfer, I encountered that the heads of some models
were not powers of 2. I refer to
**flashinfer/python/tests/alibi_reference.py**, modifies this part of
the C++ code.
  • Loading branch information
yz-tang authored Apr 11, 2024
1 parent a22aeb6 commit b217a6f
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions include/flashinfer/pos_enc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ inline std::string PosEncodingModeToString(const PosEncodingMode& pos_encoding_m
}

__device__ __forceinline__ float get_alibi_slope(uint32_t head_idx, uint32_t num_heads) {
// NOTE(Zihao): here we assume that num_heads is a power of 2
return math::ptx_exp2(-8. * float(head_idx + 1) / float(num_heads));
int n = math::ptx_exp2((int)math::ptx_log2(num_heads));
return head_idx < n ? math::ptx_exp2(-8. * float(head_idx + 1) / float(n))
: math::ptx_exp2(-4. * float((head_idx + 1 - n) * 2 - 1) / float(n));
}

/*!
Expand Down

0 comments on commit b217a6f

Please sign in to comment.