Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update models.transformers to use SequenceGeneratorAdapter and OutlinesLogitsProcessors #966

Merged
merged 1 commit into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,9 @@ model = outlines.models.transformers("mistralai/Mistral-7B-Instruct-v0.2")
generator = outlines.generate.json(model, Character)

# Draw a sample
rng = torch.Generator(device="cuda")
rng.manual_seed(789001)
seed = 789001

character = generator("Give me a character description", rng=rng)
character = generator("Give me a character description", seed=seed)

print(repr(character))
# Character(name='Anderson', age=28, armor=<Armor.chainmail: 'chainmail'>, weapon=<Weapon.sword: 'sword'>, strength=8)
Expand Down
53 changes: 52 additions & 1 deletion docs/reference/models/transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Outlines provides an integration with the `torch` implementation of causal model
```python
from outlines import models

model = models.transformers("mistralai/Mistral-7B-v0.1", device="cuda")
model = models.transformers("mistralai/Mistral-7B-v0.3", device="cuda")
```

If you need more fine-grained control you can also initialize the model and tokenizer separately:
Expand All @@ -30,4 +30,55 @@ tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = models.Transformers(llm, tokenizer)
```

# Using Logits Processors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need to improve the documentation to reach something similar to the lamacpp integration's.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should plan a restructuring and cleaning up of documentation in a separate issue. I could share some ideas in call on how we might approach this.

In this case, a lot of information documented for llamacpp applies to all other models including transformers. We shouldn't repeat ourselves. We should explain the behavior of all models generally, highlight the models differences with a feature table, and document only transformers specific information on its documentation page.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant was listing the main arguments to you can pass when initialising and calling the model, cf https://outlines-dev.github.io/outlines/reference/models/llamacpp/


There are two ways to use Outlines Structured Generation with HuggingFace Transformers:
- 1) Use Outlines generation wrapper, `outlines.models.transformers`
- 2) Use `OutlinesLogitsProcessor` with `transformers.AutoModelForCausalLM`

Outlines supports a myriad of logits processors for structured generation. In these example, we will use the `RegexLogitsProcessor` which guarantees generated text matches the specified pattern.

## Example: `outlines.models.transformers`

```
import outlines

time_regex_pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?"

model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct", device="cuda")
generator = outlines.generate.regex(model, time_regex_pattern)

output = generator("The the best time to visit a dentist is at ")
print(output)
# 2:30 pm
```

## Example: Direct `transformers` library use

```
import outlines
import transformers


model_uri = "microsoft/Phi-3-mini-4k-instruct"

outlines_tokenizer = outlines.models.TransformerTokenizer(
transformers.AutoTokenizer.from_pretrained(model_uri)
)
phone_number_logits_processor = outlines.processors.RegexLogitsProcessor(
"\\+?[1-9][0-9]{7,14}", # phone number pattern
outlines_tokenizer,
)

generator = transformers.pipeline('text-generation', model=model_uri)

output = generator(
"Jenny gave me her number it's ",
logits_processor=transformers.LogitsProcessorList([phone_number_logits_processor])
)
print(output)
# [{'generated_text': "Jenny gave me her number it's 2125550182"}]
# not quite 8675309 what we expected, but it is a valid phone number
```

[transformers]: https://github.com/huggingface/transformers
5 changes: 2 additions & 3 deletions docs/reference/text.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ from outlines import models, generate

model = models.transformers("mistralai/Mistral-7B-v0.1")

rng = torch.Generator(device="cuda")
rng.manual_seed(789001)
seed = 789001

answer = generator("What is 2+2?", rng=rng)
answer = generator("What is 2+2?", seed=seed)
```
6 changes: 2 additions & 4 deletions examples/llamacpp_example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from enum import Enum

import torch
from pydantic import BaseModel, constr

import outlines
Expand Down Expand Up @@ -37,10 +36,9 @@ class Character(BaseModel):
generator = outlines.generate.json(model, Character)

# Draw a sample
rng = torch.Generator(device="cpu")
rng.manual_seed(789005)
seed = 789005

prompt = "Instruct: You are a leading role play gamer. You have seen thousands of different characters and their attributes.\nPlease return a JSON object with common attributes of an RPG character. Give me a character description\nOutput:"

sequence = generator(prompt, rng=rng, max_tokens=512)
rlouf marked this conversation as resolved.
Show resolved Hide resolved
sequence = generator(prompt, seed=seed, max_tokens=512)
print(sequence)
42 changes: 8 additions & 34 deletions outlines/generate/cfg.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from functools import singledispatch

from outlines.fsm.guide import CFGGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.generate.api import SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp
from outlines.models.mlxlm import MLXLM
from outlines.models.vllm import VLLM
from outlines.samplers import Sampler, multinomial


@singledispatch
def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenerator:
def cfg(
model, cfg_str: str, sampler: Sampler = multinomial()
) -> SequenceGeneratorAdapter:
"""Generate text in the language of a Context-Free Grammar

Arguments
Expand All @@ -24,40 +22,16 @@ def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenera

Returns
-------
A `SequenceGenerator` instance that generates text.
A `SequenceGeneratorAdapter` instance that generates text.

"""
fsm = CFGGuide(cfg_str, model.tokenizer)
device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)

return generator


@cfg.register(MLXLM)
@cfg.register(VLLM)
def cfg_unimplemented(
model,
cfg_str: str,
sampler: Sampler = multinomial(),
):
raise NotImplementedError(
rlouf marked this conversation as resolved.
Show resolved Hide resolved
f"The CFG Logits processor is not available for {type(model)}."
f"The CFG Logits processor is not available for {type(model)}. "
+ "Please subscribe to https://github.com/outlines-dev/outlines/issues/684"
+ " for updates on the fix."
)


@cfg.register(LlamaCpp)
def cfg_llamacpp(
model: LlamaCpp,
cfg_str: str,
sampler: Sampler = multinomial(),
):
from outlines.integrations.llamacpp import CFGLogitsProcessor

logits_processor = CFGLogitsProcessor(cfg_str, model.model)
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@cfg.register(OpenAI)
def cfg_openai(model, cfg_str: str, sampler: Sampler = multinomial()):
raise NotImplementedError(
Expand Down
6 changes: 4 additions & 2 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp
from outlines.models.mlxlm import MLXLM
from outlines.models.transformers import Transformers
from outlines.models.vllm import VLLM
from outlines.samplers import Sampler, multinomial

Expand Down Expand Up @@ -39,8 +40,9 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):


@regex.register(MLXLM)
def regex_mlxlm(
model: MLXLM,
@regex.register(Transformers)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the _unified dispatchers will become the default dispatcher as a next step. In this PR it's just used by MLXLM and Transformers

def regex_unified(
model,
regex_str: str,
sampler: Sampler = multinomial(),
):
Expand Down
5 changes: 3 additions & 2 deletions outlines/generate/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from outlines.fsm.guide import StopAtEOSGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI
from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI, Transformers
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -37,7 +37,8 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:


@text.register(MLXLM)
def text_mlxlm(model: MLXLM, sampler: Sampler = multinomial()):
@text.register(Transformers)
def text_unified(model, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)


Expand Down
2 changes: 1 addition & 1 deletion outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .mamba import Mamba, mamba
from .mlxlm import MLXLM, mlxlm
from .openai import OpenAI, azure_openai, openai
from .transformers import Transformers, transformers
from .transformers import Transformers, TransformerTokenizer, transformers
from .vllm import VLLM, vllm

LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, Mamba, MLXLM, VLLM]
Loading
Loading