-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
209 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import pytest | ||
|
||
from vllm.wde.core.processor.input_processor import TextInputProcessor | ||
from vllm.wde.core.schema.engine_io import (TextOnlyInputs, TextPrompt, | ||
TokensPrompt, ValidationError) | ||
|
||
input_processor = TextInputProcessor() | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def request_id(): | ||
return "0" | ||
|
||
|
||
def test_input_processor_1(request_id): | ||
prompt = "test" | ||
request = input_processor(request_id, prompt) | ||
|
||
assert request.inputs == {"prompt": prompt} | ||
|
||
|
||
def test_input_processor_2(request_id): | ||
prompt = "test" | ||
inputs = TextPrompt(prompt=prompt) | ||
request = input_processor(request_id, inputs) | ||
|
||
assert request.inputs == {"prompt": prompt} | ||
|
||
|
||
def test_input_processor_3(request_id): | ||
prompt_token_ids = [0] | ||
inputs = TokensPrompt(prompt_token_ids=prompt_token_ids) | ||
request = input_processor(request_id, inputs) | ||
|
||
assert request.inputs == {"prompt_token_ids": prompt_token_ids} | ||
|
||
|
||
def test_input_processor_4(request_id): | ||
prompt = "test" | ||
prompt_token_ids = [0] | ||
inputs = TextOnlyInputs(prompt_token_ids=prompt_token_ids) | ||
request = input_processor(request_id, inputs) | ||
|
||
assert request.inputs == {"prompt_token_ids": prompt_token_ids} | ||
|
||
inputs = TextOnlyInputs(prompt_token_ids=prompt_token_ids, prompt=prompt) | ||
request = input_processor(request_id, inputs) | ||
|
||
assert request.inputs == { | ||
"prompt_token_ids": prompt_token_ids, | ||
"prompt": prompt | ||
} | ||
|
||
|
||
def test_input_processor_5(request_id): | ||
prompt = "test" | ||
prompt_token_ids = [0] | ||
inputs = {"prompt_token_ids": prompt_token_ids, "prompt": prompt} | ||
|
||
request = input_processor(request_id, inputs) | ||
|
||
assert request.inputs == inputs | ||
|
||
|
||
def test_validation_error(request_id): | ||
with pytest.raises(ValidationError): | ||
inputs = {} | ||
input_processor(request_id, inputs) | ||
|
||
with pytest.raises(ValidationError): | ||
inputs = {"foo": "bar"} | ||
input_processor(request_id, inputs) | ||
|
||
with pytest.raises(ValidationError): | ||
inputs = 0 | ||
input_processor(request_id, inputs) | ||
|
||
with pytest.raises(ValidationError): | ||
inputs = 0.0 | ||
input_processor(request_id, inputs) |
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import time | ||
from abc import ABC, abstractmethod | ||
from typing import Optional, Union | ||
|
||
from vllm.wde.core.schema.engine_io import (Inputs, Params, PromptInput, | ||
Request, TextOnlyInputs, | ||
TextPrompt, TextRequest, | ||
TokensPrompt, ValidationError) | ||
|
||
|
||
class InputProcessor(ABC): | ||
""" | ||
Input(request_id, inputs, params, arrival_time) -> InputProcessor -> Request | ||
""" | ||
|
||
@abstractmethod | ||
def __call__(self, | ||
request_id: str, | ||
inputs: Optional[Union[str, Inputs]] = None, | ||
params: Optional[Params] = None, | ||
arrival_time: Optional[float] = None) -> Request: | ||
raise NotImplementedError | ||
|
||
@classmethod | ||
@abstractmethod | ||
def from_engine(cls, engine): | ||
raise NotImplementedError | ||
|
||
|
||
class TextInputProcessor(InputProcessor): | ||
|
||
def __call__(self, | ||
request_id: str, | ||
inputs: Optional[PromptInput] = None, | ||
params: Optional[Params] = None, | ||
arrival_time: Optional[float] = None) -> TextRequest: | ||
|
||
if isinstance(inputs, str): | ||
inputs = {"prompt": inputs} | ||
elif isinstance(inputs, TextPrompt): | ||
inputs = {"prompt": inputs.prompt} | ||
elif isinstance(inputs, TokensPrompt): | ||
inputs = {"prompt_token_ids": inputs.prompt_token_ids} | ||
elif isinstance(inputs, TextOnlyInputs): | ||
_inputs = {"prompt_token_ids": inputs.prompt_token_ids} | ||
|
||
if inputs.prompt is not None: | ||
_inputs["prompt"] = inputs.prompt | ||
|
||
inputs = _inputs | ||
|
||
elif isinstance(inputs, dict): | ||
if "prompt" not in inputs and "prompt_token_ids" not in inputs: | ||
raise ValidationError('"prompt" and "prompt_token_ids" ' | ||
'have at least one in inputs.') | ||
inputs = { | ||
k: v | ||
for k, v in inputs.items() | ||
if k in {"prompt", "prompt_token_ids"} | ||
} | ||
else: | ||
raise ValidationError( | ||
f"Input does not support {type(inputs)} data type") | ||
|
||
if not arrival_time: | ||
arrival_time = time.time() | ||
request = TextRequest(request_id=str(request_id), | ||
inputs=inputs, | ||
arrival_time=arrival_time) | ||
return request | ||
|
||
@classmethod | ||
def from_engine(cls, engine): | ||
return cls() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from dataclasses import dataclass | ||
from typing import Dict, List, Optional, Union | ||
|
||
|
||
class Params: | ||
pass | ||
|
||
|
||
class Inputs: | ||
pass | ||
|
||
|
||
@dataclass | ||
class TextPrompt(Inputs): | ||
"""Schema for a text prompt.""" | ||
|
||
prompt: str | ||
"""The input text to be tokenized before passing to the model.""" | ||
|
||
|
||
@dataclass | ||
class TokensPrompt(Inputs): | ||
"""Schema for a tokenized prompt.""" | ||
|
||
prompt_token_ids: List[int] | ||
"""A list of token IDs to pass to the model.""" | ||
|
||
|
||
@dataclass | ||
class TextOnlyInputs(Inputs): | ||
prompt_token_ids: List[int] | ||
"""The token IDs of the prompt.""" | ||
|
||
prompt: Optional[str] = None | ||
""" | ||
The original prompt text corresponding to the token IDs, if available. | ||
""" | ||
|
||
|
||
PromptInput = Union[str, Dict, TextPrompt, TokensPrompt, TextOnlyInputs] | ||
|
||
|
||
@dataclass | ||
class Request: | ||
request_id: str | ||
arrival_time: float | ||
|
||
|
||
@dataclass | ||
class TextRequest(Request): | ||
inputs: Dict | ||
|
||
|
||
class ValidationError(ValueError): | ||
pass |