Skip to content

Commit

Permalink
rocm fix
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Mar 6, 2025
1 parent 06f87aa commit 35d40d5
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ __device__ __forceinline__ size_t get_v_entry_offset(int const req_idx,
// Note that the q&k here are the value after applying with position encoding.
void update_qkv_in_batch(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
cudaStream_t stream);
ffStream_t stream);
template <typename DT>
void update_kv_cache_kernel_flashinfer(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
cudaStream_t stream);
ffStream_t stream);
template <typename DT>
void produce_output(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
Expand Down Expand Up @@ -126,7 +126,7 @@ void run_batched_matmul(IncMultiHeadSelfAttentionMeta const *meta,
int batchCount,
cudaDataType computeType,
cublasGemmAlgo_t algo,
cudaStream_t stream,
ffStream_t stream,
int batch_ratio_a = 1,
int batch_ratio_b = 1,
int batch_ratio_c = 1,
Expand Down
52 changes: 24 additions & 28 deletions src/ops/inc_multihead_self_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1178,9 +1178,7 @@ void update_kv_cache_kernel(IncMultiHeadSelfAttentionMeta const *m,
hipStream_t stream) {
int num_tokens = bc->num_active_tokens();
int tot_num_heads = m->num_q_heads + 2 * m->num_kv_heads;
assert(m->hidden_size % m->num_q_heads == 0);
int head_dim = m->hidden_size / m->num_q_heads;
assert(head_dim == m->qProjSize);
int head_dim = m->qProjSize;
if (num_tokens > 0) {
int parallelism = head_dim * tot_num_heads * num_tokens;
// devQKVProj has shape [qProjSize, tot_num_heads, num_new_tokens]
Expand Down Expand Up @@ -1776,7 +1774,7 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
m->peft_token_infos_size,
hipMemcpyHostToDevice,
stream));
assert(m->hidden_size == m->qProjSize * m->num_q_heads);

assert(m->qProjSize == m->kProjSize);
/*q&k*/
int half_proj = m->qProjSize / 2;
Expand Down Expand Up @@ -1835,10 +1833,10 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
// matrix B's layout: [num_tokens, qProjsize * num_heads, 3]
DT const *B = static_cast<DT *>(m->devQKVProjArray);
// matrix C: gradients w.r.t. input
// matrix C's layout: [m->qSize, num_tokens]
DT *C = input_grad_ptr +
bc->requestsInfo[i].first_token_offset_in_batch * m->qSize;
// int m_ = m->qSize;
// matrix C's layout: [m->qProjSize * m->num_q_heads, num_tokens]
DT *C = input_grad_ptr + bc->requestsInfo[i].first_token_offset_in_batch *
m->qProjSize * m->num_q_heads;
// int m_ = m->qProjSize * m->num_q_heads;
int n_ = num_tokens;
int k_ = m->qProjSize * (m->num_q_heads + 2 * m->num_kv_heads);

Expand All @@ -1851,7 +1849,8 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
if (m->inference_debugging) {
std::string filename =
get_peft_dbg_folder(m, shard_id) + ".self_attn.input_gradient_0";
save_tensor(C, num_tokens * m->qSize, filename.c_str());
save_tensor(
C, num_tokens * m->qProjSize * m->num_q_heads, filename.c_str());
}
}
}
Expand Down Expand Up @@ -1954,15 +1953,11 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
FFHandler handler,
IncMultiHeadSelfAttention const *attn,
MemoryAllocator &gpu_mem_allocator,
int num_samples,
int _num_q_heads,
int _num_kv_heads)
: IncMultiHeadSelfAttentionMeta(handler,
INC_DECODING_MODE,
attn,
attn->qSize,
attn->kSize,
attn->vSize,
attn->qProjSize,
attn->kProjSize,
attn->vProjSize,
Expand All @@ -1973,21 +1968,18 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
attn->position_bias,
attn->scaling_factor,
gpu_mem_allocator,
num_samples,
attn->num_q_heads,
attn->num_kv_heads,
_num_q_heads,
_num_kv_heads,
attn->num_kv_cache_pages,
attn->quantization_type,
attn->offload) {}

IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
FFHandler handler,
InferenceMode infer_mode,
Op const *attn,
int _qSize,
int _kSize,
int _vSize,
int _qProjSize,
int _kProjSize,
int _vProjSize,
Expand All @@ -1998,24 +1990,19 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
bool _position_bias,
float _scaling_factor,
MemoryAllocator &gpu_mem_allocator,
int num_samples,
int _global_num_q_heads,
int _global_num_kv_heads,
int _num_q_heads,
int _num_kv_heads,
int _num_kv_cache_pages,
DataType _quantization_type,
bool _offload)
: OpMeta(handler, attn) {
hipStream_t stream;
checkCUDA(get_legion_stream(&stream));
checkCUDNN(miopenSetStream(handler.dnn, stream));
checkCUDNN(miopenCreateTensorDescriptor(&qk_tensor));
qSize = _qSize;
kSize = _kSize;
vSize = _vSize;
// assume dimensions match for now
assert(qSize == kSize);
assert(kSize == vSize);
qProjSize = _qProjSize;
kProjSize = _kProjSize;
assert(qProjSize == kProjSize); // required for attention QK.T matmul
Expand All @@ -2029,7 +2016,6 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
global_num_kv_heads = _global_num_kv_heads;
num_q_heads = _num_q_heads;
num_kv_heads = _num_kv_heads;
hidden_size = num_q_heads * qProjSize;

rotary_embedding_meta =
(RotaryEmbeddingMeta *)calloc(1, sizeof(RotaryEmbeddingMeta));
Expand All @@ -2042,6 +2028,14 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
position_bias = (bool *)calloc(1, sizeof(bool));
*position_bias = _position_bias;

num_kv_cache_pages = _num_kv_cache_pages;
assert(num_kv_cache_pages > 0 || enable_peft_finetuning);

// spec decoding and peft finetuning are mutually exclusive
if (enable_peft_finetuning) {
assert(infer_mode == INC_DECODING_MODE);
}

assert(num_q_heads % num_kv_heads == 0 &&
"num_q_heads must be divisible by num_kv_heads");
if (num_q_heads > num_kv_heads) {
Expand Down Expand Up @@ -2197,17 +2191,19 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
size_of_dt);
qk_prods_softmax = gpu_mem_allocator.allocate_reserved_untyped(
qk_prod_size * size_of_dt);
attn_heads = gpu_mem_allocator.allocate_reserved_untyped(attn_heads_size *
size_of_dt);
// attn_heads =
// gpu_mem_allocator.allocate_reserved_untyped(attn_heads_size *
// size_of_dt);
complex_input =
gpu_mem_allocator.allocate_reserved<hipFloatComplex>(complex_size);
} else {
qk_prods = gpu_mem_allocator.allocate_instance_untyped(qk_prod_size *
size_of_dt);
qk_prods_softmax = gpu_mem_allocator.allocate_instance_untyped(
qk_prod_size * size_of_dt);
attn_heads = gpu_mem_allocator.allocate_instance_untyped(attn_heads_size *
size_of_dt);
// attn_heads =
// gpu_mem_allocator.allocate_instance_untyped(attn_heads_size *
// size_of_dt);
complex_input =
gpu_mem_allocator.allocate_instance<hipFloatComplex>(complex_size);
}
Expand Down
10 changes: 2 additions & 8 deletions src/ops/spec_inc_multihead_self_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,7 @@ void update_kv_cache_kernel(SpecIncMultiHeadSelfAttentionMeta const *m,
hipStream_t stream) {
int num_tokens = bc->num_active_tokens();
int tot_num_heads = m->num_q_heads + 2 * m->num_kv_heads;
assert(m->hidden_size % m->num_q_heads == 0);
int head_dim = m->hidden_size / m->num_q_heads;
assert(head_dim == m->qProjSize);
int head_dim = m->qProjSize;
int curr_depth = bc->beamRequestsInfo[0].current_depth;
if (num_tokens > 0) {
int parallelism = head_dim * tot_num_heads * num_tokens;
Expand Down Expand Up @@ -843,15 +841,11 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta(
FFHandler handler,
SpecIncMultiHeadSelfAttention const *attn,
MemoryAllocator &gpu_mem_allocator,
int num_samples,
int _num_q_heads,
int _num_kv_heads)
: IncMultiHeadSelfAttentionMeta(handler,
BEAM_SEARCH_MODE,
attn,
attn->qSize,
attn->kSize,
attn->vSize,
attn->qProjSize,
attn->kProjSize,
attn->vProjSize,
Expand All @@ -862,11 +856,11 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta(
attn->position_bias,
attn->scaling_factor,
gpu_mem_allocator,
num_samples,
attn->num_q_heads,
attn->num_kv_heads,
_num_q_heads,
_num_kv_heads,
attn->num_kv_cache_pages,
DT_NONE,
false) {
hipStream_t stream;
Expand Down
12 changes: 3 additions & 9 deletions src/ops/tree_inc_multihead_self_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,7 @@ template <typename DT>
void commit_tokens(TreeIncMultiHeadSelfAttentionMeta const *m,
TreeVerifyBatchConfig const *bc,
hipStream_t stream) {
int head_dim = m->hidden_size / m->num_q_heads;
assert(head_dim == m->qProjSize);
int head_dim = m->qProjSize;
// int tot_num_heads = m->num_q_heads + 2 * m->num_kv_heads;
int num_tokens_to_commit = bc->num_tokens_to_commit;
if (num_tokens_to_commit > 0) {
Expand Down Expand Up @@ -534,8 +533,7 @@ void compute_attention_kernel_fused(TreeIncMultiHeadSelfAttentionMeta const *m,

// update the kv cache
// update K-V cache
int head_dim = m->hidden_size / m->num_q_heads;
assert(head_dim == m->qProjSize);
int head_dim = m->qProjSize;
int num_new_tokens = bc->num_active_tokens();
int parallelism = head_dim * m->num_kv_heads * num_new_tokens;
hipLaunchKernelGGL(HIP_KERNEL_NAME(update_tree_branch_kv_cache_fused),
Expand Down Expand Up @@ -660,15 +658,11 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta(
FFHandler handler,
TreeIncMultiHeadSelfAttention const *attn,
MemoryAllocator &gpu_mem_allocator,
int num_samples,
int _num_q_heads,
int _num_kv_heads)
: IncMultiHeadSelfAttentionMeta(handler,
TREE_VERIFY_MODE,
attn,
attn->qSize,
attn->kSize,
attn->vSize,
attn->qProjSize,
attn->kProjSize,
attn->vProjSize,
Expand All @@ -679,11 +673,11 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta(
attn->position_bias,
attn->scaling_factor,
gpu_mem_allocator,
num_samples,
attn->num_q_heads,
attn->num_kv_heads,
_num_q_heads,
_num_kv_heads,
attn->num_kv_cache_pages,
attn->quantization_type,
attn->offload),
num_active_infr_tokens(0) {
Expand Down

0 comments on commit 35d40d5

Please sign in to comment.