Skip to content

Commit

Permalink
Fixed loggers and callbacks (#240)
Browse files Browse the repository at this point in the history
1. Removed rank zero callbacks and loggers, since these hid complexity and led to infinitely-blocking code when using distributed functions. Closes #239.
2. Incrementing `state.timer`  _before_ calling `.eval()` in the trainer. This helps ensure that the batch count is consistent for both batch-wise and epoch-wise evaluators. This batch is printed in the logs.
3. Fixed the TQDM logger so it works properly with gradient accumulation.
4. Removed `LogLevel.ALGORITHM`, `LogLevel.MICROBATCH`, and `LogLevel.VERBOSE` since these were rarely being used. Instead, the built-in python logger should probably be used for anything that is verbose (since it really wouldn't be a useful metric), MICROBATCH should use BATCH (since a MICROBATCH is like another gpu), and ALGORITHM should use batch or epoch, depending where it is being run.,
5. Updated the file logger to take a `log_interval` instead of `log_every_n_epochs` and `log_every_n_batches`, and a `flush_interval` instead of `flush_every_n_batches`.
6. Switched the default logger in all yamls to tqdm.
  • Loading branch information
ravi-mosaicml authored Jan 19, 2022
1 parent 656bfe3 commit 083aff1
Show file tree
Hide file tree
Showing 39 changed files with 231 additions and 414 deletions.
22 changes: 0 additions & 22 deletions composer/core/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import TYPE_CHECKING

from composer.core.serializable import Serializable
from composer.utils import dist

try:
from typing import final
Expand Down Expand Up @@ -315,24 +314,3 @@ def post_close(self) -> None:
callbacks during :meth:`close`.
"""
pass


class RankZeroCallback(Callback, abc.ABC):
"""Base class for callbacks that only run on the local rank zero process.
Callbacks can be implemented in two ways:
#. Override the individual methods named for each :class:`Event`. (See
the parent class, :class:`Callback`.)
#. Override :meth:`_run_event` (**not** :meth:`run_event`) to run in response
to all events. If this method is overridden, then the individual methods
corresponding to each event name will not be automatically called (however,
the subclass implementation can invoke these methods as it wishes.)
"""

@final
def run_event(self, event: Event, state: State, logger: Logger) -> None:
if dist.get_local_rank() != 0:
return
return self._run_event(event, state, logger)
11 changes: 10 additions & 1 deletion composer/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from composer.core.callback import Callback
from composer.core.event import Event
from composer.core.logging import Logger
from composer.core.logging.logger import LogLevel
from composer.core.state import State

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -126,7 +127,15 @@ def _run_algorithms(
trace[trace_key] = Trace(exit_code=exit_code, order=order, run=True)

if self.logger is not None:
self.logger.metric_verbose(data={key: 1 if tr.run else 0 for key, tr in trace.items()})
if event in (Event.INIT, Event.TRAINING_START, Event.TRAINING_END):
log_level = LogLevel.FIT
if event in (Event.EPOCH_START, Event.EPOCH_END):
log_level = LogLevel.EPOCH
else:
# algs don't run on eval events, so don't have to worry about
# batch-frequency vs epoch-frequency evaluators
log_level = LogLevel.BATCH
self.logger.metric(log_level=log_level, data={key: 1 if tr.run else 0 for key, tr in trace.items()})

return trace

Expand Down
1 change: 0 additions & 1 deletion composer/core/logging/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright 2021 MosaicML. All Rights Reserved.

from composer.core.logging.base_backend import BaseLoggerBackend as BaseLoggerBackend
from composer.core.logging.base_backend import RankZeroLoggerBackend as RankZeroLoggerBackend
from composer.core.logging.logger import Logger as Logger
from composer.core.logging.logger import LogLevel as LogLevel
from composer.core.logging.logger import TLogData as TLogData
Expand Down
78 changes: 1 addition & 77 deletions composer/core/logging/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,12 @@
from abc import ABC
from typing import TYPE_CHECKING

from composer.core.callback import Callback, RankZeroCallback
from composer.utils import dist
from composer.core.callback import Callback

if TYPE_CHECKING:
from composer.core.logging.logger import LogLevel, TLogData
from composer.core.state import State

try:
from typing import final
except ImportError:
final = lambda x: x # final is not available in python 3.7


class BaseLoggerBackend(Callback, ABC):
"""Base class for logging backends.
Expand Down Expand Up @@ -59,73 +53,3 @@ def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData)
"""
del epoch, step, log_level, data # unused
pass


class RankZeroLoggerBackend(BaseLoggerBackend, RankZeroCallback, ABC):
"""Base class for logging backends that run only on the rank zero process.
In a multi-process training setup (e.g. when using DistributedDataParallel),
some logging backends require that only the rank zero process log data.
For example, when logging to a file, only the main process should open the file
and save data.
When using this class, override
:func:`_will_log` and :func:`_log_metric`` instead of
:func:`will_log` and :func:`log_metric`, respectively.
This class ensures that :func:`_will_log` and :func:`_log_metric`
are invoked only on the rank zero process.
.. automethod:: _will_log
.. automethod:: _log_metric
"""

def __init__(self) -> None:
super().__init__()

def _will_log(self, state: State, log_level: LogLevel) -> bool:
"""Called by the :class:`~composer.core.logging.logger.Logger`
to determine whether the logging backend will log a metric.
By default, it always returns ``True``, but this method
can be overridden.
Args:
state (State): The global state object.
log_level (LogLevel): The log level.
Returns:
bool: Whether to log a metric call, given the
:class:`~composer.core.state.State` and
:class:`~composer.core.logging.logger.LogLevel`.
"""
del state, log_level # Unused
return True

@final
def will_log(self, state: State, log_level: LogLevel) -> bool:
if dist.get_local_rank() != 0:
return False
return self._will_log(state, log_level)

def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) -> None:
"""Called by the :class:`~composer.core.logging.logger.Logger`
for metrics where :func:`will_log` returned ``True``.
The logging backend should override this function to log the data
(e.g. write it to a file, send it to a server, etc...).
Args:
epoch (int): The epoch for the logged data.
step (int): The global step for the logged data.
log_level (LogLevel). The log level.
data (TLogData): The metric to log.
"""
del epoch, step, log_level, data # Unused
pass

@final
def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) -> None:
if dist.get_local_rank() != 0:
return
return self._log_metric(epoch, step, log_level, data)
12 changes: 0 additions & 12 deletions composer/core/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,10 @@ class LogLevel(IntEnum):
FIT: Logged once per training run.
EPOCH: Logged once per epoch.
BATCH: Logged once per batch.
MICROBATCH: Logged once per microbatch (e.g. forward pass).
VERBOSE: Logged for debugging.
"""
FIT = 1
EPOCH = 2
BATCH = 3
MICROBATCH = 4
VERBOSE = 5


class Logger:
Expand Down Expand Up @@ -109,14 +105,6 @@ def metric_batch(self, data: Union[TLogData, Callable[[], TLogData]]) -> None:
"""Helper function for ``self.metric(LogLevel.BATCH, data)``"""
self.metric(LogLevel.BATCH, data)

def metric_microbatch(self, data: Union[TLogData, Callable[[], TLogData]]) -> None:
"""Helper function for ``self.metric(LogLevel.MICROBATCH, data)``"""
self.metric(LogLevel.MICROBATCH, data)

def metric_verbose(self, data: Union[TLogData, Callable[[], TLogData]]) -> None:
"""Helper function for ``self.metric(LogLevel.VERBOSE, data)``"""
self.metric(LogLevel.VERBOSE, data)


def format_log_data_value(data: TLogDataValue) -> str:
"""Recursively formats a given log data value into a string.
Expand Down
1 change: 0 additions & 1 deletion composer/loggers/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright 2021 MosaicML. All Rights Reserved.

from composer.core.logging.base_backend import BaseLoggerBackend as BaseLoggerBackend
from composer.core.logging.base_backend import RankZeroLoggerBackend as RankZeroLoggerBackend
from composer.core.logging.logger import Logger as Logger
from composer.core.logging.logger import LogLevel as LogLevel
from composer.core.logging.logger import TLogData as TLogData
Expand Down
92 changes: 52 additions & 40 deletions composer/loggers/file_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@

import yaml

from composer.core.logging import Logger, LogLevel, RankZeroLoggerBackend, TLogData, format_log_data_value
from composer.core.logging import BaseLoggerBackend, Logger, LogLevel, TLogData, format_log_data_value
from composer.core.state import State
from composer.loggers.logger_hparams import FileLoggerBackendHparams
from composer.utils import run_directory


class FileLoggerBackend(RankZeroLoggerBackend):
class FileLoggerBackend(BaseLoggerBackend):
"""Logs to a file or to the terminal.
Example output::
Expand All @@ -33,14 +32,17 @@ class FileLoggerBackend(RankZeroLoggerBackend):
log_level (LogLevel, optional): Maximum
:class:`~composer.core.logging.logger.LogLevel`. to record.
(default: :attr:`~composer.core.logging.logger.LogLevel.EPOCH`)
every_n_epochs (int, optional):
Frequency to print :attr:`~composer.core.logging.logger.LogLevel.EPOCH` logs.
(default: ``1``)
every_n_batches (int, optional):
Frequency to print :attr:`~composer.core.logging.logger.LogLevel.BATCH` logs.
(default: ``1``)
flush_every_n_batches (int, optional): How frequently to flush the log to the file.
(default: ``1``)
log_interval (int, optional):
Frequency to print logs. If ``log_level` is :attr:`~composer.core.logging.logger.LogLevel.EPOCH`,
logs will only be recorded every n epochs. If ``log_level` is
:attr:`~composer.core.logging.logger.LogLevel.BATCH`, logs will be printed every n batches.
Otherwise, if ``log_level` is :attr:`~composer.core.logging.logger.LogLevel.FIT`, this parameter is
ignored, as calls at the fit log level are always recorded. (default: ``1``)
flush_interval (int, optional): How frequently to flush the log to the file, relative to the ``log_level``.
For example, if the ``log_level`` is :attr:`~composer.core.logging.logger.LogLevel.EPOCH`,
then the logfile will be flushed every n epochs.
If the ``log_level`` is :attr:`~composer.core.logging.logger.LogLevel.BATCH`, then the logfile will be flushed
every n batches. (default: ``100``)
"""

def __init__(
Expand All @@ -49,66 +51,76 @@ def __init__(
*,
buffer_size: int = 1,
log_level: LogLevel = LogLevel.EPOCH,
every_n_epochs: int = 1,
every_n_batches: int = 1,
flush_every_n_batches: int = 1,
log_interval: int = 1,
flush_interval: int = 100,
config: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()
self.hparams = FileLoggerBackendHparams(
filename=filename,
buffer_size=buffer_size,
log_level=log_level,
every_n_epochs=every_n_epochs,
every_n_batches=every_n_batches,
flush_every_n_batches=flush_every_n_batches,
)
self.filename = filename
self.buffer_size = buffer_size
self.log_level = log_level
self.log_interval = log_interval
self.flush_interval = flush_interval
self.file: Optional[TextIO] = None
self.config = config

def _will_log(self, state: State, log_level: LogLevel) -> bool:
if log_level > self.hparams.log_level:
return False
if log_level >= LogLevel.EPOCH and state.epoch % self.hparams.every_n_epochs != 0:
return False
if log_level >= LogLevel.BATCH and (state.step + 1) % self.hparams.every_n_batches != 0:
return False
return True

def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData):
def will_log(self, state: State, log_level: LogLevel) -> bool:
if log_level == LogLevel.FIT:
return True # fit is always logged
if log_level == LogLevel.EPOCH:
if self.log_level < LogLevel.EPOCH:
return False
if self.log_level > LogLevel.EPOCH:
return True
return (int(state.timer.epoch) + 1) % self.log_interval == 0
if log_level == LogLevel.BATCH:
if self.log_level < LogLevel.BATCH:
return False
if self.log_level > LogLevel.BATCH:
return True
return (int(state.timer.batch) + 1) % self.log_interval == 0
raise ValueError(f"Unknown log level: {log_level}")

def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData):
data_str = format_log_data_value(data)
if self.file is None:
raise RuntimeError("Attempted to log before self.init() or after self.close()")
print(f"[{log_level.name}][step={step}]: {data_str}", file=self.file)
print(f"[{log_level.name}][step={step}]: {data_str}", file=self.file, flush=False)

def init(self, state: State, logger: Logger) -> None:
del state, logger # unused
if self.file is not None:
raise RuntimeError("The file logger is already initialized")
if self.hparams.filename == "stdout":
if self.filename == "stdout":
self.file = sys.stdout
elif self.hparams.filename == "stderr":
elif self.filename == "stderr":
self.file = sys.stderr
else:
self.file = open(os.path.join(run_directory.get_run_directory(), self.hparams.filename),
self.file = open(os.path.join(run_directory.get_run_directory(), self.filename),
"x+",
buffering=self.hparams.buffer_size)
buffering=self.buffer_size)
if self.config is not None:
print("Config", file=self.file)
print("-" * 30, file=self.file)
yaml.safe_dump(self.config, stream=self.file)
print("-" * 30, file=self.file)
print(file=self.file)

def training_start(self, state: State, logger: Logger) -> None:
del state, logger # unused
self._flush_file()

def batch_end(self, state: State, logger: Logger) -> None:
del logger # unused
assert self.file is not None
if (state.step + 1) % self.hparams.flush_every_n_batches == 0:
if self.log_level == LogLevel.BATCH and (int(state.timer.batch) + 1) % self.flush_interval == 0:
self._flush_file()

def epoch_end(self, state: State, logger: Logger) -> None:
del state, logger # unused
self._flush_file()
del logger # unused
if self.log_level > LogLevel.EPOCH or self.log_level == LogLevel.EPOCH and (int(state.timer.epoch) +
1) % self.flush_interval == 0:
self._flush_file()

def training_end(self, state: State, logger: Logger) -> None:
self._flush_file()
Expand Down
20 changes: 9 additions & 11 deletions composer/loggers/logger_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,13 @@ class FileLoggerBackendHparams(BaseLoggerBackendHparams):
buffer_size: int = hp.optional("Number of bytes to buffer. Defaults to 1 for line-buffering. "
"See https://docs.python.org/3/library/functions.html#open",
default=1) # line buffering. Python's default is -1.
flush_every_n_batches: int = hp.optional(
"Even if the buffer is not full, write to the file after this many steps. "
"Defaults to 1 (every step).",
default=1)
every_n_epochs: int = hp.optional(
"Frequency of logging messages for messages of LogLevel.EPOCH and higher."
"Defaults to 1 (every epoch).",
default=1)
every_n_batches: int = hp.optional(
"Frequency of logging messages for messages of LogLevel.BATCH and higher."
"Defaults to 1 (every batch).",
flush_interval: int = hp.optional(
"Frequency to flush the file, relative to the ``log_level``. "
"Defaults to 100 of the unit of ``log_level``.",
default=100)
log_interval: int = hp.optional(
"Frequency to record log messages, relative to the ``log_level``."
"Defaults to 1 (record all messages).",
default=1)

def initialize_object(self, config: Optional[Dict[str, Any]] = None) -> FileLoggerBackend:
Expand Down Expand Up @@ -98,6 +94,7 @@ class WandBLoggerBackendHparams(BaseLoggerBackendHparams):
tags: str = hp.optional(doc="wandb tags comma separated", default="")
log_artifacts: bool = hp.optional(doc="Whether to log artifacts", default=False)
log_artifacts_every_n_batches: int = hp.optional(doc="interval, in batches, to log artifacts", default=100)
rank_zero_only: bool = hp.optional("Whether to log on rank zero only", default=False)
extra_init_params: Dict[str, JSON] = hp.optional(doc="wandb parameters", default_factory=dict)

def initialize_object(self, config: Optional[Dict[str, Any]] = None) -> WandBLoggerBackend:
Expand Down Expand Up @@ -202,6 +199,7 @@ def get_flattened_dict(data: Dict[str, Any], _prefix: List[str] = []) -> Dict[st
from composer.loggers.wandb_logger import WandBLoggerBackend
return WandBLoggerBackend(
log_artifacts=self.log_artifacts,
rank_zero_only=self.rank_zero_only,
log_artifacts_every_n_batches=self.log_artifacts_every_n_batches,
init_params=init_params,
)
Expand Down
Loading

0 comments on commit 083aff1

Please sign in to comment.