Skip to content

Commit

Permalink
Add the Transformers model
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jun 27, 2023
1 parent a713819 commit 5860f10
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 1 deletion.
1 change: 1 addition & 0 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .hf_diffusers import HuggingFaceDiffuser
from .hf_transformers import HuggingFaceCompletion
from .openai import OpenAICompletion, OpenAIEmbeddings, OpenAIImageGeneration
from .transformers import transformers
91 changes: 91 additions & 0 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import math
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import numpy as np
from numpy.typing import NDArray

from outlines.models.tokenizer import Tokenizer

if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer


__all__ = ["transformers"]


class Transformers:
"""Represents a `transformers` model."""

def __init__(
self,
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
device: Optional[str] = None,
):
self.device = device if device is not None else "cpu"
self.model = model.to(self.device)
self.tokenizer = tokenizer

def __call__(
self, input_ids: NDArray[np.int64], attention_mask: NDArray[np.int64]
) -> NDArray[np.float64]:
import torch

# `transformers` model accept `input_ids` of size at most equal to 2. We
# thus reshape the input array, call the model and reshape the output
# logits.
batch_shape = input_ids.shape[:-1]
num_tokens = input_ids.shape[-1]
input_ids = input_ids.reshape(math.prod(batch_shape), num_tokens)

input_ids = torch.from_numpy(input_ids).to(self.device)
attention_mask = torch.from_numpy(attention_mask).to(self.device)

output = self.model(input_ids, attention_mask=attention_mask)

next_token_logits = output.logits[:, -1, :]
probs = torch.nn.functional.softmax(next_token_logits, dim=-1).squeeze()
probs = torch.atleast_2d(probs)
numpy_probs = probs.cpu().detach().numpy()

return numpy_probs.reshape(batch_shape + (-1,))


class TransformersTokenizer(Tokenizer):
"""Represents a tokenizer for models in the `transformers` library."""

def __init__(self, model_name: str, **kwargs):
from transformers import AutoTokenizer

kwargs.setdefault("padding_side", "left")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs)
self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = self.tokenizer.eos_token

if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.pad_token_id = self.eos_token_id
else:
self.pad_token_id = self.tokenizer.pad_token_id
self.pad_token = self.tokenizer.pad_token

def encode(
self, prompt: Union[str, List[str]], **kwargs
) -> Tuple[NDArray[np.int64], NDArray[np.int64]]:
kwargs["padding"] = True
kwargs["return_tensors"] = "np"
output = self.tokenizer(prompt, **kwargs)
return output["input_ids"], output["attention_mask"]

def decode(self, token_ids: NDArray[np.int64]) -> List[str]:
text = self.tokenizer.batch_decode(token_ids)
return text


def transformers(model_name: str, device: Optional[str] = None, **model_kwargs):
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
tokenizer = TransformersTokenizer(model_name)

return Transformers(model, tokenizer, device)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ module = [
"tenacity.*",
"tiktoken.*",
"torch",
"transformers",
"transformers.*",
]
ignore_missing_imports = true

Expand Down
67 changes: 67 additions & 0 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from transformers.models.gpt2 import GPT2TokenizerFast

from outlines.models.transformers import TransformersTokenizer, transformers

TEST_MODEL = "hf-internal-testing/tiny-random-GPTJForCausalLM"


def test_tokenizer():
tokenizer = TransformersTokenizer(TEST_MODEL)
assert tokenizer.eos_token_id == 0
assert tokenizer.pad_token_id == 0
assert isinstance(tokenizer.tokenizer, GPT2TokenizerFast)

token_ids, attention_mask = tokenizer.encode("Test")
assert token_ids.ndim == 2
assert token_ids.shape[0] == 1
assert isinstance(token_ids, np.ndarray)
assert token_ids.shape == attention_mask.shape

token_ids, attention_mask = tokenizer.encode(["Test", "Test"])
assert token_ids.ndim == 2
assert token_ids.shape[0] == 2
assert isinstance(token_ids, np.ndarray)
assert token_ids.shape == attention_mask.shape

token_ids, attention_mask = tokenizer.encode(["Test", "A long sentence"])
assert token_ids.shape == attention_mask.shape
assert attention_mask[0][0] == tokenizer.pad_token_id

text = tokenizer.decode(np.array([[0, 1, 2]]))
isinstance(text, str)

text = tokenizer.decode(np.array([[0, 1, 2], [3, 4, 5]]))
isinstance(text, list)
isinstance(text[0], str)
isinstance(text[1], str)


def test_model():
with pytest.raises(RuntimeError, match="Expected one of cpu, cuda"):
transformers(TEST_MODEL, device="non_existent")

model = transformers(TEST_MODEL, device="cpu")
assert isinstance(model.tokenizer, TransformersTokenizer)
assert model.device == "cpu"

input_ids = np.array([[0, 1, 2]])
logits = model(input_ids, np.ones_like(input_ids))
assert isinstance(logits, np.ndarray)
assert logits.ndim == 2
assert logits.shape[0] == 1

input_ids = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
logits = model(input_ids, np.ones_like(input_ids))
assert isinstance(logits, np.ndarray)
assert logits.ndim == 2
assert logits.shape[0] == 3

input_ids = np.array([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [0, 1, 2]]])
logits = model(input_ids, np.ones_like(input_ids))
assert logits.ndim == 3
assert logits.shape[0] == 2
assert logits.shape[1] == 2
assert_array_equal(logits[0][0], logits[1][1])

0 comments on commit 5860f10

Please sign in to comment.