-
Notifications
You must be signed in to change notification settings - Fork 555
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
160 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import math | ||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
from numpy.typing import NDArray | ||
|
||
from outlines.models.tokenizer import Tokenizer | ||
|
||
if TYPE_CHECKING: | ||
from transformers import PreTrainedModel, PreTrainedTokenizer | ||
|
||
|
||
__all__ = ["transformers"] | ||
|
||
|
||
class Transformers: | ||
"""Represents a `transformers` model.""" | ||
|
||
def __init__( | ||
self, | ||
model: "PreTrainedModel", | ||
tokenizer: "PreTrainedTokenizer", | ||
device: Optional[str] = None, | ||
): | ||
self.device = device if device is not None else "cpu" | ||
self.model = model.to(self.device) | ||
self.tokenizer = tokenizer | ||
|
||
def __call__( | ||
self, input_ids: NDArray[np.int64], attention_mask: NDArray[np.int64] | ||
) -> NDArray[np.float64]: | ||
import torch | ||
|
||
# `transformers` model accept `input_ids` of size at most equal to 2. We | ||
# thus reshape the input array, call the model and reshape the output | ||
# logits. | ||
batch_shape = input_ids.shape[:-1] | ||
num_tokens = input_ids.shape[-1] | ||
input_ids = input_ids.reshape(math.prod(batch_shape), num_tokens) | ||
|
||
input_ids = torch.from_numpy(input_ids).to(self.device) | ||
attention_mask = torch.from_numpy(attention_mask).to(self.device) | ||
|
||
output = self.model(input_ids, attention_mask=attention_mask) | ||
|
||
next_token_logits = output.logits[:, -1, :] | ||
probs = torch.nn.functional.softmax(next_token_logits, dim=-1).squeeze() | ||
probs = torch.atleast_2d(probs) | ||
numpy_probs = probs.cpu().detach().numpy() | ||
|
||
return numpy_probs.reshape(batch_shape + (-1,)) | ||
|
||
|
||
class TransformersTokenizer(Tokenizer): | ||
"""Represents a tokenizer for models in the `transformers` library.""" | ||
|
||
def __init__(self, model_name: str, **kwargs): | ||
from transformers import AutoTokenizer | ||
|
||
kwargs.setdefault("padding_side", "left") | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs) | ||
self.eos_token_id = self.tokenizer.eos_token_id | ||
self.eos_token = self.tokenizer.eos_token | ||
|
||
if not self.tokenizer.pad_token: | ||
self.tokenizer.pad_token = self.tokenizer.eos_token | ||
self.pad_token_id = self.eos_token_id | ||
else: | ||
self.pad_token_id = self.tokenizer.pad_token_id | ||
self.pad_token = self.tokenizer.pad_token | ||
|
||
def encode( | ||
self, prompt: Union[str, List[str]], **kwargs | ||
) -> Tuple[NDArray[np.int64], NDArray[np.int64]]: | ||
kwargs["padding"] = True | ||
kwargs["return_tensors"] = "np" | ||
output = self.tokenizer(prompt, **kwargs) | ||
return output["input_ids"], output["attention_mask"] | ||
|
||
def decode(self, token_ids: NDArray[np.int64]) -> List[str]: | ||
text = self.tokenizer.batch_decode(token_ids) | ||
return text | ||
|
||
|
||
def transformers(model_name: str, device: Optional[str] = None, **model_kwargs): | ||
from transformers import AutoModelForCausalLM | ||
|
||
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) | ||
tokenizer = TransformersTokenizer(model_name) | ||
|
||
return Transformers(model, tokenizer, device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import numpy as np | ||
import pytest | ||
from numpy.testing import assert_array_equal | ||
from transformers.models.gpt2 import GPT2TokenizerFast | ||
|
||
from outlines.models.transformers import TransformersTokenizer, transformers | ||
|
||
TEST_MODEL = "hf-internal-testing/tiny-random-GPTJForCausalLM" | ||
|
||
|
||
def test_tokenizer(): | ||
tokenizer = TransformersTokenizer(TEST_MODEL) | ||
assert tokenizer.eos_token_id == 0 | ||
assert tokenizer.pad_token_id == 0 | ||
assert isinstance(tokenizer.tokenizer, GPT2TokenizerFast) | ||
|
||
token_ids, attention_mask = tokenizer.encode("Test") | ||
assert token_ids.ndim == 2 | ||
assert token_ids.shape[0] == 1 | ||
assert isinstance(token_ids, np.ndarray) | ||
assert token_ids.shape == attention_mask.shape | ||
|
||
token_ids, attention_mask = tokenizer.encode(["Test", "Test"]) | ||
assert token_ids.ndim == 2 | ||
assert token_ids.shape[0] == 2 | ||
assert isinstance(token_ids, np.ndarray) | ||
assert token_ids.shape == attention_mask.shape | ||
|
||
token_ids, attention_mask = tokenizer.encode(["Test", "A long sentence"]) | ||
assert token_ids.shape == attention_mask.shape | ||
assert attention_mask[0][0] == tokenizer.pad_token_id | ||
|
||
text = tokenizer.decode(np.array([[0, 1, 2]])) | ||
isinstance(text, str) | ||
|
||
text = tokenizer.decode(np.array([[0, 1, 2], [3, 4, 5]])) | ||
isinstance(text, list) | ||
isinstance(text[0], str) | ||
isinstance(text[1], str) | ||
|
||
|
||
def test_model(): | ||
with pytest.raises(RuntimeError, match="Expected one of cpu, cuda"): | ||
transformers(TEST_MODEL, device="non_existent") | ||
|
||
model = transformers(TEST_MODEL, device="cpu") | ||
assert isinstance(model.tokenizer, TransformersTokenizer) | ||
assert model.device == "cpu" | ||
|
||
input_ids = np.array([[0, 1, 2]]) | ||
logits = model(input_ids, np.ones_like(input_ids)) | ||
assert isinstance(logits, np.ndarray) | ||
assert logits.ndim == 2 | ||
assert logits.shape[0] == 1 | ||
|
||
input_ids = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) | ||
logits = model(input_ids, np.ones_like(input_ids)) | ||
assert isinstance(logits, np.ndarray) | ||
assert logits.ndim == 2 | ||
assert logits.shape[0] == 3 | ||
|
||
input_ids = np.array([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [0, 1, 2]]]) | ||
logits = model(input_ids, np.ones_like(input_ids)) | ||
assert logits.ndim == 3 | ||
assert logits.shape[0] == 2 | ||
assert logits.shape[1] == 2 | ||
assert_array_equal(logits[0][0], logits[1][1]) |