diff --git a/llm_anthropic.py b/llm_anthropic.py index 7a9fadc..e06e418 100644 --- a/llm_anthropic.py +++ b/llm_anthropic.py @@ -183,6 +183,11 @@ def __init__( if supports_pdf: self.attachment_types.add("application/pdf") + def prefill_text(self, prompt): + if prompt.options.prefill and not prompt.options.hide_prefill: + return prompt.options.prefill + return "" + def build_messages(self, prompt, conversation) -> List[dict]: messages = [] if conversation: @@ -285,10 +290,11 @@ class ClaudeMessages(_Shared, llm.Model): def execute(self, prompt, stream, response, conversation): client = Anthropic(api_key=self.get_key()) kwargs = self.build_kwargs(prompt, conversation) + prefill_text = self.prefill_text(prompt) if stream: with client.messages.stream(**kwargs) as stream: - if prompt.options.prefill and not prompt.options.hide_prefill: - yield prompt.options.prefill + if prefill_text: + yield prefill_text for text in stream.text_stream: yield text # This records usage and other data: @@ -296,9 +302,7 @@ def execute(self, prompt, stream, response, conversation): else: completion = client.messages.create(**kwargs) text = completion.content[0].text - if prompt.options.prefill and not prompt.options.hide_prefill: - text = prompt.options.prefill + text - yield text + yield prefill_text + text response.response_json = completion.model_dump() self.set_usage(response) @@ -312,19 +316,18 @@ class AsyncClaudeMessages(_Shared, llm.AsyncModel): async def execute(self, prompt, stream, response, conversation): client = AsyncAnthropic(api_key=self.get_key()) kwargs = self.build_kwargs(prompt, conversation) + prefill_text = self.prefill_text(prompt) if stream: async with client.messages.stream(**kwargs) as stream_obj: - if prompt.options.prefill and not prompt.options.hide_prefill: - yield prompt.options.prefill + if prefill_text: + yield prefill_text async for text in stream_obj.text_stream: yield text response.response_json = (await stream_obj.get_final_message()).model_dump() else: completion = await client.messages.create(**kwargs) text = completion.content[0].text - if prompt.options.prefill and not prompt.options.hide_prefill: - text = prompt.options.prefill + text - yield text + yield prefill_text + text response.response_json = completion.model_dump() self.set_usage(response)