Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: convert Google Gemini tests to VCR #118

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
interactions:
- request:
body: '{"system_instruction": {"parts": [{"text": "You are a helpful assistant."}]},
"contents": [{"role": "user", "parts": [{"text": "Hello"}]}]}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '139'
content-type:
- application/json
host:
- generativelanguage.googleapis.com
user-agent:
- python-httpx/0.27.2
method: POST
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=test_google_api_key
response:
body:
string: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\":
[\n {\n \"text\": \"Hello! \U0001F44B How can I help
you today? \U0001F60A \\n\"\n }\n ],\n \"role\": \"model\"\n
\ },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\":
[\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n
\ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\":
\"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n
\ },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n
\ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\":
\"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n
\ }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\":
8,\n \"candidatesTokenCount\": 12,\n \"totalTokenCount\": 20\n }\n}\n"
headers:
Alt-Svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
Cache-Control:
- private
Content-Type:
- application/json; charset=UTF-8
Date:
- Wed, 02 Oct 2024 01:06:50 GMT
Server:
- scaffolding on HTTPServer2
Server-Timing:
- gfet4t7; dur=426
Transfer-Encoding:
- chunked
Vary:
- Origin
- X-Origin
- Referer
X-Content-Type-Options:
- nosniff
X-Frame-Options:
- SAMEORIGIN
X-XSS-Protection:
- '0'
content-length:
- '855'
status:
code: 200
message: OK
version: 1
73 changes: 73 additions & 0 deletions packages/exchange/tests/providers/cassettes/test_google_tools.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
interactions:
- request:
body: '{"system_instruction": {"parts": [{"text": "You are a helpful assistant.
Expect to need to read a file using read_file."}]}, "contents": [{"role": "user",
"parts": [{"text": "What are the contents of this file? test.txt"}]}], "tools":
{"functionDeclarations": [{"name": "read_file", "description": "Read the contents
of the file.", "parameters": {"type": "object", "properties": {"filename": {"type":
"string", "description": "The path to the file, which can be relative or\nabsolute.
If it is a plain filename, it is assumed to be in the\ncurrent working directory."}},
"required": ["filename"]}}]}}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '600'
content-type:
- application/json
host:
- generativelanguage.googleapis.com
user-agent:
- python-httpx/0.27.2
method: POST
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=test_google_api_key
response:
body:
string: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\":
[\n {\n \"functionCall\": {\n \"name\": \"read_file\",\n
\ \"args\": {\n \"filename\": \"test.txt\"\n }\n
\ }\n }\n ],\n \"role\": \"model\"\n },\n
\ \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\":
[\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n
\ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\":
\"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n
\ },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n
\ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\":
\"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n
\ }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\":
101,\n \"candidatesTokenCount\": 17,\n \"totalTokenCount\": 118\n }\n}\n"
headers:
Alt-Svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
Cache-Control:
- private
Content-Type:
- application/json; charset=UTF-8
Date:
- Wed, 02 Oct 2024 01:06:51 GMT
Server:
- scaffolding on HTTPServer2
Server-Timing:
- gfet4t7; dur=449
Transfer-Encoding:
- chunked
Vary:
- Origin
- X-Origin
- Referer
X-Content-Type-Options:
- nosniff
X-Frame-Options:
- SAMEORIGIN
X-XSS-Protection:
- '0'
content-length:
- '947'
status:
code: 200
message: OK
version: 1
29 changes: 25 additions & 4 deletions packages/exchange/tests/providers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,23 @@ def default_azure_env(monkeypatch):
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_KEY", AZURE_API_KEY)


GOOGLE_API_KEY = "test_google_api_key"


@pytest.fixture
def default_google_env(monkeypatch):
"""
This fixture prevents GoogleProvider.from_env() from erring on missing
environment variables.

When running VCR tests for the first time or after deleting a cassette
recording, set required environment variables, so that real requests don't
fail. Subsequent runs use the recorded data, so don't need them.
"""
if "GOOGLE_API_KEY" not in os.environ:
monkeypatch.setenv("GOOGLE_API_KEY", GOOGLE_API_KEY)


