From 80983e54ac836b90ed440efd1f0e9c4e8d9418fb Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 10 Jan 2024 15:32:16 -0500 Subject: [PATCH] POC for sequential FSDP OBCQ (#1947) --- src/sparseml/modifiers/obcq/base.py | 3 --- src/sparseml/modifiers/pruning/wanda/base.py | 3 +++ src/sparseml/modifiers/pruning/wanda/pytorch.py | 17 ++++++++++++++--- .../modifiers/utils/layer_compressor.py | 4 ++-- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/sparseml/modifiers/obcq/base.py b/src/sparseml/modifiers/obcq/base.py index 15facf36e68..582c12f0230 100644 --- a/src/sparseml/modifiers/obcq/base.py +++ b/src/sparseml/modifiers/obcq/base.py @@ -48,14 +48,11 @@ class SparseGPTModifier(WandaPruningModifier): in the recipe :param dampening_frac: Amount of dampening to apply to H, as a fraction of the diagonal norm - :param sequential_update: Whether or not to update weights sequentially by layer, - True saves on GPU memory (NOTE: deprecated) """ block_size: int quantize: Union[bool, Dict] dampening_frac: Optional[float] = 0.01 - sequential_update: Optional[bool] = False # deprecated quantization_modifier_: Any = None def on_initialize_structure(self, state: State, **kwargs): diff --git a/src/sparseml/modifiers/pruning/wanda/base.py b/src/sparseml/modifiers/pruning/wanda/base.py index b5386945a5a..07ccdc8b8e1 100644 --- a/src/sparseml/modifiers/pruning/wanda/base.py +++ b/src/sparseml/modifiers/pruning/wanda/base.py @@ -45,12 +45,15 @@ class WandaPruningModifier(Modifier): :param mask_structure: String to define the structure of the mask to apply. Must be of the form N:M where N, M are integers that define a custom block shape. Defaults to 0:0 which represents an unstructured mask. + :param sequential_update: Whether or not to update weights sequentially by layer, + True saves on GPU memory :param targets: list of layer names to compress during OBCQ, or '__ALL__' to compress every layer in the model """ sparsity: Union[float, List[float]] mask_structure: str = "0:0" + sequential_update: Optional[bool] = False targets: Union[str, List[str], None] = ALL_TOKEN compressible_layers_: Optional[List] = None prunen_: Optional[int] = None diff --git a/src/sparseml/modifiers/pruning/wanda/pytorch.py b/src/sparseml/modifiers/pruning/wanda/pytorch.py index 0ccd7ea0c5b..ae6d118beba 100644 --- a/src/sparseml/modifiers/pruning/wanda/pytorch.py +++ b/src/sparseml/modifiers/pruning/wanda/pytorch.py @@ -89,7 +89,9 @@ def initialize_compression(self, model: ModifiableModel): args = self._pruning_arguments(layer_sparsity) comp_cls = self._compression_class() compressor = LayerCompressor(comp_cls, self.model, layer, idx, name, args) - compressor.pre_compress() + if not self.sequential_update: + # add all batch processing hooks before the forward pass + compressor.pre_compress() self.layer_compressors_.append(compressor) @torch.no_grad() @@ -105,7 +107,9 @@ def apply_compression( _LOGGER.info( f"Running {class_name} calibration with " f"{len(dataloader)} samples..." ) - run_calibration_forward(self.model, dataloader) + if not self.sequential_update: + # in non-sequential mode we run one forward batch for all modules + run_calibration_forward(self.model, dataloader) num_layers = len(self.compressible_layers_) for idx, layer_compressor in enumerate(self.layer_compressors_): @@ -115,7 +119,14 @@ def apply_compression( f"to sparsity {layer_sparsity} =====" ) - # Prune/quantize using WANDA + # Prune/quantize using SparseGPT + if self.sequential_update: + # in sequential mode we run one forward pass for each module we + # want to compress, this will be really slow but allows compression in + # earlier layers to affect later layers + layer_compressor.pre_compress() + _LOGGER.info(f"Calibrating {layer_compressor.name}...") + run_calibration_forward(self.model, dataloader) layer_compressor.compress() layer_compressor.post_compress() diff --git a/src/sparseml/modifiers/utils/layer_compressor.py b/src/sparseml/modifiers/utils/layer_compressor.py index 46b0de0d3a1..9c9167b15ab 100644 --- a/src/sparseml/modifiers/utils/layer_compressor.py +++ b/src/sparseml/modifiers/utils/layer_compressor.py @@ -88,9 +88,9 @@ def pre_compress(self): for name in subset: layer = subset[name] - with summon_full_params_context(self.layer): - wrapper = self.module_compressor_class(name, layer) full_name = self._get_full_submodule_name(name) + with summon_full_params_context(self.layer): + wrapper = self.module_compressor_class(full_name, layer) set_layer(full_name, wrapper, self.model) self.modules[name] = wrapper