-
Notifications
You must be signed in to change notification settings - Fork 546
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
124 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .completion import completion |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import numpy as np | ||
|
||
|
||
class Completion: | ||
"""Completion generation model. | ||
Sequence instances are unconstrained generation models that stop when a token or a | ||
sequence of tokens have been generated. | ||
>> import outlines.text as text | ||
>> sequence = text.sequence(model)("Say something") | ||
""" | ||
|
||
def __init__(self, model): | ||
self.model = model | ||
|
||
def is_finished(self, token_ids): | ||
is_finished = np.zeros((token_ids.shape[0],), dtype=np.bool_) | ||
is_finished[token_ids[:, -1] == self.model.tokenizer.eos_token_id] = True | ||
return is_finished | ||
|
||
def step(self, token_ids: np.ndarray, samples: int = 1, **model_kwargs): | ||
probs = self.model(token_ids, samples, **model_kwargs) | ||
next_token_ids = np.atleast_2d( | ||
vectorized_choice(np.arange(probs.shape[1]), probs) | ||
) | ||
token_ids = np.hstack([token_ids, next_token_ids.T]) | ||
return token_ids, probs | ||
|
||
def __call__(self, prompt: str): | ||
token_ids = self.model.tokenizer(prompt) | ||
|
||
while True: | ||
token_ids = self.model(token_ids) | ||
if self.is_finished(token_ids): | ||
break | ||
|
||
return self.model.tokenizer.decode(prompt) | ||
|
||
|
||
def completion(model): | ||
return Completion(model) | ||
|
||
|
||
def vectorized_choice(items, probability_matrix): | ||
"""np.random.choice does not allow for matrix `p` | ||
`searchsorted` might be better adapted here. | ||
""" | ||
cumsum = probability_matrix.cumsum(axis=1) | ||
rand = np.random.rand(probability_matrix.shape[1]) | ||
idx = (cumsum < rand).sum(axis=1) | ||
|
||
return items[idx] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from outlines.models import transformers | ||
from outlines.text.sequences.completion import completion | ||
|
||
|
||
def test_completion_step_integration_single_prompt(): | ||
model_name = "hf-internal-testing/tiny-random-GPT2LMHeadModel" | ||
model = transformers(model_name) | ||
complete = completion(model) | ||
|
||
prompt = "test" | ||
tokens = model.tokenizer.encode( | ||
prompt, padding=True, add_special_tokens=False, return_tensors="np" | ||
) | ||
tokens = np.array(tokens) | ||
|
||
token_ids, probs = complete.step(tokens) | ||
assert tokens.shape == (1, 2) | ||
assert token_ids.shape == (1, 3) | ||
assert probs.shape == (1, len(model.tokenizer.get_vocab())) | ||
assert np.any(complete.is_finished(token_ids)).item() is False | ||
|
||
|
||
def test_completion_step_integration_array_prompt(): | ||
model_name = "hf-internal-testing/tiny-random-GPT2LMHeadModel" | ||
model = transformers(model_name) | ||
complete = completion(model) | ||
|
||
prompts = ["test1", "test2", "test3 very long"] | ||
output = model.tokenizer.batch_encode_plus( | ||
prompts, padding=True, add_special_tokens=False, return_tensors="np" | ||
) | ||
|
||
token_ids, probs = complete.step( | ||
output["input_ids"], attention_mask=output["attention_mask"] | ||
) | ||
assert output["input_ids"].shape == (3, 6) | ||
assert token_ids.shape == (3, 7) | ||
assert probs.shape == (3, len(model.tokenizer.get_vocab())) | ||
assert np.any(complete.is_finished(token_ids)).item() is False | ||
|
||
|
||
@pytest.mark.xfail | ||
def test_completion_step_integration_samples(): | ||
raise NotImplementedError | ||
|
||
|
||
@pytest.mark.xfail | ||
def test_completion_step_single_prompt(): | ||
raise NotImplementedError | ||
|
||
|
||
@pytest.mark.xfail | ||
def test_completion_array_prompt(): | ||
raise NotImplementedError | ||
|
||
|
||
@pytest.mark.xfail | ||
def test_completion_single_prompt(): | ||
raise NotImplementedError | ||
|
||
|
||
@pytest.mark.xfail | ||
def test_completion_single_prompt_sample(): | ||
raise NotImplementedError |