This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Model Compression] admm pruner #4116
Merged
Merged
Changes from 5 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
cf05b6e
add admm pruner
J-shang 7db5fa0
update docstr
J-shang f312f7b
Merge branch 'master' into compression_v2_admm
J-shang 7229202
add validation
J-shang aca6cfa
rename row to rho
J-shang 1dbba22
Merge remote-tracking branch 'upstream/master' into compression_v2_admm
J-shang 89fa3f2
update
J-shang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -135,7 +135,6 @@ def __init__(self, model: Module, config_list: List[Dict]): | |
- op_names : Operation names to prune. | ||
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning. | ||
""" | ||
self.mode = 'normal' | ||
super().__init__(model, config_list) | ||
|
||
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]): | ||
|
@@ -655,3 +654,122 @@ def reset_tools(self): | |
self.sparsity_allocator = Conv2dDependencyAwareAllocator(self, 0, self.dummy_input) | ||
else: | ||
raise NotImplementedError('Only support mode `normal`, `global` and `dependency_aware`') | ||
|
||
|
||
class ADMMPruner(OneShotPruner): | ||
""" | ||
ADMM (Alternating Direction Method of Multipliers) Pruner is a kind of mathematical optimization technique. | ||
The metric used in this pruner is the absolute value of the weight. | ||
In each iteration, the weight with small magnitudes will be set to zero. | ||
Only in the final iteration, the mask will be generated and apply to model wrapper. | ||
|
||
The original paper refer to: https://arxiv.org/abs/1804.03294. | ||
""" | ||
|
||
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None], | ||
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], iterations: int, training_epochs: int): | ||
""" | ||
Parameters | ||
---------- | ||
model | ||
Model to be pruned. | ||
config_list | ||
Supported keys: | ||
- sparsity : This is to specify the sparsity for each layer in this config to be compressed. | ||
- sparsity_per_layer : Equals to sparsity. | ||
- rho : Penalty parameters in ADMM algorithm. | ||
- op_types : Operation types to prune. | ||
- op_names : Operation names to prune. | ||
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning. | ||
trainer | ||
A callable function used to train model or just inference. Take model, optimizer, criterion as input. | ||
The model will be trained or inferenced `training_epochs` epochs. | ||
|
||
Example:: | ||
|
||
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]): | ||
training = model.training | ||
model.train(mode=True) | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
for batch_idx, (data, target) in enumerate(train_loader): | ||
data, target = data.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = criterion(output, target) | ||
loss.backward() | ||
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False. | ||
optimizer.step() | ||
model.train(mode=training) | ||
optimizer | ||
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data, | ||
so do not use this optimizer in other places. | ||
criterion | ||
The criterion function used in trainer. Take model output and target value as input, and return the loss. | ||
iterations | ||
The total iteration number in admm pruning algorithm. | ||
training_epochs | ||
The epoch number for training model in each iteration. | ||
""" | ||
self.trainer = trainer | ||
self.optimizer = optimizer | ||
self.criterion = criterion | ||
self.iterations = iterations | ||
self.training_epochs = training_epochs | ||
super().__init__(model, config_list) | ||
|
||
self.Z = {name: wrapper.module.weight.data.clone().detach() for name, wrapper in self.get_modules_wrapper().items()} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please use a clear name rather than There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. |
||
self.U = {name: torch.zeros_like(z).to(z.device) for name, z in self.Z.items()} | ||
|
||
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]): | ||
schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(INTERNAL_SCHEMA)] | ||
for schema in schema_list: | ||
schema.update({SchemaOptional('rho'): And(float, lambda n: n > 0)}) | ||
schema_list.append(deepcopy(EXCLUDE_SCHEMA)) | ||
schema = CompressorSchema(schema_list, model, _logger) | ||
schema.validate(config_list) | ||
|
||
def criterion_patch(self, origin_criterion: Callable[[Tensor, Tensor], Tensor]): | ||
def patched_criterion(output: Tensor, target: Tensor): | ||
penalty = torch.tensor(0.0).to(output.device) | ||
for name, wrapper in self.get_modules_wrapper().items(): | ||
rho = wrapper.config['rho'] | ||
penalty += (rho / 2) * torch.sqrt(torch.norm(wrapper.module.weight - self.Z[name] + self.U[name])) | ||
return origin_criterion(output, target) + penalty | ||
return patched_criterion | ||
|
||
def reset_tools(self): | ||
if self.data_collector is None: | ||
self.data_collector = WeightTrainerBasedDataCollector(self, self.trainer, self.optimizer, self.criterion, | ||
self.training_epochs, criterion_patch=self.criterion_patch) | ||
else: | ||
self.data_collector.reset() | ||
if self.metrics_calculator is None: | ||
self.metrics_calculator = NormMetricsCalculator() | ||
if self.sparsity_allocator is None: | ||
self.sparsity_allocator = NormalSparsityAllocator(self) | ||
|
||
def compress(self) -> Tuple[Module, Dict]: | ||
""" | ||
Returns | ||
------- | ||
Tuple[Module, Dict] | ||
Return the wrapped model and mask. | ||
""" | ||
for i in range(self.iterations): | ||
_logger.info('======= ADMM Iteration %d Start =======', i) | ||
data = self.data_collector.collect() | ||
|
||
for name, weight in data.items(): | ||
self.Z[name] = weight + self.U[name] | ||
metrics = self.metrics_calculator.calculate_metrics(self.Z) | ||
masks = self.sparsity_allocator.generate_sparsity(metrics) | ||
|
||
for name, mask in masks.items(): | ||
self.Z[name] = self.Z[name].mul(mask['weight_mask']) | ||
self.U[name] = self.U[name] + data[name] - self.Z[name] | ||
|
||
metrics = self.metrics_calculator.calculate_metrics(data) | ||
masks = self.sparsity_allocator.generate_sparsity(metrics) | ||
|
||
self.load_masks(masks) | ||
return self.bound_model, masks |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that ADMMPruner has some functionalities of the PruningScheduler, e.g., it performs multiple iterations and keeps track of context/state data like Z and U (I'm not sure whether this is supposed to be handled by the "task" abstraction we discussed). How do we want to integrate ADMMPruner with the scheduler logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, ADMM has scheduler logic, but in fact, it does not generate masks during each iteration. ADMM only reset the elements with small magnitudes to zero, and these elements will also be trained in the following iterations. Only in the last iteration, ADMM will generate masks. So maybe make ADMM a pruner is more reasonable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I just read the ADMM paper, and it seems that the iterative elements are for solving the optimization problem instead of iterative pruning.