Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Model Compression] admm pruner #4116

Merged
merged 7 commits into from
Sep 22, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 119 additions & 1 deletion nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def __init__(self, model: Module, config_list: List[Dict]):
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
"""
self.mode = 'normal'
super().__init__(model, config_list)

def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
Expand Down Expand Up @@ -655,3 +654,122 @@ def reset_tools(self):
self.sparsity_allocator = Conv2dDependencyAwareAllocator(self, 0, self.dummy_input)
else:
raise NotImplementedError('Only support mode `normal`, `global` and `dependency_aware`')


class ADMMPruner(BasicPruner):
"""
ADMM (Alternating Direction Method of Multipliers) Pruner is a kind of mathematical optimization technique.
The metric used in this pruner is the absolute value of the weight.
In each iteration, the weight with small magnitudes will be set to zero.
Only in the final iteration, the mask will be generated and apply to model wrapper.
The original paper refer to: https://arxiv.org/abs/1804.03294.
"""

def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], iterations: int, training_epochs: int):
"""
Parameters
----------
model
Model to be pruned.
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- rho : Penalty parameters in ADMM algorithm.
- op_types : Operation types to prune.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
optimizer
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data,
so do not use this optimizer in other places.
criterion
The criterion function used in trainer. Take model output and target value as input, and return the loss.
iterations
The total iteration number in admm pruning algorithm.
training_epochs
The epoch number for training model in each iteration.
"""
self.trainer = trainer
self.optimizer = optimizer
self.criterion = criterion
self.iterations = iterations
self.training_epochs = training_epochs
super().__init__(model, config_list)

self.Z = {name: wrapper.module.weight.data.clone().detach() for name, wrapper in self.get_modules_wrapper().items()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use a clear name rather than Z and U.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

U can be named as scaled_dual_variable, but I have no idea to name Z. The author rewrites the origin problem to 2-block optimization, and Z can be seen as another solution of weight. I think using Z is because in ADMM, the second optimization goal is usually denoted as z.
image

And in compression, the problem is
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.

self.U = {name: torch.zeros_like(z).to(z.device) for name, z in self.Z.items()}

def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
for schema in schema_list:
schema.update({SchemaOptional('rho'): And(float, lambda n: n > 0)})
schema_list.append(deepcopy(EXCLUDE_SCHEMA))
schema = CompressorSchema(schema_list, model, _logger)
schema.validate(config_list)

def criterion_patch(self, origin_criterion: Callable[[Tensor, Tensor], Tensor]):
def patched_criterion(output: Tensor, target: Tensor):
penalty = torch.tensor(0.0).to(output.device)
for name, wrapper in self.get_modules_wrapper().items():
rho = wrapper.config['rho']
penalty += (rho / 2) * torch.sqrt(torch.norm(wrapper.module.weight - self.Z[name] + self.U[name]))
return origin_criterion(output, target) + penalty
return patched_criterion

def reset_tools(self):
if self.data_collector is None:
self.data_collector = WeightTrainerBasedDataCollector(self, self.trainer, self.optimizer, self.criterion,
self.training_epochs, criterion_patch=self.criterion_patch)
else:
self.data_collector.reset()
if self.metrics_calculator is None:
self.metrics_calculator = NormMetricsCalculator()
if self.sparsity_allocator is None:
self.sparsity_allocator = NormalSparsityAllocator(self)

def compress(self) -> Tuple[Module, Dict]:
"""
Returns
-------
Tuple[Module, Dict]
Return the wrapped model and mask.
"""
for i in range(self.iterations):
_logger.info('======= ADMM Iteration %d Start =======', i)
data = self.data_collector.collect()

for name, weight in data.items():
self.Z[name] = weight + self.U[name]
metrics = self.metrics_calculator.calculate_metrics(self.Z)
masks = self.sparsity_allocator.generate_sparsity(metrics)

for name, mask in masks.items():
self.Z[name] = self.Z[name].mul(mask['weight'])
self.U[name] = self.U[name] + data[name] - self.Z[name]

metrics = self.metrics_calculator.calculate_metrics(data)
masks = self.sparsity_allocator.generate_sparsity(metrics)

self.load_masks(masks)
return self.bound_model, masks