From fdc0e5f7f0fa1ed3b769ad3444c844628a108455 Mon Sep 17 00:00:00 2001 From: liuzhe Date: Tue, 5 Jan 2021 22:24:56 +0800 Subject: [PATCH] fix retiarii --- nni/experiment/experiment.py | 5 ++++- nni/retiarii/experiment.py | 29 +++-------------------------- 2 files changed, 7 insertions(+), 27 deletions(-) diff --git a/nni/experiment/experiment.py b/nni/experiment/experiment.py index c7cb5ecf97..d521b57506 100644 --- a/nni/experiment/experiment.py +++ b/nni/experiment/experiment.py @@ -127,7 +127,7 @@ def start(self, port: int = 8080, debug: bool = False) -> None: # dispatcher must be launched after pipe initialized # the logic to launch dispatcher in background should be refactored into dispatcher api - self._dispatcher = MsgDispatcher(self.tuner, None) + self._dispatcher = self._create_dispatcher() self._dispatcher_thread = Thread(target=self._dispatcher.run) self._dispatcher_thread.start() @@ -140,6 +140,9 @@ def start(self, port: int = 8080, debug: bool = False) -> None: msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL _logger.info(msg) + def _create_dispatcher(self): # overrided by retiarii, temporary solution + return MsgDispatcher(self.tuner, None) + def stop(self) -> None: """ diff --git a/nni/retiarii/experiment.py b/nni/retiarii/experiment.py index 8bfb5c303d..7ebe5f1c96 100644 --- a/nni/retiarii/experiment.py +++ b/nni/retiarii/experiment.py @@ -159,34 +159,11 @@ def start(self, port: int = 8080, debug: bool = False) -> None: debug Whether to start in debug mode. """ - atexit.register(self.stop) - - if debug: - logging.getLogger('nni').setLevel(logging.DEBUG) - - self._proc, self._pipe = launcher.start_experiment(self.config, port, debug) - assert self._proc is not None - assert self._pipe is not None - - self.port = port # port will be None if start up failed - - # dispatcher must be created after pipe initialized - # the logic to launch dispatcher in background should be refactored into dispatcher api - self._dispatcher_thread = Thread(target=self._dispatcher.run) - self._dispatcher_thread.start() - + super().start(port, debug) self._start_strategy() - ips = [self.config.nni_manager_ip] - for interfaces in psutil.net_if_addrs().values(): - for interface in interfaces: - if interface.family == socket.AF_INET: - ips.append(interface.address) - ips = [f'http://{ip}:{port}' for ip in ips if ip] - msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) - _logger.info(msg) - - # TODO: register experiment management metadata + def _create_dispatcher(self): + return self._dispatcher def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str: """