From 2475b3bb677c8685ab9a291c490783ae2ccce5b8 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Wed, 12 Jun 2024 09:22:31 -0700 Subject: [PATCH] LIT: Remove legacy Language Model demo. PiperOrigin-RevId: 642637693 --- lit_nlp/examples/blank_slate_demo.py | 25 -- lit_nlp/examples/datasets/__init__.py | 13 - lit_nlp/examples/datasets/classification.py | 59 --- lit_nlp/examples/datasets/lm.py | 131 ------- lit_nlp/examples/lm_demo.py | 182 ---------- lit_nlp/examples/models/pretrained_lms.py | 335 ------------------ .../models/pretrained_lms_int_test.py | 36 -- lit_nlp/examples/prompt_debugging/datasets.py | 126 ++++++- lit_nlp/examples/prompt_debugging/models.py | 2 +- .../prompt_examples.jsonl | 0 website/sphinx_src/api.md | 24 +- website/sphinx_src/demos.md | 15 - website/sphinx_src/docker.md | 3 +- website/sphinx_src/frontend_development.md | 4 +- 14 files changed, 120 insertions(+), 835 deletions(-) delete mode 100644 lit_nlp/examples/datasets/__init__.py delete mode 100644 lit_nlp/examples/datasets/classification.py delete mode 100644 lit_nlp/examples/datasets/lm.py delete mode 100644 lit_nlp/examples/lm_demo.py delete mode 100644 lit_nlp/examples/models/pretrained_lms.py delete mode 100644 lit_nlp/examples/models/pretrained_lms_int_test.py rename lit_nlp/examples/{datasets => prompt_debugging}/prompt_examples.jsonl (100%) diff --git a/lit_nlp/examples/blank_slate_demo.py b/lit_nlp/examples/blank_slate_demo.py index 470f57a9..a4c66c4d 100644 --- a/lit_nlp/examples/blank_slate_demo.py +++ b/lit_nlp/examples/blank_slate_demo.py @@ -29,11 +29,8 @@ from lit_nlp import app as lit_app from lit_nlp import dev_server from lit_nlp import server_flags -from lit_nlp.examples.datasets import classification -from lit_nlp.examples.datasets import lm from lit_nlp.examples.glue import data as glue_data from lit_nlp.examples.glue import models as glue_models -from lit_nlp.examples.models import pretrained_lms from lit_nlp.examples.penguin import data as penguin_data from lit_nlp.examples.penguin import model as penguin_model @@ -85,16 +82,6 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: penguin_model.PenguinModel.init_spec(), ) - # lm demo model loaders. - model_loaders["bert"] = ( - pretrained_lms.BertMLM, - pretrained_lms.BertMLM.init_spec(), - ) - model_loaders["gpt2"] = ( - pretrained_lms.GPT2LanguageModel, - pretrained_lms.GPT2LanguageModel.init_spec(), - ) - datasets = {} dataset_loaders: lit_app.DatasetLoadersMap = {} @@ -114,18 +101,6 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: glue_data.SST2DataForLM, glue_data.SST2DataForLM.init_spec(), ) - dataset_loaders["imdb (lm)"] = ( - classification.IMDBData, - classification.IMDBData.init_spec(), - ) - dataset_loaders["plain text sentences (lm)"] = ( - lm.PlaintextSents, - lm.PlaintextSents.init_spec(), - ) - dataset_loaders["bwb (lm)"] = ( - lm.BillionWordBenchmark, - lm.BillionWordBenchmark.init_spec(), - ) # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server( diff --git a/lit_nlp/examples/datasets/__init__.py b/lit_nlp/examples/datasets/__init__.py deleted file mode 100644 index c6334245..00000000 --- a/lit_nlp/examples/datasets/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/lit_nlp/examples/datasets/classification.py b/lit_nlp/examples/datasets/classification.py deleted file mode 100644 index 0d1e1df3..00000000 --- a/lit_nlp/examples/datasets/classification.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Text classification datasets, including single- and two-sentence tasks.""" -from typing import Optional - -from lit_nlp.api import dataset as lit_dataset -from lit_nlp.api import types as lit_types -import tensorflow_datasets as tfds - - -def load_tfds(*args, **kw): - """Load from TFDS.""" - # Materialize to NumPy arrays. - # This also ensures compatibility with TF1.x non-eager mode, which doesn't - # support direct iteration over a tf.data.Dataset. - return list( - tfds.as_numpy(tfds.load(*args, download=True, try_gcs=True, **kw))) - - -class IMDBData(lit_dataset.Dataset): - """IMDB reviews dataset; see http://ai.stanford.edu/~amaas/data/sentiment/.""" - - LABELS = ["0", "1"] - AVAILABLE_SPLITS = ["test", "train", "unsupervised"] - - def __init__( - self, split="test", max_seq_len=500, max_examples: Optional[int] = None - ): - """Dataset constructor, loads the data into memory.""" - raw_examples = load_tfds("imdb_reviews", split=split) - self._examples = [] # populate this with data records - for record in raw_examples[:max_examples]: - # format and truncate from the end to max_seq_len tokens. - truncated_text = " ".join( - record["text"] - .decode("utf-8") - .replace("
", "") - .split()[-max_seq_len:] - ) - self._examples.append({ - "text": truncated_text, - "label": self.LABELS[record["label"]], - }) - - @classmethod - def init_spec(cls) -> lit_types.Spec: - return { - "split": lit_types.CategoryLabel(vocab=cls.AVAILABLE_SPLITS), - "max_seq_len": lit_types.Integer(default=500, min_val=1, max_val=1024), - "max_examples": lit_types.Integer( - default=1000, min_val=0, max_val=10_000, required=False - ), - } - - def spec(self) -> lit_types.Spec: - """Dataset spec, which should match the model"s input_spec().""" - return { - "text": lit_types.TextSegment(), - "label": lit_types.CategoryLabel(vocab=self.LABELS), - } - diff --git a/lit_nlp/examples/datasets/lm.py b/lit_nlp/examples/datasets/lm.py deleted file mode 100644 index d2292f44..00000000 --- a/lit_nlp/examples/datasets/lm.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Language modeling datasets.""" - -import copy -import json -import os -import glob -from typing import Optional - -from absl import logging -from lit_nlp.api import dataset as lit_dataset -from lit_nlp.api import types as lit_types -import tensorflow_datasets as tfds - -SAMPLE_DATA_DIR = os.path.dirname(__file__) - - -class PlaintextSents(lit_dataset.Dataset): - """Load sentences from a flat text file.""" - - def __init__( - self, - path_or_glob: str, - skiplines: int = 0, - max_examples: Optional[int] = None, - field_name: str = 'text', - ): - self.field_name = field_name - self._examples = self.load_datapoints(path_or_glob, skiplines=skiplines)[ - :max_examples - ] - - @classmethod - def init_spec(cls) -> lit_types.Spec: - default_path = '' - - return { - 'path_or_glob': lit_types.String( - default=default_path, required=False - ), - 'skiplines': lit_types.Integer(default=0, max_val=25), - 'max_examples': lit_types.Integer( - default=1000, min_val=0, max_val=10_000, required=False - ), - } - - def load_datapoints(self, path_or_glob: str, skiplines: int = 0): - examples = [] - for path in glob.glob(path_or_glob): - with open(path) as fd: - for i, line in enumerate(fd): - if i < skiplines: # skip header lines, if necessary - continue - line = line.strip() - if line: # skip blank lines, these are usually document breaks - examples.append({self.field_name: line}) - return examples - - def load(self, path: str): - return lit_dataset.Dataset(base=self, examples=self.load_datapoints(path)) - - def spec(self) -> lit_types.Spec: - """Should match MLM's input_spec().""" - return {self.field_name: lit_types.TextSegment()} - - -class PromptExamples(lit_dataset.Dataset): - """Prompt examples for modern LMs.""" - - SAMPLE_DATA_PATH = os.path.join(SAMPLE_DATA_DIR, 'prompt_examples.jsonl') - - def load_datapoints(self, path: str): - if not path: - logging.warn( - 'Empty path to PromptExamples.load_datapoints(). Returning empty' - ' dataset.' - ) - return [] - - default_ex_values = { - k: copy.deepcopy(field_spec.default) - for k, field_spec in self.spec().items() - } - - examples = [] - with open(path) as fd: - for line in fd: - examples.append(default_ex_values | json.loads(line)) - - return examples - - def __init__(self, path: str): - self._examples = self.load_datapoints(path) - - def spec(self) -> lit_types.Spec: - return { - 'source': lit_types.CategoryLabel(), - 'prompt': lit_types.TextSegment(), - 'target': lit_types.TextSegment(), - } - - def load(self, path: str): - return lit_dataset.Dataset(base=self, examples=self.load_datapoints(path)) - - -class BillionWordBenchmark(lit_dataset.Dataset): - """Billion Word Benchmark (lm1b); see http://www.statmt.org/lm-benchmark/.""" - - AVAILABLE_SPLITS = ['test', 'train'] - - def __init__(self, split: str = 'train', max_examples: Optional[int] = None): - ds = tfds.load('lm1b', split=split) - if max_examples is not None: - # Normally we can just slice the resulting dataset, but lm1b is very large - # so we can use ds.take() to only load a portion of it. - ds = ds.take(max_examples) - raw_examples = list(tfds.as_numpy(ds)) - self._examples = [{ - 'text': ex['text'].decode('utf-8') - } for ex in raw_examples] - - @classmethod - def init_spec(cls) -> lit_types.Spec: - return { - 'split': lit_types.CategoryLabel(vocab=cls.AVAILABLE_SPLITS), - 'max_examples': lit_types.Integer( - default=1000, min_val=0, max_val=10_000, required=False - ), - } - - def spec(self) -> lit_types.Spec: - return {'text': lit_types.TextSegment()} diff --git a/lit_nlp/examples/lm_demo.py b/lit_nlp/examples/lm_demo.py deleted file mode 100644 index 5af2156b..00000000 --- a/lit_nlp/examples/lm_demo.py +++ /dev/null @@ -1,182 +0,0 @@ -r"""Example demo loading pre-trained language models. - -Currently supports the following model types: -- BERT (bert-*) as a masked language model -- GPT-2 (gpt2* or distilgpt2) as a left-to-right language model - -To run locally: - python -m lit_nlp.examples.lm_demo \ - --models=bert-base-uncased --port=5432 - -Then navigate to localhost:5432 to access the demo UI. -""" - -from collections.abc import Sequence -import sys -from typing import Optional - -from absl import app -from absl import flags -from absl import logging -from lit_nlp import app as lit_app -from lit_nlp import dev_server -from lit_nlp import server_flags -from lit_nlp.api import layout -from lit_nlp.components import word_replacer -from lit_nlp.examples.datasets import classification -from lit_nlp.examples.datasets import lm -from lit_nlp.examples.glue import data as glue_data -from lit_nlp.examples.models import pretrained_lms - -# NOTE: additional flags defined in server_flags.py - -FLAGS = flags.FLAGS - -FLAGS.set_default("development_demo", True) - -_MODELS = flags.DEFINE_list( - "models", - [ - "bert-base-uncased:https://storage.googleapis.com/what-if-tool-resources/lit-models/bert-base-uncased.tar.gz", - "gpt2:https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2.tar.gz", - ], - "Models to load, as :. Currently supports variants of BERT and" - " GPT-2.", -) - -_TOP_K = flags.DEFINE_integer( - "top_k", 10, "Rank to which the output distribution is pruned." -) - -_MAX_EXAMPLES = flags.DEFINE_integer( - "max_examples", - 1000, - ( - "Maximum number of examples to load from each evaluation set. Set to" - " None to load the full set." - ), -) - -_LOAD_BWB = flags.DEFINE_bool( - "load_bwb", - False, - ( - "If true, will load examples from the Billion Word Benchmark dataset." - " This may download a lot of data the first time you run it, so disable" - " by default for the quick-start example." - ), -) - -# Custom frontend layout; see api/layout.py -modules = layout.LitModuleName -LM_LAYOUT = layout.LitCanonicalLayout( - upper={ - "Main": [ - modules.EmbeddingsModule, - modules.DataTableModule, - modules.DatapointEditorModule, - ] - }, - lower={ - "Predictions": [ - modules.LanguageModelPredictionModule, - modules.ConfusionMatrixModule, - ], - "Counterfactuals": [modules.GeneratorModule], - }, - description="Custom layout for language models.", -) - -CUSTOM_LAYOUTS = layout.DEFAULT_LAYOUTS | {"lm": LM_LAYOUT} - -# You can also change this via URL param e.g. localhost:5432/?layout=default -FLAGS.set_default("default_layout", "lm") - - -def get_wsgi_app() -> Optional[dev_server.LitServerType]: - """Return WSGI app for container-hosted demos.""" - FLAGS.set_default("server_type", "external") - FLAGS.set_default("demo_mode", True) - # Parse flags without calling app.run(main), to avoid conflict with - # gunicorn command line flags. - unused = flags.FLAGS(sys.argv, known_only=True) - if unused: - logging.info("lm_demo:get_wsgi_app() called with unused args: %s", unused) - return main([]) - - -def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: - if len(argv) > 1: - raise app.UsageError("Too many command-line arguments.") - - ## - # Load models, according to the --models flag. - models = {} - for model_string in _MODELS.value: - # Only split on the first ':', because path may be a URL - # containing 'https://' - model_name, path = model_string.split(":", 1) - logging.info("Loading model '%s' from '%s'", model_name, path) - if model_name.startswith("bert-"): - models[model_name] = pretrained_lms.BertMLM(path, top_k=_TOP_K.value) - elif model_name.startswith("gpt2") or model_name in ["distilgpt2"]: - models[model_name] = pretrained_lms.GPT2LanguageModel( - path, top_k=_TOP_K.value - ) - else: - raise ValueError( - f"Unsupported model name '{model_name}' from path '{path}'" - ) - - datasets = { - # Single sentences from movie reviews (SST dev set). - "sst_dev": glue_data.SST2Data("validation").remap({"sentence": "text"}), - # Longer passages from movie reviews (IMDB dataset, test split). - "imdb_train": classification.IMDBData("test"), - # Empty dataset, if you just want to type sentences into the UI. - "blank": lm.PlaintextSents(""), - } - - dataset_loaders: lit_app.DatasetLoadersMap = { - "sst_dev": (glue_data.SST2DataForLM, glue_data.SST2DataForLM.init_spec()), - "imdb_train": ( - classification.IMDBData, - classification.IMDBData.init_spec(), - ), - "plain_text_sentences": ( - lm.PlaintextSents, - lm.PlaintextSents.init_spec(), - ), - } - - # Guard this with a flag, because TFDS will download and process 1.67 GB - # of data if you haven't loaded `lm1b` before. - if _LOAD_BWB.value: - # A few sentences from the Billion Word Benchmark (Chelba et al. 2013). - datasets["bwb"] = lm.BillionWordBenchmark( - "train", max_examples=_MAX_EXAMPLES.value - ) - dataset_loaders["bwb"] = ( - lm.BillionWordBenchmark, - lm.BillionWordBenchmark.init_spec(), - ) - - for name in datasets: - datasets[name] = datasets[name].slice[: _MAX_EXAMPLES.value] - logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) - - generators = {"word_replacer": word_replacer.WordReplacer()} - - lit_demo = dev_server.Server( - models, - datasets, - generators=generators, - layouts=CUSTOM_LAYOUTS, - dataset_loaders=dataset_loaders, - **server_flags.get_flags(), - ) - return lit_demo.serve() - - -if __name__ == "__main__": - app.run(main) diff --git a/lit_nlp/examples/models/pretrained_lms.py b/lit_nlp/examples/models/pretrained_lms.py deleted file mode 100644 index f9194e11..00000000 --- a/lit_nlp/examples/models/pretrained_lms.py +++ /dev/null @@ -1,335 +0,0 @@ -"""Wrapper for HuggingFace models in LIT. - -Includes BERT masked LM, GPT-2, and T5. - -This wrapper loads a model into memory and implements the a number of helper -functions to predict a batch of examples and extract information such as -hidden states and attention. -""" - -import re - -from lit_nlp.api import model as lit_model -from lit_nlp.api import types as lit_types -from lit_nlp.examples.models import model_utils -from lit_nlp.lib import file_cache -from lit_nlp.lib import utils -import numpy as np -import tensorflow as tf -import transformers - - -class BertMLM(lit_model.BatchedModel): - """BERT masked LM using Huggingface Transformers and TensorFlow 2.""" - - MASK_TOKEN = "[MASK]" - - @property - def max_seq_length(self): - return self.model.config.max_position_embeddings - - @classmethod - def init_spec(cls) -> lit_model.Spec: - return { - "model_name_or_path": lit_types.String(default="bert-base-uncased"), - "top_k": lit_types.Integer(default=10, min_val=1, max_val=25), - } - - def __init__(self, model_name_or_path="bert-base-uncased", top_k=10): - super().__init__() - - # Normally path is a directory; if it's an archive file, download and - # extract to the transformers cache. - if model_name_or_path.endswith(".tar.gz"): - model_name_or_path = file_cache.cached_path( - model_name_or_path, extract_compressed_file=True - ) - - self.tokenizer = transformers.AutoTokenizer.from_pretrained( - model_name_or_path, use_fast=False - ) - # TODO(lit-dev): switch to TFBertForPreTraining to get the next-sentence - # prediction head as well. - self.model = model_utils.load_pretrained( - transformers.TFBertForMaskedLM, - model_name_or_path, - output_hidden_states=True, - output_attentions=True, - ) - self.top_k = top_k - - # TODO(lit-dev): break this out as a helper function, write some tests, - # and de-duplicate code with the other text generation functions. - def _get_topk_tokens( - self, scores: np.ndarray - ) -> list[list[tuple[str, float]]]: - """Convert raw scores to top-k token predictions.""" - # scores is [num_tokens, vocab_size] - # Find the vocab indices of top k predictions, at each token. - # np.argpartition is faster than a full argsort for k << V, - # but we need to sort the output after slicing (see below). - index_array = np.argpartition(scores, -self.top_k, axis=1)[:, -self.top_k :] - # These are each [num_tokens, tok_k] - top_tokens = [ - self.tokenizer.convert_ids_to_tokens(idxs) for idxs in index_array - ] - top_scores = np.take_along_axis(scores, index_array, axis=1) - # Convert to a list of lists of (token, score) pairs, - # where inner lists are sorted in descending order of score. - return [ - sorted(list(zip(toks, scores)), key=lambda ab: -ab[1]) - for toks, scores in zip(top_tokens, top_scores) - ] - # TODO(lit-dev): consider returning indices and a vocab, since repeating - # strings is slow and redundant. - - def _postprocess(self, output: dict[str, np.ndarray]): - """Postprocess, modifying output dict in-place.""" - # Slice to remove padding, omitting initial [CLS] and final [SEP] - slicer = slice(1, output.pop("ntok") - 1) - output["tokens"] = self.tokenizer.convert_ids_to_tokens( - output.pop("input_ids")[slicer] - ) - probas = output.pop("probas") - - # Predictions at every position, regardless of masking. - output["pred_tokens"] = self._get_topk_tokens(probas[slicer]) # pytype: disable=container-type-mismatch - - return output - - ## - # LIT API implementations - def max_minibatch_size(self) -> int: - # The lit.Model base class handles batching automatically in the - # implementation of predict(), and uses this value as the batch size. - return 8 - - def predict_minibatch(self, inputs): - """Predict on a single minibatch of examples.""" - # If input has a 'tokens' field, use that. Otherwise tokenize the text. - tokenized_texts = [ - ex.get("tokens") or self.tokenizer.tokenize(ex["text"]) for ex in inputs - ] - encoded_input = model_utils.batch_encode_pretokenized( - self.tokenizer, tokenized_texts - ) - - # out.logits is a single tensor - # [batch_size, num_tokens, vocab_size] - # out.hidden_states is a list of num_layers + 1 tensors, each - # [batch_size, num_tokens, h_dim] - out: transformers.modeling_tf_outputs.TFMaskedLMOutput = self.model( - encoded_input - ) - batched_outputs = { - "probas": tf.nn.softmax(out.logits, axis=-1).numpy(), - "input_ids": encoded_input["input_ids"].numpy(), - "ntok": tf.reduce_sum(encoded_input["attention_mask"], axis=1).numpy(), - # last layer, first token - "cls_emb": out.hidden_states[-1][:, 0].numpy(), - } - # List of dicts, one per example. - unbatched_outputs = utils.unbatch_preds(batched_outputs) - # Postprocess to remove padding and decode predictions. - return map(self._postprocess, unbatched_outputs) - - def load(self, model_name_or_path): - """Dynamically load a new BertMLM model given a model name.""" - return BertMLM(model_name_or_path, self.top_k) - - def input_spec(self): - return { - "text": lit_types.TextSegment(), - "tokens": lit_types.Tokens(mask_token="[MASK]", required=False), - } - - def output_spec(self): - return { - "tokens": lit_types.Tokens(parent="text"), - "pred_tokens": lit_types.TokenTopKPreds(align="tokens"), - "cls_emb": lit_types.Embeddings(), - } - - -# TODO(lit-dev): merge with below, inherit from HFBaseModel. -class GPT2LanguageModel(lit_model.BatchedModel): - """Wrapper for a Huggingface Transformers GPT-2 model. - - This class loads a tokenizer and model using the Huggingface library and - provides the LIT-required functions plus additional helper functions to - convert and clean tokens and to compute the top_k predictions from logits. - """ - - @property - def num_layers(self): - return self.model.config.n_layer - - @classmethod - def init_spec(cls) -> lit_model.Spec: - return { - "model_name_or_path": lit_types.String(default="gpt2"), - "top_k": lit_types.Integer(default=10, min_val=1, max_val=25), - } - - def __init__(self, model_name_or_path="gpt2", top_k=10): - """Constructor for GPT2LanguageModel. - - Args: - model_name_or_path: gpt2, gpt2-medium, gpt2-large, gpt2-xl, distilgpt2, - etc. - top_k: How many predictions to prune. - """ - super().__init__() - - # Normally path is a directory; if it's an archive file, download and - # extract to the transformers cache. - if model_name_or_path.endswith(".tar.gz"): - model_name_or_path = file_cache.cached_path( - model_name_or_path, extract_compressed_file=True - ) - - self.tokenizer = transformers.AutoTokenizer.from_pretrained( - model_name_or_path, use_fast=False - ) - # Set this after init, as if pad_token= is passed to - # AutoTokenizer.from_pretrained() above it will create a new token with - # with id = max_vocab_length and cause out-of-bounds errors in - # the embedding lookup. - self.tokenizer.pad_token = self.tokenizer.eos_token - self.model = transformers.TFGPT2LMHeadModel.from_pretrained( - model_name_or_path, output_hidden_states=True, output_attentions=True - ) - self.top_k = top_k - - @staticmethod - def clean_bpe_token(tok): - if not tok.startswith("Ġ"): - return "_" + tok - else: - return tok.replace("Ġ", "") - - def ids_to_clean_tokens(self, ids): - tokens = self.tokenizer.convert_ids_to_tokens(ids) - return [self.clean_bpe_token(t) for t in tokens] - - def _pred(self, encoded_inputs): - """Predicts one batch of tokenized text. - - Also performs some batch-level post-processing in TF. - Single-example postprocessing is done in _postprocess(), and operates on - numpy arrays. - - Each prediction has the following returns: - logits: tf.Tensor (batch_size, sequence_length, config.vocab_size). - past: list[tf.Tensor] of length config.n_layers with each tensor shape - (2, batch_size, num_heads, sequence_length, embed_size_per_head)). - states: Tuple of tf.Tensor (one for embeddings + one for each layer), - with shape (batch_size, sequence_length, hidden_size). - attentions: Tuple of tf.Tensor (one for each layer) with shape - (batch_size, num_heads, sequence_length, sequence_length) - Within this function, we combine each Tuple/List into a single Tensor. - - Args: - encoded_inputs: output of self.tokenizer() - - Returns: - payload: Dictionary with items described above, each as single Tensor. - """ - out: transformers.modeling_tf_outputs.TFCausalLMOutputWithPast = self.model( - encoded_inputs["input_ids"] - ) - - model_probs = tf.nn.softmax(out.logits, axis=-1) - top_k = tf.math.top_k(model_probs, k=self.top_k, sorted=True, name=None) - batched_outputs = { - "input_ids": encoded_inputs["input_ids"], - "ntok": tf.reduce_sum(encoded_inputs["attention_mask"], axis=1), - "top_k_indices": top_k.indices, - "top_k_probs": top_k.values, - } - - # Convert representations for each layer from tuples to single Tensor. - for i in range(len(out.attentions)): - batched_outputs[f"layer_{i+1:d}_attention"] = out.attentions[i] - for i in range(len(out.hidden_states)): - batched_outputs[f"layer_{i:d}_avg_embedding"] = tf.math.reduce_mean( - out.hidden_states[i], axis=1 - ) - - return batched_outputs - - def _postprocess(self, preds): - """Post-process single-example preds. Operates on numpy arrays.""" - ntok = preds.pop("ntok") - ids = preds.pop("input_ids")[:ntok] - preds["tokens"] = self.ids_to_clean_tokens(ids) - - # Decode predicted top-k tokens. - # token_topk_preds will be a list[list[(word, prob)]] - # Initialize prediction for 0th token as N/A. - token_topk_preds = [[("N/A", 1.0)]] - pred_ids = preds.pop("top_k_indices")[:ntok] # [num_tokens, k] - pred_probs = preds.pop("top_k_probs")[:ntok] # [num_tokens, k] - for token_pred_ids, token_pred_probs in zip(pred_ids, pred_probs): - token_pred_words = self.ids_to_clean_tokens(token_pred_ids) - token_topk_preds.append(list(zip(token_pred_words, token_pred_probs))) - preds["pred_tokens"] = token_topk_preds - - # Process attention. - for key in preds: - if not re.match(r"layer_(\d+)/attention", key): - continue - # Select only real tokens, since most of this matrix is padding. - # [num_heads, max_seq_length, max_seq_length] - # -> [num_heads, num_tokens, num_tokens] - preds[key] = preds[key][:, :ntok, :ntok].transpose((0, 2, 1)) - # Make a copy of this array to avoid memory leaks, since NumPy otherwise - # keeps a pointer around that prevents the source array from being GCed. - preds[key] = preds[key].copy() - - return preds - - ## - # LIT API implementations - def max_minibatch_size(self) -> int: - # The BatchedModel base class handles batching automatically in the - # implementation of predict(), and uses this value as the batch size. - return 6 - - def predict_minibatch(self, inputs): - """Predict on a single minibatch of examples.""" - # Preprocess inputs. - texts = [ex["text"] for ex in inputs] - encoded_inputs = self.tokenizer( - texts, - return_tensors="tf", - add_special_tokens=True, - padding="longest", - truncation="longest_first", - ) - - # Get the predictions. - batched_outputs = self._pred(encoded_inputs) - # Convert to numpy for post-processing. - detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()} - # Split up batched outputs, then post-process each example. - unbatched_outputs = utils.unbatch_preds(detached_outputs) - return map(self._postprocess, unbatched_outputs) - - def input_spec(self): - return {"text": lit_types.TextSegment()} - - def output_spec(self): - spec = { - # the "parent" keyword tells LIT which field in the input spec we should - # compare this to when computing metrics. - "pred_tokens": lit_types.TokenTopKPreds(align="tokens"), - "tokens": lit_types.Tokens(parent="text"), # all tokens - } - # Add attention and embeddings from each layer. - for i in range(self.num_layers): - spec[f"layer_{i+1:d}_attention"] = lit_types.AttentionHeads( - align_in="tokens", align_out="tokens" - ) - spec[f"layer_{i:d}_avg_embedding"] = lit_types.Embeddings() - return spec diff --git a/lit_nlp/examples/models/pretrained_lms_int_test.py b/lit_nlp/examples/models/pretrained_lms_int_test.py deleted file mode 100644 index 187a335c..00000000 --- a/lit_nlp/examples/models/pretrained_lms_int_test.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Integration tests for pretrained_lms.""" - -from absl.testing import absltest -from lit_nlp.examples.models import pretrained_lms - - -class PretrainedLmsIntTest(absltest.TestCase): - """Test that model classes can predict.""" - - def test_bertmlm(self): - # Run prediction to ensure no failure. - model_path = "https://storage.googleapis.com/what-if-tool-resources/lit-models/bert-base-uncased.tar.gz" - model = pretrained_lms.BertMLM(model_path) - model_in = [{"text": "test text", "tokens": ["test", "[MASK]"]}] - model_out = list(model.predict(model_in)) - - # Sanity-check entries exist in output. - self.assertLen(model_out, 1) - self.assertIn("pred_tokens", model_out[0]) - self.assertIn("cls_emb", model_out[0]) - - def test_gpt2(self): - # Run prediction to ensure no failure. - model_path = "https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2.tar.gz" - model = pretrained_lms.GPT2LanguageModel(model_path) - model_in = [{"text": "test text"}, {"text": "longer test text"}] - model_out = list(model.predict(model_in)) - - # Sanity-check output vs output spec. - self.assertLen(model_out, 2) - for key in model.output_spec().keys(): - self.assertIn(key, model_out[0].keys()) - - -if __name__ == "__main__": - absltest.main() diff --git a/lit_nlp/examples/prompt_debugging/datasets.py b/lit_nlp/examples/prompt_debugging/datasets.py index 4681c951..22157704 100644 --- a/lit_nlp/examples/prompt_debugging/datasets.py +++ b/lit_nlp/examples/prompt_debugging/datasets.py @@ -1,25 +1,115 @@ """Methods for configuring prompt debugging datasets.""" from collections.abc import Mapping, Sequence +import copy import functools +import json +import os import re from typing import Optional from absl import logging from lit_nlp import app as lit_app from lit_nlp.api import dataset as lit_dataset -from lit_nlp.examples.datasets import lm as lm_data +from lit_nlp.api import types as lit_types - -DEFAULT_DATASETS = ["sample_prompts"] +SAMPLE_DATA_DIR = os.path.dirname(__file__) +DEFAULT_DATASETS = ['sample_prompts'] DEFAULT_MAX_EXAMPLES = 1000 + +class PlaintextSents(lit_dataset.Dataset): + """Load sentences from a flat text file.""" + + def __init__( + self, + path_or_glob: str, + skiplines: int = 0, + max_examples: Optional[int] = None, + field_name: str = 'text', + ): + self.field_name = field_name + self._examples = self.load_datapoints(path_or_glob, skiplines=skiplines)[ + :max_examples + ] + + @classmethod + def init_spec(cls) -> lit_types.Spec: + default_path = '' + + return { + 'path_or_glob': lit_types.String(default=default_path, required=False), + 'skiplines': lit_types.Integer(default=0, max_val=25), + 'max_examples': lit_types.Integer( + default=1000, min_val=0, max_val=10_000, required=False + ), + } + + def load_datapoints(self, path_or_glob: str, skiplines: int = 0): + examples = [] + for path in glob.glob(path_or_glob): + with open(path) as fd: + for i, line in enumerate(fd): + if i < skiplines: # skip header lines, if necessary + continue + line = line.strip() + if line: # skip blank lines, these are usually document breaks + examples.append({self.field_name: line}) + return examples + + def load(self, path: str): + return lit_dataset.Dataset(base=self, examples=self.load_datapoints(path)) + + def spec(self) -> lit_types.Spec: + """Should match MLM's input_spec().""" + return {self.field_name: lit_types.TextSegment()} + + +class PromptExamples(lit_dataset.Dataset): + """Prompt examples for modern LMs.""" + + SAMPLE_DATA_PATH = os.path.join(SAMPLE_DATA_DIR, 'prompt_examples.jsonl') + + def load_datapoints(self, path: str): + if not path: + logging.warn( + 'Empty path to PromptExamples.load_datapoints(). Returning empty' + ' dataset.' + ) + return [] + + default_ex_values = { + k: copy.deepcopy(field_spec.default) + for k, field_spec in self.spec().items() + } + + examples = [] + with open(path) as fd: + for line in fd: + examples.append(default_ex_values | json.loads(line)) + + return examples + + def __init__(self, path: str): + self._examples = self.load_datapoints(path) + + def spec(self) -> lit_types.Spec: + return { + 'source': lit_types.CategoryLabel(), + 'prompt': lit_types.TextSegment(), + 'target': lit_types.TextSegment(), + } + + def load(self, path: str): + return lit_dataset.Dataset(base=self, examples=self.load_datapoints(path)) + + _plaintext_prompts = functools.partial( # pylint: disable=invalid-name - lm_data.PlaintextSents, field_name="prompt" + PlaintextSents, field_name='prompt' ) # Hack: normally dataset loaders are a class object which has a __name__, # rather than a functools.partial -_plaintext_prompts.__name__ = "PlaintextSents" +_plaintext_prompts.__name__ = 'PlaintextSents' def get_datasets( @@ -44,22 +134,22 @@ def get_datasets( datasets: dict[str, lit_dataset.Dataset] = {} for dataset_string in datasets_config: - if dataset_string == "sample_prompts": - dataset_name = "sample_prompts" - path = lm_data.PromptExamples.SAMPLE_DATA_PATH + if dataset_string == 'sample_prompts': + dataset_name = 'sample_prompts' + path = PromptExamples.SAMPLE_DATA_PATH else: # Only split on the first ':', because path may be a URL # containing 'https://' - dataset_name, path = dataset_string.split(":", 1) + dataset_name, path = dataset_string.split(':', 1) logging.info("Loading dataset '%s' from '%s'", dataset_name, path) - if path.endswith(".jsonl"): - datasets[dataset_name] = lm_data.PromptExamples(path) + if path.endswith('.jsonl'): + datasets[dataset_name] = PromptExamples(path) # .txt or .txt-#####-of-##### - elif path.endswith(".txt") or re.match(r".*\.txt-\d{5}-of-\d{5}$", path): + elif path.endswith('.txt') or re.match(r'.*\.txt-\d{5}-of-\d{5}$', path): datasets[dataset_name] = _plaintext_prompts(path) else: - raise ValueError(f"Unsupported dataset format for {dataset_string}") + raise ValueError(f'Unsupported dataset format for {dataset_string}') for name in datasets: datasets[name] = datasets[name].slice[:max_examples] @@ -70,12 +160,12 @@ def get_datasets( def get_dataset_loaders() -> lit_app.DatasetLoadersMap: return { - "jsonl_examples": ( - lm_data.PromptExamples, - lm_data.PromptExamples.init_spec(), + 'jsonl_examples': ( + PromptExamples, + PromptExamples.init_spec(), ), - "plaintext_inputs": ( + 'plaintext_inputs': ( _plaintext_prompts, - lm_data.PlaintextSents.init_spec(), + PlaintextSents.init_spec(), ), } diff --git a/lit_nlp/examples/prompt_debugging/models.py b/lit_nlp/examples/prompt_debugging/models.py index 6f2347f0..2c72c9b6 100644 --- a/lit_nlp/examples/prompt_debugging/models.py +++ b/lit_nlp/examples/prompt_debugging/models.py @@ -9,7 +9,7 @@ from lit_nlp.lib import file_cache -DEFAULT_BATCH_SIZE = 4 +DEFAULT_BATCH_SIZE = 1 DEFAULT_DL_FRAMEWORK = "kerasnlp" DEFAULT_DL_RUNTIME = "tensorflow" DEFAULT_MODELS = ["gemma_1.1_instruct_2b_en:gemma_1.1_instruct_2b_en"], diff --git a/lit_nlp/examples/datasets/prompt_examples.jsonl b/lit_nlp/examples/prompt_debugging/prompt_examples.jsonl similarity index 100% rename from lit_nlp/examples/datasets/prompt_examples.jsonl rename to lit_nlp/examples/prompt_debugging/prompt_examples.jsonl diff --git a/website/sphinx_src/api.md b/website/sphinx_src/api.md index e734be39..59c3be20 100644 --- a/website/sphinx_src/api.md +++ b/website/sphinx_src/api.md @@ -675,7 +675,7 @@ Each `LitType` subclass encapsulates its own semantics (see * A field that appears in _both_ the model's input and output specs is assumed to represent the same value. This pattern is used for model-based input manipulation. For example, a - [language model](https://github.com/PAIR-code/lit/blob/main/lit_nlp/examples/models/pretrained_lms.py) + [language model](https://github.com/PAIR-code/lit/blob/main/lit_nlp/examples/glue/models.py) might output `'tokens': lit_types.Tokens(...)`, and accept as (optional) input `'tokens': lit_types.Tokens(required=False, ...)`. An interpretability component could take output from the former, swap one or more tokens (e.g. @@ -886,22 +886,16 @@ You can specify custom web app layouts from Python via the `layouts=` attribute. The value should be a `Mapping[str, LitCanonicalLayout]`, such as: ```python -LM_LAYOUT = layout.LitCanonicalLayout( +PENGUIN_LAYOUT = layout.LitCanonicalLayout( upper={ - "Main": [ - modules.EmbeddingsModule, + 'Main': [ + modules.DiveModule, modules.DataTableModule, modules.DatapointEditorModule, ] }, - lower={ - "Predictions": [ - modules.LanguageModelPredictionModule, - modules.ConfusionMatrixModule, - ], - "Counterfactuals": [modules.GeneratorModule], - }, - description="Custom layout for language models.", + lower=layout.STANDARD_LAYOUT.lower, + description='Custom layout for the Palmer Penguins demo.', ) ``` @@ -912,14 +906,12 @@ lit_demo = dev_server.Server( models, datasets, # other args... - layouts={"lm": LM_LAYOUT}, + layouts=layout.DEFAULT_LAYOUTS | {'penguins': PENGUIN_LAYOUT}, + default_layout='penguins', **server_flags.get_flags()) return lit_demo.serve() ``` -For a full example, see -[`lm_demo.py`](https://github.com/PAIR-code/lit/blob/main/lit_nlp/examples/lm_demo.py). - You can see the pre-configured layouts provided by LIT, as well as the list of modules that can be included in your custom layout in [`layout.py`](https://github.com/PAIR-code/lit/blob/main/lit_nlp/api/layout.py). A diff --git a/website/sphinx_src/demos.md b/website/sphinx_src/demos.md index 8617519d..89812af9 100644 --- a/website/sphinx_src/demos.md +++ b/website/sphinx_src/demos.md @@ -82,21 +82,6 @@ Tip: check out the in-depth walkthrough at https://ai.google.dev/responsible/model_behavior, part of the Responsible Generative AI Toolkit. -## Language Modeling - -### BERT and GPT-2 - -**Hosted instance:** https://pair-code.github.io/lit/demos/lm.html \ -**Code:** [examples/lm_demo.py](https://github.com/PAIR-code/lit/blob/main/lit_nlp/examples/lm_demo.py) - -* Compare multiple BERT and GPT-2 models side-by-side on a variety of - plain-text corpora. -* LM visualization supports different modes: - * BERT masked language model: click-to-mask, and query model at that - position. - * GPT-2 shows left-to-right hypotheses for each target token. -* Embedding projector to show latent space of the model. - -------------------------------------------------------------------------------- ## Structured Prediction diff --git a/website/sphinx_src/docker.md b/website/sphinx_src/docker.md index f83a6f14..c6b501d4 100644 --- a/website/sphinx_src/docker.md +++ b/website/sphinx_src/docker.md @@ -23,8 +23,7 @@ the WSGI app to serve. The options provided to gunicorn for our use-case can be found in [`gunicorn_config.py`](https://github.com/PAIR-code/lit/blob/main/lit_nlp/examples/gunicorn_config.py). You can find a reference implementation in -[`glue/demo.py`](https://github.com/PAIR-code/lit/blob/main/lit_nlp/examples/glue/demo.py) or -[`lm_demo.py`](https://github.com/PAIR-code/lit/blob/main/lit_nlp/examples/lm_demo.py). +[`glue/demo.py`](https://github.com/PAIR-code/lit/blob/main/lit_nlp/examples/glue/demo.py). Use the following shell [.github/workflows/ci.ymlcommands](https://github.com/PAIR-code/lit/blob/main/lit_nlp/.github/workflows/ci.ymlcommands) to build the diff --git a/website/sphinx_src/frontend_development.md b/website/sphinx_src/frontend_development.md index a51aed81..6f399a43 100644 --- a/website/sphinx_src/frontend_development.md +++ b/website/sphinx_src/frontend_development.md @@ -76,8 +76,8 @@ pre-configured layouts in You can also add [custom layouts](./api.md#customizing-the-layout) to your LIT instance by defining one or more `LitCanonicalLayout` instances and passing them -to the server. For an example, see `CUSTOM_LAYOUTS` in -[`lm_demo.py`](https://github.com/PAIR-code/lit/blob/main/lit_nlp/examples/lm_demo.py). +to the server. For an example, see +[`prompt_debugging/layouts.py`](https://github.com/PAIR-code/lit/blob/main/lit_nlp/examples/prompt_debugging/layouts.py). Note: The pre-configured layouts are added to every `LitApp` instance using [dictionary updates](https://docs.python.org/3/library/stdtypes.html#dict) where