diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f0fd7efdef813..aec31c4ad60bb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -483,8 +483,9 @@ def add_request( prompt_token_ids[:prefix_pos], lora_request.lora_int_id if lora_request else 0) if prefix_pos is not None else None - # Defensive copy of SamplingParams, which are used by the sampler - sampling_params = copy.deepcopy(sampling_params) + # Defensive copy of SamplingParams, which are used by the sampler, + # this doesn't deep-copy LogitsProcessor objects + sampling_params = sampling_params.clone() # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 51d39220ca9ca..8103f3c2b24bf 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,4 +1,5 @@ """Sampling parameters for text generation.""" +import copy from enum import IntEnum from functools import cached_property from typing import Callable, List, Optional, Union @@ -237,6 +238,20 @@ def sampling_type(self) -> SamplingType: return SamplingType.RANDOM_SEED return SamplingType.RANDOM + def clone(self) -> "SamplingParams": + """Deep copy excluding LogitsProcessor objects. + + LogitsProcessor objects are excluded because they may contain an + arbitrary, nontrivial amount of data. + See https://github.com/vllm-project/vllm/issues/3087 + """ + + logit_processor_refs = None if self.logits_processors is None else { + id(lp): lp + for lp in self.logits_processors + } + return copy.deepcopy(self, memo=logit_processor_refs) + def __repr__(self) -> str: return ( f"SamplingParams(n={self.n}, "