Skip to content

Commit

Permalink
Add Completion generation model
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jun 27, 2023
1 parent 96a029b commit 86b7dd0
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 0 deletions.
1 change: 1 addition & 0 deletions outlines/text/sequences/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .completion import completion
52 changes: 52 additions & 0 deletions outlines/text/sequences/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import List, Optional

import numpy as np
from numpy.typing import NDArray

from outlines.text.sequences.sequence import Sequence


class Completion(Sequence):
"""Represents a completion generation model.
`Completion` instances are unconstrained generation models that stop when an EOS token
has been found or when the maximum number of tokens has been reached.
>> import outlines.text as text
>> sequence = text.sequence(model)("Say something")
"""

def __init__(self, model, max_tokens: Optional[int]):
super().__init__(model, max_tokens)

def is_finished(self, token_ids: NDArray[np.int64]) -> NDArray[np.bool_]:
"""Determine whether the sequences reached maximum length of end with
and EOS token.
In practice, `Sequence`'s `__call__` methods only passed the `token_ids`
of the sequences that haven't been marked as finished already, which is
why we only need to look for the EOS token in the last element rather
than in the whole sequence.
Parameters
----------
token_ids
The input sequences.
"""
is_finished = np.zeros((token_ids.shape[0],), dtype=np.bool_)
is_finished[token_ids[:, -1] == self.model.tokenizer.eos_token_id] = True

return is_finished

def postprocess_completions(self, completions: List[str]) -> List[str]:
"""Remove the EOS token from the completion."""
return [
completion.replace(self.model.tokenizer.eos_token, "")
for completion in completions
]


def completion(model, max_tokens: Optional[int] = None):
return Completion(model, max_tokens)
4 changes: 4 additions & 0 deletions outlines/text/sequences/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def is_finished(self, token_ids: NDArray[np.int64]) -> NDArray[np.bool_]:
"`Sequence.is_finished` must be implemented by subclasses."
)

def postprocess_completions(self, completions: List[str]) -> List[str]:
return completions

def step(
self,
rng: Generator,
Expand Down Expand Up @@ -204,6 +207,7 @@ def __call__(
is_finished[~is_finished] = self.is_finished(token_ids_unfinished).flatten()

result = self.model.tokenizer.decode(token_ids)
result = self.postprocess_completions(result)

if len(result) == 1:
return result[0]
Expand Down
42 changes: 42 additions & 0 deletions tests/text/sequences/test_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
from numpy.testing import assert_array_equal

from outlines.text.sequences.completion import Completion, completion


class Tokenizer:
eos_token = "<EOS>"
eos_token_id = 0
pad_token_ids = -1


class Model:
tokenizer = Tokenizer()


def test_completion_is_finished():
model = completion(Model(), 10)
assert isinstance(model, Completion)

token_ids = np.array([[3, 2]])
result = model.is_finished(token_ids)
assert_array_equal(result, [False])

token_ids = np.array([[3, 2, 0]])
result = model.is_finished(token_ids)
assert_array_equal(result, [True])

token_ids = np.array([[3, 2, 1], [3, 2, 0]])
result = model.is_finished(token_ids)
assert_array_equal(result, [False, True])

token_ids = np.array([[3, 2, 1, 0], [3, 2, 0, -1]])
result = model.is_finished(token_ids)
assert_array_equal(result, [True, False])


def test_completion_postprocess():
model = completion(Model())
result = model.postprocess_completions(["Here<EOS>"])
assert len(result) == 1
assert result[0] == "Here"
18 changes: 18 additions & 0 deletions tests/text/sequences/test_integration_transfomers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np

import outlines.models as models
from outlines.text.sequences.completion import completion

TEST_MODEL = "hf-internal-testing/tiny-random-GPTJForCausalLM"


def test_transformers_integration_completion():
rng = np.random.default_rng(0)

model = models.transformers(TEST_MODEL, device="cpu")
sequence = completion(model)("prompt", rng=rng)
assert isinstance(sequence, str)
assert model.tokenizer.eos_token not in sequence

sequence = completion(model, max_tokens=10)("prompt", rng=rng)
assert isinstance(sequence, str)

0 comments on commit 86b7dd0

Please sign in to comment.