Skip to content

Commit

Permalink
Use response.set_usage(), closes #25
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Nov 20, 2024
1 parent cfe61b5 commit c702761
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
15 changes: 15 additions & 0 deletions llm_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,19 @@ def process_part(self, part):
return f'```\n{part["codeExecutionResult"]["output"].strip()}\n```\n'
return ""

def set_usage(self, response):
try:
usage = response.response_json[-1].pop("usageMetadata")
input_tokens = usage.pop("promptTokenCount", None)
output_tokens = usage.pop("candidatesTokenCount", None)
usage.pop("totalTokenCount", None)
if input_tokens is not None:
response.set_usage(
input=input_tokens, output=output_tokens, details=usage or None
)
except (IndexError, KeyError):
pass


class GeminiPro(_SharedGemini, llm.Model):
def execute(self, prompt, stream, response, conversation):
Expand Down Expand Up @@ -241,6 +254,7 @@ def execute(self, prompt, stream, response, conversation):
gathered.append(event)
events.clear()
response.response_json = gathered
self.set_usage(response)


class AsyncGeminiPro(_SharedGemini, llm.AsyncModel):
Expand Down Expand Up @@ -274,6 +288,7 @@ async def execute(self, prompt, stream, response, conversation):
gathered.append(event)
events.clear()
response.response_json = gathered
self.set_usage(response)


@llm.hookimpl
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ classifiers = [
"License :: OSI Approved :: Apache Software License"
]
dependencies = [
"llm>=0.18",
"llm>=0.19a0",
"httpx",
"ijson"
]
Expand Down
7 changes: 6 additions & 1 deletion tests/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@ async def test_prompt():
"candidates": [
{"content": {"parts": [{"text": "Percy"}], "role": "model"}}
],
"usageMetadata": {"promptTokenCount": 10, "totalTokenCount": 10},
"modelVersion": "gemini-1.5-flash-002",
}
]
assert response.token_details is None
assert response.input_tokens == 10
# Not sure why our pytest-recording setup doesn't report output tokens
# https://github.com/simonw/llm-gemini/issues/25#issuecomment-2487464339
assert response.output_tokens is None

# And try it async too
async_model = llm.get_async_model("gemini-1.5-flash-latest")
async_model.key = async_model.key or GEMINI_API_KEY
Expand Down

0 comments on commit c702761

Please sign in to comment.