From 85227c0a08c9c106bdfe2c1e546a2243008dc0b4 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Fri, 7 Feb 2025 06:53:40 +0000 Subject: [PATCH] fix gqa, comment out flashinfer --- CMakeLists.txt | 3 + include/flexflow/config.h | 8 + .../inc_multihead_self_attention_kernels.h | 7 +- inference/python/incr_decoding.py | 5 +- python/flexflow/serve/__init__.py | 8 + src/ops/inc_multihead_self_attention.cpp | 8 +- src/ops/inc_multihead_self_attention.cu | 963 +++++------------- src/ops/kernels/gemm_impl.cu | 3 +- .../inc_multihead_self_attention_kernels.cu | 2 + src/ops/spec_inc_multihead_self_attention.cpp | 2 +- src/ops/spec_inc_multihead_self_attention.cu | 2 +- src/ops/tree_inc_multihead_self_attention.cpp | 2 +- src/ops/tree_inc_multihead_self_attention.cu | 2 +- src/runtime/model.cu | 35 +- src/runtime/request_manager.cc | 2 +- src/runtime/request_manager.cu | 16 +- tests/fine_grained_alignment_test.sh | 16 +- tests/inference/huggingface_inference.py | 2 +- 18 files changed, 351 insertions(+), 735 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c13e00942..b9eec07a1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -200,6 +200,9 @@ include(variant) # optional include(optional) +# flashinfer +list(APPEND FLEXFLOW_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/deps/flashinfer/include) + if (FF_GPU_BACKEND STREQUAL "cuda") list(APPEND FF_CC_FLAGS -DFF_USE_CUDA) diff --git a/include/flexflow/config.h b/include/flexflow/config.h index cce4f51ed..3246da1af 100644 --- a/include/flexflow/config.h +++ b/include/flexflow/config.h @@ -17,12 +17,16 @@ #define _FLEXFLOW_CONFIG_H_ #include "ffconst.h" #include "flexflow/batch_config.h" +#ifdef USE_FLASHINFER #include "flexflow/attention_config.h" #include "flexflow/ops/kernels/gemm_impl.h" +#endif #include "legion.h" #include #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) +#ifdef USE_FLASHINFER #include +#endif #include #include #elif defined(FF_USE_HIP_ROCM) @@ -92,8 +96,10 @@ struct FFHandler { #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) cudnnHandle_t dnn; cublasHandle_t blas; +#ifdef USE_FLASHINFER cublasLtHandle_t blasLt; // Internal::GemmEngine *gemm_engine; +#endif #else miopenHandle_t dnn; hipblasHandle_t blas; @@ -101,9 +107,11 @@ struct FFHandler { void *workSpace; size_t workSpaceSize; CombinedBatchConfigMetaStruct *batch_config_metadata; +#ifdef USE_FLASHINFER AttentionMetaData *incr_attention_metadata; AttentionMetaData *tree_search_attention_metadata; AttentionMetaData *tree_verify_attention_metadata; +#endif // request info + token info + topolopgy mask info size_t batch_config_metadata_size = sizeof(CombinedBatchConfigMetaStruct); diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h index 0a2638026..33fce61c0 100644 --- a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h @@ -14,6 +14,7 @@ namespace FlexFlow { namespace Kernels { namespace IncMultiHeadAttention { +#ifdef USE_FLASHINFER // kv layout: [num_pages, 2, page_size, num_kv_heads, head_dim] __device__ __forceinline__ size_t get_k_entry_offset(int const req_idx, int const token_idx, @@ -44,6 +45,7 @@ return ((req_idx * max_num_pages + token_idx / kPagesize) * kPagesize + num_heads * head_dim; } +#endif template void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, @@ -57,12 +59,13 @@ void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m, ffStream_t stream); template -void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, +void apply_scaling_and_rotary(IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, int shard_id, DT *output_ptr, ffStream_t stream); +#ifdef USE_FLASHINFER // [For the tokens in batch] // Update the kv cache, and compact the q array. // Source: qkv projeciton array of tokens in the batch. @@ -79,6 +82,8 @@ void produce_output(IncMultiHeadSelfAttentionMeta const *m, DT *output_ptr, ffStream_t stream); +#endif + template __global__ void apply_position_bias_qkprd(DT *input_ptr, int num_tokens, diff --git a/inference/python/incr_decoding.py b/inference/python/incr_decoding.py index 009d275b8..87ee03e2d 100644 --- a/inference/python/incr_decoding.py +++ b/inference/python/incr_decoding.py @@ -116,11 +116,12 @@ def main(): results = llm.generate(prompts, max_length=configs.max_length) else: if "max_length" not in configs_dict: - result = llm.generate("Three tips for staying healthy are: ") + results = llm.generate("Three tips for staying healthy are: ") else: - result = llm.generate( + results = llm.generate( "Three tips for staying healthy are: ", max_length=configs.max_length ) + print("Final output:", results[0].output_text) llm.stop_server() diff --git a/python/flexflow/serve/__init__.py b/python/flexflow/serve/__init__.py index f8ca8f817..408c2ef90 100644 --- a/python/flexflow/serve/__init__.py +++ b/python/flexflow/serve/__init__.py @@ -58,6 +58,7 @@ def init( benchmarking: Optional[bool] = None, inference_debugging: Optional[bool] = None, fusion: Optional[bool] = None, + log_instance_cration: Optional[bool] = None, ): """ Configure FlexFlow Serve and start the runtime. @@ -87,6 +88,7 @@ def init( - benchmarking: whether to run benchmaking only, without loading real weights, defaults to False - inference_debugging: whether to run inference in debugging mode, saving all inputs/outputs/weights to file, defaults to False - fusion: whether to enable the FlexFlow operator fusion optimization, defaults to True + - log_instance_creation: whether to log the creation of FlexFlow instances, defaults to False The configurations are passed down to the FlexFlow runtime (implemented in C++) via command line arguments. @@ -127,6 +129,8 @@ def init( :type inference_debugging: Optional[bool], optional :param fusion: whether to enable the FlexFlow operator fusion optimization, defaults to True :type fusion: Optional[bool], optional + :param log_instance_cration: whether to log the creation of FlexFlow instances, defaults to False + :type log_instance_cration: Optional[bool], optional :raises ValueError: this function will raise an exception if the user passes both a configs_dict and some named parameters :raises TypeError: this function will raise an exception if the configs_dict is not a dictionary @@ -153,6 +157,7 @@ def init( benchmarking is not None, inference_debugging is not None, fusion is not None, + log_instance_cration is not None, ] ): raise ValueError("Cannot pass both configs_dict and individual args") @@ -180,6 +185,7 @@ def init( "benchmarking": benchmarking, "inference_debugging": inference_debugging, "fusion": fusion, + "log_instance_cration": log_instance_cration, } # Check that mandatory configs are present @@ -230,5 +236,7 @@ def init( configs_dict["inference_debugging"] = False if configs_dict.get("fusion", None) is None: configs_dict["fusion"] = True + if configs_dict.get("log_instance_cration", None) is None: + configs_dict["log_instance_cration"] = False init_flexflow_runtime(configs_dict) diff --git a/src/ops/inc_multihead_self_attention.cpp b/src/ops/inc_multihead_self_attention.cpp index c0c17870c..5c39c4cbb 100644 --- a/src/ops/inc_multihead_self_attention.cpp +++ b/src/ops/inc_multihead_self_attention.cpp @@ -842,7 +842,7 @@ __global__ void } template -void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, +void apply_scaling_and_rotary(IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, int shard_id, DT *output_ptr, @@ -999,7 +999,7 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, stream); // phase 1: Implement kernel to apply rotary embedding and scaling - compute_qkv_kernel( + apply_scaling_and_rotary( m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); update_kv_cache_kernel
(m, bc, stream); @@ -1874,14 +1874,14 @@ template void half *output_ptr, hipStream_t stream); -template void Kernels::IncMultiHeadAttention::compute_qkv_kernel( +template void Kernels::IncMultiHeadAttention::apply_scaling_and_rotary( IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, int shard_id, float *output_ptr, hipStream_t stream); -template void Kernels::IncMultiHeadAttention::compute_qkv_kernel( +template void Kernels::IncMultiHeadAttention::apply_scaling_and_rotary( IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, int shard_id, diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index eb47d699b..8bfd3afb5 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -13,8 +13,10 @@ * limitations under the License. */ #include "cuComplex.h" +#ifdef USE_FLASHINFER #include "flashinfer/decode_attention_decl.cuh" #include "flashinfer/prefill_attention_decl.cuh" +#endif #include "flexflow/ffconst_utils.h" #include "flexflow/ops/inc_multihead_self_attention.h" #include "flexflow/ops/kernels/decompress_kernels.h" @@ -34,6 +36,7 @@ using Legion::Memory; namespace Kernels { namespace IncMultiHeadAttention { +#ifdef USE_FLASHINFER using flashinfer::BatchDecodeHandler; using flashinfer::BatchDecodeWithPagedKVCacheWrapperDispatched; using flashinfer::BatchPrefillHandler; @@ -44,6 +47,22 @@ using flashinfer::paged_kv_t; using flashinfer::PageStorage; using flashinfer::PosEncodingMode; using flashinfer::QKVLayout; +#endif + +std::string get_fwd_dbg_folder(IncMultiHeadSelfAttentionMeta const *m, + int shard_id) { + std::string op_name_without_uid = + IncMultiHeadSelfAttention::get_op_name_without_uid(m); + fs::path dst_filepath = get_dst_folder("fwd", m->decoding_step, shard_id); + if (m->layer_guid.model_id > 0) { + assert(false && "Model ID > 0 not supported yet"); + } + std::string layername = "layers." + + std::to_string(m->layer_guid.transformer_layer_id) + + "." + op_name_without_uid; + dst_filepath /= layername; + return dst_filepath.string(); +} template __global__ void store_kv_cache(DT const *devQKVProjArray, @@ -56,6 +75,11 @@ __global__ void store_kv_cache(DT const *devQKVProjArray, int num_q_heads, int num_kv_heads) { CUDA_KERNEL_LOOP(i, num_tokens * head_dim * num_kv_heads) { + // devQKVProjArray: [head_dim, tot_num_heads, num_tokens] + // kCache_ptr: [head_dim, num_kv_heads, max_seq_len, max_batch_size] + // vCache_ptr: [head_dim, num_kv_heads, max_seq_len, max_batch_size] + + // i is iterating over one set of key/val projections from the input int token_idx = i / (head_dim * num_kv_heads); int head_idx = (i / head_dim) % num_kv_heads; int offset = i % head_dim; @@ -168,11 +192,6 @@ void run_batched_matmul(const IncMultiHeadSelfAttentionMeta *meta, computeType, algo)); } else { - - lda = (transa == CUBLAS_OP_N) ? m : k; - ldb = (transb == CUBLAS_OP_N) ? k : n; - ldc = m; - const DT **h_A_array = new const DT*[batchCount]; const DT **h_B_array = new const DT*[batchCount]; DT **h_C_array = new DT*[batchCount]; @@ -222,6 +241,7 @@ void run_batched_matmul(const IncMultiHeadSelfAttentionMeta *meta, template void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, BatchConfig const *bc, + DT *attn_heads, int shard_id, cudaStream_t stream) { checkCUDA(cublasSetStream(m->handle.blas, stream)); @@ -231,27 +251,20 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, assert(data_type_size(m->output_type[0]) == sizeof(DT)); cudaDataType_t compute_type = cublas_data_type; - int num_tokens = bc->num_active_tokens(); - int tokens_previous_requests = 0; - int q_block_size = m->qProjSize; - int kt_block_size = m->kProjSize; - int kt_req_block_size = - kt_block_size * m->num_kv_heads * BatchConfig::max_sequence_length(); - int vt_block_size = m->vProjSize; - int vt_req_block_size = - vt_block_size * m->num_kv_heads * BatchConfig::max_sequence_length(); - assert(m->qProjSize == m->kProjSize); - - for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i] || is_decoding_request(bc, i) || is_finetuning_bwd_request(bc, i)) { + assert(m->qProjSize == m->kProjSize && m->kProjSize == m->vProjSize); + int tot_num_heads = m->num_q_heads + 2 * m->num_kv_heads; + + int num_processed_prompt_tokens = 0; + for (int req_idx = 0; req_idx < bc->max_requests_per_batch(); req_idx++) { + if (bc->request_completed[req_idx] || is_decoding_request(bc, req_idx) || is_finetuning_bwd_request(bc, req_idx)) { continue; } - int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + - bc->requestsInfo[i].num_tokens_in_batch; + int num_new_tokens = bc->requestsInfo[req_idx].num_tokens_in_batch; + int total_tokens = bc->requestsInfo[req_idx].first_token_depth_in_request + + bc->requestsInfo[req_idx].num_tokens_in_batch; // Copy query to m->query_activation_buffer if we need to compute // PEFT backward - if (bc->requestsInfo[i].finetuning_request && !bc->requestsInfo[i].finetuning_backward_phase) { + if (bc->requestsInfo[req_idx].finetuning_request && !bc->requestsInfo[req_idx].finetuning_backward_phase) { int max_peft_tokens = BatchConfig::max_finetuning_sequence_length(); size_t activation_size_needed = sizeof(DT) * max_peft_tokens * m->num_q_heads * m->qProjSize; @@ -267,16 +280,16 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, std::cout << "sizeof(DT)" << sizeof(DT) << std::endl; } assert(activation_size_needed == m->allocated_peft_buffer_size1); - int parallelism = m->qProjSize * m->num_q_heads * bc->requestsInfo[i].num_tokens_in_batch; + int parallelism = m->qProjSize * m->num_q_heads * bc->requestsInfo[req_idx].num_tokens_in_batch; store_query_cache<<>>( static_cast
(m->devQKVProjArray), static_cast
(m->query_activation_buffer), - bc->requestsInfo[i].num_tokens_in_batch, - bc->requestsInfo[i].first_token_offset_in_batch, - bc->requestsInfo[i].first_token_depth_in_request, + bc->requestsInfo[req_idx].num_tokens_in_batch, + bc->requestsInfo[req_idx].first_token_offset_in_batch, + bc->requestsInfo[req_idx].first_token_depth_in_request, m->qProjSize, m->num_q_heads, m->num_kv_heads); @@ -293,27 +306,27 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, int n = total_tokens; int k = m->qProjSize; // before transpositions - int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, - ldc = m_; + int lda = m->qProjSize * tot_num_heads; + int ldb = m->kProjSize * m->num_kv_heads; + int ldc = num_new_tokens; // N.B. strides are applied before transpose operations - int strideA = q_block_size; - int strideB = kt_block_size; + int strideA = m->qProjSize; + int strideB = m->kProjSize; int strideC = num_new_tokens * total_tokens; // matrix A: devQKVProjArray - // matrix A's layout: [qProjSize*num_q_heads + 2*kvProjSize*num_kv_heads, num_new_tokens] + // matrix A's layout: [qProjSize, tot_num_heads, num_new_tokens] // To get query projection, skip over Q entries from previous requests DT const *A = static_cast
(m->devQKVProjArray) + - bc->requestsInfo[i].first_token_offset_in_batch * + bc->requestsInfo[req_idx].first_token_offset_in_batch * m->qProjSize * (m->num_q_heads + 2 * m->num_kv_heads); // matrix B: key cache - // matrix B's layout: [kProjSize * num_kv_heads, total_tokens] + // matrix B's layout: [kProjSize, num_kv_heads, total_tokens] // To get B, skip over K entries from previous requests (all heads + // padding) - DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; - // matrix C: qk_prods + DT const *B = static_cast
(m->keyCache) + req_idx * (m->kProjSize * m->num_kv_heads * BatchConfig::max_sequence_length()); + // matrix C: qk_prods (current req only) // matrix C's layout: [num_new_tokens, total_tokens, num_q_heads] - // To get C, skip over QK.T products from previous requests DT *C = static_cast
(m->qk_prods); run_batched_matmul
(m, m->handle.blas, CUBLAS_OP_T, CUBLAS_OP_N, @@ -330,40 +343,50 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, 1, m->num_q_heads/m->num_kv_heads, 1); + std::string fpath = get_fwd_dbg_folder(m, shard_id) + ".qk_prods"; + save_tensor(static_cast
(m->qk_prods), num_new_tokens*total_tokens*m->num_q_heads, fpath.c_str()); } // Step 2: Add alibi position bias to qk production - // matrix C: qk_prods - // matrix C's layout: [num_new_tokens, total_tokens, num_q_heads] - // To get C, skip over QK.T products from previous requests - DT *C = static_cast
(m->qk_prods); - if (*m->position_bias) { - size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; - apply_position_bias_qkprd<<>>(C, - num_new_tokens, - total_tokens, - m->num_q_heads, - m->global_num_q_heads, - shard_id); - } - - // Step 3: Apply causal mask. Fill all elements above diagonal in qk prods - // with -inf to force causal attention. - assert(num_new_tokens <= total_tokens); - size_t entries_above_diagonal = num_new_tokens * (num_new_tokens - 1) / 2; - if (entries_above_diagonal > 0) { - size_t parallelism = m->num_q_heads * entries_above_diagonal; - fill_entries_above_diagonal<<(m->qk_prods); + if (*m->position_bias) { + size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; + apply_position_bias_qkprd<<>>(C, num_new_tokens, total_tokens, m->num_q_heads, - entries_above_diagonal, - static_cast
(-INFINITY)); + m->global_num_q_heads, + shard_id); + } + } + + // Step 3: Apply causal mask. Fill all elements above diagonal in qk prods + // with -inf to force causal attention. + { + assert(num_new_tokens <= total_tokens); + size_t entries_above_diagonal = num_new_tokens * (num_new_tokens - 1) / 2; + if (entries_above_diagonal > 0) { + // matrix C: qk_prods (current req only) + // matrix C's layout: [num_new_tokens, total_tokens, num_q_heads] + DT *C = static_cast
(m->qk_prods); + size_t parallelism = m->num_q_heads * entries_above_diagonal; + fill_entries_above_diagonal<<>>(C, + num_new_tokens, + total_tokens, + m->num_q_heads, + entries_above_diagonal, + static_cast
(-INFINITY)); + } + std::string fpath = get_fwd_dbg_folder(m, shard_id) + ".qk_prods.masked"; + save_tensor(static_cast
(m->qk_prods), num_new_tokens*total_tokens*m->num_q_heads, fpath.c_str()); } // Step 4: Compute Softmax(QK.T/sqrt(d_k)) @@ -389,6 +412,11 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, h_param, w_param)); float softmax_alpha = 1.0f, softmax_beta = 0.0f; + // matrix C: qk_prods (current req only) + // matrix C's layout: [num_new_tokens, total_tokens, num_q_heads] + DT *C = static_cast
(m->qk_prods); + // matrix C_softmax: qk_prods_softmax (current req only) + // matrix C_softmax's layout: [num_new_tokens, total_tokens, num_q_heads] DT *C_softmax = static_cast
(m->qk_prods_softmax); // The softmax operation below is executed according to the // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The @@ -403,22 +431,25 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, &softmax_beta, m->qk_tensor, C_softmax)); + std::string fpath = get_fwd_dbg_folder(m, shard_id) + ".qk_prods_softmax"; + save_tensor(static_cast
(m->qk_prods_softmax), num_new_tokens*total_tokens*m->num_q_heads, fpath.c_str()); + // Copy C_softmax to m->softmax_activation_buffer if we need to compute + // PEFT backward + if (bc->requestsInfo[req_idx].finetuning_request) { + int max_peft_tokens = BatchConfig::max_finetuning_sequence_length(); + DT *C_softmax = static_cast
(m->qk_prods_softmax); + size_t activation_size_needed = + sizeof(DT) * max_peft_tokens * max_peft_tokens * m->num_q_heads; + assert(activation_size_needed == m->allocated_peft_buffer_size2); + checkCUDA(cudaMemcpyAsync(m->softmax_activation_buffer, + C_softmax, + sizeof(DT) * total_tokens * num_new_tokens * + m->num_q_heads, + cudaMemcpyDeviceToDevice, + stream)); + } } - // Copy C_softmax to m->softmax_activation_buffer if we need to compute - // PEFT backward - if (bc->requestsInfo[i].finetuning_request) { - int max_peft_tokens = BatchConfig::max_finetuning_sequence_length(); - DT *C_softmax = static_cast
(m->qk_prods_softmax); - size_t activation_size_needed = - sizeof(DT) * max_peft_tokens * max_peft_tokens * m->num_q_heads; - assert(activation_size_needed == m->allocated_peft_buffer_size2); - checkCUDA(cudaMemcpyAsync(m->softmax_activation_buffer, - C_softmax, - sizeof(DT) * total_tokens * num_new_tokens * - m->num_q_heads, - cudaMemcpyDeviceToDevice, - stream)); - } + // Step 5: Matmul softmax(QK.T/sqrt(d_k)) by V. Implemented as V @ // softmax(QK.T/sqrt(d_k)).T { @@ -432,15 +463,15 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, int ldb = n; int ldc = m_ * m->num_q_heads; // N.B. strides are applied before transpose operations - int strideA = vt_block_size; + int strideA = m->vProjSize; int strideB = num_new_tokens * total_tokens; int strideC = m->vProjSize; // matrix A: value cache // matrix A's layout: [vProjSize, num_kv_heads, total_tokens] // To get A, skip over V.T entries from previous requests (all heads + // padding) - DT *A = static_cast
(m->valueCache) + i * vt_req_block_size; - // matrix B: qk_prods_softmax + DT *A = static_cast
(m->valueCache) + req_idx * (m->vProjSize * m->num_kv_heads * BatchConfig::max_sequence_length()); + // matrix B: qk_prods_softmax (current req only) // matrix B's layout: [num_new_tokens, total_tokens, num_q_heads] // To get B, skip over softmax(QK.T/sqrt(d_k)) entries from previous // requests (all heads) @@ -450,8 +481,8 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, // To get C, skip over softmax(QK.T/sqrt(d_k))V products from previous // requests // store the result attn heads, also skip the genration tokens - DT *C = static_cast
(m->attn_heads) + - (bc->requestsInfo[i].first_token_offset_in_batch) * + DT *C = static_cast
(attn_heads) + + (bc->requestsInfo[req_idx].first_token_offset_in_batch) * m->num_q_heads * m->vProjSize; run_batched_matmul
(m, m->handle.blas, CUBLAS_OP_N, CUBLAS_OP_T, @@ -468,16 +499,18 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, m->num_q_heads/m->num_kv_heads, 1, 1); + std::string fpath = get_fwd_dbg_folder(m, shard_id) + ".qk_prods_softmax"; + save_tensor(static_cast
(attn_heads), num_new_tokens * m->num_q_heads * m->vProjSize, fpath.c_str()); } - tokens_previous_requests += num_new_tokens; + num_processed_prompt_tokens += num_new_tokens; } - if (tokens_previous_requests != (num_tokens - bc->num_generation_tokens)) { + if (num_processed_prompt_tokens != (bc->num_active_tokens() - bc->num_generation_tokens)) { bc->print(); - printf("tokens_previous_requests: %i\n", tokens_previous_requests); - printf("num_tokens: %i\n", num_tokens); + printf("num_processed_prompt_tokens: %i\n", num_processed_prompt_tokens); + printf("bc->num_active_tokens(): %i\n", bc->num_active_tokens()); printf("bc->num_generation_tokens: %i\n", bc->num_generation_tokens); } - assert(tokens_previous_requests == (num_tokens - bc->num_generation_tokens)); + assert(num_processed_prompt_tokens == (bc->num_active_tokens() - bc->num_generation_tokens)); } // gridDim = num_heads @@ -746,30 +779,32 @@ __global__ void compute_attention_kernel_generation_kernel( } } -// // only used by MPT model. https://arxiv.org/abs/2108.12409 -// template -// __global__ void apply_position_bias_qkprd(DT *input_ptr, -// int num_tokens, -// int num_total_tokens, -// int num_heads, -// int global_num_q_heads, -// int shard_id) { -// CUDA_KERNEL_LOOP(i, num_tokens * num_total_tokens * num_heads) { -// // get head_idx, -// int head_idx = i / (num_tokens * num_total_tokens) + (num_heads * shard_id); -// int position_idx = (i / num_tokens) % num_total_tokens; -// position_idx = position_idx + 1 - num_total_tokens; -// // 8 is alibi_bias_max in -// // https://huggingface.co/mosaicml/mpt-30b/blob/main/config.json -// float base = (float)(head_idx + 1) * 8 / global_num_q_heads; -// float slopes = 1.0 / pow(2, base); -// // if(i == 0){ -// // printf("see position: %d, %f, %f, %f\n", position_idx, base, slopes, -// // position_idx * slopes); -// // } -// input_ptr[i] += static_cast
(position_idx * slopes); -// } -// } +#ifndef USE_FLASHINFER +// only used by MPT model. https://arxiv.org/abs/2108.12409 +template +__global__ void apply_position_bias_qkprd(DT *input_ptr, + int num_tokens, + int num_total_tokens, + int num_heads, + int global_num_q_heads, + int shard_id) { + CUDA_KERNEL_LOOP(i, num_tokens * num_total_tokens * num_heads) { + // get head_idx, + int head_idx = i / (num_tokens * num_total_tokens) + (num_heads * shard_id); + int position_idx = (i / num_tokens) % num_total_tokens; + position_idx = position_idx + 1 - num_total_tokens; + // 8 is alibi_bias_max in + // https://huggingface.co/mosaicml/mpt-30b/blob/main/config.json + float base = (float)(head_idx + 1) * 8 / global_num_q_heads; + float slopes = 1.0 / pow(2, base); + // if(i == 0){ + // printf("see position: %d, %f, %f, %f\n", position_idx, base, slopes, + // position_idx * slopes); + // } + input_ptr[i] += static_cast
(position_idx * slopes); + } +} +#endif template __global__ void scaling_query_kernel(DT *input_ptr, @@ -935,7 +970,7 @@ __global__ void } template -void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, +void apply_scaling_and_rotary(IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, int shard_id, DT *output_ptr, @@ -993,8 +1028,10 @@ void update_kv_cache_kernel(IncMultiHeadSelfAttentionMeta const *m, int tot_num_heads = m->num_q_heads + 2 * m->num_kv_heads; assert(m->hidden_size % m->num_q_heads == 0); int head_dim = m->hidden_size / m->num_q_heads; + assert(head_dim == m->qProjSize); if (num_tokens > 0) { int parallelism = head_dim*tot_num_heads * num_tokens; + // devQKVProj has shape [qProjSize, tot_num_heads, num_new_tokens] store_kv_cache<<decoding_step, shard_id); - if (m->layer_guid.model_id > 0) { - assert(false && "Model ID > 0 not supported yet"); - } - std::string layername = "layers." + - std::to_string(m->layer_guid.transformer_layer_id) + - "." + op_name_without_uid; - dst_filepath /= layername; - return dst_filepath.string(); -} +#ifdef USE_FLASHINFER template void incr_attention(IncMultiHeadSelfAttentionMeta *m, BatchConfig const *bc, @@ -1151,7 +1175,7 @@ void incr_attention(IncMultiHeadSelfAttentionMeta *m, template -void inference_kernel(IncMultiHeadSelfAttentionMeta *m, +void inference_kernel_flashinfer(IncMultiHeadSelfAttentionMeta *m, BatchConfig const *bc, int shard_id, DT const *qkv_ptr, @@ -1182,7 +1206,7 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, // } // phase 1: Implement kernel to apply rotary embedding and scaling - compute_qkv_kernel( + apply_scaling_and_rotary( m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); // if (std::string(m->op_name).find("layers.0.self_attn") != std::string::npos) { @@ -1249,6 +1273,69 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, // stream); } +#endif + +template +void inference_kernel(IncMultiHeadSelfAttentionMeta *m, + BatchConfig const *bc, + int shard_id, + DT const *qkv_ptr, + DT *output_ptr, + cudaStream_t stream) { + + // phase 0: copy calculated qkv into devQKVProjArray + // [qProjSize, tot_num_heads, num_new_tokens] + assert(m->qProjSize == m->kProjSize && m->qProjSize == m->vProjSize); + size_t tot_num_heads = m->num_q_heads + 2 * m->num_kv_heads; + size_t qkv_proj_size = m->qProjSize * tot_num_heads * bc->num_active_tokens(); + + cudaMemcpyAsync(m->devQKVProjArray, + qkv_ptr, + qkv_proj_size * sizeof(DT), + cudaMemcpyDeviceToDevice, + stream); + + if (m->inference_debugging) { + std::string fpath = get_fwd_dbg_folder(m, shard_id) + ".devQKVProjArray"; + std::cout << "save tensor to: " << fpath << std::endl; + std::cout << "qkv_proj_size: " << qkv_proj_size << " (m->qProjSize=" << m->qProjSize << ", tot_num_heads=" << tot_num_heads << ", bc->num_active_tokens()=" << bc->num_active_tokens() << ")" << std::endl; + save_tensor(static_cast
(m->devQKVProjArray), qkv_proj_size, fpath.c_str()); + } + + + // phase 1: Implement kernel to apply rotary embedding and scaling + apply_scaling_and_rotary( + m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); + + if (m->inference_debugging) { + std::string fpath = get_fwd_dbg_folder(m, shard_id) + ".post_rope"; + save_tensor(static_cast
(m->devQKVProjArray), qkv_proj_size, fpath.c_str()); + } + + update_kv_cache_kernel
(m, bc, stream); + + if (m->inference_debugging) { + size_t key_cache_size = m->kProjSize * m->num_kv_heads * BatchConfig::max_sequence_length()*BatchConfig::max_requests_per_batch(); + std::string fpath = get_fwd_dbg_folder(m, shard_id) + ".key_cache"; + save_tensor(static_cast
(m->keyCache), key_cache_size, fpath.c_str()); + fpath = get_fwd_dbg_folder(m, shard_id) + ".value_cache"; + save_tensor(static_cast
(m->valueCache), key_cache_size, fpath.c_str()); + } + + if (bc->num_generation_tokens > 0) { + // phase 3: Compute attention score for generation tokens + compute_attention_kernel_generation
( + m, bc, output_ptr, stream); + } + + if (bc->num_tokens > bc->num_generation_tokens) { + // phase 4: Compute attention score for prompt tokens; + compute_attention_kernel_prompt
(m, bc, output_ptr, shard_id, stream); + } + +} + + std::string get_peft_dbg_folder(IncMultiHeadSelfAttentionMeta const *m, int shard_id) { std::string op_name_without_uid = @@ -1382,7 +1469,7 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, float alpha = 1.0f, beta = 0.0f; // matrix A: qk_prods_softmax // matrix A's layout: [num_new_tokens, total_tokens, num_q_heads] - DT const *A = static_cast
(m->softmax_activation_buffer); + DT const *A = static_cast
(m->qk_prods_softmax); // matrix B: attn_heads gradients // matrix B's layout: [vProjSize * num_q_heads, num_new_tokens] DT const *B = static_cast
(m->handle.workSpace); @@ -1743,560 +1830,38 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, } } -template -void peft_bwd_kernel2(IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - int shard_id, - DT *input_grad_ptr, - DT const *output_grad_ptr, - cudaStream_t stream) { - assert(!m->offload); - checkCUDA(cublasSetStream(m->handle.blas, stream)); - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); - cudnnDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]); - assert(data_type_size(m->output_type[0]) == sizeof(DT)); - cudaDataType_t compute_type = cublas_data_type; +} // namespace IncMultiHeadAttention +} // namespace Kernels - assert( - bc->peft_bwd_applies_to_this_layer(m->layer_guid.transformer_layer_id)); - int i = bc->finetuning_request_index(); - int num_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int num_total_tokens = bc->requestsInfo[i].first_token_depth_in_request + - bc->requestsInfo[i].num_tokens_in_batch; - // Currently assume we are calculating gradients for all tokens - // of a request - assert(num_tokens == num_total_tokens); - int q_block_size = m->qProjSize; - int kt_block_size = m->kProjSize; - int kt_req_block_size = - kt_block_size * m->num_kv_heads * BatchConfig::max_finetuning_sequence_length(); - int vt_block_size = m->vProjSize; - int vt_req_block_size = - vt_block_size * m->num_kv_heads * BatchConfig::max_finetuning_sequence_length(); - assert(m->qProjSize == m->kProjSize && m->kProjSize == m->vProjSize); - // Step 1: copy gradient before final projection into workspace - { - int m_ = m->vProjSize * m->num_q_heads; - int n_ = num_tokens; - DT *C = static_cast
(m->handle.workSpace); - cudaMemcpyAsync(C, - output_grad_ptr + - bc->requestsInfo[i].first_token_offset_in_batch * - m->oProjSize, - m_ * n_ * sizeof(DT), - cudaMemcpyDeviceToDevice, - stream); - if (m->inference_debugging) { - // save result to file for checking - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".o_proj.input_gradient_0"; - save_tensor(C, m_ * n_, filename.c_str()); - } +using namespace Kernels::IncMultiHeadAttention; + +/*static*/ +void IncMultiHeadSelfAttention::inference_kernel_wrapper( + IncMultiHeadSelfAttentionMeta *m, + BatchConfig const *bc, + int shard_id, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output) { + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + + cudaEvent_t t_start, t_end; + if (m->profiling) { + cudaEventCreate(&t_start); + cudaEventCreate(&t_end); + cudaEventRecord(t_start, stream); } - // Step 1.5: recompute m->softmax_activation_buffer - { - // compute query-key product QK.T/sqrt(d_k) - DT alpha = 1.0f, beta = 0.0f; - if (*m->qk_prod_scaling) { - alpha = static_cast
(1.0f / sqrt(m->kProjSize)); - } - // after transpositions - int m_ = num_tokens; - int n = num_total_tokens; - assert(m_ == n); - int k = m->qProjSize; - // before transpositions - int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, - ldc = m_; - // N.B. strides are applied before transpose operations - int strideA = q_block_size; - int strideB = kt_block_size; - int strideC = num_tokens * num_total_tokens; - - // matrix A: query_activation_buffer - // matrix A's layout: [qProjSize*num_q_heads, num_tokens] - // To get query projection, skip over Q entries from previous requests - DT const *A = static_cast
(m->query_activation_buffer); - // matrix B: key cache - // matrix B's layout: [kProjSize * num_kv_heads, num_total_tokens] - // To get B, skip over K entries from previous requests (all heads + - // padding) - DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; - // matrix C: qk_prods - // matrix C's layout: [num_tokens, num_total_tokens, num_q_heads] - // To get C, skip over QK.T products from previous requests - DT *C = static_cast
(m->qk_prods); - run_batched_matmul
(m, m->handle.blas, - CUBLAS_OP_T, CUBLAS_OP_N, - m_, n, k, - &alpha, - A, cublas_data_type, lda, strideA, - B, cublas_data_type, ldb, strideB, - &beta, - C, cublas_data_type, ldc, strideC, - m->num_q_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP, - stream, - 1, - m->num_q_heads/m->num_kv_heads, - 1); - // Add alibi position bias to qk production - if (*m->position_bias) { - size_t parallelism = m->num_q_heads * num_total_tokens * num_tokens; - apply_position_bias_qkprd<<>>(C, - num_tokens, - num_total_tokens, - m->num_q_heads, - m->global_num_q_heads, - shard_id); - } - // apply causal mask - assert(num_tokens <= num_total_tokens); - size_t entries_above_diagonal = num_tokens * (num_tokens - 1) / 2; - if (entries_above_diagonal > 0) { - size_t parallelism = m->num_q_heads * entries_above_diagonal; - fill_entries_above_diagonal<<>>(C, - num_tokens, - num_total_tokens, - m->num_q_heads, - entries_above_diagonal, - static_cast
(-INFINITY)); - } - // Compute Softmax(QK.T/sqrt(d_k)) - int n_param = m->num_q_heads; - int c_param = num_total_tokens; - int h_param = 1; - int w_param = num_tokens; - checkCUDNN(cudnnSetTensor4dDescriptor(m->qk_tensor, - CUDNN_TENSOR_NCHW, - cudnn_data_type, - n_param, - c_param, - h_param, - w_param)); - float softmax_alpha = 1.0f, softmax_beta = 0.0f; - DT *C_softmax = static_cast
(m->softmax_activation_buffer); - // The softmax operation below is executed according to the - // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The - // softmax operation is computed per spatial location (H,W) per image (N) - // across dimension C. - checkCUDNN(cudnnSoftmaxForward(m->handle.dnn, - CUDNN_SOFTMAX_ACCURATE, - CUDNN_SOFTMAX_MODE_CHANNEL, - &softmax_alpha, - m->qk_tensor, - C, - &softmax_beta, - m->qk_tensor, - C_softmax)); - } - - // Step 2: compute gradients w.r.t. value - { - float alpha = 1.0f, beta = 0.0f; - // matrix A: qk_prods_softmax - // matrix A's layout: [num_new_tokens, total_tokens, num_q_heads] - DT const *A = static_cast
(m->softmax_activation_buffer); - // matrix B: attn_heads gradients - // matrix B's layout: [vProjSize * num_q_heads, num_new_tokens] - DT const *B = static_cast
(m->handle.workSpace); - // matrix C: gradients for value (saved as part of m->devQKVProjArray) - // matrix C's layout: [num_tokens, qProjsize * num_q_heads, 3] - // note that we first need to compute the gradients wrt each q_heads, then we can sum - // the gradients corresponding to each group of q_heads to obtain the gradients wrt each - // value head - DT *C = static_cast
(m->devQKVProjArray) + - 2 * num_tokens * - (m->qProjSize * m->num_q_heads); // skip over regions reserved - // for Q and K gradients - // after transpositions - int m_ = num_tokens; // total_tokens - int n_ = m->vProjSize; // num_new_tokens - int k_ = num_tokens; // num_new_tokens - // before transpositions - int lda = num_tokens; // num_new_tokens - int ldb = m->vProjSize * m->num_q_heads; - int ldc = num_tokens; // total_tokens - // N.B. strides are applied before transpose operations - int strideA = num_tokens * num_tokens; // num_new_tokens * total_tokens - int strideB = m->vProjSize; - int strideC = num_tokens * m->vProjSize; - checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_T, - m_, - n_, - k_, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - // save result to file for checking - if (m->inference_debugging) { - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".v_proj.input_gradient_0"; - save_tensor(C, m_ * n_ * m->num_q_heads, filename.c_str()); - std::string filename2 = - get_peft_dbg_folder(m, shard_id) + ".qk_prods.softmax"; - save_tensor(A, m_ * k_ * m->num_q_heads, filename2.c_str()); - } - } - // Step 3: compute gradients w.r.t. the qk_prods_softmax tensor - { - float alpha = 1.0f, beta = 0.0f; - // matrix A: attn_heads gradients - // matrix A's layout: [vProjSize * num_q_heads, num_new_tokens] - DT const *A = static_cast
(m->handle.workSpace); - // matrix B: value cache - // matrix B's layout: [vProjSize * num_kv_heads, max_num_tokens, num_req] - DT const *B = static_cast
(m->valueCache) + i * vt_req_block_size; - // matrix C: qk_prods_softmax gradients - // matrix C's layout: [num_new_tokens, total_tokens, num_q_heads] - DT *C = static_cast
(m->qk_prods_softmax); - // after transposition & striding - int m_ = num_tokens; // num_new_tokens - int n_ = num_tokens; - int k_ = m->vProjSize; - // before transposition and striding - int lda = m->vProjSize * m->num_q_heads; - int ldb = m->vProjSize * m->num_kv_heads; - int ldc = num_tokens; // num_new_tokens - int strideA = m->vProjSize; - int strideB = m->vProjSize; - int strideC = num_tokens * num_tokens; // num_new_tokens * total_tokens - - run_batched_matmul
(m, m->handle.blas, - CUBLAS_OP_T, CUBLAS_OP_N, - m_, n_, k_, - &alpha, - A, cublas_data_type, lda, strideA, - B, cublas_data_type, ldb, strideB, - &beta, - C, cublas_data_type, ldc, strideC, - m->num_q_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP, - stream, - 1, - m->num_q_heads/m->num_kv_heads, - 1, - true); - if (m->inference_debugging) { - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".qk_prods.softmax_grad"; - save_tensor( - C, num_tokens * num_tokens * m->num_q_heads, filename.c_str()); - std::string filename2 = get_peft_dbg_folder(m, shard_id) + ".vcache"; - save_tensor( - B, m->vProjSize * m->num_q_heads * num_tokens, filename2.c_str()); - } - } - // Step 4: softmax backpropagation - { - float alpha = 1.0f, beta = 0.0f; - int n_param = m->num_q_heads; - int c_param = num_tokens; - int h_param = 1; - int w_param = num_tokens; - checkCUDNN(cudnnSetTensor4dDescriptor(m->qk_tensor, - CUDNN_TENSOR_NCHW, - cudnn_data_type, - n_param, - c_param, - h_param, - w_param)); - checkCUDNN(cudnnSoftmaxBackward(m->handle.dnn, - CUDNN_SOFTMAX_ACCURATE, - CUDNN_SOFTMAX_MODE_CHANNEL, - &alpha, - m->qk_tensor, - m->softmax_activation_buffer, - m->qk_tensor, - m->qk_prods_softmax, - &beta, - m->qk_tensor, - m->qk_prods)); - - if (m->inference_debugging) { - DT *C = static_cast
(m->qk_prods); - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".qk_prods.softmax_grad_in"; - save_tensor( - C, num_tokens * num_tokens * m->num_q_heads, filename.c_str()); - } - - // TODO: fill all elements above diagonal to force causal attention - size_t entries_above_diagonal = num_tokens * (num_tokens - 1) / 2; - if (entries_above_diagonal > 0) { - size_t parallelism = m->num_q_heads * entries_above_diagonal; - fill_entries_above_diagonal<<>>(static_cast
(m->qk_prods), - num_tokens, - num_tokens, - m->num_q_heads, - entries_above_diagonal, - DT(0.0f)); - } - if (m->inference_debugging) { - DT *C = static_cast
(m->qk_prods); - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".qk_prods.softmax_grad_in.masked"; - save_tensor( - C, num_tokens * num_tokens * m->num_q_heads, filename.c_str()); - } - } - // Step 5: compute gradients w.r.t. key - { - float alpha = 1.0f, beta = 0.0f; - if (*m->qk_prod_scaling) { - alpha = 1.0f / sqrt(m->kProjSize); - } - // matrix A: gradients w.r.t. qk_prods - // matrix A's layout: [num_new_tokens, num_tokens, num_q_heads] - DT const *A = static_cast
(m->qk_prods); - // matrix B: query activation (in query_activation_buffer) - // matrix B's layout: [m->qProjSize * num_q_heads, num_new_tokens] - DT const *B = static_cast
(m->query_activation_buffer); - // matrix C: gradients for key (saved as part of m->devQKVProjArray) - // matrix C's layout: [num_tokens, qProjsize * num_q_heads, 3] - // note that we first need to compute the gradients wrt each q_heads, then we can sum - // the gradients corresponding to each group of q_heads to obtain the gradients wrt each - // key head - DT *C = static_cast
(m->devQKVProjArray) + - num_tokens * - (m->qProjSize * - m->num_q_heads); // skip over regions reserved for Q gradients - // after transposition & striding - int m_ = num_tokens; - int n_ = m->kProjSize; - int k_ = num_tokens; // num_new_tokens - // before transposition and striding - int lda = num_tokens; // num_new_tokens - int ldb = m->kProjSize * m->num_q_heads; - int ldc = num_tokens; - int strideA = num_tokens * num_tokens; - int strideB = m->kProjSize; - int strideC = num_tokens * m->kProjSize; - checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_T, - m_, - n_, - k_, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - if (m->inference_debugging) { - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".query_activation"; - save_tensor( - B, m->qProjSize * m->num_q_heads * num_tokens, filename.c_str()); - std::string filename2 = - get_peft_dbg_folder(m, shard_id) + ".devkproj_pre"; - save_tensor( - C, num_tokens * (m->qProjSize * m->num_q_heads), filename2.c_str()); - } - } - // Step 6: compute gradients w.r.t query - { - float alpha = 1.0f, beta = 0.0f; - if (*m->qk_prod_scaling) { - alpha = 1.0f / sqrt(m->kProjSize); - } - // matrix A: gradients w.r.t. qk_prods - // matrix A's layout: [num_new_tokens, num_tokens, num_q_heads] - DT const *A = static_cast
(m->qk_prods); - // matrix B: key cache - // matrix B's layout: [vProjSize * num_kv_heads, max_num_tokens, num_req] - DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; - // matrix C: gradients for query (saved as part of m->devQKVProjArray) - // matrix C's layout: [num_tokens, qProjsize * num_q_heads, 3] - DT *C = static_cast
(m->devQKVProjArray); - // after transposition & striding - int m_ = num_tokens; // num_new_tokens - int n_ = m->qProjSize; - int k_ = num_tokens; - // before transposition and striding - int lda = num_tokens; // num_new_tokens - int ldb = m->qProjSize * m->num_kv_heads; - int ldc = num_tokens; - int strideA = num_tokens * num_tokens; - int strideB = m->qProjSize; - int strideC = num_tokens * m->qProjSize; - - run_batched_matmul
(m, m->handle.blas, - CUBLAS_OP_N, CUBLAS_OP_T, - m_, n_, k_, - &alpha, - A, cublas_data_type, lda, strideA, - B, cublas_data_type, ldb, strideB, - &beta, - C, cublas_data_type, ldc, strideC, - m->num_q_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP, - stream, - 1, - m->num_q_heads/m->num_kv_heads, - 1, - true); - - if (m->inference_debugging) { - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".devQKVPRojArray_pre"; - save_tensor( - C, num_tokens * m->qProjSize * m->num_q_heads * 3, filename.c_str()); - } - } - - // Step 7: perform rotary position embeddings (RoPE) bwd - // todo: first sum the gradients wrt each q_head to obtain the gradients wrt each key head - { - if (m->rotary_embedding_meta->apply_rotary_embedding) { - assert(m->hidden_size == m->qProjSize * m->num_q_heads); - assert(m->qProjSize == m->kProjSize); - /*q&k*/ - int parallelism = num_tokens * m->hidden_size; - DT *A = static_cast
(m->devQKVProjArray); - apply_rotary_embedding_bwd<<>>( - A, - m->complex_input, - m->token_infos, - m->rotary_embedding_meta->rope_theta, - (m->rotary_embedding_meta->rope_type == "llama3"), - m->rotary_embedding_meta->factor, - m->rotary_embedding_meta->low_freq_factor, - m->rotary_embedding_meta->high_freq_factor, - m->rotary_embedding_meta->original_max_position_embeddings, - m->qProjSize, - num_tokens, - m->hidden_size); - DT *C = static_cast
(m->devQKVProjArray); - if (m->inference_debugging) { - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".devQKVPRojArray"; - save_tensor(C, - num_tokens * m->qProjSize * m->num_q_heads * 3, - filename.c_str()); - } - } - - // matrix C: gradients for key (saved as part of m->devQKVProjArray) - // matrix C's layout: [num_tokens, qProjsize * num_heads, 3] - DT *C = static_cast
(m->devQKVProjArray) + - num_tokens * - (m->qProjSize * - m->num_q_heads); // skip over regions reserved for Q gradients - if (m->inference_debugging) { - std::string filename = get_peft_dbg_folder(m, shard_id) + ".devkproj"; - save_tensor( - C, num_tokens * (m->qProjSize * m->num_q_heads), filename.c_str()); - } - } - - // Step 8: compute gradients w.r.t. input - { - float alpha = 1.0f, beta = 0.0f; - if (!m->reset_input_grads[0]) { - beta = 1.0f; - } - // matrix B: gradients w.r.t. QKV (concatenated in devQKVArray) - // matrix B's layout: [num_tokens, qProjsize * num_heads, 3] - DT const *B = static_cast
(m->devQKVProjArray); - // matrix C: gradients w.r.t. input - // matrix C's layout: [m->qSize, num_tokens] - DT *C = input_grad_ptr + - bc->requestsInfo[i].first_token_offset_in_batch * m->qSize; - // int m_ = m->qSize; - int n_ = num_tokens; - int k_ = m->qProjSize * (m->num_q_heads + 2 * m->num_kv_heads); - - // The original version uses existing result and attention's projection to - // do further calculation in a way different than the usual dense layer, - // they are off by a transpose. So an explicit transpose is needed here. - // The add here is just for gradient accumulation. - transposeAdd(C, B, n_, k_, alpha, beta, stream); - - if (m->inference_debugging) { - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".self_attn.input_gradient_0"; - save_tensor(C, num_tokens * m->qSize, filename.c_str()); - } - } -} - -} // namespace IncMultiHeadAttention -} // namespace Kernels - -using namespace Kernels::IncMultiHeadAttention; - -/*static*/ -void IncMultiHeadSelfAttention::inference_kernel_wrapper( - IncMultiHeadSelfAttentionMeta *m, - BatchConfig const *bc, - int shard_id, - GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output) { - cudaStream_t stream; - checkCUDA(get_legion_stream(&stream)); - - cudaEvent_t t_start, t_end; - if (m->profiling) { - cudaEventCreate(&t_start); - cudaEventCreate(&t_end); - cudaEventRecord(t_start, stream); - } - - assert(input.data_type == output.data_type); - - if (input.data_type == DT_HALF) { - Kernels::IncMultiHeadAttention::inference_kernel( - m, bc, shard_id, input.get_half_ptr(), output.get_half_ptr(), stream); - } else if (input.data_type == DT_FLOAT) { - Kernels::IncMultiHeadAttention::inference_kernel( - m, bc, shard_id, input.get_float_ptr(), output.get_float_ptr(), stream); - } else { - assert(false && "Unspported data type"); + + assert(input.data_type == output.data_type); + + if (input.data_type == DT_HALF) { + Kernels::IncMultiHeadAttention::inference_kernel( + m, bc, shard_id, input.get_half_ptr(), output.get_half_ptr(), stream); + } else if (input.data_type == DT_FLOAT) { + Kernels::IncMultiHeadAttention::inference_kernel( + m, bc, shard_id, input.get_float_ptr(), output.get_float_ptr(), stream); + } else { + assert(false && "Unspported data type"); } if (m->profiling) { @@ -2331,7 +1896,7 @@ void IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper( if (input_grad.data_type == DT_HALF) { assert(!m->offload); - Kernels::IncMultiHeadAttention::peft_bwd_kernel2(m, + Kernels::IncMultiHeadAttention::peft_bwd_kernel(m, bc, shard_id, input_grad.get_half_ptr(), @@ -2339,7 +1904,7 @@ void IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper( stream); } else if (input_grad.data_type == DT_FLOAT) { assert(!m->offload); - Kernels::IncMultiHeadAttention::peft_bwd_kernel2(m, + Kernels::IncMultiHeadAttention::peft_bwd_kernel(m, bc, shard_id, input_grad.get_float_ptr(), @@ -2469,7 +2034,9 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( size_t query_tmp_size = 0, key_cache_size = 0, value_cache_size = 0; switch (infer_mode) { case INC_DECODING_MODE: { +#ifdef USE_FLASHINFER query_tmp_size = num_q_heads * qProjSize * max_tokens_per_batch; +#endif key_cache_size = num_kv_heads * kProjSize * BatchConfig::max_requests_per_batch() * BatchConfig::max_sequence_length(); @@ -2480,7 +2047,9 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( } case BEAM_SEARCH_MODE: case TREE_VERIFY_MODE: { +#ifdef USE_FLASHINFER query_tmp_size = num_q_heads * qProjSize * max_tokens_per_batch; +#endif // a K-ary tree max node is (k^n - 1) / 2 key_cache_size = num_kv_heads * kProjSize * BeamSearchBatchConfig::max_requests_per_batch() * @@ -2500,7 +2069,11 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( size_t qk_prod_size = max_tokens_per_batch * BatchConfig::max_sequence_length() * num_q_heads; size_t attn_heads_size = max_tokens_per_batch * num_q_heads * vProjSize; +#ifdef USE_FLASHINFER size_t output_tmp_size = max_tokens_per_batch * num_q_heads * vProjSize; +#else + size_t output_tmp_size = 0; +#endif size_t complex_size = (max_tokens_per_batch * (qProjSize * num_q_heads + kProjSize * num_q_heads)) / 2; @@ -2559,20 +2132,22 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( // offset += qkv_max_proj_size * size_of_dt; } +#ifdef USE_FLASHINFER if (query_tmp_size > 0) { queryTmp = gpu_mem_allocator.allocate_instance_untyped(query_tmp_size * size_of_dt); } - +#endif // use key value cache in all mode. keyCache = gpu_mem_allocator.allocate_instance_untyped(key_cache_size * size_of_dt); valueCache = gpu_mem_allocator.allocate_instance_untyped(value_cache_size * size_of_dt); - +#ifdef USE_FLASHINFER outputTmp = gpu_mem_allocator.allocate_instance_untyped(output_tmp_size * size_of_dt); +#endif // gqa pointers if (num_q_heads > num_kv_heads) { @@ -2629,11 +2204,13 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( } } +#ifdef USE_FLASHINFER // set attention constants handler.incr_attention_metadata->set_enabled(true); handler.incr_attention_metadata->set_num_q_heads(num_q_heads); handler.incr_attention_metadata->set_num_kv_heads(num_kv_heads); handler.incr_attention_metadata->set_head_dim(qProjSize); +#endif cudaStreamSynchronize(stream); } @@ -2658,30 +2235,18 @@ template void half *output_ptr, cudaStream_t stream); -template void Kernels::IncMultiHeadAttention::compute_qkv_kernel( +template void Kernels::IncMultiHeadAttention::apply_scaling_and_rotary( IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, int shard_id, float *output_ptr, cudaStream_t stream); -template void Kernels::IncMultiHeadAttention::compute_qkv_kernel( +template void Kernels::IncMultiHeadAttention::apply_scaling_and_rotary( IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, int shard_id, half *output_ptr, cudaStream_t stream); -// template void Kernels::IncMultiHeadAttention::produce_output( -// IncMultiHeadSelfAttentionMeta const *m, -// BatchConfig const *bc, -// float *output_ptr, -// cudaStream_t stream); - -// template void Kernels::IncMultiHeadAttention::produce_output( -// IncMultiHeadSelfAttentionMeta const *m, -// BatchConfig const *bc, -// half *output_ptr, -// cudaStream_t stream); - }; // namespace FlexFlow diff --git a/src/ops/kernels/gemm_impl.cu b/src/ops/kernels/gemm_impl.cu index 939eaeb3b..9f15fda4c 100644 --- a/src/ops/kernels/gemm_impl.cu +++ b/src/ops/kernels/gemm_impl.cu @@ -12,7 +12,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#ifdef USE_FLASHINFER #include "flexflow/ops/kernels/gemm_impl.h" #include "flexflow/utils/cuda_helper.h" #include @@ -557,3 +557,4 @@ void GemmEngine::gemm_internal(cublasOperation_t transa, #endif } } // namespace Internal +#endif diff --git a/src/ops/kernels/inc_multihead_self_attention_kernels.cu b/src/ops/kernels/inc_multihead_self_attention_kernels.cu index 80cbaa0af..b062b8114 100644 --- a/src/ops/kernels/inc_multihead_self_attention_kernels.cu +++ b/src/ops/kernels/inc_multihead_self_attention_kernels.cu @@ -12,6 +12,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifdef USE_FLASHINFER #include "flexflow/batch_config.h" #include #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) @@ -276,3 +277,4 @@ template void Kernels::IncMultiHeadAttention::produce_output( }; // namespace FlexFlow +#endif \ No newline at end of file diff --git a/src/ops/spec_inc_multihead_self_attention.cpp b/src/ops/spec_inc_multihead_self_attention.cpp index 59b0206ed..da52ccea7 100644 --- a/src/ops/spec_inc_multihead_self_attention.cpp +++ b/src/ops/spec_inc_multihead_self_attention.cpp @@ -719,7 +719,7 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, // phase 1: Implement kernel to compute KQV for input tokens // TODO WARNING: this is commented out only because we are fixing the inc_attn // first - compute_qkv_kernel( + apply_scaling_and_rotary( m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); // phase 2: Update key/val cache update_kv_cache_kernel
(m, bc, stream); diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 703bf1989..75d2c5fd7 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -700,7 +700,7 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, // phase 1: Implement kernel to compute KQV for input tokens // TODO WARNING: this is commented out only because we are fixing the inc_attn // first - compute_qkv_kernel( + apply_scaling_and_rotary( m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); // phase 2: Update key/val cache update_kv_cache_kernel
(m, bc, stream); diff --git a/src/ops/tree_inc_multihead_self_attention.cpp b/src/ops/tree_inc_multihead_self_attention.cpp index 55698803d..6e24326dd 100644 --- a/src/ops/tree_inc_multihead_self_attention.cpp +++ b/src/ops/tree_inc_multihead_self_attention.cpp @@ -632,7 +632,7 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, // phase 1: Implement kernel to compute KQV for input tokens // TODO WARNING: this is commented out only because we are fixing the inc_attn // first - compute_qkv_kernel( + apply_scaling_and_rotary( m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); // phase 2: No need to update key/val cache diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 9ae39ddc4..9915d5f4a 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -608,7 +608,7 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, // phase 1: Implement kernel to compute KQV for input tokens // TODO WARNING: this is commented out only because we are fixing the inc_attn // first - compute_qkv_kernel( + apply_scaling_and_rotary( m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); // phase 2: No need to update key/val cache diff --git a/src/runtime/model.cu b/src/runtime/model.cu index ee0a41434..463e3c92d 100644 --- a/src/runtime/model.cu +++ b/src/runtime/model.cu @@ -91,6 +91,7 @@ FFHandler handle.offload_reserve_space_size = info->offload_reserve_space_size; handle.quantization_type = info->quantization_type; handle.allowTensorOpMathConversion = info->allowTensorOpMathConversion; +#ifdef USE_FLASHINFER handle.incr_attention_metadata = new AttentionMetaData(); handle.tree_search_attention_metadata = new AttentionMetaData(); handle.tree_verify_attention_metadata = new AttentionMetaData(); @@ -100,21 +101,12 @@ FFHandler "Attention metadata must be allocated"); assert(handle.tree_verify_attention_metadata != nullptr && "Attention metadata must be allocated"); +#endif checkCUDA(cublasCreate(&handle.blas)); - // checkCUDA(cublasLtCreate(&handle.blasLt)); if (handle.allowTensorOpMathConversion) { checkCUDA(cublasSetMathMode(handle.blas, CUBLAS_TENSOR_OP_MATH)); } checkCUDNN(cudnnCreate(&handle.dnn)); - // handle.num_devices = 0; - // handle.device_id = 0; - // handle.gemm_engine = new Internal::GemmEngine(handle.blas, handle.blasLt); - // // We may not use all devices, physical_device may not be successive, so we - // // explicitly get the physical device id - // int physical_device; - // checkCUDA(cudaGetDevice(&physical_device)); - // checkCUDA(cudaGetDeviceProperties(handle.gemm_engine->device_prop, - // physical_device)); // #ifdef FF_USE_NCCL // checkNCCL(ncclCommInitRank(&handle.nccl, info->allRanks, info->ncclId, // info->myRank)); fprintf(stderr, "handle.nccl(%p)\n", handle.nccl); @@ -166,6 +158,28 @@ FFHandler } else { handle.offload_reserve_space = nullptr; } + if (handle.batch_config_metadata_size > 0) { + // allocate memory for offload reserve space + Memory gpu_mem = get_proc_mem(Machine::get_machine(), task->target_proc); + Realm::Rect<1, coord_t> bounds( + Realm::Point<1, coord_t>(0), + Realm::Point<1, coord_t>(handle.batch_config_metadata_size - 1)); + std::vector field_sizes; + field_sizes.push_back(sizeof(char)); + Realm::RegionInstance workspaceInst; + Realm::RegionInstance::create_instance(workspaceInst, + gpu_mem, + bounds, + field_sizes, + 0, + Realm::ProfilingRequestSet()) + .wait(); + handle.batch_config_metadata = static_cast( + workspaceInst.pointer_untyped(0, sizeof(char))); + } else { + handle.batch_config_metadata = nullptr; + } +#ifdef USE_FLASHINFER // std::cout << "handle.batch_config_metadata_size: " // << handle.batch_config_metadata_size << std::endl; // std::cout << "handle.incr_attention_metadata->mem_size(): " @@ -226,6 +240,7 @@ FFHandler handle.tree_verify_attention_metadata->assign_address(nullptr, 0); // handle.gemm_engine->assign_workspace(nullptr, 0); } +#endif // checkCUDA(cudaMalloc(&handle.workSpace, handle.workSpaceSize)); #ifdef FF_USE_NCCL handle.ncclComm = NULL; diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 82c35e233..e71fbe469 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -205,9 +205,9 @@ bool RequestManager::load_request_token_ids(Request &request) { request.peft_finetuning_info.max_training_steps && "Gradient accumulation steps should be less than or equal to max " "training steps"); - } assert(get_num_ssms() == 0 && "Small speculative models not supported for " "PEFT finetuning requests"); + } return true; } diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index 2fe912284..33c513a85 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -13,8 +13,10 @@ * limitations under the License. */ +#ifdef USE_FLASHINFER #include "flashinfer/decode_attention_decl.cuh" #include "flashinfer/prefill_attention_decl.cuh" +#endif #include "flexflow/request_manager.h" #include "flexflow/utils/cuda_helper.h" @@ -23,6 +25,7 @@ namespace FlexFlow { using namespace Legion; +#ifdef USE_FLASHINFER using flashinfer::BatchDecodeHandler; using flashinfer::BatchPrefillHandler; using flashinfer::LogitsPostHook; @@ -30,6 +33,7 @@ using flashinfer::paged_kv_t; using flashinfer::PageStorage; using flashinfer::PosEncodingMode; using flashinfer::QKVLayout; +#endif void RequestManager::load_tokens_task( Task const *task, @@ -89,6 +93,7 @@ void RequestManager::load_tokens_task( stream)); } +#ifdef USE_FLASHINFER // q_indptr: the start offset of q in the batch for each request, // the length is `num_requests + 1`: [0, num_q_0, num_q_0 + num_q_1, // ..., num_q_0 + num_q_1 + ... + num_q_{num_requests - 1}] @@ -147,6 +152,7 @@ __global__ void kv_last_page_len[request_idx] = (kv_len - 1) % kPagesize + 1; qk_indptr[request_idx + 1] = qk_lens; } +#endif void RequestManager::load_batch_config_task( Task const *task, @@ -227,7 +233,7 @@ void RequestManager::load_batch_config_task( stream)); } - +#ifdef USE_FLASHINFER // load attention metadata if (batch_config->get_mode() == INC_DECODING_MODE) { int batch_size = batch_config->num_active_requests(); @@ -256,8 +262,8 @@ void RequestManager::load_batch_config_task( // prepare attention forward handler { static int32_t q_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1], - kv_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1], - kv_last_page_len_h[BatchConfig::MAX_NUM_REQUESTS]; + kv_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1]; + // kv_last_page_len_h[BatchConfig::MAX_NUM_REQUESTS]; q_indptr_h[0] = 0; kv_indptr_h[0] = 0; for (int req_idx = 0, indptr_idx = 0; req_idx < batch_config->max_requests_per_batch(); req_idx++) { @@ -266,7 +272,7 @@ void RequestManager::load_batch_config_task( int kv_len = batch_config->requestsInfo[req_idx].num_tokens_in_batch + batch_config->requestsInfo[req_idx].first_token_depth_in_request; q_indptr_h[indptr_idx + 1] = q_indptr_h[indptr_idx] + q_len; kv_indptr_h[indptr_idx + 1] = kv_indptr_h[indptr_idx] + round_up_pages(kv_len); - kv_last_page_len_h[indptr_idx] = (kv_len - 1) % kPagesize + 1; + // kv_last_page_len_h[indptr_idx] = (kv_len - 1) % kPagesize + 1; indptr_idx++; } } @@ -298,6 +304,8 @@ void RequestManager::load_batch_config_task( } else { assert(false && "Not implemented"); } +#endif + } void RequestManager::load_positions_task( diff --git a/tests/fine_grained_alignment_test.sh b/tests/fine_grained_alignment_test.sh index 84d3a2a28..2c2e2cc0f 100755 --- a/tests/fine_grained_alignment_test.sh +++ b/tests/fine_grained_alignment_test.sh @@ -2,13 +2,14 @@ set -x set -e -MODEL_NAME=${MODEL_NAME:-"JackFram/llama-160m"} +MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-1B-Instruct"} MEMORY_PER_GPU=${MEMORY_PER_GPU:-14000} ZCOPY_MEMORY=${ZCOPY_MEMORY:-40000} -TP_DEGREE=${TP_DEGREE:-2} -PP_DEGREE=${PP_DEGREE:-2} +TP_DEGREE=${TP_DEGREE:-1} +PP_DEGREE=${PP_DEGREE:-1} CACHE_PATH=${FF_CACHE_PATH:-"~/.cache/flexflow"} NUM_STEPS=${NUM_STEPS:-2} +FULL_PRECISION=${FULL_PRECISION:-true} cleanup() { eval rm -rf "${CACHE_PATH}/debug" ./fine_grained_alignment_config.json ./inference/output/fine_grained_alignment_test_ff.txt ./inference/output/fine_grained_alignment_test_hf.txt @@ -29,7 +30,7 @@ mkdir -p ./inference/output # Enable backtrace in case we run into a segfault or assertion failure export LEGION_BACKTRACE=1 -export FF_DEBG_NO_WEIGHTS=0 +export FF_DEBG_NO_WEIGHTS=1 FUSION=true @@ -53,8 +54,7 @@ python ./tests/inference/huggingface_inference.py \ --max-length "${MAX_LENGTH}" \ --prompt-file ../../inference/prompt/test.json \ --output-file ../../inference/output/fine_grained_alignment_test_hf.txt \ - --use-full-precision \ - --inference-debugging + --inference-debugging ${FULL_PRECISION:+--use-full-precision} NUM_GPUS=$((TP_DEGREE * PP_DEGREE)) json_config=$(cat <<-END @@ -67,12 +67,12 @@ json_config=$(cat <<-END "data_parallelism_degree": 1, "tensor_parallelism_degree": ${TP_DEGREE}, "pipeline_parallelism_degree": ${PP_DEGREE}, - "inference_debugging": true, + "inference_debugging": ${FULL_PRECISION}, "fusion": ${FUSION}, "refresh_cache": false, "llm_model": "${MODEL_NAME}", "cache_path": "${CACHE_PATH}", - "full_precision": true, + "full_precision": false, "prompt": "./inference/prompt/test.json", "max_length": $MAX_LENGTH, "output_file": "./inference/output/fine_grained_alignment_test_ff.txt" diff --git a/tests/inference/huggingface_inference.py b/tests/inference/huggingface_inference.py index 5777cae00..15ff0aed4 100644 --- a/tests/inference/huggingface_inference.py +++ b/tests/inference/huggingface_inference.py @@ -64,7 +64,7 @@ def main(): cuda_availble = torch.cuda.is_available() device = "cuda" if args.gpu and cuda_availble else "cpu" # Get Model - model = AutoModelForCausalLM.from_pretrained(args.model_name, trust_remote_code=True).to(device) + model = AutoModelForCausalLM.from_pretrained(args.model_name, trust_remote_code=True, attn_implementation="eager").to(device) # Get Tokenizer hf_config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)