Skip to content

Commit

Permalink
multi batch fix
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Mar 5, 2025
1 parent 82d6760 commit 826b89f
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 20 deletions.
7 changes: 6 additions & 1 deletion include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,16 @@ class BatchConfig {
int num_active_requests() const;
// returns number of inference and finetuning FWD tokens
int num_active_tokens() const;


// returns number of inference-only tokens
int num_inference_tokens() const;
int num_inference_requests() const;

// return the index where the finetuning request would be stored (i.e. last slot of the batch)
int finetuning_request_index() const;
// returns the number of finetuning FWD requests, or 0 if there is none
int num_finetuning_fwd_requests() const;

int num_finetuning_fwd_tokens() const;
int num_finetuning_bwd_requests() const;
int num_finetuning_bwd_tokens() const;
Expand Down
1 change: 1 addition & 0 deletions python/flexflow/serve/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def build_model(self):
self.max_tokens_per_batch = self.rm.get_max_tokens_per_batch()
if is_spec:
self.max_tokens_per_batch += self.rm.get_max_spec_tree_token_num()
self.max_sequence_length += self.rm.get_max_spec_tree_token_num()

ffmodel.set_num_kv_cache_pages(
compute_num_kv_cache_pages_needed(
Expand Down
1 change: 1 addition & 0 deletions python/flexflow/serve/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def build_model(self):
self.max_tokens_per_batch = self.rm.get_max_tokens_per_batch()
if is_spec:
self.max_tokens_per_batch += self.rm.get_max_spec_tree_token_num()
self.max_sequence_length += self.rm.get_max_spec_tree_token_num()

ffmodel.set_num_kv_cache_pages(
compute_num_kv_cache_pages_needed(
Expand Down
1 change: 1 addition & 0 deletions python/flexflow/serve/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def build_model(self):
self.max_tokens_per_batch = self.rm.get_max_tokens_per_batch()
if is_spec:
self.max_tokens_per_batch += self.rm.get_max_spec_tree_token_num()
self.max_sequence_length += self.rm.get_max_spec_tree_token_num()

ffmodel.set_num_kv_cache_pages(
compute_num_kv_cache_pages_needed(
Expand Down
1 change: 1 addition & 0 deletions python/flexflow/serve/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def build_model(self):
self.max_tokens_per_batch = self.rm.get_max_tokens_per_batch()
if is_spec:
self.max_tokens_per_batch += self.rm.get_max_spec_tree_token_num()
self.max_sequence_length += self.rm.get_max_spec_tree_token_num()

ffmodel.set_num_kv_cache_pages(
compute_num_kv_cache_pages_needed(
Expand Down
1 change: 1 addition & 0 deletions python/flexflow/serve/models/starcoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def build_model(self):
self.max_tokens_per_batch = self.rm.get_max_tokens_per_batch()
if is_spec:
self.max_tokens_per_batch += self.rm.get_max_spec_tree_token_num()
self.max_sequence_length += self.rm.get_max_spec_tree_token_num()

ffmodel.set_num_kv_cache_pages(
compute_num_kv_cache_pages_needed(
Expand Down
36 changes: 19 additions & 17 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -906,10 +906,7 @@ __global__ void update_kv_cache_kernel_flashinfer_kernel(
int token_abs_idx = tokenInfos[token_idx].abs_depth_in_request;
int const req_idx = tokenInfos[token_idx].request_index;
if (req_idx == peft_req_idx) {
// peft requests use separate kv cache
return;
}
assert (req_idx != peft_req_idx && "Attempting to use inference KV cache for PEFT tokens");
int req_idx_compact = 0;
for (int j = 0; j < req_idx; j++) {
Expand Down Expand Up @@ -1019,13 +1016,13 @@ void update_kv_cache_kernel_flashinfer(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
cudaStream_t stream) {
// printf("entered update_qkv_in_batch_verify\n");
int num_new_tokens = bc->num_active_tokens();
int num_new_tokens = bc->num_inference_tokens();
if (num_new_tokens == 0) {
return;
}
int tot_num_heads = m->num_q_heads + 2 * m->num_kv_heads;
int parallelism = m->qProjSize * tot_num_heads * num_new_tokens;
int peft_req_idx = bc->finetuning_request_index();
int peft_req_idx = (bc->num_finetuning_fwd_tokens() > 0) ? bc->finetuning_request_index() : -1;
int32_t *kv_indptr = m->handle.incr_attention_metadata->kv_indptr;
int32_t *kv_indices = m->handle.incr_attention_metadata->kv_indices;
update_kv_cache_kernel_flashinfer_kernel<<<GET_BLOCKS(parallelism),
Expand Down Expand Up @@ -1093,7 +1090,7 @@ void produce_output(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
DT *output_ptr,
cudaStream_t stream) {
int const num_tokens = bc->num_active_tokens();
int const num_tokens = bc->num_inference_tokens();
if (num_tokens == 0) {
return;
}
Expand All @@ -1116,14 +1113,14 @@ void flashinfer_incr_attention(IncMultiHeadSelfAttentionMeta *m,
uint32_t const num_q_heads = m->num_q_heads;
uint32_t const num_kv_heads = m->num_kv_heads;
uint32_t const head_dim = m->qProjSize;
uint32_t const batch_size = bc->num_active_requests();
uint32_t const batch_size = bc->num_inference_requests();
float const sm_scale =
(*m->qk_prod_scaling) ? 1.0f / sqrt(m->qProjSize) : 1.0f;
assert(batch_size > 0);
assert(num_q_heads > 0);
assert(num_kv_heads > 0);
assert(head_dim > 0);
assert(bc->num_active_tokens() > 0);
assert(bc->num_inference_tokens() > 0);
half *q = static_cast<half *>(m->queryTmp),
*kv = static_cast<half *>(m->kvCache),
Expand Down Expand Up @@ -1151,6 +1148,7 @@ void flashinfer_incr_attention(IncMultiHeadSelfAttentionMeta *m,
m->handle.incr_attention_metadata->kv_last_page_len);
if (m->inference_debugging) {
bc->save_to_file(get_fwd_dbg_folder(m, shard_id) + ".batch_config");
std::string fpath = get_fwd_dbg_folder(m, shard_id) + ".q_indptr";
save_tensor(
static_cast<int32_t *>(m->handle.incr_attention_metadata->q_indptr),
Expand All @@ -1162,14 +1160,20 @@ void flashinfer_incr_attention(IncMultiHeadSelfAttentionMeta *m,
batch_size + 1,
fpath.c_str());
fpath = get_fwd_dbg_folder(m, shard_id) + ".kv_indices";
int num_pages;
checkCUDA(cudaMemcpy(&num_pages,
m->handle.incr_attention_metadata->kv_indptr + batch_size,
sizeof(int),
cudaMemcpyDeviceToHost));
save_tensor(
static_cast<int32_t *>(m->handle.incr_attention_metadata->kv_indices),
batch_size + 1,
num_pages,
fpath.c_str());
fpath = get_fwd_dbg_folder(m, shard_id) + ".kv_last_page_len";
save_tensor(static_cast<int32_t *>(
m->handle.incr_attention_metadata->kv_last_page_len),
batch_size + 1,
batch_size,
fpath.c_str());
}
Expand Down Expand Up @@ -1287,9 +1291,7 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m,
// flashinfer sdpa
assert(bc->num_finetuning_fwd_tokens() >= 0 &&
bc->num_finetuning_bwd_tokens() >= 0);
if (bc->num_active_tokens() - bc->num_finetuning_fwd_tokens() -
bc->num_finetuning_bwd_tokens() >
0) {
if (bc->num_inference_tokens() > 0) {
update_kv_cache_kernel_flashinfer<DT>(m, bc, inf_stream);
flashinfer_incr_attention<DT>(m, bc, shard_id, output_ptr, inf_stream);
}
Expand Down Expand Up @@ -2025,7 +2027,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
*position_bias = _position_bias;
num_kv_cache_pages = _num_kv_cache_pages;
assert(num_kv_cache_pages > 0);
assert(num_kv_cache_pages >= 0);
// spec decoding and peft finetuning are mutually exclusive
if (enable_peft_finetuning) {
Expand Down Expand Up @@ -2246,8 +2248,8 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
gpu_mem_allocator.reserved_allocated_size);
// set attention constants
std::cerr << "Enabling incr attention metadata for handler incr meta: "
<< handler.incr_attention_metadata << std::endl;
// std::cerr << "Enabling incr attention metadata for handler incr meta: "
// << handler.incr_attention_metadata << std::endl;
handler.incr_attention_metadata->set_enabled(true);
handler.incr_attention_metadata->set_num_q_heads(num_q_heads);
handler.incr_attention_metadata->set_num_kv_heads(num_kv_heads);
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/batch_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ int BatchConfig::num_inference_tokens() const {
return num_tokens - num_ft_fwd_tokens;
}

int BatchConfig::num_inference_requests() const {
return num_active_requests() - num_finetuning_fwd_requests() - num_finetuning_bwd_requests();
}

int BatchConfig::finetuning_request_index() const {
assert(max_requests_per_batch() > 0);
return max_requests_per_batch() - 1;
Expand Down
6 changes: 4 additions & 2 deletions src/runtime/request_manager.cu
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ void prepare_inference_params_kernel_h(

// q_indptr: first token offset in batch, plus one token at the end
// representing the total number of tokens in batch
q_indptr_h.push_back(
q_indptr_h.push_back(q_indptr_h.back() +
batch_config->requestsInfo[req_idx].num_tokens_in_batch);

// kv_indptr: starting index of KV cache pages for each request in logical
Expand Down Expand Up @@ -286,9 +286,11 @@ void RequestManager::load_batch_config_task(
BatchPrefillHandler *handler = static_cast<BatchPrefillHandler *>(
handle.incr_attention_metadata->prompt_handler_collections[batch_size]);
handler->SetCUDAStream(stream);
// static int step=0;
PageManager *pm = PageManager::get_page_manager();
// printf("BatchPrefillHandler %p\n", handler);
// std::cout << *pm << std::endl;
// std::cout << "STEP " << step << ": " << *pm << std::endl;
// step+=1;
// std::cout << "batch_config: " << *batch_config << std::endl;
// std::cout << "q_indptr_h: ";
// for (int i = 0; i < q_indptr_h.size(); i++) {
Expand Down

0 comments on commit 826b89f

Please sign in to comment.