Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Fix a few bugs in Retiarii and upgrade Dockerfile #3713

Merged
merged 6 commits into from
Jun 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

FROM nvidia/cuda:9.2-cudnn7-runtime-ubuntu18.04
FROM nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04

ARG NNI_RELEASE

Expand Down Expand Up @@ -44,23 +44,22 @@ RUN ln -s python3 /usr/bin/python
RUN python3 -m pip install --upgrade pip==20.2.4 setuptools==50.3.2

# numpy 1.14.3 scipy 1.1.0
RUN python3 -m pip --no-cache-dir install numpy==1.14.3 scipy==1.1.0
RUN python3 -m pip --no-cache-dir install numpy==1.19.5 scipy==1.6.3

#
# TensorFlow
#
RUN python3 -m pip --no-cache-dir install tensorflow==2.3.1

#
# Keras 2.1.6
# Keras
#
RUN python3 -m pip --no-cache-dir install Keras==2.1.6
RUN python3 -m pip --no-cache-dir install Keras==2.4.0

#
# PyTorch
#
RUN python3 -m pip --no-cache-dir install torch==1.6.0
RUN python3 -m pip install torchvision==0.7.0
RUN python3 -m pip --no-cache-dir install torch==1.7.1 torchvision==0.8.2 pytorch-lightning==1.3.3

#
# sklearn 0.24.1
Expand All @@ -70,7 +69,7 @@ RUN python3 -m pip --no-cache-dir install scikit-learn==0.24.1
#
# pandas==0.23.4 lightgbm==2.2.2
#
RUN python3 -m pip --no-cache-dir install pandas==0.23.4 lightgbm==2.2.2
RUN python3 -m pip --no-cache-dir install pandas==1.1 lightgbm==2.2.2

#
# Install NNI
Expand Down
5 changes: 4 additions & 1 deletion docs/en_US/NAS/retiarii/Advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ If you are experiencing issues with TorchScript, or the generated model code by

This will come as the default execution engine in future version of Retiarii.

Two steps are needed to enable this engine now.
Three steps are needed to enable this engine now.

1. Add ``@nni.retiarii.model_wrapper`` decorator outside the whole PyTorch model.
2. Add ``config.execution_engine = 'py'`` to ``RetiariiExeConfig``.
3. If you need to export top models, formatter needs to be set to ``dict``. Exporting ``code`` won't work with this engine.

