-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Conversation
7b7bf11
to
f312f7b
Compare
@@ -655,3 +654,122 @@ def reset_tools(self): | |||
self.sparsity_allocator = Conv2dDependencyAwareAllocator(self, 0, self.dummy_input) | |||
else: | |||
raise NotImplementedError('Only support mode `normal`, `global` and `dependency_aware`') | |||
|
|||
|
|||
class ADMMPruner(OneShotPruner): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that ADMMPruner has some functionalities of the PruningScheduler, e.g., it performs multiple iterations and keeps track of context/state data like Z and U (I'm not sure whether this is supposed to be handled by the "task" abstraction we discussed). How do we want to integrate ADMMPruner with the scheduler logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, ADMM has scheduler logic, but in fact, it does not generate masks during each iteration. ADMM only reset the elements with small magnitudes to zero, and these elements will also be trained in the following iterations. Only in the last iteration, ADMM will generate masks. So maybe make ADMM a pruner is more reasonable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I just read the ADMM paper, and it seems that the iterative elements are for solving the optimization problem instead of iterative pruning.
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]): | ||
schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(INTERNAL_SCHEMA)] | ||
for schema in schema_list: | ||
schema.update({SchemaOptional('row'): And(float, lambda n: n > 0)}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recommend using 'rho' instead since 'row' is confusing. But this involves a change to the original API, so we might don't want to do that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I do think "row" is very confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it's a good suggestion, rho
is better, I will modify it.
self.training_epochs = training_epochs | ||
super().__init__(model, config_list) | ||
|
||
self.Z = {name: wrapper.module.weight.data.clone().detach() for name, wrapper in self.get_modules_wrapper().items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use a clear name rather than Z
and U
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it.
No description provided.