From 78a0613038b83b628d182c3ea2cd65e710ae58a8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 11 Oct 2024 17:18:37 -0700 Subject: [PATCH 01/16] try to remove seq group inside core --- vllm/engine/llm_engine.py | 92 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 563e52a37d935..fd9e90043a1a3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,3 +1,4 @@ +import dataclasses import time from collections import deque from contextlib import contextmanager @@ -61,6 +62,40 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5 +@dataclasses.dataclass +class SeqGroupHolder: + group_id: str # the original request id before splitting + + # all the request ids that are part of this group + seq_ids: Set[str] = dataclasses.field(default_factory=set) + + # all the finished requests + finished_reqs: List[SequenceGroup] = dataclasses.field( + default_factory=list) + + def maybe_finish_and_assemble( + self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: + self.seq_ids.remove(seq_group.request_id) + self.finished_reqs.append(seq_group) + if len(self.seq_ids) == 0: + assembled_seq_group = SequenceGroup( + request_id=self.group_id, + seqs=[x.seqs[0] for x in self.finished_reqs], + sampling_params=self.finished_reqs[0].sampling_params, + arrival_time=self.finished_reqs[0].arrival_time, + lora_request=self.finished_reqs[0].lora_request, + trace_headers=self.finished_reqs[0].trace_headers, + prompt_adapter_request=self.finished_reqs[0]. + prompt_adapter_request, + priority=self.finished_reqs[0].priority, + embeddings=self.finished_reqs[0].embeddings, + pooling_params=self.finished_reqs[0].pooling_params, + ) + return assembled_seq_group + else: + return None + + def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: config = try_get_generation_config( model_config.model, @@ -474,6 +509,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: ), )) + self.group_id_to_holders: Dict[str, SeqGroupHolder] = {} + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -788,6 +825,27 @@ def add_request( >>> # continue the request processing >>> ... """ + + if isinstance(params, SamplingParams) and params.n > 1: + n = params.n + params.n = 1 + holder = SeqGroupHolder(request_id) + for i in range(n): + request_id_i = f"{request_id}_{i}" + holder.seq_ids.add(request_id_i) + self.add_request( + request_id_i, + prompt=prompt, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + inputs=inputs, + ) # type: ignore + self.group_id_to_holders[request_id_i] = holder + if inputs is not None: prompt = inputs assert prompt is not None and params is not None @@ -1131,6 +1189,15 @@ def _process_model_outputs(self, scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] seq_group = scheduled_seq_group.seq_group + assembled_seq_group = self._maybe_finish_seq_in_group(seq_group) + if assembled_seq_group is not None: + # change seq_group for later code + seq_group = assembled_seq_group + # change scheduled_seq_group + scheduled_seq_group.seq_group = seq_group + else: + continue + seq_group.maybe_set_first_token_time(now) request_output = RequestOutputFactory.create( seq_group, use_cache=self.use_cached_outputs) @@ -1171,6 +1238,15 @@ def _process_model_outputs(self, scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] seq_group = scheduled_seq_group.seq_group + assembled_seq_group = self._maybe_finish_seq_in_group(seq_group) + if assembled_seq_group is not None: + # change seq_group for later code + seq_group = assembled_seq_group + # change scheduled_seq_group + scheduled_seq_group.seq_group = seq_group + else: + continue + seq_group.maybe_set_first_token_time(now) request_output = RequestOutputFactory.create( seq_group, use_cache=self.use_cached_outputs) @@ -1191,6 +1267,13 @@ def _process_model_outputs(self, RequestOutputKind.DELTA) and not seq_group.is_finished(): continue + assembled_seq_group = self._maybe_finish_seq_in_group(seq_group) + if assembled_seq_group is not None: + # change seq_group for later code + seq_group = assembled_seq_group + else: + continue + request_output = RequestOutputFactory.create( seq_group, use_cache=self.use_cached_outputs) if request_output: @@ -1215,6 +1298,15 @@ def _process_model_outputs(self, return None + def _maybe_finish_seq_in_group( + self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: + if seq_group.request_id in self.group_id_to_holders: + holder = self.group_id_to_holders[seq_group.request_id] + del self.group_id_to_holders[seq_group.request_id] + assembled_seq_group = holder.maybe_finish_and_assemble(seq_group) + return assembled_seq_group + return None + def _advance_to_next_step( self, output: List[SamplerOutput], seq_group_metadata_list: List[SequenceGroupMetadata], From 734df780a01f5c9782d1391ee08dad8debd54e71 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 11 Oct 2024 21:36:01 -0700 Subject: [PATCH 02/16] fix --- vllm/engine/llm_engine.py | 87 ++++++--------------------------------- vllm/outputs.py | 28 ++++++++++--- vllm/sequence.py | 38 ++++++++++++++++- 3 files changed, 72 insertions(+), 81 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fd9e90043a1a3..16a1975ab4b28 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,4 +1,3 @@ -import dataclasses import time from collections import deque from contextlib import contextmanager @@ -44,8 +43,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - Sequence, SequenceGroup, SequenceGroupMetadata, - SequenceStatus) + SeqGroupHolder, Sequence, SequenceGroup, + SequenceGroupMetadata, SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -62,40 +61,6 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5 -@dataclasses.dataclass -class SeqGroupHolder: - group_id: str # the original request id before splitting - - # all the request ids that are part of this group - seq_ids: Set[str] = dataclasses.field(default_factory=set) - - # all the finished requests - finished_reqs: List[SequenceGroup] = dataclasses.field( - default_factory=list) - - def maybe_finish_and_assemble( - self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: - self.seq_ids.remove(seq_group.request_id) - self.finished_reqs.append(seq_group) - if len(self.seq_ids) == 0: - assembled_seq_group = SequenceGroup( - request_id=self.group_id, - seqs=[x.seqs[0] for x in self.finished_reqs], - sampling_params=self.finished_reqs[0].sampling_params, - arrival_time=self.finished_reqs[0].arrival_time, - lora_request=self.finished_reqs[0].lora_request, - trace_headers=self.finished_reqs[0].trace_headers, - prompt_adapter_request=self.finished_reqs[0]. - prompt_adapter_request, - priority=self.finished_reqs[0].priority, - embeddings=self.finished_reqs[0].embeddings, - pooling_params=self.finished_reqs[0].pooling_params, - ) - return assembled_seq_group - else: - return None - - def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: config = try_get_generation_config( model_config.model, @@ -829,6 +794,7 @@ def add_request( if isinstance(params, SamplingParams) and params.n > 1: n = params.n params.n = 1 + params.output_kind = RequestOutputKind.FINAL_ONLY holder = SeqGroupHolder(request_id) for i in range(n): request_id_i = f"{request_id}_{i}" @@ -1189,18 +1155,11 @@ def _process_model_outputs(self, scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] seq_group = scheduled_seq_group.seq_group - assembled_seq_group = self._maybe_finish_seq_in_group(seq_group) - if assembled_seq_group is not None: - # change seq_group for later code - seq_group = assembled_seq_group - # change scheduled_seq_group - scheduled_seq_group.seq_group = seq_group - else: - continue - seq_group.maybe_set_first_token_time(now) request_output = RequestOutputFactory.create( - seq_group, use_cache=self.use_cached_outputs) + seq_group, + self.group_id_to_holders, + use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) @@ -1238,18 +1197,11 @@ def _process_model_outputs(self, scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] seq_group = scheduled_seq_group.seq_group - assembled_seq_group = self._maybe_finish_seq_in_group(seq_group) - if assembled_seq_group is not None: - # change seq_group for later code - seq_group = assembled_seq_group - # change scheduled_seq_group - scheduled_seq_group.seq_group = seq_group - else: - continue - seq_group.maybe_set_first_token_time(now) request_output = RequestOutputFactory.create( - seq_group, use_cache=self.use_cached_outputs) + seq_group, + self.group_id_to_holders, + use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) @@ -1267,15 +1219,11 @@ def _process_model_outputs(self, RequestOutputKind.DELTA) and not seq_group.is_finished(): continue - assembled_seq_group = self._maybe_finish_seq_in_group(seq_group) - if assembled_seq_group is not None: - # change seq_group for later code - seq_group = assembled_seq_group - else: - continue - request_output = RequestOutputFactory.create( - seq_group, use_cache=self.use_cached_outputs) + seq_group, + self.group_id_to_holders, + use_cache=self.use_cached_outputs, + ) if request_output: ctx.request_outputs.append(request_output) @@ -1298,15 +1246,6 @@ def _process_model_outputs(self, return None - def _maybe_finish_seq_in_group( - self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: - if seq_group.request_id in self.group_id_to_holders: - holder = self.group_id_to_holders[seq_group.request_id] - del self.group_id_to_holders[seq_group.request_id] - assembled_seq_group = holder.maybe_finish_and_assemble(seq_group) - return assembled_seq_group - return None - def _advance_to_next_step( self, output: List[SamplerOutput], seq_group_metadata_list: List[SequenceGroupMetadata], diff --git a/vllm/outputs.py b/vllm/outputs.py index 07650241cb638..8b6367a8dfbfa 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,13 +1,13 @@ import time from dataclasses import dataclass -from typing import List, Optional +from typing import Dict, List, Optional from typing import Sequence as GenericSequence from typing import Union from vllm.lora.request import LoRARequest from vllm.sampling_params import RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, - SequenceGroup, SequenceStatus) + SeqGroupHolder, SequenceGroup, SequenceStatus) @dataclass @@ -114,8 +114,10 @@ def __init__( self.encoder_prompt_token_ids = encoder_prompt_token_ids @classmethod - def from_seq_group(cls, seq_group: SequenceGroup, - use_cache: bool) -> Optional["RequestOutput"]: + def from_seq_group( + cls, seq_group: SequenceGroup, use_cache: bool, + group_id_to_holders: Dict[str, SeqGroupHolder] + ) -> Optional["RequestOutput"]: sampling_params = seq_group.sampling_params if sampling_params is None: raise ValueError( @@ -126,6 +128,17 @@ def from_seq_group(cls, seq_group: SequenceGroup, not finished): return None + if finished and seq_group.request_id in group_id_to_holders: + # parallel sampling can not stream the output + assert sampling_params.output_kind == RequestOutputKind.FINAL_ONLY + holder: SeqGroupHolder = group_id_to_holders[seq_group.request_id] + del group_id_to_holders[seq_group.request_id] + assembled_seq_group = holder.maybe_finish_and_assemble(seq_group) + # only part of the request is finished + if assembled_seq_group is None: + return None + seq_group = assembled_seq_group + # Init cache (if needed) if use_cache and seq_group.cached_request_output is None: seq_group.cached_request_output = RequestOutput( # type: ignore @@ -309,10 +322,13 @@ def __repr__(self): class RequestOutputFactory: @staticmethod - def create(seq_group: SequenceGroup, use_cache: bool = False): + def create(seq_group: SequenceGroup, + group_id_to_holders: Dict[str, SeqGroupHolder], + use_cache: bool = False): # Determine the type based on a condition, for example: if hasattr(seq_group, 'embeddings') and seq_group.embeddings is not None: return EmbeddingRequestOutput.from_seq_group(seq_group) else: - return RequestOutput.from_seq_group(seq_group, use_cache) + return RequestOutput.from_seq_group(seq_group, use_cache, + group_id_to_holders) diff --git a/vllm/sequence.py b/vllm/sequence.py index 3bb35ea955c8c..8a94c0c8d13ad 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from array import array from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import cached_property, reduce from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional from typing import Sequence as GenericSequence @@ -1378,3 +1378,39 @@ def clone( last_sampled_token_ids=self.last_sampled_token_ids.clone() if self.last_sampled_token_ids is not None else None, async_callback=self.async_callback) + + +@dataclass +class SeqGroupHolder: + group_id: str # the original request id before splitting + + # all the request ids that are part of this group + seq_ids: Set[str] = field(default_factory=set) + + # all the finished requests + finished_reqs: List[SequenceGroup] = field(default_factory=list) + + def maybe_finish_and_assemble( + self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: + self.seq_ids.remove(seq_group.request_id) + self.finished_reqs.append(seq_group) + if len(self.seq_ids) == 0: + params = self.finished_reqs[0].sampling_params + assert params is not None + params.n = len(self.finished_reqs) + assembled_seq_group = SequenceGroup( + request_id=self.group_id, + seqs=[x.seqs[0] for x in self.finished_reqs], + sampling_params=params, + arrival_time=self.finished_reqs[0].arrival_time, + lora_request=self.finished_reqs[0].lora_request, + trace_headers=self.finished_reqs[0].trace_headers, + prompt_adapter_request=self.finished_reqs[0]. + prompt_adapter_request, + priority=self.finished_reqs[0].priority, + embeddings=self.finished_reqs[0].embeddings, + pooling_params=self.finished_reqs[0].pooling_params, + ) + return assembled_seq_group + else: + return None From 8dc746a38abcbbbc071fe2c894d3a927c8e7b7b9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 11 Oct 2024 22:19:57 -0700 Subject: [PATCH 03/16] fix engine --- vllm/engine/llm_engine.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 16a1975ab4b28..760fe2639c54b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,3 +1,4 @@ +import copy import time from collections import deque from contextlib import contextmanager @@ -792,12 +793,13 @@ def add_request( """ if isinstance(params, SamplingParams) and params.n > 1: + params = copy.deepcopy(params) n = params.n params.n = 1 params.output_kind = RequestOutputKind.FINAL_ONLY holder = SeqGroupHolder(request_id) for i in range(n): - request_id_i = f"{request_id}_{i}" + request_id_i = f"{request_id}_parallel_sample_{i}" holder.seq_ids.add(request_id_i) self.add_request( request_id_i, @@ -811,6 +813,7 @@ def add_request( inputs=inputs, ) # type: ignore self.group_id_to_holders[request_id_i] = holder + return if inputs is not None: prompt = inputs From 94634916e405044a6a85243a7d8ffa46945afd61 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 11 Oct 2024 22:22:45 -0700 Subject: [PATCH 04/16] fix output --- vllm/outputs.py | 3 ++- vllm/sequence.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 8b6367a8dfbfa..768af3a62d574 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -137,7 +137,8 @@ def from_seq_group( # only part of the request is finished if assembled_seq_group is None: return None - seq_group = assembled_seq_group + return cls.from_seq_group(assembled_seq_group, use_cache, + group_id_to_holders) # Init cache (if needed) if use_cache and seq_group.cached_request_output is None: diff --git a/vllm/sequence.py b/vllm/sequence.py index 8a94c0c8d13ad..d04bbb7160d48 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1411,6 +1411,8 @@ def maybe_finish_and_assemble( embeddings=self.finished_reqs[0].embeddings, pooling_params=self.finished_reqs[0].pooling_params, ) + assembled_seq_group.cached_request_output = self.finished_reqs[ + 0].cached_request_output return assembled_seq_group else: return None From 10f7cd9f84183de0753ace26a3ecc4fd853385cf Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 11 Oct 2024 22:28:07 -0700 Subject: [PATCH 05/16] fix server --- vllm/entrypoints/openai/protocol.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 6f1135f8093ba..3ce89e00fca3c 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -327,6 +327,9 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: backend=self.guided_decoding_backend, whitespace_pattern=self.guided_whitespace_pattern) + output_kind = RequestOutputKind.FINAL_ONLY + if self.stream and self.n == 1 and self.best_of is None: + output_kind = RequestOutputKind.DELTA return SamplingParams.from_optional( n=self.n, best_of=self.best_of, @@ -349,8 +352,7 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, truncate_prompt_tokens=self.truncate_prompt_tokens, - output_kind=RequestOutputKind.DELTA if self.stream \ - else RequestOutputKind.FINAL_ONLY, + output_kind=output_kind, guided_decoding=guided_decoding, logit_bias=self.logit_bias) @@ -620,7 +622,9 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: json_object=guided_json_object, backend=self.guided_decoding_backend, whitespace_pattern=self.guided_whitespace_pattern) - + output_kind = RequestOutputKind.FINAL_ONLY + if self.stream and self.n == 1 and self.best_of is None: + output_kind = RequestOutputKind.DELTA return SamplingParams.from_optional( n=self.n, best_of=self.best_of, @@ -643,8 +647,7 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, truncate_prompt_tokens=self.truncate_prompt_tokens, - output_kind=RequestOutputKind.DELTA if self.stream \ - else RequestOutputKind.FINAL_ONLY, + output_kind=output_kind, guided_decoding=guided_decoding, logit_bias=self.logit_bias, allowed_token_ids=self.allowed_token_ids) From 245920f17ae10675d98a4e5fc92a0e326cb3fc93 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 21 Oct 2024 09:52:35 -0700 Subject: [PATCH 06/16] support streaming for parallel sampling --- vllm/engine/llm_engine.py | 44 ++++++++---------- vllm/outputs.py | 39 +++++++--------- vllm/sequence.py | 96 ++++++++++++++++++++++++++++----------- 3 files changed, 104 insertions(+), 75 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 760fe2639c54b..fc2828b93ba97 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,4 +1,3 @@ -import copy import time from collections import deque from contextlib import contextmanager @@ -44,7 +43,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - SeqGroupHolder, Sequence, SequenceGroup, + ParallelSampleSequenceGroup, Sequence, + SequenceGroup, SequenceGroupBase, SequenceGroupMetadata, SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) @@ -475,7 +475,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: ), )) - self.group_id_to_holders: Dict[str, SeqGroupHolder] = {} + self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -793,26 +793,18 @@ def add_request( """ if isinstance(params, SamplingParams) and params.n > 1: - params = copy.deepcopy(params) - n = params.n - params.n = 1 - params.output_kind = RequestOutputKind.FINAL_ONLY - holder = SeqGroupHolder(request_id) - for i in range(n): - request_id_i = f"{request_id}_parallel_sample_{i}" - holder.seq_ids.add(request_id_i) - self.add_request( - request_id_i, - prompt=prompt, - params=params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority, - inputs=inputs, - ) # type: ignore - self.group_id_to_holders[request_id_i] = holder + ParallelSampleSequenceGroup.add_request( + request_id, + self, + params, + prompt=prompt, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + inputs=inputs, + ) return if inputs is not None: @@ -1161,7 +1153,7 @@ def _process_model_outputs(self, seq_group.maybe_set_first_token_time(now) request_output = RequestOutputFactory.create( seq_group, - self.group_id_to_holders, + self.seq_id_to_seq_group, use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) @@ -1203,7 +1195,7 @@ def _process_model_outputs(self, seq_group.maybe_set_first_token_time(now) request_output = RequestOutputFactory.create( seq_group, - self.group_id_to_holders, + self.seq_id_to_seq_group, use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) @@ -1224,7 +1216,7 @@ def _process_model_outputs(self, request_output = RequestOutputFactory.create( seq_group, - self.group_id_to_holders, + self.seq_id_to_seq_group, use_cache=self.use_cached_outputs, ) if request_output: diff --git a/vllm/outputs.py b/vllm/outputs.py index 768af3a62d574..5eb7841bc9ae9 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -7,7 +7,7 @@ from vllm.lora.request import LoRARequest from vllm.sampling_params import RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, - SeqGroupHolder, SequenceGroup, SequenceStatus) + SequenceGroup, SequenceGroupBase, SequenceStatus) @dataclass @@ -116,7 +116,7 @@ def __init__( @classmethod def from_seq_group( cls, seq_group: SequenceGroup, use_cache: bool, - group_id_to_holders: Dict[str, SeqGroupHolder] + seq_id_to_seq_group: Dict[str, SequenceGroupBase] ) -> Optional["RequestOutput"]: sampling_params = seq_group.sampling_params if sampling_params is None: @@ -128,17 +128,16 @@ def from_seq_group( not finished): return None - if finished and seq_group.request_id in group_id_to_holders: - # parallel sampling can not stream the output - assert sampling_params.output_kind == RequestOutputKind.FINAL_ONLY - holder: SeqGroupHolder = group_id_to_holders[seq_group.request_id] - del group_id_to_holders[seq_group.request_id] - assembled_seq_group = holder.maybe_finish_and_assemble(seq_group) + if finished and seq_group.request_id in seq_id_to_seq_group: + group: SequenceGroupBase = seq_id_to_seq_group[ + seq_group.request_id] + group.finish_seq(seq_group) + assembled_seq_group = group.maybe_assemble_group() # only part of the request is finished if assembled_seq_group is None: return None return cls.from_seq_group(assembled_seq_group, use_cache, - group_id_to_holders) + seq_id_to_seq_group) # Init cache (if needed) if use_cache and seq_group.cached_request_output is None: @@ -150,15 +149,7 @@ def from_seq_group( outputs=[], finished=False) - seqs = seq_group.get_seqs() - if len(seqs) == 1: - top_n_seqs = seqs - else: - # Get the top-n sequences. - n = sampling_params._real_n or sampling_params.n - sorting_key = lambda seq: seq.get_cumulative_logprob() - sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) - top_n_seqs = sorted_seqs[:n] + top_n_seqs = seq_group.get_seqs() # Create the outputs. # NOTE: We need omit logprobs here explicitly because the sequence @@ -221,9 +212,13 @@ def from_seq_group( output.stop_reason = seq.stop_reason else: + index = i + if not finished and seq_group.request_id in seq_id_to_seq_group: + group = seq_id_to_seq_group[seq_group.request_id] + index = group.seq_id_to_index[seq_group.request_id] output = CompletionOutput( - seqs.index(seq), output_text, [output_token_ids] - if isinstance(output_token_ids, int) else output_token_ids, + index, output_text, [output_token_ids] if isinstance( + output_token_ids, int) else output_token_ids, seq.get_cumulative_logprob() if include_logprobs else None, output_logprobs, SequenceStatus.get_finished_reason(seq.status), @@ -324,7 +319,7 @@ class RequestOutputFactory: @staticmethod def create(seq_group: SequenceGroup, - group_id_to_holders: Dict[str, SeqGroupHolder], + seq_id_to_seq_group: Dict[str, SequenceGroupBase], use_cache: bool = False): # Determine the type based on a condition, for example: if hasattr(seq_group, @@ -332,4 +327,4 @@ def create(seq_group: SequenceGroup, return EmbeddingRequestOutput.from_seq_group(seq_group) else: return RequestOutput.from_seq_group(seq_group, use_cache, - group_id_to_holders) + seq_id_to_seq_group) diff --git a/vllm/sequence.py b/vllm/sequence.py index d04bbb7160d48..525a225cdb6b9 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1381,38 +1381,80 @@ def clone( @dataclass -class SeqGroupHolder: +class SequenceGroupBase: group_id: str # the original request id before splitting - # all the request ids that are part of this group - seq_ids: Set[str] = field(default_factory=set) + # seq id to a unique index inside this group + seq_id_to_index: Dict[str, int] = field(default_factory=dict) - # all the finished requests - finished_reqs: List[SequenceGroup] = field(default_factory=list) + # seq ids to be finished + to_be_finished: Set[str] = field(default_factory=set) - def maybe_finish_and_assemble( - self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: - self.seq_ids.remove(seq_group.request_id) - self.finished_reqs.append(seq_group) - if len(self.seq_ids) == 0: - params = self.finished_reqs[0].sampling_params + # seq id to finished sequences + finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict) + + @staticmethod + def add_request(request_id: str, engine, params, *args, **kwargs): + """When we are ready to add a request with request_id and params + into the engine, we can split the request into multiple requests. + """ + raise NotImplementedError + + def dealta_output(self, seq: SequenceGroup): + """The sequence `seq` needs to produce delta output. + We need to restore the `index` information. + """ + raise NotImplementedError + + def finish_seq(self, seq: SequenceGroup): + """The sequence `seq` finishes, we should record the information. + """ + self.to_be_finished.remove(seq.request_id) + self.finished_reqs[seq.request_id] = seq + + def maybe_assemble_group(self) -> Optional[SequenceGroup]: + """Assemble the sequence group, for producing the final + output, or adding request in the engine again. + """ + raise NotImplementedError + + +class ParallelSampleSequenceGroup(SequenceGroupBase): + + @staticmethod + def add_request(request_id: str, engine, params, **kwargs): + params = copy.deepcopy(params) + n = params.n + params.n = 1 + group = SequenceGroupBase(request_id) + for i in range(n): + request_id_i = f"{request_id}_parallel_sample_{i}" + group.seq_id_to_index[request_id_i] = i + group.to_be_finished.add(request_id_i) + engine.add_request( + request_id_i, + params=params, + **kwargs, + ) # type: ignore + engine.seq_id_to_seq_group[request_id_i] = group + + def maybe_assemble_group(self) -> Optional[SequenceGroup]: + if len(self.to_be_finished) == 0: + finished_reqs = list(self.finished_reqs.values()) + params = finished_reqs[0].sampling_params assert params is not None - params.n = len(self.finished_reqs) - assembled_seq_group = SequenceGroup( - request_id=self.group_id, - seqs=[x.seqs[0] for x in self.finished_reqs], - sampling_params=params, - arrival_time=self.finished_reqs[0].arrival_time, - lora_request=self.finished_reqs[0].lora_request, - trace_headers=self.finished_reqs[0].trace_headers, - prompt_adapter_request=self.finished_reqs[0]. - prompt_adapter_request, - priority=self.finished_reqs[0].priority, - embeddings=self.finished_reqs[0].embeddings, - pooling_params=self.finished_reqs[0].pooling_params, - ) - assembled_seq_group.cached_request_output = self.finished_reqs[ - 0].cached_request_output + params.n = len(finished_reqs) + assembled_seq_group = copy.deepcopy(finished_reqs[0]) + assembled_seq_group.request_id = self.group_id + + # Get the top-n sequences. + n = params._real_n or params.n + seqs = [x.seqs[0] for x in finished_reqs] + sorting_key = lambda seq: seq.get_cumulative_logprob() + sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) + top_n_seqs = sorted_seqs[:n] + assembled_seq_group.seqs = top_n_seqs + assembled_seq_group.sampling_params = params return assembled_seq_group else: return None From 5ac8a1160788a79710e318b1f395b949671b2283 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 21 Oct 2024 09:53:25 -0700 Subject: [PATCH 07/16] support streaming for parallel sampling --- vllm/entrypoints/openai/protocol.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3ce89e00fca3c..6f1135f8093ba 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -327,9 +327,6 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: backend=self.guided_decoding_backend, whitespace_pattern=self.guided_whitespace_pattern) - output_kind = RequestOutputKind.FINAL_ONLY - if self.stream and self.n == 1 and self.best_of is None: - output_kind = RequestOutputKind.DELTA return SamplingParams.from_optional( n=self.n, best_of=self.best_of, @@ -352,7 +349,8 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, truncate_prompt_tokens=self.truncate_prompt_tokens, - output_kind=output_kind, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, guided_decoding=guided_decoding, logit_bias=self.logit_bias) @@ -622,9 +620,7 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: json_object=guided_json_object, backend=self.guided_decoding_backend, whitespace_pattern=self.guided_whitespace_pattern) - output_kind = RequestOutputKind.FINAL_ONLY - if self.stream and self.n == 1 and self.best_of is None: - output_kind = RequestOutputKind.DELTA + return SamplingParams.from_optional( n=self.n, best_of=self.best_of, @@ -647,7 +643,8 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, truncate_prompt_tokens=self.truncate_prompt_tokens, - output_kind=output_kind, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, guided_decoding=guided_decoding, logit_bias=self.logit_bias, allowed_token_ids=self.allowed_token_ids) From acc8f4c463627a51bbebaa996f183e0d84dac1ad Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 21 Oct 2024 10:00:21 -0700 Subject: [PATCH 08/16] fix --- vllm/sequence.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 525a225cdb6b9..eb8bcc66d504e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1400,12 +1400,6 @@ def add_request(request_id: str, engine, params, *args, **kwargs): """ raise NotImplementedError - def dealta_output(self, seq: SequenceGroup): - """The sequence `seq` needs to produce delta output. - We need to restore the `index` information. - """ - raise NotImplementedError - def finish_seq(self, seq: SequenceGroup): """The sequence `seq` finishes, we should record the information. """ @@ -1426,7 +1420,7 @@ def add_request(request_id: str, engine, params, **kwargs): params = copy.deepcopy(params) n = params.n params.n = 1 - group = SequenceGroupBase(request_id) + group = ParallelSampleSequenceGroup(request_id) for i in range(n): request_id_i = f"{request_id}_parallel_sample_{i}" group.seq_id_to_index[request_id_i] = i From b09801492b5e122b55ffe2eeb29f994688922e3e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 21 Oct 2024 10:06:43 -0700 Subject: [PATCH 09/16] fix lint --- vllm/engine/llm_engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2f6b677dba84f..c1bc1828a7bda 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -46,7 +46,8 @@ from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, ParallelSampleSequenceGroup, Sequence, SequenceGroup, SequenceGroupBase, - SequenceGroupOutput, SequenceGroupMetadata, SequenceStatus) + SequenceGroupMetadata, SequenceGroupOutput, + SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config From df86d6819f4329d7e5f0f0baeb4008038510e791 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 21 Oct 2024 13:59:28 -0700 Subject: [PATCH 10/16] support streaming --- vllm/engine/llm_engine.py | 17 ++++++++++----- vllm/outputs.py | 23 ++++++++++---------- vllm/sequence.py | 46 +++++++++++++++++++++++---------------- 3 files changed, 50 insertions(+), 36 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c1bc1828a7bda..25c4e76d9b159 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -646,7 +646,10 @@ def _add_processed_request( prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> None: + ) -> SequenceGroup: + """Add a processed request to the engine's request pool. + return the created sequence group. + """ self._validate_model_inputs(processed_inputs) # Create the sequences. block_size = self.cache_config.block_size @@ -700,6 +703,8 @@ def _add_processed_request( min_cost_scheduler = self.scheduler[costs.index(min(costs))] min_cost_scheduler.add_seq_group(seq_group) + return seq_group + def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() @@ -715,7 +720,7 @@ def add_request( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - ) -> None: + ) -> Optional[SequenceGroup]: ... @overload @@ -729,7 +734,7 @@ def add_request( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - ) -> None: + ) -> Optional[SequenceGroup]: ... @deprecate_kwargs( @@ -748,7 +753,7 @@ def add_request( priority: int = 0, *, inputs: Optional[PromptType] = None, # DEPRECATED - ) -> None: + ) -> Optional[SequenceGroup]: """Add a request to the engine's request pool. The request is added to the request pool and will be processed by the @@ -806,7 +811,7 @@ def add_request( priority=priority, inputs=inputs, ) - return + return None if inputs is not None: prompt = inputs @@ -838,7 +843,7 @@ def add_request( processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get( "mm_processor_kwargs") - self._add_processed_request( + return self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, params=params, diff --git a/vllm/outputs.py b/vllm/outputs.py index 5eb7841bc9ae9..4a4ecfa45ab42 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -118,27 +118,28 @@ def from_seq_group( cls, seq_group: SequenceGroup, use_cache: bool, seq_id_to_seq_group: Dict[str, SequenceGroupBase] ) -> Optional["RequestOutput"]: - sampling_params = seq_group.sampling_params - if sampling_params is None: - raise ValueError( - "Sampling parameters are missing for a CompletionRequest.") - finished = seq_group.is_finished() - if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( - not finished): - return None - if finished and seq_group.request_id in seq_id_to_seq_group: + if seq_group.request_id in seq_id_to_seq_group: group: SequenceGroupBase = seq_id_to_seq_group[ seq_group.request_id] - group.finish_seq(seq_group) + if finished: + group.finish_seq(seq_group) assembled_seq_group = group.maybe_assemble_group() - # only part of the request is finished if assembled_seq_group is None: return None return cls.from_seq_group(assembled_seq_group, use_cache, seq_id_to_seq_group) + sampling_params = seq_group.sampling_params + if sampling_params is None: + raise ValueError( + "Sampling parameters are missing for a CompletionRequest.") + + if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( + not finished): + return None + # Init cache (if needed) if use_cache and seq_group.cached_request_output is None: seq_group.cached_request_output = RequestOutput( # type: ignore diff --git a/vllm/sequence.py b/vllm/sequence.py index ff43a047fe542..5bc27bc934739 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1407,11 +1407,13 @@ def clone( class SequenceGroupBase: group_id: str # the original request id before splitting + assembled_seq_group: Optional[SequenceGroup] = None + # seq id to a unique index inside this group seq_id_to_index: Dict[str, int] = field(default_factory=dict) # seq ids to be finished - to_be_finished: Set[str] = field(default_factory=set) + to_be_finished: Dict[str, SequenceGroup] = field(default_factory=dict) # seq id to finished sequences finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict) @@ -1426,7 +1428,7 @@ def add_request(request_id: str, engine, params, *args, **kwargs): def finish_seq(self, seq: SequenceGroup): """The sequence `seq` finishes, we should record the information. """ - self.to_be_finished.remove(seq.request_id) + del self.to_be_finished[seq.request_id] self.finished_reqs[seq.request_id] = seq def maybe_assemble_group(self) -> Optional[SequenceGroup]: @@ -1440,38 +1442,44 @@ class ParallelSampleSequenceGroup(SequenceGroupBase): @staticmethod def add_request(request_id: str, engine, params, **kwargs): - params = copy.deepcopy(params) + original_params = params + params = copy.deepcopy(original_params) n = params.n params.n = 1 group = ParallelSampleSequenceGroup(request_id) + seqs = [] for i in range(n): request_id_i = f"{request_id}_parallel_sample_{i}" group.seq_id_to_index[request_id_i] = i - group.to_be_finished.add(request_id_i) - engine.add_request( + seq_group = engine.add_request( request_id_i, params=params, **kwargs, ) # type: ignore engine.seq_id_to_seq_group[request_id_i] = group + group.to_be_finished[request_id_i] = seq_group + seqs.append(seq_group.seqs[0]) + + # for parallel sampling, the `assembled_seq_group` is always + # available, since we have all the sequences ready, and they + # will not change. + group.assembled_seq_group = SequenceGroup( + request_id=request_id, + seqs=seqs, + sampling_params=original_params, + **kwargs, + ) def maybe_assemble_group(self) -> Optional[SequenceGroup]: - if len(self.to_be_finished) == 0: - finished_reqs = list(self.finished_reqs.values()) - params = finished_reqs[0].sampling_params - assert params is not None - params.n = len(finished_reqs) - assembled_seq_group = copy.deepcopy(finished_reqs[0]) - assembled_seq_group.request_id = self.group_id - + assert self.assembled_seq_group is not None + params = self.assembled_seq_group.sampling_params + assert isinstance(params, SamplingParams) + if len(self.to_be_finished) == 0 and params._real_n is not None: # Get the top-n sequences. n = params._real_n or params.n - seqs = [x.seqs[0] for x in finished_reqs] + seqs = self.assembled_seq_group.seqs sorting_key = lambda seq: seq.get_cumulative_logprob() sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) top_n_seqs = sorted_seqs[:n] - assembled_seq_group.seqs = top_n_seqs - assembled_seq_group.sampling_params = params - return assembled_seq_group - else: - return None + self.assembled_seq_group.seqs = top_n_seqs + return self.assembled_seq_group From f841b053f439933fd7dae453ac11fbfdbe421742 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 21 Oct 2024 14:12:10 -0700 Subject: [PATCH 11/16] fix kwargs --- vllm/sequence.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 5bc27bc934739..2c2d8f5c17a4f 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1456,6 +1456,7 @@ def add_request(request_id: str, engine, params, **kwargs): params=params, **kwargs, ) # type: ignore + assert seq_group is not None engine.seq_id_to_seq_group[request_id_i] = group group.to_be_finished[request_id_i] = seq_group seqs.append(seq_group.seqs[0]) @@ -1466,8 +1467,15 @@ def add_request(request_id: str, engine, params, **kwargs): group.assembled_seq_group = SequenceGroup( request_id=request_id, seqs=seqs, + arrival_time=seq_group.arrival_time, sampling_params=original_params, - **kwargs, + lora_request=seq_group.lora_request, + embeddings=seq_group.embeddings, + pooling_params=seq_group.pooling_params, + encoder_seq=seq_group.encoder_seq, + trace_headers=seq_group.trace_headers, + prompt_adapter_request=seq_group.prompt_adapter_request, + priority=seq_group.priority, ) def maybe_assemble_group(self) -> Optional[SequenceGroup]: From 2c3e0012135b212360d53f6e6ab5954b2c051828 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 21 Oct 2024 14:25:59 -0700 Subject: [PATCH 12/16] fix streaming --- vllm/sequence.py | 45 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 2c2d8f5c17a4f..5cd1d3fdaf99a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -17,7 +17,7 @@ from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if TYPE_CHECKING: @@ -1418,6 +1418,10 @@ class SequenceGroupBase: # seq id to finished sequences finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict) + streaming: bool = False + + output_produced: bool = False + @staticmethod def add_request(request_id: str, engine, params, *args, **kwargs): """When we are ready to add a request with request_id and params @@ -1478,16 +1482,37 @@ def add_request(request_id: str, engine, params, **kwargs): priority=seq_group.priority, ) + group.streaming = params.output_kind == RequestOutputKind.DELTA + group.output_produced = False + def maybe_assemble_group(self) -> Optional[SequenceGroup]: + + # in the streaming mode, we will always return the assembled sequence + # this is because streaming will flatten the responses into a single + # stream + if self.streaming: + return self.assembled_seq_group + + # in the non-streaming mode, we will return the assembled sequence + # once after all sequences finish, and then return None for the + # rest of the time + + if len(self.to_be_finished) > 0: + return None + assert self.assembled_seq_group is not None params = self.assembled_seq_group.sampling_params assert isinstance(params, SamplingParams) - if len(self.to_be_finished) == 0 and params._real_n is not None: - # Get the top-n sequences. - n = params._real_n or params.n - seqs = self.assembled_seq_group.seqs - sorting_key = lambda seq: seq.get_cumulative_logprob() - sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) - top_n_seqs = sorted_seqs[:n] - self.assembled_seq_group.seqs = top_n_seqs - return self.assembled_seq_group + if not self.output_produced: + self.output_produced = True + if params._real_n is not None: + # Get the top-n sequences. + n = params._real_n or params.n + seqs = self.assembled_seq_group.seqs + sorting_key = lambda seq: seq.get_cumulative_logprob() + sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) + top_n_seqs = sorted_seqs[:n] + self.assembled_seq_group.seqs = top_n_seqs + return self.assembled_seq_group + if self.output_produced: + return None From a87b9a6687bbccf08a751686f2d63dfdcda6fa66 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 21 Oct 2024 14:32:07 -0700 Subject: [PATCH 13/16] fix streaming --- vllm/outputs.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 4a4ecfa45ab42..2fcf80dc36353 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -213,13 +213,9 @@ def from_seq_group( output.stop_reason = seq.stop_reason else: - index = i - if not finished and seq_group.request_id in seq_id_to_seq_group: - group = seq_id_to_seq_group[seq_group.request_id] - index = group.seq_id_to_index[seq_group.request_id] output = CompletionOutput( - index, output_text, [output_token_ids] if isinstance( - output_token_ids, int) else output_token_ids, + top_n_seqs.index(seq), output_text, [output_token_ids] + if isinstance(output_token_ids, int) else output_token_ids, seq.get_cumulative_logprob() if include_logprobs else None, output_logprobs, SequenceStatus.get_finished_reason(seq.status), From 4627bc36864d495be0bc4e3b71b4720b5c420875 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 21 Oct 2024 14:35:17 -0700 Subject: [PATCH 14/16] polish code --- vllm/sequence.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 5cd1d3fdaf99a..68832b597fb0e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1448,11 +1448,10 @@ class ParallelSampleSequenceGroup(SequenceGroupBase): def add_request(request_id: str, engine, params, **kwargs): original_params = params params = copy.deepcopy(original_params) - n = params.n params.n = 1 group = ParallelSampleSequenceGroup(request_id) seqs = [] - for i in range(n): + for i in range(original_params.n): request_id_i = f"{request_id}_parallel_sample_{i}" group.seq_id_to_index[request_id_i] = i seq_group = engine.add_request( From f04c703e373df52d45550ba8c4780dbc7b0aac2b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 21 Oct 2024 14:50:26 -0700 Subject: [PATCH 15/16] improve streaming --- vllm/outputs.py | 2 +- vllm/sequence.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 2fcf80dc36353..951976310e7ae 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -125,7 +125,7 @@ def from_seq_group( seq_group.request_id] if finished: group.finish_seq(seq_group) - assembled_seq_group = group.maybe_assemble_group() + assembled_seq_group = group.maybe_assemble_group(seq_group) if assembled_seq_group is None: return None return cls.from_seq_group(assembled_seq_group, use_cache, diff --git a/vllm/sequence.py b/vllm/sequence.py index 68832b597fb0e..93f58f00ef77b 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1435,7 +1435,8 @@ def finish_seq(self, seq: SequenceGroup): del self.to_be_finished[seq.request_id] self.finished_reqs[seq.request_id] = seq - def maybe_assemble_group(self) -> Optional[SequenceGroup]: + def maybe_assemble_group( + self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: """Assemble the sequence group, for producing the final output, or adding request in the engine again. """ @@ -1484,13 +1485,16 @@ def add_request(request_id: str, engine, params, **kwargs): group.streaming = params.output_kind == RequestOutputKind.DELTA group.output_produced = False - def maybe_assemble_group(self) -> Optional[SequenceGroup]: + def maybe_assemble_group( + self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: - # in the streaming mode, we will always return the assembled sequence - # this is because streaming will flatten the responses into a single - # stream + # in the streaming mode, we will return the assembled sequence + # for the first sequence, and then return None for the rest of + # sequences if self.streaming: - return self.assembled_seq_group + if self.seq_id_to_index[seq_group.request_id] == 0: + return self.assembled_seq_group + return None # in the non-streaming mode, we will return the assembled sequence # once after all sequences finish, and then return None for the From adc6be4a27a7057abb2c186a3ba26f1ef390b9d2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 21 Oct 2024 15:21:23 -0700 Subject: [PATCH 16/16] add tests for parallel streaming --- tests/entrypoints/openai/test_completion.py | 34 +++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index cc72a49ebbbda..f03bdb045f640 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, assert "".join(chunks) == single_output +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): + """Streaming for parallel sampling. + The tokens from multiple samples, are flattened into a single stream, + with an index to indicate which sample the token belongs to. + """ + + prompt = "What is an LLM?" + n = 3 + max_tokens = 5 + + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=max_tokens, + n=n, + stream=True) + chunks: List[List[str]] = [[] for i in range(n)] + finish_reason_count = 0 + async for chunk in stream: + index = chunk.choices[0].index + text = chunk.choices[0].text + chunks[index].append(text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == n + for chunk in chunks: + assert len(chunk) == max_tokens + print("".join(chunk)) + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name",