Skip to content

Commit

Permalink
FSDP oneshot (#1939)
Browse files Browse the repository at this point in the history
* initial recipe re-loading

* loading for input recipe

* persist structure across recipe loads

* clean up fn names

* clean up duplicated code

* delete extra file

* unit tests

* fix failing test

* quantization edge cases

* quant tests

* fixes for stage name clashes

* clean up documentation

* setup StageRunner class

* running one_shot from text_gen script

* cleanup helper fns

* precision support

* formatting

* WIP for alternating

* fixing device issue

* MVP for alternating flows

* add apply flag during finalization as well

* clarity comments

* clean up docstrings

* fix unit test

* WIP FSDP support

* fix for unwrapping

* WIP for state reloading between stages

* example fsdp config updates

* add finetuning README

* fix for 2nd oneshot stage

* cleaning up stage logic

* mvp for single GPU fsdp

* WIP obcq FSDP

* quality

* sgpt wrapper

* clean up on finalize, improve logging

* cleaning up device args

* merge alternating

* WIP alternating

* fsdp compatible, training loss issue

* fix for loss bug

* fixing checkpoint issue

* fix for quantization saving

* cleanup after merge

* unit test fix

* clean up logging

* move FSDP to helper files

* update docstrings, clean up

* fix circular import

* unmodify example

* fix typo!

* setup FSDP for when starting from oneshot

* update setup and readme

* fix CLI issue, update README

* POC for sequential FSDP OBCQ (#1947)

* fix GHA line lost in merge

* fix calib loading

* fix dependencies

* reverting OBCQ merged changes for now

* restore SparseCausalModel for now

* add progress bar for calibration forward pass (#1950)

---------
  • Loading branch information
Sara Adkins authored Jan 11, 2024
1 parent 691cb49 commit 5007b8c
Show file tree
Hide file tree
Showing 32 changed files with 903 additions and 763 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/test-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ jobs:
id: export-check
run: >
((git diff --name-only origin/main HEAD | grep -E "[src|tests]/sparseml/export|setup.py|.github")
|| (echo $GITHUB_REF | grep -E "refs/heads/[release/|main]"))
&& echo "::set-output name=output::1" || echo "::set-output name=output::0"
- name: "Checking if sparseml.transformers was changed"
id: transformers-check
run: >
Expand Down Expand Up @@ -230,7 +232,7 @@ jobs:
- name: "Clean sparsezoo directory"
run: rm -r sparsezoo/
- name: "⚙️ Install dependencies"
run: pip3 install .[dev,torch,transformers]
run: pip3 install .[dev,torch,transformers,onnxruntime]
- name: "🔬 Running transformers tests"
run: make test TARGETS=transformers
export-tests:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ fsdp_config:
fsdp_sharding_strategy: 1
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
rdzv_backend: static
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"scikit-learn",
"seqeval",
"einops",
"onnxruntime>=1.0.0",
"accelerate>=0.20.3",
]
_yolov5_deps = _pytorch_vision_deps + [
Expand Down Expand Up @@ -205,6 +206,7 @@ def _setup_entry_points() -> Dict:
"sparseml.transformers.text_generation.train=sparseml.transformers.finetune.text_generation:run_train", # noqa 501
"sparseml.transformers.text_generation.finetune=sparseml.transformers.finetune.text_generation:run_train", # noqa 501
"sparseml.transformers.text_generation.eval=sparseml.transformers.finetune.text_generation:run_eval", # noqa 501
"sparseml.transformers.text_generation.oneshot=sparseml.transformers.finetune.text_generation:run_oneshot", # noqa 501
]
)

Expand Down
45 changes: 18 additions & 27 deletions src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Optional, Union

from sparseml.core.factory import ModifierFactory
from sparseml.core.state import State
Expand All @@ -29,48 +29,39 @@ class SparseGPTModifier(WandaPruningModifier):
"""
Modifier for applying the one-shot OBCQ algorithm to a model
Life-cycle:
- initialze
- compress
- finalize
Lifecycle:
- on_initialize
- initialize_compression()
- compressible_layers()
- LayerCompressor.pre_compress()
- apply_compression()
- run_calibration_forward()
- LayerCompressor.compress()
- LayerCompressor.post_compress()
- on_finalize
- LayerCompressor.revert_layer_wrappers()
:param sparsity: Sparsity to compress model to
:param block_size: Used to determine number of columns to compress in one pass
:param quantize: Whether or not to quantize weights during SparseGPT. Set to
True to quantize using an existing quantization modifier, or pass in the
configuration for a quantization modifier if one does not already exist
in the recipe
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
diagonal norm
:param sequential_update: Whether or not to update weights sequentially by layer,
True saves on GPU memory
:param mask_structure: String to define the structure of the mask to apply.
Must be of the form N:M where N, M are integers that define a custom block
shape. Defaults to 0:0 which represents an unstructured mask.
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model
:param target_ids: list of keys in model output to cache, NOTE: this argument
has been deprecated and will be removed in a future release
"""

block_size: int
quantize: Union[bool, Dict]
dampening_frac: Optional[float] = 0.01
sequential_update: Optional[bool] = True
prunen_: Optional[int] = None
prunem_: Optional[int] = None
target_ids: Optional[List[str]] = None
layer_prefix: Optional[str] = None
quantization_modifier_: Any = None

def __post_init__(self):
if self.target_ids is not None:
_LOGGER.warning(
"`target_ids` param has been deprecated and will be "
"removed in a future release"
)

def on_initialize_structure(self, state: State, **kwargs):
"""
Check the model's quantization state matches that expected by this modifier,
adding a default quantization scheme if needed
:param state: session state storing input model and calibration data
"""
quantization_already_active = state.model.qat_active()
if isinstance(self.quantize, bool):
if not self.quantize and quantization_already_active:
Expand Down
77 changes: 41 additions & 36 deletions src/sparseml/modifiers/obcq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,79 +12,84 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import List, Optional

from functools import partial
from typing import Any, Optional

from sparseml.core.model import ModifiableModel
from sparseml.core.state import State
from sparseml.modifiers.obcq.base import SparseGPTModifier
from sparseml.modifiers.obcq.utils.layer_compressor import OBCQLayerCompressor
from sparseml.modifiers.obcq.utils.sgpt_wrapper import SparseGptWrapper
from sparseml.modifiers.pruning.wanda.pytorch import WandaPruningModifierPyTorch


__all__ = ["SparseGPTModifierPyTorch"]

_LOGGER = logging.getLogger(__name__)


class SparseGPTModifierPyTorch(WandaPruningModifierPyTorch, SparseGPTModifier):
"""
Pytorch implementation of SparseGPT
Lifecycle:
- on_initialize
- setup
- compressible_layers
- prune
- compress_bottom
- LayerCompressor.compress
- initialize_compression()
- compressible_layers()
- LayerCompressor.pre_compress()
- apply_compression()
- run_calibration_forward()
- LayerCompressor.compress()
- LayerCompressor.post_compress()
- on_finalize
- LayerCompressor.revert_layer_wrappers()
:param model: Pytorch model to perform OBCQ on, in-place
"""

model: Any = None
device_: str = "cuda:0"
layer_prefix_: Optional[str] = None
layer_compressor_class_ = OBCQLayerCompressor
model: Optional[ModifiableModel] = None
layer_compressors: List = None

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Initialize and run the OBCQ algorithm on the current state
:param state: session state storing input model and calibration data
"""
self._validate_layerwise_sparsity()

if not self.initialized_structure_:
self.on_initialize_structure(state, **kwargs)
if self.quantization_modifier_:
self.quantization_modifier_.initialize(state, **kwargs)

# attach target_ids to `compress_bottom` for OBCQ
# this must be done before calling super().on_initialize
return super(SparseGPTModifierPyTorch, self).on_initialize(state, **kwargs)

compress_bottom = partial(self.compress_bottom, target_ids=self.target_ids)
def on_finalize(self, state: "State", **kwargs) -> bool:
"""
disable the quantization observers used by the OBCQ algorithm
# we need setattr here because of Pydantic's internal data model
object.__setattr__(self, "compress_bottom", compress_bottom)
return super().on_initialize(state=state, **kwargs)
:param state: session state storing input model and calibration data
"""
if self.quantization_modifier_:
self.quantization_modifier_.finalize(state, **kwargs)

def _get_compression_args(self, layer_sparsity):
return super(SparseGPTModifierPyTorch, self).on_finalize(state, **kwargs)

def _pruning_arguments(self, sparsity):
"""
Gather the parameters needed for root module compression in a dict
:param sparsity: target sparsity
:return: dict of params for pruning
"""
return {
**super()._get_compression_args(layer_sparsity=layer_sparsity),
**{
"blocksize": self.block_size,
"percdamp": self.dampening_frac,
"sequential_update": self.sequential_update,
"quantize": self.quantize,
},
"sparsity": sparsity,
"prunen": self.prunen_,
"prunem": self.prunem_,
"blocksize": self.block_size,
"percdamp": self.dampening_frac,
}

def on_finalize(self, state: State, **kwargs) -> bool:
def _compression_class(self):
"""
disable the observers used by the OBCQ algorithm and set kv-cache configuration
:param state: un-used, for matching spec of Modifier base class
:return: wrapper class used for root modules of this compression class
"""
if self.quantization_modifier_:
self.quantization_modifier_.finalize(state, **kwargs)
return super().on_finalize(state, **kwargs)
return SparseGptWrapper
80 changes: 0 additions & 80 deletions src/sparseml/modifiers/obcq/utils/layer_compressor.py

This file was deleted.

Loading

0 comments on commit 5007b8c

Please sign in to comment.