From cfdef72e54bbfd7adbf2a8faab753d47a51006ae Mon Sep 17 00:00:00 2001 From: quzha Date: Tue, 5 Jan 2021 22:22:13 +0800 Subject: [PATCH] fix retiarii experiment --- nni/retiarii/experiment.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/nni/retiarii/experiment.py b/nni/retiarii/experiment.py index 8bfb5c303db..7f7e9e4dc92 100644 --- a/nni/retiarii/experiment.py +++ b/nni/retiarii/experiment.py @@ -11,7 +11,9 @@ import colorama import psutil -from ..experiment import Experiment, TrainingServiceConfig, launcher +import nni.runtime.log + +from ..experiment import Experiment, TrainingServiceConfig, launcher, management from ..experiment.config.base import ConfigBase, PathLike from ..experiment.config import util from ..experiment.pipe import Pipe @@ -86,6 +88,7 @@ class RetiariiExperiment(Experiment): def __init__(self, base_model: Model, trainer: BaseTrainer, applied_mutators: Mutator = None, strategy: BaseStrategy = None): self.config: RetiariiExeConfig = None + self.id: Optional[str] = None self.port: Optional[int] = None self.base_model = base_model @@ -161,10 +164,15 @@ def start(self, port: int = 8080, debug: bool = False) -> None: """ atexit.register(self.stop) - if debug: - logging.getLogger('nni').setLevel(logging.DEBUG) + self.id = management.generate_experiment_id() + + if self.config.experiment_working_directory is not None: + log_dir = Path(self.config.experiment_working_directory, self.id, 'log') + else: + log_dir = Path.home() / f'nni-experiments/{self.id}/log' + nni.runtime.log.start_experiment_log(self.id, log_dir, debug) - self._proc, self._pipe = launcher.start_experiment(self.config, port, debug) + self._proc, self._pipe = launcher.start_experiment(self.id, self.config, port, debug) assert self._proc is not None assert self._pipe is not None @@ -183,7 +191,7 @@ def start(self, port: int = 8080, debug: bool = False) -> None: if interface.family == socket.AF_INET: ips.append(interface.address) ips = [f'http://{ip}:{port}' for ip in ips if ip] - msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL _logger.info(msg) # TODO: register experiment management metadata