From b8d4dff162b6768af3a59440bf7d6ca80bf4e530 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Thu, 16 Nov 2023 12:12:40 -0500 Subject: [PATCH] Refactor to use WandaLayerCompressor Update WrappedGPT --- src/sparseml/modifiers/pruning/wanda/base.py | 3 +- .../modifiers/pruning/wanda/pytorch.py | 227 +++++++++++------- .../modifiers/pruning/wanda/utils/__init__.py | 2 - .../modifiers/pruning/wanda/utils/helpers.py | 83 ------- .../pruning/wanda/utils/layer_compressor.py | 197 +++++++++++++++ .../pruning/wanda/utils/wrapped_gpt.py | 122 +++++++++- 6 files changed, 453 insertions(+), 181 deletions(-) delete mode 100644 src/sparseml/modifiers/pruning/wanda/utils/helpers.py create mode 100644 src/sparseml/modifiers/pruning/wanda/utils/layer_compressor.py diff --git a/src/sparseml/modifiers/pruning/wanda/base.py b/src/sparseml/modifiers/pruning/wanda/base.py index 45efb14b2af..35231d59ec6 100644 --- a/src/sparseml/modifiers/pruning/wanda/base.py +++ b/src/sparseml/modifiers/pruning/wanda/base.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import List, Union +from typing import List, Optional, Union from sparseml.core import Modifier from sparseml.core.model.base import ModifiableModel @@ -45,6 +45,7 @@ class WandaPruningModifier(Modifier): sparsity: Union[float, List[float]] mask_structure: str = "0:0" targets: Union[str, List[str], None] = ALL_TOKEN + compressible_layers_: Optional[List] = None def on_initialize_structure(self, state: State, **kwargs): """ diff --git a/src/sparseml/modifiers/pruning/wanda/pytorch.py b/src/sparseml/modifiers/pruning/wanda/pytorch.py index 4262738b2ec..21544a9699b 100644 --- a/src/sparseml/modifiers/pruning/wanda/pytorch.py +++ b/src/sparseml/modifiers/pruning/wanda/pytorch.py @@ -13,16 +13,15 @@ # limitations under the License. import logging +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch +from sparseml.core.model.base import ModifiableModel from sparseml.core.state import State +from sparseml.modifiers.obcq.utils.helpers import cache_attention_inputs from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier -from sparseml.modifiers.pruning.wanda.utils.helpers import ( - find_layers, - prepare_calibration_input, -) -from sparseml.modifiers.pruning.wanda.utils.wrapped_gpt import WrappedGPT +from sparseml.modifiers.pruning.wanda.utils.layer_compressor import WandaLayerCompressor _LOGGER = logging.getLogger(__name__) @@ -33,92 +32,148 @@ class WandaPruningModifierPyTorch(WandaPruningModifier): PyTorch implementation of WandaPruningModifier """ + model: Optional[ModifiableModel] = None + device_: str = "cuda:0" + layer_prefix_: Optional[str] = None + prunen_: Optional[int] = None + prunem_: Optional[int] = None + def on_initialize(self, state: State, **kwargs) -> bool: - modifiable_model = state.model - pytorch_model = modifiable_model.model - use_cache = pytorch_model.config.use_cache - - # set use_cache to False to avoid OOM - pytorch_model.config.use_cache = False - - _LOGGER.info("Preparing calibration data") - calibration_dataloader = state.data.calib - device = state.hardware.device - pytorch_model.to(device) - with torch.no_grad(): - inps, outs, attention_mask, position_ids = prepare_calibration_input( - pytorch_model, calibration_dataloader, device - ) + """ + Initialize and run the WANDA algorithm on the current state - layers = pytorch_model.model.layers - for i in range(len(layers)): - layer = layers[i] - subset = find_layers(layer) - wrapped_layers = {} - for name in subset: - wrapped_layers[name] = WrappedGPT(subset[name]) - - def add_batch(name): - def tmp(_, inp, out): - wrapped_layers[name].add_batch(inp[0].data, out.data) - - return tmp - - handles = [] - for name in wrapped_layers: - handles.append(subset[name].register_forward_hook(add_batch(name))) - for j in range(len(calibration_dataloader)): - with torch.no_grad(): - outs[j] = layer( - inps[j].unsqueeze(0), - attention_mask=attention_mask, - position_ids=position_ids, - )[0] - for h in handles: - h.remove() - if self.mask_structure == "unstructured": - prune_n = prune_m = 0 - else: - prune_n, prune_m = tuple(map(int, self.mask_structure.split(":"))) - - for name in subset: - _LOGGER.info(f"pruning layer {i} name {name}") - W_metric = torch.abs(subset[name].weight.data) * torch.sqrt( - wrapped_layers[name].scaler_row.reshape((1, -1)) - ) + :param state: session state storing input model and calibration data + """ + self._validate_layerwise_sparsity() + + self.initialize_wanda(state, **kwargs) - W_mask = ( - torch.zeros_like(W_metric) == 1 - ) # initialize a mask to be all False - if prune_n != 0: - # structured n:m sparsity - for ii in range(W_metric.shape[1]): - if ii % prune_m == 0: - tmp = W_metric[:, ii : (ii + prune_m)].float() - W_mask.scatter_( - 1, - ii + torch.topk(tmp, prune_n, dim=1, largest=False)[1], - True, - ) - else: - sort_res = torch.sort(W_metric, dim=-1, stable=True) - indices = sort_res[1][:, : int(W_metric.shape[1] * self.sparsity)] - W_mask.scatter_(1, indices, True) - - subset[name].weight.data[W_mask] = 0 # set weights to zero - - for j in range(len(calibration_dataloader)): - with torch.no_grad(): - outs[j] = layer( - inps[j].unsqueeze(0), - attention_mask=attention_mask, - position_ids=position_ids, - )[0] - inps, outs = outs, inps - - pytorch_model.config.use_cache = use_cache + # run wanda on calibration data + self.apply_wanda(dataloader=state.data.calib) torch.cuda.empty_cache() return True + def initialize_wanda(self, state: State, **kwargs): + """ + Setup for WANDA, initializes the model, device, + and other parameters, also initilializes the + compressible layers of model, and sets the device + + :param state: session state storing input model and calibration data + """ + self.model = state.model + self.compressible_layers_ = self.compressible_layers() + self.device_ = self._set_device(device=state.hardware.device) + self.layer_prefix_ = self.model.layer_prefix + self._infer_mask_block_size() + + @torch.no_grad() + def apply_wanda( + self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None + ) -> Dict: + """ + Run Wanda on the loaded model, using dataloader as calibration data + + :param dataloader: calibration data for WANDA + """ + accum_kwargs = {"dataloader": dataloader} + pytorch_model = self.model.model + + # Step 0: Pass the calibration data through the (compressed) bottom part of the + # network, capturing the outputs which will become the inputs to the first + # decoder layer. Also return attention_mask as part of kwargs + extras = self.compress_bottom( + dev=self.device_, + layer_prefix=self.layer_prefix_, + **accum_kwargs, + ) + accum_kwargs.update(extras) + + # Step 1: Sequentially prune decoder layers + inputs = None + num_layers = len(self.compressible_layers_) + for idx, layer in enumerate(self.compressible_layers_): + if "outputs" not in accum_kwargs: + raise RuntimeError( + "The 'outputs' key is expected but not found from the " + "return of the bottom compressor" + ) + + inputs = accum_kwargs["outputs"] + layer_sparsity = ( + self.sparsity[idx] if isinstance(self.sparsity, List) else self.sparsity + ) + _LOGGER.info( + f"\n===== Compressing layer {idx+1}/{num_layers} " + f"to sparsity {layer_sparsity} =====" + ) + args = { + "sparsity": layer_sparsity, + "prunen": self.prunen_, + "prunem": self.prunem_, + } + # Prune using WandaGPT + layer_compressor = WandaLayerCompressor( + model=pytorch_model, + layer=layer, + layer_index=idx, + inputs=inputs, + args=args, + ) + layer_kwargs = layer_compressor.compress(dev=self.device_, **accum_kwargs) + accum_kwargs.update(layer_kwargs) + + def compress_bottom( + self, + dataloader: List = None, + nsamples: int = None, + dev: str = "cuda:0", + layer_prefix: Optional[str] = None, + ) -> Dict: + """ + Runs calibration data through the bottom part of the network (everything up + to the first decoder layer) and return the captured outputs + + :param dataloader: calibration data to pass through the model + :param nsamples: number of samples to use for calibration, or None to use it all + :param dev: device to use + :param layer_prefix: name of model attribute that contains the list of layers, + i.e. model.decoder for OPT or just model for Llama + :return: outputs from bottom part of network, attention mask, and kv-cache state + """ + layer_prefix = layer_prefix or self.layer_prefix_ + cached_inputs = cache_attention_inputs( + model=self.model.model, + dataloader=dataloader, + device=dev, + nsamples=nsamples, + target_ids=None, + layer_prefix=layer_prefix, + ) + + outputs = cached_inputs.pop("inputs") + outputs = [o[0] for o in outputs] + cached_inputs.update({"outputs": outputs}) + return cached_inputs + def on_finalize(self, state: State, **kwargs): return True + + def _set_device(self, device: str): + if "cuda" in device and not torch.cuda.is_available(): + self.device_ = "cpu" + else: + self.device_ = device + + def _infer_mask_block_size(self): + """ + Infer the mask block size from the mask structure. + Parses mask_structure of the form N:M where N, M are integers that + define a custom block shape; and sets prunen_ and prunem_ accordingly. + + :post-condition: prunen_ and prunem_ are set + """ + if self.mask_structure is None: + raise ValueError("mask_structure must be defined") + + self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":"))) diff --git a/src/sparseml/modifiers/pruning/wanda/utils/__init__.py b/src/sparseml/modifiers/pruning/wanda/utils/__init__.py index 5301319d194..ebdf28a6d5b 100644 --- a/src/sparseml/modifiers/pruning/wanda/utils/__init__.py +++ b/src/sparseml/modifiers/pruning/wanda/utils/__init__.py @@ -13,5 +13,3 @@ # limitations under the License. # flake8: noqa -from .helpers import * -from .wrapped_gpt import * diff --git a/src/sparseml/modifiers/pruning/wanda/utils/helpers.py b/src/sparseml/modifiers/pruning/wanda/utils/helpers.py deleted file mode 100644 index fe8466b84cd..00000000000 --- a/src/sparseml/modifiers/pruning/wanda/utils/helpers.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - - -def prepare_calibration_input(model, dataloader, device): - use_cache = model.config.use_cache - model.config.use_cache = False - layers = model.model.layers - - # dev = model.hf_device_map["model.embed_tokens"] - # if "model.embed_tokens" in model.hf_device_map: - # device = model.hf_device_map["model.embed_tokens"] - - dtype = next(iter(model.parameters())).dtype - inps = torch.zeros( - (128, model.seqlen, model.config.hidden_size), dtype=dtype, device=device - ) - inps.requires_grad = False - cache = {"i": 0, "attention_mask": None, "position_ids": None} - - class Catcher(torch.nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self, inp, **kwargs): - inps[cache["i"]] = inp - cache["i"] += 1 - cache["attention_mask"] = kwargs["attention_mask"] - cache["position_ids"] = kwargs["position_ids"] - raise ValueError - - layers[0] = Catcher(layers[0]) - for batch in dataloader: - try: - model(batch[0].to(device)) - except ValueError: - pass - layers[0] = layers[0].module - - outs = torch.zeros_like(inps) - attention_mask = cache["attention_mask"] - position_ids = cache["position_ids"] - model.config.use_cache = use_cache - - return inps, outs, attention_mask, position_ids - - -def find_layers(module, layers=[torch.nn.Linear], name=""): - """ - Recursively find the layers of a certain type in a module. - - Args: - module (torch.nn.Module): PyTorch module. - layers (list): List of layer types to find. - name (str): Name of the module. - - Returns: - dict: Dictionary of layers of the given type(s) within the module. - """ - if type(module) in layers: - return {name: module} - res = {} - for name1, child in module.named_children(): - res.update( - find_layers( - child, layers=layers, name=name + "." + name1 if name != "" else name1 - ) - ) - return res diff --git a/src/sparseml/modifiers/pruning/wanda/utils/layer_compressor.py b/src/sparseml/modifiers/pruning/wanda/utils/layer_compressor.py new file mode 100644 index 00000000000..08dfe36975c --- /dev/null +++ b/src/sparseml/modifiers/pruning/wanda/utils/layer_compressor.py @@ -0,0 +1,197 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import logging +from typing import Dict, List + +import torch +from torch.nn import Module + +from sparseml.modifiers.pruning.wanda.utils.wrapped_gpt import WrappedGPT +from sparseml.pytorch.utils.helpers import get_dependency_order +from sparseml.utils.pytorch.module import get_prunable_layers + + +__all__ = ["WandaLayerCompressor"] + +_LOGGER = logging.getLogger(__name__) + + +class WandaLayerCompressor: + """ + Runs the Wanda algorithm on a single layer using calibration data inputs + + Lifecycle: + - compress + - pre_compress_parallel (optional) + - add_batch + - fasterprune + - post_compress + + :param model: model containing the layer we are running compression on + :param layer: layer to run compression on + :param layer_index: index of layer in the model + :param inputs: calibration data to pass through the layer + :param args: additional keyword arguments + """ + + def __init__( + self, model: Module, layer: Module, layer_index: int, inputs: List, args: Dict + ): + self.model = model + self.layer = layer + self.layer_index = layer_index + self.inputs = inputs + self.args = args + + def compressible_modules(self) -> Dict: + """ + Get the list of modules in the layer that can be compressed + + :return: dictionary of compressible modules + """ + compressible_layers = get_prunable_layers(self.layer) + return compressible_layers + + def pre_compress_parallel(self, **kwargs) -> Dict: + """ + Sets up the WrappedGPT objects for each compressible module, + computes the statistics for each using the calibration data. + + :return: WrappedGPT objects for each module + """ + subset = self.compressible_modules() + + gpts = {} + for name in subset: + gpts[name] = WrappedGPT(subset[name]) + + def add_batch(name): + def tmp(_, inp, out): + gpts[name].add_batch(inp[0].data, out.data) + + return tmp + + handles = [] + for name in gpts: + handles.append(subset[name].register_forward_hook(add_batch(name))) + + # Run through the samples in order to compute statistics for each module + nsamples = len(self.inputs) + forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) + passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] + for sample_idx in range(nsamples): + passed_in_kwargs = {} + for arg in passed_in_args: + if isinstance(kwargs[arg], List): + passed_in_kwargs[arg] = kwargs[arg][sample_idx] + else: + passed_in_kwargs[arg] = kwargs[arg] + self.layer(self.inputs[sample_idx], **passed_in_kwargs) + for h in handles: + h.remove() + + return {"gpts": gpts} + + def compress(self, dev: str = "cuda:0", **kwargs) -> Dict: + """ + Run WANDA compression on all compressible modules in the layer + + :param dev: device to run computation on + """ + self.layer.to(dev) + self.sequentially_compress(**kwargs) + extras = self.post_compress(**kwargs) + return {"outputs": extras["outputs"]} + + def post_compress(self, **kwargs) -> Dict: + """ + Clean up after compression + + :return: outputs of the layer + """ + nsamples = len(self.inputs) + outputs = [] + forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) + passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] + for j in range(nsamples): + passed_in_kwargs = {} + for arg in passed_in_args: + if isinstance(kwargs[arg], List): + passed_in_kwargs[arg] = kwargs[arg][j] + else: + passed_in_kwargs[arg] = kwargs[arg] + outputs.append(self.layer(self.inputs[j], **passed_in_kwargs)[0]) + + self.inputs = None + # once we've finished compressing the layer, move it back to CPU to save memory + self.layer.to("cpu") + torch.cuda.empty_cache() + + return {"outputs": outputs} + + def sequentially_compress(self, **kwargs): + """ + Run compression module by module, in dependency order. Unlike in parallel + compression, we compute the statistics layer by layer instead of computing them + all up front. This saves on memory and means compression in earlier layers + affects the inputs to later layers + """ + subset = self.compressible_modules() + + # filter kwargs that are expected as layer inputs + forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) + passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] + + passed_in_kwargs = {} + for arg in passed_in_args: + if isinstance(kwargs[arg], List): # take the first batch + passed_in_kwargs[arg] = kwargs[arg][0] + else: + passed_in_kwargs[arg] = kwargs[arg] + order = get_dependency_order( + self.layer, subset, self.inputs[0], **passed_in_kwargs + ) + + nsamples = len(self.inputs) + for name in order: # create WrappedGPT object for each compressible module + gpts = WrappedGPT(subset[name]) + + def add_batch(name): + def tmp(_, inp, out): + gpts.add_batch(inp[0].data, out.data) + + return tmp + + # add WrappedGPT hook for current module + handle = subset[name].register_forward_hook(add_batch(name)) + for sample_idx in range(nsamples): + passed_in_kwargs = {} + for arg in passed_in_args: + if isinstance(kwargs[arg], List): + passed_in_kwargs[arg] = kwargs[arg][sample_idx] + else: + passed_in_kwargs[arg] = kwargs[arg] + # run layer, triggering WrappedGPT add_batch for current module + self.layer(self.inputs[sample_idx], **passed_in_kwargs) + handle.remove() + + _LOGGER.info(f"Compressing module {name} of layer {self.layer_index}") + gpts.fasterprune( # run WrappedGPT algorithm on current module + self.args["sparsity"], + prunen=self.args["prunen"], + prunem=self.args["prunem"], + ) + gpts.free() diff --git a/src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py b/src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py index 7b30d4f4e69..fda16110c86 100644 --- a/src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py +++ b/src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py @@ -12,31 +12,69 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import time + import torch import torch.nn as nn +try: + import transformers +except ImportError as err: + transformers = None + transformers_err = err + __all__ = ["WrappedGPT"] +DEBUG = False +_LOGGER = logging.getLogger(__name__) + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + + class WrappedGPT: """ - This class wraps a GPT layer for specific operations. + Runs Wanda on a single module that contains no sub-modules + + Lifecycle: + - add_batch + - fasterprune + - free + + + :param layer: module to run Wanda on """ - def __init__(self, layer, layer_id=0, layer_name="none"): + def __init__(self, layer): + if transformers is None: + raise transformers_err + self.layer = layer self.dev = self.layer.weight.device - self.rows = layer.weight.data.shape[0] - self.columns = layer.weight.data.shape[1] - - self.scaler_row = torch.zeros((self.columns), device=self.dev) + W = layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.rows = W.shape[0] + self.columns = W.shape[1] self.nsamples = 0 + self.scaler_row = torch.zeros((self.columns), device=self.dev) - self.layer_id = layer_id - self.layer_name = layer_name + def add_batch(self, inp: torch.Tensor, out: torch.Tensor): + """ + Add a batch of layer input and output data to the layer + statistics calculation - def add_batch(self, inp, out): + :param inp: tensor containing layer input + :param out: tensor containing layer output + """ + if DEBUG: + self._inp1 = inp + self.out1 = out if len(inp.shape) == 2: inp = inp.unsqueeze(0) tmp = inp.shape[0] @@ -50,3 +88,69 @@ def add_batch(self, inp, out): inp = inp.type(torch.float32) self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples + + def fasterprune( + self, + sparsity: float, + prunen: int = 0, + prunem: int = 0, + ): + """ + Run pruning and on the layer up to the target + sparsity value. + + :param sparsity: target sparsity to reach for layer + :param prunen: N for N:M pruning + :param prunem: M for N:M pruning + """ + W = self.layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.float() + + tick = time.time() + + W_metric = torch.abs(W) * torch.sqrt(self.scaler_row.reshape((1, -1))) + + # initialize a mask to be all False + W_mask = torch.zeros_like(W_metric) == 1 + if prunen != 0: + # structured n:m sparsity + for ii in range(W_metric.shape[1]): + if ii % prunem == 0: + tmp = W_metric[:, ii : (ii + prunem)].float() + W_mask.scatter_( + 1, + ii + torch.topk(tmp, prunen, dim=1, largest=False)[1], + True, + ) + else: + sort_res = torch.sort(W_metric, dim=-1, stable=True) + indices = sort_res[1][:, : int(W_metric.shape[1] * sparsity)] + W_mask.scatter_(1, indices, True) + + W[W_mask] = 0 # set weights to zero + + if torch.cuda.is_available(): + torch.cuda.synchronize() + _LOGGER.info("time %.2f" % (time.time() - tick)) + + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.layer.weight.data = W.reshape(self.layer.weight.shape).to( + self.layer.weight.data.dtype + ) + if DEBUG: + _LOGGER.debug(torch.sum((self.layer(self._inp1) - self.out1) ** 2)) + + def free(self): + """ + Free memory after the layer is complete + """ + if DEBUG: + self._inp1 = None + self.out1 = None + self.scaler_row = None + torch.cuda.empty_cache()