Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Sep 4, 2023
1 parent 084ca75 commit e86af62
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 0 deletions.
12 changes: 12 additions & 0 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ __global__ void single_query_cached_kv_attention_kernel(
accs[i] = 0.f;
}

scalar_t zero_value;
zero(zero_value);
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx];
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
Expand All @@ -258,6 +260,16 @@ __global__ void single_query_cached_kv_attention_kernel(
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
if (block_idx == num_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// we should explicitly zero out the values since they may contain NaNs.
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j <= V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
}
}
accs[i] += dot(logits_vec, v_vec);
}
}
Expand Down
10 changes: 10 additions & 0 deletions csrc/attention/dtype_bfloat16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -420,4 +420,14 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
#endif
}

// Zero-out a variable.
inline __device__ void zero(__nv_bfloat16& dst) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
#endif
}

} // namespace vllm
1 change: 1 addition & 0 deletions csrc/attention/dtype_float16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -441,4 +441,5 @@ inline __device__ Float8_ to_float(uint4 u) {
return tmp;
}


} // namespace vllm
5 changes: 5 additions & 0 deletions csrc/attention/dtype_float32.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -265,4 +265,9 @@ inline __device__ Float8_ to_float(Float8_ u) {
return u;
}

// Zero-out a variable.
inline __device__ void zero(float& dst) {
dst = 0.f;
}

} // namespace vllm

0 comments on commit e86af62

Please sign in to comment.