Skip to content

Commit

Permalink
Add Completion generation model
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jun 15, 2023
1 parent 5b0cbb8 commit 76879da
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 0 deletions.
1 change: 1 addition & 0 deletions outlines/text/sequences/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .completion import completion
56 changes: 56 additions & 0 deletions outlines/text/sequences/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import numpy as np


class Completion:
"""Completion generation model.
Sequence instances are unconstrained generation models that stop when a token or a
sequence of tokens have been generated.
>> import outlines.text as text
>> sequence = text.sequence(model)("Say something")
"""

def __init__(self, model):
self.model = model

def is_finished(self, token_ids):
is_finished = np.zeros((token_ids.shape[0],), dtype=np.bool_)
is_finished[token_ids[:, -1] == self.model.tokenizer.eos_token_id] = True
return is_finished

def step(self, token_ids: np.ndarray, samples: int = 1, **model_kwargs):
probs = self.model(token_ids, samples, **model_kwargs)
next_token_ids = np.atleast_2d(
vectorized_choice(np.arange(probs.shape[1]), probs)
)
token_ids = np.hstack([token_ids, next_token_ids.T])
return token_ids, probs

def __call__(self, prompt: str):
token_ids = self.model.tokenizer(prompt)

while True:
token_ids = self.model(token_ids)
if self.is_finished(token_ids):
break

return self.model.tokenizer.decode(prompt)


def completion(model):
return Completion(model)


def vectorized_choice(items, probability_matrix):
"""np.random.choice does not allow for matrix `p`
`searchsorted` might be better adapted here.
"""
cumsum = probability_matrix.cumsum(axis=1)
rand = np.random.rand(probability_matrix.shape[1])
idx = (cumsum < rand).sum(axis=1)

return items[idx]
67 changes: 67 additions & 0 deletions tests/text/sequences/test_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np
import pytest

from outlines.models import transformers
from outlines.text.sequences.completion import completion


def test_completion_step_integration_single_prompt():
model_name = "hf-internal-testing/tiny-random-GPT2LMHeadModel"
model = transformers(model_name)
complete = completion(model)

prompt = "test"
tokens = model.tokenizer.encode(
prompt, padding=True, add_special_tokens=False, return_tensors="np"
)
tokens = np.array(tokens)

token_ids, probs = complete.step(tokens)
assert tokens.shape == (1, 2)
assert token_ids.shape == (1, 3)
assert probs.shape == (1, len(model.tokenizer.get_vocab()))
assert np.any(complete.is_finished(token_ids)).item() is False


def test_completion_step_integration_array_prompt():
model_name = "hf-internal-testing/tiny-random-GPT2LMHeadModel"
model = transformers(model_name)
complete = completion(model)

prompts = ["test1", "test2", "test3 very long"]
output = model.tokenizer.batch_encode_plus(
prompts, padding=True, add_special_tokens=False, return_tensors="np"
)

token_ids, probs = complete.step(
output["input_ids"], attention_mask=output["attention_mask"]
)
assert output["input_ids"].shape == (3, 6)
assert token_ids.shape == (3, 7)
assert probs.shape == (3, len(model.tokenizer.get_vocab()))
assert np.any(complete.is_finished(token_ids)).item() is False


@pytest.mark.xfail
def test_completion_step_integration_samples():
raise NotImplementedError


@pytest.mark.xfail
def test_completion_step_single_prompt():
raise NotImplementedError


@pytest.mark.xfail
def test_completion_array_prompt():
raise NotImplementedError


@pytest.mark.xfail
def test_completion_single_prompt():
raise NotImplementedError


@pytest.mark.xfail
def test_completion_single_prompt_sample():
raise NotImplementedError

0 comments on commit 76879da

Please sign in to comment.