diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index 8af9a08c..704afced 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -134,6 +134,7 @@ def _process_model_outputs(self, skip) = ctx.output_queue.popleft() # Filter out outputs of migrating requests. + server_infos = [] if outputs: new_outputs = [] new_scheduled_seq_groups = [] @@ -145,6 +146,7 @@ def _process_model_outputs(self, new_scheduled_seq_groups.append(scheduled_seq_group) new_seq_group_metadata_list.append(seq_group_meta) new_outputs.append(seq_group_output) + server_infos.append(seq_group.server_info) scheduler_outputs.scheduled_seq_groups = new_scheduled_seq_groups outputs[0].outputs = new_outputs seq_group_metadata_list = new_seq_group_metadata_list @@ -156,11 +158,24 @@ def _process_model_outputs(self, ctx.output_queue.appendleft((outputs, seq_group_metadata_list, scheduler_outputs, is_async, is_last_step, is_first_step_output, skip)) - return super()._process_model_outputs(ctx, request_id) + for server_info in server_infos: + if hasattr(server_info, 'request_timestamps'): + server_info.request_timestamps.engine_process_model_outputs_timestamp_begin = time.time() + + super()._process_model_outputs(ctx, request_id) + + if ctx.request_outputs: + request_outputs, server_infos = zip(*ctx.request_outputs) + for request_output, server_info in zip(request_outputs, server_infos): + if hasattr(server_info, 'request_timestamps'): + request_output.request_timestamps = server_info.request_timestamps + request_output.request_timestamps.engine_process_model_outputs_timestamp_end = time.time() + + return def _process_request_outputs( self, - outputs: List[Tuple[RequestOutput,ServerInfo]], + outputs: List[Tuple[RequestOutput, ServerInfo]], step_begin_time: float ) -> Tuple[List[RequestOutput], List[ServerInfo]]: request_outputs = [] @@ -169,14 +184,16 @@ def _process_request_outputs( request_outputs, server_infos = zip(*outputs) request_outputs = list(request_outputs) server_infos = list(server_infos) - for request_output in request_outputs: - if request_output.finished: - logger.info("engine finished request {}".format(request_output.request_id)) + for request_output in request_outputs: if hasattr(request_output, 'request_timestamps'): request_output.request_timestamps.engine_step_timestamp_begin = step_begin_time request_output.request_timestamps.engine_step_timestamp_end = time.time() + for request_output in request_outputs: + if request_output.finished: + logger.info("engine finished request {}".format(request_output.request_id)) + instance_info: InstanceInfo = self.instance_info instance_info.instance_id = self.instance_id instance_info.step_id = next(self.step_counter)