diff --git a/nni/experiment/experiment.py b/nni/experiment/experiment.py index c7cb5ecf97..d521b57506 100644 --- a/nni/experiment/experiment.py +++ b/nni/experiment/experiment.py @@ -127,7 +127,7 @@ def start(self, port: int = 8080, debug: bool = False) -> None: # dispatcher must be launched after pipe initialized # the logic to launch dispatcher in background should be refactored into dispatcher api - self._dispatcher = MsgDispatcher(self.tuner, None) + self._dispatcher = self._create_dispatcher() self._dispatcher_thread = Thread(target=self._dispatcher.run) self._dispatcher_thread.start() @@ -140,6 +140,9 @@ def start(self, port: int = 8080, debug: bool = False) -> None: msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL _logger.info(msg) + def _create_dispatcher(self): # overrided by retiarii, temporary solution + return MsgDispatcher(self.tuner, None) + def stop(self) -> None: """ diff --git a/nni/retiarii/experiment.py b/nni/retiarii/experiment.py index 8bfb5c303d..7ebe5f1c96 100644 --- a/nni/retiarii/experiment.py +++ b/nni/retiarii/experiment.py @@ -159,34 +159,11 @@ def start(self, port: int = 8080, debug: bool = False) -> None: debug Whether to start in debug mode. """ - atexit.register(self.stop) - - if debug: - logging.getLogger('nni').setLevel(logging.DEBUG) - - self._proc, self._pipe = launcher.start_experiment(self.config, port, debug) - assert self._proc is not None - assert self._pipe is not None - - self.port = port # port will be None if start up failed - - # dispatcher must be created after pipe initialized - # the logic to launch dispatcher in background should be refactored into dispatcher api - self._dispatcher_thread = Thread(target=self._dispatcher.run) - self._dispatcher_thread.start() - + super().start(port, debug) self._start_strategy() - ips = [self.config.nni_manager_ip] - for interfaces in psutil.net_if_addrs().values(): - for interface in interfaces: - if interface.family == socket.AF_INET: - ips.append(interface.address) - ips = [f'http://{ip}:{port}' for ip in ips if ip] - msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) - _logger.info(msg) - - # TODO: register experiment management metadata + def _create_dispatcher(self): + return self._dispatcher def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str: """