Skip to content

Commit

Permalink
Merge pull request #82 from flexflow/paged_attention_new
Browse files Browse the repository at this point in the history
Paged attention new
  • Loading branch information
chenzhuofu authored Jan 31, 2025
2 parents a21f9fb + e2d6fc6 commit 918356d
Show file tree
Hide file tree
Showing 61 changed files with 1,145 additions and 295 deletions.
11 changes: 10 additions & 1 deletion include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class BatchConfig {
static int max_spec_tree_token_num();
static int max_sequence_length();
static int max_output_length();
static size_t max_kv_cache_size();
static bool streaming_cache();
static int get_max_tree_depth();
friend std::ostream &operator<<(std::ostream &os, BatchConfig const &bc);
Expand Down Expand Up @@ -113,7 +114,15 @@ class BatchConfig {
int first_token_index_in_request = -1;
int first_token_offset_in_batch = -1;
int num_tokens_in_batch = 0;
int padding = 0; // Padding for memory pointer alignment
RequestGuid request_guid;

static constexpr size_t request_guid_size = sizeof(RequestGuid);
static constexpr size_t alignment = 16;
static constexpr size_t padding_size =
(alignment - (sizeof(int) * 3 + request_guid_size) % alignment) %
alignment;
static constexpr size_t padding_length = padding_size / sizeof(int);
int padding[padding_length] = {}; // Padding for memory pointer alignment
};

struct PerTokenInfo {
Expand Down
3 changes: 3 additions & 0 deletions include/flexflow/flexflow_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,9 @@ void flexflow_request_manager_set_max_sequence_length(
void flexflow_request_manager_set_max_output_length(
flexflow_request_manager_t handle_, int max_output_length);

void flexflow_request_manager_set_max_kv_cache_size(
flexflow_request_manager_t handle_, int max_kv_cache_size);

void flexflow_request_manager_register_tokenizer(
flexflow_request_manager_t handle_,
enum ModelType model_type,
Expand Down
8 changes: 8 additions & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,10 @@ class FFModel {
CompMode comp_mode = COMP_MODE_TRAINING);
void compile_inference();
void set_transformer_layer_id(int id);
void set_num_transformer_layers(int num_layers);
void set_num_kv_heads(int num_heads);
void set_qkv_dim(int qkv_dim);
void set_size_dt(int size_dt);
void set_position_offset(int offset);
void graph_optimize(size_t budget,
bool only_data_parallel,
Expand Down Expand Up @@ -1143,6 +1147,10 @@ class FFModel {
size_t tensor_global_guid, parallel_tensor_global_guid, node_global_guid;
size_t current_transformer_layer_id;
// positional embedding start offset
int num_transformer_layers;
int num_kv_heads;
int qkv_dim;
int size_dt;
int position_offset;
FFConfig config;
FFIterationConfig iter_config;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,29 @@ namespace Kernels {
namespace IncMultiHeadAttention {

// kv layout: [num_pages, 2, page_size, num_kv_heads, head_dim]
__device__ __forceinline__ size_t
get_k_entry_offset_verify(int const token_idx,
int const page_idx,
int const num_heads,
int const head_dim) {
size_t index = ((page_idx)*kPagesize * 2 + (token_idx % kPagesize)) *
head_dim * num_heads;
return index;
}

// kv layout: [num_pages, 2, page_size, num_kv_heads, head_dim]
__device__ __forceinline__ size_t
get_v_entry_offset_verify(int const token_idx,
int const page_idx,
int const num_heads,
int const head_dim) {
size_t index =
((page_idx)*kPagesize * 2 + kPagesize + (token_idx % kPagesize)) *
head_dim * num_heads;
return index;
}

// // kv layout: [num_pages, 2, page_size, num_kv_heads, head_dim]
__device__ __forceinline__ size_t get_k_entry_offset(int const req_idx,
int const token_idx,
int const max_num_pages,
Expand Down Expand Up @@ -89,6 +112,12 @@ void update_qkv_in_batch(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
cudaStream_t stream);

template <typename DT>
void update_qkv_in_batch_paged(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
cudaStream_t stream,
bool is_spec);

// [For the tokens in streaming cache]
// Convert the out-of-order cache to in-order relative position.
// Source: pre-pos-encoding kv values in the streaming cache.
Expand Down
162 changes: 162 additions & 0 deletions include/flexflow/page_manager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#pragma once

#include "flexflow/batch_config.h"
#include "flexflow/config.h"
#include "flexflow/inference.h"
#include "flexflow/model.h"
#include "flexflow/utils/file_loader.h"
#include <deque>
#include <future>
#include <mutex>
#include <tokenizers_cpp.h>

namespace FlexFlow {

using TokenId = BatchConfig::TokenId;

/**
* @class LogicalTokenBlock
* @brief A class to represent a sequence of tokens for each request
*/
class LogicalTokenBlock {
public:
using TokenId = BatchConfig::TokenId;

// Constructor
LogicalTokenBlock(int block_number, uint32_t block_size);

// Method to check if the block is empty
bool is_empty() const;

// Method to check if the block is full
bool is_full() const;

// Method to get the number of empty slots
int get_num_empty_slots() const;

// Method to get the number of allocated slots
int get_num_alloc_slots() const;

// Used to clean up the spec tokens in a block since these spec tokens may not
// be committed after use
void reset_num_spec_tokens();

// Method to append tokens
void append_tokens(std::vector<TokenId> const &token_ids_to_append,
bool committed);

int get_num_tokens() const {
return num_tokens;
}
int get_num_commit_tokens() const {
return num_commit_tokens;
}
int get_num_spec_tokens() const {
return num_spec_tokens;
}

std::vector<TokenId> get_token_ids() const;

private:
int block_number; // the index of the logical token block
int block_size; // the size of the block
int num_tokens; // the number of tokens currently stored in the block
int num_commit_tokens; // the number of tokens inside this block that are
// already committed
int num_spec_tokens; // the number of tokens inside this block that are
// speculative tokens, which is stored temporarily
std::vector<TokenId> token_ids; // store the token ids in a order that
// corresponds to the inference sequence
};

/**
* @class PhysicalTokenBlock
* @brief A class to represent a physical block of tokens similar to physical
* memory address It keeps track of the location of the tokens stored on GPU
* memory
*/
class PhysicalTokenBlock {
public:
// Constructor
PhysicalTokenBlock(int block_number, int block_size);

// Method to get the block number
int get_block_number() const {
return block_number;
}
void incr_ref_count() {
ref_count++;
}
void decr_ref_count() {
ref_count--;
}
int ref_count; // reference count, TODO: move to private

private:
int block_number; // the index of the physical token block
int block_size; // the size of the block
};

/**
* @class BlockAllocator
* @brief A Block Manager that is reponsible for maintaining a pool of free
* blocks
*/
class BlockAllocator {
public:
// Constructor
BlockAllocator(int block_size, int num_total_blocks);

// Allocate a block
PhysicalTokenBlock allocate();

// Free a block
void free(PhysicalTokenBlock &block);

// Get the number of free blocks
int get_num_free_blocks() const;

private:
int block_size;
size_t num_total_blocks;
std::deque<PhysicalTokenBlock> free_blocks;
};

/*
* @class PageManager
* @brief A wrapper class that manages the kv cache allocation status
* notice that all the layers of model will share the same page manager because
* the position of kv cache will be the same
*/
class PageManager {
public:
// Get the singleton instance of the PageManager as it will be shared in
// multiple places
static PageManager *get_page_manager();
static PageManager *get_page_manager(FFModel *ff, size_t kv_cache_size);
size_t get_kv_cache_size_per_layer();
using BlockTable = std::vector<PhysicalTokenBlock>;
using RequestGuid = BatchConfig::RequestGuid;
PageManager(int block_size, size_t num_total_blocks);
int allocate_one_block(RequestGuid const &request_guid);
void free_request(RequestGuid const &request_guid);
// used for the case that we want to free the last num_blocks that stores spec
// tokens(which are the tokens are not yet committed)
void free_multiple_blocks(RequestGuid const &request_guid, int num_blocks);
std::vector<int>
get_block_table_indices(RequestGuid const &request_guid) const;

void free_block_table(BlockTable &block_table);

private:
size_t kv_cache_size_per_layer;
int block_size; // the size of the block
int num_total_blocks; // the total number of blocks
BlockAllocator block_allocator;
std::unordered_map<RequestGuid, BlockTable> block_tables;

int get_num_total_free_blocks() const;
int get_num_allocated_blocks(RequestGuid const &request_guid) const;
};

}; // namespace FlexFlow
20 changes: 20 additions & 0 deletions include/flexflow/request_manager.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "flexflow/batch_config.h"
#include "flexflow/inference.h"
#include "flexflow/model.h"
#include "flexflow/page_manager.h"
#include "flexflow/utils/file_loader.h"
#include <condition_variable>
#include <future>
Expand Down Expand Up @@ -148,6 +149,12 @@ struct Request {
Status status = PENDING;
std::vector<BatchConfig::TokenId> tokens;

// page attention, page_last_committed should be -1 because there are no
// blocks at the beginning
int page_last_committed = -1;
std::vector<LogicalTokenBlock> blocks;

// TokenTree speculative_token_tree;
std::vector<TokenTree> speculative_token_trees;
// To make request manager stateful, we need to store the causal mask here
BatchConfig::BitMask causal_mask;
Expand Down Expand Up @@ -316,6 +323,8 @@ class RequestManager {
int get_max_spec_tree_token_num();
void set_max_sequence_length(int max_seq_length);
int get_max_sequence_length();
void set_max_kv_cache_size(size_t max_kv_cache_size);
size_t get_max_kv_cache_size();
void set_max_output_length(int max_output_length);
int get_max_output_length();
void set_decoding_mode(DecodingMode mode);
Expand Down Expand Up @@ -449,6 +458,7 @@ class RequestManager {
int max_spec_tree_token_num;
int max_sequence_length;
int max_output_length;
size_t max_kv_cache_size;
int max_tree_depth;
int max_tree_width;
int k;
Expand Down Expand Up @@ -573,6 +583,16 @@ class RequestManager {
void init_bitmask_spec(RequestGuid guid);
BatchConfig::BitMask create_llm_bitmask(RequestGuid guid);

// Page Attention related
int get_num_blocks_allocated(Request &request) const;
int get_len_last_block(Request &request) const;
int get_idx_last_logical_token(Request &request) const;
int idx_logical_to_physical(Request &request, int idx_logical);
void _append_block_to_request(Request &request, bool is_commit);
int append_token_to_block(Request &request, TokenId token, bool is_commit);
void reset_block_table(Request &request);
void print_num_tokens(Request &request);

// Token tree related
void init_token_tree(RequestGuid guid);
void add_root_to_spec_token_tree(RequestGuid guid,
Expand Down
8 changes: 8 additions & 0 deletions inference/incr_decoding/incr_decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ void parse_input_args(char **argv,
int &max_tokens_per_prefilling_batch,
int &max_sequence_length,
int &max_output_length,
size_t &max_kv_cache_size,
int &sampling_seed,
bool &streaming_cache,
bool &slo_attainment_early_termination,
Expand Down Expand Up @@ -138,6 +139,10 @@ void parse_input_args(char **argv,
max_output_length = std::stoi(argv[++i]);
continue;
}
if (!strcmp(argv[i], "--max-kv-cache-size")) {
max_kv_cache_size = std::stoi(argv[++i]);
continue;
}
if (!strcmp(argv[i], "--sampling-seed")) {
sampling_seed = std::stoi(argv[++i]);
continue;
Expand Down Expand Up @@ -224,6 +229,7 @@ void FlexFlow::top_level_task(Task const *task,
int max_tokens_per_prefilling_batch = -1;
int max_sequence_length = 256;
int max_output_length = 512;
size_t max_kv_cache_size = 0; // if -1, then use the default value
RequestManager::DecodingMode decoding_mode =
RequestManager::INCREMENTAL_DECODING;
int sampling_seed = 0;
Expand Down Expand Up @@ -258,6 +264,7 @@ void FlexFlow::top_level_task(Task const *task,
max_tokens_per_prefilling_batch,
max_sequence_length,
max_output_length,
max_kv_cache_size,
sampling_seed,
streaming_cache,
slo_attainment_early_termination,
Expand Down Expand Up @@ -356,6 +363,7 @@ void FlexFlow::top_level_task(Task const *task,
rm->set_max_tokens_per_prefilling_batch(max_tokens_per_prefilling_batch);
rm->set_max_sequence_length(max_sequence_length);
rm->set_max_output_length(max_output_length);
rm->set_max_kv_cache_size(max_kv_cache_size);
rm->set_decoding_mode(decoding_mode);
rm->set_slo_violation_early_termination(slo_attainment_early_termination);
rm->set_baseline_latency(baseline_latency_ms);
Expand Down
5 changes: 5 additions & 0 deletions inference/models/falcon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ void FALCON::create_falcon_model(FFModel &ff,
Tensor mha = nullptr, mlp_output = nullptr;
Tensor res_ln_outputs[2] = {nullptr, nullptr};

ff.set_num_transformer_layers(falcon_config.n_layer);
ff.set_num_kv_heads(falcon_config.n_head_kv);
ff.set_qkv_dim(falcon_config.hidden_size / falcon_config.n_head * 2);
ff.set_size_dt(data_type_size(input->data_type));

for (int i = 0; i < falcon_config.n_layer; i++) {
// set transformer layer id
ff.set_transformer_layer_id(i);
Expand Down
1 change: 1 addition & 0 deletions inference/models/falcon.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

// #include "file_loader.h"
#include "flexflow/batch_config.h"
#include "flexflow/ffconst_utils.h"
#include "flexflow/inference.h"
#include "flexflow/request_manager.h"
#include <nlohmann/json.hpp>
Expand Down
7 changes: 7 additions & 0 deletions inference/models/llama.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ void LLAMA::create_llama_model(FFModel &ff,

Tensor w2 = nullptr;

// metadata that needs to be sent to page manager in order to calculate the kv
// cache per layer
ff.set_num_transformer_layers(llama_config.num_hidden_layers);
ff.set_num_kv_heads(llama_config.num_key_value_heads);
int qkv_dim = llama_config.hidden_size / llama_config.num_attention_heads * 2;
ff.set_qkv_dim(qkv_dim);
ff.set_size_dt(data_type_size(input->data_type));
for (int i = 0; i < llama_config.num_hidden_layers; i++) {
// set transformer layer id
ff.set_transformer_layer_id(i);
Expand Down
Loading

0 comments on commit 918356d

Please sign in to comment.