From 9b7cc6eb9b99bd56d19d04959442436360bd8a69 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Wed, 15 Nov 2023 07:38:36 -0500 Subject: [PATCH 1/9] Add wanda base --- .../modifiers/pruning/wanda/__init__.py | 13 +++++ src/sparseml/modifiers/pruning/wanda/base.py | 48 +++++++++++++++++++ tests/sparseml/modifiers/pruning/__init__.py | 13 +++++ .../modifiers/pruning/wanda/__init__.py | 13 +++++ .../modifiers/pruning/wanda/test_base.py | 40 ++++++++++++++++ 5 files changed, 127 insertions(+) create mode 100644 src/sparseml/modifiers/pruning/wanda/__init__.py create mode 100644 src/sparseml/modifiers/pruning/wanda/base.py create mode 100644 tests/sparseml/modifiers/pruning/__init__.py create mode 100644 tests/sparseml/modifiers/pruning/wanda/__init__.py create mode 100644 tests/sparseml/modifiers/pruning/wanda/test_base.py diff --git a/src/sparseml/modifiers/pruning/wanda/__init__.py b/src/sparseml/modifiers/pruning/wanda/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/modifiers/pruning/wanda/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/modifiers/pruning/wanda/base.py b/src/sparseml/modifiers/pruning/wanda/base.py new file mode 100644 index 00000000000..d1934eb4d69 --- /dev/null +++ b/src/sparseml/modifiers/pruning/wanda/base.py @@ -0,0 +1,48 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Union + +from sparseml.core import Modifier +from sparseml.core.state import State +from sparseml.utils import ALL_TOKEN + + +__all__ = ["WandaPruningModifier"] + + +class WandaPruningModifier(Modifier): + """ + Modifier for applying the one-shot WANDA algorithm to a model + from the paper: https://arxiv.org/abs/2306.11695 + """ + + sparsity: Union[float, List[float]] + block_size: int + targets: Union[str, List[str], None] = ALL_TOKEN + mask_structure: str = "unstructured" + + def on_initialize_structure(self, state: State, **kwargs): + pass # nothing needed for this modifier + + def compressible_layers(self) -> List: + """ + Retrieves the modules corresponding to a list of + compressible layer names + + :return: list of Pytorch modules to compress + """ + compressible_dict = self.model.get_layers(self.targets) + return [v for _, v in compressible_dict.items()] diff --git a/tests/sparseml/modifiers/pruning/__init__.py b/tests/sparseml/modifiers/pruning/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/modifiers/pruning/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/modifiers/pruning/wanda/__init__.py b/tests/sparseml/modifiers/pruning/wanda/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/modifiers/pruning/wanda/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/modifiers/pruning/wanda/test_base.py b/tests/sparseml/modifiers/pruning/wanda/test_base.py new file mode 100644 index 00000000000..82fad2f954a --- /dev/null +++ b/tests/sparseml/modifiers/pruning/wanda/test_base.py @@ -0,0 +1,40 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from sparseml.core.factory import ModifierFactory +from sparseml.core.framework import Framework +from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier +from tests.sparseml.modifiers.conf import setup_modifier_factory + + +def test_wanda_is_registered(): + + kwargs = dict( + sparsity=0.5, + block_size=128, + targets="__ALL_PRUNABLE__", + ) + setup_modifier_factory() + type_ = ModifierFactory.create( + type_="WandaPruningModifier", + framework=Framework.general, + allow_experimental=False, + allow_registered=True, + **kwargs, + ) + + assert isinstance( + type_, WandaPruningModifier + ), "PyTorch ConstantPruningModifier not registered" From 87171b7cef776a33aa4b1d200a9c05802d1671ba Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Wed, 15 Nov 2023 11:09:22 -0500 Subject: [PATCH 2/9] Initial implementation --- src/sparseml/modifiers/pruning/wanda/base.py | 3 +- .../modifiers/pruning/wanda/pytorch.py | 124 ++++++++++++++++++ .../modifiers/pruning/wanda/utils/__init__.py | 17 +++ .../modifiers/pruning/wanda/utils/helpers.py | 83 ++++++++++++ .../pruning/wanda/utils/wrapped_gpt.py | 52 ++++++++ .../modifiers/pruning/wanda/test_base.py | 1 - .../modifiers/pruning/wanda/test_pytorch.py | 39 ++++++ 7 files changed, 316 insertions(+), 3 deletions(-) create mode 100644 src/sparseml/modifiers/pruning/wanda/pytorch.py create mode 100644 src/sparseml/modifiers/pruning/wanda/utils/__init__.py create mode 100644 src/sparseml/modifiers/pruning/wanda/utils/helpers.py create mode 100644 src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py create mode 100644 tests/sparseml/pytorch/modifiers/pruning/wanda/test_pytorch.py diff --git a/src/sparseml/modifiers/pruning/wanda/base.py b/src/sparseml/modifiers/pruning/wanda/base.py index d1934eb4d69..1debc9e5e51 100644 --- a/src/sparseml/modifiers/pruning/wanda/base.py +++ b/src/sparseml/modifiers/pruning/wanda/base.py @@ -28,9 +28,8 @@ class WandaPruningModifier(Modifier): Modifier for applying the one-shot WANDA algorithm to a model from the paper: https://arxiv.org/abs/2306.11695 """ - + sparsity: Union[float, List[float]] - block_size: int targets: Union[str, List[str], None] = ALL_TOKEN mask_structure: str = "unstructured" diff --git a/src/sparseml/modifiers/pruning/wanda/pytorch.py b/src/sparseml/modifiers/pruning/wanda/pytorch.py new file mode 100644 index 00000000000..4262738b2ec --- /dev/null +++ b/src/sparseml/modifiers/pruning/wanda/pytorch.py @@ -0,0 +1,124 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch + +from sparseml.core.state import State +from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier +from sparseml.modifiers.pruning.wanda.utils.helpers import ( + find_layers, + prepare_calibration_input, +) +from sparseml.modifiers.pruning.wanda.utils.wrapped_gpt import WrappedGPT + + +_LOGGER = logging.getLogger(__name__) + + +class WandaPruningModifierPyTorch(WandaPruningModifier): + """ + PyTorch implementation of WandaPruningModifier + """ + + def on_initialize(self, state: State, **kwargs) -> bool: + modifiable_model = state.model + pytorch_model = modifiable_model.model + use_cache = pytorch_model.config.use_cache + + # set use_cache to False to avoid OOM + pytorch_model.config.use_cache = False + + _LOGGER.info("Preparing calibration data") + calibration_dataloader = state.data.calib + device = state.hardware.device + pytorch_model.to(device) + with torch.no_grad(): + inps, outs, attention_mask, position_ids = prepare_calibration_input( + pytorch_model, calibration_dataloader, device + ) + + layers = pytorch_model.model.layers + for i in range(len(layers)): + layer = layers[i] + subset = find_layers(layer) + wrapped_layers = {} + for name in subset: + wrapped_layers[name] = WrappedGPT(subset[name]) + + def add_batch(name): + def tmp(_, inp, out): + wrapped_layers[name].add_batch(inp[0].data, out.data) + + return tmp + + handles = [] + for name in wrapped_layers: + handles.append(subset[name].register_forward_hook(add_batch(name))) + for j in range(len(calibration_dataloader)): + with torch.no_grad(): + outs[j] = layer( + inps[j].unsqueeze(0), + attention_mask=attention_mask, + position_ids=position_ids, + )[0] + for h in handles: + h.remove() + if self.mask_structure == "unstructured": + prune_n = prune_m = 0 + else: + prune_n, prune_m = tuple(map(int, self.mask_structure.split(":"))) + + for name in subset: + _LOGGER.info(f"pruning layer {i} name {name}") + W_metric = torch.abs(subset[name].weight.data) * torch.sqrt( + wrapped_layers[name].scaler_row.reshape((1, -1)) + ) + + W_mask = ( + torch.zeros_like(W_metric) == 1 + ) # initialize a mask to be all False + if prune_n != 0: + # structured n:m sparsity + for ii in range(W_metric.shape[1]): + if ii % prune_m == 0: + tmp = W_metric[:, ii : (ii + prune_m)].float() + W_mask.scatter_( + 1, + ii + torch.topk(tmp, prune_n, dim=1, largest=False)[1], + True, + ) + else: + sort_res = torch.sort(W_metric, dim=-1, stable=True) + indices = sort_res[1][:, : int(W_metric.shape[1] * self.sparsity)] + W_mask.scatter_(1, indices, True) + + subset[name].weight.data[W_mask] = 0 # set weights to zero + + for j in range(len(calibration_dataloader)): + with torch.no_grad(): + outs[j] = layer( + inps[j].unsqueeze(0), + attention_mask=attention_mask, + position_ids=position_ids, + )[0] + inps, outs = outs, inps + + pytorch_model.config.use_cache = use_cache + torch.cuda.empty_cache() + return True + + def on_finalize(self, state: State, **kwargs): + return True diff --git a/src/sparseml/modifiers/pruning/wanda/utils/__init__.py b/src/sparseml/modifiers/pruning/wanda/utils/__init__.py new file mode 100644 index 00000000000..5301319d194 --- /dev/null +++ b/src/sparseml/modifiers/pruning/wanda/utils/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +from .helpers import * +from .wrapped_gpt import * diff --git a/src/sparseml/modifiers/pruning/wanda/utils/helpers.py b/src/sparseml/modifiers/pruning/wanda/utils/helpers.py new file mode 100644 index 00000000000..fe8466b84cd --- /dev/null +++ b/src/sparseml/modifiers/pruning/wanda/utils/helpers.py @@ -0,0 +1,83 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def prepare_calibration_input(model, dataloader, device): + use_cache = model.config.use_cache + model.config.use_cache = False + layers = model.model.layers + + # dev = model.hf_device_map["model.embed_tokens"] + # if "model.embed_tokens" in model.hf_device_map: + # device = model.hf_device_map["model.embed_tokens"] + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros( + (128, model.seqlen, model.config.hidden_size), dtype=dtype, device=device + ) + inps.requires_grad = False + cache = {"i": 0, "attention_mask": None, "position_ids": None} + + class Catcher(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, inp, **kwargs): + inps[cache["i"]] = inp + cache["i"] += 1 + cache["attention_mask"] = kwargs["attention_mask"] + cache["position_ids"] = kwargs["position_ids"] + raise ValueError + + layers[0] = Catcher(layers[0]) + for batch in dataloader: + try: + model(batch[0].to(device)) + except ValueError: + pass + layers[0] = layers[0].module + + outs = torch.zeros_like(inps) + attention_mask = cache["attention_mask"] + position_ids = cache["position_ids"] + model.config.use_cache = use_cache + + return inps, outs, attention_mask, position_ids + + +def find_layers(module, layers=[torch.nn.Linear], name=""): + """ + Recursively find the layers of a certain type in a module. + + Args: + module (torch.nn.Module): PyTorch module. + layers (list): List of layer types to find. + name (str): Name of the module. + + Returns: + dict: Dictionary of layers of the given type(s) within the module. + """ + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update( + find_layers( + child, layers=layers, name=name + "." + name1 if name != "" else name1 + ) + ) + return res diff --git a/src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py b/src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py new file mode 100644 index 00000000000..7b30d4f4e69 --- /dev/null +++ b/src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +__all__ = ["WrappedGPT"] + + +class WrappedGPT: + """ + This class wraps a GPT layer for specific operations. + """ + + def __init__(self, layer, layer_id=0, layer_name="none"): + self.layer = layer + self.dev = self.layer.weight.device + self.rows = layer.weight.data.shape[0] + self.columns = layer.weight.data.shape[1] + + self.scaler_row = torch.zeros((self.columns), device=self.dev) + self.nsamples = 0 + + self.layer_id = layer_id + self.layer_name = layer_name + + def add_batch(self, inp, out): + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + + self.scaler_row *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + + inp = inp.type(torch.float32) + self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples diff --git a/tests/sparseml/modifiers/pruning/wanda/test_base.py b/tests/sparseml/modifiers/pruning/wanda/test_base.py index 82fad2f954a..8dcb682020d 100644 --- a/tests/sparseml/modifiers/pruning/wanda/test_base.py +++ b/tests/sparseml/modifiers/pruning/wanda/test_base.py @@ -23,7 +23,6 @@ def test_wanda_is_registered(): kwargs = dict( sparsity=0.5, - block_size=128, targets="__ALL_PRUNABLE__", ) setup_modifier_factory() diff --git a/tests/sparseml/pytorch/modifiers/pruning/wanda/test_pytorch.py b/tests/sparseml/pytorch/modifiers/pruning/wanda/test_pytorch.py new file mode 100644 index 00000000000..2bdca703951 --- /dev/null +++ b/tests/sparseml/pytorch/modifiers/pruning/wanda/test_pytorch.py @@ -0,0 +1,39 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from sparseml.core.factory import ModifierFactory +from sparseml.core.framework import Framework +from tests.sparseml.modifiers.conf import setup_modifier_factory + + +def test_wanda_pytorch_is_registered(): + from sparseml.modifiers.pruning.wanda.pytorch import WandaPruningModifierPyTorch + + kwargs = dict( + sparsity=0.5, + targets="__ALL_PRUNABLE__", + ) + setup_modifier_factory() + type_ = ModifierFactory.create( + type_="WandaPruningModifier", + framework=Framework.pytorch, + allow_experimental=False, + allow_registered=True, + **kwargs, + ) + + assert isinstance( + type_, WandaPruningModifierPyTorch + ), "PyTorch ConstantPruningModifier not registered" From cf65bb537e4bd6951c51a79485215562084e3a76 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Thu, 16 Nov 2023 08:42:02 -0500 Subject: [PATCH 3/9] Update Wanda Base --- src/sparseml/modifiers/pruning/wanda/base.py | 59 +++++++++++++++++++- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/src/sparseml/modifiers/pruning/wanda/base.py b/src/sparseml/modifiers/pruning/wanda/base.py index 1debc9e5e51..45efb14b2af 100644 --- a/src/sparseml/modifiers/pruning/wanda/base.py +++ b/src/sparseml/modifiers/pruning/wanda/base.py @@ -16,6 +16,7 @@ from typing import List, Union from sparseml.core import Modifier +from sparseml.core.model.base import ModifiableModel from sparseml.core.state import State from sparseml.utils import ALL_TOKEN @@ -27,21 +28,73 @@ class WandaPruningModifier(Modifier): """ Modifier for applying the one-shot WANDA algorithm to a model from the paper: https://arxiv.org/abs/2306.11695 + + Life-cycle: + - initialze + - compress + - finalize + + :param sparsity: Sparsity to compress model to + :param mask_structure: String to define the structure of the mask to apply. + Must be of the form N:M where N, M are integers that define a custom block + shape. Defaults to 0:0 which represents an unstructured mask. + :param targets: list of layer names to compress during OBCQ, or '__ALL__' + to compress every layer in the model """ sparsity: Union[float, List[float]] + mask_structure: str = "0:0" targets: Union[str, List[str], None] = ALL_TOKEN - mask_structure: str = "unstructured" def on_initialize_structure(self, state: State, **kwargs): - pass # nothing needed for this modifier + """ + This modifier does not alter the model structure. + This method is a no-op. + + :param state: Unused, kept to conform to the parent method signature + :param kwargs: Unused, kept to conform to the parent method signature + """ def compressible_layers(self) -> List: """ Retrieves the modules corresponding to a list of compressible layer names - :return: list of Pytorch modules to compress + :precondition: self.model is set and is a `ModifiableModel` + :precondition: The `ModifiableModel` implements a `get_layers` + method + :return: list of modules to compress """ + if not isinstance(self.model, ModifiableModel): + raise ValueError( + "`self.model` must be a ModifiableModel to use " + f"the WANDA modifier but got {type(self.model)} instead" + ) + compressible_dict = self.model.get_layers(self.targets) return [v for _, v in compressible_dict.items()] + + def _validate_layerwise_sparsity(self): + if isinstance(self.sparsity, float): + # single sparsity will be applied to all layers + return + + if not isinstance(self.targets, List): + raise ValueError( + "Layer targets must be a list when specifying layer-wise" + f" sparsity. Got {type(self.targets)}" + ) + + if len(self.targets) != len(self.sparsity): + raise ValueError( + "Number of layer targets must match the number of " + f"sparsities. Got {len(self.targets)} layers and " + f"{len(self.sparsity)} sparsities" + ) + + for layer_name in self.targets: + if layer_name.startswith("re:"): + raise ValueError( + "Using regular expressions for layer-wise sparsity " + f"profiles is not permitted. Found {layer_name}" + ) From b8d4dff162b6768af3a59440bf7d6ca80bf4e530 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Thu, 16 Nov 2023 12:12:40 -0500 Subject: [PATCH 4/9] Refactor to use WandaLayerCompressor Update WrappedGPT --- src/sparseml/modifiers/pruning/wanda/base.py | 3 +- .../modifiers/pruning/wanda/pytorch.py | 227 +++++++++++------- .../modifiers/pruning/wanda/utils/__init__.py | 2 - .../modifiers/pruning/wanda/utils/helpers.py | 83 ------- .../pruning/wanda/utils/layer_compressor.py | 197 +++++++++++++++ .../pruning/wanda/utils/wrapped_gpt.py | 122 +++++++++- 6 files changed, 453 insertions(+), 181 deletions(-) delete mode 100644 src/sparseml/modifiers/pruning/wanda/utils/helpers.py create mode 100644 src/sparseml/modifiers/pruning/wanda/utils/layer_compressor.py diff --git a/src/sparseml/modifiers/pruning/wanda/base.py b/src/sparseml/modifiers/pruning/wanda/base.py index 45efb14b2af..35231d59ec6 100644 --- a/src/sparseml/modifiers/pruning/wanda/base.py +++ b/src/sparseml/modifiers/pruning/wanda/base.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import List, Union +from typing import List, Optional, Union from sparseml.core import Modifier from sparseml.core.model.base import ModifiableModel @@ -45,6 +45,7 @@ class WandaPruningModifier(Modifier): sparsity: Union[float, List[float]] mask_structure: str = "0:0" targets: Union[str, List[str], None] = ALL_TOKEN + compressible_layers_: Optional[List] = None def on_initialize_structure(self, state: State, **kwargs): """ diff --git a/src/sparseml/modifiers/pruning/wanda/pytorch.py b/src/sparseml/modifiers/pruning/wanda/pytorch.py index 4262738b2ec..21544a9699b 100644 --- a/src/sparseml/modifiers/pruning/wanda/pytorch.py +++ b/src/sparseml/modifiers/pruning/wanda/pytorch.py @@ -13,16 +13,15 @@ # limitations under the License. import logging +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch +from sparseml.core.model.base import ModifiableModel from sparseml.core.state import State +from sparseml.modifiers.obcq.utils.helpers import cache_attention_inputs from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier -from sparseml.modifiers.pruning.wanda.utils.helpers import ( - find_layers, - prepare_calibration_input, -) -from sparseml.modifiers.pruning.wanda.utils.wrapped_gpt import WrappedGPT +from sparseml.modifiers.pruning.wanda.utils.layer_compressor import WandaLayerCompressor _LOGGER = logging.getLogger(__name__) @@ -33,92 +32,148 @@ class WandaPruningModifierPyTorch(WandaPruningModifier): PyTorch implementation of WandaPruningModifier """ + model: Optional[ModifiableModel] = None + device_: str = "cuda:0" + layer_prefix_: Optional[str] = None + prunen_: Optional[int] = None + prunem_: Optional[int] = None + def on_initialize(self, state: State, **kwargs) -> bool: - modifiable_model = state.model - pytorch_model = modifiable_model.model - use_cache = pytorch_model.config.use_cache - - # set use_cache to False to avoid OOM - pytorch_model.config.use_cache = False - - _LOGGER.info("Preparing calibration data") - calibration_dataloader = state.data.calib - device = state.hardware.device - pytorch_model.to(device) - with torch.no_grad(): - inps, outs, attention_mask, position_ids = prepare_calibration_input( - pytorch_model, calibration_dataloader, device - ) + """ + Initialize and run the WANDA algorithm on the current state - layers = pytorch_model.model.layers - for i in range(len(layers)): - layer = layers[i] - subset = find_layers(layer) - wrapped_layers = {} - for name in subset: - wrapped_layers[name] = WrappedGPT(subset[name]) - - def add_batch(name): - def tmp(_, inp, out): - wrapped_layers[name].add_batch(inp[0].data, out.data) - - return tmp - - handles = [] - for name in wrapped_layers: - handles.append(subset[name].register_forward_hook(add_batch(name))) - for j in range(len(calibration_dataloader)): - with torch.no_grad(): - outs[j] = layer( - inps[j].unsqueeze(0), - attention_mask=attention_mask, - position_ids=position_ids, - )[0] - for h in handles: - h.remove() - if self.mask_structure == "unstructured": - prune_n = prune_m = 0 - else: - prune_n, prune_m = tuple(map(int, self.mask_structure.split(":"))) - - for name in subset: - _LOGGER.info(f"pruning layer {i} name {name}") - W_metric = torch.abs(subset[name].weight.data) * torch.sqrt( - wrapped_layers[name].scaler_row.reshape((1, -1)) - ) + :param state: session state storing input model and calibration data + """ + self._validate_layerwise_sparsity() + + self.initialize_wanda(state, **kwargs) - W_mask = ( - torch.zeros_like(W_metric) == 1 - ) # initialize a mask to be all False - if prune_n != 0: - # structured n:m sparsity - for ii in range(W_metric.shape[1]): - if ii % prune_m == 0: - tmp = W_metric[:, ii : (ii + prune_m)].float() - W_mask.scatter_( - 1, - ii + torch.topk(tmp, prune_n, dim=1, largest=False)[1], - True, - ) - else: - sort_res = torch.sort(W_metric, dim=-1, stable=True) - indices = sort_res[1][:, : int(W_metric.shape[1] * self.sparsity)] - W_mask.scatter_(1, indices, True) - - subset[name].weight.data[W_mask] = 0 # set weights to zero - - for j in range(len(calibration_dataloader)): - with torch.no_grad(): - outs[j] = layer( - inps[j].unsqueeze(0), - attention_mask=attention_mask, - position_ids=position_ids, - )[0] - inps, outs = outs, inps - - pytorch_model.config.use_cache = use_cache + # run wanda on calibration data + self.apply_wanda(dataloader=state.data.calib) torch.cuda.empty_cache() return True + def initialize_wanda(self, state: State, **kwargs): + """ + Setup for WANDA, initializes the model, device, + and other parameters, also initilializes the + compressible layers of model, and sets the device + + :param state: session state storing input model and calibration data + """ + self.model = state.model + self.compressible_layers_ = self.compressible_layers() + self.device_ = self._set_device(device=state.hardware.device) + self.layer_prefix_ = self.model.layer_prefix + self._infer_mask_block_size() + + @torch.no_grad() + def apply_wanda( + self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None + ) -> Dict: + """ + Run Wanda on the loaded model, using dataloader as calibration data + + :param dataloader: calibration data for WANDA + """ + accum_kwargs = {"dataloader": dataloader} + pytorch_model = self.model.model + + # Step 0: Pass the calibration data through the (compressed) bottom part of the + # network, capturing the outputs which will become the inputs to the first + # decoder layer. Also return attention_mask as part of kwargs + extras = self.compress_bottom( + dev=self.device_, + layer_prefix=self.layer_prefix_, + **accum_kwargs, + ) + accum_kwargs.update(extras) + + # Step 1: Sequentially prune decoder layers + inputs = None + num_layers = len(self.compressible_layers_) + for idx, layer in enumerate(self.compressible_layers_): + if "outputs" not in accum_kwargs: + raise RuntimeError( + "The 'outputs' key is expected but not found from the " + "return of the bottom compressor" + ) + + inputs = accum_kwargs["outputs"] + layer_sparsity = ( + self.sparsity[idx] if isinstance(self.sparsity, List) else self.sparsity + ) + _LOGGER.info( + f"\n===== Compressing layer {idx+1}/{num_layers} " + f"to sparsity {layer_sparsity} =====" + ) + args = { + "sparsity": layer_sparsity, + "prunen": self.prunen_, + "prunem": self.prunem_, + } + # Prune using WandaGPT + layer_compressor = WandaLayerCompressor( + model=pytorch_model, + layer=layer, + layer_index=idx, + inputs=inputs, + args=args, + ) + layer_kwargs = layer_compressor.compress(dev=self.device_, **accum_kwargs) + accum_kwargs.update(layer_kwargs) + + def compress_bottom( + self, + dataloader: List = None, + nsamples: int = None, + dev: str = "cuda:0", + layer_prefix: Optional[str] = None, + ) -> Dict: + """ + Runs calibration data through the bottom part of the network (everything up + to the first decoder layer) and return the captured outputs + + :param dataloader: calibration data to pass through the model + :param nsamples: number of samples to use for calibration, or None to use it all + :param dev: device to use + :param layer_prefix: name of model attribute that contains the list of layers, + i.e. model.decoder for OPT or just model for Llama + :return: outputs from bottom part of network, attention mask, and kv-cache state + """ + layer_prefix = layer_prefix or self.layer_prefix_ + cached_inputs = cache_attention_inputs( + model=self.model.model, + dataloader=dataloader, + device=dev, + nsamples=nsamples, + target_ids=None, + layer_prefix=layer_prefix, + ) + + outputs = cached_inputs.pop("inputs") + outputs = [o[0] for o in outputs] + cached_inputs.update({"outputs": outputs}) + return cached_inputs + def on_finalize(self, state: State, **kwargs): return True + + def _set_device(self, device: str): + if "cuda" in device and not torch.cuda.is_available(): + self.device_ = "cpu" + else: + self.device_ = device + + def _infer_mask_block_size(self): + """ + Infer the mask block size from the mask structure. + Parses mask_structure of the form N:M where N, M are integers that + define a custom block shape; and sets prunen_ and prunem_ accordingly. + + :post-condition: prunen_ and prunem_ are set + """ + if self.mask_structure is None: + raise ValueError("mask_structure must be defined") + + self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":"))) diff --git a/src/sparseml/modifiers/pruning/wanda/utils/__init__.py b/src/sparseml/modifiers/pruning/wanda/utils/__init__.py index 5301319d194..ebdf28a6d5b 100644 --- a/src/sparseml/modifiers/pruning/wanda/utils/__init__.py +++ b/src/sparseml/modifiers/pruning/wanda/utils/__init__.py @@ -13,5 +13,3 @@ # limitations under the License. # flake8: noqa -from .helpers import * -from .wrapped_gpt import * diff --git a/src/sparseml/modifiers/pruning/wanda/utils/helpers.py b/src/sparseml/modifiers/pruning/wanda/utils/helpers.py deleted file mode 100644 index fe8466b84cd..00000000000 --- a/src/sparseml/modifiers/pruning/wanda/utils/helpers.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - - -def prepare_calibration_input(model, dataloader, device): - use_cache = model.config.use_cache - model.config.use_cache = False - layers = model.model.layers - - # dev = model.hf_device_map["model.embed_tokens"] - # if "model.embed_tokens" in model.hf_device_map: - # device = model.hf_device_map["model.embed_tokens"] - - dtype = next(iter(model.parameters())).dtype - inps = torch.zeros( - (128, model.seqlen, model.config.hidden_size), dtype=dtype, device=device - ) - inps.requires_grad = False - cache = {"i": 0, "attention_mask": None, "position_ids": None} - - class Catcher(torch.nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self, inp, **kwargs): - inps[cache["i"]] = inp - cache["i"] += 1 - cache["attention_mask"] = kwargs["attention_mask"] - cache["position_ids"] = kwargs["position_ids"] - raise ValueError - - layers[0] = Catcher(layers[0]) - for batch in dataloader: - try: - model(batch[0].to(device)) - except ValueError: - pass - layers[0] = layers[0].module - - outs = torch.zeros_like(inps) - attention_mask = cache["attention_mask"] - position_ids = cache["position_ids"] - model.config.use_cache = use_cache - - return inps, outs, attention_mask, position_ids - - -def find_layers(module, layers=[torch.nn.Linear], name=""): - """ - Recursively find the layers of a certain type in a module. - - Args: - module (torch.nn.Module): PyTorch module. - layers (list): List of layer types to find. - name (str): Name of the module. - - Returns: - dict: Dictionary of layers of the given type(s) within the module. - """ - if type(module) in layers: - return {name: module} - res = {} - for name1, child in module.named_children(): - res.update( - find_layers( - child, layers=layers, name=name + "." + name1 if name != "" else name1 - ) - ) - return res diff --git a/src/sparseml/modifiers/pruning/wanda/utils/layer_compressor.py b/src/sparseml/modifiers/pruning/wanda/utils/layer_compressor.py new file mode 100644 index 00000000000..08dfe36975c --- /dev/null +++ b/src/sparseml/modifiers/pruning/wanda/utils/layer_compressor.py @@ -0,0 +1,197 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import logging +from typing import Dict, List + +import torch +from torch.nn import Module + +from sparseml.modifiers.pruning.wanda.utils.wrapped_gpt import WrappedGPT +from sparseml.pytorch.utils.helpers import get_dependency_order +from sparseml.utils.pytorch.module import get_prunable_layers + + +__all__ = ["WandaLayerCompressor"] + +_LOGGER = logging.getLogger(__name__) + + +class WandaLayerCompressor: + """ + Runs the Wanda algorithm on a single layer using calibration data inputs + + Lifecycle: + - compress + - pre_compress_parallel (optional) + - add_batch + - fasterprune + - post_compress + + :param model: model containing the layer we are running compression on + :param layer: layer to run compression on + :param layer_index: index of layer in the model + :param inputs: calibration data to pass through the layer + :param args: additional keyword arguments + """ + + def __init__( + self, model: Module, layer: Module, layer_index: int, inputs: List, args: Dict + ): + self.model = model + self.layer = layer + self.layer_index = layer_index + self.inputs = inputs + self.args = args + + def compressible_modules(self) -> Dict: + """ + Get the list of modules in the layer that can be compressed + + :return: dictionary of compressible modules + """ + compressible_layers = get_prunable_layers(self.layer) + return compressible_layers + + def pre_compress_parallel(self, **kwargs) -> Dict: + """ + Sets up the WrappedGPT objects for each compressible module, + computes the statistics for each using the calibration data. + + :return: WrappedGPT objects for each module + """ + subset = self.compressible_modules() + + gpts = {} + for name in subset: + gpts[name] = WrappedGPT(subset[name]) + + def add_batch(name): + def tmp(_, inp, out): + gpts[name].add_batch(inp[0].data, out.data) + + return tmp + + handles = [] + for name in gpts: + handles.append(subset[name].register_forward_hook(add_batch(name))) + + # Run through the samples in order to compute statistics for each module + nsamples = len(self.inputs) + forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) + passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] + for sample_idx in range(nsamples): + passed_in_kwargs = {} + for arg in passed_in_args: + if isinstance(kwargs[arg], List): + passed_in_kwargs[arg] = kwargs[arg][sample_idx] + else: + passed_in_kwargs[arg] = kwargs[arg] + self.layer(self.inputs[sample_idx], **passed_in_kwargs) + for h in handles: + h.remove() + + return {"gpts": gpts} + + def compress(self, dev: str = "cuda:0", **kwargs) -> Dict: + """ + Run WANDA compression on all compressible modules in the layer + + :param dev: device to run computation on + """ + self.layer.to(dev) + self.sequentially_compress(**kwargs) + extras = self.post_compress(**kwargs) + return {"outputs": extras["outputs"]} + + def post_compress(self, **kwargs) -> Dict: + """ + Clean up after compression + + :return: outputs of the layer + """ + nsamples = len(self.inputs) + outputs = [] + forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) + passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] + for j in range(nsamples): + passed_in_kwargs = {} + for arg in passed_in_args: + if isinstance(kwargs[arg], List): + passed_in_kwargs[arg] = kwargs[arg][j] + else: + passed_in_kwargs[arg] = kwargs[arg] + outputs.append(self.layer(self.inputs[j], **passed_in_kwargs)[0]) + + self.inputs = None + # once we've finished compressing the layer, move it back to CPU to save memory + self.layer.to("cpu") + torch.cuda.empty_cache() + + return {"outputs": outputs} + + def sequentially_compress(self, **kwargs): + """ + Run compression module by module, in dependency order. Unlike in parallel + compression, we compute the statistics layer by layer instead of computing them + all up front. This saves on memory and means compression in earlier layers + affects the inputs to later layers + """ + subset = self.compressible_modules() + + # filter kwargs that are expected as layer inputs + forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) + passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] + + passed_in_kwargs = {} + for arg in passed_in_args: + if isinstance(kwargs[arg], List): # take the first batch + passed_in_kwargs[arg] = kwargs[arg][0] + else: + passed_in_kwargs[arg] = kwargs[arg] + order = get_dependency_order( + self.layer, subset, self.inputs[0], **passed_in_kwargs + ) + + nsamples = len(self.inputs) + for name in order: # create WrappedGPT object for each compressible module + gpts = WrappedGPT(subset[name]) + + def add_batch(name): + def tmp(_, inp, out): + gpts.add_batch(inp[0].data, out.data) + + return tmp + + # add WrappedGPT hook for current module + handle = subset[name].register_forward_hook(add_batch(name)) + for sample_idx in range(nsamples): + passed_in_kwargs = {} + for arg in passed_in_args: + if isinstance(kwargs[arg], List): + passed_in_kwargs[arg] = kwargs[arg][sample_idx] + else: + passed_in_kwargs[arg] = kwargs[arg] + # run layer, triggering WrappedGPT add_batch for current module + self.layer(self.inputs[sample_idx], **passed_in_kwargs) + handle.remove() + + _LOGGER.info(f"Compressing module {name} of layer {self.layer_index}") + gpts.fasterprune( # run WrappedGPT algorithm on current module + self.args["sparsity"], + prunen=self.args["prunen"], + prunem=self.args["prunem"], + ) + gpts.free() diff --git a/src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py b/src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py index 7b30d4f4e69..fda16110c86 100644 --- a/src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py +++ b/src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py @@ -12,31 +12,69 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import time + import torch import torch.nn as nn +try: + import transformers +except ImportError as err: + transformers = None + transformers_err = err + __all__ = ["WrappedGPT"] +DEBUG = False +_LOGGER = logging.getLogger(__name__) + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + + class WrappedGPT: """ - This class wraps a GPT layer for specific operations. + Runs Wanda on a single module that contains no sub-modules + + Lifecycle: + - add_batch + - fasterprune + - free + + + :param layer: module to run Wanda on """ - def __init__(self, layer, layer_id=0, layer_name="none"): + def __init__(self, layer): + if transformers is None: + raise transformers_err + self.layer = layer self.dev = self.layer.weight.device - self.rows = layer.weight.data.shape[0] - self.columns = layer.weight.data.shape[1] - - self.scaler_row = torch.zeros((self.columns), device=self.dev) + W = layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.rows = W.shape[0] + self.columns = W.shape[1] self.nsamples = 0 + self.scaler_row = torch.zeros((self.columns), device=self.dev) - self.layer_id = layer_id - self.layer_name = layer_name + def add_batch(self, inp: torch.Tensor, out: torch.Tensor): + """ + Add a batch of layer input and output data to the layer + statistics calculation - def add_batch(self, inp, out): + :param inp: tensor containing layer input + :param out: tensor containing layer output + """ + if DEBUG: + self._inp1 = inp + self.out1 = out if len(inp.shape) == 2: inp = inp.unsqueeze(0) tmp = inp.shape[0] @@ -50,3 +88,69 @@ def add_batch(self, inp, out): inp = inp.type(torch.float32) self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples + + def fasterprune( + self, + sparsity: float, + prunen: int = 0, + prunem: int = 0, + ): + """ + Run pruning and on the layer up to the target + sparsity value. + + :param sparsity: target sparsity to reach for layer + :param prunen: N for N:M pruning + :param prunem: M for N:M pruning + """ + W = self.layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.float() + + tick = time.time() + + W_metric = torch.abs(W) * torch.sqrt(self.scaler_row.reshape((1, -1))) + + # initialize a mask to be all False + W_mask = torch.zeros_like(W_metric) == 1 + if prunen != 0: + # structured n:m sparsity + for ii in range(W_metric.shape[1]): + if ii % prunem == 0: + tmp = W_metric[:, ii : (ii + prunem)].float() + W_mask.scatter_( + 1, + ii + torch.topk(tmp, prunen, dim=1, largest=False)[1], + True, + ) + else: + sort_res = torch.sort(W_metric, dim=-1, stable=True) + indices = sort_res[1][:, : int(W_metric.shape[1] * sparsity)] + W_mask.scatter_(1, indices, True) + + W[W_mask] = 0 # set weights to zero + + if torch.cuda.is_available(): + torch.cuda.synchronize() + _LOGGER.info("time %.2f" % (time.time() - tick)) + + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.layer.weight.data = W.reshape(self.layer.weight.shape).to( + self.layer.weight.data.dtype + ) + if DEBUG: + _LOGGER.debug(torch.sum((self.layer(self._inp1) - self.out1) ** 2)) + + def free(self): + """ + Free memory after the layer is complete + """ + if DEBUG: + self._inp1 = None + self.out1 = None + self.scaler_row = None + torch.cuda.empty_cache() From 563efd7388a3beeca476cdbddecbb72089e5fce0 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Thu, 16 Nov 2023 13:39:13 -0500 Subject: [PATCH 5/9] Rename WrappedGPT to WandaGPT --- .../modifiers/pruning/wanda/utils/layer_compressor.py | 6 +++--- src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/sparseml/modifiers/pruning/wanda/utils/layer_compressor.py b/src/sparseml/modifiers/pruning/wanda/utils/layer_compressor.py index 08dfe36975c..eb663808e99 100644 --- a/src/sparseml/modifiers/pruning/wanda/utils/layer_compressor.py +++ b/src/sparseml/modifiers/pruning/wanda/utils/layer_compressor.py @@ -19,7 +19,7 @@ import torch from torch.nn import Module -from sparseml.modifiers.pruning.wanda.utils.wrapped_gpt import WrappedGPT +from sparseml.modifiers.pruning.wanda.utils.wrapped_gpt import WandaGPT from sparseml.pytorch.utils.helpers import get_dependency_order from sparseml.utils.pytorch.module import get_prunable_layers @@ -76,7 +76,7 @@ def pre_compress_parallel(self, **kwargs) -> Dict: gpts = {} for name in subset: - gpts[name] = WrappedGPT(subset[name]) + gpts[name] = WandaGPT(subset[name]) def add_batch(name): def tmp(_, inp, out): @@ -167,7 +167,7 @@ def sequentially_compress(self, **kwargs): nsamples = len(self.inputs) for name in order: # create WrappedGPT object for each compressible module - gpts = WrappedGPT(subset[name]) + gpts = WandaGPT(subset[name]) def add_batch(name): def tmp(_, inp, out): diff --git a/src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py b/src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py index fda16110c86..04d453e8c20 100644 --- a/src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py +++ b/src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py @@ -25,7 +25,7 @@ transformers = None transformers_err = err -__all__ = ["WrappedGPT"] +__all__ = ["WandaGPT"] DEBUG = False @@ -35,7 +35,7 @@ torch.backends.cudnn.allow_tf32 = False -class WrappedGPT: +class WandaGPT: """ Runs Wanda on a single module that contains no sub-modules From 5f24ff9126d1f045780dcb5677538c548b8f536e Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Mon, 18 Dec 2023 09:55:32 -0500 Subject: [PATCH 6/9] [Wanda Refactor] Wanda/OBCQ Modifier Refactor (#1887) * Define GPT contract * rename tmp -> batch_size * Define LayerCompressor Contract * Rename gpt_helpers to gpts Fix some docstrings * add named argument to function call * Wanda/OBCQ refactor * propagate target-ids * Address review comments from * #1885 * #1886 --- src/sparseml/modifiers/obcq/base.py | 42 +--- src/sparseml/modifiers/obcq/pytorch.py | 166 ++------------ .../modifiers/obcq/utils/layer_compressor.py | 170 ++------------ .../modifiers/obcq/utils/sparsegpt.py | 37 +-- src/sparseml/modifiers/pruning/wanda/base.py | 3 +- .../modifiers/pruning/wanda/pytorch.py | 51 +++-- .../pruning/wanda/utils/layer_compressor.py | 160 ++----------- .../{wrapped_gpt.py => module_compressor.py} | 47 +--- .../modifiers/utils/layer_compressor.py | 210 ++++++++++++++++++ .../modifiers/utils/module_compressor.py | 89 ++++++++ 10 files changed, 414 insertions(+), 561 deletions(-) rename src/sparseml/modifiers/pruning/wanda/utils/{wrapped_gpt.py => module_compressor.py} (79%) create mode 100644 src/sparseml/modifiers/utils/layer_compressor.py create mode 100644 src/sparseml/modifiers/utils/module_compressor.py diff --git a/src/sparseml/modifiers/obcq/base.py b/src/sparseml/modifiers/obcq/base.py index 33e3de1d75c..20a543bdb36 100644 --- a/src/sparseml/modifiers/obcq/base.py +++ b/src/sparseml/modifiers/obcq/base.py @@ -15,10 +15,9 @@ import logging from typing import Any, Dict, List, Optional, Union -from sparseml.core import Modifier from sparseml.core.factory import ModifierFactory from sparseml.core.state import State -from sparseml.utils import ALL_TOKEN +from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier __all__ = ["SparseGPTModifier"] @@ -26,7 +25,7 @@ _LOGGER = logging.getLogger(__name__) -class SparseGPTModifier(Modifier): +class SparseGPTModifier(WandaPruningModifier): """ Modifier for applying the one-shot OBCQ algorithm to a model @@ -54,18 +53,14 @@ class SparseGPTModifier(Modifier): has been deprecated and will be removed in a future release """ - sparsity: Union[float, List[float]] block_size: int quantize: Union[bool, Dict] dampening_frac: Optional[float] = 0.01 sequential_update: Optional[bool] = True - mask_structure: str = "0:0" prunen_: Optional[int] = None prunem_: Optional[int] = None - targets: Union[str, List[str], None] = ALL_TOKEN target_ids: Optional[List[str]] = None layer_prefix: Optional[str] = None - compressible_layers_: Optional[List] = None quantization_modifier_: Any = None def __post_init__(self): @@ -75,15 +70,6 @@ def __post_init__(self): "removed in a future release" ) - def compressible_layers(self) -> List: - """ - Retrieves the modules corresponding to a list of compressible layer names - - :return: list of Pytorch modules to compress - """ - compressible_dict = self.model.get_layers(self.targets) - return [v for _, v in compressible_dict.items()] - def on_initialize_structure(self, state: State, **kwargs): quantization_already_active = state.model.qat_active() if isinstance(self.quantize, bool): @@ -143,27 +129,3 @@ def _build_quant_modifier_from_dict(self, quant_config, framework): allow_experimental=True, **modifier_args, ) - - def _validate_layerwise_sparsity(self): - if isinstance(self.sparsity, float): - return # single sparsity will be applied to all layers - - if not isinstance(self.targets, List): - raise ValueError( - "Layer targets must be a list when specifying layer-wise" - f" sparsity. Got {self.targets}" - ) - - if len(self.targets) != len(self.sparsity): - raise ValueError( - "Number of layer targets must match the number of " - f"sparsities. Got {len(self.targets)} layers and " - f"{len(self.sparsity)} sparsities" - ) - - for layer_name in self.targets: - if layer_name.startswith("re:"): - raise ValueError( - "Using regular expressions for layer-wise sparsity " - f"profiles is not permitted. Found {layer_name}" - ) diff --git a/src/sparseml/modifiers/obcq/pytorch.py b/src/sparseml/modifiers/obcq/pytorch.py index d6d3de4594c..d11135731d9 100644 --- a/src/sparseml/modifiers/obcq/pytorch.py +++ b/src/sparseml/modifiers/obcq/pytorch.py @@ -13,30 +13,27 @@ # limitations under the License. -import logging -from typing import Any, Dict, Iterable, List, Optional, Tuple +from functools import partial +from typing import Any, Optional -import torch - -from sparseml.core.model import ModifiableModel from sparseml.core.state import State from sparseml.modifiers.obcq.base import SparseGPTModifier -from sparseml.modifiers.obcq.utils.helpers import cache_attention_inputs -from sparseml.modifiers.obcq.utils.layer_compressor import LayerCompressor +from sparseml.modifiers.obcq.utils.layer_compressor import OBCQLayerCompressor +from sparseml.modifiers.pruning.wanda.pytorch import WandaPruningModifierPyTorch -_LOGGER = logging.getLogger(__name__) +__all__ = ["SparseGPTModifierPyTorch"] -class SparseGPTModifierPyTorch(SparseGPTModifier): +class SparseGPTModifierPyTorch(WandaPruningModifierPyTorch, SparseGPTModifier): """ Pytorch implementation of SparseGPT Lifecycle: - on_initialize - - initialize_obcq + - setup - compressible_layers - - apply_obcq + - prune - compress_bottom - LayerCompressor.compress - on_finalize @@ -47,6 +44,7 @@ class SparseGPTModifierPyTorch(SparseGPTModifier): model: Any = None device_: str = "cuda:0" layer_prefix_: Optional[str] = None + layer_compressor_class_ = OBCQLayerCompressor def on_initialize(self, state: "State", **kwargs) -> bool: """ @@ -60,152 +58,30 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self.on_initialize_structure(state, **kwargs) if self.quantization_modifier_: self.quantization_modifier_.initialize(state, **kwargs) - modifiable_model = state.model - calibration_dataloader = state.data.calib - device = state.hardware.device - - self.initialize_obcq(modifiable_model, device) - self.apply_obcq(calibration_dataloader) - return True + # attach target_ids to `compress_bottom` for OBCQ + # this must be done before calling super().on_initialize - def initialize_obcq( - self, - model: "ModifiableModel", - device: Optional[str] = "cuda:0", - ): - """ - Setup for SparseGPT, initialize the the compressible layers of model, and set - the device + self.compress_bottom = partial(self.compress_bottom, target_ids=self.target_ids) + return super().on_initialize(state=state, **kwargs) - :param model: PyTorch model to sparsify - :param device: device to run sparsification on, preferably a GPU - """ - self.model = model - self.compressible_layers_ = self.compressible_layers() - self.layer_prefix_ = model.layer_prefix - self.model = self.model.model - self._set_device(device) - self._infer_mask_block_size() - - @torch.no_grad() - def apply_obcq( - self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None - ) -> Dict: - """ - Run OBCQ on the loaded model, using dataloader as calibration data - - :param dataloader: calibration data for OBCQ - """ - accum_kwargs = {"dataloader": dataloader} - - # Step 0: Pass the calibration data through the (compressed) bottom part of the - # network, capturing the outputs which will become the inputs to the first - # decoder layer. Also return attention_mask as part of kwargs - extras = self.compress_bottom( - dev=self.device_, - target_ids=self.target_ids, - layer_prefix=self.layer_prefix_, - **accum_kwargs, - ) - accum_kwargs.update(extras) - - # Step 1: Sequentially prune/quantize decoder layers - inputs = None - num_layers = len(self.compressible_layers_) - for idx, layer in enumerate(self.compressible_layers_): - if "outputs" not in accum_kwargs: - raise RuntimeError( - "The 'outputs' key is expected but not found from the " - "return of the bottom compressor" - ) - - inputs = accum_kwargs["outputs"] - layer_sparsity = ( - self.sparsity[idx] if isinstance(self.sparsity, List) else self.sparsity - ) - _LOGGER.info( - f"\n===== Compressing layer {idx+1}/{num_layers} " - f"to sparsity {layer_sparsity} =====" - ) - args = { - "sparsity": layer_sparsity, - "prunen": self.prunen_, - "prunem": self.prunem_, + def _get_compression_args(self, layer_sparsity): + return { + **super()._get_compression_args(layer_sparsity=layer_sparsity), + **{ "blocksize": self.block_size, "percdamp": self.dampening_frac, "sequential_update": self.sequential_update, "quantize": self.quantize, - } - layer_compressor = LayerCompressor(self.model, layer, idx, inputs, args) + }, + } - # Prune/quantize using SparseGPT - layer_kwargs = layer_compressor.compress(dev=self.device_, **accum_kwargs) - accum_kwargs.update(layer_kwargs) - - def on_finalize(self, state: "State", **kwargs) -> bool: + def on_finalize(self, state: State, **kwargs) -> bool: """ disable the observers used by the OBCQ algorithm and set kv-cache configuration :param state: un-used, for matching spec of Modifier base class """ - if self.quantization_modifier_: self.quantization_modifier_.finalize(state, **kwargs) - - return True - - def compress_bottom( - self, - dataloader: List = None, - nsamples: int = None, - dev: str = "cuda:0", - target_ids: List[str] = None, - layer_prefix: Optional[str] = None, - ) -> Dict: - """ - Runs calibration data through the bottom part of the network (everything up - to the first decoder layer) and return the captured outputs - - :param dataloader: calibration data to pass through the model - :param nsamples: number of samples to use for calibration, or None to use it all - :param dev: device to use - :param target_ids: list of keys in model output to cache, NOTE: this argument - has been deprecated and will be removed in a future release - :param layer_prefix: name of model attribute that contains the list of layers, - i.e. model.decoder for OPT or just model for Llama - :return: outputs from bottom part of network, attention mask, and kv-cache state - """ - layer_prefix = layer_prefix or self.layer_prefix_ - cached_inputs = cache_attention_inputs( - model=self.model, - dataloader=dataloader, - device=dev, - nsamples=nsamples, - target_ids=target_ids, - layer_prefix=layer_prefix, - ) - - outputs = cached_inputs.pop("inputs") - outputs = [o[0] for o in outputs] - cached_inputs.update({"outputs": outputs}) - return cached_inputs - - def _set_device(self, device: str): - if "cuda" in device and not torch.cuda.is_available(): - self.device_ = "cpu" - else: - self.device_ = device - - def _infer_mask_block_size(self): - """ - Infer the mask block size from the mask structure. - Parses mask_structure of the form N:M where N, M are integers that - define a custom block shape; and sets prunen_ and prunem_ accordingly. - - :post-condition: prunen_ and prunem_ are set - """ - if self.mask_structure is None: - raise ValueError("mask_structure must be defined") - - self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":"))) + return super().on_finalize(state, **kwargs) diff --git a/src/sparseml/modifiers/obcq/utils/layer_compressor.py b/src/sparseml/modifiers/obcq/utils/layer_compressor.py index 7dd1b1885cd..887b651dad4 100644 --- a/src/sparseml/modifiers/obcq/utils/layer_compressor.py +++ b/src/sparseml/modifiers/obcq/utils/layer_compressor.py @@ -12,22 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -import logging -from typing import Dict, List -import torch -from torch.nn import Module +import logging +from typing import Dict from sparseml.modifiers.obcq.utils.sparsegpt import SparseGPT -from sparseml.pytorch.utils.helpers import get_dependency_order -from sparseml.utils.pytorch.module import get_prunable_layers +from sparseml.modifiers.utils.layer_compressor import LayerCompressor +from sparseml.modifiers.utils.module_compressor import ModuleCompressor + +__all__ = ["OBCQLayerCompressor"] _LOGGER = logging.getLogger(__name__) -class LayerCompressor: +class OBCQLayerCompressor(LayerCompressor): """ Runs the SparseGPT algorithm on a single layer using calibration data inputs @@ -45,63 +44,7 @@ class LayerCompressor: :param args: additional keyword arguments """ - def __init__( - self, model: Module, layer: Module, layer_index: int, inputs: List, args: Dict - ): - self.model = model - self.layer = layer - self.layer_index = layer_index - self.inputs = inputs - self.args = args - - def compressible_modules(self) -> Dict: - """ - Get the list of modules in the layer that can be compressed - - :return: dictionary of compressible modules - """ - compressible_layers = get_prunable_layers(self.layer) - return compressible_layers - - def pre_compress_parallel(self, **kwargs) -> Dict: - """ - Sets up the SparseGPT objects for each compressible module, computes the Hessian - for each using the calibration data. - - :return: SparseGPT objects for each module - """ - subset = self.compressible_modules() - - gpts = {} - for name in subset: - gpts[name] = SparseGPT(subset[name]) - - def add_batch(name): - def tmp(_, inp, out): - gpts[name].add_batch(inp[0].data, out.data) - - return tmp - - handles = [] - for name in gpts: - handles.append(subset[name].register_forward_hook(add_batch(name))) - - # Run through the samples in order to compute Hessian matrix for each module - nsamples = len(self.inputs) - forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) - passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] - for sample_idx in range(nsamples): - passed_in_kwargs = {} - for arg in passed_in_args: - if isinstance(kwargs[arg], List): - passed_in_kwargs[arg] = kwargs[arg][sample_idx] - else: - passed_in_kwargs[arg] = kwargs[arg] - self.layer(self.inputs[sample_idx], **passed_in_kwargs) - for h in handles: - h.remove() - - return {"gpts": gpts} + module_compressor_class: ModuleCompressor = SparseGPT def compress(self, dev: str = "cuda:0", **kwargs) -> Dict: """ @@ -116,14 +59,7 @@ def compress(self, dev: str = "cuda:0", **kwargs) -> Dict: gpts = extras["gpts"] for name in gpts: _LOGGER.info(f"Compressing {name}...") - sparsity = self.args["sparsity"] - gpts[name].fasterprune( - sparsity, - prunen=self.args["prunen"], - prunem=self.args["prunem"], - percdamp=self.args["percdamp"], - blocksize=self.args["blocksize"], - ) + self.invoke_fasterprune(module_compressor=gpts[name]) gpts[name].free() else: # Hessians computed layer by layer @@ -133,84 +69,12 @@ def compress(self, dev: str = "cuda:0", **kwargs) -> Dict: return {"outputs": extras["outputs"]} - def post_compress(self, **kwargs) -> Dict: - """ - Clean up after compression - - :return: outputs of the layer - """ - nsamples = len(self.inputs) - outputs = [] - forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) - passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] - for j in range(nsamples): - passed_in_kwargs = {} - for arg in passed_in_args: - if isinstance(kwargs[arg], List): - passed_in_kwargs[arg] = kwargs[arg][j] - else: - passed_in_kwargs[arg] = kwargs[arg] - outputs.append(self.layer(self.inputs[j], **passed_in_kwargs)[0]) - - self.inputs = None - # once we've finished compressing the layer, move it back to CPU to save memory - self.layer.to("cpu") - torch.cuda.empty_cache() - - return {"outputs": outputs} - - def sequentially_compress(self, **kwargs): - """ - Run compression module by module, in dependency order. Unlike in parallel - compression, we compute the Hessians layer by layer instead of computing them - all up front. This saves on memory and means compression in earlier layers - affects the inputs to later layers - """ - subset = self.compressible_modules() - - # filter kwargs that are expected as layer inputs - forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) - passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] - - passed_in_kwargs = {} - for arg in passed_in_args: - if isinstance(kwargs[arg], List): # take the first batch - passed_in_kwargs[arg] = kwargs[arg][0] - else: - passed_in_kwargs[arg] = kwargs[arg] - order = get_dependency_order( - self.layer, subset, self.inputs[0], **passed_in_kwargs + def invoke_fasterprune(self, module_compressor: SparseGPT): + # run SparseGPT algorithm on current module + module_compressor.fasterprune( + sparsity=self.args["sparsity"], + prunen=self.args["prunen"], + prunem=self.args["prunem"], + percdamp=self.args["percdamp"], + blocksize=self.args["blocksize"], ) - - nsamples = len(self.inputs) - for name in order: # create SparseGPT object for each compressible module - gpts = SparseGPT(subset[name]) - - def add_batch(name): - def tmp(_, inp, out): - gpts.add_batch(inp[0].data, out.data) - - return tmp - - # add SparseGPT hook for current module - handle = subset[name].register_forward_hook(add_batch(name)) - for sample_idx in range(nsamples): - passed_in_kwargs = {} - for arg in passed_in_args: - if isinstance(kwargs[arg], List): - passed_in_kwargs[arg] = kwargs[arg][sample_idx] - else: - passed_in_kwargs[arg] = kwargs[arg] - # run layer, triggering SparseGPT add_batch for current module - self.layer(self.inputs[sample_idx], **passed_in_kwargs) - handle.remove() - - _LOGGER.info(f"Compressing module {name} of layer {self.layer_index}") - gpts.fasterprune( # run SparseGPT algorithm on current module - self.args["sparsity"], - prunen=self.args["prunen"], - prunem=self.args["prunem"], - percdamp=self.args["percdamp"], - blocksize=self.args["blocksize"], - ) - gpts.free() diff --git a/src/sparseml/modifiers/obcq/utils/sparsegpt.py b/src/sparseml/modifiers/obcq/utils/sparsegpt.py index 41569bf88df..1123a7e7752 100644 --- a/src/sparseml/modifiers/obcq/utils/sparsegpt.py +++ b/src/sparseml/modifiers/obcq/utils/sparsegpt.py @@ -19,6 +19,8 @@ import torch import torch.nn as nn +from sparseml.modifiers.utils.module_compressor import ModuleCompressor + try: import transformers @@ -26,6 +28,8 @@ transformers = None transformers_err = err +__all__ = ["SparseGPT"] + DEBUG = False _LOGGER = logging.getLogger(__name__) @@ -34,7 +38,7 @@ torch.backends.cudnn.allow_tf32 = False -class SparseGPT: +class SparseGPT(ModuleCompressor): """ Runs SparseGPT on a single module that contains no sub-modules @@ -48,20 +52,8 @@ class SparseGPT: """ def __init__(self, layer): - if transformers is None: - raise transformers_err - - self.layer = layer - self.dev = self.layer.weight.device - W = layer.weight.data.clone() - if isinstance(self.layer, nn.Conv2d): - W = W.flatten(1) - if isinstance(self.layer, transformers.Conv1D): - W = W.t() - self.rows = W.shape[0] - self.columns = W.shape[1] + super().__init__(layer=layer) self.H = torch.zeros((self.columns, self.columns), device=self.dev) - self.nsamples = 0 def add_batch(self, inp: torch.Tensor, out: torch.Tensor): """ @@ -70,20 +62,18 @@ def add_batch(self, inp: torch.Tensor, out: torch.Tensor): :param inp: tensor containing layer input :param out: tensor containing layer our """ - if DEBUG: - self._inp1 = inp - self.out1 = out + self.store_inps_outs_for_debugging(inp, out) if len(inp.shape) == 2: inp = inp.unsqueeze(0) - tmp = inp.shape[0] + batch_size = inp.shape[0] if isinstance(self.layer, nn.Linear) or isinstance( self.layer, transformers.Conv1D ): if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) inp = inp.t() - self.H *= self.nsamples / (self.nsamples + tmp) - self.nsamples += tmp + self.H *= self.nsamples / (self.nsamples + batch_size) + self.nsamples += batch_size inp = math.sqrt(2 / self.nsamples) * inp.float() self.H += inp.matmul(inp.t()) @@ -104,7 +94,7 @@ def fasterprune( :param prunem: M for N:M pruning :param blocksize: Number of columns to compress in one pass :param percdamp: Amount of dampening to apply to H, as a fraction of the - diagonal norm + diagonal norm """ W = self.layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d): @@ -216,8 +206,5 @@ def free(self): """ Free the Hessian memory after the layer is complete """ - if DEBUG: - self._inp1 = None - self.out1 = None self.H = None - torch.cuda.empty_cache() + super().free() diff --git a/src/sparseml/modifiers/pruning/wanda/base.py b/src/sparseml/modifiers/pruning/wanda/base.py index 35231d59ec6..253642d78e7 100644 --- a/src/sparseml/modifiers/pruning/wanda/base.py +++ b/src/sparseml/modifiers/pruning/wanda/base.py @@ -69,7 +69,8 @@ def compressible_layers(self) -> List: if not isinstance(self.model, ModifiableModel): raise ValueError( "`self.model` must be a ModifiableModel to use " - f"the WANDA modifier but got {type(self.model)} instead" + f"the {self.__class__.__qualname__} modifier but got " + f"{type(self.model)} instead" ) compressible_dict = self.model.get_layers(self.targets) diff --git a/src/sparseml/modifiers/pruning/wanda/pytorch.py b/src/sparseml/modifiers/pruning/wanda/pytorch.py index 21544a9699b..c51dc6f34fb 100644 --- a/src/sparseml/modifiers/pruning/wanda/pytorch.py +++ b/src/sparseml/modifiers/pruning/wanda/pytorch.py @@ -29,7 +29,18 @@ class WandaPruningModifierPyTorch(WandaPruningModifier): """ - PyTorch implementation of WandaPruningModifier + Pytorch implementation of WandaPruningModifier + + Lifecycle: + - on_initialize + - setup + - compressible_layers + - prune + - compress_bottom + - LayerCompressor.compress + - on_finalize + + :param model: `ModifiableModel` to perform wanda on, in-place """ model: Optional[ModifiableModel] = None @@ -37,6 +48,7 @@ class WandaPruningModifierPyTorch(WandaPruningModifier): layer_prefix_: Optional[str] = None prunen_: Optional[int] = None prunem_: Optional[int] = None + layer_compressor_class_ = WandaLayerCompressor def on_initialize(self, state: State, **kwargs) -> bool: """ @@ -45,15 +57,14 @@ def on_initialize(self, state: State, **kwargs) -> bool: :param state: session state storing input model and calibration data """ self._validate_layerwise_sparsity() + self.setup(state=state, **kwargs) - self.initialize_wanda(state, **kwargs) - - # run wanda on calibration data - self.apply_wanda(dataloader=state.data.calib) + # run on calibration data + self.prune(dataloader=state.data.calib) torch.cuda.empty_cache() return True - def initialize_wanda(self, state: State, **kwargs): + def setup(self, state: State, **kwargs): """ Setup for WANDA, initializes the model, device, and other parameters, also initilializes the @@ -68,7 +79,7 @@ def initialize_wanda(self, state: State, **kwargs): self._infer_mask_block_size() @torch.no_grad() - def apply_wanda( + def prune( self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None ) -> Dict: """ @@ -107,13 +118,9 @@ def apply_wanda( f"\n===== Compressing layer {idx+1}/{num_layers} " f"to sparsity {layer_sparsity} =====" ) - args = { - "sparsity": layer_sparsity, - "prunen": self.prunen_, - "prunem": self.prunem_, - } - # Prune using WandaGPT - layer_compressor = WandaLayerCompressor( + args = self._get_compression_args(layer_sparsity=layer_sparsity) + # Prune using GPT + layer_compressor = self.layer_compressor_class_( model=pytorch_model, layer=layer, layer_index=idx, @@ -123,12 +130,20 @@ def apply_wanda( layer_kwargs = layer_compressor.compress(dev=self.device_, **accum_kwargs) accum_kwargs.update(layer_kwargs) + def _get_compression_args(self, layer_sparsity): + return { + "sparsity": layer_sparsity, + "prunen": self.prunen_, + "prunem": self.prunem_, + } + def compress_bottom( self, dataloader: List = None, nsamples: int = None, dev: str = "cuda:0", layer_prefix: Optional[str] = None, + target_ids: Optional[List[int]] = None, ) -> Dict: """ Runs calibration data through the bottom part of the network (everything up @@ -139,15 +154,19 @@ def compress_bottom( :param dev: device to use :param layer_prefix: name of model attribute that contains the list of layers, i.e. model.decoder for OPT or just model for Llama + :param target_ids: list of keys in model output to cache, NOTE: this argument + has been deprecated and will be removed in a future release, also must be + set to None for Wanda :return: outputs from bottom part of network, attention mask, and kv-cache state """ layer_prefix = layer_prefix or self.layer_prefix_ + pytorch_model = self.model.model cached_inputs = cache_attention_inputs( - model=self.model.model, + model=pytorch_model, dataloader=dataloader, device=dev, nsamples=nsamples, - target_ids=None, + target_ids=target_ids, layer_prefix=layer_prefix, ) diff --git a/src/sparseml/modifiers/pruning/wanda/utils/layer_compressor.py b/src/sparseml/modifiers/pruning/wanda/utils/layer_compressor.py index eb663808e99..fe59e620b6f 100644 --- a/src/sparseml/modifiers/pruning/wanda/utils/layer_compressor.py +++ b/src/sparseml/modifiers/pruning/wanda/utils/layer_compressor.py @@ -12,24 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -import logging -from typing import Dict, List -import torch -from torch.nn import Module +from typing import Dict -from sparseml.modifiers.pruning.wanda.utils.wrapped_gpt import WandaGPT -from sparseml.pytorch.utils.helpers import get_dependency_order -from sparseml.utils.pytorch.module import get_prunable_layers +from sparseml.modifiers.pruning.wanda.utils.module_compressor import ( + WandaModuleCompressor, +) +from sparseml.modifiers.utils.layer_compressor import LayerCompressor +from sparseml.modifiers.utils.module_compressor import ModuleCompressor __all__ = ["WandaLayerCompressor"] -_LOGGER = logging.getLogger(__name__) - -class WandaLayerCompressor: +class WandaLayerCompressor(LayerCompressor): """ Runs the Wanda algorithm on a single layer using calibration data inputs @@ -47,63 +43,7 @@ class WandaLayerCompressor: :param args: additional keyword arguments """ - def __init__( - self, model: Module, layer: Module, layer_index: int, inputs: List, args: Dict - ): - self.model = model - self.layer = layer - self.layer_index = layer_index - self.inputs = inputs - self.args = args - - def compressible_modules(self) -> Dict: - """ - Get the list of modules in the layer that can be compressed - - :return: dictionary of compressible modules - """ - compressible_layers = get_prunable_layers(self.layer) - return compressible_layers - - def pre_compress_parallel(self, **kwargs) -> Dict: - """ - Sets up the WrappedGPT objects for each compressible module, - computes the statistics for each using the calibration data. - - :return: WrappedGPT objects for each module - """ - subset = self.compressible_modules() - - gpts = {} - for name in subset: - gpts[name] = WandaGPT(subset[name]) - - def add_batch(name): - def tmp(_, inp, out): - gpts[name].add_batch(inp[0].data, out.data) - - return tmp - - handles = [] - for name in gpts: - handles.append(subset[name].register_forward_hook(add_batch(name))) - - # Run through the samples in order to compute statistics for each module - nsamples = len(self.inputs) - forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) - passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] - for sample_idx in range(nsamples): - passed_in_kwargs = {} - for arg in passed_in_args: - if isinstance(kwargs[arg], List): - passed_in_kwargs[arg] = kwargs[arg][sample_idx] - else: - passed_in_kwargs[arg] = kwargs[arg] - self.layer(self.inputs[sample_idx], **passed_in_kwargs) - for h in handles: - h.remove() - - return {"gpts": gpts} + module_compressor_class: ModuleCompressor = WandaModuleCompressor def compress(self, dev: str = "cuda:0", **kwargs) -> Dict: """ @@ -116,82 +56,10 @@ def compress(self, dev: str = "cuda:0", **kwargs) -> Dict: extras = self.post_compress(**kwargs) return {"outputs": extras["outputs"]} - def post_compress(self, **kwargs) -> Dict: - """ - Clean up after compression - - :return: outputs of the layer - """ - nsamples = len(self.inputs) - outputs = [] - forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) - passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] - for j in range(nsamples): - passed_in_kwargs = {} - for arg in passed_in_args: - if isinstance(kwargs[arg], List): - passed_in_kwargs[arg] = kwargs[arg][j] - else: - passed_in_kwargs[arg] = kwargs[arg] - outputs.append(self.layer(self.inputs[j], **passed_in_kwargs)[0]) - - self.inputs = None - # once we've finished compressing the layer, move it back to CPU to save memory - self.layer.to("cpu") - torch.cuda.empty_cache() - - return {"outputs": outputs} - - def sequentially_compress(self, **kwargs): - """ - Run compression module by module, in dependency order. Unlike in parallel - compression, we compute the statistics layer by layer instead of computing them - all up front. This saves on memory and means compression in earlier layers - affects the inputs to later layers - """ - subset = self.compressible_modules() - - # filter kwargs that are expected as layer inputs - forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) - passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] - - passed_in_kwargs = {} - for arg in passed_in_args: - if isinstance(kwargs[arg], List): # take the first batch - passed_in_kwargs[arg] = kwargs[arg][0] - else: - passed_in_kwargs[arg] = kwargs[arg] - order = get_dependency_order( - self.layer, subset, self.inputs[0], **passed_in_kwargs + def invoke_fasterprune(self, module_compressor: "WandaModuleCompressor"): + # run WandaGPT algorithm on current module + module_compressor.fasterprune( + self.args["sparsity"], + prunen=self.args["prunen"], + prunem=self.args["prunem"], ) - - nsamples = len(self.inputs) - for name in order: # create WrappedGPT object for each compressible module - gpts = WandaGPT(subset[name]) - - def add_batch(name): - def tmp(_, inp, out): - gpts.add_batch(inp[0].data, out.data) - - return tmp - - # add WrappedGPT hook for current module - handle = subset[name].register_forward_hook(add_batch(name)) - for sample_idx in range(nsamples): - passed_in_kwargs = {} - for arg in passed_in_args: - if isinstance(kwargs[arg], List): - passed_in_kwargs[arg] = kwargs[arg][sample_idx] - else: - passed_in_kwargs[arg] = kwargs[arg] - # run layer, triggering WrappedGPT add_batch for current module - self.layer(self.inputs[sample_idx], **passed_in_kwargs) - handle.remove() - - _LOGGER.info(f"Compressing module {name} of layer {self.layer_index}") - gpts.fasterprune( # run WrappedGPT algorithm on current module - self.args["sparsity"], - prunen=self.args["prunen"], - prunem=self.args["prunem"], - ) - gpts.free() diff --git a/src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py b/src/sparseml/modifiers/pruning/wanda/utils/module_compressor.py similarity index 79% rename from src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py rename to src/sparseml/modifiers/pruning/wanda/utils/module_compressor.py index 04d453e8c20..bb77d182298 100644 --- a/src/sparseml/modifiers/pruning/wanda/utils/wrapped_gpt.py +++ b/src/sparseml/modifiers/pruning/wanda/utils/module_compressor.py @@ -18,6 +18,8 @@ import torch import torch.nn as nn +from sparseml.modifiers.utils.module_compressor import ModuleCompressor + try: import transformers @@ -25,7 +27,7 @@ transformers = None transformers_err = err -__all__ = ["WandaGPT"] +__all__ = ["WandaModuleCompressor"] DEBUG = False @@ -35,33 +37,14 @@ torch.backends.cudnn.allow_tf32 = False -class WandaGPT: +class WandaModuleCompressor(ModuleCompressor): """ - Runs Wanda on a single module that contains no sub-modules - - Lifecycle: - - add_batch - - fasterprune - - free - - - :param layer: module to run Wanda on + Runs WANDA on a single module that contains no sub-modules + see https://arxiv.org/abs/2306.11695 """ def __init__(self, layer): - if transformers is None: - raise transformers_err - - self.layer = layer - self.dev = self.layer.weight.device - W = layer.weight.data.clone() - if isinstance(self.layer, nn.Conv2d): - W = W.flatten(1) - if isinstance(self.layer, transformers.Conv1D): - W = W.t() - self.rows = W.shape[0] - self.columns = W.shape[1] - self.nsamples = 0 + super().__init__(layer=layer) self.scaler_row = torch.zeros((self.columns), device=self.dev) def add_batch(self, inp: torch.Tensor, out: torch.Tensor): @@ -72,20 +55,17 @@ def add_batch(self, inp: torch.Tensor, out: torch.Tensor): :param inp: tensor containing layer input :param out: tensor containing layer output """ - if DEBUG: - self._inp1 = inp - self.out1 = out + self.store_inps_outs_for_debugging(inp, out) if len(inp.shape) == 2: inp = inp.unsqueeze(0) - tmp = inp.shape[0] + batch_size = inp.shape[0] if isinstance(self.layer, nn.Linear): if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) inp = inp.t() - self.scaler_row *= self.nsamples / (self.nsamples + tmp) - self.nsamples += tmp - + self.scaler_row *= self.nsamples / (self.nsamples + batch_size) + self.nsamples += batch_size inp = inp.type(torch.float32) self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples @@ -149,8 +129,5 @@ def free(self): """ Free memory after the layer is complete """ - if DEBUG: - self._inp1 = None - self.out1 = None self.scaler_row = None - torch.cuda.empty_cache() + super().free() diff --git a/src/sparseml/modifiers/utils/layer_compressor.py b/src/sparseml/modifiers/utils/layer_compressor.py new file mode 100644 index 00000000000..c07e4d98cb2 --- /dev/null +++ b/src/sparseml/modifiers/utils/layer_compressor.py @@ -0,0 +1,210 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import logging +from abc import ABC +from typing import Dict, List + +import torch +from torch.nn import Module + +from sparseml.modifiers.utils.module_compressor import ModuleCompressor +from sparseml.pytorch.utils.helpers import get_dependency_order +from sparseml.utils.pytorch.module import get_prunable_layers + + +__all__ = ["LayerCompressor"] +_LOGGER = logging.getLogger(__name__) + + +class LayerCompressor(ABC): + """ + Defines the contract to run GPT algorithms on a + single layer using calibration data inputs + + Example Lifecycle: + - compress + - pre_compress_parallel (optional) + - add_batch + - fasterprune + - post_compress + + Note: inheriting classes must define the gpt_class attribute, + and implement the invoke_fasterprune, and compress methods. + + :param model: model containing the layer we are running compression on + :param layer: layer to run compression on + :param layer_index: index of layer in the model + :param inputs: calibration data to pass through the layer + :param args: additional keyword arguments + """ + + module_compressor_class: ModuleCompressor + + def __init__( + self, model: Module, layer: Module, layer_index: int, inputs: List, args: Dict + ): + self.model = model + self.layer = layer + self.layer_index = layer_index + self.inputs = inputs + self.args = args + + def compress(self, dev: str = "cuda:0", **kwargs) -> Dict: + """ + Run GPT compression on all compressible modules in the layer + + :param dev: device to run computation on + """ + raise NotImplementedError() + + def invoke_fasterprune(self, module_compressor: ModuleCompressor): + """ + Invoke fasterprune method on the GPT object + + :param gpts: Instantiated GPT object + :raises NotImplementedError: inheritor must provide an + implementation for this method + """ + raise NotImplementedError() + + def compressible_modules(self) -> Dict: + """ + Get the list of modules in the layer that can be compressed + + :return: dictionary of compressible modules + """ + compressible_layers = get_prunable_layers(self.layer) + return compressible_layers + + def pre_compress_parallel(self, **kwargs) -> Dict: + """ + Sets up the SparseGPT objects for each compressible module, computes the Hessian + for each using the calibration data. + + :return: SparseGPT objects for each module + """ + subset = self.compressible_modules() + + gpts = {} + for name in subset: + gpts[name] = self.module_compressor_class(subset[name]) + + def add_batch(name): + def tmp(_, inp, out): + gpts[name].add_batch(inp[0].data, out.data) + + return tmp + + handles = [] + for name in gpts: + handles.append(subset[name].register_forward_hook(add_batch(name))) + + # Run through the samples in order to compute Hessian matrix for each module + nsamples = len(self.inputs) + forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) + passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] + for sample_idx in range(nsamples): + passed_in_kwargs = {} + for arg in passed_in_args: + if isinstance(kwargs[arg], List): + passed_in_kwargs[arg] = kwargs[arg][sample_idx] + else: + passed_in_kwargs[arg] = kwargs[arg] + self.layer(self.inputs[sample_idx], **passed_in_kwargs) + for h in handles: + h.remove() + + return {"gpts": gpts} + + def post_compress(self, **kwargs) -> Dict: + """ + Clean up after compression + + :return: outputs of the layer + """ + nsamples = len(self.inputs) + outputs = [] + forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) + passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] + for j in range(nsamples): + passed_in_kwargs = {} + for arg in passed_in_args: + if isinstance(kwargs[arg], List): + passed_in_kwargs[arg] = kwargs[arg][j] + else: + passed_in_kwargs[arg] = kwargs[arg] + outputs.append(self.layer(self.inputs[j], **passed_in_kwargs)[0]) + + self.inputs = None + # once we've finished compressing the layer, move it back to CPU to save memory + self.layer.to("cpu") + torch.cuda.empty_cache() + + return {"outputs": outputs} + + def sequentially_compress(self, **kwargs): + """ + Run compression module by module, in dependency order. Unlike in parallel + compression, we compute the Hessians layer by layer instead of computing them + all up front. This saves on memory and means compression in earlier layers + affects the inputs to later layers + """ + subset = self.compressible_modules() + + # filter kwargs that are expected as layer inputs + forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) + passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] + + passed_in_kwargs = {} + for arg in passed_in_args: + if isinstance(kwargs[arg], List): # take the first batch + passed_in_kwargs[arg] = kwargs[arg][0] + else: + passed_in_kwargs[arg] = kwargs[arg] + order = get_dependency_order( + self.layer, subset, self.inputs[0], **passed_in_kwargs + ) + + nsamples = len(self.inputs) + for ( + name + ) in order: # create ModuleCompressor object for each compressible module + gpts: ModuleCompressor = self.module_compressor_class(subset[name]) + + def add_batch(name): + def tmp(_, inp, out): + gpts.add_batch(inp[0].data, out.data) + + return tmp + + # add ModuleCompressor hook for current module + handle = subset[name].register_forward_hook(add_batch(name)) + for sample_idx in range(nsamples): + passed_in_kwargs = {} + for arg in passed_in_args: + if isinstance(kwargs[arg], List): + passed_in_kwargs[arg] = kwargs[arg][sample_idx] + else: + passed_in_kwargs[arg] = kwargs[arg] + # run layer, triggering SparseGPT add_batch for current module + self.layer(self.inputs[sample_idx], **passed_in_kwargs) + handle.remove() + + _LOGGER.info(f"Compressing module {name} of layer {self.layer_index}") + + # run compression algorithm on current module + self.invoke_fasterprune(module_compressor=gpts) + gpts.free() diff --git a/src/sparseml/modifiers/utils/module_compressor.py b/src/sparseml/modifiers/utils/module_compressor.py new file mode 100644 index 00000000000..f53be2ade56 --- /dev/null +++ b/src/sparseml/modifiers/utils/module_compressor.py @@ -0,0 +1,89 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC + +import torch +import torch.nn as nn + + +try: + import transformers +except ImportError as err: + transformers = None + transformers_err = err + +__all__ = ["ModuleCompressor"] + + +DEBUG = False + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + + +class ModuleCompressor(ABC): + """ + Base Abstract class for pruning/quantization a single module + with no sub-modules using information from input/output + statistics + + :param layer: module to run compression on + """ + + def __init__(self, layer): + if transformers is None: + raise transformers_err + + self.layer = layer + self.dev = self.layer.weight.device + W = layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.rows = W.shape[0] + self.columns = W.shape[1] + self.nsamples = 0 + + def store_inps_outs_for_debugging(self, inp, out): + if DEBUG: + self._inp1 = inp + self.out1 = out + + def free(self): + """ + Free memory after the layer is complete + calls torch.cuda.empty_cache() to defragement GPU memory + """ + if DEBUG: + if hasattr(self, "_inp1"): + self._inp1 = None + if hasattr(self, "out1"): + self.out1 = None + torch.cuda.empty_cache() + + def add_batch(self, *args, **kwargs): + """ + Add a batch of layer input and output data to the layer + statistics calculation + """ + raise NotImplementedError("Child class must implement `add_batch`") + + def fasterprune(self, *args, **kwargs): + """ + Run pruning and on the layer up to the target + sparsity + """ + raise NotImplementedError("Child class must implement `fasterprune`") From 149f2970dcfd6249d0a3476e0cdd15a1b0bf86fe Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Mon, 18 Dec 2023 10:08:45 -0500 Subject: [PATCH 7/9] Fix typo --- src/sparseml/modifiers/pruning/wanda/utils/module_compressor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/modifiers/pruning/wanda/utils/module_compressor.py b/src/sparseml/modifiers/pruning/wanda/utils/module_compressor.py index bb77d182298..fcd71882bae 100644 --- a/src/sparseml/modifiers/pruning/wanda/utils/module_compressor.py +++ b/src/sparseml/modifiers/pruning/wanda/utils/module_compressor.py @@ -76,7 +76,7 @@ def fasterprune( prunem: int = 0, ): """ - Run pruning and on the layer up to the target + Run pruning on the layer up to the target sparsity value. :param sparsity: target sparsity to reach for layer From 084ded5bc1be8ec76cf2eb024602b05e23ebc213 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Mon, 18 Dec 2023 11:05:58 -0500 Subject: [PATCH 8/9] Update test --- src/sparseml/modifiers/obcq/pytorch.py | 5 ++++- tests/sparseml/transformers/obcq/test_obcq.py | 9 +++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/sparseml/modifiers/obcq/pytorch.py b/src/sparseml/modifiers/obcq/pytorch.py index d11135731d9..07bf9e5d4e3 100644 --- a/src/sparseml/modifiers/obcq/pytorch.py +++ b/src/sparseml/modifiers/obcq/pytorch.py @@ -62,7 +62,10 @@ def on_initialize(self, state: "State", **kwargs) -> bool: # attach target_ids to `compress_bottom` for OBCQ # this must be done before calling super().on_initialize - self.compress_bottom = partial(self.compress_bottom, target_ids=self.target_ids) + compress_bottom = partial(self.compress_bottom, target_ids=self.target_ids) + + # we need setattr here because of Pydantic's internal data model + object.__setattr__(self, "compress_bottom", compress_bottom) return super().on_initialize(state=state, **kwargs) def _get_compression_args(self, layer_sparsity): diff --git a/tests/sparseml/transformers/obcq/test_obcq.py b/tests/sparseml/transformers/obcq/test_obcq.py index e7c419705e3..06585d2f717 100644 --- a/tests/sparseml/transformers/obcq/test_obcq.py +++ b/tests/sparseml/transformers/obcq/test_obcq.py @@ -18,7 +18,7 @@ import torch from sparseml.core.framework import Framework -from sparseml.core.model import ModifiableModel +from sparseml.core.state import State from sparseml.modifiers.obcq import SparseGPTModifier from sparseml.modifiers.obcq.utils.helpers import ppl_eval_general from sparseml.pytorch.utils.helpers import tensor_sparsity @@ -76,7 +76,6 @@ def test_lm_head_target(): device = "cpu" model = SparseCausalLM.auto_model_from_pretrained(tiny_model_path) - modifiable_model = ModifiableModel(model=model, framework=Framework.pytorch) kwargs = { "sparsity": 0.5, @@ -95,11 +94,13 @@ def test_lm_head_target(): sparsegpt_modifier_no_head = SparseGPTModifier( framework=Framework.pytorch, **kwargs ) - sparsegpt_modifier_no_head.initialize_obcq(model=modifiable_model, device=device) + state = State(framework=Framework.pytorch) + state.update(model=model, device=device) + sparsegpt_modifier_no_head.setup(state) kwargs["targets"].append("lm_head") sparsegpt_modifier_head = SparseGPTModifier(framework=Framework.pytorch, **kwargs) - sparsegpt_modifier_head.initialize_obcq(model=modifiable_model, device=device) + sparsegpt_modifier_head.setup(state) # check we pick up the lm_head layer layers_no_head = len(sparsegpt_modifier_no_head.compressible_layers_) From 66dc2e550b68dee031ebf1362f31367a5adc5b05 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Mon, 18 Dec 2023 12:06:07 -0500 Subject: [PATCH 9/9] Fix regression --- src/sparseml/modifiers/pruning/wanda/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/modifiers/pruning/wanda/pytorch.py b/src/sparseml/modifiers/pruning/wanda/pytorch.py index c51dc6f34fb..6c2115fec9c 100644 --- a/src/sparseml/modifiers/pruning/wanda/pytorch.py +++ b/src/sparseml/modifiers/pruning/wanda/pytorch.py @@ -74,7 +74,7 @@ def setup(self, state: State, **kwargs): """ self.model = state.model self.compressible_layers_ = self.compressible_layers() - self.device_ = self._set_device(device=state.hardware.device) + self._set_device(device=state.hardware.device) self.layer_prefix_ = self.model.layer_prefix self._infer_mask_block_size()