diff --git a/examples/pytorch/text-generation/README.md b/examples/pytorch/text-generation/README.md index 2177c45c3b88..fce4aef86b14 100644 --- a/examples/pytorch/text-generation/README.md +++ b/examples/pytorch/text-generation/README.md @@ -18,7 +18,7 @@ limitations under the License. Based on the script [`run_generation.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-generation/run_generation.py). -Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL, XLNet, CTRL. +Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, GPTJ, Transformer-XL, XLNet, CTRL, BLOOM, LLAMA, OPT. A similar script is used for our official demo [Write With Transfomer](https://transformer.huggingface.co), where you can try out the different models available in the library. diff --git a/examples/pytorch/text-generation/run_generation.py b/examples/pytorch/text-generation/run_generation.py index e0dda0ec0c2f..75221934da85 100755 --- a/examples/pytorch/text-generation/run_generation.py +++ b/examples/pytorch/text-generation/run_generation.py @@ -19,6 +19,7 @@ import argparse +import inspect import logging from typing import Tuple @@ -26,13 +27,20 @@ import torch from transformers import ( + AutoTokenizer, + BloomForCausalLM, + BloomTokenizerFast, CTRLLMHeadModel, CTRLTokenizer, GenerationMixin, GPT2LMHeadModel, GPT2Tokenizer, + GPTJForCausalLM, + LlamaForCausalLM, + LlamaTokenizer, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, + OPTForCausalLM, TransfoXLLMHeadModel, TransfoXLTokenizer, XLMTokenizer, @@ -59,6 +67,10 @@ "xlnet": (XLNetLMHeadModel, XLNetTokenizer), "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer), "xlm": (XLMWithLMHeadModel, XLMTokenizer), + "gptj": (GPTJForCausalLM, AutoTokenizer), + "bloom": (BloomForCausalLM, BloomTokenizerFast), + "llama": (LlamaForCausalLM, LlamaTokenizer), + "opt": (OPTForCausalLM, GPT2Tokenizer), } # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia @@ -173,23 +185,26 @@ def sparse_model_config(model_config): raise ValueError("Check the model config") num_embedding_size_per_head = int(embedding_size / num_head) - num_layer = model_config.n_layer + if hasattr(model_config, "n_layer"): + num_layer = model_config.n_layer + elif hasattr(model_config, "num_hidden_layers"): + num_layer = model_config.num_hidden_layers + else: + raise ValueError("Number of hidden layers couldn't be determined from the model config") return num_layer, num_head, num_embedding_size_per_head -def prepare_jit_inputs(inputs, model, tokenizer): - num_batch = len(inputs) - dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True) +def generate_past_key_values(model, batch_size, seq_len): num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config) if model.config.model_type == "bloom": past_key_values = tuple( ( - torch.zeros(int(num_attention_heads * num_batch), num_embedding_size_per_head, 1) - .to(model.config.torch_dtype) + torch.empty(int(num_attention_heads * batch_size), num_embedding_size_per_head, seq_len) + .to(model.dtype) .to(model.device), - torch.zeros(int(num_attention_heads * num_batch), 1, num_embedding_size_per_head) - .to(model.config.torch_dtype) + torch.empty(int(num_attention_heads * batch_size), seq_len, num_embedding_size_per_head) + .to(model.dtype) .to(model.device), ) for _ in range(num_block_layers) @@ -197,37 +212,34 @@ def prepare_jit_inputs(inputs, model, tokenizer): else: past_key_values = tuple( ( - torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head) - .to(model.config.torch_dtype) + torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head) + .to(model.dtype) .to(model.device), - torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head) - .to(model.config.torch_dtype) + torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head) + .to(model.dtype) .to(model.device), ) for _ in range(num_block_layers) ) + return past_key_values + +def prepare_jit_inputs(inputs, model, tokenizer): + batch_size = len(inputs) + dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt") + dummy_input = dummy_input.to(model.device) + if model.config.use_cache: + dummy_input["past_key_values"] = generate_past_key_values(model, batch_size, 1) dummy_input["attention_mask"] = torch.cat( [ - torch.zeros(dummy_input["attention_mask"].shape[0], 1).to(dummy_input["attention_mask"].dtype), + torch.zeros(dummy_input["attention_mask"].shape[0], 1) + .to(dummy_input["attention_mask"].dtype) + .to(model.device), dummy_input["attention_mask"], ], -1, ) - - if model.config.use_cache: - jit_inputs = ( - dummy_input["input_ids"].to(model.device), - past_key_values, - dummy_input["attention_mask"].to(model.device), - ) - else: - jit_inputs = ( - dummy_input["input_ids"].to(model.device), - dummy_input["attention_mask"].to(model.device), - ) - - return jit_inputs + return dummy_input class _ModelFallbackWrapper(GenerationMixin): @@ -238,15 +250,13 @@ def __init__(self, optimized, default): self._default = default def __call__(self, *args, **kwargs): - if kwargs["past_key_values"] is None: - return self._default(*args, **kwargs) - trace_graph_inputs = [] + if kwargs["past_key_values"] is None and self._default.config.use_cache: + kwargs["past_key_values"] = generate_past_key_values(self._default, kwargs["input_ids"].shape[0], 0) kwargs.pop("position_ids", None) - for k, v in kwargs.items(): - if v is not None and not isinstance(v, bool): - trace_graph_inputs.append(v) - trace_graph_inputs = tuple(trace_graph_inputs) - outputs = self._optimized(*trace_graph_inputs) + for k in list(kwargs.keys()): + if kwargs[k] is None or isinstance(kwargs[k], bool): + kwargs.pop(k) + outputs = self._optimized(**kwargs) lm_logits = outputs[0] past_key_values = outputs[1] fixed_output = CausalLMOutputWithPast( @@ -324,9 +334,7 @@ def main(): action="store_true", help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", ) - parser.add_argument( - "--jit", type=bool, default=False, help="Whether or not to use jit trace to accelerate inference" - ) + parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference") args = parser.parse_args() args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") @@ -351,8 +359,8 @@ def main(): if args.fp16: model.half() - - args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings) + max_seq_length = getattr(model.config, "max_position_embeddings", 0) + args.length = adjust_length_to_model(args.length, max_sequence_length=max_seq_length) logger.info(args) prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") @@ -382,10 +390,15 @@ def main(): input_ids = encoded_prompt if args.jit: - jit_input_texts = ["jit"] + jit_input_texts = ["enable jit"] jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer) torch._C._jit_set_texpr_fuser_enabled(False) model.config.return_dict = False + if hasattr(model, "forward"): + sig = inspect.signature(model.forward) + else: + sig = inspect.signature(model.__call__) + jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None) traced_model = torch.jit.trace(model, jit_inputs, strict=False) traced_model = torch.jit.freeze(traced_model.eval()) traced_model(*jit_inputs)