Skip to content

Commit

Permalink
prefill, hide_prefill and stop_sequences options
Browse files Browse the repository at this point in the history
Refs #2
  • Loading branch information
simonw committed Jan 31, 2025
1 parent e50aeaf commit ffd2391
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 5 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ jobs:
- name: Run tests
run: |
pytest
- name: Check if cog needs to be run
run: |
cog --check README.md
105 changes: 105 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,111 @@ The plugin sets up `claude-3.5-sonnet` and similar as aliases, usable like this:
llm -m claude-3.5-sonnet 'Fun facts about pelicans'
```

## Model options

The following options can be passed using `-o name value` on the CLI or as `keyword=value` arguments to the Python `model.prompt()` method:

<!-- [[[cog
import cog, llm
_type_lookup = {
"number": "float",
"integer": "int",
"string": "str",
"object": "dict",
}
model = llm.get_model("claude-3.5-sonnet")
output = []
for name, field in model.Options.schema()["properties"].items():
any_of = field.get("anyOf")
if any_of is None:
any_of = [{"type": field["type"]}]
types = ", ".join(
[
_type_lookup.get(item["type"], item["type"])
for item in any_of
if item["type"] != "null"
]
)
bits = ["- **", name, "**: `", types, "`\n"]
description = field.get("description", "")
if description:
bits.append('\n ' + description + '\n\n')
output.append("".join(bits))
cog.out("".join(output))
]]] -->
- **max_tokens**: `int`

The maximum number of tokens to generate before stopping

- **temperature**: `float`

Amount of randomness injected into the response. Defaults to 1.0. Ranges from 0.0 to 1.0. Use temperature closer to 0.0 for analytical / multiple choice, and closer to 1.0 for creative and generative tasks. Note that even with temperature of 0.0, the results will not be fully deterministic.

- **top_p**: `float`

Use nucleus sampling. In nucleus sampling, we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. You should either alter temperature or top_p, but not both. Recommended for advanced use cases only. You usually only need to use temperature.

- **top_k**: `int`

Only sample from the top K options for each subsequent token. Used to remove 'long tail' low probability responses. Recommended for advanced use cases only. You usually only need to use temperature.

- **user_id**: `str`

An external identifier for the user who is associated with the request

- **prefill**: `str`

A prefill to use for the response

- **hide_prefill**: `boolean`

Do not repeat the prefill value at the start of the response

- **stop_sequences**: `array, str`

Custom text sequences that will cause the model to stop generating - pass either a list of strings or a single string

<!-- [[[end]]] -->

The `prefill` option can be used to set the first part of the response. To increase the chance of returning JSON, set that to `{`:

```bash
llm -m claude-3.5-sonnet 'Fun data about pelicans' \
-o prefill '{'
```
If you do not want the prefill token to be echoed in the response, set `hide_prefill` to `true`:

```bash
llm -m claude-3.5-haiku 'Short python function describing a pelican' \
-o prefill '```python' \
-o hide_prefill true \
-o stop_sequences '```'
```
This example sets `` ``` `` as the stop sequence, so the response will be a Python function without the wrapping Markdown code block.

To pass a single stop sequence, send a string:
```bash
llm -m claude-3.5-sonnet 'Fun facts about pelicans' \
-o stop-sequences "beak"
```
For multiple stop sequences, pass a JSON array:

```bash
llm -m claude-3.5-sonnet 'Fun facts about pelicans' \
-o stop-sequences '["beak", "feathers"]'
```

When using the Python API, pass a string or an array of strings:

```python
response = llm.query(
model="claude-3.5-sonnet",
query="Fun facts about pelicans",
stop_sequences = ["beak", "feathers"],
)
```

## Development

To set up this plugin locally, first checkout the code. Then create a new virtual environment:
Expand Down
62 changes: 58 additions & 4 deletions llm_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from anthropic import Anthropic, AsyncAnthropic
import llm
import json
from pydantic import Field, field_validator, model_validator
from typing import Optional, List
from typing import Optional, List, Union


@llm.hookimpl
Expand Down Expand Up @@ -73,6 +74,44 @@ class ClaudeOptions(llm.Options):
default=None,
)

