Skip to content

Commit

Permalink
FlexLLM (part 2) (#104)
Browse files Browse the repository at this point in the history
* init

* update

* hip fixes
  • Loading branch information
goliaro authored Feb 24, 2025
1 parent 91bcf2d commit 2488463
Show file tree
Hide file tree
Showing 11 changed files with 2,439 additions and 1,636 deletions.
4 changes: 4 additions & 0 deletions include/flexflow/ops/inc_multihead_self_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta {
// typedef hipFloatComplex attFloatComplex;
hipFloatComplex *complex_input;
#endif
// GQA
void **d_A_array, **d_B_array, **d_C_array;
void **d_A_array2, **d_B_array2, **d_C_array2;
size_t gqa_ptr_array_size;
// PEFT specific fields
void *softmax_activation_buffer;
void *query_activation_buffer;
Expand Down
91 changes: 59 additions & 32 deletions include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,41 +42,68 @@ __global__ void apply_position_bias_qkprd(DT *input_ptr,

#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
template <typename DT>
__global__ void
apply_rotary_embedding(DT *input_ptr,
cuFloatComplex *complex_input,
BatchConfig::PerTokenInfo const *tokenInfos,
int qProjSize,
int kProjSize,
int num_heads,
int num_tokens,
int num_kv_heads,
int q_block_size,
int k_block_size,
int q_array_size,
bool q_tensor);
#elif defined(FF_USE_HIP_ROCM)
void run_batched_matmul(IncMultiHeadSelfAttentionMeta const *meta,
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
void const *alpha,
const DT *A,
cudaDataType Atype,
int lda,
long long int strideA,
const DT *B,
cudaDataType Btype,
int ldb,
long long int strideB,
void const *beta,
DT *C,
cudaDataType Ctype,
int ldc,
long long int strideC,
int batchCount,
cudaDataType computeType,
cublasGemmAlgo_t algo,
cudaStream_t stream,
int batch_ratio_a = 1,
int batch_ratio_b = 1,
int batch_ratio_c = 1,
bool bwd = false);
#else
template <typename DT>
__global__ void
apply_rotary_embedding(DT *input_ptr,
hipFloatComplex *complex_input,
BatchConfig::PerTokenInfo const *tokenInfos,
int qProjSize,
int kProjSize,
int num_heads,
int num_tokens,
int num_kv_heads,
int q_block_size,
int k_block_size,
int q_array_size,
bool q_tensor);
void run_batched_matmul(IncMultiHeadSelfAttentionMeta const *meta,
hipblasHandle_t handle,
hipblasOperation_t transa,
hipblasOperation_t transb,
int m,
int n,
int k,
void const *alpha,
const DT *A,
hipblasDatatype_t Atype,
int lda,
long long int strideA,
const DT *B,
hipblasDatatype_t Btype,
int ldb,
long long int strideB,
void const *beta,
DT *C,
hipblasDatatype_t Ctype,
int ldc,
long long int strideC,
int batchCount,
hipblasDatatype_t computeType,
hipblasGemmAlgo_t algo,
hipStream_t stream,
int batch_ratio_a = 1,
int batch_ratio_b = 1,
int batch_ratio_c = 1,
bool bwd = false);
#endif

template <typename DT>
void pre_build_weight_kernel(IncMultiHeadSelfAttentionMeta const *m,
GenericTensorAccessorR const weight,
DataType data_type,
ffStream_t stream);
} // namespace IncMultiHeadAttention
} // namespace Kernels
} // namespace FlexFlow
Expand Down
9 changes: 6 additions & 3 deletions python/flexflow/serve/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,13 @@ def build_model(self, max_tokens_per_batch):
self.falcon_config.layer_norm_epsilon,
name=f"layers.{i}.input_layernorm",
)

assert(self.falcon_config.hidden_size % self.falcon_config.n_head == 0)
head_dim = self.falcon_config.hidden_size // self.falcon_config.n_head

qkv_proj = ffmodel.dense(
att_norm,
3 * self.falcon_config.hidden_size,
head_dim * (self.falcon_config.n_head + 2*self.falcon_config.n_head_kv),
ActiMode.AC_MODE_NONE,
False,
name=f"layers.{i}.self_attention.qkv_proj",
Expand All @@ -158,8 +161,8 @@ def build_model(self, max_tokens_per_batch):
self.falcon_config.hidden_size,
self.falcon_config.n_head,
self.falcon_config.n_head_kv,
self.falcon_config.hidden_size // self.falcon_config.n_head,
self.falcon_config.hidden_size // self.falcon_config.n_head,
head_dim,
head_dim,
0.0, # dropout
False, # add_zero_attn
DataType.DT_NONE, # data_type
Expand Down
11 changes: 6 additions & 5 deletions python/flexflow/serve/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,12 @@ def build_model(self, max_tokens_per_batch):
name=f"layers.{i}.input_layernorm",
)

assert( self.llama_config.hidden_size % self.llama_config.num_attention_heads == 0 )
head_dim = self.llama_config.hidden_size // self.llama_config.num_attention_heads

qkv_proj = ffmodel.dense(
attn_norm,
3 * self.llama_config.hidden_size,
head_dim * (self.llama_config.num_attention_heads + 2 * self.llama_config.num_key_value_heads),
ActiMode.AC_MODE_NONE,
False,
name=f"layers.{i}.self_attn.qkv_proj",
Expand All @@ -148,10 +151,8 @@ def build_model(self, max_tokens_per_batch):
self.llama_config.hidden_size,
self.llama_config.num_attention_heads,
self.llama_config.num_key_value_heads,
self.llama_config.hidden_size
// self.llama_config.num_attention_heads,
self.llama_config.hidden_size
// self.llama_config.num_attention_heads,
head_dim,
head_dim,
0.0, # dropout
False, # add_zero_attn
DataType.DT_NONE, # data_type
Expand Down
Loading

0 comments on commit 2488463

Please sign in to comment.