Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paged attention new #82

Merged
merged 62 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
71d8a7b
add page_manager and request_manager functions
Bob-Chen222 Oct 3, 2024
0eaca39
add batch_config
Bob-Chen222 Oct 3, 2024
151872f
request manager h and request manger cc to be continued
Bob-Chen222 Oct 8, 2024
904364d
Merge remote-tracking branch 'origin/specscheduler' into paged_attent…
chenzhuofu Oct 8, 2024
e3abef8
refactored the interface of block manager but may not be bug free
chenzhuofu Oct 8, 2024
73dc699
ckpt before build
chenzhuofu Oct 10, 2024
de0b803
some fix
Bob-Chen222 Oct 10, 2024
0e405c1
ready for sanity check
Bob-Chen222 Oct 10, 2024
dec2266
Merge remote-tracking branch 'origin/specscheduler' into paged_attent…
Bob-Chen222 Oct 10, 2024
6b4777e
fix last commit index
Bob-Chen222 Oct 10, 2024
8394f15
fix request id error
Bob-Chen222 Oct 11, 2024
2ec8b5b
fix spec token num
chenzhuofu Oct 11, 2024
b12df8c
fix small error in free_multiple_blocks
chenzhuofu Oct 11, 2024
6298f2a
ckpt single request
Bob-Chen222 Oct 11, 2024
c00ddec
add cleanup
Bob-Chen222 Oct 11, 2024
b1ff323
ckpt before index error in prepare_parameters
Bob-Chen222 Oct 11, 2024
8a3975a
fix token error in prepare_batch_config
Bob-Chen222 Oct 11, 2024
f4e73ea
ckpt, something wrong in the prefilling
Bob-Chen222 Oct 11, 2024
4eeb021
ckpt
Bob-Chen222 Oct 12, 2024
12fafa3
Merge remote-tracking branch 'origin/specscheduler' into paged_attent…
Bob-Chen222 Oct 12, 2024
3ad0ca5
update
Bob-Chen222 Oct 12, 2024
945dee9
Merge remote-tracking branch 'origin/specscheduler' into paged_attent…
Bob-Chen222 Oct 20, 2024
19e41d6
add some docuementation and delete print
Bob-Chen222 Oct 21, 2024
b1793fb
add additional flag max-kv-cache-size
Bob-Chen222 Oct 21, 2024
fd65a90
Merge remote-tracking branch 'origin/specscheduler' into paged_attent…
Bob-Chen222 Nov 4, 2024
832f5cb
fix for merge
Bob-Chen222 Nov 4, 2024
4a7162f
init page manager at request manager init and clean the format
Bob-Chen222 Nov 4, 2024
6b74f93
ckpt
Bob-Chen222 Nov 5, 2024
20cb714
refactor and add kv cache flag via page manager
Bob-Chen222 Nov 5, 2024
311c450
ckpt for performance issue
Bob-Chen222 Nov 5, 2024
a493f2a
first attempt in incr decoding with page attention
Bob-Chen222 Nov 5, 2024
5250a3b
ckpt for nothing
Bob-Chen222 Nov 6, 2024
810983e
fix compilation error
Bob-Chen222 Nov 7, 2024
f7656be
all good for spec, now test incr
Bob-Chen222 Nov 7, 2024
8c203ec
typo
Bob-Chen222 Nov 7, 2024
3c158f8
workable incrdecoding!
Bob-Chen222 Nov 7, 2024
3b34a5b
Merge remote-tracking branch 'origin/specscheduler' into paged_attent…
Bob-Chen222 Nov 7, 2024
7d612f7
refactor
Bob-Chen222 Nov 8, 2024
07ec33e
some format
Bob-Chen222 Nov 8, 2024
dad3d0f
Update request_manager.h
Bob-Chen222 Nov 8, 2024
1693455
Update llama.cc
Bob-Chen222 Nov 8, 2024
a17c130
Update spec_infer.cc
Bob-Chen222 Nov 8, 2024
0f16daf
Update trace_generator.cc
Bob-Chen222 Nov 8, 2024
ff7de09
Update tree_inc_multihead_self_attention.cu
Bob-Chen222 Nov 8, 2024
e3815a9
Update tree_inc_multihead_self_attention.cu
Bob-Chen222 Nov 8, 2024
38f6ef8
Update tree_inc_multihead_self_attention.cu
Bob-Chen222 Nov 8, 2024
80ea225
Update page_manager.cc
Bob-Chen222 Nov 8, 2024
5fe3a8a
Update request_manager.cc
Bob-Chen222 Nov 8, 2024
a721926
Update request_manager.cc
Bob-Chen222 Nov 8, 2024
1e7e2ec
Update request_manager.cc
Bob-Chen222 Nov 8, 2024
1792981
Update request_manager.cc
Bob-Chen222 Nov 8, 2024
95023e6
final update
Bob-Chen222 Nov 8, 2024
bc67e97
:Merge branch 'specscheduler' of https://github.com/flexflow/flexflow…
chenzhuofu Jan 24, 2025
a5b7de6
fix: minor
goliaro Jan 26, 2025
76c23c0
feat: merge misc. from `page_attention_new`
chenzhuofu Jan 26, 2025
9c042f5
fix: merge page_manager, also fix some issues
chenzhuofu Jan 26, 2025
2a751fd
style: format code
chenzhuofu Jan 26, 2025
3ed67e4
fix: minor
chenzhuofu Jan 26, 2025
69b9f72
fix: merge page_manager, also fix some issues
chenzhuofu Jan 26, 2025
e0eca51
style: format code
chenzhuofu Jan 26, 2025
0f13a92
Merge branch 'paged_attention_new' of https://github.com/flexflow/fle…
chenzhuofu Jan 26, 2025
e2d6fc6
chore: remove outdated comments
chenzhuofu Jan 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class BatchConfig {
static int max_spec_tree_token_num();
static int max_sequence_length();
static int max_output_length();
static size_t max_kv_cache_size();
static bool streaming_cache();
static int get_max_tree_depth();
friend std::ostream &operator<<(std::ostream &os, BatchConfig const &bc);
Expand Down Expand Up @@ -113,7 +114,15 @@ class BatchConfig {
int first_token_index_in_request = -1;
int first_token_offset_in_batch = -1;
int num_tokens_in_batch = 0;
int padding = 0; // Padding for memory pointer alignment
RequestGuid request_guid;

static constexpr size_t request_guid_size = sizeof(RequestGuid);
static constexpr size_t alignment = 16;
static constexpr size_t padding_size =
(alignment - (sizeof(int) * 3 + request_guid_size) % alignment) %
alignment;
static constexpr size_t padding_length = padding_size / sizeof(int);
int padding[padding_length] = {}; // Padding for memory pointer alignment
};

struct PerTokenInfo {
Expand Down
3 changes: 3 additions & 0 deletions include/flexflow/flexflow_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,9 @@ void flexflow_request_manager_set_max_sequence_length(
void flexflow_request_manager_set_max_output_length(
flexflow_request_manager_t handle_, int max_output_length);

void flexflow_request_manager_set_max_kv_cache_size(
flexflow_request_manager_t handle_, int max_kv_cache_size);

void flexflow_request_manager_register_tokenizer(
flexflow_request_manager_t handle_,
enum ModelType model_type,
Expand Down
8 changes: 8 additions & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,10 @@ class FFModel {
CompMode comp_mode = COMP_MODE_TRAINING);
void compile_inference();
void set_transformer_layer_id(int id);
void set_num_transformer_layers(int num_layers);
void set_num_kv_heads(int num_heads);
void set_qkv_dim(int qkv_dim);
void set_size_dt(int size_dt);
void set_position_offset(int offset);
void graph_optimize(size_t budget,
bool only_data_parallel,
Expand Down Expand Up @@ -1143,6 +1147,10 @@ class FFModel {
size_t tensor_global_guid, parallel_tensor_global_guid, node_global_guid;
size_t current_transformer_layer_id;
// positional embedding start offset
int num_transformer_layers;
int num_kv_heads;
int qkv_dim;
int size_dt;
int position_offset;
FFConfig config;
FFIterationConfig iter_config;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,29 @@ namespace Kernels {
namespace IncMultiHeadAttention {

// kv layout: [num_pages, 2, page_size, num_kv_heads, head_dim]
__device__ __forceinline__ size_t
get_k_entry_offset_verify(int const token_idx,
int const page_idx,
int const num_heads,
int const head_dim) {
size_t index = ((page_idx)*kPagesize * 2 + (token_idx % kPagesize)) *
head_dim * num_heads;
return index;
}

// kv layout: [num_pages, 2, page_size, num_kv_heads, head_dim]
__device__ __forceinline__ size_t
get_v_entry_offset_verify(int const token_idx,
int const page_idx,
int const num_heads,
int const head_dim) {
size_t index =
((page_idx)*kPagesize * 2 + kPagesize + (token_idx % kPagesize)) *
head_dim * num_heads;
return index;
}

// // kv layout: [num_pages, 2, page_size, num_kv_heads, head_dim]
__device__ __forceinline__ size_t get_k_entry_offset(int const req_idx,
int const token_idx,
int const max_num_pages,
Expand Down Expand Up @@ -89,6 +112,12 @@ void update_qkv_in_batch(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
cudaStream_t stream);

template <typename DT>
void update_qkv_in_batch_paged(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
cudaStream_t stream,
bool is_spec);

// [For the tokens in streaming cache]
// Convert the out-of-order cache to in-order relative position.
// Source: pre-pos-encoding kv values in the streaming cache.
Expand Down
162 changes: 162 additions & 0 deletions include/flexflow/page_manager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#pragma once

#include "flexflow/batch_config.h"
#include "flexflow/config.h"
#include "flexflow/inference.h"
#include "flexflow/model.h"
#include "flexflow/utils/file_loader.h"
#include <deque>
#include <future>
#include <mutex>
#include <tokenizers_cpp.h>

namespace FlexFlow {

using TokenId = BatchConfig::TokenId;

/**
* @class LogicalTokenBlock
* @brief A class to represent a sequence of tokens for each request
*/
class LogicalTokenBlock {
public:
using TokenId = BatchConfig::TokenId;

// Constructor
LogicalTokenBlock(int block_number, uint32_t block_size);

// Method to check if the block is empty
bool is_empty() const;

// Method to check if the block is full
bool is_full() const;

// Method to get the number of empty slots
int get_num_empty_slots() const;

// Method to get the number of allocated slots
int get_num_alloc_slots() const;

// Used to clean up the spec tokens in a block since these spec tokens may not
// be committed after use
void reset_num_spec_tokens();

// Method to append tokens
void append_tokens(std::vector<TokenId> const &token_ids_to_append,
bool committed);

int get_num_tokens() const {
return num_tokens;
}
int get_num_commit_tokens() const {
return num_commit_tokens;
}
int get_num_spec_tokens() const {
return num_spec_tokens;
}

std::vector<TokenId> get_token_ids() const;

private:
int block_number; // the index of the logical token block
int block_size; // the size of the block
int num_tokens; // the number of tokens currently stored in the block
int num_commit_tokens; // the number of tokens inside this block that are
// already committed
int num_spec_tokens; // the number of tokens inside this block that are
// speculative tokens, which is stored temporarily
std::vector<TokenId> token_ids; // store the token ids in a order that
// corresponds to the inference sequence
};

/**
* @class PhysicalTokenBlock
* @brief A class to represent a physical block of tokens similar to physical
* memory address It keeps track of the location of the tokens stored on GPU
* memory
*/
class PhysicalTokenBlock {
public:
// Constructor
PhysicalTokenBlock(int block_number, int block_size);

// Method to get the block number
int get_block_number() const {
return block_number;
}
void incr_ref_count() {
ref_count++;
}
void decr_ref_count() {
ref_count--;
}
int ref_count; // reference count, TODO: move to private

private:
int block_number; // the index of the physical token block
int block_size; // the size of the block
};

/**
* @class BlockAllocator
* @brief A Block Manager that is reponsible for maintaining a pool of free
* blocks
*/
class BlockAllocator {
public:
// Constructor
BlockAllocator(int block_size, int num_total_blocks);

// Allocate a block
PhysicalTokenBlock allocate();

// Free a block
void free(PhysicalTokenBlock &block);

// Get the number of free blocks
int get_num_free_blocks() const;

private:
int block_size;
size_t num_total_blocks;
std::deque<PhysicalTokenBlock> free_blocks;
};

/*
* @class PageManager
* @brief A wrapper class that manages the kv cache allocation status
* notice that all the layers of model will share the same page manager because
* the position of kv cache will be the same
*/
class PageManager {
public:
// Get the singleton instance of the PageManager as it will be shared in
// multiple places
static PageManager *get_page_manager();
static PageManager *get_page_manager(FFModel *ff, size_t kv_cache_size);
size_t get_kv_cache_size_per_layer();
using BlockTable = std::vector<PhysicalTokenBlock>;
using RequestGuid = BatchConfig::RequestGuid;
PageManager(int block_size, size_t num_total_blocks);
int allocate_one_block(RequestGuid const &request_guid);
void free_request(RequestGuid const &request_guid);
// used for the case that we want to free the last num_blocks that stores spec
// tokens(which are the tokens are not yet committed)
void free_multiple_blocks(RequestGuid const &request_guid, int num_blocks);
std::vector<int>
get_block_table_indices(RequestGuid const &request_guid) const;

void free_block_table(BlockTable &block_table);

private:
size_t kv_cache_size_per_layer;
int block_size; // the size of the block
int num_total_blocks; // the total number of blocks
BlockAllocator block_allocator;
std::unordered_map<RequestGuid, BlockTable> block_tables;

int get_num_total_free_blocks() const;
int get_num_allocated_blocks(RequestGuid const &request_guid) const;
};

}; // namespace FlexFlow
20 changes: 20 additions & 0 deletions include/flexflow/request_manager.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "flexflow/batch_config.h"
#include "flexflow/inference.h"
#include "flexflow/model.h"
#include "flexflow/page_manager.h"
#include "flexflow/utils/file_loader.h"
#include <condition_variable>
#include <future>
Expand Down Expand Up @@ -148,6 +149,12 @@ struct Request {
Status status = PENDING;
std::vector<BatchConfig::TokenId> tokens;

// page attention, page_last_committed should be -1 because there are no
// blocks at the beginning
int page_last_committed = -1;
std::vector<LogicalTokenBlock> blocks;

// TokenTree speculative_token_tree;
std::vector<TokenTree> speculative_token_trees;
// To make request manager stateful, we need to store the causal mask here
BatchConfig::BitMask causal_mask;
Expand Down Expand Up @@ -316,6 +323,8 @@ class RequestManager {
int get_max_spec_tree_token_num();
void set_max_sequence_length(int max_seq_length);
int get_max_sequence_length();
void set_max_kv_cache_size(size_t max_kv_cache_size);
size_t get_max_kv_cache_size();
void set_max_output_length(int max_output_length);
int get_max_output_length();
void set_decoding_mode(DecodingMode mode);
Expand Down Expand Up @@ -449,6 +458,7 @@ class RequestManager {
int max_spec_tree_token_num;
int max_sequence_length;
int max_output_length;
size_t max_kv_cache_size;
int max_tree_depth;
int max_tree_width;
int k;
Expand Down Expand Up @@ -573,6 +583,16 @@ class RequestManager {
void init_bitmask_spec(RequestGuid guid);
BatchConfig::BitMask create_llm_bitmask(RequestGuid guid);

// Page Attention related
int get_num_blocks_allocated(Request &request) const;
int get_len_last_block(Request &request) const;
int get_idx_last_logical_token(Request &request) const;
int idx_logical_to_physical(Request &request, int idx_logical);
void _append_block_to_request(Request &request, bool is_commit);
int append_token_to_block(Request &request, TokenId token, bool is_commit);
void reset_block_table(Request &request);
void print_num_tokens(Request &request);

// Token tree related
void init_token_tree(RequestGuid guid);
void add_root_to_spec_token_tree(RequestGuid guid,
Expand Down
8 changes: 8 additions & 0 deletions inference/incr_decoding/incr_decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ void parse_input_args(char **argv,
int &max_tokens_per_prefilling_batch,
int &max_sequence_length,
int &max_output_length,
size_t &max_kv_cache_size,
int &sampling_seed,
bool &streaming_cache,
bool &slo_attainment_early_termination,
Expand Down Expand Up @@ -138,6 +139,10 @@ void parse_input_args(char **argv,
max_output_length = std::stoi(argv[++i]);
continue;
}
if (!strcmp(argv[i], "--max-kv-cache-size")) {
max_kv_cache_size = std::stoi(argv[++i]);
continue;
}
if (!strcmp(argv[i], "--sampling-seed")) {
sampling_seed = std::stoi(argv[++i]);
continue;
Expand Down Expand Up @@ -224,6 +229,7 @@ void FlexFlow::top_level_task(Task const *task,
int max_tokens_per_prefilling_batch = -1;
int max_sequence_length = 256;
int max_output_length = 512;
size_t max_kv_cache_size = 0; // if -1, then use the default value
RequestManager::DecodingMode decoding_mode =
RequestManager::INCREMENTAL_DECODING;
int sampling_seed = 0;
Expand Down Expand Up @@ -258,6 +264,7 @@ void FlexFlow::top_level_task(Task const *task,
max_tokens_per_prefilling_batch,
max_sequence_length,
max_output_length,
max_kv_cache_size,
sampling_seed,
streaming_cache,
slo_attainment_early_termination,
Expand Down Expand Up @@ -356,6 +363,7 @@ void FlexFlow::top_level_task(Task const *task,
rm->set_max_tokens_per_prefilling_batch(max_tokens_per_prefilling_batch);
rm->set_max_sequence_length(max_sequence_length);
rm->set_max_output_length(max_output_length);
rm->set_max_kv_cache_size(max_kv_cache_size);
rm->set_decoding_mode(decoding_mode);
rm->set_slo_violation_early_termination(slo_attainment_early_termination);
rm->set_baseline_latency(baseline_latency_ms);
Expand Down
5 changes: 5 additions & 0 deletions inference/models/falcon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ void FALCON::create_falcon_model(FFModel &ff,
Tensor mha = nullptr, mlp_output = nullptr;
Tensor res_ln_outputs[2] = {nullptr, nullptr};

ff.set_num_transformer_layers(falcon_config.n_layer);
ff.set_num_kv_heads(falcon_config.n_head_kv);
ff.set_qkv_dim(falcon_config.hidden_size / falcon_config.n_head * 2);
ff.set_size_dt(data_type_size(input->data_type));

for (int i = 0; i < falcon_config.n_layer; i++) {
// set transformer layer id
ff.set_transformer_layer_id(i);
Expand Down
1 change: 1 addition & 0 deletions inference/models/falcon.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

// #include "file_loader.h"
#include "flexflow/batch_config.h"
#include "flexflow/ffconst_utils.h"
#include "flexflow/inference.h"
#include "flexflow/request_manager.h"
#include <nlohmann/json.hpp>
Expand Down
7 changes: 7 additions & 0 deletions inference/models/llama.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ void LLAMA::create_llama_model(FFModel &ff,

Tensor w2 = nullptr;

// metadata that needs to be sent to page manager in order to calculate the kv
// cache per layer
ff.set_num_transformer_layers(llama_config.num_hidden_layers);
ff.set_num_kv_heads(llama_config.num_key_value_heads);
int qkv_dim = llama_config.hidden_size / llama_config.num_attention_heads * 2;
ff.set_qkv_dim(qkv_dim);
ff.set_size_dt(data_type_size(input->data_type));
for (int i = 0; i < llama_config.num_hidden_layers; i++) {
// set transformer layer id
ff.set_transformer_layer_id(i);
Expand Down
Loading
Loading