From 2e764a9017539adc2ec5c502295528b70e16e97d Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Tue, 25 Feb 2025 23:01:20 -0500 Subject: [PATCH] FlexLLM (part 3) (#106) * update * longer prompts * Inference test fixes (#105) * FlexLLM (part 2) (#104) * init * update * hip fixes * fixes * Update run.sh * save outputs in json format, fix test script * update * update checker script * fix * update * update * update * update * update * shellcheck * rocm * update --- conda/flexflow.yml | 1 + docker/flexflow-environment/Dockerfile | 2 +- include/flexflow/request_manager.h | 1 + inference/python/incr_decoding.py | 2 +- inference/python/spec_infer.py | 4 +- python/flexflow/core/flexflow_cffi.py | 2 +- python/flexflow/serve/serve.py | 3 +- src/c/flexflow_c.cc | 7 +- src/ops/inc_multihead_self_attention.cpp | 175 +++++------ src/ops/inc_multihead_self_attention.cu | 175 +++++------ src/runtime/request_manager.cc | 296 ++++++++++--------- tests/fine_grained_alignment_test.sh | 22 +- tests/inference/generate_inf_test_configs.py | 7 +- tests/inference/huggingface_inference.py | 46 +-- tests/inference/inference_alignment_test.py | 36 ++- tests/inference/test_inference_output.py | 161 +++++----- tests/inference_tests.sh | 22 +- tests/peft/alignment/align_test_utils.py | 14 +- 18 files changed, 526 insertions(+), 450 deletions(-) diff --git a/conda/flexflow.yml b/conda/flexflow.yml index 7568ed648..3f6cd99a4 100644 --- a/conda/flexflow.yml +++ b/conda/flexflow.yml @@ -27,3 +27,4 @@ dependencies: - loralib - triton - peft + - pytest diff --git a/docker/flexflow-environment/Dockerfile b/docker/flexflow-environment/Dockerfile index 92423adf2..88b360074 100644 --- a/docker/flexflow-environment/Dockerfile +++ b/docker/flexflow-environment/Dockerfile @@ -113,7 +113,7 @@ RUN rm /usr/local/bin/install_pytorch.sh RUN pip3 install transformers>=4.47.1 sentencepiece einops RUN pip3 install tensorflow notebook # PEFT-related -RUN pip3 install scipy bitsandbytes datasets accelerate loralib triton peft +RUN pip3 install scipy bitsandbytes datasets accelerate loralib triton peft pytest RUN pip3 install streamlit # Install Rust diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index 768f2d840..98f8dc93a 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -265,6 +265,7 @@ class RequestManager { void record_decoding_req_profiling_info(BatchConfig const &old_fwd_bc, int req_idx); void record_step_profile_info(BatchConfig const &old_bc); + void save_output_to_json(); void save_profiling_info_to_csv(std::string output_folder, std::string dataset_name, std::string llm_model_name, diff --git a/inference/python/incr_decoding.py b/inference/python/incr_decoding.py index bf044670d..968aa65b2 100644 --- a/inference/python/incr_decoding.py +++ b/inference/python/incr_decoding.py @@ -101,7 +101,7 @@ def main(): ) llm.compile( generation_config, - max_requests_per_batch=1, + max_requests_per_batch=4, max_seq_length=256, max_tokens_per_batch=64, ) diff --git a/inference/python/spec_infer.py b/inference/python/spec_infer.py index 38dc6db63..a7652be59 100644 --- a/inference/python/spec_infer.py +++ b/inference/python/spec_infer.py @@ -130,7 +130,7 @@ def main(): for ssm in ssms: ssm.compile( generation_config, - max_requests_per_batch=1, + max_requests_per_batch=4, max_seq_length=256, max_tokens_per_batch=64, ) @@ -138,7 +138,7 @@ def main(): # Compile the LLM for inference and load the weights into memory llm.compile( generation_config, - max_requests_per_batch=1, + max_requests_per_batch=4, max_seq_length=256, max_tokens_per_batch=64, ssms=ssms, diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py index 6cf4138a8..48c9bf211 100644 --- a/python/flexflow/core/flexflow_cffi.py +++ b/python/flexflow/core/flexflow_cffi.py @@ -4718,7 +4718,7 @@ def generate(self, requests_list: List[Request]): ] # entry will be None for finetuning requests c_output_texts = [ ( - ffi.new("char[]", max_sequence_length * 5) + ffi.new("char[]", max_sequence_length * 10) if request.req_type == RequestType.REQ_INFERENCE else ffi.NULL ) diff --git a/python/flexflow/serve/serve.py b/python/flexflow/serve/serve.py index 394869426..6db415aea 100644 --- a/python/flexflow/serve/serve.py +++ b/python/flexflow/serve/serve.py @@ -301,8 +301,9 @@ def download_hf_weights_if_needed(self) -> None: If not, or if the refresh_cache parameter is set to True, download new weights and convert them. """ - # TODO: edit this to download the weights using snapshot_download and convert them to FlexFlow format without loading them to GPU def download_and_convert_llm_weights(model_name): + num_cores = os.cpu_count() -1 if os.cpu_count() > 1 else 1 + snapshot_download(repo_id=model_name, allow_patterns="*.safetensors", max_workers=min(30, num_cores)) hf_model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index 4c6ac5a09..ae21fd0c5 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -1780,10 +1780,11 @@ void flexflow_model_generate(flexflow_model_t handle_, if (max_lengths[i] >= 0) { assert(total_tokens <= max_lengths[i] || num_output_tokens == 0); } - // assert(results[i].output_tokens.size() <= max_seq_lengths[i] || - // results[i].output_tokens.size() == - // results[i].input_tokens.size()); output_length_and_tokens[i][0] = results[i].output_tokens.size(); + assert(results[i].output_tokens.size() <= max_lengths[i] + 100 && + "Exceeding python buffer size for token ids"); + assert(results[i].output_text.length() <= max_lengths[i] * 10 && + "Exceeding python buffer size for output text"); std::copy(results[i].output_tokens.begin(), results[i].output_tokens.end(), output_length_and_tokens[i] + 1); diff --git a/src/ops/inc_multihead_self_attention.cpp b/src/ops/inc_multihead_self_attention.cpp index 6acdce039..41268ee4d 100644 --- a/src/ops/inc_multihead_self_attention.cpp +++ b/src/ops/inc_multihead_self_attention.cpp @@ -989,67 +989,57 @@ __global__ void int num_tokens, int num_q_heads, int num_kv_heads) { - CUDA_KERNEL_LOOP(i, num_tokens * num_q_heads * proj_size) { - size_t q_array_size = proj_size * num_q_heads * num_tokens; - int hidden_size = num_q_heads * proj_size; - int total_num_heads = num_q_heads + 2 * num_kv_heads; - // create complex number - bool q_tensor = i < (q_array_size / 2); - int real_i = q_tensor ? i : i - q_array_size / 2; - int token_idx = real_i / (hidden_size / 2); - int idx = real_i % (proj_size / 2); - int head_idx = (real_i - (token_idx * (hidden_size / 2))) / (proj_size / 2); - if (!q_tensor) { - head_idx /= (num_q_heads / num_kv_heads); - } - - int real_part_index = idx + head_idx * proj_size + - token_idx * proj_size * total_num_heads + - hidden_size * (q_tensor ? 0 : 1); - int complex_part_index = real_part_index + (proj_size / 2); - - complex_input[i] = {input_ptr[real_part_index], - input_ptr[complex_part_index]}; - - // get the freq_cis: shape 1 * (qProjSize/2) = 1 * 64 - // apply a Cartesian coordinate transformation - // multiple with input & /copy back to q/k - - // get position of token - - // size_t pos = id_map[token_idx].token_position; - size_t pos = tokenInfos[token_idx].abs_depth_in_request; - - // float before_real = complex_input[i].x, before_complex = - int pos_i = real_i % (proj_size / 2); - - float freq = - pos * (1.0 / pow(rope_theta, (float)2 * pos_i / proj_size)); // θ_i + int half_proj = proj_size / 2; + int q_proj_work = num_tokens * num_q_heads * half_proj; + int kv_proj_work = num_tokens * num_kv_heads * half_proj; + int tot_num_heads = num_q_heads + 2 * num_kv_heads; + CUDA_KERNEL_LOOP(i, q_proj_work + kv_proj_work) { + bool q_tensor = i < q_proj_work; + int num_heads = q_tensor ? num_q_heads : num_kv_heads; + + int real_i = q_tensor ? i : i - q_proj_work; + int token_idx = real_i / (half_proj * num_heads); + int pair_idx = real_i % half_proj; + int head_idx = (real_i / half_proj) % num_heads; + + // input_ptr: [proj_size, tot_num_heads, num_tokens] + int real_part_index = token_idx * proj_size * tot_num_heads + + (q_tensor ? 0 : proj_size * num_q_heads) + + head_idx * proj_size + pair_idx; + int complex_part_index = real_part_index + half_proj; + complex_input[i] = {(float)input_ptr[real_part_index], + (float)input_ptr[complex_part_index]}; + + float inv_freq = + 1.0 / pow(rope_theta, (float)2.0 * pair_idx / proj_size); // θ_i if (llama3_rope) { float pi = HIP_PI_F; - float wavelen = 2 * pi / freq; + float wavelen = 2 * pi / inv_freq; float low_freq_wavelen = original_max_position_embeddings / low_freq_factor; float high_freq_wavelen = original_max_position_embeddings / high_freq_factor; - if (wavelen < high_freq_wavelen) { - } else if (wavelen > low_freq_wavelen) { - freq = freq / factor; - } else { - assert(low_freq_wavelen != high_freq_wavelen); - float smooth = - (original_max_position_embeddings / wavelen - low_freq_factor) / - (high_freq_factor - low_freq_factor); - freq = ((1 - smooth) * freq / factor + smooth * freq); + if (wavelen > low_freq_wavelen) { + inv_freq = inv_freq / factor; + } + float smooth_factor = + (original_max_position_embeddings / wavelen - low_freq_factor) / + (high_freq_factor - low_freq_factor); + if (!(wavelen < high_freq_wavelen) && !(wavelen > low_freq_wavelen)) { + inv_freq = ((1 - smooth_factor) * inv_freq / factor + + smooth_factor * inv_freq); } } - hipFloatComplex complex_pos = {cos(freq), sin(freq)}; + int pos = tokenInfos[token_idx].abs_depth_in_request; + inv_freq = inv_freq * pos; + + hipFloatComplex complex_pos = {cos(inv_freq), sin(inv_freq)}; complex_input[i] = hipCmulf(complex_input[i], complex_pos); - input_ptr[real_part_index] = complex_input[i].x; - input_ptr[complex_part_index] = complex_input[i].y; + input_ptr[real_part_index] = (DT)complex_input[i].x; + input_ptr[complex_part_index] = (DT)complex_input[i].y; } } @@ -1066,57 +1056,62 @@ __global__ void int original_max_position_embeddings, int proj_size, int num_tokens, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { + int num_q_heads, + int num_kv_heads) { + int half_proj = proj_size / 2; + int q_proj_work = num_tokens * num_q_heads * half_proj; + int kv_proj_work = num_tokens * num_kv_heads * half_proj; + // int tot_num_heads = num_q_heads + 2 * num_kv_heads; + CUDA_KERNEL_LOOP(i, q_proj_work + kv_proj_work) { // compute indexes to visit first half proj_size of each of q/k tensor. - // devQKVProj has shape [num_tokens, qProjSize, num_heads, 3] in peft_bwd - bool q_tensor = i < (num_tokens * hidden_size / 2); - int real_i = q_tensor ? i : i - num_tokens * hidden_size / 2; - assert(hidden_size % proj_size == 0); - int num_heads = hidden_size / proj_size; + // devQKVProj has shape [num_tokens, proj_size, tot_num_heads] in peft_bwd + bool q_tensor = i < q_proj_work; + int num_heads = q_tensor ? num_q_heads : num_kv_heads; + int real_i = q_tensor ? i : i - q_proj_work; int token_idx = real_i % num_tokens; - int idx = (real_i / num_tokens) % (proj_size / 2); - int head_idx = real_i / (num_tokens * proj_size / 2); + int pair_idx = (real_i / num_tokens) % half_proj; + int head_idx = real_i / (num_tokens * half_proj); assert(head_idx < num_heads); - int complex_part_index = (q_tensor ? 0 : 1) * num_tokens * hidden_size + - head_idx * num_tokens * proj_size + - idx * num_tokens + token_idx; - int real_part_index = complex_part_index + (proj_size / 2) * num_tokens; + int complex_part_index = + (q_tensor ? 0 : num_tokens * proj_size * num_q_heads) + + head_idx * proj_size * num_tokens + pair_idx * num_tokens + token_idx; + int real_part_index = complex_part_index + num_tokens * half_proj; - complex_input[i] = {input_ptr[real_part_index], - input_ptr[complex_part_index]}; + complex_input[i] = {(float)input_ptr[real_part_index], + (float)input_ptr[complex_part_index]}; - size_t pos = tokenInfos[token_idx].abs_depth_in_request; - - float freq = - pos * (1.0 / pow(rope_theta, (float)2 * idx / proj_size)); // θ_i + float inv_freq = + 1.0 / pow(rope_theta, (float)2.0 * pair_idx / proj_size); // θ_i if (llama3_rope) { float pi = HIP_PI_F; - float wavelen = 2 * pi / freq; + float wavelen = 2 * pi / inv_freq; float low_freq_wavelen = original_max_position_embeddings / low_freq_factor; float high_freq_wavelen = original_max_position_embeddings / high_freq_factor; - if (wavelen < high_freq_wavelen) { - } else if (wavelen > low_freq_wavelen) { - freq = freq / factor; - } else { - assert(low_freq_wavelen != high_freq_wavelen); - float smooth = - (original_max_position_embeddings / wavelen - low_freq_factor) / - (high_freq_factor - low_freq_factor); - freq = ((1 - smooth) * freq / factor + smooth * freq); + if (wavelen > low_freq_wavelen) { + inv_freq = inv_freq / factor; + } + float smooth_factor = + (original_max_position_embeddings / wavelen - low_freq_factor) / + (high_freq_factor - low_freq_factor); + if (!(wavelen < high_freq_wavelen) && !(wavelen > low_freq_wavelen)) { + inv_freq = ((1 - smooth_factor) * inv_freq / factor + + smooth_factor * inv_freq); } } - hipFloatComplex complex_pos = {cos(freq), sin(freq)}; + int pos = tokenInfos[token_idx].abs_depth_in_request; + inv_freq = inv_freq * pos; + + hipFloatComplex complex_pos = {cos(inv_freq), sin(inv_freq)}; complex_input[i] = hipCmulf(complex_input[i], complex_pos); - input_ptr[real_part_index] = complex_input[i].x; - input_ptr[complex_part_index] = complex_input[i].y; + input_ptr[real_part_index] = (DT)complex_input[i].x; + input_ptr[complex_part_index] = (DT)complex_input[i].y; } } @@ -1151,7 +1146,10 @@ void apply_scaling_and_rotary(IncMultiHeadSelfAttentionMeta const *m, // Step 3: apply rotary embedding if needed if (m->rotary_embedding_meta->apply_rotary_embedding) { /*q&k*/ - parallelism = num_tokens * m->hidden_size; + int half_proj = m->qProjSize / 2; + int q_proj_work = num_tokens * m->num_q_heads * half_proj; + int kv_proj_work = num_tokens * m->num_kv_heads * half_proj; + parallelism = q_proj_work + kv_proj_work; hipLaunchKernelGGL( HIP_KERNEL_NAME(apply_rotary_embedding_fwd), GET_BLOCKS(parallelism), @@ -1781,15 +1779,17 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, assert(m->hidden_size == m->qProjSize * m->num_q_heads); assert(m->qProjSize == m->kProjSize); /*q&k*/ - int parallelism = num_tokens * m->hidden_size; - DT *A = static_cast
(m->devQKVProjArray); + int half_proj = m->qProjSize / 2; + int q_proj_work = num_tokens * m->num_q_heads * half_proj; + int kv_proj_work = num_tokens * m->num_kv_heads * half_proj; + int parallelism = q_proj_work + kv_proj_work; hipLaunchKernelGGL( HIP_KERNEL_NAME(apply_rotary_embedding_bwd), GET_BLOCKS(parallelism), min(CUDA_NUM_THREADS, parallelism), 0, stream, - A, + static_cast
(m->devQKVProjArray), m->complex_input, m->peft_token_infos_device, m->rotary_embedding_meta->rope_theta, @@ -1800,7 +1800,8 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, m->rotary_embedding_meta->original_max_position_embeddings, m->qProjSize, num_tokens, - m->hidden_size); + m->num_q_heads, + m->num_kv_heads); DT *C = static_cast
(m->devQKVProjArray); if (m->inference_debugging) { std::string filename = @@ -2092,7 +2093,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( max_tokens_per_batch * BatchConfig::max_sequence_length() * num_q_heads; size_t attn_heads_size = max_tokens_per_batch * num_q_heads * vProjSize; size_t complex_size = (max_tokens_per_batch * (qProjSize * num_q_heads + - kProjSize * num_q_heads)) / + kProjSize * num_kv_heads)) / 2; if (enable_peft_finetuning) { allocated_peft_buffer_size1 = BatchConfig::max_sequence_length() * diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 874e8a02e..1b02b0052 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -968,67 +968,57 @@ __global__ void int num_tokens, int num_q_heads, int num_kv_heads) { - CUDA_KERNEL_LOOP(i, num_tokens * num_q_heads * proj_size) { - size_t q_array_size = proj_size * num_q_heads * num_tokens; - int hidden_size = num_q_heads * proj_size; - int total_num_heads = num_q_heads + 2 * num_kv_heads; - // create complex number - bool q_tensor = i < (q_array_size / 2); - int real_i = q_tensor ? i : i - q_array_size / 2; - int token_idx = real_i / (hidden_size / 2); - int idx = real_i % (proj_size / 2); - int head_idx = (real_i - (token_idx * (hidden_size / 2))) / (proj_size / 2); - if (!q_tensor) { - head_idx /= (num_q_heads / num_kv_heads); - } - - int real_part_index = idx + head_idx * proj_size + - token_idx * proj_size * total_num_heads + - hidden_size * (q_tensor ? 0 : 1); - int complex_part_index = real_part_index + (proj_size / 2); - - complex_input[i] = {input_ptr[real_part_index], - input_ptr[complex_part_index]}; - - // get the freq_cis: shape 1 * (qProjSize/2) = 1 * 64 - // apply a Cartesian coordinate transformation - // multiple with input & /copy back to q/k - - // get position of token - - // size_t pos = id_map[token_idx].token_position; - size_t pos = tokenInfos[token_idx].abs_depth_in_request; - - // float before_real = complex_input[i].x, before_complex = - int pos_i = real_i % (proj_size / 2); - - float freq = - pos * (1.0 / pow(rope_theta, (float)2 * pos_i / proj_size)); // θ_i + int half_proj = proj_size / 2; + int q_proj_work = num_tokens * num_q_heads * half_proj; + int kv_proj_work = num_tokens * num_kv_heads * half_proj; + int tot_num_heads = num_q_heads + 2 * num_kv_heads; + CUDA_KERNEL_LOOP(i, q_proj_work + kv_proj_work) { + bool q_tensor = i < q_proj_work; + int num_heads = q_tensor ? num_q_heads : num_kv_heads; + + int real_i = q_tensor ? i : i - q_proj_work; + int token_idx = real_i / (half_proj * num_heads); + int pair_idx = real_i % half_proj; + int head_idx = (real_i / half_proj) % num_heads; + + // input_ptr: [proj_size, tot_num_heads, num_tokens] + int real_part_index = token_idx * proj_size * tot_num_heads + + (q_tensor ? 0 : proj_size * num_q_heads) + + head_idx * proj_size + pair_idx; + int complex_part_index = real_part_index + half_proj; + complex_input[i] = {(float)input_ptr[real_part_index], + (float)input_ptr[complex_part_index]}; + + float inv_freq = + 1.0 / pow(rope_theta, (float)2.0 * pair_idx / proj_size); // θ_i if (llama3_rope) { float pi = CUDART_PI_F; - float wavelen = 2 * pi / freq; + float wavelen = 2 * pi / inv_freq; float low_freq_wavelen = original_max_position_embeddings / low_freq_factor; float high_freq_wavelen = original_max_position_embeddings / high_freq_factor; - if (wavelen < high_freq_wavelen) { - } else if (wavelen > low_freq_wavelen) { - freq = freq / factor; - } else { - assert(low_freq_wavelen != high_freq_wavelen); - float smooth = - (original_max_position_embeddings / wavelen - low_freq_factor) / - (high_freq_factor - low_freq_factor); - freq = ((1 - smooth) * freq / factor + smooth * freq); + if (wavelen > low_freq_wavelen) { + inv_freq = inv_freq / factor; + } + float smooth_factor = + (original_max_position_embeddings / wavelen - low_freq_factor) / + (high_freq_factor - low_freq_factor); + if (!(wavelen < high_freq_wavelen) && !(wavelen > low_freq_wavelen)) { + inv_freq = ((1 - smooth_factor) * inv_freq / factor + + smooth_factor * inv_freq); } } - cuFloatComplex complex_pos = {cos(freq), sin(freq)}; + int pos = tokenInfos[token_idx].abs_depth_in_request; + inv_freq = inv_freq * pos; + + cuFloatComplex complex_pos = {cos(inv_freq), sin(inv_freq)}; complex_input[i] = cuCmulf(complex_input[i], complex_pos); - input_ptr[real_part_index] = complex_input[i].x; - input_ptr[complex_part_index] = complex_input[i].y; + input_ptr[real_part_index] = (DT)complex_input[i].x; + input_ptr[complex_part_index] = (DT)complex_input[i].y; } } @@ -1045,57 +1035,62 @@ __global__ void int original_max_position_embeddings, int proj_size, int num_tokens, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { + int num_q_heads, + int num_kv_heads) { + int half_proj = proj_size / 2; + int q_proj_work = num_tokens * num_q_heads * half_proj; + int kv_proj_work = num_tokens * num_kv_heads * half_proj; + // int tot_num_heads = num_q_heads + 2 * num_kv_heads; + CUDA_KERNEL_LOOP(i, q_proj_work + kv_proj_work) { // compute indexes to visit first half proj_size of each of q/k tensor. - // devQKVProj has shape [num_tokens, qProjSize, num_heads, 3] in peft_bwd - bool q_tensor = i < (num_tokens * hidden_size / 2); - int real_i = q_tensor ? i : i - num_tokens * hidden_size / 2; - assert(hidden_size % proj_size == 0); - int num_heads = hidden_size / proj_size; + // devQKVProj has shape [num_tokens, proj_size, tot_num_heads] in peft_bwd + bool q_tensor = i < q_proj_work; + int num_heads = q_tensor ? num_q_heads : num_kv_heads; + int real_i = q_tensor ? i : i - q_proj_work; int token_idx = real_i % num_tokens; - int idx = (real_i / num_tokens) % (proj_size / 2); - int head_idx = real_i / (num_tokens * proj_size / 2); + int pair_idx = (real_i / num_tokens) % half_proj; + int head_idx = real_i / (num_tokens * half_proj); assert(head_idx < num_heads); - int complex_part_index = (q_tensor ? 0 : 1) * num_tokens * hidden_size + - head_idx * num_tokens * proj_size + - idx * num_tokens + token_idx; - int real_part_index = complex_part_index + (proj_size / 2) * num_tokens; + int complex_part_index = + (q_tensor ? 0 : num_tokens * proj_size * num_q_heads) + + head_idx * proj_size * num_tokens + pair_idx * num_tokens + token_idx; + int real_part_index = complex_part_index + num_tokens * half_proj; - complex_input[i] = {input_ptr[real_part_index], - input_ptr[complex_part_index]}; + complex_input[i] = {(float)input_ptr[real_part_index], + (float)input_ptr[complex_part_index]}; - size_t pos = tokenInfos[token_idx].abs_depth_in_request; - - float freq = - pos * (1.0 / pow(rope_theta, (float)2 * idx / proj_size)); // θ_i + float inv_freq = + 1.0 / pow(rope_theta, (float)2.0 * pair_idx / proj_size); // θ_i if (llama3_rope) { float pi = CUDART_PI_F; - float wavelen = 2 * pi / freq; + float wavelen = 2 * pi / inv_freq; float low_freq_wavelen = original_max_position_embeddings / low_freq_factor; float high_freq_wavelen = original_max_position_embeddings / high_freq_factor; - if (wavelen < high_freq_wavelen) { - } else if (wavelen > low_freq_wavelen) { - freq = freq / factor; - } else { - assert(low_freq_wavelen != high_freq_wavelen); - float smooth = - (original_max_position_embeddings / wavelen - low_freq_factor) / - (high_freq_factor - low_freq_factor); - freq = ((1 - smooth) * freq / factor + smooth * freq); + if (wavelen > low_freq_wavelen) { + inv_freq = inv_freq / factor; + } + float smooth_factor = + (original_max_position_embeddings / wavelen - low_freq_factor) / + (high_freq_factor - low_freq_factor); + if (!(wavelen < high_freq_wavelen) && !(wavelen > low_freq_wavelen)) { + inv_freq = ((1 - smooth_factor) * inv_freq / factor + + smooth_factor * inv_freq); } } - cuFloatComplex complex_pos = {cos(freq), sin(freq)}; + int pos = tokenInfos[token_idx].abs_depth_in_request; + inv_freq = inv_freq * pos; + + cuFloatComplex complex_pos = {cos(inv_freq), sin(inv_freq)}; complex_input[i] = cuCmulf(complex_input[i], complex_pos); - input_ptr[real_part_index] = complex_input[i].x; - input_ptr[complex_part_index] = complex_input[i].y; + input_ptr[real_part_index] = (DT)complex_input[i].x; + input_ptr[complex_part_index] = (DT)complex_input[i].y; } } @@ -1128,7 +1123,10 @@ void apply_scaling_and_rotary(IncMultiHeadSelfAttentionMeta const *m, // Step 3: apply rotary embedding if needed if (m->rotary_embedding_meta->apply_rotary_embedding) { /*q&k*/ - parallelism = num_tokens * m->hidden_size; + int half_proj = m->qProjSize / 2; + int q_proj_work = num_tokens * m->num_q_heads * half_proj; + int kv_proj_work = num_tokens * m->num_kv_heads * half_proj; + parallelism = q_proj_work + kv_proj_work; apply_rotary_embedding_fwd<<hidden_size == m->qProjSize * m->num_q_heads); assert(m->qProjSize == m->kProjSize); /*q&k*/ - int parallelism = num_tokens * m->hidden_size; - DT *A = static_cast
(m->devQKVProjArray); + int half_proj = m->qProjSize / 2; + int q_proj_work = num_tokens * m->num_q_heads * half_proj; + int kv_proj_work = num_tokens * m->num_kv_heads * half_proj; + int parallelism = q_proj_work + kv_proj_work; apply_rotary_embedding_bwd<<>>( - A, + static_cast
(m->devQKVProjArray), m->complex_input, m->peft_token_infos_device, m->rotary_embedding_meta->rope_theta, @@ -1774,7 +1774,8 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, m->rotary_embedding_meta->original_max_position_embeddings, m->qProjSize, num_tokens, - m->hidden_size); + m->num_q_heads, + m->num_kv_heads); DT *C = static_cast
(m->devQKVProjArray); if (m->inference_debugging) { std::string filename = @@ -2066,7 +2067,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( max_tokens_per_batch * BatchConfig::max_sequence_length() * num_q_heads; size_t attn_heads_size = max_tokens_per_batch * num_q_heads * vProjSize; size_t complex_size = (max_tokens_per_batch * (qProjSize * num_q_heads + - kProjSize * num_q_heads)) / + kProjSize * num_kv_heads)) / 2; if (enable_peft_finetuning) { allocated_peft_buffer_size1 = BatchConfig::max_sequence_length() * diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index b2e615bb6..740d5df4b 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -446,6 +446,10 @@ void RequestManager::register_tokenizer(ModelType type, void RequestManager::register_output_filepath( std::string const &_output_filepath) { this->output_filepath = _output_filepath; + // delete the file if it already exists + if (std::filesystem::exists(output_filepath)) { + std::filesystem::remove(output_filepath); + } } int RequestManager::register_ssm_model(FFModel *model) { @@ -599,6 +603,13 @@ RequestGuid RequestManager::register_new_request(Request const &request_) { gr.input_tokens = request.tokens; gr.output_text = request_.prompt; gr.output_tokens = request.tokens; + if (model_type == ModelType::LLAMA && old_llama_tokenizer && + request.add_special_tokens && request.tokens.at(0) == bos_token_id) { + // Unlike Huggingface, the sentencepiece C++ library automatically removes + // the BOS token + gr.input_text = " " + gr.input_text; + gr.output_text = " " + gr.output_text; + } request_generation_results[request.guid] = gr; ProfileInfo profile_info; @@ -894,81 +905,30 @@ void RequestManager::handle_completed_inf_req(BatchConfig const &old_bc, assert(request.req_type == RequestType::REQ_INFERENCE && "Found misplaced finetuning request"); - std::vector output_tokens = request.tokens; + GenerationResult &gr = request_generation_results[request.guid]; + std::vector output_tokens = std::vector( + request.tokens.begin() + gr.input_tokens.size(), request.tokens.end()); if (is_eos_token(output_tokens.back())) { // remove the EOS token output_tokens.pop_back(); } - std::string output = this->tokenizer_->Decode(output_tokens); - // Unlike Huggingface, the sentencepiece C++ library automatically - // removes the BOS token - if (model_type == ModelType::LLAMA && old_llama_tokenizer && - request.add_special_tokens && output_tokens.at(0) == bos_token_id) { - output = " " + output; - } - { - // update generation result - GenerationResult &gr = request_generation_results[request.guid]; - assert(gr.guid == request.guid); - gr.output_tokens = request.tokens; - gr.output_text = output; - } + std::string output_text = this->tokenizer_->Decode(output_tokens); + // update generation result + assert(gr.guid == request.guid); + gr.output_tokens = output_tokens; + gr.output_text = output_text; request.status = Request::COMPLETED; trigger_request_completion_future(request.guid); log_req_mgr.print("[Done] guid(%zu) initial_len(%d) final_length(%zu)", old_bc.requestsInfo[i].request_guid, request.initial_len, - output_tokens.size()); + gr.input_tokens.size() + gr.output_tokens.size()); // log_req_mgr.print("Final output: %s", output.c_str()); num_processed_requests++; ProfileInfo profile_info = profiling_requests[request.guid]; profile_info.finish_time = Realm::Clock::current_time_in_microseconds(); total_request_run_time += profile_info.finish_time - profile_info.start_time; profiling_requests[request.guid] = profile_info; - // log_req_mgr.print("[%s] guid(%zu) llm_decoding_steps(%d) initial_len(%d) - // final_len(%d) latency(%.1lf) ttft(%.1lf)", - // request.warmup ? "Warmup" : "Profile", - // request.guid, - // profile_info.llm_decoding_steps, - // request.initial_len, - // request.tokens.size(), - // (profile_info.finish_time - profile_info.start_time)/1e3, - // (profile_info.first_token_time - - // profile_info.registration_time)/1e3); - // Write output to file if needed: - if (!output_filepath.empty()) { - std::ofstream outputFile(output_filepath, std::ios::app); - if (outputFile.is_open()) { - outputFile << "[" << (request.warmup ? "Warmup" : "Profile") << "] guid(" - << request.guid << ") llm_decoding_steps(" - << profile_info.llm_decoding_steps << ") initial_len(" - << request.initial_len << ") final_len(" - << request.tokens.size() << ") latency(" << std::fixed - << std::setprecision(3) - << (profile_info.finish_time - profile_info.start_time) / 1e3 - << ") ttft(" << std::fixed << std::setprecision(3) - << (profile_info.first_token_time - - profile_info.registration_time) / - 1e3 - << ")\n"; - if (request.benchmarking_tokens <= 0) { - outputFile << "token IDs: "; - for (int i = 0; i < output_tokens.size(); i++) { - outputFile << output_tokens[i]; - if (i < output_tokens.size() - 1) { - outputFile << ","; - } - } - outputFile << std::endl; - outputFile << output; - } - outputFile.close(); - } else { - std::cout << "Unable to open the output file: " << output_filepath - << std::endl; - assert(false); - } - } } void RequestManager::add_continuing_inf_req_to_new_batch( @@ -1168,35 +1128,6 @@ void RequestManager::handle_completed_finetuning_req( request.peft_finetuning_info.completed_training_steps, request.peft_finetuning_info.finetuning_losses.back(), profile_info.finish_time - profile_info.start_time); - // if (!output_filepath.empty()) { - // // std::ofstream outputFile(output_filepath, std::ios::app); - // // if (outputFile.is_open()) { - // // std::string tokens_str = "["; - // // for (size_t i = 0; i < request.finetuning_tokens_per_batch.size(); - // // i++) { - // // tokens_str += - // // std::to_string(request.finetuning_tokens_per_batch[i]); - // // if (i != request.finetuning_tokens_per_batch.size() - 1) { - // // tokens_str += ", "; - // // } - // // } - // // tokens_str += "]"; - // // outputFile << "[" << (request.warmup ? "Warmup" : "Finetuning") - // // << "] guid(" << request.guid - // // << ") completed_training_steps(" - // // << request.peft_finetuning_info.completed_training_steps - // // << ") processed_finetuning_tokens(" - // // << request.processed_finetuning_tokens << ") latency(" - // // << std::fixed << std::setprecision(3) - // // << (profile_info.finish_time - profile_info.start_time) - // // << ") tokens_per_batch(" << tokens_str << ")\n"; - // // outputFile.close(); - // } else { - // std::cout << "Unable to open the output file: " << output_filepath - // << std::endl; - // assert(false); - // } - // } } void RequestManager::add_finetuning_req_fwd_batch(BatchConfig &new_bc) { @@ -1610,6 +1541,82 @@ BatchConfig return new_bc; } +void RequestManager::save_output_to_json() { + auto toCSV = [](auto const &arr) -> std::string { + std::ostringstream oss; + for (size_t i = 0; i < arr.size(); i++) { + if constexpr (std::is_floating_point_v) { + oss << std::fixed << std::setprecision(3) << arr[i]; + } else { + oss << arr[i]; + } + if (i != arr.size() - 1) { + oss << ","; + } + } + return oss.str(); + }; + + if (!output_filepath.empty()) { + // Extract keys and sort them in ascending order. + std::vector sortedKeys; + for (auto const &kv : request_generation_results) { + sortedKeys.push_back(kv.first); + } + std::sort(sortedKeys.begin(), sortedKeys.end()); + // Create a JSON array. + json jsonList = json::array(); + // Iterate over the sorted keys and add each dictionary to the JSON array. + for (auto const &key : sortedKeys) { + GenerationResult const &res = request_generation_results[key]; + ProfileInfo profile_info = profiling_requests[key]; + Request &request = all_requests[key]; + if (request.req_type == RequestType::REQ_INFERENCE) { + json entry = { + {"req_idx", key - 1000000}, + {"warmup", request.warmup}, + {"benchmarking tokens", request.benchmarking_tokens}, + {"req_type", "inference"}, + {"prompt_length", res.input_tokens.size()}, + {"response_length", res.output_tokens.size()}, + {"max_length", request.max_length}, + {"input_tokens", toCSV(res.input_tokens)}, + {"output_tokens", toCSV(res.output_tokens)}, + {"prompt", res.input_text}, + {"response", res.output_text}, + {"num_decoding_steps", profile_info.llm_decoding_steps}, + {"latency", + (profile_info.finish_time - profile_info.start_time) / 1e3}, + {"ttft", + (profile_info.first_token_time - profile_info.registration_time) / + 1e3}}; + jsonList.push_back(entry); + } else { + json entry = { + {"req_idx", key - 1000000}, + {"warmup", request.warmup}, + {"benchmarking tokens", request.benchmarking_tokens}, + {"req_type", "finetuning"}, + {"max_length", request.max_length}, + {"dataset_size", request.dataset.size()}, + {"completed_training_steps", + request.peft_finetuning_info.completed_training_steps}, + {"finetuning_losses", + toCSV(request.peft_finetuning_info.finetuning_losses)}, + {"latency", + (profile_info.finish_time - profile_info.start_time) / 1e3}}; + jsonList.push_back(entry); + } + } + + // Append the formatted JSON to the output_file. + std::ofstream outputFile(output_filepath, std::ios::app); + outputFile << jsonList.dump(2) << std::endl; + outputFile.close(); + std::cout << "Output saved to " << output_filepath << std::endl; + } +} + void RequestManager::save_profiling_info_to_csv(std::string output_folder, std::string dataset_name, std::string llm_model_name, @@ -1830,27 +1837,28 @@ BeamSearchBatchConfig request.tokens.push_back(token_pair.first); } } - log_req_mgr.print("[Done] guid(%zu) with final length(%zu)", + GenerationResult &gr = request_generation_results[request.guid]; + std::vector output_tokens = + std::vector(request.tokens.begin() + gr.input_tokens.size(), + request.tokens.end()); + if (is_eos_token(output_tokens.back())) { + // remove the EOS token + output_tokens.pop_back(); + } + std::string output_text = this->tokenizer_->Decode(output_tokens); + log_req_mgr.print("[Done] guid(%zu) initial_len(%d) final_length(%zu)", request.guid, + request.initial_len, request.tokens.size()); - std::string output = this->tokenizer_->Decode(request.tokens); - // Unlike Huggingface, the sentencepiece C++ library automatically - // removes the BOS token - if (model_type == ModelType::LLAMA && old_llama_tokenizer && - request.add_special_tokens && - request.tokens.at(0) == bos_token_id) { - output = " " + output; - } - { - // update generation result - GenerationResult &gr = request_generation_results[request.guid]; - assert(gr.guid == request.guid); - gr.output_tokens = request.tokens; - gr.output_text = output; - } + + // update generation result + assert(gr.guid == request.guid); + gr.output_tokens = output_tokens; + gr.output_text = output_text; + request.status = Request::COMPLETED; trigger_request_completion_future(request.guid); - log_req_mgr.print("Final output: %s", output.c_str()); + log_req_mgr.print("Final output: %s", output_text.c_str()); new_bc.request_completed[i] = true; new_bc.request_running[i] = false; @@ -1872,38 +1880,6 @@ BeamSearchBatchConfig profile_info.finish_time, profile_info.finish_time - profile_info.start_time); - // Write output to file if needed: - if (!output_filepath.empty()) { - std::ofstream outputFile(output_filepath, std::ios::app); - if (outputFile.is_open()) { - outputFile << "[Profile] guid(" << request.guid - << ") llm_decoding_steps(" - << profile_info.llm_decoding_steps << ") latency(" - << std::fixed << std::setprecision(3) - << (profile_info.finish_time - profile_info.start_time) - << ")\n"; - // outputFile << "end-to-end latency: " << std::fixed - // << std::setprecision(3) << total_request_run_time - // << std::endl; - // outputFile << "num decoding steps: " - // << profile_info.llm_decoding_steps << std::endl; - outputFile << "token IDs: "; - for (int i = 0; i < request.tokens.size(); i++) { - outputFile << request.tokens[i]; - if (i < request.tokens.size() - 1) { - outputFile << ","; - } - } - outputFile << std::endl; - outputFile << output; - outputFile.close(); - } else { - std::cout << "Unable to open the output file: " << output_filepath - << std::endl; - assert(false); - } - } - // delete the old input tree from cache dfs_tree_inputs.erase(request.guid); @@ -2148,6 +2124,20 @@ BeamSearchBatchConfig old_bc.print(); new_bc.print(); } + if (new_bc.num_active_tokens() > BatchConfig::max_tokens_per_batch()) { + std::cout << "Error: new_bc.num_active_tokens() > " + "BatchConfig::max_tokens_per_batch()" + << std::endl; + new_bc.print(); + assert(false); + } + if (new_bc.num_active_requests() > BatchConfig::max_requests_per_batch()) { + std::cout << "Error: new_bc.num_active_requests() > " + "BatchConfig::max_requests_per_batch()" + << std::endl; + new_bc.print(); + assert(false); + } return new_bc; } @@ -2461,6 +2451,20 @@ BeamSearchBatchConfig old_bc.print(); new_bc.print(); } + if (new_bc.num_active_tokens() > BatchConfig::max_tokens_per_batch()) { + std::cout << "Error: new_bc.num_active_tokens() > " + "BatchConfig::max_tokens_per_batch()" + << std::endl; + new_bc.print(); + assert(false); + } + if (new_bc.num_active_requests() > BatchConfig::max_requests_per_batch()) { + std::cout << "Error: new_bc.num_active_requests() > " + "BatchConfig::max_requests_per_batch()" + << std::endl; + new_bc.print(); + assert(false); + } return new_bc; } @@ -2784,6 +2788,20 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( } } + if (new_bc.num_active_tokens() > BatchConfig::max_tokens_per_batch()) { + std::cout << "Error: new_bc.num_active_tokens() > " + "BatchConfig::max_tokens_per_batch()" + << std::endl; + new_bc.print(); + assert(false); + } + if (new_bc.num_active_requests() > BatchConfig::max_requests_per_batch()) { + std::cout << "Error: new_bc.num_active_requests() > " + "BatchConfig::max_requests_per_batch()" + << std::endl; + new_bc.print(); + assert(false); + } return new_bc; } @@ -3481,6 +3499,7 @@ std::vector results.push_back(rm->get_generation_result(peft_guids[i])); } rm->run_idx++; + rm->save_output_to_json(); return results; } @@ -3555,6 +3574,7 @@ std::vector for (int i = 0; i < peft_guids.size(); i++) { results.push_back(rm->get_generation_result(peft_guids[i])); } + rm->save_output_to_json(); return results; } diff --git a/tests/fine_grained_alignment_test.sh b/tests/fine_grained_alignment_test.sh index 0ef134195..4baaa53ab 100755 --- a/tests/fine_grained_alignment_test.sh +++ b/tests/fine_grained_alignment_test.sh @@ -2,13 +2,21 @@ set -x set -e -MODEL_NAME=${MODEL_NAME:-"JackFram/llama-160m"} +MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-1B-Instruct"} MEMORY_PER_GPU=${MEMORY_PER_GPU:-14000} ZCOPY_MEMORY=${ZCOPY_MEMORY:-40000} TP_DEGREE=${TP_DEGREE:-2} PP_DEGREE=${PP_DEGREE:-2} CACHE_PATH=${FF_CACHE_PATH:-"~/.cache/flexflow"} NUM_STEPS=${NUM_STEPS:-2} +FULL_PRECISION=${FULL_PRECISION:-true} +FUSION=${FUSION:-true} + +# Token to access private huggingface models (e.g. LLAMA-2) +HUGGINGFACE_TOKEN=${HUGGINGFACE_TOKEN:-none} +if [[ "$HUGGINGFACE_TOKEN" != "none" ]]; then + huggingface-cli login --token "$HUGGINGFACE_TOKEN" +fi cleanup() { eval rm -rf "${CACHE_PATH}/debug" ./fine_grained_alignment_config.json ./inference/output/fine_grained_alignment_test_ff.txt ./inference/output/fine_grained_alignment_test_hf.txt @@ -30,7 +38,7 @@ mkdir -p ./inference/output # Enable backtrace in case we run into a segfault or assertion failure export LEGION_BACKTRACE=1 export FF_DEBG_NO_WEIGHTS=1 -FUSION=true + # Check if the Python code executed successfully @@ -48,13 +56,13 @@ fi MAX_LENGTH=$((PROMPT_LENGTH + NUM_STEPS + 1)) +if [ "$FULL_PRECISION" = "true" ]; then full_precision_flag="--use-full-precision"; else full_precision_flag=""; fi python ./tests/inference/huggingface_inference.py \ --model-name "${MODEL_NAME}" \ --max-length "${MAX_LENGTH}" \ --prompt-file ../../inference/prompt/test.json \ - --output-file ../../inference/output/fine_grained_alignment_test_hf.txt \ - --use-full-precision \ - --inference-debugging + --output-file ../../inference/output/fine_grained_alignment_test_hf.json \ + "${full_precision_flag}" --inference-debugging NUM_GPUS=$((TP_DEGREE * PP_DEGREE)) json_config=$(cat <<-END @@ -72,10 +80,10 @@ json_config=$(cat <<-END "refresh_cache": false, "llm_model": "${MODEL_NAME}", "cache_path": "${CACHE_PATH}", - "full_precision": true, + "full_precision": ${FULL_PRECISION}, "prompt": "./inference/prompt/test.json", "max_length": $MAX_LENGTH, - "output_file": "./inference/output/fine_grained_alignment_test_ff.txt" + "output_file": "./inference/output/fine_grained_alignment_test_ff.json" } END ) diff --git a/tests/inference/generate_inf_test_configs.py b/tests/inference/generate_inf_test_configs.py index fc0444885..15a1af681 100644 --- a/tests/inference/generate_inf_test_configs.py +++ b/tests/inference/generate_inf_test_configs.py @@ -19,7 +19,6 @@ "use_4bit_quantization": False, "use_8bit_quantization": False, "enable_peft": False, - "peft_activation_reserve_space_size": 1024, # 1GB "profiling": False, "benchmarking": False, "inference_debugging": False, @@ -34,7 +33,7 @@ "full_precision": True, "prompt": "", "output_file": "", - "max_length": 128, + "max_length": 255, } ssm_configs = { "ssms": [ @@ -70,7 +69,7 @@ def gen_incr_dec_configs(prompt_file, output_folder, incr_dec_models, parallelis + f"{tp}_tp_{pp}_pp" ) test_configs_file = os.path.join(config_output_folder, f"{filename}.json") - output_file = os.path.join(output_folder, filename + ".txt") + output_file = os.path.join(output_folder, filename + ".json") ff_init_configs["tensor_parallelism_degree"] = tp ff_init_configs["pipeline_parallelism_degree"] = pp @@ -99,7 +98,7 @@ def gen_spec_configs(prompt_file, output_folder, specinfer_model_pairs, parallel + f"{tp}_tp_{pp}_pp" ) test_configs_file = os.path.join(config_output_folder, f"{filename}.json") - output_file = os.path.join(output_folder, filename + ".txt") + output_file = os.path.join(output_folder, filename + ".json") ff_init_configs["tensor_parallelism_degree"] = tp ff_init_configs["pipeline_parallelism_degree"] = pp diff --git a/tests/inference/huggingface_inference.py b/tests/inference/huggingface_inference.py index 4af36ee22..443f1722f 100644 --- a/tests/inference/huggingface_inference.py +++ b/tests/inference/huggingface_inference.py @@ -23,7 +23,7 @@ def main(): # Parse command line arguments parser = argparse.ArgumentParser() parser.add_argument("--model-name", type=str, required=True) - parser.add_argument("--max-length", type=int, default=128) + parser.add_argument("--max-length", type=int, default=255) parser.add_argument("--prompt-file", type=str, required=True) parser.add_argument("--output-file", type=str, required=True) parser.add_argument( @@ -83,25 +83,33 @@ def main(): ############################################### # Generate output + output_list = [] + for i, prompt in enumerate(prompt_list): + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(model.device) + generated = model.generate( + batch["input_ids"], + max_length=args.max_length, + generation_config=generation_config, + ) + prompt_token_ids = list(batch["input_ids"].cpu().numpy()[0]) + response_token_ids = list(generated[0].cpu().numpy())[len(prompt_token_ids):] + # Remove eos token if present at the end + if response_token_ids[-1] == tokenizer.eos_token_id: + response_token_ids = response_token_ids[:-1] + response = tokenizer.decode(response_token_ids) + output_list.append({ + "req_idx": i, + "req_type": "inference", + "prompt_length": len(prompt_token_ids), + "response_length": len(response_token_ids), + "prompt": prompt, + "response": response, + "input_tokens": ",".join(str(x) for x in prompt_token_ids), + "output_tokens": ",".join(str(x) for x in response_token_ids), + "num_decoding_steps": len(response_token_ids), + }) with open(args.output_file, "w") as f: - for i, prompt in enumerate(prompt_list): - batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(model.device) - generated = model.generate( - batch["input_ids"], - max_length=args.max_length, - generation_config=generation_config, - ) - token_ids = list(generated[0].cpu().numpy()) - # Remove eos token if present at the end - if token_ids[-1] == tokenizer.eos_token_id: - token_ids = token_ids[:-1] - out = tokenizer.decode(generated[0]) - if out.endswith(tokenizer.eos_token): - out = out[:-len(tokenizer.eos_token)] - # Write output to file - out_str = out if i == (len(prompt_list) - 1) else out + "\n" - f.write("token IDs: " + ",".join(str(x) for x in token_ids) + "\n") - f.write(out_str) + json.dump(output_list, f, indent=2) if __name__ == "__main__": diff --git a/tests/inference/inference_alignment_test.py b/tests/inference/inference_alignment_test.py index 1fe2bfbaa..aa417504c 100644 --- a/tests/inference/inference_alignment_test.py +++ b/tests/inference/inference_alignment_test.py @@ -261,37 +261,43 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance torch.testing.assert_close(hf_q_proj_in, hf_k_proj_in) torch.testing.assert_close(hf_k_proj_in, hf_v_proj_in) compare(hf_q_proj_in, ff_qkv_tensor_in, label=f"QKV proj {i} input") + + bz, seq_len, hidden_dim = hf_q_proj_out.shape + head_dim = hidden_dim // self.num_attention_heads + tot_num_heads = self.num_attention_heads + 2*self.num_key_value_heads + ff_qkv_tensor_out = get_ff_tensor( ff_qkv_tensor_name, output_comparison, - torch.Size([hf_q_proj_out.shape[0], hf_q_proj_out.shape[1], 3*hf_q_proj_out.shape[2]]), + torch.Size([bz, seq_len, head_dim*tot_num_heads]), tp_type=TPType.PARTITION ) - head_dim = hf_q_proj_out.shape[2] // self.num_attention_heads - heads_per_shard = self.num_attention_heads // self.tp_degree - chunk_size = head_dim * heads_per_shard + q_heads_per_shard = self.num_attention_heads // self.tp_degree + kv_heads_per_shard = self.num_key_value_heads // self.tp_degree + q_chunk_size = head_dim * q_heads_per_shard + kv_chunk_size = head_dim * kv_heads_per_shard # print(ff_qkv_tensor_out.shape) - ff_qproj_out = ff_qkv_tensor_out[:chunk_size, :, :] - ff_kproj_out = ff_qkv_tensor_out[chunk_size:2*chunk_size, :, :] - ff_vproj_out = ff_qkv_tensor_out[2*chunk_size : 3*chunk_size, :, :] - qkv_chunk_size = 3*chunk_size + ff_qproj_out = ff_qkv_tensor_out[: q_chunk_size, :, :] + ff_kproj_out = ff_qkv_tensor_out[q_chunk_size : q_chunk_size + kv_chunk_size, :, :] + ff_vproj_out = ff_qkv_tensor_out[q_chunk_size + kv_chunk_size : q_chunk_size + 2*kv_chunk_size, :, :] + qkv_chunk_size = q_chunk_size + 2*kv_chunk_size for tp_idx in range(1, self.tp_degree): prev_size = tp_idx * qkv_chunk_size - ff_qproj_out_ = ff_qkv_tensor_out[prev_size : prev_size + chunk_size, :, :] - ff_kproj_out_ = ff_qkv_tensor_out[prev_size + chunk_size : prev_size + 2*chunk_size, :, :] - ff_vproj_out_ = ff_qkv_tensor_out[prev_size + 2*chunk_size : prev_size + 3*chunk_size, :, :] + ff_qproj_out_ = ff_qkv_tensor_out[prev_size : prev_size + q_chunk_size, :, :] + ff_kproj_out_ = ff_qkv_tensor_out[prev_size + q_chunk_size : prev_size + q_chunk_size + kv_chunk_size, :, :] + ff_vproj_out_ = ff_qkv_tensor_out[prev_size + q_chunk_size + kv_chunk_size : prev_size + q_chunk_size + 2*kv_chunk_size, :, :] ff_qproj_out = np.concatenate((ff_qproj_out, ff_qproj_out_), axis=0) ff_kproj_out = np.concatenate((ff_kproj_out, ff_kproj_out_), axis=0) ff_vproj_out = np.concatenate((ff_vproj_out, ff_vproj_out_), axis=0) - compare_loaded_tensors(hf_q_proj_out.T, ff_qproj_out) - compare_loaded_tensors(hf_k_proj_out.T, ff_kproj_out) - compare_loaded_tensors(hf_v_proj_out.T, ff_vproj_out) + compare_loaded_tensors(hf_q_proj_out.T, ff_qproj_out, label=f"Q proj {i} output") + compare_loaded_tensors(hf_k_proj_out.T, ff_kproj_out, label=f"K proj {i} output") + compare_loaded_tensors(hf_v_proj_out.T, ff_vproj_out, label=f"V proj {i} output") ff_tensor_name = f"layers.{i}.layers.{i}.self_attn" input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0) ff_attn_tensor_in = get_ff_tensor( ff_tensor_name, input_comparison, - torch.Size([hf_q_proj_out.shape[0], hf_q_proj_out.shape[1], 3*hf_q_proj_out.shape[2]]), + torch.Size([bz, seq_len, head_dim*tot_num_heads]), tp_type=TPType.PARTITION ) assert torch.allclose(ff_qkv_tensor_out, ff_attn_tensor_in) diff --git a/tests/inference/test_inference_output.py b/tests/inference/test_inference_output.py index f5021fa23..3fc0a138b 100644 --- a/tests/inference/test_inference_output.py +++ b/tests/inference/test_inference_output.py @@ -1,48 +1,51 @@ import os -import re import glob import pytest +import json OUTPUT_DIR = os.path.join("inference", "output") -def get_line(filepath, line_index): - """ - Returns the specified line (0-based) from the file, or '' if the file - doesn't have that many lines. - """ - with open(filepath, "r", encoding="utf-8") as f: - lines = f.readlines() - return lines[line_index] if len(lines) > line_index else "" - -def compare_single_line(file_a, file_b): - """ - Compare a single line in two files: - - If filename starts with 'spec_infer' or 'incr_dec', compare line index = 1 (2nd line). - - If filename starts with 'huggingface_', compare line index = 0 (1st line). - Raise AssertionError if they differ. - """ - base_a = os.path.basename(file_a) - base_b = os.path.basename(file_b) - - if base_a.startswith(("spec_infer", "incr_dec")): - line_a = get_line(file_a, 1) - else: - line_a = get_line(file_a, 0) - - if base_b.startswith(("spec_infer", "incr_dec")): - line_b = get_line(file_b, 1) - else: - line_b = get_line(file_b, 0) - - list_a = line_a[len("token IDs: "):].split(",") - list_b = line_b[len("token IDs: "):].split(",") - - # check if the first 50 elements are equal - for i in range(min(50, len(list_a), len(list_b))): - if list_a[i] != list_b[i]: - raise AssertionError( - f"File contents differ at position {i}:\n {file_a} -> {list_a[i]}\n {file_b} -> {list_b[i]}" - ) +def compare_output_tokens(file1, file2): + """ + Open two JSON files (each containing a list of dictionaries), check that they have the same number + of dictionaries, sort them by 'req_idx', and then for each matching req_idx, compare the first 50 + output_tokens (or all tokens if fewer than 50). The output_tokens are stored as a comma-separated + string of integers under the "output_tokens" key. + """ + # Helper function to convert a comma-separated string of integers into a list of ints. + def parse_tokens(token_str): + return [int(tok.strip()) for tok in token_str.split(',') if tok.strip()] + + # Load both JSON files. + with open(file1, 'r') as f1, open(file2, 'r') as f2: + data1 = json.load(f1) + data2 = json.load(f2) + + # Check that both files have the same number of dictionaries. + if len(data1) != len(data2): + raise ValueError("Error: Files do not have the same number of dictionaries.") + + # Sort both lists by the 'req_idx' key. + data1_sorted = sorted(data1, key=lambda d: d['req_idx']) + data2_sorted = sorted(data2, key=lambda d: d['req_idx']) + + # Compare each pair of dictionaries. + for d1, d2 in zip(data1_sorted, data2_sorted): + req_idx1 = d1.get('req_idx') + req_idx2 = d2.get('req_idx') + + # Verify that req_idx values match. + if req_idx1 != req_idx2: + raise ValueError(f"Mismatch in req_idx: {req_idx1} vs {req_idx2}") + + # Parse the output tokens from the comma-separated strings. + tokens1 = parse_tokens(d1.get('output_tokens', '')) + tokens2 = parse_tokens(d2.get('output_tokens', '')) + + # Determine the number of tokens to compare. + num_to_compare = min(30, len(tokens1), len(tokens2)) + if tokens1[:num_to_compare] != tokens2[:num_to_compare]: + raise ValueError(f"Output tokens mismatch for req_idx {req_idx1} at idx {num_to_compare}/{len(tokens1)}:") def group_model_files(prefix): @@ -68,7 +71,7 @@ def collect_file_comparisons(): """ Yields tuples (file_a, file_b) for all pairwise comparisons among spec_infer or incr_dec files that share a model name, - plus the comparison with huggingface_.txt if it exists. + plus the comparison with huggingface_.json if it exists. """ for prefix in ["spec_infer", "incr_dec"]: grouped = group_model_files(prefix) @@ -77,28 +80,18 @@ def collect_file_comparisons(): for i in range(len(file_group)): for j in range(i+1, len(file_group)): yield file_group[i], file_group[j] - # Compare with huggingface_.txt - hf_file = os.path.join(OUTPUT_DIR, f"huggingface_{model_name}.txt") + # Compare with huggingface_.json + hf_file = os.path.join(OUTPUT_DIR, f"huggingface_{model_name}.json") if os.path.exists(hf_file) and file_group: yield file_group[0], hf_file -def _extract_llm_decoding_steps(line): - """ - Given a string like: - [Profile] guid(26516) llm_decoding_steps(69) latency(123456) - parse and return the integer after llm_decoding_steps(...). - Return None if not found. - """ - match = re.search(r'llm_decoding_steps\((\d+)\)', line) - return int(match.group(1)) if match else None - def collect_spec_infer_incr_dec_pairs(): """ Yields (spec_file, incr_file) for files that have the same trailing name after the prefix spec_infer- / incr_dec-. """ - all_files = glob.glob(os.path.join(OUTPUT_DIR, "*.*")) # .txt/.json + all_files = glob.glob(os.path.join(OUTPUT_DIR, "*.*")) # .json/.json spec_infer = {} for f in all_files: base = os.path.basename(f) @@ -118,7 +111,7 @@ def test_output_alignment(file_a, file_b): """ Each file pair is tested and reported separately. """ - compare_single_line(file_a, file_b) + compare_output_tokens(file_a, file_b) @@ -126,26 +119,40 @@ def test_output_alignment(file_a, file_b): ids=lambda f: os.path.basename(f)) def test_decoding_steps(spec_file, incr_file): """ - For each matching pair (same suffix), compare the first line: - "[Profile] guid(...) llm_decoding_steps(...) latency(...)" - Ensure that spec_infer's llm_decoding_steps is <= incr_dec's steps / 1.5. - """ - with open(spec_file, "r", encoding="utf-8") as fs: - spec_line = fs.readline() - with open(incr_file, "r", encoding="utf-8") as fi: - incr_line = fi.readline() - - spec_steps = _extract_llm_decoding_steps(spec_line) - incr_steps = _extract_llm_decoding_steps(incr_line) - - # If we don't have valid numbers in one or both lines, skip - if spec_steps is None or incr_steps is None: - pytest.skip(f"No valid llm_decoding_steps found in {spec_file} or {incr_file}") - - # Check ratio - if not (spec_steps <= incr_steps / 1.5): - raise AssertionError( - f"[{os.path.basename(spec_file)} vs {os.path.basename(incr_file)}] " - f"spec_infer has llm_decoding_steps={spec_steps}, which is not " - f"<= incr_dec steps={incr_steps}/1.5 = {incr_steps/1.5:.1f}" - ) \ No newline at end of file + Open two JSON files (each containing a list of dictionaries), check that they have the same number + of dictionaries, sort them by 'req_idx', and then for each matching req_idx, check that the + value of 'num_decoding_steps' in spec_file is <= the corresponding value in incr_file / 1.5 times. + """ + # Load JSON data from both files. + with open(spec_file, 'r') as f1, open(incr_file, 'r') as f2: + spec_data = json.load(f1) + inc_data = json.load(f2) + + # Verify that both files contain the same number of dictionaries. + if len(spec_data) != len(inc_data): + print("Error: Files do not have the same number of dictionaries.") + return + + # Sort both lists by the 'req_idx' key. + data1_sorted = sorted(spec_data, key=lambda d: d['req_idx']) + data2_sorted = sorted(inc_data, key=lambda d: d['req_idx']) + + # Compare each pair of dictionaries. + for d1, d2 in zip(data1_sorted, data2_sorted): + req_idx_spec = d1.get('req_idx') + req_idx_inc_dec = d2.get('req_idx') + + # Ensure the req_idx values match. + if req_idx_spec != req_idx_inc_dec: + raise ValueError(f"Mismatch in req_idx: {req_idx_spec} vs {req_idx_inc_dec}") + + # Get the num_decoding_steps values. + steps_spec = d1.get('num_decoding_steps') + steps_inc_dec = d2.get('num_decoding_steps') + + if steps_spec is None or steps_inc_dec is None: + raise ValueError(f"Missing 'num_decoding_steps' for req_idx {req_idx_spec}") + + # Check if steps1 is <= 1.5 times steps2. + if not (steps_spec <= steps_inc_dec / 1.5): + raise ValueError(f"req_idx {req_idx_spec}: {steps_spec} speculation steps, which is not <= {steps_inc_dec} / 1.5") diff --git a/tests/inference_tests.sh b/tests/inference_tests.sh index 54c1884e8..120d3a58b 100755 --- a/tests/inference_tests.sh +++ b/tests/inference_tests.sh @@ -19,6 +19,25 @@ rm -rf inference/prompt inference/output inference/inf_test_configs || true # Create test prompt file mkdir -p ./inference/prompt echo '["Three tips for staying healthy are: "]' > ./inference/prompt/test.json +# cat << 'EOF' > ./inference/prompt/test.json +# [ +# "The largest ocean on Earth is", +# "The inventor of the telephone was", +# "The speed of light is", +# "The tallest mountain in the world is", +# "The first man on the moon was" +# ] +# EOF +# cat << 'EOF' > ./inference/prompt/test.json +# [ +# "In the year 2075, artificial intelligence has become deeply integrated into every aspect of human life. Autonomous robots manage infrastructure, AI-powered doctors perform complex surgeries with unmatched precision, and personalized AI assistants anticipate people's needs before they even express them. Despite these advancements, ethical concerns continue to grow. One of the most pressing debates surrounding AI development in this era is whether", +# "The rapid development of space exploration has led humanity to establish permanent settlements beyond Earth. With bases on the Moon and Mars, scientists and engineers work tirelessly to create sustainable ecosystems that can support human life in the long term. However, numerous challenges remain, from radiation exposure to psychological effects of isolation in deep space. One of the most critical issues that must be addressed before humanity can expand further into the solar system is", +# "Throughout history, scientific discoveries have continuously reshaped our understanding of the universe. The shift from a geocentric to a heliocentric model, the theory of relativity, and the advent of quantum mechanics have all challenged previous assumptions and opened new frontiers of knowledge. As we continue to explore the cosmos, scientists are now focused on solving one of the most perplexing mysteries of all: the nature of dark matter and dark energy. If researchers were to uncover definitive proof regarding their existence, it could mean that", +# "The emergence of advanced genetic engineering techniques has revolutionized modern medicine, allowing scientists to edit DNA with unprecedented precision. With technologies like CRISPR, researchers have already corrected genetic mutations that cause severe diseases and are even exploring the potential of enhancing human traits such as intelligence and longevity. However, this progress raises profound ethical concerns, as the ability to manipulate the human genome could lead to unforeseen consequences. One of the major dilemmas in the future of genetic engineering revolves around", +# "Climate change has become the defining challenge of the 21st century, with rising global temperatures, extreme weather events, and melting ice caps threatening ecosystems and human populations worldwide. Scientists and policymakers are racing against time to develop sustainable solutions, from carbon capture technologies to alternative energy sources like nuclear fusion. Despite these efforts, one of the biggest obstacles to achieving global climate stability is the fact that" +# ] +# EOF + # Create output folder mkdir -p ./inference/output @@ -60,8 +79,9 @@ for model_name in "${model_names[@]}"; do model_name_=$(echo "$model_name" | cut -d'/' -f2 | tr '[:upper:]' '[:lower:]') python ./tests/inference/huggingface_inference.py \ --model-name "$model_name" \ + --max-length 255 \ --prompt-file "${PWD}/inference/prompt/test.json" \ - --output-file "${PWD}/inference/output/huggingface_$model_name_.txt" + --output-file "${PWD}/inference/output/huggingface_$model_name_.json" done ############## Check alignment between results ############## diff --git a/tests/peft/alignment/align_test_utils.py b/tests/peft/alignment/align_test_utils.py index a8a9be2f3..b531ac89f 100644 --- a/tests/peft/alignment/align_test_utils.py +++ b/tests/peft/alignment/align_test_utils.py @@ -415,7 +415,7 @@ def load_hf_tensor(filename: str): return hf_tensor -def compare_loaded_tensors(hf_tensor, ff_tensor, tolerance=1e-2): +def compare_loaded_tensors(hf_tensor, ff_tensor, tolerance=1e-2, label=""): """Check whether a Huggingface and a FlexFlow tensors, both loaded to memory in the form of a numpy array, are equal Args: @@ -425,13 +425,15 @@ def compare_loaded_tensors(hf_tensor, ff_tensor, tolerance=1e-2): """ assert hf_tensor.shape == ff_tensor.shape mismatches = [] + len_hf_tensor = hf_tensor.flatten().shape[0] if not np.allclose(hf_tensor, ff_tensor, atol=tolerance): - print(f"mismatch between hf_tensor and ff_tensor") - print(f"HF: {hf_tensor}\nFF:{ff_tensor}") - print(np.isclose(hf_tensor, ff_tensor, atol=tolerance)) - mismatches = np.where(~np.isclose(hf_tensor, ff_tensor, atol=tolerance))[0] + mismatches = np.where(~np.isclose(hf_tensor.squeeze(), ff_tensor.squeeze(), atol=tolerance))[0] + label = label + ": " if label else "" + pct_mismatch = len(mismatches) / len_hf_tensor + print(f"{label} {pct_mismatch*100:.3}% mismatch between hf_tensor and ff_tensor") + print(f"HF: {hf_tensor.squeeze()}\nFF: {ff_tensor.squeeze()}") + print(np.isclose(hf_tensor.squeeze(), ff_tensor.squeeze(), atol=tolerance)) # print(mismatches) - len_hf_tensor = hf_tensor.flatten().shape[0] assert len(mismatches) <= 0.05 * len_hf_tensor print("Ok!")