Skip to content

Commit

Permalink
attachments= keyword argument, tests pass again - refs #587
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Oct 28, 2024
1 parent a9bc1c7 commit 286cf9f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
2 changes: 1 addition & 1 deletion llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def chat(
raise click.ClickException(str(ex))
if prompt.strip() in ("exit", "quit"):
break
response = conversation.prompt(prompt, system, **validated_options)
response = conversation.prompt(prompt, system=system, **validated_options)
# System prompt only sent for the first message:
system = None
for chunk in response:
Expand Down
19 changes: 14 additions & 5 deletions llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,18 @@ class Prompt:
options: "Options" = field(default_factory=dict)

def __init__(
self, prompt, model, attachments, system=None, prompt_json=None, options=None
self,
prompt,
model,
*,
attachments=None,
system=None,
prompt_json=None,
options=None
):
self.prompt = prompt
self.model = model
self.attachments = list(attachments)
self.attachments = list(attachments or [])
self.system = system
self.prompt_json = prompt_json
self.options = options or {}
Expand All @@ -105,7 +112,8 @@ class Conversation:
def prompt(
self,
prompt: Optional[str],
*attachments: Attachment,
*,
attachments: Attachment = None,
system: Optional[str] = None,
stream: bool = True,
**options
Expand Down Expand Up @@ -386,7 +394,8 @@ def execute(
def prompt(
self,
prompt: str,
*attachments: Attachment,
*,
attachments: Attachment = None,
system: Optional[str] = None,
stream: bool = True,
**options
Expand All @@ -396,7 +405,7 @@ def prompt(
raise ValueError(
"This model does not support attachments, but some were provided"
)
for attachment in attachments:
for attachment in attachments or []:
attachment_type = attachment.resolve_type()
if attachment_type not in self.attachment_types:
raise ValueError(
Expand Down
5 changes: 4 additions & 1 deletion tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ def test_chat_basic(mock_model, logs_db):
mock_model.enqueue(["one world"])
mock_model.enqueue(["one again"])
result = runner.invoke(
llm.cli.cli, ["chat", "-m", "mock"], input="Hi\nHi two\nquit\n"
llm.cli.cli,
["chat", "-m", "mock"],
input="Hi\nHi two\nquit\n",
catch_exceptions=False,
)
assert result.exit_code == 0
assert result.output == (
Expand Down

0 comments on commit 286cf9f

Please sign in to comment.