Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add GPTJ/bloom/llama/opt into model list and enhance the jit support #23291

Merged
merged 1 commit into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/pytorch/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License.

Based on the script [`run_generation.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-generation/run_generation.py).

Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL, XLNet, CTRL.
Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, GPTJ, Transformer-XL, XLNet, CTRL, BLOOM, LLAMA, OPT.
A similar script is used for our official demo [Write With Transfomer](https://transformer.huggingface.co), where you
can try out the different models available in the library.

Expand Down
95 changes: 54 additions & 41 deletions examples/pytorch/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,28 @@


import argparse
import inspect
import logging
from typing import Tuple

import numpy as np
import torch

from transformers import (
AutoTokenizer,
BloomForCausalLM,
BloomTokenizerFast,
CTRLLMHeadModel,
CTRLTokenizer,
GenerationMixin,
GPT2LMHeadModel,
GPT2Tokenizer,
GPTJForCausalLM,
LlamaForCausalLM,
LlamaTokenizer,
OpenAIGPTLMHeadModel,
OpenAIGPTTokenizer,
OPTForCausalLM,
TransfoXLLMHeadModel,
TransfoXLTokenizer,
XLMTokenizer,
Expand All @@ -59,6 +67,10 @@
"xlnet": (XLNetLMHeadModel, XLNetTokenizer),
"transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
"xlm": (XLMWithLMHeadModel, XLMTokenizer),
"gptj": (GPTJForCausalLM, AutoTokenizer),
"bloom": (BloomForCausalLM, BloomTokenizerFast),
"llama": (LlamaForCausalLM, LlamaTokenizer),
"opt": (OPTForCausalLM, GPT2Tokenizer),
}

# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
Expand Down Expand Up @@ -173,61 +185,61 @@ def sparse_model_config(model_config):
raise ValueError("Check the model config")

num_embedding_size_per_head = int(embedding_size / num_head)
num_layer = model_config.n_layer
if hasattr(model_config, "n_layer"):
num_layer = model_config.n_layer
elif hasattr(model_config, "num_hidden_layers"):
num_layer = model_config.num_hidden_layers
else:
raise ValueError("Number of hidden layers couldn't be determined from the model config")

return num_layer, num_head, num_embedding_size_per_head


def prepare_jit_inputs(inputs, model, tokenizer):
num_batch = len(inputs)
dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True)
def generate_past_key_values(model, batch_size, seq_len):
num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config)
if model.config.model_type == "bloom":
past_key_values = tuple(
(
torch.zeros(int(num_attention_heads * num_batch), num_embedding_size_per_head, 1)
.to(model.config.torch_dtype)
torch.empty(int(num_attention_heads * batch_size), num_embedding_size_per_head, seq_len)
.to(model.dtype)
.to(model.device),
torch.zeros(int(num_attention_heads * num_batch), 1, num_embedding_size_per_head)
.to(model.config.torch_dtype)
torch.empty(int(num_attention_heads * batch_size), seq_len, num_embedding_size_per_head)
.to(model.dtype)
.to(model.device),
)
for _ in range(num_block_layers)
)
else:
past_key_values = tuple(
(
torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head)
.to(model.config.torch_dtype)
torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head)
.to(model.dtype)
.to(model.device),
torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head)
.to(model.config.torch_dtype)
torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head)
.to(model.dtype)
.to(model.device),
)
for _ in range(num_block_layers)
)
return past_key_values


def prepare_jit_inputs(inputs, model, tokenizer):
batch_size = len(inputs)
dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt")
dummy_input = dummy_input.to(model.device)
if model.config.use_cache:
dummy_input["past_key_values"] = generate_past_key_values(model, batch_size, 1)
dummy_input["attention_mask"] = torch.cat(
[
torch.zeros(dummy_input["attention_mask"].shape[0], 1).to(dummy_input["attention_mask"].dtype),
torch.zeros(dummy_input["attention_mask"].shape[0], 1)
.to(dummy_input["attention_mask"].dtype)
.to(model.device),
dummy_input["attention_mask"],
],
-1,
)

if model.config.use_cache:
jit_inputs = (
dummy_input["input_ids"].to(model.device),
past_key_values,
dummy_input["attention_mask"].to(model.device),
)
else:
jit_inputs = (
dummy_input["input_ids"].to(model.device),
dummy_input["attention_mask"].to(model.device),
)

return jit_inputs
return dummy_input


class _ModelFallbackWrapper(GenerationMixin):
Expand All @@ -238,15 +250,13 @@ def __init__(self, optimized, default):
self._default = default

def __call__(self, *args, **kwargs):
if kwargs["past_key_values"] is None:
return self._default(*args, **kwargs)
trace_graph_inputs = []
if kwargs["past_key_values"] is None and self._default.config.use_cache:
kwargs["past_key_values"] = generate_past_key_values(self._default, kwargs["input_ids"].shape[0], 0)
kwargs.pop("position_ids", None)
for k, v in kwargs.items():
if v is not None and not isinstance(v, bool):
trace_graph_inputs.append(v)
trace_graph_inputs = tuple(trace_graph_inputs)
outputs = self._optimized(*trace_graph_inputs)
for k in list(kwargs.keys()):
if kwargs[k] is None or isinstance(kwargs[k], bool):
kwargs.pop(k)
outputs = self._optimized(**kwargs)
lm_logits = outputs[0]
past_key_values = outputs[1]
fixed_output = CausalLMOutputWithPast(
Expand Down Expand Up @@ -324,9 +334,7 @@ def main():
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--jit", type=bool, default=False, help="Whether or not to use jit trace to accelerate inference"
)
parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference")
args = parser.parse_args()

args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
Expand All @@ -351,8 +359,8 @@ def main():

if args.fp16:
model.half()

args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings)
max_seq_length = getattr(model.config, "max_position_embeddings", 0)
args.length = adjust_length_to_model(args.length, max_sequence_length=max_seq_length)
logger.info(args)

prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
Expand Down Expand Up @@ -382,10 +390,15 @@ def main():
input_ids = encoded_prompt

if args.jit:
jit_input_texts = ["jit"]
jit_input_texts = ["enable jit"]
jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer)
torch._C._jit_set_texpr_fuser_enabled(False)
model.config.return_dict = False
if hasattr(model, "forward"):
sig = inspect.signature(model.forward)
else:
sig = inspect.signature(model.__call__)
jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None)
traced_model = torch.jit.trace(model, jit_inputs, strict=False)
traced_model = torch.jit.freeze(traced_model.eval())
traced_model(*jit_inputs)
Expand Down