Skip to content

Commit

Permalink
POC for sequential FSDP OBCQ (#1947)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins authored Jan 10, 2024
1 parent e250572 commit 80983e5
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 8 deletions.
3 changes: 0 additions & 3 deletions src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,11 @@ class SparseGPTModifier(WandaPruningModifier):
in the recipe
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
diagonal norm
:param sequential_update: Whether or not to update weights sequentially by layer,
True saves on GPU memory (NOTE: deprecated)
"""

block_size: int
quantize: Union[bool, Dict]
dampening_frac: Optional[float] = 0.01
sequential_update: Optional[bool] = False # deprecated
quantization_modifier_: Any = None

def on_initialize_structure(self, state: State, **kwargs):
Expand Down
3 changes: 3 additions & 0 deletions src/sparseml/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,15 @@ class WandaPruningModifier(Modifier):
:param mask_structure: String to define the structure of the mask to apply.
Must be of the form N:M where N, M are integers that define a custom block
shape. Defaults to 0:0 which represents an unstructured mask.
:param sequential_update: Whether or not to update weights sequentially by layer,
True saves on GPU memory
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model
"""

sparsity: Union[float, List[float]]
mask_structure: str = "0:0"
sequential_update: Optional[bool] = False
targets: Union[str, List[str], None] = ALL_TOKEN
compressible_layers_: Optional[List] = None
prunen_: Optional[int] = None
Expand Down
17 changes: 14 additions & 3 deletions src/sparseml/modifiers/pruning/wanda/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def initialize_compression(self, model: ModifiableModel):
args = self._pruning_arguments(layer_sparsity)
comp_cls = self._compression_class()
compressor = LayerCompressor(comp_cls, self.model, layer, idx, name, args)
compressor.pre_compress()
if not self.sequential_update:
# add all batch processing hooks before the forward pass
compressor.pre_compress()
self.layer_compressors_.append(compressor)

@torch.no_grad()
Expand All @@ -105,7 +107,9 @@ def apply_compression(
_LOGGER.info(
f"Running {class_name} calibration with " f"{len(dataloader)} samples..."
)
run_calibration_forward(self.model, dataloader)
if not self.sequential_update:
# in non-sequential mode we run one forward batch for all modules
run_calibration_forward(self.model, dataloader)

num_layers = len(self.compressible_layers_)
for idx, layer_compressor in enumerate(self.layer_compressors_):
Expand All @@ -115,7 +119,14 @@ def apply_compression(
f"to sparsity {layer_sparsity} ====="
)

# Prune/quantize using WANDA
# Prune/quantize using SparseGPT
if self.sequential_update:
# in sequential mode we run one forward pass for each module we
# want to compress, this will be really slow but allows compression in
# earlier layers to affect later layers
layer_compressor.pre_compress()
_LOGGER.info(f"Calibrating {layer_compressor.name}...")
run_calibration_forward(self.model, dataloader)
layer_compressor.compress()
layer_compressor.post_compress()

Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/modifiers/utils/layer_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def pre_compress(self):

for name in subset:
layer = subset[name]
with summon_full_params_context(self.layer):
wrapper = self.module_compressor_class(name, layer)
full_name = self._get_full_submodule_name(name)
with summon_full_params_context(self.layer):
wrapper = self.module_compressor_class(full_name, layer)
set_layer(full_name, wrapper, self.model)
self.modules[name] = wrapper

Expand Down

0 comments on commit 80983e5

Please sign in to comment.