From 083aff188cf7ccb285083d74392f2d72ba05768d Mon Sep 17 00:00:00 2001 From: ravi-mosaicml <87037432+ravi-mosaicml@users.noreply.github.com> Date: Tue, 18 Jan 2022 16:16:28 -0800 Subject: [PATCH] Fixed loggers and callbacks (#240) 1. Removed rank zero callbacks and loggers, since these hid complexity and led to infinitely-blocking code when using distributed functions. Closes #239. 2. Incrementing `state.timer` _before_ calling `.eval()` in the trainer. This helps ensure that the batch count is consistent for both batch-wise and epoch-wise evaluators. This batch is printed in the logs. 3. Fixed the TQDM logger so it works properly with gradient accumulation. 4. Removed `LogLevel.ALGORITHM`, `LogLevel.MICROBATCH`, and `LogLevel.VERBOSE` since these were rarely being used. Instead, the built-in python logger should probably be used for anything that is verbose (since it really wouldn't be a useful metric), MICROBATCH should use BATCH (since a MICROBATCH is like another gpu), and ALGORITHM should use batch or epoch, depending where it is being run., 5. Updated the file logger to take a `log_interval` instead of `log_every_n_epochs` and `log_every_n_batches`, and a `flush_interval` instead of `flush_every_n_batches`. 6. Switched the default logger in all yamls to tqdm. --- composer/core/callback.py | 22 ----- composer/core/engine.py | 11 ++- composer/core/logging/__init__.py | 1 - composer/core/logging/base_backend.py | 78 +--------------- composer/core/logging/logger.py | 12 --- composer/loggers/__init__.py | 1 - composer/loggers/file_logger.py | 92 +++++++++++-------- composer/loggers/logger_hparams.py | 20 ++-- composer/loggers/mosaicml_logger.py | 12 ++- composer/loggers/tqdm_logger.py | 85 ++++++++++------- composer/loggers/wandb_logger.py | 53 +++++++---- composer/trainer/trainer.py | 14 +-- composer/yamls/models/bert-base.yaml | 8 +- composer/yamls/models/glue/cola.yaml | 8 +- composer/yamls/models/glue/mnli-m.yaml | 8 +- composer/yamls/models/glue/mnli.yaml | 8 +- composer/yamls/models/glue/mrpc.yaml | 8 +- composer/yamls/models/glue/qnli.yaml | 8 +- composer/yamls/models/glue/qqp.yaml | 8 +- composer/yamls/models/glue/rte.yaml | 8 +- composer/yamls/models/glue/sst-2.yaml | 8 +- composer/yamls/models/glue/stsb.yaml | 8 +- composer/yamls/models/gpt2_1,3b.yaml | 8 +- composer/yamls/models/gpt2_125m.yaml | 8 +- composer/yamls/models/gpt2_13b.yaml | 8 +- composer/yamls/models/gpt2_2,7b.yaml | 8 +- composer/yamls/models/gpt2_350m.yaml | 8 +- composer/yamls/models/gpt2_52m.yaml | 8 +- composer/yamls/models/gpt2_6,7b.yaml | 8 +- composer/yamls/models/gpt2_760m.yaml | 8 +- composer/yamls/models/gpt2_83m.yaml | 8 +- composer/yamls/models/resnet56_cifar10.yaml | 8 +- .../models/resnet56_cifar10_synthetic.yaml | 8 +- composer/yamls/models/resnet9_cifar10.yaml | 8 +- docs/source/core/callback.rst | 1 - docs/source/core/logger.rst | 1 - tests/test_logger.py | 61 ++++++------ tests/test_mosaicml_logger.py | 2 +- tests/trainer/test_ddp.py | 3 +- 39 files changed, 231 insertions(+), 414 deletions(-) mode change 100644 => 100755 composer/loggers/__init__.py mode change 100644 => 100755 composer/loggers/mosaicml_logger.py mode change 100644 => 100755 composer/loggers/tqdm_logger.py diff --git a/composer/core/callback.py b/composer/core/callback.py index 299d2bd33b..ba329bc9c4 100644 --- a/composer/core/callback.py +++ b/composer/core/callback.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING from composer.core.serializable import Serializable -from composer.utils import dist try: from typing import final @@ -315,24 +314,3 @@ def post_close(self) -> None: callbacks during :meth:`close`. """ pass - - -class RankZeroCallback(Callback, abc.ABC): - """Base class for callbacks that only run on the local rank zero process. - - Callbacks can be implemented in two ways: - - #. Override the individual methods named for each :class:`Event`. (See - the parent class, :class:`Callback`.) - - #. Override :meth:`_run_event` (**not** :meth:`run_event`) to run in response - to all events. If this method is overridden, then the individual methods - corresponding to each event name will not be automatically called (however, - the subclass implementation can invoke these methods as it wishes.) - """ - - @final - def run_event(self, event: Event, state: State, logger: Logger) -> None: - if dist.get_local_rank() != 0: - return - return self._run_event(event, state, logger) diff --git a/composer/core/engine.py b/composer/core/engine.py index 9fe69d67e5..429f866e9a 100644 --- a/composer/core/engine.py +++ b/composer/core/engine.py @@ -9,6 +9,7 @@ from composer.core.callback import Callback from composer.core.event import Event from composer.core.logging import Logger +from composer.core.logging.logger import LogLevel from composer.core.state import State log = logging.getLogger(__name__) @@ -126,7 +127,15 @@ def _run_algorithms( trace[trace_key] = Trace(exit_code=exit_code, order=order, run=True) if self.logger is not None: - self.logger.metric_verbose(data={key: 1 if tr.run else 0 for key, tr in trace.items()}) + if event in (Event.INIT, Event.TRAINING_START, Event.TRAINING_END): + log_level = LogLevel.FIT + if event in (Event.EPOCH_START, Event.EPOCH_END): + log_level = LogLevel.EPOCH + else: + # algs don't run on eval events, so don't have to worry about + # batch-frequency vs epoch-frequency evaluators + log_level = LogLevel.BATCH + self.logger.metric(log_level=log_level, data={key: 1 if tr.run else 0 for key, tr in trace.items()}) return trace diff --git a/composer/core/logging/__init__.py b/composer/core/logging/__init__.py index 24698cbb19..4858f9b3c9 100644 --- a/composer/core/logging/__init__.py +++ b/composer/core/logging/__init__.py @@ -1,7 +1,6 @@ # Copyright 2021 MosaicML. All Rights Reserved. from composer.core.logging.base_backend import BaseLoggerBackend as BaseLoggerBackend -from composer.core.logging.base_backend import RankZeroLoggerBackend as RankZeroLoggerBackend from composer.core.logging.logger import Logger as Logger from composer.core.logging.logger import LogLevel as LogLevel from composer.core.logging.logger import TLogData as TLogData diff --git a/composer/core/logging/base_backend.py b/composer/core/logging/base_backend.py index 7285009ae6..130923ced5 100644 --- a/composer/core/logging/base_backend.py +++ b/composer/core/logging/base_backend.py @@ -5,18 +5,12 @@ from abc import ABC from typing import TYPE_CHECKING -from composer.core.callback import Callback, RankZeroCallback -from composer.utils import dist +from composer.core.callback import Callback if TYPE_CHECKING: from composer.core.logging.logger import LogLevel, TLogData from composer.core.state import State -try: - from typing import final -except ImportError: - final = lambda x: x # final is not available in python 3.7 - class BaseLoggerBackend(Callback, ABC): """Base class for logging backends. @@ -59,73 +53,3 @@ def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) """ del epoch, step, log_level, data # unused pass - - -class RankZeroLoggerBackend(BaseLoggerBackend, RankZeroCallback, ABC): - """Base class for logging backends that run only on the rank zero process. - - In a multi-process training setup (e.g. when using DistributedDataParallel), - some logging backends require that only the rank zero process log data. - For example, when logging to a file, only the main process should open the file - and save data. - - When using this class, override - :func:`_will_log` and :func:`_log_metric`` instead of - :func:`will_log` and :func:`log_metric`, respectively. - - This class ensures that :func:`_will_log` and :func:`_log_metric` - are invoked only on the rank zero process. - - .. automethod:: _will_log - .. automethod:: _log_metric - """ - - def __init__(self) -> None: - super().__init__() - - def _will_log(self, state: State, log_level: LogLevel) -> bool: - """Called by the :class:`~composer.core.logging.logger.Logger` - to determine whether the logging backend will log a metric. - - By default, it always returns ``True``, but this method - can be overridden. - - Args: - state (State): The global state object. - log_level (LogLevel): The log level. - - Returns: - bool: Whether to log a metric call, given the - :class:`~composer.core.state.State` and - :class:`~composer.core.logging.logger.LogLevel`. - """ - del state, log_level # Unused - return True - - @final - def will_log(self, state: State, log_level: LogLevel) -> bool: - if dist.get_local_rank() != 0: - return False - return self._will_log(state, log_level) - - def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) -> None: - """Called by the :class:`~composer.core.logging.logger.Logger` - for metrics where :func:`will_log` returned ``True``. - - The logging backend should override this function to log the data - (e.g. write it to a file, send it to a server, etc...). - - Args: - epoch (int): The epoch for the logged data. - step (int): The global step for the logged data. - log_level (LogLevel). The log level. - data (TLogData): The metric to log. - """ - del epoch, step, log_level, data # Unused - pass - - @final - def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) -> None: - if dist.get_local_rank() != 0: - return - return self._log_metric(epoch, step, log_level, data) diff --git a/composer/core/logging/logger.py b/composer/core/logging/logger.py index d155b0f335..27ba3c6d9c 100644 --- a/composer/core/logging/logger.py +++ b/composer/core/logging/logger.py @@ -30,14 +30,10 @@ class LogLevel(IntEnum): FIT: Logged once per training run. EPOCH: Logged once per epoch. BATCH: Logged once per batch. - MICROBATCH: Logged once per microbatch (e.g. forward pass). - VERBOSE: Logged for debugging. """ FIT = 1 EPOCH = 2 BATCH = 3 - MICROBATCH = 4 - VERBOSE = 5 class Logger: @@ -109,14 +105,6 @@ def metric_batch(self, data: Union[TLogData, Callable[[], TLogData]]) -> None: """Helper function for ``self.metric(LogLevel.BATCH, data)``""" self.metric(LogLevel.BATCH, data) - def metric_microbatch(self, data: Union[TLogData, Callable[[], TLogData]]) -> None: - """Helper function for ``self.metric(LogLevel.MICROBATCH, data)``""" - self.metric(LogLevel.MICROBATCH, data) - - def metric_verbose(self, data: Union[TLogData, Callable[[], TLogData]]) -> None: - """Helper function for ``self.metric(LogLevel.VERBOSE, data)``""" - self.metric(LogLevel.VERBOSE, data) - def format_log_data_value(data: TLogDataValue) -> str: """Recursively formats a given log data value into a string. diff --git a/composer/loggers/__init__.py b/composer/loggers/__init__.py old mode 100644 new mode 100755 index 43f4f2498c..b6fbb73252 --- a/composer/loggers/__init__.py +++ b/composer/loggers/__init__.py @@ -1,7 +1,6 @@ # Copyright 2021 MosaicML. All Rights Reserved. from composer.core.logging.base_backend import BaseLoggerBackend as BaseLoggerBackend -from composer.core.logging.base_backend import RankZeroLoggerBackend as RankZeroLoggerBackend from composer.core.logging.logger import Logger as Logger from composer.core.logging.logger import LogLevel as LogLevel from composer.core.logging.logger import TLogData as TLogData diff --git a/composer/loggers/file_logger.py b/composer/loggers/file_logger.py index 2854b9ba7b..10b5952ec7 100644 --- a/composer/loggers/file_logger.py +++ b/composer/loggers/file_logger.py @@ -8,13 +8,12 @@ import yaml -from composer.core.logging import Logger, LogLevel, RankZeroLoggerBackend, TLogData, format_log_data_value +from composer.core.logging import BaseLoggerBackend, Logger, LogLevel, TLogData, format_log_data_value from composer.core.state import State -from composer.loggers.logger_hparams import FileLoggerBackendHparams from composer.utils import run_directory -class FileLoggerBackend(RankZeroLoggerBackend): +class FileLoggerBackend(BaseLoggerBackend): """Logs to a file or to the terminal. Example output:: @@ -33,14 +32,17 @@ class FileLoggerBackend(RankZeroLoggerBackend): log_level (LogLevel, optional): Maximum :class:`~composer.core.logging.logger.LogLevel`. to record. (default: :attr:`~composer.core.logging.logger.LogLevel.EPOCH`) - every_n_epochs (int, optional): - Frequency to print :attr:`~composer.core.logging.logger.LogLevel.EPOCH` logs. - (default: ``1``) - every_n_batches (int, optional): - Frequency to print :attr:`~composer.core.logging.logger.LogLevel.BATCH` logs. - (default: ``1``) - flush_every_n_batches (int, optional): How frequently to flush the log to the file. - (default: ``1``) + log_interval (int, optional): + Frequency to print logs. If ``log_level` is :attr:`~composer.core.logging.logger.LogLevel.EPOCH`, + logs will only be recorded every n epochs. If ``log_level` is + :attr:`~composer.core.logging.logger.LogLevel.BATCH`, logs will be printed every n batches. + Otherwise, if ``log_level` is :attr:`~composer.core.logging.logger.LogLevel.FIT`, this parameter is + ignored, as calls at the fit log level are always recorded. (default: ``1``) + flush_interval (int, optional): How frequently to flush the log to the file, relative to the ``log_level``. + For example, if the ``log_level`` is :attr:`~composer.core.logging.logger.LogLevel.EPOCH`, + then the logfile will be flushed every n epochs. + If the ``log_level`` is :attr:`~composer.core.logging.logger.LogLevel.BATCH`, then the logfile will be flushed + every n batches. (default: ``100``) """ def __init__( @@ -49,50 +51,54 @@ def __init__( *, buffer_size: int = 1, log_level: LogLevel = LogLevel.EPOCH, - every_n_epochs: int = 1, - every_n_batches: int = 1, - flush_every_n_batches: int = 1, + log_interval: int = 1, + flush_interval: int = 100, config: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() - self.hparams = FileLoggerBackendHparams( - filename=filename, - buffer_size=buffer_size, - log_level=log_level, - every_n_epochs=every_n_epochs, - every_n_batches=every_n_batches, - flush_every_n_batches=flush_every_n_batches, - ) + self.filename = filename + self.buffer_size = buffer_size + self.log_level = log_level + self.log_interval = log_interval + self.flush_interval = flush_interval self.file: Optional[TextIO] = None self.config = config - def _will_log(self, state: State, log_level: LogLevel) -> bool: - if log_level > self.hparams.log_level: - return False - if log_level >= LogLevel.EPOCH and state.epoch % self.hparams.every_n_epochs != 0: - return False - if log_level >= LogLevel.BATCH and (state.step + 1) % self.hparams.every_n_batches != 0: - return False - return True - - def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData): + def will_log(self, state: State, log_level: LogLevel) -> bool: + if log_level == LogLevel.FIT: + return True # fit is always logged + if log_level == LogLevel.EPOCH: + if self.log_level < LogLevel.EPOCH: + return False + if self.log_level > LogLevel.EPOCH: + return True + return (int(state.timer.epoch) + 1) % self.log_interval == 0 + if log_level == LogLevel.BATCH: + if self.log_level < LogLevel.BATCH: + return False + if self.log_level > LogLevel.BATCH: + return True + return (int(state.timer.batch) + 1) % self.log_interval == 0 + raise ValueError(f"Unknown log level: {log_level}") + + def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData): data_str = format_log_data_value(data) if self.file is None: raise RuntimeError("Attempted to log before self.init() or after self.close()") - print(f"[{log_level.name}][step={step}]: {data_str}", file=self.file) + print(f"[{log_level.name}][step={step}]: {data_str}", file=self.file, flush=False) def init(self, state: State, logger: Logger) -> None: del state, logger # unused if self.file is not None: raise RuntimeError("The file logger is already initialized") - if self.hparams.filename == "stdout": + if self.filename == "stdout": self.file = sys.stdout - elif self.hparams.filename == "stderr": + elif self.filename == "stderr": self.file = sys.stderr else: - self.file = open(os.path.join(run_directory.get_run_directory(), self.hparams.filename), + self.file = open(os.path.join(run_directory.get_run_directory(), self.filename), "x+", - buffering=self.hparams.buffer_size) + buffering=self.buffer_size) if self.config is not None: print("Config", file=self.file) print("-" * 30, file=self.file) @@ -100,15 +106,21 @@ def init(self, state: State, logger: Logger) -> None: print("-" * 30, file=self.file) print(file=self.file) + def training_start(self, state: State, logger: Logger) -> None: + del state, logger # unused + self._flush_file() + def batch_end(self, state: State, logger: Logger) -> None: del logger # unused assert self.file is not None - if (state.step + 1) % self.hparams.flush_every_n_batches == 0: + if self.log_level == LogLevel.BATCH and (int(state.timer.batch) + 1) % self.flush_interval == 0: self._flush_file() def epoch_end(self, state: State, logger: Logger) -> None: - del state, logger # unused - self._flush_file() + del logger # unused + if self.log_level > LogLevel.EPOCH or self.log_level == LogLevel.EPOCH and (int(state.timer.epoch) + + 1) % self.flush_interval == 0: + self._flush_file() def training_end(self, state: State, logger: Logger) -> None: self._flush_file() diff --git a/composer/loggers/logger_hparams.py b/composer/loggers/logger_hparams.py index a3221d225f..5a559e770c 100644 --- a/composer/loggers/logger_hparams.py +++ b/composer/loggers/logger_hparams.py @@ -56,17 +56,13 @@ class FileLoggerBackendHparams(BaseLoggerBackendHparams): buffer_size: int = hp.optional("Number of bytes to buffer. Defaults to 1 for line-buffering. " "See https://docs.python.org/3/library/functions.html#open", default=1) # line buffering. Python's default is -1. - flush_every_n_batches: int = hp.optional( - "Even if the buffer is not full, write to the file after this many steps. " - "Defaults to 1 (every step).", - default=1) - every_n_epochs: int = hp.optional( - "Frequency of logging messages for messages of LogLevel.EPOCH and higher." - "Defaults to 1 (every epoch).", - default=1) - every_n_batches: int = hp.optional( - "Frequency of logging messages for messages of LogLevel.BATCH and higher." - "Defaults to 1 (every batch).", + flush_interval: int = hp.optional( + "Frequency to flush the file, relative to the ``log_level``. " + "Defaults to 100 of the unit of ``log_level``.", + default=100) + log_interval: int = hp.optional( + "Frequency to record log messages, relative to the ``log_level``." + "Defaults to 1 (record all messages).", default=1) def initialize_object(self, config: Optional[Dict[str, Any]] = None) -> FileLoggerBackend: @@ -98,6 +94,7 @@ class WandBLoggerBackendHparams(BaseLoggerBackendHparams): tags: str = hp.optional(doc="wandb tags comma separated", default="") log_artifacts: bool = hp.optional(doc="Whether to log artifacts", default=False) log_artifacts_every_n_batches: int = hp.optional(doc="interval, in batches, to log artifacts", default=100) + rank_zero_only: bool = hp.optional("Whether to log on rank zero only", default=False) extra_init_params: Dict[str, JSON] = hp.optional(doc="wandb parameters", default_factory=dict) def initialize_object(self, config: Optional[Dict[str, Any]] = None) -> WandBLoggerBackend: @@ -202,6 +199,7 @@ def get_flattened_dict(data: Dict[str, Any], _prefix: List[str] = []) -> Dict[st from composer.loggers.wandb_logger import WandBLoggerBackend return WandBLoggerBackend( log_artifacts=self.log_artifacts, + rank_zero_only=self.rank_zero_only, log_artifacts_every_n_batches=self.log_artifacts_every_n_batches, init_params=init_params, ) diff --git a/composer/loggers/mosaicml_logger.py b/composer/loggers/mosaicml_logger.py old mode 100644 new mode 100755 index fe4a07272b..7f7a787a9a --- a/composer/loggers/mosaicml_logger.py +++ b/composer/loggers/mosaicml_logger.py @@ -13,9 +13,11 @@ import requests -from composer.core.logging import LogLevel, RankZeroLoggerBackend, TLogData +from composer.core.logging import LogLevel, TLogData +from composer.core.logging.base_backend import BaseLoggerBackend from composer.core.logging.logger import format_log_data_as_json from composer.core.types import JSON, Logger, State, StateDict +from composer.utils import dist from composer.utils.string_enum import StringEnum _MOSAICML_API_KEY_ENV = "MOSAICML_LOGGER_API_KEY" @@ -92,7 +94,7 @@ def _upsert_run(run_id: str, sys.exit(1) -class MosaicMLLoggerBackend(RankZeroLoggerBackend): +class MosaicMLLoggerBackend(BaseLoggerBackend): """Log to the MosaicML backend. Args: @@ -125,7 +127,7 @@ def __init__(self, config: Optional[Dict[str, JSON]] = None) -> None: super().__init__() - self.skip_logging = False + self.skip_logging = dist.get_global_rank() != 0 self.log_level = log_level self.run_name = run_name self.run_type = run_type @@ -153,11 +155,11 @@ def __init__(self, self.queue = Queue() self.thread = Thread(target=self._listen_to_queue, daemon=True, name="mosaicml-logger-thread") - def _will_log(self, state: State, log_level: LogLevel) -> bool: + def will_log(self, state: State, log_level: LogLevel) -> bool: del state # unused return log_level <= self.log_level - def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData): + def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData): del log_level # unused if self.skip_logging: diff --git a/composer/loggers/tqdm_logger.py b/composer/loggers/tqdm_logger.py old mode 100644 new mode 100755 index 10c1c5011c..15d29e65e4 --- a/composer/loggers/tqdm_logger.py +++ b/composer/loggers/tqdm_logger.py @@ -3,15 +3,17 @@ from __future__ import annotations import sys -from dataclasses import asdict, dataclass, field -from typing import TYPE_CHECKING, Any, Dict, Optional +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional import tqdm import yaml -from composer.core.logging import LogLevel, RankZeroLoggerBackend, TLogData, TLogDataValue, format_log_data_value +from composer.core.logging import LogLevel, TLogData, TLogDataValue, format_log_data_value +from composer.core.logging.base_backend import BaseLoggerBackend from composer.core.state import State from composer.core.types import StateDict +from composer.utils import dist if TYPE_CHECKING: from composer.core.logging import Logger @@ -21,39 +23,26 @@ @dataclass class _TQDMLoggerInstanceState: - total: int - epoch: int - is_train: bool + total: Optional[int] + description: str + position: int + keys_to_log: List[str] n: int - epoch_metrics: Dict[str, TLogDataValue] = field(default_factory=dict) + epoch_metrics: Dict[str, TLogDataValue] class _TQDMLoggerInstance: - def __init__(self, - total: int, - epoch: int, - is_train: bool, - n: int = 0, - epoch_metrics: Optional[Dict[str, TLogDataValue]] = None) -> None: - self.state = _TQDMLoggerInstanceState(total=total, - epoch=epoch, - is_train=is_train, - n=n, - epoch_metrics=(epoch_metrics or {})) - desc = f'Epoch {epoch + 1}{"" if is_train else " (val)"}' - position = 0 if is_train else 1 - self.pbar = tqdm.tqdm(total=total, - desc=desc, - position=position, - initial=n, + def __init__(self, state: _TQDMLoggerInstanceState) -> None: + self.state = state + self.pbar = tqdm.tqdm(total=state.total, + desc=state.description, + position=state.position, bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}") - self.pbar.set_postfix(epoch_metrics) + self.pbar.set_postfix(state.epoch_metrics) def log_metric(self, data: TLogData): - formatted_data = { - k: format_log_data_value(v) for (k, v) in data.items() if k in _IS_TRAIN_TO_KEYS_TO_LOG[self.state.is_train] - } + formatted_data = {k: format_log_data_value(v) for (k, v) in data.items() if k in self.state.keys_to_log} self.state.epoch_metrics.update(formatted_data) self.pbar.set_postfix(self.state.epoch_metrics) @@ -68,7 +57,7 @@ def state_dict(self) -> StateDict: return asdict(self.state) -class TQDMLoggerBackend(RankZeroLoggerBackend): +class TQDMLoggerBackend(BaseLoggerBackend): """Shows TQDM progress bars. During training, the progress bar logs the batch and training loss. @@ -95,11 +84,11 @@ def __init__(self, config: Optional[Dict[str, Any]] = None) -> None: self.is_train: Optional[bool] = None self.config = config - def _will_log(self, state: State, log_level: LogLevel) -> bool: + def will_log(self, state: State, log_level: LogLevel) -> bool: del state # Unused - return log_level <= LogLevel.BATCH + return dist.get_global_rank() == 0 and log_level <= LogLevel.BATCH - def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) -> None: + def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) -> None: del epoch, step, log_level # Unused if self.is_train in self.pbars: # Logging outside an epoch @@ -116,36 +105,60 @@ def init(self, state: State, logger: Logger) -> None: print() def _start(self, state: State): + if dist.get_global_rank() != 0: + return assert self.is_train is not None, "self.is_train should be set by the callback" # TODO(anis) -- in #120, len(state.eval_dataloader) is inaccurate, as it does not incorporate # trainer._eval_subset_num_batches. The evaluator spec should fix this. total_steps = state.steps_per_epoch if self.is_train else len(state.eval_dataloader) - self.pbars[self.is_train] = _TQDMLoggerInstance(total=total_steps, epoch=state.epoch, is_train=self.is_train) + desc = f'Epoch {int(state.timer.epoch)}' + position = 0 if self.is_train else 1 + if not self.is_train: + desc += f", Batch {int(state.timer.batch)} (val)" + self.pbars[self.is_train] = _TQDMLoggerInstance( + _TQDMLoggerInstanceState(total=total_steps, + position=position, + n=0, + keys_to_log=_IS_TRAIN_TO_KEYS_TO_LOG[self.is_train], + description=desc, + epoch_metrics={})) def epoch_start(self, state: State, logger: Logger) -> None: del logger # unused + if dist.get_global_rank() != 0: + return self.is_train = True self._start(state) def eval_start(self, state: State, logger: Logger) -> None: del logger # unused + if dist.get_global_rank() != 0: + return self.is_train = False self._start(state) def _update(self): + if dist.get_global_rank() != 0: + return if self.is_train in self.pbars: assert self.is_train is not None self.pbars[self.is_train].update() - def after_backward(self, state: State, logger: Logger) -> None: + def batch_end(self, state: State, logger: Logger) -> None: del state, logger # unused + if dist.get_global_rank() != 0: + return self._update() def eval_after_forward(self, state: State, logger: Logger) -> None: del state, logger # unused + if dist.get_global_rank() != 0: + return self._update() def _end(self): + if dist.get_global_rank() != 0: + return if self.is_train in self.pbars: assert self.is_train is not None self.pbars[self.is_train].close() @@ -154,10 +167,14 @@ def _end(self): def epoch_end(self, state: State, logger: Logger) -> None: del state, logger # unused + if dist.get_global_rank() != 0: + return self._end() def eval_end(self, state: State, logger: Logger) -> None: del state, logger # unused + if dist.get_global_rank() != 0: + return self._end() def state_dict(self) -> StateDict: diff --git a/composer/loggers/wandb_logger.py b/composer/loggers/wandb_logger.py index 38c7ba8416..4ac23b79ac 100644 --- a/composer/loggers/wandb_logger.py +++ b/composer/loggers/wandb_logger.py @@ -4,6 +4,8 @@ import os import sys +import textwrap +import warnings from typing import Any, Dict, Optional from composer.core.logging import BaseLoggerBackend, LogLevel, TLogData @@ -23,14 +25,25 @@ class WandBLoggerBackend(BaseLoggerBackend): was realized when logging and uploading artifacts, so it is recommended to do so infrequently. Only applicable when `log_artifacts` is True (default: ``100``) + rank_zero_only (bool, optional): Whether to log only on the rank-zero process (default: ``False``). + When logging artifacts, it is highly recommended to log on all ranks. Artifacts from ranks 1+ + will not be stored, which may discard pertinent information. For example, when using Deepspeed + ZeRO, it would be impossible to restore from checkpoints without artifacts from all ranks. init_params (Dict[str, Any], optional): Parameters to pass into :meth:`wandb.init`. """ def __init__(self, log_artifacts: bool = False, log_artifacts_every_n_batches: int = 100, + rank_zero_only: bool = False, init_params: Optional[Dict[str, Any]] = None) -> None: - super().__init__() + if log_artifacts and rank_zero_only: + warnings.warn( + textwrap.dedent("""When logging artifacts, `rank_zero_only` should be set to False. + Artifacts from other ranks will not be collected, leading to a loss of information required to + restore from checkpoints.""")) + self._enabled = (not rank_zero_only) or dist.get_global_rank() == 0 + self._log_artifacts = log_artifacts self._log_artifacts_every_n_batches = log_artifacts_every_n_batches self._last_upload_timestamp = 0.0 @@ -40,35 +53,40 @@ def __init__(self, def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData): del epoch, log_level # unused - wandb.log(data, step=step) + if self._enabled: + wandb.log(data, step=step) def state_dict(self) -> StateDict: # Storing these fields in the state dict to support run resuming in the future. - return { - "name": wandb.run.name, - "project": wandb.run.project, - "entity": wandb.run.entity, - "id": wandb.run.id, - "group": wandb.run.group - } + if self._enabled: + return { + "name": wandb.run.name, + "project": wandb.run.project, + "entity": wandb.run.entity, + "id": wandb.run.id, + "group": wandb.run.group + } + else: + return {} def init(self, state: State, logger: Logger) -> None: del state, logger # unused - wandb.init(**self._init_params) + if self._enabled: + wandb.init(**self._init_params) def batch_end(self, state: State, logger: Logger) -> None: del logger # unused - if self._log_artifacts and (state.step + 1) % self._log_artifacts_every_n_batches == 0: + if self._enabled and self._log_artifacts and (state.step + 1) % self._log_artifacts_every_n_batches == 0: self._upload_artifacts() def epoch_end(self, state: State, logger: Logger) -> None: del state, logger # unused - if self._log_artifacts: + if self._enabled and self._log_artifacts: self._upload_artifacts() def training_end(self, state: State, logger: Logger) -> None: del state, logger # unused - if self._log_artifacts: + if self._enabled and self._log_artifacts: self._upload_artifacts() def _upload_artifacts(self): @@ -79,14 +97,10 @@ def _upload_artifacts(self): # so uploads will not block the training loop # slowdown is likely from extra I/O of scanning the directory and/or # scheduling uploads - run_directory_name = run_directory.get_run_directory() - if run_directory_name is None: - return - modified_files = run_directory.get_modified_files(self._last_upload_timestamp) for modified_file in modified_files: file_type = modified_file.split(".")[-1] - relpath = os.path.relpath(modified_file, run_directory_name) + relpath = os.path.relpath(modified_file, run_directory.get_run_directory()) relpath = f"rank_{dist.get_global_rank()}-" + relpath.replace("/", "-") artifact = wandb.Artifact(name=relpath, type=file_type) artifact.add_file(os.path.abspath(modified_file)) @@ -95,6 +109,9 @@ def _upload_artifacts(self): def post_close(self) -> None: # Cleaning up on post_close so all artifacts are uploaded + if not self._enabled: + return + if self._log_artifacts: self._upload_artifacts() diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index ef2ded8d1f..f5b24b0f48 100755 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -639,13 +639,15 @@ def _train_loop(self) -> None: for scheduler in state.schedulers: scheduler.step(interval='batch') # type: ignore - if self.validate_every_n_batches > 0 and (state.step + 1) % self.validate_every_n_batches == 0: - self.eval(is_batch=True) - state.timer.on_batch_complete( samples=int(num_samples_in_batch.item()), tokens=int(num_tokens_in_batch.item()), ) + + if self.validate_every_n_batches > 0 and int( + state.timer.batch) % self.validate_every_n_batches == 0: + self.eval(is_batch=True) + if self.checkpoint_saver and self.checkpoint_saver.should_checkpoint(state=state, event=Event.BATCH_END): self.checkpoint_saver.save_checkpoint(state=state, @@ -660,11 +662,11 @@ def _train_loop(self) -> None: self.engine.run_event(Event.EPOCH_END) - if self.validate_every_n_epochs > 0 and (state.epoch + 1) % self.validate_every_n_epochs == 0: - self.eval(is_batch=False) - state.timer.on_epoch_complete() + if self.validate_every_n_epochs > 0 and int(state.timer.epoch) % self.validate_every_n_epochs == 0: + self.eval(is_batch=False) + if self.checkpoint_saver and self.checkpoint_saver.should_checkpoint(state=state, event=Event.EPOCH_END): self.checkpoint_saver.save_checkpoint(state=state, seed=self.seed, diff --git a/composer/yamls/models/bert-base.yaml b/composer/yamls/models/bert-base.yaml index d6e642208b..6c7c4779a3 100644 --- a/composer/yamls/models/bert-base.yaml +++ b/composer/yamls/models/bert-base.yaml @@ -54,13 +54,7 @@ schedulers: interval: step verbose: false loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 7ep # Baseline is 256M samples, 7 epochs is ~280M samples train_batch_size: 4000 eval_batch_size: 2000 diff --git a/composer/yamls/models/glue/cola.yaml b/composer/yamls/models/glue/cola.yaml index ce4e58f7e5..031a696b39 100644 --- a/composer/yamls/models/glue/cola.yaml +++ b/composer/yamls/models/glue/cola.yaml @@ -41,13 +41,7 @@ schedulers: interval: step verbose: false loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 10ep train_batch_size: 32 eval_batch_size: 32 diff --git a/composer/yamls/models/glue/mnli-m.yaml b/composer/yamls/models/glue/mnli-m.yaml index 87db18ef36..1fbe15bde8 100755 --- a/composer/yamls/models/glue/mnli-m.yaml +++ b/composer/yamls/models/glue/mnli-m.yaml @@ -41,13 +41,7 @@ schedulers: interval: step verbose: false loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 3ep train_batch_size: 48 eval_batch_size: 48 diff --git a/composer/yamls/models/glue/mnli.yaml b/composer/yamls/models/glue/mnli.yaml index c833fd9380..671595e861 100755 --- a/composer/yamls/models/glue/mnli.yaml +++ b/composer/yamls/models/glue/mnli.yaml @@ -41,13 +41,7 @@ schedulers: interval: step verbose: false loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 3ep train_batch_size: 48 eval_batch_size: 48 diff --git a/composer/yamls/models/glue/mrpc.yaml b/composer/yamls/models/glue/mrpc.yaml index 5d1a130293..fc284a3af9 100755 --- a/composer/yamls/models/glue/mrpc.yaml +++ b/composer/yamls/models/glue/mrpc.yaml @@ -41,13 +41,7 @@ schedulers: interval: step verbose: false loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 10ep train_batch_size: 32 eval_batch_size: 32 diff --git a/composer/yamls/models/glue/qnli.yaml b/composer/yamls/models/glue/qnli.yaml index 26a6c968ce..fd7aa65071 100644 --- a/composer/yamls/models/glue/qnli.yaml +++ b/composer/yamls/models/glue/qnli.yaml @@ -41,13 +41,7 @@ schedulers: interval: step verbose: false loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 10ep train_batch_size: 16 eval_batch_size: 16 diff --git a/composer/yamls/models/glue/qqp.yaml b/composer/yamls/models/glue/qqp.yaml index d7fdc62f25..20045c686c 100755 --- a/composer/yamls/models/glue/qqp.yaml +++ b/composer/yamls/models/glue/qqp.yaml @@ -41,13 +41,7 @@ schedulers: interval: step verbose: false loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 5ep train_batch_size: 16 eval_batch_size: 16 diff --git a/composer/yamls/models/glue/rte.yaml b/composer/yamls/models/glue/rte.yaml index 8a2448e056..877442d4da 100644 --- a/composer/yamls/models/glue/rte.yaml +++ b/composer/yamls/models/glue/rte.yaml @@ -41,13 +41,7 @@ schedulers: interval: step verbose: false loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 3ep train_batch_size: 16 eval_batch_size: 16 diff --git a/composer/yamls/models/glue/sst-2.yaml b/composer/yamls/models/glue/sst-2.yaml index c5fc0a7e9b..06b5ce50a1 100644 --- a/composer/yamls/models/glue/sst-2.yaml +++ b/composer/yamls/models/glue/sst-2.yaml @@ -41,13 +41,7 @@ schedulers: interval: step verbose: false loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 3ep train_batch_size: 16 eval_batch_size: 16 diff --git a/composer/yamls/models/glue/stsb.yaml b/composer/yamls/models/glue/stsb.yaml index aa4e3a1199..a3c988a6f8 100644 --- a/composer/yamls/models/glue/stsb.yaml +++ b/composer/yamls/models/glue/stsb.yaml @@ -41,13 +41,7 @@ schedulers: interval: step verbose: false loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 10ep train_batch_size: 32 eval_batch_size: 32 diff --git a/composer/yamls/models/gpt2_1,3b.yaml b/composer/yamls/models/gpt2_1,3b.yaml index 5c24c88e9a..2902d8f798 100755 --- a/composer/yamls/models/gpt2_1,3b.yaml +++ b/composer/yamls/models/gpt2_1,3b.yaml @@ -73,13 +73,7 @@ schedulers: verbose: false T_max: 13860ba loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 1ep train_batch_size: 1024 eval_batch_size: 8 diff --git a/composer/yamls/models/gpt2_125m.yaml b/composer/yamls/models/gpt2_125m.yaml index 517868b878..f6848416c2 100644 --- a/composer/yamls/models/gpt2_125m.yaml +++ b/composer/yamls/models/gpt2_125m.yaml @@ -75,13 +75,7 @@ schedulers: verbose: false T_max: 13860ba loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 1ep train_batch_size: 512 eval_batch_size: 8 # use micro_bs_per_gpu = 1 to accomodate 10GB limit diff --git a/composer/yamls/models/gpt2_13b.yaml b/composer/yamls/models/gpt2_13b.yaml index 0712c0d437..c6c1c50d3a 100755 --- a/composer/yamls/models/gpt2_13b.yaml +++ b/composer/yamls/models/gpt2_13b.yaml @@ -73,13 +73,7 @@ schedulers: verbose: false T_max: 13860ba loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 1ep train_batch_size: 2048 eval_batch_size: 8 diff --git a/composer/yamls/models/gpt2_2,7b.yaml b/composer/yamls/models/gpt2_2,7b.yaml index 09e1934649..88f3c132c5 100755 --- a/composer/yamls/models/gpt2_2,7b.yaml +++ b/composer/yamls/models/gpt2_2,7b.yaml @@ -73,13 +73,7 @@ schedulers: verbose: false T_max: 13860ba loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 1ep train_batch_size: 1024 eval_batch_size: 8 diff --git a/composer/yamls/models/gpt2_350m.yaml b/composer/yamls/models/gpt2_350m.yaml index c76cbc767b..faa8ea10a7 100755 --- a/composer/yamls/models/gpt2_350m.yaml +++ b/composer/yamls/models/gpt2_350m.yaml @@ -73,13 +73,7 @@ schedulers: verbose: false T_max: 10890ba loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 1ep train_batch_size: 512 eval_batch_size: 8 diff --git a/composer/yamls/models/gpt2_52m.yaml b/composer/yamls/models/gpt2_52m.yaml index 5fafb97f77..b47796cf88 100644 --- a/composer/yamls/models/gpt2_52m.yaml +++ b/composer/yamls/models/gpt2_52m.yaml @@ -75,13 +75,7 @@ schedulers: verbose: false T_max: 8910ba loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 1ep train_batch_size: 512 eval_batch_size: 16 # use micro_bs_per_gpu = 2 to accomodate 10GB limit diff --git a/composer/yamls/models/gpt2_6,7b.yaml b/composer/yamls/models/gpt2_6,7b.yaml index 7d7d90085c..deebfde324 100755 --- a/composer/yamls/models/gpt2_6,7b.yaml +++ b/composer/yamls/models/gpt2_6,7b.yaml @@ -73,13 +73,7 @@ schedulers: verbose: false T_max: 13860ba loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 1ep train_batch_size: 2048 eval_batch_size: 8 diff --git a/composer/yamls/models/gpt2_760m.yaml b/composer/yamls/models/gpt2_760m.yaml index dacc4ca26c..d1f53e6cda 100755 --- a/composer/yamls/models/gpt2_760m.yaml +++ b/composer/yamls/models/gpt2_760m.yaml @@ -73,13 +73,7 @@ schedulers: verbose: false T_max: 13860ba loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 1ep train_batch_size: 512 eval_batch_size: 8 diff --git a/composer/yamls/models/gpt2_83m.yaml b/composer/yamls/models/gpt2_83m.yaml index 2bdcb41003..abb12b7416 100644 --- a/composer/yamls/models/gpt2_83m.yaml +++ b/composer/yamls/models/gpt2_83m.yaml @@ -75,13 +75,7 @@ schedulers: verbose: false T_max: 10890ba loggers: - - file: - log_level: batch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_batches: 100 - every_n_epochs: 1 + - tqdm: {} max_duration: 1ep train_batch_size: 512 eval_batch_size: 16 # use micro_bs_per_gpu = 2 to accomodate 10GB limit diff --git a/composer/yamls/models/resnet56_cifar10.yaml b/composer/yamls/models/resnet56_cifar10.yaml index 168501c1a5..0d5fe56a78 100644 --- a/composer/yamls/models/resnet56_cifar10.yaml +++ b/composer/yamls/models/resnet56_cifar10.yaml @@ -37,13 +37,7 @@ model: - bn_uniform num_classes: 10 loggers: - - file: - log_level: epoch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_epochs: 1 - every_n_batches: 100 + - tqdm: {} max_duration: 160ep train_batch_size: 1024 eval_batch_size: 1000 diff --git a/composer/yamls/models/resnet56_cifar10_synthetic.yaml b/composer/yamls/models/resnet56_cifar10_synthetic.yaml index a97decda59..f5c98da987 100644 --- a/composer/yamls/models/resnet56_cifar10_synthetic.yaml +++ b/composer/yamls/models/resnet56_cifar10_synthetic.yaml @@ -37,13 +37,7 @@ model: - bn_uniform num_classes: 10 loggers: - - file: - log_level: epoch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_epochs: 1 - every_n_batches: 100 + - tqdm: {} max_duration: 160ep train_batch_size: 1024 eval_batch_size: 1000 diff --git a/composer/yamls/models/resnet9_cifar10.yaml b/composer/yamls/models/resnet9_cifar10.yaml index c96d44dc3d..e406aec79e 100644 --- a/composer/yamls/models/resnet9_cifar10.yaml +++ b/composer/yamls/models/resnet9_cifar10.yaml @@ -37,13 +37,7 @@ model: - bn_uniform num_classes: 10 loggers: - - file: - log_level: epoch - filename: stdout - buffer_size: 1 - flush_every_n_batches: 100 - every_n_epochs: 1 - every_n_batches: 100 + - tqdm: {} max_duration: 160ep train_batch_size: 1024 eval_batch_size: 1000 diff --git a/docs/source/core/callback.rst b/docs/source/core/callback.rst index 7925a7d068..1016c031fa 100644 --- a/docs/source/core/callback.rst +++ b/docs/source/core/callback.rst @@ -56,4 +56,3 @@ Callbacks can be implemented in two ways: :nosignatures: ~composer.Callback - ~composer.core.callback.RankZeroCallback diff --git a/docs/source/core/logger.rst b/docs/source/core/logger.rst index f890b9cf9f..0f2d44a185 100644 --- a/docs/source/core/logger.rst +++ b/docs/source/core/logger.rst @@ -30,6 +30,5 @@ For example, to define a new logging backend: ~composer.loggers.logger_hparams.BaseLoggerBackendHparams ~composer.core.logging.base_backend.BaseLoggerBackend - ~composer.core.logging.base_backend.RankZeroLoggerBackend .. autoclass:: Logger diff --git a/tests/test_logger.py b/tests/test_logger.py index db0015be80..bc983ea4a7 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -5,17 +5,16 @@ from unittest.mock import MagicMock import pytest -import torch.distributed as dist import tqdm from _pytest.monkeypatch import MonkeyPatch from composer.core.event import Event from composer.core.logging import Logger, LogLevel from composer.core.state import State -from composer.loggers.file_logger import FileLoggerBackend from composer.loggers.logger_hparams import (FileLoggerBackendHparams, TQDMLoggerBackendHparams, WandBLoggerBackendHparams) from composer.trainer.trainer_hparams import TrainerHparams +from composer.utils import dist @pytest.fixture @@ -23,50 +22,54 @@ def log_file_name(tmpdir: pathlib.Path) -> str: return os.path.join(tmpdir, "output.log") -@pytest.fixture -def log_destination(log_file_name: str) -> FileLoggerBackend: - return FileLoggerBackendHparams( - every_n_batches=3, - every_n_epochs=2, - log_level=LogLevel.BATCH, +@pytest.mark.parametrize("log_level", [LogLevel.EPOCH, LogLevel.BATCH]) +def test_file_logger(dummy_state: State, log_level: LogLevel, log_file_name: str): + log_destination = FileLoggerBackendHparams( + log_interval=3, + log_level=log_level, filename=log_file_name, buffer_size=1, - flush_every_n_batches=1, + flush_interval=1, ).initialize_object() - - -def test_file_logger(dummy_state: State, log_destination: FileLoggerBackend, monkeypatch: MonkeyPatch, - log_file_name: str): dummy_state.timer.on_batch_complete() dummy_state.timer.on_batch_complete() dummy_state.timer.on_epoch_complete() - dummy_state.timer.on_epoch_complete() logger = Logger(dummy_state, backends=[log_destination]) - monkeypatch.setattr(dist, "get_rank", lambda: 0) log_destination.run_event(Event.INIT, dummy_state, logger) logger.metric_fit({"metric": "fit"}) # should print - logger.metric_epoch({"metric": "epoch"}) # should print - logger.metric_batch({"metric": "batch"}) # should print - logger.metric_verbose({"metric": "verbose"}) # should NOT print, since we're on the BATCH log level + logger.metric_epoch({"metric": "epoch"}) # should print on batch level, since epoch calls are always printed + logger.metric_batch({"metric": "batch"}) # should print on batch level, since we print every 3 steps dummy_state.timer.on_epoch_complete() - logger.metric_epoch({"metric": "epoch1"}) # should NOT print, since we print every 2 epochs + logger.metric_epoch({"metric": "epoch1"}) # should print, since we log every 3 epochs dummy_state.timer.on_epoch_complete() dummy_state.timer.on_batch_complete() log_destination.run_event(Event.BATCH_END, dummy_state, logger) - logger.metric_epoch({"metric": "epoch2"}) # should print - logger.metric_batch({"metric": "batch1"}) # should NOT print, since we print every 3 steps + logger.metric_epoch({"metric": "epoch2"}) # should print on batch level, since epoch calls are always printed + logger.metric_batch({"metric": "batch1"}) # should NOT print log_destination.run_event(Event.BATCH_END, dummy_state, logger) log_destination.run_event(Event.TRAINING_END, dummy_state, logger) with open(log_file_name, 'r') as f: - assert f.readlines() == [ - '[FIT][step=2]: { "metric": "fit", }\n', - '[EPOCH][step=2]: { "metric": "epoch", }\n', - '[BATCH][step=2]: { "metric": "batch", }\n', - '[EPOCH][step=3]: { "metric": "epoch2", }\n', - ] + if log_level == LogLevel.EPOCH: + assert f.readlines() == [ + '[FIT][step=2]: { "metric": "fit", }\n', + '[EPOCH][step=2]: { "metric": "epoch1", }\n', + ] + else: + assert log_level == LogLevel.BATCH + assert f.readlines() == [ + '[FIT][step=2]: { "metric": "fit", }\n', + '[EPOCH][step=2]: { "metric": "epoch", }\n', + '[BATCH][step=2]: { "metric": "batch", }\n', + '[EPOCH][step=2]: { "metric": "epoch1", }\n', + '[EPOCH][step=3]: { "metric": "epoch2", }\n', + ] -def test_tqdm_logger(mosaic_trainer_hparams: TrainerHparams, monkeypatch: MonkeyPatch): +@pytest.mark.parametrize("world_size", [ + pytest.param(1), + pytest.param(2, marks=pytest.mark.world_size(2)), +]) +def test_tqdm_logger(mosaic_trainer_hparams: TrainerHparams, monkeypatch: MonkeyPatch, world_size: int): is_train_to_mock_tqdms = { True: [], False: [], @@ -85,6 +88,8 @@ def get_mock_tqdm(position: int, *args, **kwargs): mosaic_trainer_hparams.loggers = [TQDMLoggerBackendHparams()] trainer = mosaic_trainer_hparams.initialize_object() trainer.fit() + if dist.get_global_rank() == 1: + return assert len(is_train_to_mock_tqdms[True]) == max_epochs assert mosaic_trainer_hparams.validate_every_n_batches < 0 assert len(is_train_to_mock_tqdms[False]) == mosaic_trainer_hparams.validate_every_n_epochs * max_epochs diff --git a/tests/test_mosaicml_logger.py b/tests/test_mosaicml_logger.py index e0cf351a2c..91af8de6d6 100755 --- a/tests/test_mosaicml_logger.py +++ b/tests/test_mosaicml_logger.py @@ -62,7 +62,7 @@ def _mock_upsert_run(run_id: str, expected_log_calls = 0 for i in range(num_times_to_log): data_point = {f'data-{i}': 'value'} - logger._log_metric(epoch=1, step=i, log_level=LogLevel.BATCH, data=data_point) + logger.log_metric(epoch=1, step=i, log_level=LogLevel.BATCH, data=data_point) dummy_state.timer.on_batch_complete() logger.batch_end(dummy_state, dummy_logger) diff --git a/tests/trainer/test_ddp.py b/tests/trainer/test_ddp.py index ffd300934e..f0ddcf953f 100755 --- a/tests/trainer/test_ddp.py +++ b/tests/trainer/test_ddp.py @@ -215,8 +215,9 @@ def test_ddp(device: DeviceHparams, world_size: int, mosaic_trainer_hparams: Tra for epoch in range(max_epochs): for local_rank in range(dist.get_local_world_size()): for is_train in (True, False): + real_epoch = epoch if is_train else epoch + 1 # validation is 1 ahead of training data: Dict[str, types.Tensor] = torch.load( # type: ignore - get_batch_file_path(rank=local_rank, epoch=epoch, is_train=is_train), + get_batch_file_path(rank=local_rank, epoch=real_epoch, is_train=is_train), map_location='cpu', ) for pickle in is_train_to_pickles[is_train]: