Skip to content

Commit

Permalink
Refactored prefill text output logic
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Jan 31, 2025
1 parent 1673576 commit d8ea765
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ def __init__(
if supports_pdf:
self.attachment_types.add("application/pdf")

def prefill_text(self, prompt):
if prompt.options.prefill and not prompt.options.hide_prefill:
return prompt.options.prefill
return ""

def build_messages(self, prompt, conversation) -> List[dict]:
messages = []
if conversation:
Expand Down Expand Up @@ -285,20 +290,19 @@ class ClaudeMessages(_Shared, llm.Model):
def execute(self, prompt, stream, response, conversation):
client = Anthropic(api_key=self.get_key())
kwargs = self.build_kwargs(prompt, conversation)
prefill_text = self.prefill_text(prompt)
if stream:
with client.messages.stream(**kwargs) as stream:
if prompt.options.prefill and not prompt.options.hide_prefill:
yield prompt.options.prefill
if prefill_text:
yield prefill_text
for text in stream.text_stream:
yield text
# This records usage and other data:
response.response_json = stream.get_final_message().model_dump()
else:
completion = client.messages.create(**kwargs)
text = completion.content[0].text
if prompt.options.prefill and not prompt.options.hide_prefill:
text = prompt.options.prefill + text
yield text
yield prefill_text + text
response.response_json = completion.model_dump()
self.set_usage(response)

Expand All @@ -312,19 +316,18 @@ class AsyncClaudeMessages(_Shared, llm.AsyncModel):
async def execute(self, prompt, stream, response, conversation):
client = AsyncAnthropic(api_key=self.get_key())
kwargs = self.build_kwargs(prompt, conversation)
prefill_text = self.prefill_text(prompt)
if stream:
async with client.messages.stream(**kwargs) as stream_obj:
if prompt.options.prefill and not prompt.options.hide_prefill:
yield prompt.options.prefill
if prefill_text:
yield prefill_text
async for text in stream_obj.text_stream:
yield text
response.response_json = (await stream_obj.get_final_message()).model_dump()
else:
completion = await client.messages.create(**kwargs)
text = completion.content[0].text
if prompt.options.prefill and not prompt.options.hide_prefill:
text = prompt.options.prefill + text
yield text
yield prefill_text + text
response.response_json = completion.model_dump()
self.set_usage(response)

Expand Down

0 comments on commit d8ea765

Please sign in to comment.