Skip to content

Commit

Permalink
feat: remove logits_processor inspect overhead
Browse files Browse the repository at this point in the history
Signed-off-by: imkero <kerorek@outlook.com>
  • Loading branch information
imkero committed Feb 3, 2025
1 parent ce96700 commit 9d8d188
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 10 deletions.
3 changes: 2 additions & 1 deletion tests/v1/sample/test_logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
import torch

from vllm.sampling_params import LogitsProcessor
from vllm.logits_process import LogitsProcessor, normalize_logits_processor
from vllm.utils import make_tensor_with_pad
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
Expand Down Expand Up @@ -230,6 +230,7 @@ def test_sampler_logits_processors(
sampling_metadata.logits_processors = {}
processors = processors_and_validator[0]
if processors:
processors = [normalize_logits_processor(p) for p in processors]
if batch_size > 1:
# leave the last but non-first seq untouched
sampling_metadata.logits_processors = {
Expand Down
15 changes: 15 additions & 0 deletions vllm/logits_process.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

import inspect
from typing import Callable, List, Tuple, Union

import torch
Expand All @@ -16,6 +17,20 @@
to sample from."""


def normalize_logits_processor(
logits_processor: LogitsProcessor) -> LogitsProcessor:
"""ensure given logits_processor takes 3 arguments"""
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
return logits_processor

def wrapper(promot_token_ids: List[int], output_token_ids: List[int],
logits: torch.Tensor):
return logits_processor(output_token_ids, logits)

return wrapper


def get_bad_words_logits_processors(
bad_words: List[str],
tokenizer: AnyTokenizer) -> List[LogitsProcessor]:
Expand Down
9 changes: 2 additions & 7 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
"""A layer that samples the next tokens from the model's outputs."""
import inspect
from typing import Tuple

import torch
Expand Down Expand Up @@ -154,12 +153,8 @@ def apply_logits_processors(
prompt_tokens_ids = sampling_metadata.prompt_token_ids_cpu[seq_idx]

for logits_processor in seq_logits_processors:
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
logits_row = logits_processor(prompt_tokens_ids,
past_tokens_ids, logits_row)
else:
logits_row = logits_processor(past_tokens_ids, logits_row)
logits_row = logits_processor(prompt_tokens_ids,
past_tokens_ids, logits_row)

if logits_row is not original_logits_row:
logits[seq_idx] = logits_row
Expand Down
8 changes: 6 additions & 2 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class CachedRequestState:
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None

normalized_logits_processors: Optional[List[LogitsProcessor]] = None
"""copy of sampling_params.logits_processors which
ensured to receive 3 arguments."""

@property
def num_tokens(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)
Expand Down Expand Up @@ -228,9 +232,9 @@ def add_request(
self.repetition_penalties_reqs.add(req_id)
self.min_tokens[req_index] = sampling_params.min_tokens
self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids
if sampling_params.logits_processors:
if request.normalized_logits_processors:
self.logits_processors[req_index] = \
sampling_params.logits_processors
request.normalized_logits_processors
self.prompt_token_ids_cpu[req_index] = request.prompt_token_ids

# NOTE(woosuk): self.generators should not include the requests that
Expand Down
9 changes: 9 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.logits_process import normalize_logits_processor
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
Expand Down Expand Up @@ -303,6 +304,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
spatial_merge_size,
)

# normalize logits_processors
if self.requests[req_id].sampling_params.logits_processors:
self.requests[req_id].normalized_logits_processors = [
normalize_logits_processor(logits_processor)
for logits_processor in
self.requests[req_id].sampling_params.logits_processors
]

req_ids_to_add.append(req_id)

# Update the cached states of the resumed requests.
Expand Down

0 comments on commit 9d8d188

Please sign in to comment.