From 4a059d722b6d2898aadb08e200ddb385c9da513c Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 19 Nov 2024 18:11:52 -0800 Subject: [PATCH] Log --async responses to DB, closes #641 Refs #507 --- llm/cli.py | 14 ++++++--- llm/models.py | 15 ++++++++++ tests/test_cli_openai_models.py | 51 +++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 4 deletions(-) diff --git a/llm/cli.py b/llm/cli.py index 5a9f20b4..c75e0e3e 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -6,6 +6,7 @@ import json from llm import ( Attachment, + AsyncResponse, Collection, Conversation, Response, @@ -376,6 +377,7 @@ def read_prompt(): validated_options["stream"] = False prompt = read_prompt() + response = None prompt_method = model.prompt if conversation: @@ -386,12 +388,13 @@ def read_prompt(): async def inner(): if should_stream: - async for chunk in prompt_method( + response = prompt_method( prompt, attachments=resolved_attachments, system=system, **validated_options, - ): + ) + async for chunk in response: print(chunk, end="") sys.stdout.flush() print("") @@ -403,8 +406,9 @@ async def inner(): **validated_options, ) print(await response.text()) + return response - asyncio.run(inner()) + response = asyncio.run(inner()) else: response = prompt_method( prompt, @@ -423,11 +427,13 @@ async def inner(): raise click.ClickException(str(ex)) # Log to the database - if (logs_on() or log) and not no_log and not async_: + if (logs_on() or log) and not no_log: log_path = logs_db_path() (log_path.parent).mkdir(parents=True, exist_ok=True) db = sqlite_utils.Database(log_path) migrate(db) + if isinstance(response, AsyncResponse): + response = asyncio.run(response.to_sync_response()) response.log_to_db(db) diff --git a/llm/models.py b/llm/models.py index 70d19377..c160798b 100644 --- a/llm/models.py +++ b/llm/models.py @@ -426,6 +426,21 @@ async def datetime_utc(self) -> str: def __await__(self): return self._force().__await__() + async def to_sync_response(self) -> Response: + await self._force() + response = Response( + self.prompt, + self.model, + self.stream, + conversation=self.conversation, + ) + response._chunks = self._chunks + response._done = True + response._end = self._end + response._start = self._start + response._start_utcnow = self._start_utcnow + return response + @classmethod def fake( cls, diff --git a/tests/test_cli_openai_models.py b/tests/test_cli_openai_models.py index 7cbab726..b65ad078 100644 --- a/tests/test_cli_openai_models.py +++ b/tests/test_cli_openai_models.py @@ -1,6 +1,7 @@ from click.testing import CliRunner from llm.cli import cli import pytest +import sqlite_utils @pytest.fixture @@ -143,3 +144,53 @@ def test_only_gpt4_audio_preview_allows_mp3_or_wav(httpx_mock, model, filetype): assert ( f"This model does not support attachments of type '{long}'" in result.output ) + + +@pytest.mark.parametrize("async_", (False, True)) +def test_gpt4o_mini_sync_and_async(monkeypatch, tmpdir, httpx_mock, async_): + user_path = tmpdir / "user_dir" + log_db = user_path / "logs.db" + monkeypatch.setenv("LLM_USER_PATH", str(user_path)) + assert not log_db.exists() + httpx_mock.add_response( + method="POST", + # chat completion request + url="https://api.openai.com/v1/chat/completions", + json={ + "id": "chatcmpl-AQT9a30kxEaM1bqxRPepQsPlCyGJh", + "object": "chat.completion", + "created": 1730871958, + "model": "gpt-4o-mini", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Ho ho ho", + "refusal": None, + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 2, + "total_tokens": 12, + }, + "system_fingerprint": "fp_49254d0e9b", + }, + headers={"Content-Type": "application/json"}, + ) + runner = CliRunner() + args = ["-m", "gpt-4o-mini", "--key", "x", "--no-stream"] + if async_: + args.append("--async") + result = runner.invoke(cli, args, catch_exceptions=False) + assert result.exit_code == 0 + assert result.output == "Ho ho ho\n" + # Confirm it was correctly logged + assert log_db.exists() + db = sqlite_utils.Database(str(log_db)) + assert db["responses"].count == 1 + row = next(db["responses"].rows) + assert row["response"] == "Ho ho ho"