Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add flops and params counter #2535

Merged
merged 21 commits into from
Jun 30, 2020
Merged

Add flops and params counter #2535

merged 21 commits into from
Jun 30, 2020

Conversation

colorjam
Copy link
Contributor

@colorjam colorjam commented Jun 7, 2020

Support flops and params calculation with mask.

from nni.counter import count_flops_params
flops, params = count_flops_params(model, (1, 1, 28, 28))
print(flops, params)

@ultmaster
Copy link
Contributor

Please remove accidentally pushed pth files.

weight_mask = None
m_type = type(m)
if m_type in custom_ops:
if isinstance(m_list[idx-1], PrunerModuleWrapper):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if idx == 0? What if PrunerModuleWrapper has multiple children?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks~ PrunerModuleWrapper is our customize wrapper and it directly wraps Conv or Linear module, so it won't have multiple children.

@QuanluZhang
Copy link
Contributor

@QuanluZhang QuanluZhang reopened this Jun 20, 2020
if isinstance(prev_m, PrunerModuleWrapper):
weight_mask = prev_m.weight_mask

m.register_buffer('weight_mask', weight_mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to register buffer here if the model is not PrunerModuleWrapper?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not directly register buffer on PrunerModuleWrapper. We only register buffer on Conv or Linear (please see custom_ops in Line 96) where the previous module was PrunerModuleWrapper.

except ImportError:
_logger.warning('Please install thop using command: pip install thop')

def count_flops_params(model: nn.Module, input_size=None, verbose=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to remove default value of input_size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix it. please review it on the latest version~

@chicm-ms
Copy link
Contributor

Please remove the pth files.

@chicm-ms chicm-ms merged commit a3b0bd7 into microsoft:master Jun 30, 2020
@chicm-ms chicm-ms mentioned this pull request Jul 1, 2020
24 tasks
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants