Skip to content

Commit

Permalink
propagate target-ids
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Dec 7, 2023
1 parent 872ebb3 commit 24535d1
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 29 deletions.
37 changes: 19 additions & 18 deletions src/sparseml/modifiers/obcq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,16 @@
# limitations under the License.


import logging
from typing import Any, Optional
from sparseml.modifiers.pruning.wanda.pytorch import WandaPruningModifierPyTorch

from functools import partial
from typing import Any, Optional

from sparseml.core.state import State
from sparseml.experimental.sparsegpt.layer_compressor import LayerCompressor
from sparseml.modifiers.obcq.base import SparseGPTModifier
from sparseml.modifiers.obcq.utils.helpers import cache_attention_inputs
from sparseml.modifiers.pruning.wanda.pytorch import WandaPruningModifierPyTorch
from sparseml.modifiers.utils.layer_compressors import OBCQLayerCompressor




class SparseGPTModifierPyTorch(WandaPruningModifierPyTorch, SparseGPTModifier):
"""
Pytorch implementation of SparseGPT
Expand All @@ -45,7 +42,7 @@ class SparseGPTModifierPyTorch(WandaPruningModifierPyTorch, SparseGPTModifier):
model: Any = None
device_: str = "cuda:0"
layer_prefix_: Optional[str] = None
layer_compressor_class_: Any = OBCQLayerCompressor
layer_compressor_class_: LayerCompressor = OBCQLayerCompressor

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Expand All @@ -59,17 +56,22 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
self.on_initialize_structure(state, **kwargs)
if self.quantization_modifier_:
self.quantization_modifier_.initialize(state, **kwargs)
return super().on_initialize(state, **kwargs)

# attach target_ids to `compress_bottom` for OBCQ
# this must be done before calling super().on_initialize

self.compress_bottom = partial(self.compress_bottom, target_ids=self.target_ids)
return super().on_initialize(state=state, **kwargs)

def _get_compression_args(self, layer_sparsity):
return {
**super()._get_compression_args(layer_sparsity=layer_sparsity),
**{
"blocksize": self.block_size,
"percdamp": self.dampening_frac,
"sequential_update": self.sequential_update,
"quantize": self.quantize,
},
return {
**super()._get_compression_args(layer_sparsity=layer_sparsity),
**{
"blocksize": self.block_size,
"percdamp": self.dampening_frac,
"sequential_update": self.sequential_update,
"quantize": self.quantize,
},
}

def on_finalize(self, state: State, **kwargs) -> bool:
Expand All @@ -81,4 +83,3 @@ def on_finalize(self, state: State, **kwargs) -> bool:
if self.quantization_modifier_:
self.quantization_modifier_.finalize(state, **kwargs)
return super().on_finalize(state, **kwargs)

30 changes: 19 additions & 11 deletions src/sparseml/modifiers/pruning/wanda/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple
from sparseml.modifiers.utils import layer_compressors

import torch

from sparseml.core.model.base import ModifiableModel
from sparseml.core.state import State
from sparseml.experimental.sparsegpt.layer_compressor import LayerCompressor
from sparseml.modifiers.obcq.utils.helpers import cache_attention_inputs
from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier
from sparseml.modifiers.utils.layer_compressors import WandaLayerCompressor
Expand All @@ -30,17 +30,26 @@

class WandaPruningModifierPyTorch(WandaPruningModifier):
"""
PyTorch implementation of WandaPruningModifier
Pytorch implementation of WandaPruningModifier
Lifecycle:
- on_initialize
- setup
- compressible_layers
- prune
- compress_bottom
- LayerCompressor.compress
- on_finalize
:param model: `ModifiableModel` to perform wanda on, in-place
"""

model: Optional[ModifiableModel] = None
device_: str = "cuda:0"
layer_prefix_: Optional[str] = None
prunen_: Optional[int] = None
prunem_: Optional[int] = None
layer_compressor_class_: layer_compressors = WandaLayerCompressor
layer_compressor_class_: LayerCompressor = WandaLayerCompressor

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Expand All @@ -49,9 +58,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
:param state: session state storing input model and calibration data
"""
self._validate_layerwise_sparsity()

self.setup(state=state, **kwargs)

# run on calibration data
self.prune(dataloader=state.data.calib)
torch.cuda.empty_cache()
Expand Down Expand Up @@ -125,10 +133,10 @@ def prune(

def _get_compression_args(self, layer_sparsity):
return {
"sparsity": layer_sparsity,
"prunen": self.prunen_,
"prunem": self.prunem_,
}
"sparsity": layer_sparsity,
"prunen": self.prunen_,
"prunem": self.prunem_,
}

def compress_bottom(
self,
Expand Down

0 comments on commit 24535d1

Please sign in to comment.