From 0a889aeb2c735d1ebc283bc949fd83785112d596 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Sat, 8 Mar 2025 23:04:52 +0000 Subject: [PATCH] fix bugs --- include/flexflow/request_manager.h | 3 +- inference/flexllm/peft_train.cc | 34 +++++----- inference/models/llama.cc | 3 +- inference/peft/peft.cc | 22 +++---- inference/python/ff_peft.py | 2 +- inference/python/peft_demo/demo.ipynb | 8 +-- inference/python/peft_demo/demo.py | 6 +- python/flexflow/core/flexflow_cffi.py | 4 +- src/c/flexflow_c.cc | 2 +- src/ops/inc_multihead_self_attention.cu | 2 +- src/ops/kernels/residual_rms_norm_kernels.cu | 2 +- src/ops/kernels/rms_norm_kernels.cu | 4 +- src/runtime/request_manager.cc | 66 +++++++++++++------- 13 files changed, 92 insertions(+), 66 deletions(-) diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index e7809aa3b..5cfb8e485 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -98,7 +98,7 @@ struct Request { struct PeftFinetuningInfo { FinetuningStatus status = FORWARD_PHASE; std::string dataset_filepath; - int max_training_steps = 1; + int max_training_epochs = 1; // overall state int completed_training_steps = 0; // fwd state @@ -456,6 +456,7 @@ class RequestManager { double start_time, finish_time; double registration_time, first_token_time; bool first_token_time_set = false; + int num_evictions = 0; }; std::unordered_map profiling_requests; double total_request_run_time; diff --git a/inference/flexllm/peft_train.cc b/inference/flexllm/peft_train.cc index 689af5417..f1e39b58e 100644 --- a/inference/flexllm/peft_train.cc +++ b/inference/flexllm/peft_train.cc @@ -53,7 +53,7 @@ void parse_input_args(char **argv, int &max_tokens_per_batch, int &max_sequence_length, int &num_kv_cache_slots, - int &max_training_steps, + int &max_training_epochs, int &num_layers_per_finetuning_step, bool &run_warmup) { for (int i = 1; i < argc; i++) { @@ -144,7 +144,7 @@ void parse_input_args(char **argv, continue; } if (!strcmp(argv[i], "--max-training-steps")) { - max_training_steps = std::stoi(argv[++i]); + max_training_epochs = std::stoi(argv[++i]); continue; } if (!strcmp(argv[i], "--num-layers-per-finetuning-step")) { @@ -183,7 +183,8 @@ std::vector make_warmup_requests(int num_inf_request, finetuning_req.warmup = true; finetuning_req.peft_model_id = (peft_model_id != nullptr) ? *peft_model_id : PEFTModelID::NO_ID; - finetuning_req.peft_finetuning_info.max_training_steps = num_finetuning_steps; + finetuning_req.peft_finetuning_info.max_training_epochs = + num_finetuning_steps; warmup_requests.push_back(finetuning_req); return warmup_requests; } @@ -229,7 +230,12 @@ std::vector load_trace(nlohmann::ordered_json prompt_json, std::vector load_requests(std::string prompt_file_path, int max_length_if_needed) { std::ifstream file_handle(prompt_file_path); - assert(!file_handle.good() && "Error opening prompt file!"); + if (!file_handle.good()) { + std::cerr << "Error opening prompt file " << prompt_file_path << std::endl; + std::cerr << "Current working directory: " + << std::filesystem::current_path() << std::endl; + assert(!file_handle.good() && "Error opening prompt file!"); + } nlohmann::ordered_json prompt_json; try { prompt_json = nlohmann::ordered_json::parse(file_handle, @@ -248,9 +254,9 @@ std::vector load_requests(std::string prompt_file_path, std::cerr << "Error: JSON file is null!" << std::endl; assert(false); } else if (prompt_json.is_array()) { - return load_prompt_list(prompt_file_path, max_length_if_needed); + return load_prompt_list(prompt_json, max_length_if_needed); } else if (prompt_json.is_object()) { - return load_trace(prompt_file_path); + return load_trace(prompt_json); } else { std::cerr << "JSON is neither an array nor an object!" << std::endl; assert(false); @@ -277,7 +283,7 @@ void FlexFlow::top_level_task(Task const *task, int max_requests_per_batch = 1; int max_tokens_per_batch = 128; int max_sequence_length = 256; - int max_training_steps = 2; + int max_training_epochs = 2; bool enable_peft_finetuning = true; int num_layers_per_finetuning_step = -1; bool run_warmup = false; @@ -301,7 +307,7 @@ void FlexFlow::top_level_task(Task const *task, max_tokens_per_batch, max_sequence_length, num_kv_cache_slots, - max_training_steps, + max_training_epochs, num_layers_per_finetuning_step, run_warmup); enable_peft_finetuning = file_paths.dataset_file_path.empty() ? false : true; @@ -386,8 +392,7 @@ void FlexFlow::top_level_task(Task const *task, // load PEFT config int rank = 16; LoraOptimizerConfig *optim_config = new LoraSGDOptimizerConfig(0.001f); - std::vector target_modules = { - "qkv_proj", "o_proj", "gate_proj", "down_proj", "up_proj"}; + std::vector target_modules = {"down_proj"}; LoraLinearConfig peft_config_finetuning(file_paths.cache_folder_path, peft_model_name, true /*trainable*/, @@ -480,20 +485,17 @@ void FlexFlow::top_level_task(Task const *task, load_requests(file_paths.prompt_file_path, 128); // Add fine-tuning request - assert(!file_paths.dataset_file_path.empty() && "Dataset file path is required for fine-tuning."); printf("Finetuning request with dataset %s\n", file_paths.dataset_file_path.c_str()); Request fine_tuning_req; fine_tuning_req.req_type = RequestType::REQ_FINETUNING; - fine_tuning_req.peft_model_id = (peft_model_id_finetuning != nullptr) - ? *peft_model_id_finetuning - : PEFTModelID::NO_ID; + fine_tuning_req.peft_model_id = *peft_model_id_finetuning; fine_tuning_req.peft_finetuning_info.dataset_filepath = file_paths.dataset_file_path; - fine_tuning_req.peft_finetuning_info.max_training_steps = - max_training_steps; + fine_tuning_req.peft_finetuning_info.max_training_epochs = + max_training_epochs; requests.push_back(fine_tuning_req); std::cout << "----------inference started--------------" << std::endl; diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 408ad4e38..37d6f444b 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -36,7 +36,8 @@ void LLAMA::create_llama_model(FFModel &ff, assert(false && "The number of attention heads is smaller, or it is not " "divisible by the tensor parallelism degree"); } - + std::cout << "Creating llama model with ff.config.enable_peft_finetuning=" + << ff.config.enable_peft_finetuning << std::endl; assert(llama_config.hidden_size % llama_config.num_attention_heads == 0 && "Hidden size not divisible by number of attention heads"); int head_dim = llama_config.hidden_size / llama_config.num_attention_heads; diff --git a/inference/peft/peft.cc b/inference/peft/peft.cc index 9f52650d9..3f23ecfb4 100644 --- a/inference/peft/peft.cc +++ b/inference/peft/peft.cc @@ -53,7 +53,7 @@ void parse_input_args(char **argv, int &max_tokens_per_batch, int &max_sequence_length, int &num_kv_cache_slots, - int &max_training_steps, + int &max_training_epochs, int &num_layers_per_finetuning_step, bool &run_warmup) { for (int i = 1; i < argc; i++) { @@ -87,7 +87,7 @@ void parse_input_args(char **argv, continue; } // dataset for finetuning - if (!strcmp(argv[i], "")) { + if (!strcmp(argv[i], "-finetuning-dataset")) { paths.dataset_file_path = std::string(argv[++i]); continue; } @@ -144,7 +144,7 @@ void parse_input_args(char **argv, continue; } if (!strcmp(argv[i], "--max-training-steps")) { - max_training_steps = std::stoi(argv[++i]); + max_training_epochs = std::stoi(argv[++i]); continue; } if (!strcmp(argv[i], "--num-layers-per-finetuning-step")) { @@ -183,7 +183,8 @@ std::vector make_warmup_requests(int num_inf_request, finetuning_req.warmup = true; finetuning_req.peft_model_id = (peft_model_id != nullptr) ? *peft_model_id : PEFTModelID::NO_ID; - finetuning_req.peft_finetuning_info.max_training_steps = num_finetuning_steps; + finetuning_req.peft_finetuning_info.max_training_epochs = + num_finetuning_steps; warmup_requests.push_back(finetuning_req); return warmup_requests; } @@ -207,7 +208,7 @@ void FlexFlow::top_level_task(Task const *task, int max_requests_per_batch = 1; int max_tokens_per_batch = 128; int max_sequence_length = 256; - int max_training_steps = 2; + int max_training_epochs = 2; bool enable_peft_finetuning = true; int num_layers_per_finetuning_step = -1; bool run_warmup = false; @@ -231,7 +232,7 @@ void FlexFlow::top_level_task(Task const *task, max_tokens_per_batch, max_sequence_length, num_kv_cache_slots, - max_training_steps, + max_training_epochs, num_layers_per_finetuning_step, run_warmup); @@ -357,6 +358,8 @@ void FlexFlow::top_level_task(Task const *task, rm->set_enable_peft_finetuning(enable_peft_finetuning); FFModel model(ffconfig, ffconfig.cpu_offload); + model.set_num_kv_cache_pages(compute_num_kv_cache_pages_needed( + max_sequence_length, max_requests_per_batch, false)); if (model_type == ModelType::LLAMA) { LLAMA::create_llama_model(model, config_filepath, @@ -394,9 +397,6 @@ void FlexFlow::top_level_task(Task const *task, assert(false && "unknow model type"); } - model.set_num_kv_cache_pages(compute_num_kv_cache_pages_needed( - max_sequence_length, max_requests_per_batch, false)); - rm->set_num_transformer_layers(model.current_transformer_layer_id + 1); if (num_layers_per_finetuning_step > 0) { rm->set_num_layers_per_finetuning_step(num_layers_per_finetuning_step); @@ -464,8 +464,8 @@ void FlexFlow::top_level_task(Task const *task, : PEFTModelID::NO_ID; fine_tuning_req.peft_finetuning_info.dataset_filepath = file_paths.dataset_file_path; - fine_tuning_req.peft_finetuning_info.max_training_steps = - max_training_steps; + fine_tuning_req.peft_finetuning_info.max_training_epochs = + max_training_epochs; requests.push_back(fine_tuning_req); } std::vector result = model.generate(requests); diff --git a/inference/python/ff_peft.py b/inference/python/ff_peft.py index 2adf5367b..0ad6900b7 100644 --- a/inference/python/ff_peft.py +++ b/inference/python/ff_peft.py @@ -165,7 +165,7 @@ def main(): ff.RequestType.REQ_FINETUNING, peft_model_id=llm.get_ff_peft_id(lora_finetuning_config), dataset_filepath=configs.finetuning_dataset, - max_training_steps=2, + max_training_epochs=2, ) requests.append(finetuning_request) diff --git a/inference/python/peft_demo/demo.ipynb b/inference/python/peft_demo/demo.ipynb index ea2b8417b..f3cd113a4 100644 --- a/inference/python/peft_demo/demo.ipynb +++ b/inference/python/peft_demo/demo.ipynb @@ -97,7 +97,7 @@ " \"max_requests_per_batch\": 1,\n", " \"max_sequence_length\": 128,\n", " \"max_tokens_per_batch\": 128,\n", - " \"max_training_steps\": 100,\n", + " \"max_training_epochs\": 100,\n", " \"seed\": 42,\n", "}\n", "model_configs = {\n", @@ -1082,7 +1082,7 @@ " max_sequence_length=configs.max_sequence_length,\n", " peft_model_id=llm.get_ff_peft_id(lora_finetuning_config),\n", " dataset_filepath=os.path.join(os.getcwd(), configs.finetuning_dataset),\n", - " max_training_steps=configs.max_training_steps,\n", + " max_training_epochs=configs.max_training_epochs,\n", ")\n", "ft_res = llm.generate([finetuning_request])" ] @@ -1104,7 +1104,7 @@ } ], "source": [ - "epochs = list(range(configs_dict[\"max_training_steps\"]))\n", + "epochs = list(range(configs_dict[\"max_training_epochs\"]))\n", "loss_values = ft_res[0].finetuning_losses\n", "\n", "plt.figure(figsize=(10, 6))\n", @@ -1778,7 +1778,7 @@ " \"max_requests_per_batch\": 1,\n", " \"max_sequence_length\": 128,\n", " \"max_tokens_per_batch\": 128,\n", - " \"max_training_steps\": 100,\n", + " \"max_training_epochs\": 100,\n", " \"seed\": 42,\n", "}\n", "model_configs = {\n", diff --git a/inference/python/peft_demo/demo.py b/inference/python/peft_demo/demo.py index da64336e7..fb63d89f7 100644 --- a/inference/python/peft_demo/demo.py +++ b/inference/python/peft_demo/demo.py @@ -52,7 +52,7 @@ def create_datasets(finetune_dataset_size=2, inference_file_path='inference_data "max_requests_per_batch": 1, "max_sequence_length": 128, "max_tokens_per_batch": 128, - "max_training_steps": 100, + "max_training_epochs": 100, "seed": 42, } model_configs = { @@ -185,7 +185,7 @@ def create_datasets(finetune_dataset_size=2, inference_file_path='inference_data max_sequence_length=configs.max_sequence_length, peft_model_id=llm.get_ff_peft_id(lora_finetuning_config), dataset_filepath=os.path.join(os.getcwd(), configs.finetuning_dataset), - max_training_steps=configs.max_training_steps, + max_training_epochs=configs.max_training_epochs, ) ft_res = llm.generate([finetuning_request]) for res in ft_res: @@ -231,7 +231,7 @@ def create_datasets(finetune_dataset_size=2, inference_file_path='inference_data print("==Inference result after finetuning: ", inf_req_res_2[0].output_text) -epochs = list(range(configs_dict["max_training_steps"])) +epochs = list(range(configs_dict["max_training_epochs"])) loss_values = ft_res[0].finetuning_losses plt.figure(figsize=(10, 6)) diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py index 469453c75..6f5bb685b 100644 --- a/python/flexflow/core/flexflow_cffi.py +++ b/python/flexflow/core/flexflow_cffi.py @@ -2136,7 +2136,7 @@ class Request: add_special_tokens: bool = True peft_model_id: Optional[PEFTModelID] = None dataset_filepath: Optional[str] = None - max_training_steps: int = 1 + max_training_epochs: int = 1 # ----------------------------------------------------------------------- @@ -4492,7 +4492,7 @@ def generate(self, requests_list: List[Request]): dataset_filepaths = [ get_c_name(request.dataset_filepath) for request in requests_list ] - training_steps = [request.max_training_steps for request in requests_list] + training_steps = [request.max_training_epochs for request in requests_list] num_finetuning_losses = ffi.new("int *") # c_finetuning_losses = ffi.new("float**") # TODO: set this value automatically diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index 5abde0579..5a28d1680 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -1604,7 +1604,7 @@ void flexflow_model_generate(flexflow_model_t handle_, } std::string const dataset_fp(dataset_filepaths[i]); fine_tuning_req.peft_finetuning_info.dataset_filepath = dataset_fp; - fine_tuning_req.peft_finetuning_info.max_training_steps = + fine_tuning_req.peft_finetuning_info.max_training_epochs = training_steps[i]; requests.push_back(fine_tuning_req); DEBUG_PRINT("[Model] finetune[%d] %p %s %i %i %i %i", diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 584083083..27c9d1fec 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -1388,7 +1388,7 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, assert(num_tokens == num_total_tokens); assert(num_total_tokens == bc->requestsInfo[i].max_length); assert(m->qProjSize == m->kProjSize && m->kProjSize == m->vProjSize); - assert(bc->requestsInfo[i].first_token_offset_in_batch == 0); + // assert(bc->requestsInfo[i].first_token_offset_in_batch == 0); if (m->inference_debugging) { // save result to file for checking diff --git a/src/ops/kernels/residual_rms_norm_kernels.cu b/src/ops/kernels/residual_rms_norm_kernels.cu index a576bad33..5cb7a0239 100644 --- a/src/ops/kernels/residual_rms_norm_kernels.cu +++ b/src/ops/kernels/residual_rms_norm_kernels.cu @@ -429,7 +429,7 @@ void peft_bwd_kernel(ResidualRMSNormMeta const *m, bc->peft_bwd_applies_to_this_layer(m->layer_guid.transformer_layer_id)); int i = bc->finetuning_request_index(); - int M = bc->requestsInfo[i].num_tokens_in_batch; + int M = bc->num_finetuning_bwd_tokens(); int N = m->in_dim; T const *residual_output_rms_input_ptr = diff --git a/src/ops/kernels/rms_norm_kernels.cu b/src/ops/kernels/rms_norm_kernels.cu index 928770616..1e1936b23 100644 --- a/src/ops/kernels/rms_norm_kernels.cu +++ b/src/ops/kernels/rms_norm_kernels.cu @@ -448,8 +448,8 @@ void peft_bwd_kernel(RMSNormMeta const *m, bc->peft_bwd_applies_to_this_layer(m->layer_guid.transformer_layer_id)); int i = bc->finetuning_request_index(); - int M = bc->requestsInfo[i].num_tokens_in_batch; - int N = m->num_elements; + int M = bc->num_finetuning_bwd_tokens(); + int N = m->in_dim; ComputeInternalGradientsCUDAKernel <<>>( N, diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index beef76497..e403d40ef 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -104,9 +104,10 @@ bool RequestManager::load_request_token_ids(Request &request) { request.max_length = request.tokens.size() + request.max_new_tokens; } if (request.max_length >= get_max_sequence_length()) { - std::cout << "Error: max_length (" << request.max_length + std::cerr << "Error: max_length (" << request.max_length << ") exceeds max sequence length of " << get_max_sequence_length() << ".\n"; + assert(false); return false; } } else { @@ -118,9 +119,10 @@ bool RequestManager::load_request_token_ids(Request &request) { // check that max sequence length is not exceeded // 1. prompt itself should be less than max sequence length if (tokens.size() >= get_max_sequence_length()) { - std::cout << "Error: prompt (" << tokens.size() + std::cerr << "Error: prompt (" << tokens.size() << " tokens) exceeds max sequence length of " << get_max_sequence_length() << ".\n"; + assert(false); return false; } // 2. max_length should not exceed the max_sequence_length @@ -128,6 +130,7 @@ bool RequestManager::load_request_token_ids(Request &request) { std::cout << "Error: max_length (" << request.max_length << ") exceeds max sequence length of " << get_max_sequence_length() << ".\n"; + assert(false); return false; } // for (int i = 0; i < tokens.size(); i++) { @@ -170,13 +173,13 @@ bool RequestManager::load_request_token_ids(Request &request) { << "Creating dataset with benchmarking tokens. Size of dataset: " << request.dataset.size() << std::endl; } else { - using json = nlohmann::json; std::ifstream file_handle(request.peft_finetuning_info.dataset_filepath); assert(file_handle.good() && "Dataset file does not exist."); - json dataset_json = json::parse(file_handle, - /*parser_callback_t */ nullptr, - /*allow_exceptions */ true, - /*ignore_comments */ true); + nlohmann::ordered_json dataset_json = + nlohmann::ordered_json::parse(file_handle, + /*parser_callback_t */ nullptr, + /*allow_exceptions */ true, + /*ignore_comments */ true); for (auto &prompt : dataset_json) { std::string text = prompt.get(); @@ -187,10 +190,11 @@ bool RequestManager::load_request_token_ids(Request &request) { input_tokens.insert(input_tokens.begin(), bos_token_id); } if (input_tokens.size() > get_max_sequence_length()) { - std::cout << "Error: sample in training dataset is " + std::cerr << "Error: sample in training dataset is " << input_tokens.size() << " tokens long, exceeding the maximum sequence length of " << get_max_sequence_length() << " tokens.\n"; + assert(false); return false; } else { request.dataset.push_back(input_tokens); @@ -207,7 +211,8 @@ bool RequestManager::load_request_token_ids(Request &request) { assert(request.peft_finetuning_info.gradient_accumulation_steps > 0 && "Invalid gradient accumulation steps"); assert(request.peft_finetuning_info.gradient_accumulation_steps <= - request.peft_finetuning_info.max_training_steps && + request.peft_finetuning_info.max_training_epochs * + request.dataset.size() && "Gradient accumulation steps should be less than or equal to max " "training steps"); assert(get_num_ssms() == 0 && "Small speculative models not supported for " @@ -245,8 +250,8 @@ std::ostream &operator<<(std::ostream &os, Request const &req) { os << " status: " << req.peft_finetuning_info.status << "\n"; os << " dataset_filepath: " << req.peft_finetuning_info.dataset_filepath << "\n"; - os << " max_training_steps: " - << req.peft_finetuning_info.max_training_steps << "\n"; + os << " max_training_epochs: " + << req.peft_finetuning_info.max_training_epochs << "\n"; os << " completed_training_steps: " << req.peft_finetuning_info.completed_training_steps << "\n"; os << " dataset_entry_processed_tokens: " @@ -934,8 +939,8 @@ void RequestManager::process_inf_req_progress(BatchConfig const &old_fwd_bc, record_decoding_req_profiling_info(old_fwd_bc, req_idx); if (inf_req_completed(old_fwd_bc, req_idx)) { - printf("Request %zu completed...\n", - old_fwd_bc.requestsInfo[req_idx].request_guid); + // printf("Request %zu completed...\n", + // old_fwd_bc.requestsInfo[req_idx].request_guid); handle_completed_inf_req(old_fwd_bc, req_idx); } } @@ -967,16 +972,19 @@ void RequestManager::handle_completed_inf_req(BatchConfig const &old_bc, gr.output_text = output_text; request.status = Request::COMPLETED; trigger_request_completion_future(request.guid); - log_req_mgr.print("[Done] guid(%zu) initial_len(%d) final_length(%zu)", - old_bc.requestsInfo[i].request_guid, - request.initial_len, - gr.input_tokens.size() + gr.output_tokens.size()); // log_req_mgr.print("Final output: %s", output.c_str()); num_processed_requests++; ProfileInfo profile_info = profiling_requests[request.guid]; profile_info.finish_time = Realm::Clock::current_time_in_microseconds(); total_request_run_time += profile_info.finish_time - profile_info.start_time; profiling_requests[request.guid] = profile_info; + log_req_mgr.print("Inference req %zu completed. Length: %lu (prompt=%lu, " + "response=%lu). Evicted %i times.", + old_bc.requestsInfo[i].request_guid, + gr.input_tokens.size() + gr.output_tokens.size(), + gr.input_tokens.size(), + gr.output_tokens.size(), + profile_info.num_evictions); } void RequestManager::evict_requests_if_needed(BatchConfig const &old_bc, @@ -1057,6 +1065,7 @@ void RequestManager::evict_requests_if_needed(BatchConfig const &old_bc, // printf("Pending infr request queue size: %zu -> %zu\n", before, after); // Remove the evicted request from planned_tokens_per_request // before = planned_tokens_per_request.size(); + profiling_requests[request_to_evict].num_evictions += 1; planned_tokens_per_request.erase( std::remove_if( planned_tokens_per_request.begin(), @@ -1345,7 +1354,8 @@ void RequestManager::add_finetuning_req_fwd_batch(BatchConfig &new_bc) { // set_optimizer_tasks( // new_bc.requestsInfo[inference_batch_size].optimizer_tasks, - // request.peft_finetuning_info.max_training_steps, + // request.peft_finetuning_info.max_training_epochs * + // request.dataset.size(), // request.peft_finetuning_info.completed_training_steps, // request.peft_finetuning_info.gradient_accumulation_steps); @@ -1433,7 +1443,8 @@ void RequestManager::add_finetuning_req_bwd_batch(BatchConfig &new_bc) { new_bc.requestsInfo[inference_batch_size].peft_bwd_first_layer); set_optimizer_tasks(new_bc.requestsInfo[inference_batch_size].optimizer_tasks, - request.peft_finetuning_info.max_training_steps, + request.peft_finetuning_info.max_training_epochs * + request.dataset.size(), request.peft_finetuning_info.completed_training_steps, request.peft_finetuning_info.gradient_accumulation_steps); @@ -1485,7 +1496,8 @@ void RequestManager::process_finetuning_req_fwd_progress( assert(request.req_type == RequestType::REQ_FINETUNING && "Found misplaced inference request"); assert(request.peft_finetuning_info.completed_training_steps <= - request.peft_finetuning_info.max_training_steps); + request.peft_finetuning_info.max_training_epochs * + request.dataset.size()); assert(request.guid == old_bc.requestsInfo[inference_batch_size].request_guid && "Request GUID mismatch"); @@ -1560,8 +1572,18 @@ void RequestManager::process_finetuning_req_bwd_progress( request.peft_finetuning_info.status = Request::FORWARD_PHASE; request.peft_finetuning_info.last_processed_bwd_layer = INT_MAX; } - if (request.peft_finetuning_info.completed_training_steps == - request.peft_finetuning_info.max_training_steps || + // print status update after each epoch + int tot_steps = + request.peft_finetuning_info.max_training_epochs * request.dataset.size(); + if (request.peft_finetuning_info.completed_training_steps % + ((int)request.dataset.size()) == + 0) { + log_req_mgr.print("Completed finetuning epoch %i/%i", + request.peft_finetuning_info.completed_training_steps / + ((int)request.dataset.size()), + request.peft_finetuning_info.max_training_epochs); + } + if (request.peft_finetuning_info.completed_training_steps == tot_steps || inference_finished) { handle_completed_finetuning_req(old_bc); }