Skip to content

Commit

Permalink
[Misc] add fixture to guided processor tests (#6341)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinbu233 authored Jul 12, 2024
1 parent f9d25c2 commit b039cbb
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 206 deletions.
69 changes: 69 additions & 0 deletions tests/entrypoints/openai/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest


@pytest.fixture
def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")


@pytest.fixture
def sample_json_schema():
return {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work_history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work_history"]
}


@pytest.fixture
def sample_guided_choice():
return [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
"Ruby", "Swift", "Kotlin"
]


@pytest.fixture
def sample_sql_statements():
return ("""
start: select_statement
select_statement: "SELECT" column "from" table "where" condition
column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number
number: "1" | "2"
""")
119 changes: 39 additions & 80 deletions tests/entrypoints/openai/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,53 +22,6 @@
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"

TEST_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "string"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work history"]
}

TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")

TEST_CHOICE = [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby",
"Swift", "Kotlin"
]


@pytest.fixture(scope="module")
def zephyr_lora_files():
Expand Down Expand Up @@ -408,7 +361,8 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str):
guided_decoding_backend: str,
sample_guided_choice):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -422,10 +376,10 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
model=MODEL_NAME,
messages=messages,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE,
extra_body=dict(guided_choice=sample_guided_choice,
guided_decoding_backend=guided_decoding_backend))
choice1 = chat_completion.choices[0].message.content
assert choice1 in TEST_CHOICE
assert choice1 in sample_guided_choice

messages.append({"role": "assistant", "content": choice1})
messages.append({
Expand All @@ -436,18 +390,19 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
model=MODEL_NAME,
messages=messages,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE,
extra_body=dict(guided_choice=sample_guided_choice,
guided_decoding_backend=guided_decoding_backend))
choice2 = chat_completion.choices[0].message.content
assert choice2 in TEST_CHOICE
assert choice2 in sample_guided_choice
assert choice1 != choice2


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_json_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str):
guided_decoding_backend: str,
sample_json_schema):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -456,18 +411,18 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
f"fits this schema: {sample_json_schema}"
}]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
extra_body=dict(guided_json=TEST_SCHEMA,
extra_body=dict(guided_json=sample_json_schema,
guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message
assert message.content is not None
json1 = json.loads(message.content)
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)
jsonschema.validate(instance=json1, schema=sample_json_schema)

messages.append({"role": "assistant", "content": message.content})
messages.append({
Expand All @@ -480,12 +435,12 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
extra_body=dict(guided_json=TEST_SCHEMA,
extra_body=dict(guided_json=sample_json_schema,
guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message
assert message.content is not None
json2 = json.loads(message.content)
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
jsonschema.validate(instance=json2, schema=sample_json_schema)
assert json1["name"] != json2["name"]
assert json1["age"] != json2["age"]

Expand All @@ -494,37 +449,37 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_regex_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str):
guided_decoding_backend: str, sample_regex):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example IP address with this regex: {TEST_REGEX}"
f"Give an example IP address with this regex: {sample_regex}"
}]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX,
extra_body=dict(guided_regex=sample_regex,
guided_decoding_backend=guided_decoding_backend))
ip1 = chat_completion.choices[0].message.content
assert ip1 is not None
assert re.fullmatch(TEST_REGEX, ip1) is not None
assert re.fullmatch(sample_regex, ip1) is not None

messages.append({"role": "assistant", "content": ip1})
messages.append({"role": "user", "content": "Give me a different one"})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX,
extra_body=dict(guided_regex=sample_regex,
guided_decoding_backend=guided_decoding_backend))
ip2 = chat_completion.choices[0].message.content
assert ip2 is not None
assert re.fullmatch(TEST_REGEX, ip2) is not None
assert re.fullmatch(sample_regex, ip2) is not None
assert ip1 != ip2


Expand Down Expand Up @@ -553,7 +508,8 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
guided_decoding_backend: str):
guided_decoding_backend: str,
sample_guided_choice):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -569,7 +525,7 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
max_tokens=10,
logprobs=True,
top_logprobs=5,
extra_body=dict(guided_choice=TEST_CHOICE,
extra_body=dict(guided_choice=sample_guided_choice,
guided_decoding_backend=guided_decoding_backend))

assert chat_completion.choices[0].logprobs is not None
Expand All @@ -585,7 +541,8 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_named_tool_use(client: openai.AsyncOpenAI,
guided_decoding_backend: str):
guided_decoding_backend: str,
sample_json_schema):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -594,7 +551,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
f"fits this schema: {sample_json_schema}"
}]

# non-streaming
Expand All @@ -608,7 +565,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
"parameters": sample_json_schema
}
}],
tool_choice={
Expand All @@ -621,7 +578,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
assert len(message.content) == 0
json_string = message.tool_calls[0].function.arguments
json1 = json.loads(json_string)
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)
jsonschema.validate(instance=json1, schema=sample_json_schema)

messages.append({"role": "assistant", "content": json_string})
messages.append({
Expand All @@ -642,7 +599,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
"parameters": sample_json_schema
}
}],
tool_choice={
Expand All @@ -667,15 +624,16 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
# finish reason should only return in last block
assert finish_reason_count == 1
json2 = json.loads("".join(output))
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
jsonschema.validate(instance=json2, schema=sample_json_schema)
assert json1["name"] != json2["name"]
assert json1["age"] != json2["age"]


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_required_tool_use_not_yet_supported(
client: openai.AsyncOpenAI, guided_decoding_backend: str):
client: openai.AsyncOpenAI, guided_decoding_backend: str,
sample_json_schema):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -684,7 +642,7 @@ async def test_required_tool_use_not_yet_supported(
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
f"fits this schema: {sample_json_schema}"
}]

with pytest.raises(openai.BadRequestError):
Expand All @@ -697,7 +655,7 @@ async def test_required_tool_use_not_yet_supported(
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
"parameters": sample_json_schema
}
}],
tool_choice="required")
Expand All @@ -712,16 +670,17 @@ async def test_required_tool_use_not_yet_supported(
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
"parameters": sample_json_schema
}
}],
tool_choice="auto")


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_inconsistent_tool_choice_and_tools(
client: openai.AsyncOpenAI, guided_decoding_backend: str):
async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_json_schema):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -730,7 +689,7 @@ async def test_inconsistent_tool_choice_and_tools(
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
f"fits this schema: {sample_json_schema}"
}]

with pytest.raises(openai.BadRequestError):
Expand All @@ -755,7 +714,7 @@ async def test_inconsistent_tool_choice_and_tools(
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
"parameters": sample_json_schema
}
}],
tool_choice={
Expand Down
Loading

0 comments on commit b039cbb

Please sign in to comment.