diff --git a/.conda/meta.yaml b/.conda/meta.yaml index 52d520d..d3e67ad 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -19,7 +19,7 @@ requirements: - python>=3.6 run: - - pytorch >=1.1.0, <=1.4.0 + - pytorch >=1.5.0 test: # Python imports diff --git a/requirements.txt b/requirements.txt index 5e73e53..26cb806 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -torch>=1.1.0 +torch>=1.5.0 diff --git a/setup.py b/setup.py index 9c3ee96..c7eae4f 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ def write_version_file(): readme = f.read() requirements = [ - 'torch>=1.1.0' + 'torch>=1.5.0' ] setup( diff --git a/test/test_crawler.py b/test/test_crawler.py index d4c56d6..dbb32a7 100644 --- a/test/test_crawler.py +++ b/test/test_crawler.py @@ -15,9 +15,9 @@ def test_apply(self): mod = nn.Sequential(nn.Conv2d(3, 16, 3), multi_convs) # Tag module attributes - def tag_name(mod, depth, name): - mod.__depth__ = depth - mod.__name__ = name + def tag_name(mod, name): + mod.__depth__ = len(name.split('.')) - 1 + mod.__name__ = name.rpartition('.')[-1] crawler.apply(mod, tag_name) diff --git a/test/test_modules.py b/test/test_modules.py index 60391f7..e0485d9 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -27,6 +27,7 @@ def test_module_flops(self): 4 * (2 * 8 - 1)) # Activations self.assertEqual(modules.module_flops(nn.Identity(), torch.zeros((1, 8)), torch.zeros((1, 8))), 0) + self.assertEqual(modules.module_flops(nn.Flatten(), torch.zeros((1, 8)), torch.zeros((1, 8))), 0) self.assertEqual(modules.module_flops(nn.ReLU(), torch.zeros((1, 8)), torch.zeros((1, 8))), 8) self.assertEqual(modules.module_flops(nn.ELU(), torch.zeros((1, 8)), torch.zeros((1, 8))), 48) self.assertEqual(modules.module_flops(nn.LeakyReLU(), torch.zeros((1, 8)), torch.zeros((1, 8))), 32) @@ -130,6 +131,7 @@ def test_module_dmas(self): 4 * (8 + 1) + 8 + 4) # Activation self.assertEqual(modules.module_dmas(nn.Identity(), torch.zeros((1, 8)), torch.zeros((1, 8))), 8) + self.assertEqual(modules.module_dmas(nn.Flatten(), torch.zeros((1, 8)), torch.zeros((1, 8))), 16) self.assertEqual(modules.module_dmas(nn.ReLU(), torch.zeros((1, 8)), torch.zeros((1, 8))), 8 * 2) self.assertEqual(modules.module_dmas(nn.ReLU(inplace=True), torch.zeros((1, 8)), None), 8) self.assertEqual(modules.module_dmas(nn.ELU(), torch.zeros((1, 8)), torch.zeros((1, 8))), 17) @@ -176,6 +178,7 @@ def test_module_rf(self): (1, 1, 0)) # Activation self.assertEqual(modules.module_rf(nn.Identity(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) + self.assertEqual(modules.module_rf(nn.Flatten(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) self.assertEqual(modules.module_rf(nn.ReLU(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) self.assertEqual(modules.module_rf(nn.ELU(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) self.assertEqual(modules.module_rf(nn.Sigmoid(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) diff --git a/torchscan/crawler.py b/torchscan/crawler.py index 4c80ffe..5344964 100644 --- a/torchscan/crawler.py +++ b/torchscan/crawler.py @@ -15,21 +15,20 @@ __all__ = ['crawl_module', 'summary'] -def apply(module, fn, depth=0, name=None): +def apply(module, fn, name=None): """Modified version of `torch.nn.Module.apply` method Args: module (torch.nn.Module): target module fn (callable): function to apply to each module - depth (int, optional): current depth of `module` name (str, optional): name of the current module """ if name is None: name = module.__class__.__name__.lower() - fn(module, depth, name) + fn(module, name) for n, m in module.named_children(): - apply(m, fn, depth + 1, n) + apply(m, fn, f"{name}.{n}") def crawl_module(module, input_shape, dtype=None, max_depth=None): @@ -72,94 +71,131 @@ def crawl_module(module, input_shape, dtype=None, max_depth=None): input_ts = [torch.rand(1, *in_shape).to(dtype=_dtype, device=device) for in_shape, _dtype in zip(input_shape, dtype)] + pre_fw_handles, post_fw_handles = [], [] + pre_hook_tracker = {} + post_hook_tracker = {} + # Hook definition - def _hook_info(module, depth, name): + def _hook_info(module, name): def _pre_hook(module, input): """Pre-forward hook""" - - # Params - grad_params, nograd_params, param_size = 0, 0, 0 - num_buffers, buffer_size = 0, 0 - is_shared = False - if not any(module.children()): - # Parameters - for p in module.parameters(): - if id(p) not in param_ids: - if p.requires_grad: - grad_params += p.data.numel() + # Check that another hook has not been triggered at this forward stage + if not pre_hook_tracker[id(module)]['is_used'] and \ + (pre_hook_tracker[id(module)]['target'] == pre_hook_tracker[id(module)]['current']): + # Add information + # Params + grad_params, nograd_params, param_size = 0, 0, 0 + num_buffers, buffer_size = 0, 0 + is_shared = False + if not any(module.children()): + # Parameters + for p in module.parameters(): + if id(p) not in param_ids: + if p.requires_grad: + grad_params += p.data.numel() + else: + nograd_params += p.data.numel() + param_size += p.data.numel() * p.data.element_size() + param_ids.append(id(p)) else: - nograd_params += p.data.numel() - param_size += p.data.numel() * p.data.element_size() - param_ids.append(id(p)) - else: - is_shared = True - # Buffers - for b in module.buffers(): - if id(b) not in param_ids: - num_buffers += b.numel() - buffer_size += b.numel() * b.element_size() - param_ids.append(id(b)) - else: - is_shared = True - - call_idxs[id(module)] = len(info) - - info.append(dict(name=name, - depth=depth, - type=module.__class__.__name__, - input_shape=(-1, *input[0][0].shape[1:]), - output_shape=None, - grad_params=grad_params, - nograd_params=nograd_params, - param_size=param_size, - num_buffers=num_buffers, - buffer_size=buffer_size, - flops=0, - macs=0, - dmas=0, - rf=1, - s=1, - p=0, - is_shared=is_shared, - is_leaf=not any(module.children()))) - - # Remove the hook by using its handle - pre_fw_handle.remove() + is_shared = True + # Buffers + for b in module.buffers(): + if id(b) not in param_ids: + num_buffers += b.numel() + buffer_size += b.numel() * b.element_size() + param_ids.append(id(b)) + else: + is_shared = True + + if call_idxs.get(id(module)) is None: + call_idxs[id(module)] = [len(info)] + else: + call_idxs[id(module)].append(len(info)) + + info.append(dict(name=name.rpartition('.')[-1], + depth=len(name.split('.')) - 1, + type=module.__class__.__name__, + input_shape=(-1, *input[0][0].shape[1:]), + output_shape=None, + grad_params=grad_params, + nograd_params=nograd_params, + param_size=param_size, + num_buffers=num_buffers, + buffer_size=buffer_size, + flops=0, + macs=0, + dmas=0, + rf=1, + s=1, + p=0, + is_shared=is_shared, + is_leaf=not any(module.children()))) + # Mark the next hook for execution + pre_hook_tracker[id(module)]['target'] += 1 + # Current pass already used one of the hooks + pre_hook_tracker[id(module)]['is_used'] = True + pre_hook_tracker[id(module)]['current'] += 1 + # All the hooks have been checked, reset the temporary values + if pre_hook_tracker[id(module)]['current'] == len(module._forward_pre_hooks): + pre_hook_tracker[id(module)]['current'] = 0 + pre_hook_tracker[id(module)]['is_used'] = False def _fwd_hook(module, input, output): """Post-forward hook""" - # Retrieve forward index - fw_idx = call_idxs[id(module)] - - if any(module.children()): - tot_flops, tot_macs, tot_dmas = 0, 0, 0 - current_rf, current_stride, current_padding = 1, 1, 0 - else: - # Compute stats for standalone layers - tot_flops = module_flops(module, input[0], output) - tot_macs = module_macs(module, input[0], output) - tot_dmas = module_dmas(module, input[0], output) - current_rf, current_stride, current_padding = module_rf(module, input[0], output) - - # Update layer information - info[fw_idx]['output_shape'] = (-1, *output.shape[1:]) - # Add them, since some modules can be used several times - info[fw_idx]['flops'] = tot_flops - info[fw_idx]['macs'] = tot_macs - info[fw_idx]['dmas'] = tot_dmas - # Compute receptive field - info[fw_idx]['rf'] = current_rf - info[fw_idx]['s'] = current_stride - info[fw_idx]['p'] = current_padding - - # Remove the hook by using its handle - post_fw_handle.remove() - - # Hook only leaf children - pre_fw_handle = module.register_forward_pre_hook(_pre_hook) - post_fw_handle = module.register_forward_hook(_fwd_hook) + # Check that another hook has not been triggered at this forward stage + if not post_hook_tracker[id(module)]['is_used'] and \ + (post_hook_tracker[id(module)]['target'] == post_hook_tracker[id(module)]['current']): + # Write information + # Retrieve forward index + fw_idx = call_idxs[id(module)] + if len(fw_idx) == 1: + fw_idx = fw_idx[0] + else: + # The first dictionary with output_shape=None is the correct one + for _idx in fw_idx: + if info[_idx]['output_shape'] is None: + fw_idx = _idx + break + + if any(module.children()): + tot_flops, tot_macs, tot_dmas = 0, 0, 0 + current_rf, current_stride, current_padding = 1, 1, 0 + else: + # Compute stats for standalone layers + tot_flops = module_flops(module, input[0], output) + tot_macs = module_macs(module, input[0], output) + tot_dmas = module_dmas(module, input[0], output) + current_rf, current_stride, current_padding = module_rf(module, input[0], output) + + # Update layer information + info[fw_idx]['output_shape'] = (-1, *output.shape[1:]) + # Add them, since some modules can be used several times + info[fw_idx]['flops'] = tot_flops + info[fw_idx]['macs'] = tot_macs + info[fw_idx]['dmas'] = tot_dmas + # Compute receptive field + info[fw_idx]['rf'] = current_rf + info[fw_idx]['s'] = current_stride + info[fw_idx]['p'] = current_padding + + # Mark the next hook for execution + post_hook_tracker[id(module)]['target'] += 1 + # Current pass already used one of the hooks + post_hook_tracker[id(module)]['is_used'] = True + post_hook_tracker[id(module)]['current'] += 1 + # All the hooks have been checked, reset the temporary values + if post_hook_tracker[id(module)]['current'] == len(module._forward_pre_hooks): + post_hook_tracker[id(module)]['current'] = 0 + post_hook_tracker[id(module)]['is_used'] = False + + pre_fw_handles.append(module.register_forward_pre_hook(_pre_hook)) + post_fw_handles.append(module.register_forward_hook(_fwd_hook)) + # Handle modules that are used multiple times (with several hooks) + pre_hook_tracker[id(module)] = dict(current=0, target=0, is_used=False) + post_hook_tracker[id(module)] = dict(current=0, target=0, is_used=False) # Hook model info = [] @@ -171,6 +207,12 @@ def _fwd_hook(module, input, output): with torch.no_grad(): module(*input_ts) + # Removes all hooks using their handles + for handle in pre_fw_handles: + handle.remove() + for handle in post_fw_handles: + handle.remove() + reserved_ram, diff_ram = 0, 0 if torch.cuda.is_available(): reserved_ram = torch.cuda.memory_reserved() / 1024 ** 2 diff --git a/torchscan/modules/flops.py b/torchscan/modules/flops.py index 84cf316..57fd757 100644 --- a/torchscan/modules/flops.py +++ b/torchscan/modules/flops.py @@ -28,7 +28,7 @@ def module_flops(module, input, output): int: number of FLOPs """ - if isinstance(module, nn.Identity): + if isinstance(module, (nn.Identity, nn.Flatten)): return 0 elif isinstance(module, nn.Linear): return flops_linear(module, input, output) diff --git a/torchscan/modules/macs.py b/torchscan/modules/macs.py index 526027c..5cc7990 100644 --- a/torchscan/modules/macs.py +++ b/torchscan/modules/macs.py @@ -29,7 +29,7 @@ def module_macs(module, input, output): """ if isinstance(module, nn.Linear): return macs_linear(module, input, output) - elif isinstance(module, (nn.Identity, nn.ReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, nn.Tanh, nn.Sigmoid)): + elif isinstance(module, (nn.Identity, nn.ReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, nn.Tanh, nn.Sigmoid, nn.Flatten)): return 0 elif isinstance(module, _ConvTransposeNd): return macs_convtransposend(module, input, output) diff --git a/torchscan/modules/memory.py b/torchscan/modules/memory.py index fa19a24..85299b1 100644 --- a/torchscan/modules/memory.py +++ b/torchscan/modules/memory.py @@ -31,6 +31,8 @@ def module_dmas(module, input, output): if isinstance(module, nn.Identity): return dmas_identity(module, input, output) + elif isinstance(module, nn.Flatten): + return dmas_flatten(module, input, output) elif isinstance(module, nn.Linear): return dmas_linear(module, input, output) elif isinstance(module, (nn.ReLU, nn.ReLU6)): @@ -76,6 +78,12 @@ def dmas_identity(module, input, output): return input.numel() +def dmas_flatten(module, input, output): + """DMAs estimation for `torch.nn.Flatten`""" + + return 2 * input.numel() + + def dmas_linear(module, input, output): """DMAs estimation for `torch.nn.Linear`""" diff --git a/torchscan/modules/receptive.py b/torchscan/modules/receptive.py index 38c48af..e394402 100644 --- a/torchscan/modules/receptive.py +++ b/torchscan/modules/receptive.py @@ -27,7 +27,8 @@ def module_rf(module, input, output): int: effective stride int: effective padding """ - if isinstance(module, (nn.Identity, nn.ReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, nn.Tanh, nn.Sigmoid, _BatchNorm)): + if isinstance(module, (nn.Identity, nn.Flatten, nn.ReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, nn.Tanh, nn.Sigmoid, + _BatchNorm)): return 1, 1, 0 elif isinstance(module, _ConvTransposeNd): k = module.kernel_size[0] if isinstance(module.kernel_size, tuple) else module.kernel_size