@pytest.fixture(scope="module")
def vcr_config():
"""
Expand Down Expand Up @@ -85,6 +102,8 @@ def scrub_request_url(request):
request.uri = re.sub(r"/deployments/[^/]+", f"/deployments/{AZURE_DEPLOYMENT_NAME}", request.uri)
request.headers["host"] = AZURE_ENDPOINT.replace("https://", "")
request.headers["api-key"] = AZURE_API_KEY
elif "generativelanguage.googleapis.com" in request.uri:
request.uri = re.sub(r"([?&])key=[^&]+", r"\1key=" + GOOGLE_API_KEY, request.uri)

return request

Expand All @@ -93,16 +112,18 @@ def scrub_response_headers(response):
"""
This scrubs sensitive response headers. Note they are case-sensitive!
"""
response["headers"]["openai-organization"] = OPENAI_ORG_ID
response["headers"]["Set-Cookie"] = "test_set_cookie"
if "openai-organization" in response["headers"]:
response["headers"]["openai-organization"] = OPENAI_ORG_ID
if "Set-Cookie" in response["headers"]:
response["headers"]["Set-Cookie"] = "test_set_cookie"
return response


def complete(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
provider = provider_cls.from_env()
system = "You are a helpful assistant."
messages = [Message.user("Hello")]
return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs)
return provider.complete(model=model, system=system, messages=messages, tools=(), **kwargs)


def tools(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
Expand All @@ -128,4 +149,4 @@ def vision(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message,
content=[ToolResult(tool_use_id="xyz", output='"image:tests/test_image.png"')],
),
]
return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs)
return provider.complete(model=model, system=system, messages=messages, tools=(), **kwargs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a lot of IDE warnings about call sites we use, as it is at the moment tools: Tuple[Tool]. So, this clears warnings about passing None. Though I wonder if it shouldn't be tools: Tuple[Tool, ...]? Seems that would be more explicit about possibly none, and also clear some other warnings. If you feel this is worthwhile I can in a follow-up or different PR.

Screenshot 2024-10-05 at 7 32 24 AM

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

80 changes: 31 additions & 49 deletions packages/exchange/tests/providers/test_google.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os
from unittest.mock import patch

import httpx
import pytest
from exchange import Message, Text
from exchange.content import ToolResult, ToolUse
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.providers.google import GoogleProvider
from exchange.tool import Tool
from .conftest import complete, tools

GOOGLE_MODEL = os.getenv("GOOGLE_MODEL", "gemini-1.5-flash")


def example_fn(param: str) -> None:
Expand All @@ -30,12 +32,6 @@ def test_from_env_throw_error_when_missing_api_key():
assert "https://ai.google.dev/gemini-api/docs/api-key" in context.value.message


@pytest.fixture
@patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"})
def google_provider():
return GoogleProvider.from_env()


def test_google_response_to_text_message() -> None:
response = {"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}]}
message = GoogleProvider.google_response_to_message(response)
Expand Down Expand Up @@ -105,54 +101,40 @@ def test_messages_to_google_spec() -> None:
assert actual_spec == expected_spec


@patch("httpx.Client.post")
@patch("logging.warning")
@patch("logging.error")
def test_google_completion(mock_error, mock_warning, mock_post, google_provider):
mock_response = {
"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}],
"usageMetadata": {"promptTokenCount": 3, "candidatesTokenCount": 10, "totalTokenCount": 13},
}
@pytest.mark.vcr()
def test_google_complete(default_google_env):
reply_message, reply_usage = complete(GoogleProvider, GOOGLE_MODEL)

# First attempts fail with status code 429, 2nd succeeds
def create_response(status_code, json_data=None):
response = httpx.Response(status_code)
response._content = httpx._content.json_dumps(json_data or {}).encode()
response._request = httpx.Request("POST", "https://generativelanguage.googleapis.com/v1beta/")
return response
assert reply_message.content == [Text("Hello! 👋 How can I help you today? 😊 \n")]
assert reply_usage.total_tokens == 20

mock_post.side_effect = [
create_response(429), # 1st attempt
create_response(200, mock_response), # Final success
]

model = "gemini-1.5-flash"
system = "You are a helpful assistant."
messages = [Message.user("Hello, Gemini")]
@pytest.mark.integration
def test_google_complete_integration():
reply = complete(GoogleProvider, GOOGLE_MODEL)

reply_message, reply_usage = google_provider.complete(model=model, system=system, messages=messages)
assert reply[0].content is not None
print("Completion content from Google:", reply[0].content)

assert reply_message.content == [Text(text="Hello from Gemini!")]
assert reply_usage.total_tokens == 13
assert mock_post.call_count == 2
mock_post.assert_any_call(
"models/gemini-1.5-flash:generateContent",
json={
"system_instruction": {"parts": [{"text": "You are a helpful assistant."}]},
"contents": [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}],
},
)

@pytest.mark.vcr()
def test_google_tools(default_google_env):
reply_message, reply_usage = tools(GoogleProvider, GOOGLE_MODEL)

@pytest.mark.integration
def test_google_integration():
provider = GoogleProvider.from_env()
model = "gemini-1.5-flash" # updated model to a known valid model
system = "You are a helpful assistant."
messages = [Message.user("Hello, Gemini")]
tool_use = reply_message.content[0]
assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}"
assert tool_use.id == "read_file"
assert tool_use.name == "read_file"
assert tool_use.parameters == {"filename": "test.txt"}
assert reply_usage.total_tokens == 118

# Run the completion
reply = provider.complete(model=model, system=system, messages=messages)

assert reply[0].content is not None
print("Completion content from Google:", reply[0].content)
@pytest.mark.integration
def test_google_tools_integration():
reply = tools(GoogleProvider, GOOGLE_MODEL)

tool_use = reply[0].content[0]
assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}"
assert tool_use.id is not None
assert tool_use.name == "read_file"
assert tool_use.parameters == {"filename": "test.txt"}