Skip to content

Commit

Permalink
[Fix] Don't deep-copy LogitsProcessors when copying SamplingParams (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored Feb 29, 2024
1 parent 2c08ff2 commit 29a8d6a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
5 changes: 3 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,9 @@ def add_request(
prompt_token_ids[:prefix_pos], lora_request.lora_int_id
if lora_request else 0) if prefix_pos is not None else None

# Defensive copy of SamplingParams, which are used by the sampler
sampling_params = copy.deepcopy(sampling_params)
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()

# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
Expand Down
15 changes: 15 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Sampling parameters for text generation."""
import copy
from enum import IntEnum
from functools import cached_property
from typing import Callable, List, Optional, Union
Expand Down Expand Up @@ -237,6 +238,20 @@ def sampling_type(self) -> SamplingType:
return SamplingType.RANDOM_SEED
return SamplingType.RANDOM

def clone(self) -> "SamplingParams":
"""Deep copy excluding LogitsProcessor objects.
LogitsProcessor objects are excluded because they may contain an
arbitrary, nontrivial amount of data.
See https://github.com/vllm-project/vllm/issues/3087
"""

logit_processor_refs = None if self.logits_processors is None else {
id(lp): lp
for lp in self.logits_processors
}
return copy.deepcopy(self, memo=logit_processor_refs)

def __repr__(self) -> str:
return (
f"SamplingParams(n={self.n}, "
Expand Down

0 comments on commit 29a8d6a

Please sign in to comment.