Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

add function for calculating the real model size #2401

Closed
wants to merge 10 commits into from
Closed

add function for calculating the real model size #2401

wants to merge 10 commits into from

Conversation

marsggbo
Copy link
Contributor

@marsggbo marsggbo commented May 1, 2020

refer to #1947

@ultmaster
Copy link
Contributor

ultmaster commented May 5, 2020

I'd like to provide another implementation which is more general to me, because even if you throw away the unused modules in layer choice, there still could be other unused modules not in layer choice.

import logging

import torch
import torch.nn as nn

__all__ = ["flops_counter"]


def count_convNd(m, _, y):
    cin = m.in_channels
    kernel_ops = m.weight.size()[2] * m.weight.size()[3]
    ops_per_element = kernel_ops
    output_elements = y.nelement()
    total_ops = cin * output_elements * ops_per_element // m.groups  # cout x oW x oH
    m.total_ops = torch.Tensor([int(total_ops)])
    m.module_used = torch.tensor([1])


def count_linear(m, _, __):
    total_ops = m.in_features * m.out_features
    m.total_ops = torch.Tensor([int(total_ops)])
    m.module_used = torch.tensor([1])


def count_naive(m, _, __):
    m.module_used = torch.tensor([1])


register_hooks = {
    nn.Conv1d: count_convNd,
    nn.Conv2d: count_convNd,
    nn.Conv3d: count_convNd,
    nn.Linear: count_linear,
}


def flops_counter(model, input_size):
    handler_collection = []
    logger = logging.getLogger(__name__)

    def add_hooks(m_):
        if len(list(m_.children())) > 0:
            return

        m_.register_buffer('total_ops', torch.zeros(1))
        m_.register_buffer('total_params', torch.zeros(1))
        m_.register_buffer('module_used', torch.zeros(1))

        for p in m_.parameters():
            m_.total_params += torch.Tensor([p.numel()])

        m_type = type(m_)
        fn = register_hooks.get(m_type, count_naive)

        if fn is not None:
            _handler = m_.register_forward_hook(fn)
            handler_collection.append(_handler)

    def remove_buffer(m_):
        if len(list(m_.children())) > 0:
            return

        del m_.total_ops, m_.total_params, m_.module_used

    original_device = next(model.parameters()).device
    training = model.training

    model.eval()
    model.apply(add_hooks)

    assert isinstance(input_size, tuple)
    if torch.is_tensor(input_size[0]):
        x = (t.to(original_device) for t in input_size)
    else:
        x = (torch.zeros(input_size).to(original_device), )
    with torch.no_grad():
        model(*x)

    total_ops = 0
    total_params = 0
    for name, m in model.named_modules():
        if len(list(m.children())) > 0:  # skip for non-leaf module
            continue
        if not m.module_used:
            continue
        total_ops += m.total_ops
        total_params += m.total_params
        logger.debug("%s: %.2f %.2f", name, m.total_ops.item(), m.total_params.item())

    total_ops = total_ops.item()
    total_params = total_params.item()

    model.train(training).to(original_device)
    for handler in handler_collection:
        handler.remove()
    model.apply(remove_buffer)

    return total_ops, total_params

@ultmaster
Copy link
Contributor

ultmaster commented May 5, 2020

Also I'm interested in the purpose of this get_real_model_size? Is it for calculating and showing the model size after apply_fixed_architecture, or do you want to use the model size to guide your search strategy during search pharse? Because if it's the former case, substituting layer choice with the selected module is actually a planned item for release v1.6. If it's the latter case, it's often up to mutator to decide the definition of parameter size of those MixedOps.

@QuanluZhang
Copy link
Contributor

@marsggbo thanks for your contribution. please check whether the pr #2420 solves your problem. close this pr for now.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants