Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Support loading supernet checkpoint in lightning (#5096)
Browse files Browse the repository at this point in the history
  • Loading branch information
ultmaster authored Sep 2, 2022
1 parent 79a51d4 commit 5874c27
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 6 deletions.
18 changes: 17 additions & 1 deletion nni/nas/evaluator/pytorch/lightning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import os
import warnings
from pathlib import Path
Expand Down Expand Up @@ -31,6 +32,8 @@
# FIXME: hack to make it importable for tests
]

_logger = logging.getLogger(__name__)


class LightningModule(pl.LightningModule):
"""
Expand Down Expand Up @@ -175,14 +178,21 @@ def __eq__(self, other):
def fit(self, model):
"""
Fit the model with provided dataloader, with Lightning trainer.
If ``train_dataloaders`` is not provided, ``trainer.validate()`` will be called.
Parameters
----------
model : nn.Module
The model to fit.
"""
self.module.set_model(model)
return self.trainer.fit(self.module, self.train_dataloaders, self.val_dataloaders, **self.fit_kwargs)
if self.train_dataloaders is None:
_logger.info('Train dataloaders are missing. Skip to validation.')
return self.trainer.validate(self.module, self.val_dataloaders, **self.fit_kwargs)
else:
if self.val_dataloaders is None:
_logger.warning('Validation dataloaders are missing.')
return self.trainer.fit(self.module, self.train_dataloaders, self.val_dataloaders, **self.fit_kwargs)


def _check_dataloader(dataloader):
Expand Down Expand Up @@ -265,6 +275,12 @@ def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics())

def on_fit_end(self):
self._final_report()

def on_validation_end(self):
self._final_report()

def _final_report(self):
if self.running_mode == 'multi' and nni.get_current_parameter() is not None:
nni.report_final_result(self._get_validation_metrics())

Expand Down
32 changes: 32 additions & 0 deletions nni/nas/fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import json
import logging
from contextlib import contextmanager
from pathlib import Path
from typing import Union, Dict, Any

Expand Down Expand Up @@ -41,3 +42,34 @@ def fixed_arch(fixed_arch: Union[str, Path, Dict[str, Any]], verbose=True):
_logger.info(f'Fixed architecture: %s', fixed_arch)

return ContextStack('fixed', fixed_arch)


@contextmanager
def no_fixed_arch():
"""
Ignore the ``fixed_arch()`` context.
This is useful in creating a search space within a ``fixed_arch()`` context.
Under the hood, it only disables the most recent one fixed context, which means,
if it's currently in a nested with-fixed-arch context, multiple ``no_fixed_arch()`` contexts is required.
Examples
--------
>>> with fixed_arch(arch_dict):
... with no_fixed_arch():
... model_space = ModelSpace()
"""

NO_ARCH = '_no_arch_'

popped_arch = NO_ARCH # make linter happy
try:
try:
popped_arch = ContextStack.pop('fixed')
except IndexError:
# context unavailable
popped_arch = NO_ARCH
yield
finally:
if popped_arch is not NO_ARCH:
ContextStack.push('fixed', popped_arch)
10 changes: 7 additions & 3 deletions nni/nas/hub/pytorch/autoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper, basic_unit
from nni.nas.fixed import no_fixed_arch
from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.nas.oneshot.pytorch.supermodule.operation import MixedOperation
from nni.nas.oneshot.pytorch.supermodule._valuechoice_utils import traverse_all_options
Expand Down Expand Up @@ -432,7 +433,7 @@ def preset(cls, name: str):
@classmethod
def load_strategy_checkpoint(cls, name: str, download: bool = True, progress: bool = True):
"""
Load the RandomOneShot strategy initialized with supernet weights.
Load the related strategy checkpoints.
Parameters
----------
Expand All @@ -446,15 +447,18 @@ def load_strategy_checkpoint(cls, name: str, download: bool = True, progress: bo
Returns
-------
BaseStrategy
The RandomOneShot strategy initialized with supernet weights provided in the official repo.
The loaded strategy.
"""
legal = ['random-one-shot-tiny', 'random-one-shot-small', 'random-one-shot-base']
if name not in legal:
raise ValueError(f'Unsupported name: {name}. It should be one of {legal}.')
name = name[16:]

# RandomOneShot is the only supported strategy for now.
from nni.nas.strategy import RandomOneShot
init_kwargs = cls.preset(name)
model_sapce = cls(**init_kwargs)
with no_fixed_arch():
model_sapce = cls(**init_kwargs)
strategy = RandomOneShot(mutation_hooks=cls.get_extra_mutation_hooks())
strategy.attach_model(model_sapce)
weight_file = load_pretrained_weight(f"autoformer-{name}-supernet", download=download, progress=progress)
Expand Down
6 changes: 6 additions & 0 deletions nni/nas/oneshot/pytorch/base_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,12 @@ def on_train_start(self):
def on_train_end(self):
return self.model.on_train_end()

def on_validation_start(self):
return self.model.on_validation_start()

def on_validation_end(self):
return self.model.on_validation_end()

def on_fit_start(self):
return self.model.on_fit_start()

Expand Down
1 change: 1 addition & 0 deletions nni/nas/oneshot/pytorch/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def attach_model(self, base_model: Union[Model, nn.Module]):
evaluator_module.running_mode = 'oneshot'
evaluator_module.set_model(py_model)
else:
# FIXME: this should be an evaluator + model
from nni.retiarii.evaluator.pytorch.lightning import ClassificationModule
evaluator_module = ClassificationModule()
evaluator_module.running_mode = 'oneshot'
Expand Down
4 changes: 2 additions & 2 deletions nni/nas/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def push(cls, key: str, value: Any):
cls._stack[key].append(value)

@classmethod
def pop(cls, key: str) -> None:
cls._stack[key].pop()
def pop(cls, key: str) -> Any:
return cls._stack[key].pop()

@classmethod
def top(cls, key: str) -> Any:
Expand Down

0 comments on commit 5874c27

Please sign in to comment.