Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Pruning: Add support for torch.split() operation #5143

Open
AL3708 opened this issue Sep 27, 2022 · 4 comments
Open

Pruning: Add support for torch.split() operation #5143

AL3708 opened this issue Sep 27, 2022 · 4 comments

Comments

@AL3708
Copy link

AL3708 commented Sep 27, 2022

If there is torch.split() operation somewhere in the network, an AttributeError is thrown:

AttributeError                            Traceback (most recent call last)
Cell In [92], line 2
      1 pruner._unwrap_model()
----> 2 ModelSpeedup(net, torch.rand(1, 3, 32, 32, device='cpu'), masks).speedup_model()

File ~\.virtualenvs\...\lib\site-packages\nni\compression\pytorch\speedup\compressor.py:543, in ModelSpeedup.speedup_model(self)
    540 fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
    542 _logger.info("infer module masks...")
--> 543 self.infer_modules_masks()
    544 _logger.info('resolve the mask conflict')
    546 # load the original stat dict before replace the model

File ~\.virtualenvs\...\lib\site-packages\nni\compression\pytorch\speedup\compressor.py:380, in ModelSpeedup.infer_modules_masks(self)
    378 curnode = visit_queue.get()
    379 # forward mask inference for curnode
--> 380 self.update_direct_sparsity(curnode)
    381 successors = self.torch_graph.find_successors(curnode.unique_name)
    382 for successor in successors:

File ~\.virtualenvs\...\lib\site-packages\nni\compression\pytorch\speedup\compressor.py:216, in ModelSpeedup.update_direct_sparsity(self, node)
    214 _logger.info('Update mask for %s', module_name)
    215 unique_name = node.unique_name
--> 216 dummy_input, input_debugname = self._prepare_dummy_input(node)
    217 # get the input mask from self.masks
    218 # Note: the input mask of the successor nodes are
    219 # already created by the predecessor node
    220 in_masks = [self.masks[debugname] for debugname in input_debugname]

File ~\.virtualenvs\...\lib\site-packages\nni\compression\pytorch\speedup\compressor.py:200, in ModelSpeedup._prepare_dummy_input(self, node)
    197         continue
    198     # The detach operation here is for the in-place operation. We cannot
    199     # directly can the backward on the output tensor of an in-place operator.
--> 200     dummy_input.append(self.internal_result[_input].detach())
    202     debugnames.append(_input)
    204 return dummy_input, debugnames

AttributeError: 'tuple' object has no attribute 'detach'

Environment:

  • NNI version: 2.9
  • Training service (local|remote|pai|aml|etc): local
  • Client OS: Windows 10
  • Python version: 3.10.6
  • PyTorchversion: 1.12
  • Is conda/virtualenv/venv used?: pipenv
  • Is running in Docker?: No

How to reproduce it?:
Use code:

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.conv_block = nn.Sequential(
            nn.Conv2d(16, 16, 3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        self.classifier = nn.Conv2d(32, 1, 1)

    def forward(self, x):
        x = self.stem(x)
        x0, x1 = torch.split(x, 16, dim=1)
        x1 = self.conv_block(x1)
        x = torch.cat([x0, x1], dim=1)
        return self.classifier(x)

net = Network()
config_list = [{
    'total_sparsity': 0.2,
    'op_types': ['Conv2d']
}]
pruner = L1NormPruner(net, config_list)
_, masks = pruner.compress()
pruner._unwrap_model()
# Throws an error
ModelSpeedup(net, torch.rand(1, 3, 32, 32, device='cpu'), masks).speedup_model()
@Louis-J Louis-J self-assigned this Sep 28, 2022
@Louis-J
Copy link
Contributor

Louis-J commented Sep 28, 2022

it's occurred when solving other issues. we plan to repair this issue in #5107.

@wwdok
Copy link

wwdok commented Oct 31, 2022

I met similar issue today, it throw error: ,

AttributeError                            Traceback (most recent call last)

File c:\Users\user\AppData\Local\Programs\Python\Python38\lib\site-packages\nni\compression\pytorch\speedup\compressor.py:543, in ModelSpeedup.speedup_model(self)
    540 fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
    542 _logger.info("infer module masks...")
--> 543 self.infer_modules_masks()
    544 _logger.info('resolve the mask conflict')
    546 # load the original stat dict before replace the model

File c:\Users\user\AppData\Local\Programs\Python\Python38\lib\site-packages\nni\compression\pytorch\speedup\compressor.py:380, in ModelSpeedup.infer_modules_masks(self)
    378 curnode = visit_queue.get()
    379 # forward mask inference for curnode
--> 380 self.update_direct_sparsity(curnode)
    381 successors = self.torch_graph.find_successors(curnode.unique_name)
    382 for successor in successors:

File c:\Users\user\AppData\Local\Programs\Python\Python38\lib\site-packages\nni\compression\pytorch\speedup\compressor.py:216, in ModelSpeedup.update_direct_sparsity(self, node)
    214 _logger.info('Update mask for %s', module_name)
...
--> 200     dummy_input.append(self.internal_result[_input].detach())
    202     debugnames.append(_input)
    204 return dummy_input, debugnames

AttributeError: 'tuple' object has no attribute 'detach'

my model has structure like shufflenet, so it has torch.chunk() in its def forward()

@Lijiaoa
Copy link
Contributor

Lijiaoa commented Nov 7, 2022

hi @wwdok This bug will be fixed in pr #5107!

it's occurred when solving other issues. we plan to repair this issue in #5107.

@Lijiaoa
Copy link
Contributor

Lijiaoa commented May 8, 2023

Sorry for late reply and please try nni latest version(master) or v3.0. If you still have some problem please feel free to let me know. @AL3708

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

5 participants