From 76879dac86da8ba84aa28475e0ca68254d11b2b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 12 Jun 2023 15:50:28 +0200 Subject: [PATCH] Add `Completion` generation model --- outlines/text/sequences/__init__.py | 1 + outlines/text/sequences/completion.py | 56 +++++++++++++++++++++ tests/text/sequences/test_completion.py | 67 +++++++++++++++++++++++++ 3 files changed, 124 insertions(+) create mode 100644 outlines/text/sequences/__init__.py create mode 100644 outlines/text/sequences/completion.py create mode 100644 tests/text/sequences/test_completion.py diff --git a/outlines/text/sequences/__init__.py b/outlines/text/sequences/__init__.py new file mode 100644 index 000000000..a9bbd59ca --- /dev/null +++ b/outlines/text/sequences/__init__.py @@ -0,0 +1 @@ +from .completion import completion diff --git a/outlines/text/sequences/completion.py b/outlines/text/sequences/completion.py new file mode 100644 index 000000000..9c95e2a47 --- /dev/null +++ b/outlines/text/sequences/completion.py @@ -0,0 +1,56 @@ +import numpy as np + + +class Completion: + """Completion generation model. + + Sequence instances are unconstrained generation models that stop when a token or a + sequence of tokens have been generated. + + >> import outlines.text as text + >> sequence = text.sequence(model)("Say something") + + """ + + def __init__(self, model): + self.model = model + + def is_finished(self, token_ids): + is_finished = np.zeros((token_ids.shape[0],), dtype=np.bool_) + is_finished[token_ids[:, -1] == self.model.tokenizer.eos_token_id] = True + return is_finished + + def step(self, token_ids: np.ndarray, samples: int = 1, **model_kwargs): + probs = self.model(token_ids, samples, **model_kwargs) + next_token_ids = np.atleast_2d( + vectorized_choice(np.arange(probs.shape[1]), probs) + ) + token_ids = np.hstack([token_ids, next_token_ids.T]) + return token_ids, probs + + def __call__(self, prompt: str): + token_ids = self.model.tokenizer(prompt) + + while True: + token_ids = self.model(token_ids) + if self.is_finished(token_ids): + break + + return self.model.tokenizer.decode(prompt) + + +def completion(model): + return Completion(model) + + +def vectorized_choice(items, probability_matrix): + """np.random.choice does not allow for matrix `p` + + `searchsorted` might be better adapted here. + + """ + cumsum = probability_matrix.cumsum(axis=1) + rand = np.random.rand(probability_matrix.shape[1]) + idx = (cumsum < rand).sum(axis=1) + + return items[idx] diff --git a/tests/text/sequences/test_completion.py b/tests/text/sequences/test_completion.py new file mode 100644 index 000000000..85a1a4741 --- /dev/null +++ b/tests/text/sequences/test_completion.py @@ -0,0 +1,67 @@ +import numpy as np +import pytest + +from outlines.models import transformers +from outlines.text.sequences.completion import completion + + +def test_completion_step_integration_single_prompt(): + model_name = "hf-internal-testing/tiny-random-GPT2LMHeadModel" + model = transformers(model_name) + complete = completion(model) + + prompt = "test" + tokens = model.tokenizer.encode( + prompt, padding=True, add_special_tokens=False, return_tensors="np" + ) + tokens = np.array(tokens) + + token_ids, probs = complete.step(tokens) + assert tokens.shape == (1, 2) + assert token_ids.shape == (1, 3) + assert probs.shape == (1, len(model.tokenizer.get_vocab())) + assert np.any(complete.is_finished(token_ids)).item() is False + + +def test_completion_step_integration_array_prompt(): + model_name = "hf-internal-testing/tiny-random-GPT2LMHeadModel" + model = transformers(model_name) + complete = completion(model) + + prompts = ["test1", "test2", "test3 very long"] + output = model.tokenizer.batch_encode_plus( + prompts, padding=True, add_special_tokens=False, return_tensors="np" + ) + + token_ids, probs = complete.step( + output["input_ids"], attention_mask=output["attention_mask"] + ) + assert output["input_ids"].shape == (3, 6) + assert token_ids.shape == (3, 7) + assert probs.shape == (3, len(model.tokenizer.get_vocab())) + assert np.any(complete.is_finished(token_ids)).item() is False + + +@pytest.mark.xfail +def test_completion_step_integration_samples(): + raise NotImplementedError + + +@pytest.mark.xfail +def test_completion_step_single_prompt(): + raise NotImplementedError + + +@pytest.mark.xfail +def test_completion_array_prompt(): + raise NotImplementedError + + +@pytest.mark.xfail +def test_completion_single_prompt(): + raise NotImplementedError + + +@pytest.mark.xfail +def test_completion_single_prompt_sample(): + raise NotImplementedError