Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add TupleConstruct to allow (a, b)=SomeModule() #3357

Merged
merged 1 commit into from
Feb 3, 2021

Conversation

tczhangzhi
Copy link
Contributor

NNI did not support the prim::TupleConstruct operation, causing the following code to report an error when outputting tuple data in forward:

import random

import nni.retiarii.nn.pytorch as nn
import torch.nn.functional as F
from nni.retiarii.experiment import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii.strategies import RandomStrategy
from nni.retiarii.trainer import PyTorchImageClassificationTrainer


class Net(nn.Module):
    def __init__(self, hidden_size):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.LayerChoice([
            nn.Linear(4*4*50, hidden_size),
            nn.Linear(4*4*50, hidden_size, bias=False)
        ])
        self.fc2 = nn.Linear(hidden_size, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        =========== tuple output start ===========
        return F.log_softmax(x, dim=1), x
        =========== tuple output end ===========


if __name__ == '__main__':
    base_model = Net(128)
    # some customized trainer
    trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="MNIST",
                                                dataset_kwargs={"root": "data/mnist", "download": True},
                                                dataloader_kwargs={"batch_size": 32},
                                                optimizer_kwargs={"lr": 1e-3},
                                                trainer_kwargs={"max_epochs": 1})

    simple_startegy = RandomStrategy()

    exp = RetiariiExperiment(base_model, trainer, [], simple_startegy)

    exp_config = RetiariiExeConfig('local')
    exp_config.experiment_name = 'mnist_search'
    exp_config.trial_concurrency = 2
    exp_config.max_trial_number = 10
    exp_config.training_service.use_active_gpu = False

    exp.run(exp_config, 7081 + random.randint(0, 100))

The error message is as follows:

Traceback (most recent call last):
  File "test.py", line 50, in <module>
    exp.run(exp_config, 7081 + random.randint(0, 100))
  File "/data/zhangzhi/nni/nni/retiarii/experiment.py", line 173, in run
    super().run(port, debug)
  File "/data/zhangzhi/nni/nni/experiment/experiment.py", line 181, in run
    self.start(port, debug)
  File "/data/zhangzhi/nni/nni/retiarii/experiment.py", line 158, in start
    self._start_strategy()
  File "/data/zhangzhi/nni/nni/retiarii/experiment.py", line 124, in _start_strategy
    base_model_ir = convert_to_graph(script_module, self.base_model)
  File "/data/zhangzhi/nni/nni/retiarii/converter/graph_gen.py", line 522, in convert_to_graph
    GraphConverter().convert_module(script_module, module, module_name, model)
  File "/data/zhangzhi/nni/nni/retiarii/converter/graph_gen.py", line 483, in convert_module
    module_name, ir_model, ir_graph)
  File "/data/zhangzhi/nni/nni/retiarii/converter/graph_gen.py", line 349, in handle_graph_nodes
    handle_single_node(node)
  File "/data/zhangzhi/nni/nni/retiarii/converter/graph_gen.py", line 344, in handle_single_node
    raise RuntimeError('Unsupported kind: {}'.format(node.kind()))
RuntimeError: Unsupported kind: prim::TupleConstruct

After PR, the generated operation is as follows:

__TupleConstruct32 = (__log_softmax31, __fc2)
return __TupleConstruct32

This PR is partly related to #2756 (requiring the support for TupleUnpack and TupleConstruct) and #3340 (working on TupleUnpack).

@tczhangzhi
Copy link
Contributor Author

BTW, I found this bug during adding regularization terms (of loss functions) depending on some intermediate variables. I also fixed other problems but they are not related to this issue, so I will submit them in another PRs.

@QuanluZhang QuanluZhang requested a review from ultmaster February 2, 2021 06:47
@ultmaster ultmaster changed the title Add TupleConstruct to permit (a, b)=SomeModule() Add TupleConstruct to allow (a, b)=SomeModule() Feb 2, 2021
@J-shang J-shang merged commit 5946b4a into microsoft:master Feb 3, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants