Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] move parallel sampling out from vllm core #9302

Merged
merged 17 commits into from
Oct 22, 2024
44 changes: 39 additions & 5 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import time
from collections import deque
from contextlib import contextmanager
Expand Down Expand Up @@ -43,8 +44,8 @@
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
SeqGroupHolder, Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
Expand Down Expand Up @@ -474,6 +475,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
),
))

self.group_id_to_holders: Dict[str, SeqGroupHolder] = {}

def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).

Expand Down Expand Up @@ -788,6 +791,30 @@ def add_request(
>>> # continue the request processing
>>> ...
"""

if isinstance(params, SamplingParams) and params.n > 1:
params = copy.deepcopy(params)
n = params.n
params.n = 1
params.output_kind = RequestOutputKind.FINAL_ONLY
holder = SeqGroupHolder(request_id)
for i in range(n):
request_id_i = f"{request_id}_parallel_sample_{i}"
holder.seq_ids.add(request_id_i)
self.add_request(
request_id_i,
prompt=prompt,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
inputs=inputs,
) # type: ignore
self.group_id_to_holders[request_id_i] = holder
return

if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
Expand Down Expand Up @@ -1133,7 +1160,9 @@ def _process_model_outputs(self,
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
seq_group,
self.group_id_to_holders,
use_cache=self.use_cached_outputs)
if request_output:
ctx.request_outputs.append(request_output)

Expand Down Expand Up @@ -1173,7 +1202,9 @@ def _process_model_outputs(self,
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
seq_group,
self.group_id_to_holders,
use_cache=self.use_cached_outputs)
if request_output:
ctx.request_outputs.append(request_output)

Expand All @@ -1192,7 +1223,10 @@ def _process_model_outputs(self,
continue

request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
seq_group,
self.group_id_to_holders,
use_cache=self.use_cached_outputs,
)
if request_output:
ctx.request_outputs.append(request_output)

Expand Down
13 changes: 8 additions & 5 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)

output_kind = RequestOutputKind.FINAL_ONLY
if self.stream and self.n == 1 and self.best_of is None:
output_kind = RequestOutputKind.DELTA
return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
Expand All @@ -349,8 +352,7 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
output_kind=output_kind,
guided_decoding=guided_decoding,
logit_bias=self.logit_bias)

Expand Down Expand Up @@ -620,7 +622,9 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)

output_kind = RequestOutputKind.FINAL_ONLY
if self.stream and self.n == 1 and self.best_of is None:
output_kind = RequestOutputKind.DELTA
return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
Expand All @@ -643,8 +647,7 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
output_kind=output_kind,
guided_decoding=guided_decoding,
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids)
Expand Down
29 changes: 23 additions & 6 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import time
from dataclasses import dataclass
from typing import List, Optional
from typing import Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Union

from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceStatus)
SeqGroupHolder, SequenceGroup, SequenceStatus)


@dataclass
Expand Down Expand Up @@ -114,8 +114,10 @@ def __init__(
self.encoder_prompt_token_ids = encoder_prompt_token_ids

@classmethod
def from_seq_group(cls, seq_group: SequenceGroup,
use_cache: bool) -> Optional["RequestOutput"]:
def from_seq_group(
cls, seq_group: SequenceGroup, use_cache: bool,
group_id_to_holders: Dict[str, SeqGroupHolder]
) -> Optional["RequestOutput"]:
sampling_params = seq_group.sampling_params
if sampling_params is None:
raise ValueError(
Expand All @@ -126,6 +128,18 @@ def from_seq_group(cls, seq_group: SequenceGroup,
not finished):
return None

if finished and seq_group.request_id in group_id_to_holders:
# parallel sampling can not stream the output
assert sampling_params.output_kind == RequestOutputKind.FINAL_ONLY
holder: SeqGroupHolder = group_id_to_holders[seq_group.request_id]
del group_id_to_holders[seq_group.request_id]
assembled_seq_group = holder.maybe_finish_and_assemble(seq_group)
# only part of the request is finished
if assembled_seq_group is None:
return None
return cls.from_seq_group(assembled_seq_group, use_cache,
group_id_to_holders)

# Init cache (if needed)
if use_cache and seq_group.cached_request_output is None:
seq_group.cached_request_output = RequestOutput( # type: ignore
Expand Down Expand Up @@ -309,10 +323,13 @@ def __repr__(self):
class RequestOutputFactory:

@staticmethod
def create(seq_group: SequenceGroup, use_cache: bool = False):
def create(seq_group: SequenceGroup,
group_id_to_holders: Dict[str, SeqGroupHolder],
use_cache: bool = False):
# Determine the type based on a condition, for example:
if hasattr(seq_group,
'embeddings') and seq_group.embeddings is not None:
return EmbeddingRequestOutput.from_seq_group(seq_group)
else:
return RequestOutput.from_seq_group(seq_group, use_cache)
return RequestOutput.from_seq_group(seq_group, use_cache,
group_id_to_holders)
40 changes: 39 additions & 1 deletion vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from array import array
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import cached_property, reduce
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
Expand Down Expand Up @@ -1378,3 +1378,41 @@ def clone(
last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None,
async_callback=self.async_callback)


@dataclass
class SeqGroupHolder:
group_id: str # the original request id before splitting

# all the request ids that are part of this group
seq_ids: Set[str] = field(default_factory=set)

# all the finished requests
finished_reqs: List[SequenceGroup] = field(default_factory=list)

def maybe_finish_and_assemble(
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
self.seq_ids.remove(seq_group.request_id)
self.finished_reqs.append(seq_group)
if len(self.seq_ids) == 0:
params = self.finished_reqs[0].sampling_params
assert params is not None
params.n = len(self.finished_reqs)
assembled_seq_group = SequenceGroup(
request_id=self.group_id,
seqs=[x.seqs[0] for x in self.finished_reqs],
sampling_params=params,
arrival_time=self.finished_reqs[0].arrival_time,
lora_request=self.finished_reqs[0].lora_request,
trace_headers=self.finished_reqs[0].trace_headers,
prompt_adapter_request=self.finished_reqs[0].
prompt_adapter_request,
priority=self.finished_reqs[0].priority,
embeddings=self.finished_reqs[0].embeddings,
pooling_params=self.finished_reqs[0].pooling_params,
)
assembled_seq_group.cached_request_output = self.finished_reqs[
0].cached_request_output
return assembled_seq_group
else:
return None