Skip to content

Commit

Permalink
fix gqa, comment out flashinfer
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Feb 7, 2025
1 parent 87b5403 commit 85227c0
Show file tree
Hide file tree
Showing 18 changed files with 351 additions and 735 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ include(variant)
# optional
include(optional)

# flashinfer
list(APPEND FLEXFLOW_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/deps/flashinfer/include)

if (FF_GPU_BACKEND STREQUAL "cuda")
list(APPEND FF_CC_FLAGS
-DFF_USE_CUDA)
Expand Down
8 changes: 8 additions & 0 deletions include/flexflow/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
#define _FLEXFLOW_CONFIG_H_
#include "ffconst.h"
#include "flexflow/batch_config.h"
#ifdef USE_FLASHINFER
#include "flexflow/attention_config.h"
#include "flexflow/ops/kernels/gemm_impl.h"
#endif
#include "legion.h"
#include <cstring>
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
#ifdef USE_FLASHINFER
#include <cublasLt.h>
#endif
#include <cublas_v2.h>
#include <cudnn.h>
#elif defined(FF_USE_HIP_ROCM)
Expand Down Expand Up @@ -92,18 +96,22 @@ struct FFHandler {
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
cudnnHandle_t dnn;
cublasHandle_t blas;
#ifdef USE_FLASHINFER
cublasLtHandle_t blasLt;
// Internal::GemmEngine *gemm_engine;
#endif
#else
miopenHandle_t dnn;
hipblasHandle_t blas;
#endif
void *workSpace;
size_t workSpaceSize;
CombinedBatchConfigMetaStruct *batch_config_metadata;
#ifdef USE_FLASHINFER
AttentionMetaData *incr_attention_metadata;
AttentionMetaData *tree_search_attention_metadata;
AttentionMetaData *tree_verify_attention_metadata;
#endif

// request info + token info + topolopgy mask info
size_t batch_config_metadata_size = sizeof(CombinedBatchConfigMetaStruct);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace FlexFlow {
namespace Kernels {
namespace IncMultiHeadAttention {

#ifdef USE_FLASHINFER
// kv layout: [num_pages, 2, page_size, num_kv_heads, head_dim]
__device__ __forceinline__ size_t get_k_entry_offset(int const req_idx,
int const token_idx,
Expand Down Expand Up @@ -44,6 +45,7 @@ return ((req_idx * max_num_pages + token_idx / kPagesize) * kPagesize +
num_heads *
head_dim;
}
#endif

template <typename DT>
void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m,
Expand All @@ -57,12 +59,13 @@ void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m,
ffStream_t stream);

template <typename DT>
void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m,
void apply_scaling_and_rotary(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
int shard_id,
DT *output_ptr,
ffStream_t stream);

#ifdef USE_FLASHINFER
// [For the tokens in batch]
// Update the kv cache, and compact the q array.
// Source: qkv projeciton array of tokens in the batch.
Expand All @@ -79,6 +82,8 @@ void produce_output(IncMultiHeadSelfAttentionMeta const *m,
DT *output_ptr,
ffStream_t stream);

#endif

template <typename DT>
__global__ void apply_position_bias_qkprd(DT *input_ptr,
int num_tokens,
Expand Down
5 changes: 3 additions & 2 deletions inference/python/incr_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,12 @@ def main():
results = llm.generate(prompts, max_length=configs.max_length)
else:
if "max_length" not in configs_dict:
result = llm.generate("Three tips for staying healthy are: ")
results = llm.generate("Three tips for staying healthy are: ")
else:
result = llm.generate(
results = llm.generate(
"Three tips for staying healthy are: ", max_length=configs.max_length
)
print("Final output:", results[0].output_text)

llm.stop_server()

Expand Down
8 changes: 8 additions & 0 deletions python/flexflow/serve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def init(
benchmarking: Optional[bool] = None,
inference_debugging: Optional[bool] = None,
fusion: Optional[bool] = None,
log_instance_cration: Optional[bool] = None,
):
"""
Configure FlexFlow Serve and start the runtime.
Expand Down Expand Up @@ -87,6 +88,7 @@ def init(
- benchmarking: whether to run benchmaking only, without loading real weights, defaults to False
- inference_debugging: whether to run inference in debugging mode, saving all inputs/outputs/weights to file, defaults to False
- fusion: whether to enable the FlexFlow operator fusion optimization, defaults to True
- log_instance_creation: whether to log the creation of FlexFlow instances, defaults to False
The configurations are passed down to the FlexFlow runtime (implemented in C++) via command line arguments.
Expand Down Expand Up @@ -127,6 +129,8 @@ def init(
:type inference_debugging: Optional[bool], optional
:param fusion: whether to enable the FlexFlow operator fusion optimization, defaults to True
:type fusion: Optional[bool], optional
:param log_instance_cration: whether to log the creation of FlexFlow instances, defaults to False
:type log_instance_cration: Optional[bool], optional
:raises ValueError: this function will raise an exception if the user passes both a configs_dict and some named parameters
:raises TypeError: this function will raise an exception if the configs_dict is not a dictionary
Expand All @@ -153,6 +157,7 @@ def init(
benchmarking is not None,
inference_debugging is not None,
fusion is not None,
log_instance_cration is not None,
]
):
raise ValueError("Cannot pass both configs_dict and individual args")
Expand Down Expand Up @@ -180,6 +185,7 @@ def init(
"benchmarking": benchmarking,
"inference_debugging": inference_debugging,
"fusion": fusion,
"log_instance_cration": log_instance_cration,
}

# Check that mandatory configs are present
Expand Down Expand Up @@ -230,5 +236,7 @@ def init(
configs_dict["inference_debugging"] = False
if configs_dict.get("fusion", None) is None:
configs_dict["fusion"] = True
if configs_dict.get("log_instance_cration", None) is None:
configs_dict["log_instance_cration"] = False

init_flexflow_runtime(configs_dict)
8 changes: 4 additions & 4 deletions src/ops/inc_multihead_self_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ __global__ void
}

template <typename DT>
void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m,
void apply_scaling_and_rotary(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
int shard_id,
DT *output_ptr,
Expand Down Expand Up @@ -999,7 +999,7 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m,
stream);

// phase 1: Implement kernel to apply rotary embedding and scaling
compute_qkv_kernel(
apply_scaling_and_rotary(
m, bc, shard_id, static_cast<DT *>(m->devQKVProjArray), stream);
update_kv_cache_kernel<DT>(m, bc, stream);

Expand Down Expand Up @@ -1874,14 +1874,14 @@ template void
half *output_ptr,
hipStream_t stream);

template void Kernels::IncMultiHeadAttention::compute_qkv_kernel<float>(
template void Kernels::IncMultiHeadAttention::apply_scaling_and_rotary<float>(
IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
int shard_id,
float *output_ptr,
hipStream_t stream);

template void Kernels::IncMultiHeadAttention::compute_qkv_kernel<half>(
template void Kernels::IncMultiHeadAttention::apply_scaling_and_rotary<half>(
IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
int shard_id,
Expand Down
Loading

0 comments on commit 85227c0

Please sign in to comment.