.. note:: You should always use ``super().__init__()` instead of ``super(MyNetwork, self).__init__()`` in the PyTorch model, because the latter one has issues with model wrapper.

``@basic_unit`` and ``serializer``
----------------------------------
Expand Down
16 changes: 11 additions & 5 deletions examples/nas/multi-trial/mnist/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,23 @@
import nni.retiarii.strategy as strategy
import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn.functional as F
from nni.retiarii import serialize
from nni.retiarii import serialize, model_wrapper
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment, debug_mutated_model
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST


# uncomment this for python execution engine
# @model_wrapper
class Net(nn.Module):
def __init__(self, hidden_size):
super(Net, self).__init__()
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.LayerChoice([
nn.Linear(4*4*50, hidden_size),
nn.Linear(4*4*50, hidden_size, bias=False)
])
], label='fc1_choice')
self.fc2 = nn.Linear(hidden_size, 10)

def forward(self, x):
Expand Down Expand Up @@ -55,8 +56,13 @@ def forward(self, x):
exp_config.trial_concurrency = 2
exp_config.max_trial_number = 2
exp_config.training_service.use_active_gpu = False
export_formatter = 'code'

# uncomment this for python execution engine
# exp_config.execution_engine = 'py'
# export_formatter = 'dict'

exp.run(exp_config, 8081 + random.randint(0, 100))
print('Final model:')
for model_code in exp.export_top_models():
for model_code in exp.export_top_models(formatter=export_formatter):
print(model_code)
2 changes: 2 additions & 0 deletions nni/retiarii/codegen/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
import torch.nn.functional as F
import torch.optim as optim

import nni.retiarii.nn.pytorch

{}

{}
Expand Down
6 changes: 5 additions & 1 deletion nni/retiarii/execution/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def load(data) -> 'PythonGraphData':
class PurePythonExecutionEngine(BaseExecutionEngine):
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history}
mutation = get_mutation_dict(model)
graph_data = PythonGraphData(get_importable_name(model.python_class, relocate_module=True),
model.python_init_params, mutation, model.evaluator)
return graph_data
Expand All @@ -51,3 +51,7 @@ def _unpack_if_only_one(ele: List[Any]):
if len(ele) == 1:
return ele[0]
return ele


def get_mutation_dict(model: Model):
return {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history}
13 changes: 10 additions & 3 deletions nni/retiarii/experiment/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ..codegen import model_to_pytorch_script
from ..converter import convert_to_graph
from ..execution import list_models, set_execution_engine
from ..execution.python import get_mutation_dict
from ..graph import Model, Evaluator
from ..integration import RetiariiAdvisor
from ..mutator import Mutator
Expand Down Expand Up @@ -317,7 +318,7 @@ def export_top_models(self, top_k: int = 1, optimize_mode: str = 'maximize', for
"""
Export several top performing models.

For one-shot algorithms, only top-1 is supported. For others, ``optimize_mode`` asnd ``formater`` is
For one-shot algorithms, only top-1 is supported. For others, ``optimize_mode`` and ``formatter`` are
available for customization.

top_k : int
Expand All @@ -326,18 +327,24 @@ def export_top_models(self, top_k: int = 1, optimize_mode: str = 'maximize', for
``maximize`` or ``minimize``. Not supported by one-shot algorithms.
``optimize_mode`` is likely to be removed and defined in strategy in future.
formatter : str
Only model code is supported for now. Not supported by one-shot algorithms.
Support ``code`` and ``dict``. Not supported by one-shot algorithms.
If ``code``, the python code of model will be returned.
If ``dict``, the mutation history will be returned.
"""
if formatter == 'code':
assert self.config.execution_engine != 'py', 'You should use `dict` formatter when using Python execution engine.'
if isinstance(self.trainer, BaseOneShotTrainer):
assert top_k == 1, 'Only support top_k is 1 for now.'
return self.trainer.export()
else:
all_models = filter(lambda m: m.metric is not None, list_models())
assert optimize_mode in ['maximize', 'minimize']
all_models = sorted(all_models, key=lambda m: m.metric, reverse=optimize_mode == 'maximize')
assert formatter == 'code', 'Export formatter other than "code" is not supported yet.'
assert formatter in ['code', 'dict'], 'Export formatter other than "code" and "dict" is not supported yet.'
if formatter == 'code':
return [model_to_pytorch_script(model) for model in all_models[:top_k]]
elif formatter == 'dict':
return [get_mutation_dict(model) for model in all_models[:top_k]]

def retrain_model(self, model):
"""
Expand Down
39 changes: 38 additions & 1 deletion nni/retiarii/strategy/_rl_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file might cause import error for those who didn't install RL-related dependencies

import logging
from multiprocessing.pool import ThreadPool

import gym
import numpy as np
Expand All @@ -9,6 +10,7 @@

from gym import spaces
from tianshou.data import to_torch
from tianshou.env.worker import EnvWorker

from .utils import get_targeted_model
from ..graph import ModelStatus
Expand All @@ -18,6 +20,41 @@
_logger = logging.getLogger(__name__)


class MultiThreadEnvWorker(EnvWorker):
def __init__(self, env_fn):
self.env = env_fn()
self.pool = ThreadPool(processes=1)
super().__init__(env_fn)

def __getattr__(self, key):
return getattr(self.env, key)

def reset(self):
return self.env.reset()

@staticmethod
def wait(*args, **kwargs):
raise NotImplementedError('Async collect is not supported yet.')

def send_action(self, action) -> None:
# self.result is actually a handle
self.result = self.pool.apply_async(self.env.step, (action,))

def get_result(self):
return self.result.get()

def seed(self, seed):
super().seed(seed)
return self.env.seed(seed)

def render(self, **kwargs):
return self.env.render(**kwargs)

def close_env(self) -> None:
self.pool.terminate()
return self.env.close()


class ModelEvaluationEnv(gym.Env):
def __init__(self, base_model, mutators, search_space):
self.base_model = base_model
Expand Down Expand Up @@ -107,7 +144,7 @@ def forward(self, obs, **kwargs):
# to take care of choices with different number of options
mask = torch.arange(self.action_dim).expand(len(out), self.action_dim) >= obs['action_dim'].unsqueeze(1)
out[mask.to(out.device)] = float('-inf')
return nn.functional.softmax(out), kwargs.get('state', None)
return nn.functional.softmax(out, dim=-1), kwargs.get('state', None)


class Critic(nn.Module):
Expand Down
26 changes: 6 additions & 20 deletions nni/retiarii/strategy/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
try:
has_tianshou = True
import torch
from tianshou.data import AsyncCollector, Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy, PPOPolicy # pylint: disable=unused-import
from ._rl_impl import ModelEvaluationEnv, Preprocessor, Actor, Critic
from ._rl_impl import ModelEvaluationEnv, MultiThreadEnvWorker, Preprocessor, Actor, Critic
except ImportError:
has_tianshou = False

Expand All @@ -25,8 +25,6 @@ class PolicyBasedRL(BaseStrategy):
This is a wrapper of algorithms provided in tianshou (PPO by default),
and can be easily customized with other algorithms that inherit ``BasePolicy`` (e.g., REINFORCE [1]_).

Note that RL algorithms are known to have issues on Windows and MacOS. They will be supported in future.

Parameters
----------
max_collect : int
Expand All @@ -36,12 +34,6 @@ class PolicyBasedRL(BaseStrategy):
After each collect, trainer will sample batch from replay buffer and do the update. Default: 20.
policy_fn : function
Takes ``ModelEvaluationEnv`` as input and return a policy. See ``_default_policy_fn`` for an example.
asynchronous : bool
If true, in each step, collector won't wait for all the envs to complete.
This should generally not affect the result, but might affect the efficiency. Note that a slightly more trials
than expected might be collected if this is enabled.
If asynchronous is false, collector will wait for all parallel environments to complete in each step.
See ``tianshou.data.AsyncCollector`` for more details.

References
----------
Expand All @@ -51,15 +43,14 @@ class PolicyBasedRL(BaseStrategy):
"""

def __init__(self, max_collect: int = 100, trial_per_collect = 20,
policy_fn: Optional[Callable[['ModelEvaluationEnv'], 'BasePolicy']] = None, asynchronous: bool = True):
policy_fn: Optional[Callable[['ModelEvaluationEnv'], 'BasePolicy']] = None):
if not has_tianshou:
raise ImportError('`tianshou` is required to run RL-based strategy. '
'Please use "pip install tianshou" to install it beforehand.')

self.policy_fn = policy_fn or self._default_policy_fn
self.max_collect = max_collect
self.trial_per_collect = trial_per_collect
self.asynchronous = asynchronous

@staticmethod
def _default_policy_fn(env):
Expand All @@ -77,13 +68,8 @@ def run(self, base_model, applied_mutators):
env_fn = lambda: ModelEvaluationEnv(base_model, applied_mutators, search_space)
policy = self.policy_fn(env_fn())

if self.asynchronous:
# wait for half of the env complete in each step
env = SubprocVectorEnv([env_fn for _ in range(concurrency)], wait_num=int(concurrency * 0.5))
collector = AsyncCollector(policy, env, VectorReplayBuffer(20000, len(env)))
else:
env = SubprocVectorEnv([env_fn for _ in range(concurrency)])
collector = Collector(policy, env, VectorReplayBuffer(20000, len(env)))
env = BaseVectorEnv([env_fn for _ in range(concurrency)], MultiThreadEnvWorker)
collector = Collector(policy, env, VectorReplayBuffer(20000, len(env)))

for cur_collect in range(1, self.max_collect + 1):
_logger.info('Collect [%d] Running...', cur_collect)
Expand Down
2 changes: 2 additions & 0 deletions test/ut/retiarii/debug_mnist_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch.nn.functional as F
import torch.optim as optim

import nni.retiarii.nn.pytorch

import torch


Expand Down
3 changes: 1 addition & 2 deletions test/ut/retiarii/test_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def test_evolution():
_reset_execution_engine()


@pytest.mark.skipif(sys.platform in ('win32', 'darwin'), reason='Does not run on Windows and MacOS')
def test_rl():
rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
engine = MockExecutionEngine(failure_prob=0.2)
Expand All @@ -150,7 +149,7 @@ def test_rl():
wait_models(*engine.models)
_reset_execution_engine()

rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10, asynchronous=False)
rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
engine = MockExecutionEngine(failure_prob=0.2)
_reset_execution_engine(engine)
rl.run(*_get_model_and_mutators())
Expand Down