From 10c804a56465bcd828bdb2dbfcf74188749a2733 Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Wed, 5 Jul 2023 12:04:01 +0200 Subject: [PATCH] Perplexity Eval for Text Generation Models (#1073) * initial commit * Update src/deepsparse/license.py * limit to 150mb * ready to review * initial commit * [Codegen][ORT][Static Seq Length] TextGenerationPipeline (#946) * initial commit * coreys simplifications * finishing the second model static * ready, time for beautification * ready for review * moved the code to examples * fix eos logic * add argument num_tokens_to_generate * [CodeGen][Documentation] (#956) * initial commit * coreys simplifications * finishing the second model static * ready, time for beautification * ready for review * moved the code to examples * fix eos logic * add argument num_tokens_to_generate * initial commit * change order * Update examples/codegen/README.md Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com> --------- Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com> * reimplementation for generative pipelines * restore text generation from examples * [CodeGen] ONNX model loading to support >2Gb models / two engines (#991) * refactor sucessfull * Pipeline fully refactored, time to test engine support. Note: Sliding window not yet implemented! * First iteration with Sage * Apply suggestions from code review * ORT agrees with the Engine. But they both give not entirely correct result. Hey, this is good news still * dynamic ORT vs static DS * pipeline handles OPT multitoken pass * fixes to get static pipeline a little further along * adjust shapes and slicing to enable static autoregressive pass - ISSUE: tokens past the base seq len are repeated * migrate from cache_length to positions input * got if working for multitoken + single token scenario * cleanup the pipeline * further cleanup post merge * Pipeline working for single-token inference only * do not load the onnx model with external files twice * pipeline never redundantly saves the external data + more robust tokenizer * Stop saving tmp files, otherwise the engine looks for external files in the wrong place * Left pad support * cleanup * cleanup2 * Add in pipeline timing * add in force tokens logic * remove input validation for text generation pipelines * remove multitoken support for now * remove kv cache engine and other fixes * nest input shape override * comment out input shape override * add non batch override for ORT * clean up generation pipeline * initial commit * Update src/deepsparse/license.py * limit to 150mb * ready to review * fix the erronous Makefile * perhaps fixed GHA * take into consideration that GHA creates four files * initial commit * tested with actual model * remove val_inp argument * Update README.md * Apply suggestions from code review * Update README.md * [BugFix] Update deepsparse dockerfile (#1069) * Remove autoinstall triggering commands * Fix typo * initial implementation * working implementation for pipeline input * [Fix] Fix CLI benchmark errors (#1071) * initial commit * ready for review * Update src/deepsparse/utils/onnx.py * Clean a typo in the pipeline code * cleanup the old files * Update src/deepsparse/transformers/engines/nl_decoder_engine.py * ready for review * ready for testing * assert proper padding on pipeline init * now also supporting kv cache perplexity. time for cleanup * ready for review * correctly print engine info * work with left padding of the tokenizer * quality * fix the multitoken inference --------- Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com> Co-authored-by: Mark Kurtz Co-authored-by: Benjamin Co-authored-by: Rahul Tuli --- .../transformers/engines/nl_decoder_engine.py | 8 +- .../transformers/eval_downstream.py | 42 ++++++- src/deepsparse/transformers/metrics.py | 103 +++++++++++++++++- .../transformers/pipelines/text_generation.py | 70 ++++++++---- 4 files changed, 193 insertions(+), 30 deletions(-) diff --git a/src/deepsparse/transformers/engines/nl_decoder_engine.py b/src/deepsparse/transformers/engines/nl_decoder_engine.py index 809befce37..f75264db14 100644 --- a/src/deepsparse/transformers/engines/nl_decoder_engine.py +++ b/src/deepsparse/transformers/engines/nl_decoder_engine.py @@ -154,10 +154,7 @@ def __call__( else: logits = out[0] - B, S, V = logits.shape # batch, sequence, vocab - logits = logits[:, -1, :].reshape(B, 1, V) # only take the last token - - token = self.generate_token(logits=logits) + token = self.generate_token(logits=logits[:, -1, :]) return token, logits @@ -253,6 +250,9 @@ def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray: return numpy.random.choice(len(probs), p=probs) + def __str__(self): + return f"{self.__class__.__name__}: {self.engine}" + def _initialize_kv_cache_state(self, length: int) -> Dict[str, numpy.ndarray]: # initialize empty kv cache of size # (batch_size, num_attention_heads, length, hidden_dims) diff --git a/src/deepsparse/transformers/eval_downstream.py b/src/deepsparse/transformers/eval_downstream.py index 01d0861580..6e6fa16b20 100644 --- a/src/deepsparse/transformers/eval_downstream.py +++ b/src/deepsparse/transformers/eval_downstream.py @@ -68,14 +68,43 @@ import numpy from tqdm.auto import tqdm -from deepsparse import Pipeline -from deepsparse.transformers.metrics import PrecisionRecallF1 +from deepsparse import DEEPSPARSE_ENGINE, ORT_ENGINE, Pipeline +from deepsparse.transformers.metrics import Perplexity, PrecisionRecallF1 from datasets import load_dataset, load_metric # isort: skip -DEEPSPARSE_ENGINE = "deepsparse" -ORT_ENGINE = "onnxruntime" + +def perplexity_eval(args, batch_size=16, dataset_name="openai_humaneval"): + dataset = load_dataset(dataset_name)["test"] + + text_generation = Pipeline.create( + task="text-generation", + model_path=args.model_path, + engine_type=args.engine, + num_cores=args.num_cores, + sequence_length=args.max_sequence_length, + prompt_processing_sequence_length=args.max_sequence_length, + max_generated_tokens=1, + remove_special_tokens_from_prompt=False, + ) + perplexity_metrics = Perplexity(pipeline=text_generation, batch_size=batch_size) + active_engines = [ + engine + for engine in [text_generation.engine, text_generation.multitoken_engine] + if engine + ] + print("Engine info: ") + [print(f"{engine}\n") for engine in active_engines] + predictions = [] + for idx, sample in _enumerate_progress(dataset, args.max_samples): + predictions.append(sample["prompt"] + sample["canonical_solution"]) + if len(predictions) == batch_size: + perplexity_metrics.add_batch(predictions) + predictions = [] + if args.max_samples and idx >= args.max_samples: + break + return perplexity_metrics def qa_eval(args, dataset_name="squad"): @@ -443,11 +472,14 @@ def _split_train_val(train_dataset, val_ratio, seed=42): "imdb": imdb_eval, "conll2003": conll2003_eval, "go_emotions": go_emotions_eval, + "openai_humaneval": perplexity_eval, } def parse_args(): parser = argparse.ArgumentParser( + # TODO: It is not BERT anymore, should we + # have another script or modify the existing one? description="Evaluate a BERT ONNX model on a downstream dataset" ) parser.add_argument( @@ -461,9 +493,9 @@ def parse_args(): parser.add_argument( "-d", "--dataset", - type=str, choices=list(SUPPORTED_DATASETS.keys()), required=True, + type=str, ) parser.add_argument( "-v", diff --git a/src/deepsparse/transformers/metrics.py b/src/deepsparse/transformers/metrics.py index 407e9b9d6b..ef5dd521eb 100644 --- a/src/deepsparse/transformers/metrics.py +++ b/src/deepsparse/transformers/metrics.py @@ -17,18 +17,119 @@ """ -from typing import Dict, Optional +from typing import Any, Dict, List, Optional import numpy +from tqdm import tqdm +import torch +from deepsparse import Pipeline +from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline from sklearn.metrics import precision_recall_fscore_support __all__ = [ "PrecisionRecallF1", + "Perplexity", ] +class Perplexity: + def __init__(self, pipeline: Pipeline, batch_size: int = 16): + """ + Given the pipeline, compute the perplexity of the model + on the given text input. + + Code adapted from: + https://huggingface.co/spaces/evaluate-metric/perplexity/blob/main/perplexity.py # noqa: E501 + + :param pipeline: The pipeline to use for text generation + :param batch_size: The batch size to split the input text into + non-overlapping batches + """ + if not isinstance(pipeline, TextGenerationPipeline): + raise ValueError( + "Perplexity can only be computed for text generation pipelines" + ) + self._pipeline = pipeline + self._batch_size = batch_size + self._sequence_length = pipeline.sequence_length + self._loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + + self.perplexities = [] + + def add_batch(self, predictions: List[str]): + """ + Run the model on the given input sequences and compute the perplexity. + The resulting perplexity is appended to the list of perplexities. + + :param predictions: The predictions to compute perplexity on + """ + # tokenize the input text + encodings = self._pipeline.tokenizer( + predictions, + return_attention_mask=True, + max_length=self._sequence_length, + truncation=True, + padding="max_length", + ) + + encoded_texts = encodings["input_ids"] + attention_masks = encodings["attention_mask"] + + for start_index in tqdm(range(0, len(encoded_texts), self._batch_size)): + end_index = min(start_index + self._batch_size, len(encoded_texts)) + encoded_batch = encoded_texts[start_index:end_index] + attention_mask = attention_masks[start_index:end_index] + + out = self._pipeline( + sequences=predictions, return_logits=True, truncate=True + ) + logits = out.logits + + labels = encoded_batch + labels = numpy.stack(labels) + attention_mask = numpy.stack(attention_mask) + + # because the tokenizer is left padded, we need to move the meaningful + # part of the logits and labels to the right + num_padded_entries = attention_mask.sum(axis=1) + + # shift the values at num_paddings to the top of the array using roll + for i, num_padded in enumerate(num_padded_entries): + logits[i] = numpy.roll(logits[i], num_padded, axis=0) + labels[i] = numpy.roll(labels[i], num_padded, axis=0) + attention_mask[i] = numpy.roll(attention_mask[i], num_padded, axis=0) + + # shift logits and labels create the input and target for the loss function + shift_logits = logits[:, :-1, :] + shift_labels = labels[:, 1:] + shift_attention_mask_batch = attention_mask[:, 1:] + + # compute perplexity for this batch + perplexity_batch = torch.exp( + ( + self._loss_fct( + torch.tensor(shift_logits.transpose(0, 2, 1)), + torch.tensor(shift_labels), + ) + * torch.tensor(shift_attention_mask_batch) + ).sum(1) + / torch.tensor(shift_attention_mask_batch).sum(1) + ) + self.perplexities.extend(perplexity_batch.numpy().tolist()) + + def compute(self) -> Dict[str, Any]: + """ + :return: A dictionary containing the mean perplexity + and the list of perplexities + """ + return { + "mean_perplexity": numpy.mean(self.perplexities), + "perplexities": self.perplexities, + } + + class PrecisionRecallF1: def __init__(self, id_to_label: Optional[Dict[int, str]] = None): self._id_to_label = id_to_label diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 12726f4c06..4a41b8b32d 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import List, Optional, Tuple, Type, Union import numpy from pydantic import BaseModel, Field @@ -33,7 +33,8 @@ class TextGenerationInput(BaseModel): return_logits: bool = Field( default=False, description="A flag that indicates whether to return " - "the logits for the generated text sequence. ", + "the logits for the input text sequence and the " + "generated text sequence. ", ) session_id: Optional[str] = Field( default=None, @@ -42,6 +43,13 @@ class TextGenerationInput(BaseModel): "and the model is using kv cache, it " "will be set to a random uuid.", ) + truncate: bool = Field( + default=False, + description="A flag that indicates whether to truncate " + "the input text sequence. Useful, when a batch of " + "predictions needs to have consistent length so one" + "can compute metric in a batched fashion. ", + ) class TextGenerationOutput(BaseModel): @@ -89,6 +97,8 @@ class TextGenerationPipeline(TransformersPipeline): of tokens supplied even if the stop token is reached. :param use_deepsparse_cache: if True, the pipeline will use the deepsparse kv cache for caching the model outputs. + :param remove_special_tokens_from_prompt: if True, the pipeline will remove + the special tokens from the prompt, before processing it. Defaults to True. :param kwargs: kwargs to pass to the TransformersPipeline """ @@ -101,6 +111,7 @@ def __init__( prompt_processing_sequence_length: int = 128, force_max_tokens: bool = False, use_deepsparse_cache: bool = False, + remove_special_tokens_from_prompt: bool = True, **kwargs, ): if use_deepsparse_cache: @@ -125,11 +136,15 @@ def __init__( self.max_generated_tokens = max_generated_tokens self.prompt_processing_sequence_length = prompt_processing_sequence_length self.force_max_tokens = force_max_tokens + self.remove_special_tokens_from_prompt = remove_special_tokens_from_prompt # override tokenizer to pad to left self.tokenizer.padding_side = "left" + if not self.tokenizer.pad_token: + self.tokenizer.pad_token = self.tokenizer.eos_token self.engine = None + self.multitoken_engine = NLDecoderEngine( onnx_file_path=self.onnx_file_path, engine_type=self.engine_type, @@ -158,6 +173,15 @@ def __init__( tokenizer=self.tokenizer, use_deepsparse_cache=use_deepsparse_cache, ) + if ( + not self.multitoken_engine.kv_cache_enabled + and self.max_generated_tokens > 1 + ): + raise ValueError( + "The model used for inference does not support kv cache. It is " + "assumed that it maps from the token sequence to predicted logits." + "Set `max_generated_tokens` to 1 to support that scenario." + ) @staticmethod def route_input_to_bucket( @@ -200,13 +224,12 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]: :return: the inputs for the engine """ - self.tokenizer.pad_token = self.tokenizer.eos_token - input_tokens = self.tokenizer( inputs.sequences, return_tensors="np", max_length=self.sequence_length, padding="max_length", + truncation=inputs.truncate, ) attention_mask = input_tokens["attention_mask"] @@ -239,8 +262,11 @@ def process_engine_outputs( :return: the output schema for the pipeline """ generated_tokens, generated_logits = engine_outputs + if generated_tokens.ndim == 1: + # if we have a single dimension, add a batch dimension + generated_tokens = generated_tokens[None, :] sequences = self.tokenizer.batch_decode( - *generated_tokens, skip_special_tokens=True + generated_tokens, skip_special_tokens=True ) logits = generated_logits if kwargs.get("return_logits") else None @@ -259,17 +285,12 @@ def engine_forward( of logits for each generated token """ if not self.multitoken_engine.kv_cache_enabled: - if self.max_generated_tokens != 1: - raise ValueError( - "The model used for inference does not support kv cache. It is " - "assumed that it maps from the token sequence to predicted logits." - "Set `max_generated_tokens` to 1 to support that scenario." - ) - tokens, logits = self.multitoken_engine(engine_inputs) - tokens = [tokens] + tokens, prompt_logits = self.multitoken_engine(engine_inputs) + return numpy.array([tokens]), prompt_logits + else: # run the prompt through - tokens, logits = self.prompt_inference(engine_inputs) + tokens, prompt_logits = self.prompt_inference(engine_inputs) # create the generated output max_tokens = ( @@ -279,7 +300,7 @@ def engine_forward( ) # set safety for absolute max generation generated_tokens = [tokens[-1]] - generated_logits = [logits] + generated_logits = prompt_logits while len(generated_tokens) < max_tokens: ( @@ -293,13 +314,13 @@ def engine_forward( if token == self.tokenizer.eos_token_id and not self.force_max_tokens: break - return numpy.array([[generated_tokens]]), numpy.concatenate( + return numpy.array(generated_tokens), numpy.concatenate( generated_logits, axis=1 ) def prompt_inference( self, engine_inputs: List[numpy.ndarray] - ) -> Tuple[List[int], Dict[str, numpy.ndarray]]: + ) -> Tuple[List[int], List[numpy.ndarray]]: """ An inference run that processes the prompt through the model to generate the new token and logits @@ -311,8 +332,14 @@ def prompt_inference( - The logits generated from the prompt (with dimensions ['batch_size', 'num_tokens', 'vocab_size']) """ - # get tokens by attention mask - tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist() + tokens = engine_inputs[0] + if self.remove_special_tokens_from_prompt: + # get tokens by attention mask + tokens = tokens[engine_inputs[1].nonzero()].tolist() + else: + tokens = tokens[0].tolist() + + prompt_logits = [] new_token = None num_tokens_processed = 0 @@ -326,6 +353,7 @@ def prompt_inference( ] new_token, new_logits = self.multitoken_engine(engine_inputs) num_tokens_processed = self.prompt_processing_sequence_length + prompt_logits.append(new_logits) if num_tokens_processed: # transfer the cache state from the multi-token engine to the main engine @@ -333,15 +361,17 @@ def prompt_inference( # prompt size is small, run autoregressive inference to populate kv cache run_tokens = [] if num_tokens_processed == 0 else tokens[:num_tokens_processed] + for token in tokens[num_tokens_processed:]: run_tokens.append(token) new_token, new_logits = self.autoregressive_inference( run_tokens, shift_positions_by_one=not bool(num_tokens_processed) ) + prompt_logits.append(new_logits) tokens.append(new_token) - return tokens, new_logits + return tokens, prompt_logits def autoregressive_inference( self,