Skip to content

Commit

Permalink
callback added
Browse files Browse the repository at this point in the history
  • Loading branch information
megha95 committed Aug 5, 2024
1 parent 0ee10ca commit 6fcce3f
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 23 deletions.
29 changes: 14 additions & 15 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
from vllm import utils
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5


def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
model_config.model,
Expand Down Expand Up @@ -357,6 +357,7 @@ def __init__(
self.previous_output = None
self.previous_scheduler_outputs = None
self.previous_seq_group_metadata_list = None
self.request_outputs = None

def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
Expand Down Expand Up @@ -803,11 +804,7 @@ def _process_sequence_group_outputs(
return

def _process_model_outputs(
self,
output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
scheduled_seq_groups: List[ScheduledSequenceGroup],
ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata],
self
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Apply the model output to the sequences in the scheduled seq groups.
Expand All @@ -816,6 +813,11 @@ def _process_model_outputs(

now = time.time()

scheduled_seq_groups = self.previous_scheduler_outputs.scheduled_seq_groups
ignored_seq_groups = self.previous_scheduler_outputs.ignored_seq_groups
output = self.previous_output
seq_group_metadata_list = self.previous_seq_group_metadata_list

# Organize outputs by [sequence group][step] instead of
# [step][sequence group].
output_by_sequence_group = create_output_by_sequence_group(
Expand Down Expand Up @@ -851,7 +853,8 @@ def _process_model_outputs(
for seq_group in ignored_seq_groups:
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
return request_outputs
self.request_outputs = request_outputs
return

def _advance_to_next_step(
self,
Expand Down Expand Up @@ -931,12 +934,6 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
raise NotImplementedError(
"Pipeline parallelism is only supported through AsyncLLMEngine "
"as performance will be severely degraded otherwise.")
request_outputs = None
if (self.previous_output) and (len(self.previous_output) > 0):
request_outputs = self._process_model_outputs(
self.previous_output, self.previous_scheduler_outputs.scheduled_seq_groups,
self.previous_scheduler_outputs.ignored_seq_groups, self.previous_seq_group_metadata_list)

seq_group_metadata_list, scheduler_outputs = self.scheduler[
0].schedule()

Expand All @@ -952,10 +949,12 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids)
output = self.model_executor.execute_model(
execute_model_req=execute_model_req)
execute_model_req=execute_model_req, callback_fn=self._process_model_outputs)
else:
output = []

# hack to avoid callback function for first step
utils.flag_for_callback_fn = True
self.previous_output = output
self.previous_scheduler_outputs = scheduler_outputs
self.previous_seq_group_metadata_list = seq_group_metadata_list
Expand All @@ -978,7 +977,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
# queued control plane messages, such as add/remove lora adapters.
self.model_executor.stop_remote_worker_execution_loop()

return request_outputs
return self.request_outputs

def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
if logger_name in self.stat_loggers:
Expand Down
5 changes: 3 additions & 2 deletions vllm/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,16 @@ def initialize_cache(self, num_gpu_blocks: int,

def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
execute_model_req: ExecuteModelRequest,
callback_fn = None) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_tensor_parallel_workers_only=True,
**self.extra_execute_model_run_workers_kwargs)

# Only the driver worker returns the sampling results.
driver_outputs = self._driver_execute_model(execute_model_req)
driver_outputs = self._driver_execute_model(execute_model_req, callback_fn)
assert driver_outputs is not None
return driver_outputs

Expand Down
3 changes: 2 additions & 1 deletion vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def initialize_cache(self, num_gpu_blocks: int,

@abstractmethod
def execute_model(
self, execute_model_req: ExecuteModelRequest
self, execute_model_req: ExecuteModelRequest,
callback_fn = None
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences."""
raise NotImplementedError
Expand Down
5 changes: 3 additions & 2 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,15 @@ def shutdown(self):
worker_monitor.close()

def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]
self, execute_model_req: Optional[ExecuteModelRequest],
callback_fn = None
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return self.driver_worker.execute_model(execute_model_req)
return self.driver_worker.execute_model(execute_model_req, callback_fn)

def _run_workers(
self,
Expand Down
3 changes: 3 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from vllm import _custom_ops as ops
from vllm.logger import enable_trace_function_call, init_logger

global flag_for_callback_fn
flag_for_callback_fn = False

logger = init_logger(__name__)

STR_DTYPE_TO_TORCH_DTYPE = {
Expand Down
5 changes: 4 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from vllm.utils import (CudaMemoryProfiler, flatten_2d_lists,
get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available)
from vllm import utils
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
Expand Down Expand Up @@ -137,7 +138,6 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
# Used for speculative decoding. We do not broadcast it because it is only
# used by the driver worker.
is_prompt: Optional[bool] = None

def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
Expand Down Expand Up @@ -1292,6 +1292,7 @@ def execute_model(
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
callback_fn = None
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in ModelRunner")
Expand Down Expand Up @@ -1380,6 +1381,8 @@ def execute_model(
if not self.is_driver_worker:
return []

if utils.flag_for_callback_fn and callback_fn is not None:
callback_fn()
# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,
Expand Down
1 change: 1 addition & 0 deletions vllm/worker/model_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def execute_model(
kv_caches: Optional[List[torch.Tensor]],
intermediate_tensors: Optional[IntermediateTensors],
num_steps: int = 1,
callback_fn = None
) -> Optional[List[SamplerOutput]]:
"""
Execute the model on the given input.
Expand Down
5 changes: 3 additions & 2 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ def execute_worker(self, worker_input: WorkerInput) -> None:

def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
execute_model_req: Optional[ExecuteModelRequest] = None,
callback_fn = None
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
Expand Down Expand Up @@ -273,7 +274,7 @@ def execute_model(
output = self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None, intermediate_tensors,
num_steps)
num_steps, callback_fn)

if not get_pp_group().is_last_rank:
# output is IntermediateTensors
Expand Down

0 comments on commit 6fcce3f

Please sign in to comment.