Skip to content

Commit

Permalink
local guided decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinbu233 committed Jul 28, 2024
1 parent 3eeb148 commit decffe4
Show file tree
Hide file tree
Showing 7 changed files with 384 additions and 6 deletions.
69 changes: 69 additions & 0 deletions tests/entrypoints/llm/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest


@pytest.fixture
def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")


@pytest.fixture
def sample_json_schema():
return {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work_history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work_history"]
}


@pytest.fixture
def sample_guided_choice():
return [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
"Ruby", "Swift", "Kotlin"
]


@pytest.fixture
def sample_sql_statements():
return ("""
start: select_statement
select_statement: "SELECT" column "from" table "where" condition
column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number
number: "1" | "2"
""")
154 changes: 154 additions & 0 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import json
import re
import weakref
from typing import List

import jsonschema
import pytest

from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams

from ...conftest import cleanup


MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"

PROMPTS = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

TOKEN_IDS = [
[0],
[0, 1],
[0, 2, 1],
[0, 3, 1, 2],
]

@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME, max_model_len=1024)

with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup()

@pytest.mark.skip_global_cleanup
def test_guided_regex(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
outputs = llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options=dict(guided_regex=sample_regex))

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert re.fullmatch(sample_regex, generated_text) is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_json_completion(sample_json_schema, llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
)
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options=dict(guided_json=sample_json_schema))

assert outputs is not None

for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)


@pytest.mark.skip_global_cleanup
def test_guided_choice_completion(sample_guided_choice, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
use_tqdm=True,
guided_options=dict(guided_choice=sample_guided_choice))

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert generated_text in sample_guided_choice
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_grammar(sample_sql_statements, llm):

sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
outputs = llm.generate(
prompts=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
sampling_params=sampling_params,
use_tqdm=True,
guided_options=dict(guided_grammar=sample_sql_statements))

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None

# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(sample_sql_statements)
parser.parse(generated_text)

# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")

assert generated_text.strip() == ground_truth

print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
47 changes: 45 additions & 2 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
from typing import (ClassVar, Dict, List, Optional, Sequence, Union, cast,
overload)

from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
Expand All @@ -10,6 +11,8 @@
parse_and_batch_prompt)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
GuidedDecodingFields, get_local_guided_decoding_logits_processor)
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
Expand Down Expand Up @@ -262,6 +265,7 @@ def generate(
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options: Optional[Union[Dict, GuidedDecodingFields]] = None
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
Expand Down Expand Up @@ -303,6 +307,13 @@ def generate(
else:
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)

if isinstance(guided_options, Dict):
if len(guided_options) > 1:
raise ValueError(
"You can only use one guided decoding but multiple is "
f"specified: {self.__dict__}")
guided_options = GuidedDecodingFields(**guided_options)

if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
Expand All @@ -311,7 +322,8 @@ def generate(
inputs=inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
guided_options=guided_options)

outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput)
Expand Down Expand Up @@ -508,6 +520,7 @@ def _validate_and_add_requests(
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
guided_options: Optional[GuidedDecodingFields] = None,
) -> None:
if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list.
Expand All @@ -523,6 +536,18 @@ def _validate_and_add_requests(
raise ValueError("The lengths of prompts and lora_request "
"must be the same.")

if isinstance(params, list):
if len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")

params = [
self._add_guided_processor(param, guided_options)
for param in params if isinstance(param, SamplingParams)
]
elif isinstance(params, SamplingParams):
params = self._add_guided_processor(params, guided_options)

# Add requests to the engine.
for i, request_inputs in enumerate(inputs):
self._add_request(
Expand All @@ -548,6 +573,24 @@ def _add_request(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)

def _add_guided_processor(
self,
params: SamplingParams,
guided_options: Optional[GuidedDecodingFields] = None):
if guided_options:
if guided_options.guided_decoding_backend is None:
decoding_config = self.llm_engine.get_decoding_config()
guided_options.guided_decoding_backend = (
decoding_config.guided_decoding_backend)
guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa
guided_options.guided_decoding_backend, guided_options,
self.get_tokenizer())
if guided_logits_processor:
if params.logits_processors is None:
params.logits_processors = []
params.logits_processors.append(guided_logits_processor)
return params

def _run_engine(
self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Expand Down
26 changes: 24 additions & 2 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
get_lm_format_enforcer_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingFields)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor

Expand All @@ -20,6 +21,8 @@ async def get_guided_decoding_logits_processor(
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_lm_format_enforcer_guided_decoding_logits_processor)
return await get_lm_format_enforcer_guided_decoding_logits_processor(
request, tokenizer)

Expand All @@ -28,6 +31,25 @@ async def get_guided_decoding_logits_processor(
"Must be one of 'outlines, 'lm-format-enforcer'")


def get_local_guided_decoding_logits_processor(
guided_decoding_backend: str, guided_options: GuidedDecodingFields,
tokenizer) -> Optional[LogitsProcessor]:
# request = _adapt_request_for_tool_use(request)

if guided_decoding_backend == 'outlines':
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_options, tokenizer)

raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")


def _adapt_request_for_tool_use(request: Union[CompletionRequest,
ChatCompletionRequest]):
# the legacy completion API does not support tool use
Expand Down
Loading

0 comments on commit decffe4

Please sign in to comment.