Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Fix bug of FLOPs counter (#3497)
Browse files Browse the repository at this point in the history
  • Loading branch information
colorjam authored Apr 21, 2021
1 parent 638da0b commit 7fd0776
Showing 1 changed file with 126 additions and 11 deletions.
137 changes: 126 additions & 11 deletions nni/compression/pytorch/utils/counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence
from nni.compression.pytorch.compressor import PrunerModuleWrapper


Expand All @@ -32,21 +33,27 @@ def __init__(self, custom_ops=None, mode='default'):
for reference, please see ``self.ops``.
mode:
the mode of how to collect information. If the mode is set to `default`,
only the information of convolution and linear will be collected.
only the information of convolution, linear and rnn modules will be collected.
If the mode is set to `full`, other operations will also be collected.
"""
self.ops = {
nn.Conv1d: self._count_convNd,
nn.Conv2d: self._count_convNd,
nn.Conv3d: self._count_convNd,
nn.Linear: self._count_linear
nn.ConvTranspose1d: self._count_convNd,
nn.ConvTranspose2d: self._count_convNd,
nn.ConvTranspose3d: self._count_convNd,
nn.Linear: self._count_linear,
nn.RNNCell: self._count_rnn_cell,
nn.GRUCell: self._count_gru_cell,
nn.LSTMCell: self._count_lstm_cell,
nn.RNN: self._count_rnn,
nn.GRU: self._count_gru,
nn.LSTM: self._count_lstm
}
self._count_bias = False
if mode == 'full':
self.ops.update({
nn.ConvTranspose1d: self._count_convNd,
nn.ConvTranspose2d: self._count_convNd,
nn.ConvTranspose3d: self._count_convNd,
nn.BatchNorm1d: self._count_bn,
nn.BatchNorm2d: self._count_bn,
nn.BatchNorm3d: self._count_bn,
Expand Down Expand Up @@ -86,7 +93,7 @@ def _get_result(self, m, flops):

def _count_convNd(self, m, x, y):
cin = m.in_channels
kernel_ops = m.weight.size()[2] * m.weight.size()[3]
kernel_ops = torch.zeros(m.weight.size()[2:]).numel()
output_size = torch.zeros(y.size()[2:]).numel()
cout = y.size()[1]

Expand Down Expand Up @@ -156,13 +163,125 @@ def _count_upsample(self, m, x, y):

return self._get_result(m, total_ops)

def _count_cell_flops(self, input_size, hidden_size, cell_type):
# h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
total_ops = hidden_size * (input_size + hidden_size) + hidden_size

if self._count_bias:
total_ops += hidden_size * 2

if cell_type == 'rnn':
return total_ops

if cell_type == 'gru':
# r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
# z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
# n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
total_ops *= 3

# r hadamard : r * (~)
total_ops += hidden_size

# h' = (1 - z) * n + z * h
# hadamard hadamard add
total_ops += hidden_size * 3

elif cell_type == 'lstm':
# i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
# f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
# o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
# g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
total_ops *= 4

# c' = f * c + i * g
# hadamard hadamard add
total_ops += hidden_size * 3

# h' = o * \tanh(c')
total_ops += hidden_size

return total_ops


def _count_rnn_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'rnn')
batch_size = x[0].size(0)
total_ops *= batch_size

return self._get_result(m, total_ops)

def _count_gru_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'gru')
batch_size = x[0].size(0)
total_ops *= batch_size

return self._get_result(m, total_ops)

def _count_lstm_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'lstm')
batch_size = x[0].size(0)
total_ops *= batch_size

return self._get_result(m, total_ops)

def _get_bsize_nsteps(self, m, x):
if isinstance(x[0], PackedSequence):
batch_size = torch.max(x[0].batch_sizes)
num_steps = x[0].batch_sizes.size(0)
else:
if m.batch_first:
batch_size = x[0].size(0)
num_steps = x[0].size(1)
else:
batch_size = x[0].size(1)
num_steps = x[0].size(0)

return batch_size, num_steps

def _count_rnn_module(self, m, x, y, module_name):
input_size = m.input_size
hidden_size = m.hidden_size
num_layers = m.num_layers

batch_size, num_steps = self._get_bsize_nsteps(m, x)
total_ops = self._count_cell_flops(input_size, hidden_size, module_name)

for _ in range(num_layers - 1):
if m.bidirectional:
cell_flops = self._count_cell_flops(hidden_size * 2, hidden_size, module_name) * 2
else:
cell_flops = self._count_cell_flops(hidden_size, hidden_size,module_name)
total_ops += cell_flops

total_ops *= num_steps
total_ops *= batch_size
return total_ops

def _count_rnn(self, m, x, y):
total_ops = self._count_rnn_module(m, x, y, 'rnn')

return self._get_result(m, total_ops)

def _count_gru(self, m, x, y):
total_ops = self._count_rnn_module(m, x, y, 'gru')

return self._get_result(m, total_ops)

def _count_lstm(self, m, x, y):
total_ops = self._count_rnn_module(m, x, y, 'lstm')

return self._get_result(m, total_ops)


def count_module(self, m, x, y, name):
# assume x is tuple of single tensor
result = self.ops[type(m)](m, x, y)
output_size = y[0].size() if isinstance(y, tuple) else y.size()

total_result = {
'name': name,
'input_size': tuple(x[0].size()),
'output_size': tuple(y.size()),
'output_size': tuple(output_size),
'module_type': type(m).__name__,
**result
}
Expand Down Expand Up @@ -279,10 +398,6 @@ def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'):
model(*x)

# restore origin status
for name, m in model.named_modules():
if hasattr(m, 'weight_mask'):
delattr(m, 'weight_mask')

model.train(training).to(original_device)
for handler in handler_collection:
handler.remove()
Expand Down

0 comments on commit 7fd0776

Please sign in to comment.