Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Fix transformer pruning example (#4002)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaowu0162 authored Aug 3, 2021
1 parent dfd853d commit e5b4bf1
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 58 deletions.
33 changes: 19 additions & 14 deletions docs/en_US/Compression/Pruner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ Transformer Head Pruner is a tool designed for pruning attention heads from the

Typically, each attention layer in the Transformer models consists of four weights: three projection matrices for query, key, value, and an output projection matrix. The outputs of the former three matrices contains the projected results for all heads. Normally, the results are then reshaped so that each head performs that attention computation independently. The final results are concatenated back before fed into the output projection. Therefore, when an attention head is pruned, the same weights corresponding to that heads in the three projection matrices are pruned. Also, the weights in the output projection corresponding to the head's output are pruned. In our implementation, we calculate and apply masks to the four matrices together.

Note: currently, the pruner can only handle models with projection weights written as separate ``Linear`` modules, i.e., it expects four ``Linear`` modules corresponding to query, key, value, and an output projections. Therefore, in the ``config_list``, you should either write ``['Linear']`` for the ``op_types`` field, or write names corresponding to ``Linear`` modules for the ``op_names`` field.
Note: currently, the pruner can only handle models with projection weights written as separate ``Linear`` modules, i.e., it expects four ``Linear`` modules corresponding to query, key, value, and an output projections. Therefore, in the ``config_list``, you should either write ``['Linear']`` for the ``op_types`` field, or write names corresponding to ``Linear`` modules for the ``op_names`` field. For instance, the `Huggingface transformers <https://huggingface.co/transformers/index.html>`_ are supported, but ``torch.nn.Transformer`` is not.

The pruner implements the following algorithm:

Expand All @@ -756,11 +756,9 @@ Currently, the following head sorting criteria are supported:
* "l2_activation": rank heads by the L2-norm of their attention computation output.
* "taylorfo": rank heads by l1 norm of the output of attention computation * gradient for this output. Check more details in `this paper <https://arxiv.org/abs/1905.10650>`__ and `this one <https://arxiv.org/abs/1611.06440>`__.
We support local sorting (i.e., sorting heads within a layer) and global sorting (sorting all heads together), and you can control by setting the ``global_sort`` parameter. Note that if ``global_sort=True`` is passed, all weights must have the same sparsity in the config list. However, this does not mean that each layer will be prune to the same sparsity as specified. This sparsity value will be interpreted as a global sparsity, and each layer is likely to have different sparsity after pruning by global sort.
We support local sorting (i.e., sorting heads within a layer) and global sorting (sorting all heads together), and you can control by setting the ``global_sort`` parameter. Note that if ``global_sort=True`` is passed, all weights must have the same sparsity in the config list. However, this does not mean that each layer will be prune to the same sparsity as specified. This sparsity value will be interpreted as a global sparsity, and each layer is likely to have different sparsity after pruning by global sort. As a reminder, we found that if global sorting is used, it is usually helpful to use an iterative pruning scheme, interleaving pruning with intermediate finetuning, since global sorting often results in non-uniform sparsity distributions, which makes the model more susceptible to forgetting.
In our implementation, we support two ways to group the four weights in the same layer together. You can either pass a nested list containing the names of these modules as the pruner's initialization parameters (usage below), or simply pass a dummy input and the pruner will run ``torch.jit.trace`` to group the weights (experimental feature). However, if you would like to assign different sparsity to each layer, you can only use the first option, i.e., passing names of the weights to the pruner (see usage below). Also note that weights belonging to the same layer must have the same sparsity.
In addition to the following usage guide, we provide a more detailed example of pruning BERT for tasks from the GLUE benchmark. Please find it in this :githublink:`page <examples/model_compress/pruning/transformers>`.
In our implementation, we support two ways to group the four weights in the same layer together. You can either pass a nested list containing the names of these modules as the pruner's initialization parameters (usage below), or simply pass a dummy input instead and the pruner will run ``torch.jit.trace`` to group the weights (experimental feature). However, if you would like to assign different sparsity to each layer, you can only use the first option, i.e., passing names of the weights to the pruner (see usage below). Also, note that we require the weights belonging to the same layer to have the same sparsity.
Usage
^^^^^
Expand All @@ -786,6 +784,7 @@ Suppose we want to prune a BERT with Huggingface implementation, which has the f
["encoder.layer.{}.attention.self.key".format(i) for i in range(12)],
["encoder.layer.{}.attention.self.value".format(i) for i in range(12)],
["encoder.layer.{}.attention.output.dense".format(i) for i in range(12)]))
kwargs = {"ranking_criterion": "l1_weight",
"global_sort": False,
"num_iterations": 1,
Expand All @@ -796,20 +795,26 @@ Suppose we want to prune a BERT with Huggingface implementation, which has the f
"optimizer": optimizer,
"forward_runner": forward_runner
}
config_list = [{
"sparsity": 0.5,
config_list = [{
"sparsity": 0.5,
"op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[:6] for x in layer] # first six layers
},
{
"sparsity": 0.25,
"op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[6:] for x in layer] # last six layers
}
]
}, {
"sparsity": 0.25,
"op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[6:] for x in layer] # last six layers
}]
pruner = TransformerHeadPruner(model, config_list, **kwargs)
pruner.compress()
In addition to this usage guide, we provide a more detailed example of pruning BERT (Huggingface implementation) for transfer learning on the tasks from the `GLUE benchmark <https://gluebenchmark.com/>`_. Please find it in this :githublink:`page <examples/model_compress/pruning/transformers>`. To run the example, first make sure that you install the package ``transformers`` and ``datasets``. Then, you may start by running the following command:
.. code-block:: bash
./run.sh gpu_id glue_task
By default, the code will download a pretrained BERT language model, and then finetune for several epochs on the downstream GLUE task. Then, the ``TransformerHeadPruner`` will be used to prune out heads from each layer by a certain criterion (by default, the code lets the pruner uses magnitude ranking, and prunes out 50% of the heads in each layer in an one-shot manner). Finally, the pruned model will be finetuned in the downstream task for several epochs. You can check the details of pruning from the logs printed out by the example. You can also experiment with different pruning settings by changing the parameters in ``run.sh``, or directly changing the ``config_list`` in ``transformer_pruning.py``.
User configuration for Transformer Head Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
91 changes: 47 additions & 44 deletions examples/model_compress/pruning/transformers/transformer_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,24 @@

