Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Refactor flops counter (#3048)
Browse files Browse the repository at this point in the history
  • Loading branch information
colorjam authored Nov 23, 2020
1 parent 291bbbb commit b6233e5
Show file tree
Hide file tree
Showing 4 changed files with 326 additions and 105 deletions.
20 changes: 17 additions & 3 deletions docs/en_US/Compression/CompressionUtils.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,28 @@ fixed_mask = fix_mask_conflict('./resnet18_mask', net, data)
```

## Model FLOPs/Parameters Counter
We provide a model counter for calculating the model FLOPs and parameters. This counter supports calculating FLOPs/parameters of a normal model without masks, it can also calculates FLOPs/parameters of a model with mask wrappers, which helps users easily check model complexity during model compression on NNI. Note that, for sturctured pruning, we only identify the remained filters according to its mask, which not taking the pruned input channels into consideration, so the calculated FLOPs will be larger than real number (i.e., the number calculated after Model Speedup).
We provide a model counter for calculating the model FLOPs and parameters. This counter supports calculating FLOPs/parameters of a normal model without masks, it can also calculates FLOPs/parameters of a model with mask wrappers, which helps users easily check model complexity during model compression on NNI. Note that, for sturctured pruning, we only identify the remained filters according to its mask, which not taking the pruned input channels into consideration, so the calculated FLOPs will be larger than real number (i.e., the number calculated after Model Speedup).

We support two modes to collect information of modules. The first mode is `default`, which only collect the information of convolution and linear. The second mode is `full`, which also collect the information of other operations. Users can easily use our collected `results` for futher analysis.

### Usage
```
from nni.compression.pytorch.utils.counter import count_flops_params
# Given input size (1, 1, 28, 28)
flops, params = count_flops_params(model, (1, 1, 28, 28))
# Given input size (1, 1, 28, 28)
flops, params, results = count_flops_params(model, (1, 1, 28, 28))
# Given input tensor with size (1, 1, 28, 28) and switch to full mode
x = torch.randn(1, 1, 28, 28)
flops, params, results = count_flops_params(model, (x,) mode='full') # tuple of tensor as input
# Format output size to M (i.e., 10^6)
print(f'FLOPs: {flops/1e6:.3f}M, Params: {params/1e6:.3f}M)
print(results)
{
'conv': {'flops': [60], 'params': [20], 'weight_size': [(5, 3, 1, 1)], 'input_size': [(1, 3, 2, 2)], 'output_size': [(1, 5, 2, 2)], 'module_type': ['Conv2d']},
'conv2': {'flops': [100], 'params': [30], 'weight_size': [(5, 5, 1, 1)], 'input_size': [(1, 5, 2, 2)], 'output_size': [(1, 5, 2, 2)], 'module_type': ['Conv2d']}
}
```
Loading

0 comments on commit b6233e5

Please sign in to comment.