prefill: Optional[str] = Field(
description="A prefill to use for the response",
default=None,
)

hide_prefill: Optional[bool] = Field(
description="Do not repeat the prefill value at the start of the response",
default=None,
)

stop_sequences: Optional[Union[list, str]] = Field(
description=(
"Custom text sequences that will cause the model to stop generating - "
"pass either a list of strings or a single string"
),
default=None,
)

@field_validator("stop_sequences")
def validate_stop_sequences(cls, stop_sequences):
error_msg = "stop_sequences must be a list of strings or a single string"
if isinstance(stop_sequences, str):
try:
stop_sequences = json.loads(stop_sequences)
if not isinstance(stop_sequences, list) or not all(
isinstance(seq, str) for seq in stop_sequences
):
raise ValueError(error_msg)
return stop_sequences
except json.JSONDecodeError:
return [stop_sequences]
elif isinstance(stop_sequences, list):
if not all(isinstance(seq, str) for seq in stop_sequences):
raise ValueError(error_msg)
return stop_sequences
else:
raise ValueError(error_msg)

@field_validator("max_tokens")
@classmethod
def validate_max_tokens(cls, max_tokens):
Expand Down Expand Up @@ -129,7 +168,7 @@ def __init__(
supports_images=True,
supports_pdf=False,
):
self.model_id = 'anthropic/' + model_id
self.model_id = "anthropic/" + model_id
self.claude_model_id = claude_model_id or model_id
self.attachment_types = set()
if supports_images:
Expand Down Expand Up @@ -201,6 +240,8 @@ def build_messages(self, prompt, conversation) -> List[dict]:
)
else:
messages.append({"role": "user", "content": prompt.prompt})
if prompt.options.prefill:
messages.append({"role": "assistant", "content": prompt.options.prefill})
return messages

def build_kwargs(self, prompt, conversation):
Expand All @@ -223,6 +264,9 @@ def build_kwargs(self, prompt, conversation):
if prompt.system:
kwargs["system"] = prompt.system

if prompt.options.stop_sequences:
kwargs["stop_sequences"] = prompt.options.stop_sequences

return kwargs

def set_usage(self, response):
Expand All @@ -243,13 +287,18 @@ def execute(self, prompt, stream, response, conversation):
kwargs = self.build_kwargs(prompt, conversation)
if stream:
with client.messages.stream(**kwargs) as stream:
if prompt.options.prefill and not prompt.options.hide_prefill:
yield prompt.options.prefill
for text in stream.text_stream:
yield text
# This records usage and other data:
response.response_json = stream.get_final_message().model_dump()
else:
completion = client.messages.create(**kwargs)
yield completion.content[0].text
text = completion.content[0].text
if prompt.options.prefill and not prompt.options.hide_prefill:
text = prompt.options.prefill + text
yield text
response.response_json = completion.model_dump()
self.set_usage(response)

Expand All @@ -265,12 +314,17 @@ async def execute(self, prompt, stream, response, conversation):
kwargs = self.build_kwargs(prompt, conversation)
if stream:
async with client.messages.stream(**kwargs) as stream_obj:
if prompt.options.prefill and not prompt.options.hide_prefill:
yield prompt.options.prefill
async for text in stream_obj.text_stream:
yield text
response.response_json = (await stream_obj.get_final_message()).model_dump()
else:
completion = await client.messages.create(**kwargs)
yield completion.content[0].text
text = completion.content[0].text
if prompt.options.prefill and not prompt.options.hide_prefill:
text = prompt.options.prefill + text
yield text
response.response_json = completion.model_dump()
self.set_usage(response)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ CI = "https://github.com/simonw/llm-anthropic/actions"
anthropic = "llm_anthropic"

[project.optional-dependencies]
test = ["pytest", "pytest-recording", "pytest-asyncio"]
test = ["pytest", "pytest-recording", "pytest-asyncio", "cogapp"]

0 comments on commit ffd2391

Please sign in to comment.