import argparse
import logging
import math
import os
import random

import torch
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm

import nni
from nni.compression.pytorch import ModelSpeedup
from nni.compression.pytorch.utils.counter import count_flops_params
from nni.algorithms.compression.pytorch.pruning import TransformerHeadPruner


import datasets
from datasets import load_dataset, load_metric
import transformers
from transformers import (
AdamW,
AutoConfig,
AutoModel,
AutoModelForPreTraining,
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
PretrainedConfig,
default_data_collator,
get_scheduler,
)

Expand All @@ -38,7 +29,8 @@


def parse_args():
parser = argparse.ArgumentParser(description="Example: prune a Huggingface transformer and finetune on GLUE tasks.")
parser = argparse.ArgumentParser(
description="Example: prune a Huggingface transformer and finetune on GLUE tasks.")

parser.add_argument("--model_name", type=str, required=True,
help="Pretrained model architecture.")
Expand All @@ -53,7 +45,8 @@ def parse_args():
help="Rank the heads globally and prune the heads with lowest scores. If set to False, the "
"heads are only ranked within one layer")
parser.add_argument("--ranking_criterion", type=str, default="l1_weight",
choices=["l1_weight", "l2_weight", "l1_activation", "l2_activation", "taylorfo"],
choices=["l1_weight", "l2_weight",
"l1_activation", "l2_activation", "taylorfo"],
help="Criterion by which the attention heads are ranked.")
parser.add_argument("--num_iterations", type=int, default=1,
help="Number of pruning iterations (1 for one-shot pruning).")
Expand Down Expand Up @@ -93,7 +86,8 @@ def get_raw_dataset(task_name):
"""
raw_dataset = load_dataset("glue", task_name)
is_regression = task_name == "stsb"
num_labels = 1 if is_regression else len(raw_dataset["train"].features["label"].names)
num_labels = 1 if is_regression else len(
raw_dataset["train"].features["label"].names)

return raw_dataset, is_regression, num_labels

Expand All @@ -105,29 +99,32 @@ def preprocess(args, tokenizer, raw_dataset):
assert args.task_name is not None

task_to_keys = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"),
"qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"),
"qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
}
sentence1_key, sentence2_key = task_to_keys[args.task_name]

def tokenize(data):
texts = (
(data[sentence1_key],) if sentence2_key is None else (data[sentence1_key], data[sentence2_key])
(data[sentence1_key],) if sentence2_key is None else (
data[sentence1_key], data[sentence2_key])
)
result = tokenizer(*texts, padding=False, max_length=args.max_length, truncation=True)
result = tokenizer(*texts, padding=False,
max_length=args.max_length, truncation=True)

if "label" in data:
result["labels"] = data["label"]
return result

processed_datasets = raw_dataset.map(tokenize, batched=True, remove_columns=raw_dataset["train"].column_names)
processed_datasets = raw_dataset.map(
tokenize, batched=True, remove_columns=raw_dataset["train"].column_names)
return processed_datasets


Expand Down Expand Up @@ -168,7 +165,8 @@ def train_model(args, model, is_regression, train_dataloader, eval_dataloader, o
for field in batch.keys():
batch[field] = batch[field].to(device)
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
predictions = outputs.logits.argmax(dim=-1) if not is_regression \
else outputs.logits.squeeze()
metric.add_batch(predictions=predictions, references=batch["labels"])

eval_metric = metric.compute()
Expand All @@ -183,7 +181,7 @@ def trainer_helper(model, train_dataloader, optimizer, device):
"""
logger.info("Training for 1 epoch...")
progress_bar = tqdm(range(len(train_dataloader)), position=0, leave=True)

