Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Fix launch from Python log #3263

Merged
merged 5 commits into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions nni/experiment/experiment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import atexit
import logging
from pathlib import Path
import socket
from subprocess import Popen
from threading import Thread
Expand All @@ -15,6 +16,7 @@

from .config import ExperimentConfig
from . import launcher
from . import management
from .pipe import Pipe
from . import rest
from ..tools.nnictl.command_utils import kill_command
Expand Down Expand Up @@ -76,6 +78,7 @@ def __init__(self, tuner: Tuner, training_service: str) -> None:

def __init__(self, tuner: Tuner, config=None, training_service=None):
self.config: ExperimentConfig
self.id: Optional[str] = None
self.port: Optional[int] = None
self.tuner: Tuner = tuner
self._proc: Optional[Popen] = None
Expand Down Expand Up @@ -108,10 +111,15 @@ def start(self, port: int = 8080, debug: bool = False) -> None:
"""
atexit.register(self.stop)

if debug:
logging.getLogger('nni').setLevel(logging.DEBUG)
self.id = management.generate_experiment_id()

self._proc, self._pipe = launcher.start_experiment(self.config, port, debug)
if self.config.experiment_working_directory is not None:
log_dir = Path(self.config.experiment_working_directory, self.id, 'log')
else:
log_dir = Path.home() / f'nni-experiments/{self.id}/log'
nni.runtime.log.start_experiment_log(self.id, log_dir, debug)

self._proc, self._pipe = launcher.start_experiment(self.id, self.config, port, debug)
assert self._proc is not None
assert self._pipe is not None

Expand All @@ -129,11 +137,9 @@ def start(self, port: int = 8080, debug: bool = False) -> None:
if interface.family == socket.AF_INET:
ips.append(interface.address)
ips = [f'http://{ip}:{port}' for ip in ips if ip]
msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips)
msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL
_logger.info(msg)

# TODO: register experiment management metadata


def stop(self) -> None:
"""
Expand All @@ -142,6 +148,8 @@ def stop(self) -> None:
_logger.info('Stopping experiment, please wait...')
atexit.unregister(self.stop)

if self.id is not None:
nni.runtime.log.stop_experiment_log(self.id)
if self._proc is not None:
kill_command(self._proc.pid)
if self._pipe is not None:
Expand All @@ -150,6 +158,7 @@ def stop(self) -> None:
self._dispatcher.stopping = True
self._dispatcher_thread.join(timeout=1)

self.id = None
self.port = None
self._proc = None
self._pipe = None
Expand Down
6 changes: 2 additions & 4 deletions nni/experiment/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,24 @@

from .config import ExperimentConfig
from .config import convert
from . import management
from .pipe import Pipe
from . import rest
from ..tools.nnictl.config_utils import Experiments

_logger = logging.getLogger('nni.experiment')


def start_experiment(config: ExperimentConfig, port: int, debug: bool) -> Tuple[Popen, Pipe]:
def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bool) -> Tuple[Popen, Pipe]:
pipe = None
proc = None

config.validate(initialized_tuner=True)
_ensure_port_idle(port)
if config.training_service.platform == 'openpai':
_ensure_port_idle(port + 1, 'OpenPAI requires an additional port')
exp_id = management.generate_experiment_id()

try:
_logger.info('Creating experiment %s%s', colorama.Fore.CYAN, exp_id)
_logger.info('Creating experiment %s', colorama.Fore.CYAN + exp_id + colorama.Style.RESET_ALL)
pipe = Pipe(exp_id)
start_time, proc = _start_rest_server(config, port, debug, exp_id, pipe.path)
_logger.info('Connecting IPC pipe...')
Expand Down
57 changes: 29 additions & 28 deletions nni/runtime/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
from .env_vars import dispatcher_env_vars, trial_env_vars


handlers = {}

log_format = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
time_format = '%Y-%m-%d %H:%M:%S'
formatter = Formatter(log_format, time_format)


def init_logger() -> None:
"""
This function will (and should only) get invoked on the first time of importing nni (no matter which submodule).
Expand All @@ -37,25 +44,28 @@ def init_logger() -> None:

_init_logger_standalone()

logging.getLogger('filelock').setLevel(logging.WARNING)


def init_logger_experiment() -> None:
"""
Initialize logger for `nni.experiment.Experiment`.

