-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #82 from flexflow/paged_attention_new
Paged attention new
- Loading branch information
Showing
61 changed files
with
1,145 additions
and
295 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
#pragma once | ||
|
||
#include "flexflow/batch_config.h" | ||
#include "flexflow/config.h" | ||
#include "flexflow/inference.h" | ||
#include "flexflow/model.h" | ||
#include "flexflow/utils/file_loader.h" | ||
#include <deque> | ||
#include <future> | ||
#include <mutex> | ||
#include <tokenizers_cpp.h> | ||
|
||
namespace FlexFlow { | ||
|
||
using TokenId = BatchConfig::TokenId; | ||
|
||
/** | ||
* @class LogicalTokenBlock | ||
* @brief A class to represent a sequence of tokens for each request | ||
*/ | ||
class LogicalTokenBlock { | ||
public: | ||
using TokenId = BatchConfig::TokenId; | ||
|
||
// Constructor | ||
LogicalTokenBlock(int block_number, uint32_t block_size); | ||
|
||
// Method to check if the block is empty | ||
bool is_empty() const; | ||
|
||
// Method to check if the block is full | ||
bool is_full() const; | ||
|
||
// Method to get the number of empty slots | ||
int get_num_empty_slots() const; | ||
|
||
// Method to get the number of allocated slots | ||
int get_num_alloc_slots() const; | ||
|
||
// Used to clean up the spec tokens in a block since these spec tokens may not | ||
// be committed after use | ||
void reset_num_spec_tokens(); | ||
|
||
// Method to append tokens | ||
void append_tokens(std::vector<TokenId> const &token_ids_to_append, | ||
bool committed); | ||
|
||
int get_num_tokens() const { | ||
return num_tokens; | ||
} | ||
int get_num_commit_tokens() const { | ||
return num_commit_tokens; | ||
} | ||
int get_num_spec_tokens() const { | ||
return num_spec_tokens; | ||
} | ||
|
||
std::vector<TokenId> get_token_ids() const; | ||
|
||
private: | ||
int block_number; // the index of the logical token block | ||
int block_size; // the size of the block | ||
int num_tokens; // the number of tokens currently stored in the block | ||
int num_commit_tokens; // the number of tokens inside this block that are | ||
// already committed | ||
int num_spec_tokens; // the number of tokens inside this block that are | ||
// speculative tokens, which is stored temporarily | ||
std::vector<TokenId> token_ids; // store the token ids in a order that | ||
// corresponds to the inference sequence | ||
}; | ||
|
||
/** | ||
* @class PhysicalTokenBlock | ||
* @brief A class to represent a physical block of tokens similar to physical | ||
* memory address It keeps track of the location of the tokens stored on GPU | ||
* memory | ||
*/ | ||
class PhysicalTokenBlock { | ||
public: | ||
// Constructor | ||
PhysicalTokenBlock(int block_number, int block_size); | ||
|
||
// Method to get the block number | ||
int get_block_number() const { | ||
return block_number; | ||
} | ||
void incr_ref_count() { | ||
ref_count++; | ||
} | ||
void decr_ref_count() { | ||
ref_count--; | ||
} | ||
int ref_count; // reference count, TODO: move to private | ||
|
||
private: | ||
int block_number; // the index of the physical token block | ||
int block_size; // the size of the block | ||
}; | ||
|
||
/** | ||
* @class BlockAllocator | ||
* @brief A Block Manager that is reponsible for maintaining a pool of free | ||
* blocks | ||
*/ | ||
class BlockAllocator { | ||
public: | ||
// Constructor | ||
BlockAllocator(int block_size, int num_total_blocks); | ||
|
||
// Allocate a block | ||
PhysicalTokenBlock allocate(); | ||
|
||
// Free a block | ||
void free(PhysicalTokenBlock &block); | ||
|
||
// Get the number of free blocks | ||
int get_num_free_blocks() const; | ||
|
||
private: | ||
int block_size; | ||
size_t num_total_blocks; | ||
std::deque<PhysicalTokenBlock> free_blocks; | ||
}; | ||
|
||
/* | ||
* @class PageManager | ||
* @brief A wrapper class that manages the kv cache allocation status | ||
* notice that all the layers of model will share the same page manager because | ||
* the position of kv cache will be the same | ||
*/ | ||
class PageManager { | ||
public: | ||
// Get the singleton instance of the PageManager as it will be shared in | ||
// multiple places | ||
static PageManager *get_page_manager(); | ||
static PageManager *get_page_manager(FFModel *ff, size_t kv_cache_size); | ||
size_t get_kv_cache_size_per_layer(); | ||
using BlockTable = std::vector<PhysicalTokenBlock>; | ||
using RequestGuid = BatchConfig::RequestGuid; | ||
PageManager(int block_size, size_t num_total_blocks); | ||
int allocate_one_block(RequestGuid const &request_guid); | ||
void free_request(RequestGuid const &request_guid); | ||
// used for the case that we want to free the last num_blocks that stores spec | ||
// tokens(which are the tokens are not yet committed) | ||
void free_multiple_blocks(RequestGuid const &request_guid, int num_blocks); | ||
std::vector<int> | ||
get_block_table_indices(RequestGuid const &request_guid) const; | ||
|
||
void free_block_table(BlockTable &block_table); | ||
|
||
private: | ||
size_t kv_cache_size_per_layer; | ||
int block_size; // the size of the block | ||
int num_total_blocks; // the total number of blocks | ||
BlockAllocator block_allocator; | ||
std::unordered_map<RequestGuid, BlockTable> block_tables; | ||
|
||
int get_num_total_free_blocks() const; | ||
int get_num_allocated_blocks(RequestGuid const &request_guid) const; | ||
}; | ||
|
||
}; // namespace FlexFlow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.