Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] add fixture to guided processor tests #6341

Merged
merged 2 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/entrypoints/openai/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest


@pytest.fixture
def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")


@pytest.fixture
def sample_json_schema():
return {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work_history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work_history"]
}


@pytest.fixture
def sample_guided_choice():
return [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
"Ruby", "Swift", "Kotlin"
]


@pytest.fixture
def sample_sql_statements():
return ("""
start: select_statement
select_statement: "SELECT" column "from" table "where" condition
column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number
number: "1" | "2"
""")
119 changes: 39 additions & 80 deletions tests/entrypoints/openai/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,53 +22,6 @@
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"

TEST_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "string"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work history"]
}

TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")

TEST_CHOICE = [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby",
"Swift", "Kotlin"
]


@pytest.fixture(scope="module")
def zephyr_lora_files():
Expand Down Expand Up @@ -408,7 +361,8 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str):
guided_decoding_backend: str,
sample_guided_choice):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -422,10 +376,10 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
model=MODEL_NAME,
messages=messages,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE,
extra_body=dict(guided_choice=sample_guided_choice,
guided_decoding_backend=guided_decoding_backend))
choice1 = chat_completion.choices[0].message.content
assert choice1 in TEST_CHOICE
assert choice1 in sample_guided_choice

messages.append({"role": "assistant", "content": choice1})
messages.append({
Expand All @@ -436,18 +390,19 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
model=MODEL_NAME,
messages=messages,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE,
extra_body=dict(guided_choice=sample_guided_choice,
guided_decoding_backend=guided_decoding_backend))
choice2 = chat_completion.choices[0].message.content
assert choice2 in TEST_CHOICE
assert choice2 in sample_guided_choice
assert choice1 != choice2


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_json_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str):
guided_decoding_backend: str,
sample_json_schema):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -456,18 +411,18 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
f"fits this schema: {sample_json_schema}"
}]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
extra_body=dict(guided_json=TEST_SCHEMA,
extra_body=dict(guided_json=sample_json_schema,
guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message
assert message.content is not None
json1 = json.loads(message.content)
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)
jsonschema.validate(instance=json1, schema=sample_json_schema)

messages.append({"role": "assistant", "content": message.content})
messages.append({
Expand All @@ -480,12 +435,12 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
extra_body=dict(guided_json=TEST_SCHEMA,
extra_body=dict(guided_json=sample_json_schema,
guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message
assert message.content is not None
json2 = json.loads(message.content)
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
jsonschema.validate(instance=json2, schema=sample_json_schema)
assert json1["name"] != json2["name"]
assert json1["age"] != json2["age"]

Expand All @@ -494,37 +449,37 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_regex_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str):
guided_decoding_backend: str, sample_regex):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example IP address with this regex: {TEST_REGEX}"
f"Give an example IP address with this regex: {sample_regex}"
}]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX,
extra_body=dict(guided_regex=sample_regex,
guided_decoding_backend=guided_decoding_backend))
ip1 = chat_completion.choices[0].message.content
assert ip1 is not None
assert re.fullmatch(TEST_REGEX, ip1) is not None
assert re.fullmatch(sample_regex, ip1) is not None

messages.append({"role": "assistant", "content": ip1})
messages.append({"role": "user", "content": "Give me a different one"})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX,
extra_body=dict(guided_regex=sample_regex,
guided_decoding_backend=guided_decoding_backend))
ip2 = chat_completion.choices[0].message.content
assert ip2 is not None
assert re.fullmatch(TEST_REGEX, ip2) is not None
assert re.fullmatch(sample_regex, ip2) is not None
assert ip1 != ip2


Expand Down Expand Up @@ -553,7 +508,8 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
guided_decoding_backend: str):
guided_decoding_backend: str,
sample_guided_choice):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -569,7 +525,7 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
max_tokens=10,
logprobs=True,
top_logprobs=5,
extra_body=dict(guided_choice=TEST_CHOICE,
extra_body=dict(guided_choice=sample_guided_choice,
guided_decoding_backend=guided_decoding_backend))

assert chat_completion.choices[0].logprobs is not None
Expand All @@ -585,7 +541,8 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_named_tool_use(client: openai.AsyncOpenAI,
guided_decoding_backend: str):
guided_decoding_backend: str,
sample_json_schema):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -594,7 +551,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
f"fits this schema: {sample_json_schema}"
}]

# non-streaming
Expand All @@ -608,7 +565,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
"parameters": sample_json_schema
}
}],
tool_choice={
Expand All @@ -621,7 +578,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
assert len(message.content) == 0
json_string = message.tool_calls[0].function.arguments
json1 = json.loads(json_string)
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)
jsonschema.validate(instance=json1, schema=sample_json_schema)

messages.append({"role": "assistant", "content": json_string})
messages.append({
Expand All @@ -642,7 +599,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
"parameters": sample_json_schema
}
}],
tool_choice={
Expand All @@ -667,15 +624,16 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
# finish reason should only return in last block
assert finish_reason_count == 1
json2 = json.loads("".join(output))
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
jsonschema.validate(instance=json2, schema=sample_json_schema)
assert json1["name"] != json2["name"]
assert json1["age"] != json2["age"]


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_required_tool_use_not_yet_supported(
client: openai.AsyncOpenAI, guided_decoding_backend: str):
client: openai.AsyncOpenAI, guided_decoding_backend: str,
sample_json_schema):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -684,7 +642,7 @@ async def test_required_tool_use_not_yet_supported(
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
f"fits this schema: {sample_json_schema}"
}]

with pytest.raises(openai.BadRequestError):
Expand All @@ -697,7 +655,7 @@ async def test_required_tool_use_not_yet_supported(
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
"parameters": sample_json_schema
}
}],
tool_choice="required")
Expand All @@ -712,16 +670,17 @@ async def test_required_tool_use_not_yet_supported(
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
"parameters": sample_json_schema
}
}],
tool_choice="auto")


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_inconsistent_tool_choice_and_tools(
client: openai.AsyncOpenAI, guided_decoding_backend: str):
async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_json_schema):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -730,7 +689,7 @@ async def test_inconsistent_tool_choice_and_tools(
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
f"fits this schema: {sample_json_schema}"
}]

with pytest.raises(openai.BadRequestError):
Expand All @@ -755,7 +714,7 @@ async def test_inconsistent_tool_choice_and_tools(
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
"parameters": sample_json_schema
}
}],
tool_choice={
Expand Down
Loading
Loading