-
-
Notifications
You must be signed in to change notification settings - Fork 6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3eeb148
commit decffe4
Showing
7 changed files
with
384 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import pytest | ||
|
||
|
||
@pytest.fixture | ||
def sample_regex(): | ||
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" | ||
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") | ||
|
||
|
||
@pytest.fixture | ||
def sample_json_schema(): | ||
return { | ||
"type": "object", | ||
"properties": { | ||
"name": { | ||
"type": "string" | ||
}, | ||
"age": { | ||
"type": "integer" | ||
}, | ||
"skills": { | ||
"type": "array", | ||
"items": { | ||
"type": "string", | ||
"maxLength": 10 | ||
}, | ||
"minItems": 3 | ||
}, | ||
"work_history": { | ||
"type": "array", | ||
"items": { | ||
"type": "object", | ||
"properties": { | ||
"company": { | ||
"type": "string" | ||
}, | ||
"duration": { | ||
"type": "number" | ||
}, | ||
"position": { | ||
"type": "string" | ||
} | ||
}, | ||
"required": ["company", "position"] | ||
} | ||
} | ||
}, | ||
"required": ["name", "age", "skills", "work_history"] | ||
} | ||
|
||
|
||
@pytest.fixture | ||
def sample_guided_choice(): | ||
return [ | ||
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", | ||
"Ruby", "Swift", "Kotlin" | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def sample_sql_statements(): | ||
return (""" | ||
start: select_statement | ||
select_statement: "SELECT" column "from" table "where" condition | ||
column: "col_1" | "col_2" | ||
table: "table_1" | "table_2" | ||
condition: column "=" number | ||
number: "1" | "2" | ||
""") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import json | ||
import re | ||
import weakref | ||
from typing import List | ||
|
||
import jsonschema | ||
import pytest | ||
|
||
from vllm.entrypoints.llm import LLM | ||
from vllm.outputs import RequestOutput | ||
from vllm.sampling_params import SamplingParams | ||
|
||
from ...conftest import cleanup | ||
|
||
|
||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" | ||
|
||
PROMPTS = [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
|
||
TOKEN_IDS = [ | ||
[0], | ||
[0, 1], | ||
[0, 2, 1], | ||
[0, 3, 1, 2], | ||
] | ||
|
||
@pytest.fixture(scope="module") | ||
def llm(): | ||
# pytest caches the fixture so we use weakref.proxy to | ||
# enable garbage collection | ||
llm = LLM(model=MODEL_NAME, max_model_len=1024) | ||
|
||
with llm.deprecate_legacy_api(): | ||
yield weakref.proxy(llm) | ||
del llm | ||
cleanup() | ||
|
||
@pytest.mark.skip_global_cleanup | ||
def test_guided_regex(sample_regex, llm): | ||
sampling_params = SamplingParams( | ||
temperature=0.8, | ||
top_p=0.95, | ||
) | ||
outputs = llm.generate(prompts=[ | ||
f"Give an example IPv4 address with this regex: {sample_regex}" | ||
] * 2, | ||
sampling_params=sampling_params, | ||
use_tqdm=True, | ||
guided_options=dict(guided_regex=sample_regex)) | ||
|
||
assert outputs is not None | ||
for output in outputs: | ||
assert output is not None | ||
assert isinstance(output, RequestOutput) | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(generated_text) | ||
assert generated_text is not None | ||
assert re.fullmatch(sample_regex, generated_text) is not None | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
|
||
|
||
@pytest.mark.skip_global_cleanup | ||
def test_guided_json_completion(sample_json_schema, llm): | ||
sampling_params = SamplingParams( | ||
temperature=1.0, | ||
max_tokens=1000, | ||
) | ||
outputs = llm.generate(prompts=[ | ||
f"Give an example JSON for an employee profile " | ||
f"that fits this schema: {sample_json_schema}" | ||
] * 2, | ||
sampling_params=sampling_params, | ||
use_tqdm=True, | ||
guided_options=dict(guided_json=sample_json_schema)) | ||
|
||
assert outputs is not None | ||
|
||
for output in outputs: | ||
assert output is not None | ||
assert isinstance(output, RequestOutput) | ||
prompt = output.prompt | ||
|
||
generated_text = output.outputs[0].text | ||
assert generated_text is not None | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
output_json = json.loads(generated_text) | ||
jsonschema.validate(instance=output_json, schema=sample_json_schema) | ||
|
||
|
||
@pytest.mark.skip_global_cleanup | ||
def test_guided_choice_completion(sample_guided_choice, llm): | ||
sampling_params = SamplingParams( | ||
temperature=0.8, | ||
top_p=0.95, | ||
) | ||
outputs = llm.generate( | ||
prompts="The best language for type-safe systems programming is ", | ||
sampling_params=sampling_params, | ||
use_tqdm=True, | ||
guided_options=dict(guided_choice=sample_guided_choice)) | ||
|
||
assert outputs is not None | ||
for output in outputs: | ||
assert output is not None | ||
assert isinstance(output, RequestOutput) | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(generated_text) | ||
assert generated_text is not None | ||
assert generated_text in sample_guided_choice | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
|
||
|
||
@pytest.mark.skip_global_cleanup | ||
def test_guided_grammar(sample_sql_statements, llm): | ||
|
||
sampling_params = SamplingParams( | ||
temperature=0.8, | ||
top_p=0.95, | ||
) | ||
outputs = llm.generate( | ||
prompts=("Generate a sql state that select col_1 from " | ||
"table_1 where it is equals to 1"), | ||
sampling_params=sampling_params, | ||
use_tqdm=True, | ||
guided_options=dict(guided_grammar=sample_sql_statements)) | ||
|
||
assert outputs is not None | ||
for output in outputs: | ||
assert output is not None | ||
assert isinstance(output, RequestOutput) | ||
prompt = output.prompt | ||
|
||
generated_text = output.outputs[0].text | ||
assert generated_text is not None | ||
|
||
# use Lark to parse the output, and make sure it's a valid parse tree | ||
from lark import Lark | ||
parser = Lark(sample_sql_statements) | ||
parser.parse(generated_text) | ||
|
||
# remove spaces for comparison b/c we removed them in the grammar | ||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( | ||
" ", "") | ||
|
||
assert generated_text.strip() == ground_truth | ||
|
||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.