This function will get invoked after `init_logger()`.
"""
formatter.format = _colorful_format
colorful_formatter = Formatter(log_format, time_format)
colorful_formatter.format = _colorful_format
handlers['_default_'].setFormatter(colorful_formatter)

log_path = _prepare_log_dir(dispatcher_env_vars.NNI_LOG_DIRECTORY) / 'dispatcher.log'
_setup_root_logger(FileHandler(log_path), logging.DEBUG)
def start_experiment_log(experiment_id: str, log_directory: Path, debug: bool) -> None:
log_path = _prepare_log_dir(log_directory) / 'dispatcher.log'
log_level = logging.DEBUG if debug else logging.INFO
_register_handler(FileHandler(log_path), log_level, experiment_id)

def stop_experiment_log(experiment_id: str) -> None:
if experiment_id in handlers:
logging.getLogger().removeHandler(handlers.pop(experiment_id))

time_format = '%Y-%m-%d %H:%M:%S'

formatter = Formatter(
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s',
time_format
)

def _init_logger_dispatcher() -> None:
log_level_map = {
Expand All @@ -69,26 +79,20 @@ def _init_logger_dispatcher() -> None:

log_path = _prepare_log_dir(dispatcher_env_vars.NNI_LOG_DIRECTORY) / 'dispatcher.log'
log_level = log_level_map.get(dispatcher_env_vars.NNI_LOG_LEVEL, logging.INFO)
_setup_root_logger(FileHandler(log_path), log_level)
_register_handler(FileHandler(log_path), log_level)


def _init_logger_trial() -> None:
log_path = _prepare_log_dir(trial_env_vars.NNI_OUTPUT_DIR) / 'trial.log'
log_file = open(log_path, 'w')
_setup_root_logger(StreamHandler(log_file), logging.INFO)
_register_handler(StreamHandler(log_file), logging.INFO)

if trial_env_vars.NNI_PLATFORM == 'local':
sys.stdout = _LogFileWrapper(log_file)


def _init_logger_standalone() -> None:
_setup_nni_logger(StreamHandler(sys.stdout), logging.INFO)

# Following line does not affect NNI loggers, but without this user's logger won't
# print log even it's level is set to INFO, so we do it for user's convenience.
# If this causes any issue in future, remove it and use `logging.info()` instead of
# `logging.getLogger('xxx').info()` in all examples.
logging.basicConfig()
_register_handler(StreamHandler(sys.stdout), logging.INFO)


def _prepare_log_dir(path: Optional[str]) -> Path:
Expand All @@ -98,20 +102,18 @@ def _prepare_log_dir(path: Optional[str]) -> Path:
ret.mkdir(parents=True, exist_ok=True)
return ret

def _setup_root_logger(handler: Handler, level: int) -> None:
_setup_logger('', handler, level)

def _setup_nni_logger(handler: Handler, level: int) -> None:
_setup_logger('nni', handler, level)

def _setup_logger(name: str, handler: Handler, level: int) -> None:
def _register_handler(handler: Handler, level: int, tag: str = '_default_') -> None:
assert tag not in handlers
handlers[tag] = handler
handler.setFormatter(formatter)
logger = logging.getLogger(name)
logger = logging.getLogger()
logger.addHandler(handler)
logger.setLevel(level)
logger.propagate = False

def _colorful_format(record):
time = formatter.formatTime(record, time_format)
if not record.name.startswith('nni.'):
return '[{}] ({}) {}'.format(time, record.name, record.msg % record.args)
if record.levelno >= logging.ERROR:
color = colorama.Fore.RED
elif record.levelno >= logging.WARNING:
Expand All @@ -121,7 +123,6 @@ def _colorful_format(record):
else:
color = colorama.Fore.BLUE
msg = color + (record.msg % record.args) + colorama.Style.RESET_ALL
time = formatter.formatTime(record, time_format)
if record.levelno < logging.INFO:
return '[{}] {}:{} {}'.format(time, record.threadName, record.name, msg)
else:
Expand Down