Skip to content

Commit

Permalink
Add Continuation generation model
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jun 28, 2023
1 parent 12a67b8 commit 7f4387d
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 1 deletion.
1 change: 1 addition & 0 deletions outlines/text/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .functions import function
from .generate import continuation
from .prompts import prompt, render
1 change: 1 addition & 0 deletions outlines/text/generate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .continuation import continuation
52 changes: 52 additions & 0 deletions outlines/text/generate/continuation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import List, Optional

import numpy as np
from numpy.typing import NDArray

from outlines.text.generate.sequence import Sequence


class Continuation(Sequence):
"""Represents a completion generation model.
`Completion` instances are unconstrained generation models that stop when an EOS token
has been found or when the maximum number of tokens has been reached.
>> import outlines.text as text
>> sequence = text.sequence(model)("Say something")
"""

def __init__(self, model, max_tokens: Optional[int]):
super().__init__(model, max_tokens)

def is_finished(self, token_ids: NDArray[np.int64]) -> NDArray[np.bool_]:
"""Determine whether the sequences reached maximum length of end with
and EOS token.
In practice, `Sequence`'s `__call__` methods only passed the `token_ids`
of the sequences that haven't been marked as finished already, which is
why we only need to look for the EOS token in the last element rather
than in the whole sequence.
Parameters
----------
token_ids
The input sequences.
"""
is_finished = np.zeros((token_ids.shape[0],), dtype=np.bool_)
is_finished[token_ids[:, -1] == self.model.tokenizer.eos_token_id] = True

return is_finished

def postprocess_completions(self, completions: List[str]) -> List[str]:
"""Remove the EOS token from the completion."""
return [
completion.replace(self.model.tokenizer.eos_token, "")
for completion in completions
]


def continuation(model, max_tokens: Optional[int] = None):
return Continuation(model, max_tokens)
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def is_finished(self, token_ids: NDArray[np.int64]) -> NDArray[np.bool_]:
"`Sequence.is_finished` must be implemented by subclasses."
)

def postprocess_completions(self, completions: List[str]) -> List[str]:
return completions

def step(
self,
rng: Generator,
Expand Down Expand Up @@ -204,6 +207,7 @@ def __call__(
is_finished[~is_finished] = self.is_finished(token_ids_unfinished).flatten()

result = self.model.tokenizer.decode(token_ids)
result = self.postprocess_completions(result)

if len(result) == 1:
return result[0]
Expand Down
42 changes: 42 additions & 0 deletions tests/text/generate/test_continuation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
from numpy.testing import assert_array_equal

from outlines.text.generate.continuation import Continuation, continuation


class Tokenizer:
eos_token = "<EOS>"
eos_token_id = 0
pad_token_ids = -1


class Model:
tokenizer = Tokenizer()


def test_continuation_is_finished():
model = continuation(Model(), 10)
assert isinstance(model, Continuation)

token_ids = np.array([[3, 2]])
result = model.is_finished(token_ids)
assert_array_equal(result, [False])

token_ids = np.array([[3, 2, 0]])
result = model.is_finished(token_ids)
assert_array_equal(result, [True])

token_ids = np.array([[3, 2, 1], [3, 2, 0]])
result = model.is_finished(token_ids)
assert_array_equal(result, [False, True])

token_ids = np.array([[3, 2, 1, 0], [3, 2, 0, -1]])
result = model.is_finished(token_ids)
assert_array_equal(result, [True, False])


def test_continuation_postprocess():
model = continuation(Model())
result = model.postprocess_completions(["Here<EOS>"])
assert len(result) == 1
assert result[0] == "Here"
18 changes: 18 additions & 0 deletions tests/text/generate/test_integration_transfomers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np

import outlines.models as models
from outlines.text.generate.continuation import continuation

TEST_MODEL = "hf-internal-testing/tiny-random-GPTJForCausalLM"


def test_transformers_integration_completion():
rng = np.random.default_rng(0)

model = models.transformers(TEST_MODEL, device="cpu")
sequence = continuation(model)("prompt", rng=rng)
assert isinstance(sequence, str)
assert model.tokenizer.eos_token not in sequence

sequence = continuation(model, max_tokens=10)("prompt", rng=rng)
assert isinstance(sequence, str)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from numpy.testing import assert_array_equal

from outlines.text.sequences.sequence import Sequence, vectorized_random_choice
from outlines.text.generate.sequence import Sequence, vectorized_random_choice


def test_vectorized_random_choice():
Expand Down

0 comments on commit 7f4387d

Please sign in to comment.