Skip to content

Commit

Permalink
InputProcessor & TextInputProcessor
Browse files Browse the repository at this point in the history
  • Loading branch information
noooop committed Sep 30, 2024
1 parent f13a07b commit 82197c8
Show file tree
Hide file tree
Showing 10 changed files with 209 additions and 0 deletions.
Empty file added tests/wde/__init__.py
Empty file.
Empty file added tests/wde/core/__init__.py
Empty file.
Empty file.
80 changes: 80 additions & 0 deletions tests/wde/core/processor/test_input_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import pytest

from vllm.wde.core.processor.input_processor import TextInputProcessor
from vllm.wde.core.schema.engine_io import (TextOnlyInputs, TextPrompt,
TokensPrompt, ValidationError)

input_processor = TextInputProcessor()


@pytest.fixture(scope="session")
def request_id():
return "0"


def test_input_processor_1(request_id):
prompt = "test"
request = input_processor(request_id, prompt)

assert request.inputs == {"prompt": prompt}


def test_input_processor_2(request_id):
prompt = "test"
inputs = TextPrompt(prompt=prompt)
request = input_processor(request_id, inputs)

assert request.inputs == {"prompt": prompt}


def test_input_processor_3(request_id):
prompt_token_ids = [0]
inputs = TokensPrompt(prompt_token_ids=prompt_token_ids)
request = input_processor(request_id, inputs)

assert request.inputs == {"prompt_token_ids": prompt_token_ids}


def test_input_processor_4(request_id):
prompt = "test"
prompt_token_ids = [0]
inputs = TextOnlyInputs(prompt_token_ids=prompt_token_ids)
request = input_processor(request_id, inputs)

assert request.inputs == {"prompt_token_ids": prompt_token_ids}

inputs = TextOnlyInputs(prompt_token_ids=prompt_token_ids, prompt=prompt)
request = input_processor(request_id, inputs)

assert request.inputs == {
"prompt_token_ids": prompt_token_ids,
"prompt": prompt
}


def test_input_processor_5(request_id):
prompt = "test"
prompt_token_ids = [0]
inputs = {"prompt_token_ids": prompt_token_ids, "prompt": prompt}

request = input_processor(request_id, inputs)

assert request.inputs == inputs


def test_validation_error(request_id):
with pytest.raises(ValidationError):
inputs = {}
input_processor(request_id, inputs)

with pytest.raises(ValidationError):
inputs = {"foo": "bar"}
input_processor(request_id, inputs)

with pytest.raises(ValidationError):
inputs = 0
input_processor(request_id, inputs)

with pytest.raises(ValidationError):
inputs = 0.0
input_processor(request_id, inputs)
Empty file added vllm/wde/__init__.py
Empty file.
Empty file added vllm/wde/core/__init__.py
Empty file.
Empty file.
74 changes: 74 additions & 0 deletions vllm/wde/core/processor/input_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import time
from abc import ABC, abstractmethod
from typing import Optional, Union

from vllm.wde.core.schema.engine_io import (Inputs, Params, PromptInput,
Request, TextOnlyInputs,
TextPrompt, TextRequest,
TokensPrompt, ValidationError)


class InputProcessor(ABC):
"""
Input(request_id, inputs, params, arrival_time) -> InputProcessor -> Request
"""

@abstractmethod
def __call__(self,
request_id: str,
inputs: Optional[Union[str, Inputs]] = None,
params: Optional[Params] = None,
arrival_time: Optional[float] = None) -> Request:
raise NotImplementedError

@classmethod
@abstractmethod
def from_engine(cls, engine):
raise NotImplementedError


class TextInputProcessor(InputProcessor):

def __call__(self,
request_id: str,
inputs: Optional[PromptInput] = None,
params: Optional[Params] = None,
arrival_time: Optional[float] = None) -> TextRequest:

if isinstance(inputs, str):
inputs = {"prompt": inputs}
elif isinstance(inputs, TextPrompt):
inputs = {"prompt": inputs.prompt}
elif isinstance(inputs, TokensPrompt):
inputs = {"prompt_token_ids": inputs.prompt_token_ids}
elif isinstance(inputs, TextOnlyInputs):
_inputs = {"prompt_token_ids": inputs.prompt_token_ids}

if inputs.prompt is not None:
_inputs["prompt"] = inputs.prompt

inputs = _inputs

elif isinstance(inputs, dict):
if "prompt" not in inputs and "prompt_token_ids" not in inputs:
raise ValidationError('"prompt" and "prompt_token_ids" '
'have at least one in inputs.')
inputs = {
k: v
for k, v in inputs.items()
if k in {"prompt", "prompt_token_ids"}
}
else:
raise ValidationError(
f"Input does not support {type(inputs)} data type")

if not arrival_time:
arrival_time = time.time()
request = TextRequest(request_id=str(request_id),
inputs=inputs,
arrival_time=arrival_time)
return request

@classmethod
def from_engine(cls, engine):
return cls()
Empty file.
55 changes: 55 additions & 0 deletions vllm/wde/core/schema/engine_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Union


class Params:
pass


class Inputs:
pass


@dataclass
class TextPrompt(Inputs):
"""Schema for a text prompt."""

prompt: str
"""The input text to be tokenized before passing to the model."""


@dataclass
class TokensPrompt(Inputs):
"""Schema for a tokenized prompt."""

prompt_token_ids: List[int]
"""A list of token IDs to pass to the model."""


@dataclass
class TextOnlyInputs(Inputs):
prompt_token_ids: List[int]
"""The token IDs of the prompt."""

prompt: Optional[str] = None
"""
The original prompt text corresponding to the token IDs, if available.
"""


PromptInput = Union[str, Dict, TextPrompt, TokensPrompt, TextOnlyInputs]


@dataclass
class Request:
request_id: str
arrival_time: float


@dataclass
class TextRequest(Request):
inputs: Dict


class ValidationError(ValueError):
pass

0 comments on commit 82197c8

Please sign in to comment.