train_epoch = 1
for epoch in range(train_epoch):
for step, batch in enumerate(train_dataloader):
Expand Down Expand Up @@ -213,7 +211,7 @@ def forward_runner_helper(model, train_dataloader, device):
_ = model(**batch)
# note: no loss.backward or optimizer.step() is performed here
progress_bar.update(1)


def final_eval_for_mnli(args, model, processed_datasets, metric, data_collator):
"""
Expand Down Expand Up @@ -248,15 +246,18 @@ def main():
transformers.utils.logging.set_verbosity_info()

# Load dataset and tokenizer, and then preprocess the dataset
raw_dataset, is_regression, num_labels = get_raw_dataset(args.task_name)
raw_dataset, is_regression, num_labels = get_raw_dataset(args.task_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)
processed_datasets = preprocess(args, tokenizer, raw_dataset)
train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]
eval_dataset = processed_datasets["validation_matched" if args.task_name ==
"mnli" else "validation"]

# Load pretrained model
config = AutoConfig.from_pretrained(args.model_name, num_labels=num_labels, finetuning_task=args.task_name)
model = AutoModelForSequenceClassification.from_pretrained(args.model_name, config=config)
config = AutoConfig.from_pretrained(
args.model_name, num_labels=num_labels, finetuning_task=args.task_name)
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name, config=config)
model.to(device)

#########################################################################
Expand All @@ -269,9 +270,10 @@ def main():
lr_scheduler = get_scheduler(name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps,
num_training_steps=train_steps)
metric = load_metric("glue", args.task_name)

logger.info("================= Finetuning before pruning =================")
train_model(args, model, is_regression, train_dataloader, eval_dataloader, optimizer, lr_scheduler, metric, device)
train_model(args, model, is_regression, train_dataloader,
eval_dataloader, optimizer, lr_scheduler, metric, device)

if args.output_dir is not None:
torch.save(model.state_dict(), args.output_dir + "/model_before_pruning.pt")
Expand Down Expand Up @@ -316,13 +318,11 @@ def forward_runner(model):
"sparsity": args.sparsity,
"op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[:6] for x in layer]
},
{
"sparsity": args.sparsity / 2,
"op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[6:] for x in layer]
}
]
}, {
"sparsity": args.sparsity / 2,
"op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[6:] for x in layer]
}]

pruner = TransformerHeadPruner(model, config_list, **kwargs)
pruner.compress()
Expand Down Expand Up @@ -360,21 +360,24 @@ def forward_runner(model):
# After pruning, finetune again on the target task
# Get the metric function
metric = load_metric("glue", args.task_name)

# re-initialize the optimizer and the scheduler
optimizer, _, _, data_collator = get_dataloader_and_optimizer(args, tokenizer, model, train_dataset,
eval_dataset)
lr_scheduler = get_scheduler(name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps,
num_training_steps=train_steps)

logger.info("================= Finetuning after Pruning =================")
train_model(args, model, is_regression, train_dataloader, eval_dataloader, optimizer, lr_scheduler, metric, device)
train_model(args, model, is_regression, train_dataloader,
eval_dataloader, optimizer, lr_scheduler, metric, device)

if args.output_dir is not None:
torch.save(model.state_dict(), args.output_dir + "/model_after_pruning.pt")
torch.save(model.state_dict(), args.output_dir +
"/model_after_pruning.pt")

if args.task_name == "mnli":
final_eval_for_mnli(args, model, processed_datasets, metric, data_collator)
final_eval_for_mnli(args, model, processed_datasets,
metric, data_collator)

flops, params, results = count_flops_params(model, dummy_input)
print(f"Final model FLOPs {flops / 1e6:.2f} M, #Params: {params / 1e6:.2f}M")
Expand Down

0 comments on commit e5b4bf1

Please sign in to comment.