Skip to content

Commit

Permalink
fix: Fixed support of modules that are used multiple times during for…
Browse files Browse the repository at this point in the history
…ward (#24)

* chore: Updated requirements

* fix: Fixed hooking mechanism edge case

When a module gets several forward, removing the hook after the first forward raises an error.

* feat: Added support of nn.Flatten

* test: Added unittests for torch.nn.Flatten

* feat: Added support of receptive field for Flatten

* test: Updated unittest

* fix: Fixed support of reused modules

* refactor: Simplified hooking system

* fix: Fixed name & depth for reused modules

* refactor: Refactored name & depth resolution into primary hooking

* test: Fixed unittests
  • Loading branch information
frgfm authored Aug 5, 2020
1 parent 900eb16 commit 4db02c1
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 92 deletions.
2 changes: 1 addition & 1 deletion .conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ requirements:
- python>=3.6

run:
- pytorch >=1.1.0, <=1.4.0
- pytorch >=1.5.0

test:
# Python imports
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
torch>=1.1.0
torch>=1.5.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def write_version_file():
readme = f.read()

requirements = [
'torch>=1.1.0'
'torch>=1.5.0'
]

setup(
Expand Down
6 changes: 3 additions & 3 deletions test/test_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def test_apply(self):
mod = nn.Sequential(nn.Conv2d(3, 16, 3), multi_convs)

# Tag module attributes
def tag_name(mod, depth, name):
mod.__depth__ = depth
mod.__name__ = name
def tag_name(mod, name):
mod.__depth__ = len(name.split('.')) - 1
mod.__name__ = name.rpartition('.')[-1]

crawler.apply(mod, tag_name)

Expand Down
3 changes: 3 additions & 0 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_module_flops(self):
4 * (2 * 8 - 1))
# Activations
self.assertEqual(modules.module_flops(nn.Identity(), torch.zeros((1, 8)), torch.zeros((1, 8))), 0)
self.assertEqual(modules.module_flops(nn.Flatten(), torch.zeros((1, 8)), torch.zeros((1, 8))), 0)
self.assertEqual(modules.module_flops(nn.ReLU(), torch.zeros((1, 8)), torch.zeros((1, 8))), 8)
self.assertEqual(modules.module_flops(nn.ELU(), torch.zeros((1, 8)), torch.zeros((1, 8))), 48)
self.assertEqual(modules.module_flops(nn.LeakyReLU(), torch.zeros((1, 8)), torch.zeros((1, 8))), 32)
Expand Down Expand Up @@ -130,6 +131,7 @@ def test_module_dmas(self):
4 * (8 + 1) + 8 + 4)
# Activation
self.assertEqual(modules.module_dmas(nn.Identity(), torch.zeros((1, 8)), torch.zeros((1, 8))), 8)
self.assertEqual(modules.module_dmas(nn.Flatten(), torch.zeros((1, 8)), torch.zeros((1, 8))), 16)
self.assertEqual(modules.module_dmas(nn.ReLU(), torch.zeros((1, 8)), torch.zeros((1, 8))), 8 * 2)
self.assertEqual(modules.module_dmas(nn.ReLU(inplace=True), torch.zeros((1, 8)), None), 8)
self.assertEqual(modules.module_dmas(nn.ELU(), torch.zeros((1, 8)), torch.zeros((1, 8))), 17)
Expand Down Expand Up @@ -176,6 +178,7 @@ def test_module_rf(self):
(1, 1, 0))
# Activation
self.assertEqual(modules.module_rf(nn.Identity(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
self.assertEqual(modules.module_rf(nn.Flatten(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
self.assertEqual(modules.module_rf(nn.ReLU(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
self.assertEqual(modules.module_rf(nn.ELU(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
self.assertEqual(modules.module_rf(nn.Sigmoid(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
Expand Down
208 changes: 125 additions & 83 deletions torchscan/crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,20 @@
__all__ = ['crawl_module', 'summary']


def apply(module, fn, depth=0, name=None):
def apply(module, fn, name=None):
"""Modified version of `torch.nn.Module.apply` method
Args:
module (torch.nn.Module): target module
fn (callable): function to apply to each module
depth (int, optional): current depth of `module`
name (str, optional): name of the current module
"""

if name is None:
name = module.__class__.__name__.lower()
fn(module, depth, name)
fn(module, name)
for n, m in module.named_children():
apply(m, fn, depth + 1, n)
apply(m, fn, f"{name}.{n}")


def crawl_module(module, input_shape, dtype=None, max_depth=None):
Expand Down Expand Up @@ -72,94 +71,131 @@ def crawl_module(module, input_shape, dtype=None, max_depth=None):
input_ts = [torch.rand(1, *in_shape).to(dtype=_dtype, device=device)
for in_shape, _dtype in zip(input_shape, dtype)]

pre_fw_handles, post_fw_handles = [], []
pre_hook_tracker = {}
post_hook_tracker = {}

# Hook definition
def _hook_info(module, depth, name):
def _hook_info(module, name):

def _pre_hook(module, input):
"""Pre-forward hook"""

# Params
grad_params, nograd_params, param_size = 0, 0, 0
num_buffers, buffer_size = 0, 0
is_shared = False
if not any(module.children()):
# Parameters
for p in module.parameters():
if id(p) not in param_ids:
if p.requires_grad:
grad_params += p.data.numel()
# Check that another hook has not been triggered at this forward stage
if not pre_hook_tracker[id(module)]['is_used'] and \
(pre_hook_tracker[id(module)]['target'] == pre_hook_tracker[id(module)]['current']):
# Add information
# Params
grad_params, nograd_params, param_size = 0, 0, 0
num_buffers, buffer_size = 0, 0
is_shared = False
if not any(module.children()):
# Parameters
for p in module.parameters():
if id(p) not in param_ids:
if p.requires_grad:
grad_params += p.data.numel()
else:
nograd_params += p.data.numel()
param_size += p.data.numel() * p.data.element_size()
param_ids.append(id(p))
else:
nograd_params += p.data.numel()
param_size += p.data.numel() * p.data.element_size()
param_ids.append(id(p))
else:
is_shared = True
# Buffers
for b in module.buffers():
if id(b) not in param_ids:
num_buffers += b.numel()
buffer_size += b.numel() * b.element_size()
param_ids.append(id(b))
else:
is_shared = True

call_idxs[id(module)] = len(info)

info.append(dict(name=name,
depth=depth,
type=module.__class__.__name__,
input_shape=(-1, *input[0][0].shape[1:]),
output_shape=None,
grad_params=grad_params,
nograd_params=nograd_params,
param_size=param_size,
num_buffers=num_buffers,
buffer_size=buffer_size,
flops=0,
macs=0,
dmas=0,
rf=1,
s=1,
p=0,
is_shared=is_shared,
is_leaf=not any(module.children())))

# Remove the hook by using its handle
pre_fw_handle.remove()
is_shared = True
# Buffers
for b in module.buffers():
if id(b) not in param_ids:
num_buffers += b.numel()
buffer_size += b.numel() * b.element_size()
param_ids.append(id(b))
else:
is_shared = True

if call_idxs.get(id(module)) is None:
call_idxs[id(module)] = [len(info)]
else:
call_idxs[id(module)].append(len(info))

info.append(dict(name=name.rpartition('.')[-1],
depth=len(name.split('.')) - 1,
type=module.__class__.__name__,
input_shape=(-1, *input[0][0].shape[1:]),
output_shape=None,
grad_params=grad_params,
nograd_params=nograd_params,
param_size=param_size,
num_buffers=num_buffers,
buffer_size=buffer_size,
flops=0,
macs=0,
dmas=0,
rf=1,
s=1,
p=0,
is_shared=is_shared,
is_leaf=not any(module.children())))
# Mark the next hook for execution
pre_hook_tracker[id(module)]['target'] += 1
# Current pass already used one of the hooks
pre_hook_tracker[id(module)]['is_used'] = True
pre_hook_tracker[id(module)]['current'] += 1
# All the hooks have been checked, reset the temporary values
if pre_hook_tracker[id(module)]['current'] == len(module._forward_pre_hooks):
pre_hook_tracker[id(module)]['current'] = 0
pre_hook_tracker[id(module)]['is_used'] = False

def _fwd_hook(module, input, output):
"""Post-forward hook"""

# Retrieve forward index
fw_idx = call_idxs[id(module)]

if any(module.children()):
tot_flops, tot_macs, tot_dmas = 0, 0, 0
current_rf, current_stride, current_padding = 1, 1, 0
else:
# Compute stats for standalone layers
tot_flops = module_flops(module, input[0], output)
tot_macs = module_macs(module, input[0], output)
tot_dmas = module_dmas(module, input[0], output)
current_rf, current_stride, current_padding = module_rf(module, input[0], output)

# Update layer information
info[fw_idx]['output_shape'] = (-1, *output.shape[1:])
# Add them, since some modules can be used several times
info[fw_idx]['flops'] = tot_flops
info[fw_idx]['macs'] = tot_macs
info[fw_idx]['dmas'] = tot_dmas
# Compute receptive field
info[fw_idx]['rf'] = current_rf
info[fw_idx]['s'] = current_stride
info[fw_idx]['p'] = current_padding

# Remove the hook by using its handle
post_fw_handle.remove()

# Hook only leaf children
pre_fw_handle = module.register_forward_pre_hook(_pre_hook)
post_fw_handle = module.register_forward_hook(_fwd_hook)
# Check that another hook has not been triggered at this forward stage
if not post_hook_tracker[id(module)]['is_used'] and \
(post_hook_tracker[id(module)]['target'] == post_hook_tracker[id(module)]['current']):
# Write information
# Retrieve forward index
fw_idx = call_idxs[id(module)]
if len(fw_idx) == 1:
fw_idx = fw_idx[0]
else:
# The first dictionary with output_shape=None is the correct one
for _idx in fw_idx:
if info[_idx]['output_shape'] is None:
fw_idx = _idx
break

if any(module.children()):
tot_flops, tot_macs, tot_dmas = 0, 0, 0
current_rf, current_stride, current_padding = 1, 1, 0
else:
# Compute stats for standalone layers
tot_flops = module_flops(module, input[0], output)
tot_macs = module_macs(module, input[0], output)
tot_dmas = module_dmas(module, input[0], output)
current_rf, current_stride, current_padding = module_rf(module, input[0], output)

# Update layer information
info[fw_idx]['output_shape'] = (-1, *output.shape[1:])
# Add them, since some modules can be used several times
info[fw_idx]['flops'] = tot_flops
info[fw_idx]['macs'] = tot_macs
info[fw_idx]['dmas'] = tot_dmas
# Compute receptive field
info[fw_idx]['rf'] = current_rf
info[fw_idx]['s'] = current_stride
info[fw_idx]['p'] = current_padding

# Mark the next hook for execution
post_hook_tracker[id(module)]['target'] += 1
# Current pass already used one of the hooks
post_hook_tracker[id(module)]['is_used'] = True
post_hook_tracker[id(module)]['current'] += 1
# All the hooks have been checked, reset the temporary values
if post_hook_tracker[id(module)]['current'] == len(module._forward_pre_hooks):
post_hook_tracker[id(module)]['current'] = 0
post_hook_tracker[id(module)]['is_used'] = False

pre_fw_handles.append(module.register_forward_pre_hook(_pre_hook))
post_fw_handles.append(module.register_forward_hook(_fwd_hook))
# Handle modules that are used multiple times (with several hooks)
pre_hook_tracker[id(module)] = dict(current=0, target=0, is_used=False)
post_hook_tracker[id(module)] = dict(current=0, target=0, is_used=False)

# Hook model
info = []
Expand All @@ -171,6 +207,12 @@ def _fwd_hook(module, input, output):
with torch.no_grad():
module(*input_ts)

# Removes all hooks using their handles
for handle in pre_fw_handles:
handle.remove()
for handle in post_fw_handles:
handle.remove()

reserved_ram, diff_ram = 0, 0
if torch.cuda.is_available():
reserved_ram = torch.cuda.memory_reserved() / 1024 ** 2
Expand Down
2 changes: 1 addition & 1 deletion torchscan/modules/flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def module_flops(module, input, output):
int: number of FLOPs
"""

if isinstance(module, nn.Identity):
if isinstance(module, (nn.Identity, nn.Flatten)):
return 0
elif isinstance(module, nn.Linear):
return flops_linear(module, input, output)
Expand Down
2 changes: 1 addition & 1 deletion torchscan/modules/macs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def module_macs(module, input, output):
"""
if isinstance(module, nn.Linear):
return macs_linear(module, input, output)
elif isinstance(module, (nn.Identity, nn.ReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, nn.Tanh, nn.Sigmoid)):
elif isinstance(module, (nn.Identity, nn.ReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, nn.Tanh, nn.Sigmoid, nn.Flatten)):
return 0
elif isinstance(module, _ConvTransposeNd):
return macs_convtransposend(module, input, output)
Expand Down
8 changes: 8 additions & 0 deletions torchscan/modules/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def module_dmas(module, input, output):

if isinstance(module, nn.Identity):
return dmas_identity(module, input, output)
elif isinstance(module, nn.Flatten):
return dmas_flatten(module, input, output)
elif isinstance(module, nn.Linear):
return dmas_linear(module, input, output)
elif isinstance(module, (nn.ReLU, nn.ReLU6)):
Expand Down Expand Up @@ -76,6 +78,12 @@ def dmas_identity(module, input, output):
return input.numel()


def dmas_flatten(module, input, output):
"""DMAs estimation for `torch.nn.Flatten`"""

return 2 * input.numel()


def dmas_linear(module, input, output):
"""DMAs estimation for `torch.nn.Linear`"""

Expand Down
3 changes: 2 additions & 1 deletion torchscan/modules/receptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def module_rf(module, input, output):
int: effective stride
int: effective padding
"""
if isinstance(module, (nn.Identity, nn.ReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, nn.Tanh, nn.Sigmoid, _BatchNorm)):
if isinstance(module, (nn.Identity, nn.Flatten, nn.ReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, nn.Tanh, nn.Sigmoid,
_BatchNorm)):
return 1, 1, 0
elif isinstance(module, _ConvTransposeNd):
k = module.kernel_size[0] if isinstance(module.kernel_size, tuple) else module.kernel_size
Expand Down

0 comments on commit 4db02c1

Please sign in to comment.