Skip to content

Commit

Permalink
Use LogitsProcessors for models.transformers, apply in outlines.gener…
Browse files Browse the repository at this point in the history
…ate.*
  • Loading branch information
lapp0 committed Jul 3, 2024
1 parent 25b6bcd commit 7d43bbd
Show file tree
Hide file tree
Showing 12 changed files with 358 additions and 288 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,9 @@ model = outlines.models.transformers("mistralai/Mistral-7B-Instruct-v0.2")
generator = outlines.generate.json(model, Character)

# Draw a sample
rng = torch.Generator(device="cuda")
rng.manual_seed(789001)
seed = 789001

character = generator("Give me a character description", rng=rng)
character = generator("Give me a character description", seed=seed)

print(repr(character))
# Character(name='Anderson', age=28, armor=<Armor.chainmail: 'chainmail'>, weapon=<Weapon.sword: 'sword'>, strength=8)
Expand Down
53 changes: 52 additions & 1 deletion docs/reference/models/transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Outlines provides an integration with the `torch` implementation of causal model
```python
from outlines import models

model = models.transformers("mistralai/Mistral-7B-v0.1", device="cuda")
model = models.transformers("mistralai/Mistral-7B-v0.3", device="cuda")
```

If you need more fine-grained control you can also initialize the model and tokenizer separately:
Expand All @@ -30,4 +30,55 @@ tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = models.Transformers(llm, tokenizer)
```

# Using Logits Processors

There are two ways to use Outlines Structured Generation with HuggingFace Transformers:
- 1) Use Outlines generation wrapper, `outlines.models.transformers`
- 2) Use `OutlinesLogitsProcessor` with `transformers.AutoModelForCausalLM`

Outlines supports a myriad of logits processors for structured generation. In these example, we will use the `RegexLogitsProcessor` which guarantees generated text matches the specified pattern.

## Example: `outlines.models.transformers`

```
import outlines
time_regex_pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?"
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct", device="cuda")
generator = outlines.generate.regex(model, time_regex_pattern)
output = generator("The the best time to visit a dentist is at ")
print(output)
# 2:30 pm
```

## Example: Direct `transformers` library use

```
import outlines
import transformers
model_uri = "microsoft/Phi-3-mini-4k-instruct"
outlines_tokenizer = outlines.models.TransformerTokenizer(
transformers.AutoTokenizer.from_pretrained(model_uri)
)
phone_number_logits_processor = outlines.processors.RegexLogitsProcessor(
"\\+?[1-9][0-9]{7,14}", # phone number pattern
outlines_tokenizer,
)
generator = transformers.pipeline('text-generation', model=model_uri)
output = generator(
"Jenny gave me her number it's ",
logits_processor=transformers.LogitsProcessorList([phone_number_logits_processor])
)
print(output)
# [{'generated_text': "Jenny gave me her number it's 2125550182"}]
# not quite 8675309 what we expected, but it is a valid phone number
```

[transformers]: https://github.com/huggingface/transformers
5 changes: 2 additions & 3 deletions docs/reference/text.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ from outlines import models, generate

model = models.transformers("mistralai/Mistral-7B-v0.1")

rng = torch.Generator(device="cuda")
rng.manual_seed(789001)
seed = 789001

answer = generator("What is 2+2?", rng=rng)
answer = generator("What is 2+2?", seed=seed)
```
6 changes: 2 additions & 4 deletions examples/llamacpp_example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from enum import Enum

import torch
from pydantic import BaseModel, constr

import outlines
Expand Down Expand Up @@ -37,10 +36,9 @@ class Character(BaseModel):
generator = outlines.generate.json(model, Character)

# Draw a sample
rng = torch.Generator(device="cpu")
rng.manual_seed(789005)
seed = 789005

prompt = "Instruct: You are a leading role play gamer. You have seen thousands of different characters and their attributes.\nPlease return a JSON object with common attributes of an RPG character. Give me a character description\nOutput:"

sequence = generator(prompt, rng=rng, max_tokens=512)
sequence = generator(prompt, seed=seed, max_tokens=512)
print(sequence)
42 changes: 8 additions & 34 deletions outlines/generate/cfg.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from functools import singledispatch

from outlines.fsm.guide import CFGGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.generate.api import SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp
from outlines.models.mlxlm import MLXLM
from outlines.models.vllm import VLLM
from outlines.samplers import Sampler, multinomial


@singledispatch
def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenerator:
def cfg(
model, cfg_str: str, sampler: Sampler = multinomial()
) -> SequenceGeneratorAdapter:
"""Generate text in the language of a Context-Free Grammar
Arguments
Expand All @@ -24,40 +22,16 @@ def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenera
Returns
-------
A `SequenceGenerator` instance that generates text.
A `SequenceGeneratorAdapter` instance that generates text.
"""
fsm = CFGGuide(cfg_str, model.tokenizer)
device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)

return generator


@cfg.register(MLXLM)
@cfg.register(VLLM)
def cfg_unimplemented(
model,
cfg_str: str,
sampler: Sampler = multinomial(),
):
raise NotImplementedError(
f"The CFG Logits processor is not available for {type(model)}."
f"The CFG Logits processor is not available for {type(model)}. "
+ "Please subscribe to https://github.com/outlines-dev/outlines/issues/684"
+ " for updates on the fix."
)


@cfg.register(LlamaCpp)
def cfg_llamacpp(
model: LlamaCpp,
cfg_str: str,
sampler: Sampler = multinomial(),
):
from outlines.integrations.llamacpp import CFGLogitsProcessor

logits_processor = CFGLogitsProcessor(cfg_str, model.model)
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@cfg.register(OpenAI)
def cfg_openai(model, cfg_str: str, sampler: Sampler = multinomial()):
raise NotImplementedError(
Expand Down
6 changes: 4 additions & 2 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp
from outlines.models.mlxlm import MLXLM
from outlines.models.transformers import Transformers
from outlines.models.vllm import VLLM
from outlines.samplers import Sampler, multinomial

Expand Down Expand Up @@ -39,8 +40,9 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):


@regex.register(MLXLM)
def regex_mlxlm(
model: MLXLM,
@regex.register(Transformers)
def regex_unified(
model,
regex_str: str,
sampler: Sampler = multinomial(),
):
Expand Down
5 changes: 3 additions & 2 deletions outlines/generate/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from outlines.fsm.guide import StopAtEOSGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI
from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI, Transformers
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -37,7 +37,8 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:


@text.register(MLXLM)
def text_mlxlm(model: MLXLM, sampler: Sampler = multinomial()):
@text.register(Transformers)
def text_unified(model, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)


Expand Down
2 changes: 1 addition & 1 deletion outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .mamba import Mamba, mamba
from .mlxlm import MLXLM, mlxlm
from .openai import OpenAI, azure_openai, openai
from .transformers import Transformers, transformers
from .transformers import Transformers, TransformerTokenizer, transformers
from .vllm import VLLM, vllm

LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, Mamba, MLXLM, VLLM]
Loading

0 comments on commit 7d43bbd

Please sign in to comment.