Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
fix retiarii experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
QuanluZhang committed Jan 5, 2021
1 parent 9eae8e8 commit cfdef72
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions nni/retiarii/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import colorama
import psutil

from ..experiment import Experiment, TrainingServiceConfig, launcher
import nni.runtime.log

from ..experiment import Experiment, TrainingServiceConfig, launcher, management
from ..experiment.config.base import ConfigBase, PathLike
from ..experiment.config import util
from ..experiment.pipe import Pipe
Expand Down Expand Up @@ -86,6 +88,7 @@ class RetiariiExperiment(Experiment):
def __init__(self, base_model: Model, trainer: BaseTrainer,
applied_mutators: Mutator = None, strategy: BaseStrategy = None):
self.config: RetiariiExeConfig = None
self.id: Optional[str] = None
self.port: Optional[int] = None

self.base_model = base_model
Expand Down Expand Up @@ -161,10 +164,15 @@ def start(self, port: int = 8080, debug: bool = False) -> None:
"""
atexit.register(self.stop)

if debug:
logging.getLogger('nni').setLevel(logging.DEBUG)
self.id = management.generate_experiment_id()

if self.config.experiment_working_directory is not None:
log_dir = Path(self.config.experiment_working_directory, self.id, 'log')
else:
log_dir = Path.home() / f'nni-experiments/{self.id}/log'
nni.runtime.log.start_experiment_log(self.id, log_dir, debug)

self._proc, self._pipe = launcher.start_experiment(self.config, port, debug)
self._proc, self._pipe = launcher.start_experiment(self.id, self.config, port, debug)
assert self._proc is not None
assert self._pipe is not None

Expand All @@ -183,7 +191,7 @@ def start(self, port: int = 8080, debug: bool = False) -> None:
if interface.family == socket.AF_INET:
ips.append(interface.address)
ips = [f'http://{ip}:{port}' for ip in ips if ip]
msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips)
msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL
_logger.info(msg)

# TODO: register experiment management metadata
Expand Down

0 comments on commit cfdef72

Please sign in to comment.