From 6357f2e4432cbeb015c1301bab8397ecd890c921 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Mon, 15 Nov 2021 14:33:34 -0800 Subject: [PATCH 01/38] Added `run_event` to callback Closes #11 This PR helps clean up some of the tests, rank zero callbacks, and will be used by future profiling work. --- composer/core/callback.py | 57 ++++++++++++++----------------- composer/core/engine.py | 6 +--- tests/callbacks/test_callbacks.py | 23 ++++++++----- tests/test_logger.py | 21 ++++++------ tests/trainer/test_checkpoint.py | 8 ++--- 5 files changed, 54 insertions(+), 61 deletions(-) diff --git a/composer/core/callback.py b/composer/core/callback.py index 94b8f925b0..1c71736434 100644 --- a/composer/core/callback.py +++ b/composer/core/callback.py @@ -5,15 +5,13 @@ from __future__ import annotations import abc -from functools import wraps -from types import MethodType -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING from composer.core.serializable import Serializable from composer.utils.ddp import is_rank_zero if TYPE_CHECKING: - from composer import Logger, State + from composer import Event, Logger, State class Callback(Serializable, abc.ABC): @@ -44,6 +42,25 @@ def init(self, state: State, logger: Logger) -> None: del state, logger # unused pass + def run_event(self, event: Event, state: State, logger: Logger) -> None: + """This method is called by the engine on each event. By default, it + invokes the callback function for the event (for example, + `self.run_event(Event.TRAINING_START, state, logger)` invokes + `self.training_start(state, logger)`). If this method is overridden, + the subclass method should include `super().run_event(event, state, logger)` + so all callback methods will be invoked. + + Args: + event (Event): The event. + state (State): The state. + logger (Logger): The logger. + """ + try: + event_cb = getattr(self, event.value) + except AttributeError: + raise ValueError(f'Callback {self} has no method for event {event}') + return event_cb(state, logger) + def training_start(self, state: State, logger: Logger) -> None: """Called on the :attr:`Event.TRAINING_START` event. @@ -278,33 +295,9 @@ def eval_end(self, state: State, logger: Logger) -> None: class RankZeroCallback(Callback, abc.ABC): """Base class for callbacks that only run on the rank zero process. - - .. Note:: - - :meth:`init` and :meth:`load_state_dict` are executed - before the DDP fork and will be called on all ranks. """ - def __init__(self) -> None: - from composer.core import Event - - super().__init__() - - # ensure all callbacks are executed only on rank 0 - functions_to_wrap = [*(event.value for event in Event), "state_dict"] - - for fn_name in functions_to_wrap: - original_fn = getattr(self, fn_name) - - @wraps(original_fn) - def wrapped_fn( - backend: RankZeroCallback, - *args: Any, - original_fn: Callable[[State, Logger], None] = original_fn, - **kwargs: Any, - ) -> None: - if not is_rank_zero(): - return - return original_fn(*args, **kwargs) - - setattr(self, fn_name, MethodType(wrapped_fn, self)) + def run_event(self, event: Event, state: State, logger: Logger) -> None: + if not is_rank_zero(): + return + return super().run_event(event, state, logger) diff --git a/composer/core/engine.py b/composer/core/engine.py index ab7dee6e9a..dfdfb7983d 100644 --- a/composer/core/engine.py +++ b/composer/core/engine.py @@ -179,8 +179,4 @@ def _run_callbacks( event = Event(event) for cb in self.callbacks: - if not hasattr(cb, event.value): - raise ValueError(f'f{cb} has no method for event {event}') - else: - f = getattr(cb, event.value) - f(self.state, self.logger) + cb.run_event(event, self.state, self.logger) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 2bc0b10187..485f78abb0 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -1,7 +1,6 @@ # Copyright 2021 MosaicML. All Rights Reserved. -from unittest.mock import Mock - +import _pytest.monkeypatch import pytest from composer.core import Event @@ -15,20 +14,26 @@ def test_callbacks_map_to_events(): # callback methods must be 1:1 mapping with events # exception for private methods cb = Callback() - excluded_methods = ["state_dict", "load_state_dict", "setup"] + excluded_methods = ["state_dict", "load_state_dict", "setup", "run_event"] methods = set(m for m in dir(cb) if (m not in excluded_methods and not m.startswith("_"))) event_names = set(e.value for e in Event) assert methods == event_names @pytest.mark.parametrize('event', list(Event)) -def test_run_event_callbacks(event: Event, dummy_state: State): - callbacks = [Mock() for _ in range(5)] +def test_run_event_callbacks(event: Event, dummy_state: State, monkeypatch: _pytest.monkeypatch.MonkeyPatch): + callback = Callback() logger = Logger(dummy_state) - engine = Engine(state=dummy_state, algorithms=[], logger=logger, callbacks=callbacks) + engine = Engine(state=dummy_state, algorithms=[], logger=logger, callbacks=[callback]) + + called = [False] # storing as an array so it is a pointer, not primitive + + def patched_callback(state: State, logger: Logger): + + called[0] = True + + monkeypatch.setattr(callback, event.value, patched_callback) engine.run_event(event) - for cb in callbacks: - f = getattr(cb, event.value) - f.assert_called_once_with(dummy_state, logger) + assert called[0], "callback method not invoked" diff --git a/tests/test_logger.py b/tests/test_logger.py index 95758f2600..0dd5e5a700 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -7,6 +7,7 @@ import torch.distributed as dist from _pytest.monkeypatch import MonkeyPatch +from composer.core.event import Event from composer.core.logging import Logger, LogLevel from composer.core.state import State from composer.loggers.file_logger import FileLoggerBackend @@ -36,7 +37,7 @@ def test_file_logger(dummy_state: State, log_destination: FileLoggerBackend, mon dummy_state.epoch = 2 logger = Logger(dummy_state, backends=[log_destination]) monkeypatch.setattr(dist, "get_rank", lambda: 0) - log_destination.training_start(dummy_state, logger) + log_destination.run_event(Event.TRAINING_START, dummy_state, logger) logger.metric_fit({"metric": "fit"}) # should print logger.metric_epoch({"metric": "epoch"}) # should print logger.metric_batch({"metric": "batch"}) # should print @@ -45,11 +46,11 @@ def test_file_logger(dummy_state: State, log_destination: FileLoggerBackend, mon logger.metric_epoch({"metric": "epoch1"}) # should NOT print, since we print every 2 epochs dummy_state.epoch = 4 dummy_state.step = 3 - log_destination.batch_end(dummy_state, logger) + log_destination.run_event(Event.BATCH_END, dummy_state, logger) logger.metric_epoch({"metric": "epoch2"}) # should print logger.metric_batch({"metric": "batch1"}) # should NOT print, since we print every 3 steps - log_destination.batch_end(dummy_state, logger) - log_destination.training_end(dummy_state, logger) + log_destination.run_event(Event.BATCH_END, dummy_state, logger) + log_destination.run_event(Event.TRAINING_END, dummy_state, logger) with open(log_file_name, 'r') as f: assert f.readlines() == [ '[FIT][step=2]: { "metric": "fit", }\n', @@ -70,10 +71,10 @@ def test_deferred(self, dummy_state_without_rank: State, log_file_name: str, mon logger = Logger(dummy_state, backends=[log_destination]) logger.metric_batch({"metric": "before_training_start"}) monkeypatch.setattr(dist, "get_rank", lambda: rank) - log_destination.training_start(dummy_state, logger) + log_destination.run_event(Event.TRAINING_START, dummy_state, logger) logger.metric_batch({"metric": "after_training_start"}) - log_destination.batch_end(dummy_state, logger) - log_destination.training_end(dummy_state, logger) + log_destination.run_event(Event.BATCH_END, dummy_state, logger) + log_destination.run_event(Event.TRAINING_END, dummy_state, logger) if rank == 0: with open(log_file_name, 'r') as f: assert f.readlines() == [ @@ -96,10 +97,10 @@ def test_deep_copy(self, dummy_state_without_rank: State, log_destination: FileL logger.metric_batch({"metric": metric_data}) metric_data[0] = ["world"] monkeypatch.setattr(dist, "get_rank", lambda: 0) - log_destination.training_start(dummy_state, logger) + log_destination.run_event(Event.TRAINING_START, dummy_state, logger) logger.metric_batch({"metric": metric_data}) - log_destination.batch_end(dummy_state, logger) - log_destination.training_end(dummy_state, logger) + log_destination.run_event(Event.BATCH_END, dummy_state, logger) + log_destination.run_event(Event.TRAINING_END, dummy_state, logger) with open(log_file_name, 'r') as f: assert f.readlines() == [ '[BATCH][step=2]: { "metric": [["hello"]], }\n', diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 82edcd0103..359c13b0ad 100755 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -1,11 +1,9 @@ # Copyright 2021 MosaicML. All Rights Reserved. -import functools import os import pathlib import random import shutil -from logging import Logger from typing import Dict, Optional import pytest @@ -17,7 +15,7 @@ from composer.core.callback import Callback from composer.core.event import Event from composer.core.state import State -from composer.core.types import StateDict +from composer.core.types import Logger, StateDict from composer.trainer.devices import CPUDeviceHparams, DeviceHparams, GPUDeviceHparams from composer.trainer.trainer import Trainer from composer.trainer.trainer_hparams import TrainerHparams, callback_registry @@ -54,9 +52,9 @@ def __init__(self) -> None: for event in Event: self.event_to_num_calls[event] = 0 - setattr(self, event.value, functools.partial(self._event_catchall, event=event)) - def _event_catchall(self, state: State, logger: Logger, event: Event): + def run_event(self, event: Event, state: State, logger: Logger): + super().run_event(event, state, logger) if event == Event.TRAINING_START: # ignoring training start as it is called once per startup # and the states otherwise won't match From f395df4e93906b227a7ff0e629a5ddd565c73bb3 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Mon, 15 Nov 2021 16:13:36 -0800 Subject: [PATCH 02/38] Removed callback helper methods --- composer/callbacks/benchmarker.py | 20 +- composer/callbacks/grad_monitor.py | 8 +- composer/callbacks/lr_monitor.py | 7 +- composer/callbacks/memory_monitor.py | 15 +- composer/callbacks/speed_monitor.py | 23 ++- composer/callbacks/torch_profiler.py | 19 +- composer/core/callback.py | 279 +++----------------------- composer/core/logging/base_backend.py | 28 +-- composer/loggers/file_logger.py | 10 +- composer/loggers/tqdm_logger.py | 113 ++++------- composer/loggers/wandb_logger.py | 35 ++-- docs/source/core/callback.rst | 8 +- pyproject.toml | 2 + tests/callbacks/test_callbacks.py | 28 +-- tests/trainer/test_ddp.py | 45 ++--- 15 files changed, 197 insertions(+), 443 deletions(-) diff --git a/composer/callbacks/benchmarker.py b/composer/callbacks/benchmarker.py index 9e72b2339e..4a7353472a 100644 --- a/composer/callbacks/benchmarker.py +++ b/composer/callbacks/benchmarker.py @@ -10,6 +10,7 @@ from composer import Logger, State from composer.callbacks.callback_hparams import BenchmarkerHparams from composer.core.callback import Callback +from composer.core.event import Event from composer.core.types import BreakEpochException log = logging.getLogger(__name__) @@ -106,6 +107,17 @@ def __init__(self, self.original_max_epochs = -1 self.wct_dict = {} + def run_event(self, event: Event, state: State, logger: Logger) -> None: + super().run_event(event, state, logger) + if event == Event.TRAINING_START: + self._training_start(state, logger) + if event == Event.BATCH_START: + self._batch_start(state, logger) + if event == Event.BATCH_END: + self._batch_end(state, logger) + if event == Event.EPOCH_END: + self._epoch_end(state, logger) + def _compute_elapsed_wct(self, epoch_wct_dict, steps_per_epoch: int, n_epochs: int): wct = 0.0 wct_per_step = 0 @@ -116,7 +128,7 @@ def _compute_elapsed_wct(self, epoch_wct_dict, steps_per_epoch: int, n_epochs: i wct += wct_per_step return wct * n_epochs - def training_start(self, state: State, logger: Logger): + def _training_start(self, state: State, logger: Logger): del logger # Unused warnings.warn("The timing monitor is activated. The model will not be fully trained." "All quality metrics for this run will be incorrect.") @@ -129,7 +141,7 @@ def training_start(self, state: State, logger: Logger): self.wct_dict = {e: {s: -1.0 for s in self.step_list} for e in self.epoch_list} state.max_epochs = len(self.epoch_list) - def epoch_end(self, state: State, logger: Logger): + def _epoch_end(self, state: State, logger: Logger): prev_epoch = self.epoch_list[self.epoch_ix] epoch_wct_dict = self.wct_dict[prev_epoch] self.epoch_ix += 1 @@ -145,7 +157,7 @@ def epoch_end(self, state: State, logger: Logger): self.wall_clock_train += self._compute_elapsed_wct(epoch_wct_dict, state.steps_per_epoch, n_epochs) logger.metric_epoch({'wall_clock_train': self.wall_clock_train}) - def batch_start(self, state: State, logger: Logger): + def _batch_start(self, state: State, logger: Logger): del state, logger # Unused if self.current_time is None: self.current_time = time.time() @@ -153,7 +165,7 @@ def batch_start(self, state: State, logger: Logger): self.profile_steps = 0 self.profile_time = 0.0 - def batch_end(self, state: State, logger: Logger): + def _batch_end(self, state: State, logger: Logger): if self.current_time is not None: now = time.time() elapsed = now - self.current_time diff --git a/composer/callbacks/grad_monitor.py b/composer/callbacks/grad_monitor.py index 9fb227dd9c..bef07df5c3 100644 --- a/composer/callbacks/grad_monitor.py +++ b/composer/callbacks/grad_monitor.py @@ -1,6 +1,6 @@ # Copyright 2021 MosaicML. All Rights Reserved. -from composer.core import Logger, State +from composer.core import Event, Logger, State from composer.core.callback import Callback @@ -24,7 +24,7 @@ def __init__(self, log_layer_grad_norms: bool = False): super().__init__() self.log_layer_grad_norms = log_layer_grad_norms - def after_train_batch(self, state: State, logger: Logger): + def run_event(self, event: Event, state: State, logger: Logger): """Compute the gradient L2 norm after the reduction of the backwards pass across GPUs. This function iterates over the parameters of the model and hence may cause a reduction in @@ -33,11 +33,15 @@ def after_train_batch(self, state: State, logger: Logger): unscaling in cases where gradients are scaled. Args: + event (Event): The :class:`~composer.core.Event` object state (State): The :class:`~composer.core.State` object used during training. logger (Logger): The :class:`~composer.core.logging.logger.Logger` object. """ + super().run_event(event, state, logger) + if event != Event.AFTER_TRAIN_BATCH: + return norm = None layer_norms = {} for name, p in state.model.named_parameters(): diff --git a/composer/callbacks/lr_monitor.py b/composer/callbacks/lr_monitor.py index 461e0b4862..3958405f92 100644 --- a/composer/callbacks/lr_monitor.py +++ b/composer/callbacks/lr_monitor.py @@ -1,6 +1,6 @@ # Copyright 2021 MosaicML. All Rights Reserved. -from composer.core import Callback, Logger, State +from composer.core import Callback, Event, Logger, State from composer.utils import ensure_tuple @@ -14,7 +14,10 @@ class LRMonitor(Callback): def __init__(self) -> None: super().__init__() - def batch_end(self, state: State, logger: Logger): + def run_event(self, event: Event, state: State, logger: Logger): + super().run_event(event, state, logger) + if event != Event.BATCH_END: + return assert state.optimizers is not None, "optimizers must be defined" for optimizer in ensure_tuple(state.optimizers): lrs = [group['lr'] for group in optimizer.param_groups] diff --git a/composer/callbacks/memory_monitor.py b/composer/callbacks/memory_monitor.py index cc38f88d27..ef42861b43 100755 --- a/composer/callbacks/memory_monitor.py +++ b/composer/callbacks/memory_monitor.py @@ -4,7 +4,7 @@ from torch.cuda import device_count, memory_stats -from composer.core import Logger, State +from composer.core import Event, Logger, State from composer.core.callback import Callback log = logging.getLogger(__name__) @@ -26,16 +26,11 @@ def __init__(self): if device_count == 0: log.warn("Memory monitor only works on GPU devices.") - def after_train_batch(self, state: State, logger: Logger): - """This function calls the torch cuda memory stats and reports basic memory - statistics. + def run_event(self, event: Event, state: State, logger: Logger): + super().run_event(event, state, logger) + if event != Event.AFTER_TRAIN_BATCH: + return - Args: - state (State): The :class:`~composer.core.State` object - used during training. - logger (Logger): - The :class:`~composer.core.logging.logger.Logger` object. - """ memory_report = {} default_stats = { diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index c2d031e6f5..e8c0c9f5e5 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -6,7 +6,7 @@ from collections import deque from typing import Deque, Optional -from composer import Logger, State +from composer import Logger, State, Event from composer.callbacks.callback_hparams import SpeedMonitorHparams from composer.core.callback import RankZeroCallback from composer.core.types import StateDict @@ -63,11 +63,18 @@ def _load_state(self) -> None: self.batch_num_samples = self.loaded_state["batch_num_samples"] self.loaded_state = None - def batch_start(self, state: State, logger: Logger) -> None: - del state, logger # unused - self._load_state() - - def epoch_start(self, state: State, logger: Logger): + def _run_event(self, event: Event, state: State, logger: Logger) -> None: + if event == Event.EPOCH_START: + self._epoch_start(state, logger) + if event == Event.BATCH_START: + self._load_state() + if event == Event.BATCH_END: + self._batch_end(state, logger) + if event == Event.EPOCH_END: + self._epoch_end(state, logger) + super()._run_event(event, state, logger) + + def _epoch_start(self, state: State, logger: Logger): del state, logger # unused self._load_state() self.epoch_start_time = time.time() @@ -75,7 +82,7 @@ def epoch_start(self, state: State, logger: Logger): self.batch_num_samples.clear() self.train_examples_per_epoch = 0 - def batch_end(self, state: State, logger: Logger): + def _batch_end(self, state: State, logger: Logger): self.batch_end_times.append(time.time()) batch_num_samples = 0 batch_num_samples += state.last_batch_size @@ -90,7 +97,7 @@ def batch_end(self, state: State, logger: Logger): throughput = sum(self.batch_num_samples) / (self.batch_end_times[-1] - self.batch_end_times[0]) logger.metric_batch({'throughput/step': throughput}) - def epoch_end(self, state: State, logger: Logger): + def _epoch_end(self, state: State, logger: Logger): del state # unused epoch_time = time.time() - self.epoch_start_time self.wall_clock_train += epoch_time diff --git a/composer/callbacks/torch_profiler.py b/composer/callbacks/torch_profiler.py index 0e3724a04f..d8ca2467e3 100644 --- a/composer/callbacks/torch_profiler.py +++ b/composer/callbacks/torch_profiler.py @@ -12,7 +12,7 @@ from composer import Callback from composer.callbacks.callback_hparams import TorchProfilerHparams -from composer.core.types import StateDict +from composer.core.types import Event, StateDict from composer.utils.ddp import get_global_rank from composer.utils.run_directory import get_relative_to_run_directory @@ -132,7 +132,16 @@ def scheduler_fn(self, profiler_step: int) -> ProfilerAction: torch_scheduler_action = ProfilerAction.RECORD_AND_SAVE return torch_scheduler_action - def training_start(self, state: State, logger: Logger) -> None: + def run_event(self, event: Event, state: State, logger: Logger) -> None: + super().run_event(event, state, logger) + if event == Event.TRAINING_START: + self._training_start(state, logger) + if event == Event.BATCH_START: + self._batch_start(state, logger) + if event == Event.BATCH_END: + self._batch_end(state, logger) + + def _training_start(self, state: State, logger: Logger) -> None: del state, logger # unused assert self.profiler is None, _PROFILE_MISSING_ERROR self.profiler = torch.profiler.profile( @@ -151,16 +160,16 @@ def training_start(self, state: State, logger: Logger) -> None: self.profiler.__enter__() atexit.register(self._close_profiler) - def batch_end(self, state: State, logger: Logger) -> None: + def _batch_end(self, state: State, logger: Logger) -> None: del state, logger # unused assert self.profiler is not None, _PROFILE_MISSING_ERROR self.profiler.step() - def epoch_start(self, state: State, logger: Logger) -> None: + def _epoch_start(self, state: State, logger: Logger) -> None: del logger # unused self.profiler_state.batches_per_epoch = state.steps_per_epoch - def batch_start(self, state: State, logger: Logger) -> None: + def _batch_start(self, state: State, logger: Logger) -> None: self.profiler_state.batch_in_epoch = state.batch_idx assert self.profiler is not None, _PROFILE_MISSING_ERROR logger.metric_batch({"profiler/state": self.profiler.current_action.name}) diff --git a/composer/core/callback.py b/composer/core/callback.py index 1c71736434..aa6c9533d7 100644 --- a/composer/core/callback.py +++ b/composer/core/callback.py @@ -5,11 +5,17 @@ from __future__ import annotations import abc +import warnings from typing import TYPE_CHECKING from composer.core.serializable import Serializable from composer.utils.ddp import is_rank_zero +try: + from typing import final +except ImportError: + final = lambda x: x # final is not available in python 3.7 + if TYPE_CHECKING: from composer import Event, Logger, State @@ -22,33 +28,15 @@ class Callback(Serializable, abc.ABC): they are run on specific events. By convention, Callbacks should not modify :class:`State`. - Each method name corresponds to an :class:`Event`. - - Subclasses of callbacks should override these methods to run in response - to given :class:`Event` invocations. + Subclasses should override :meth:`~Callback.run_event` + to run in response to given :class:`Event` invocations. """ def __init__(self) -> None: super().__init__() - def init(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.INIT` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - def run_event(self, event: Event, state: State, logger: Logger) -> None: - """This method is called by the engine on each event. By default, it - invokes the callback function for the event (for example, - `self.run_event(Event.TRAINING_START, state, logger)` invokes - `self.training_start(state, logger)`). If this method is overridden, - the subclass method should include `super().run_event(event, state, logger)` - so all callback methods will be invoked. + """This method is called by the engine on each event. Args: event (Event): The event. @@ -58,246 +46,29 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: try: event_cb = getattr(self, event.value) except AttributeError: - raise ValueError(f'Callback {self} has no method for event {event}') + # Good -- the callback does not override any methods + return + warnings.warn( + f"CallbackMethodDeprecationWarning: `self.{event.value}()` will be removed in callbacks." + "Instead, override `self.run_event()`.", + category=DeprecationWarning) return event_cb(state, logger) - def training_start(self, state: State, logger: Logger) -> None: - """Called on the :attr:`Event.TRAINING_START` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def epoch_start(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.EPOCH_START` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def batch_start(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.BATCH_START` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def after_dataloader(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.AFTER_DATALOADER` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def before_train_batch(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.BEFORE_TRAIN_BATCH` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def before_forward(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.BEFORE_FORWARD` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def after_forward(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.AFTER_FORWARD` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def before_loss(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.BEFORE_LOSS` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def after_loss(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.AFTER_LOSS` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def before_backward(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.BEFORE_BACKWARD` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def after_backward(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.AFTER_BACKWARD` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def after_train_batch(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.AFTER_TRAIN_BATCH` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def batch_end(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.BATCH_END` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def epoch_end(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.EPOCH_END` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def training_end(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.TRAINING_END` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def eval_start(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.EVAL_START` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def eval_batch_start(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.EVAL_BATCH_START` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def eval_before_forward(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.EVAL_BATCH_FORWARD` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def eval_after_forward(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.EVAL_AFTER_FORWARD` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def eval_batch_end(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.EVAL_BATCH_END` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - - def eval_end(self, state: State, logger: Logger) -> None: - """Called on the :attr:`~Event.EVAL_END` event. - - Args: - state (State): The global state. - logger (Logger): The logger. - - """ - del state, logger # unused - pass - class RankZeroCallback(Callback, abc.ABC): """Base class for callbacks that only run on the rank zero process. + + Subclasses should override :meth:`_run_event` + (**not** `run_event`) to run in response + to given :class:`Event` invocations. """ + @final def run_event(self, event: Event, state: State, logger: Logger) -> None: + super().run_event(event, state, logger) if not is_rank_zero(): return - return super().run_event(event, state, logger) + return self._run_event(event, state, logger) + + def _run_event(self, event: Event, state: State, logger: Logger) -> None: + pass diff --git a/composer/core/logging/base_backend.py b/composer/core/logging/base_backend.py index ef99ac79be..e0a61bc9fb 100644 --- a/composer/core/logging/base_backend.py +++ b/composer/core/logging/base_backend.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple from composer.core.callback import Callback, RankZeroCallback +from composer.core.event import Event from composer.core.logging.logger import Logger from composer.utils.ddp import is_rank_zero @@ -150,22 +151,11 @@ def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) return return self._log_metric(epoch, step, log_level, data) - def _training_start(self, state: State, logger: Logger) -> None: - """Callback called on the - :attr:`~composer.core.event.Event.TRAINING_START` event. - - Args: - state (State): The global state. - logger (Logger): The global logger. - """ - del state, logger # unused - pass - - @final - def training_start(self, state: State, logger: Logger) -> None: - self._training_start(state, logger) # initialize the logger - if self._deferred_log_metric_calls is None: - raise RuntimeError("_deferred_log_metric_calls should not be None") - for epoch, step, log_level, data in self._deferred_log_metric_calls: - self._log_metric(epoch, step, log_level, data) - self._deferred_log_metric_calls = None + def _run_event(self, event: Event, state: State, logger: Logger) -> None: + super()._run_event(event, state, logger) + if event == Event.TRAINING_START: + if self._deferred_log_metric_calls is None: + raise RuntimeError("_deferred_log_metric_calls should not be None") + for epoch, step, log_level, data in self._deferred_log_metric_calls: + self._log_metric(epoch, step, log_level, data) + self._deferred_log_metric_calls = None diff --git a/composer/loggers/file_logger.py b/composer/loggers/file_logger.py index 3cc42c88e9..a32c850cf8 100644 --- a/composer/loggers/file_logger.py +++ b/composer/loggers/file_logger.py @@ -9,6 +9,7 @@ import yaml +from composer.core.event import Event from composer.core.logging import Logger, LogLevel, RankZeroLoggerBackend, TLogData, format_log_data_value from composer.core.state import State from composer.loggers.logger_hparams import FileLoggerBackendHparams @@ -79,6 +80,13 @@ def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData data_str = format_log_data_value(data) print(f"[{log_level.name}][step={step}]: {data_str}", file=self.file) + def _run_event(self, event: Event, state: State, logger: Logger) -> None: + if event == Event.TRAINING_START: + self._training_start(state, logger) + if event == Event.BATCH_END: + self._batch_end(state, logger) + super()._run_event(event, state, logger) + def _training_start(self, state: State, logger: Logger) -> None: if self.hparams.filename == "stdout": self.file = sys.stdout @@ -94,7 +102,7 @@ def _training_start(self, state: State, logger: Logger) -> None: print("-" * 30, file=self.file) print(file=self.file) - def batch_end(self, state: State, logger: Logger) -> None: + def _batch_end(self, state: State, logger: Logger) -> None: assert self.file is not None if (state.step + 1) % self.hparams.flush_every_n_batches == 0 and self.file not in (sys.stdout, sys.stderr): self.file.flush() diff --git a/composer/loggers/tqdm_logger.py b/composer/loggers/tqdm_logger.py index 3ea9755b04..ba28a3ca07 100644 --- a/composer/loggers/tqdm_logger.py +++ b/composer/loggers/tqdm_logger.py @@ -9,6 +9,7 @@ import yaml from tqdm import tqdm +from composer.core.event import Event from composer.core.logging import LogLevel, RankZeroLoggerBackend, TLogData, TLogDataValue, format_log_data_value from composer.core.state import State from composer.core.types import StateDict @@ -16,12 +17,14 @@ if TYPE_CHECKING: from composer.core.logging import Logger +_IS_TRAIN_TO_KEYS_TO_LOG = {True: ['loss/train'], False: ['accuracy/val']} + @dataclass class _TQDMLoggerInstanceState: total: int epoch: int - val: bool + is_train: bool n: int keys_to_log: Sequence[str] epoch_metrics: Dict[str, TLogDataValue] = field(default_factory=dict) @@ -32,18 +35,17 @@ class _TQDMLoggerInstance: def __init__(self, total: int, epoch: int, - val: bool, - keys_to_log: Sequence[str], + is_train: bool, n: int = 0, epoch_metrics: Optional[Dict[str, TLogDataValue]] = None) -> None: self.state = _TQDMLoggerInstanceState(total=total, epoch=epoch, - val=val, + is_train=is_train, n=n, - keys_to_log=keys_to_log, + keys_to_log=_IS_TRAIN_TO_KEYS_TO_LOG[is_train], epoch_metrics=(epoch_metrics or {})) - desc = f'Epoch {epoch + 1}{" (val)" if val else ""}' - position = 1 if val else 0 + desc = f'Epoch {epoch + 1}{"" if is_train else " (val)"}' + position = 0 if is_train else 1 self.pbar = tqdm(total=total, desc=desc, position=position, @@ -90,9 +92,8 @@ class TQDMLoggerBackend(RankZeroLoggerBackend): def __init__(self, config: Optional[Dict[str, Any]] = None) -> None: super().__init__() - self.pbar_train: Optional[_TQDMLoggerInstance] = None - self.pbar_val: Optional[_TQDMLoggerInstance] = None - self.is_validating = False + self.pbars: Dict[bool, _TQDMLoggerInstance] = {} + self.active_pbar: Optional[_TQDMLoggerInstance] = None self.config = config def _will_log(self, state: State, log_level: LogLevel) -> bool: @@ -101,73 +102,37 @@ def _will_log(self, state: State, log_level: LogLevel) -> bool: def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) -> None: del epoch, step, log_level # Unused - pbar = self.pbar_val if self.is_validating else self.pbar_train - if pbar is None: + if self.active_pbar is None: # Logging outside an epoch return - pbar.log_metric(data) - - def _training_start(self, state: State, logger: Logger) -> None: - del state, logger # Unused - if self.config is not None: - print("Config") - print("-" * 30) - yaml.safe_dump(self.config, stream=sys.stdout) - print("-" * 30) - print() - - def epoch_start(self, state: State, logger: Logger) -> None: - del logger # Unused - assert self.pbar_train is None - self.pbar_train = _TQDMLoggerInstance(total=state.steps_per_epoch, - epoch=state.epoch, - val=False, - keys_to_log=["loss/train"]) - - def after_backward(self, state: State, logger: Logger) -> None: - del state, logger # Unused - assert self.pbar_train is not None - self.pbar_train.update() - - def epoch_end(self, state: State, logger: Logger) -> None: - del state, logger # Unused - assert self.pbar_train is not None - self.pbar_train.close() - self.pbar_train = None - - def eval_start(self, state: State, logger: Logger) -> None: - del logger # Unused - assert self.pbar_val is None - assert state.eval_dataloader is not None - self.pbar_val = _TQDMLoggerInstance(total=len(state.eval_dataloader), - epoch=state.epoch, - val=True, - keys_to_log=["accuracy/val"]) - self.is_validating = True - - def eval_after_forward(self, state: State, logger: Logger) -> None: - del state, logger # Unused - assert self.pbar_val is not None - self.pbar_val.update() - - def eval_end(self, state: State, logger: Logger) -> None: - del state, logger # Unused - assert self.pbar_val is not None - self.pbar_val.close() - self.pbar_val = None - self.is_validating = False + self.active_pbar.log_metric(data) + + def _run_event(self, event: Event, state: State, logger: Logger) -> None: + if event == Event.TRAINING_START: + if self.config is not None: + print("Config") + print("-" * 30) + yaml.safe_dump(self.config, stream=sys.stdout) + print("-" * 30) + print() + if event in (Event.EPOCH_START, Event.EVAL_START): + is_train = event == Event.TRAINING_START + self.pbars[is_train] = _TQDMLoggerInstance(total=state.steps_per_epoch, + epoch=state.epoch, + is_train=is_train) + self.active_pbar = self.pbars[is_train] + if event in (Event.AFTER_BACKWARD, Event.EVAL_AFTER_FORWARD): + assert self.active_pbar is not None + self.active_pbar.update() + if event in (Event.EPOCH_END, Event.EVAL_END): + assert self.active_pbar is not None + self.active_pbar.close() + self.active_pbar = None + + super()._run_event(event, state, logger) def state_dict(self) -> StateDict: - state = {"is_validating": self.is_validating} - if self.pbar_train: - state["pbar_train"] = self.pbar_train.state_dict() - if self.pbar_val: - state["pbar_val"] = self.pbar_val.state_dict() - return state + return {"pbars": {k: v.state_dict() for (k, v) in self.pbars.items()}} def load_state_dict(self, state: StateDict) -> None: - self.is_validating = state["is_validating"] - if "pbar_train" in state: - self.pbar_train = _TQDMLoggerInstance(**state["pbar_train"]) - if "pbar_val" in state: - self.pbar_val = _TQDMLoggerInstance(**state["pbar"]) + self.pbars = {k: _TQDMLoggerInstance(**v) for (k, v) in state["pbars"].items()} diff --git a/composer/loggers/wandb_logger.py b/composer/loggers/wandb_logger.py index ae8948d40f..d23d2036f5 100644 --- a/composer/loggers/wandb_logger.py +++ b/composer/loggers/wandb_logger.py @@ -7,6 +7,7 @@ import sys from typing import Any +from composer.core.event import Event from composer.core.logging import LogLevel, RankZeroLoggerBackend, TLogData from composer.core.types import Logger, State, StateDict from composer.utils.run_directory import get_run_directory @@ -34,32 +35,32 @@ def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData del epoch, log_level # unused wandb.log(data, step=step) - def _training_start(self, state: State, logger: Logger) -> None: - del state, logger # unused - wandb.init(**self._init_params) - atexit.register(self._close_wandb) - def state_dict(self) -> StateDict: # Storing these fields in the state dict to support run resuming in the future. return {"name": wandb.run.name, "project": wandb.run.project, "entity": wandb.run.entity, "id": wandb.run.id} - def batch_end(self, state: State, logger: Logger) -> None: - del logger # unused - # On resnet50, _log_artifacts() caused a 22% throughput degradation - # wandb.log_artifact() is async according to the docs - # (see https://docs.wandb.ai/guides/artifacts/api#2.-create-an-artifact) - # so uploads will not block the training loop - # slowdown is likely from extra I/O - # Hence, logging every n batches instead of every batch - if (state.step + 1) % self.log_artifacts_every_n_batches == 0: + def _run_event(self, event: Event, state: State, logger: Logger) -> None: + if event == Event.TRAINING_START: + wandb.init(**self._init_params) + atexit.register(self._close_wandb) + + if event == Event.BATCH_END: + if (state.step + 1) % self.log_artifacts_every_n_batches == 0: + self._log_artifacts() + + if event == Event.EPOCH_END: self._log_artifacts() - def epoch_end(self, state: State, logger: Logger) -> None: - del state, logger # unused - self._log_artifacts() + super()._run_event(event, state, logger) def _log_artifacts(self): # Scan the run directory and upload artifacts to wandb + # On resnet50, _log_artifacts() caused a 22% throughput degradation + # wandb.log_artifact() is async according to the docs + # (see https://docs.wandb.ai/guides/artifacts/api#2.-create-an-artifact) + # so uploads will not block the training loop + # slowdown is likely from extra I/O of scanning the directory and/or + # scheduling uploads run_directory = get_run_directory() if run_directory is not None: for subfile in os.listdir(run_directory): diff --git a/docs/source/core/callback.rst b/docs/source/core/callback.rst index b3d4340ccf..2b3c9f6ecc 100644 --- a/docs/source/core/callback.rst +++ b/docs/source/core/callback.rst @@ -12,8 +12,7 @@ By convention, callbacks should not modify the :class:`State`. Each callback inherits from the :class:`Callback` base class, -and overrides functions corresponding to the event. - +and overrides the :meth:`~Callback.run_event` method. For example: @@ -23,8 +22,9 @@ For example: class MyCallback(Callback) - def epoch_start(self, state: State, logger: Logger): - print(f'Epoch {state.epoch}/{state.max_epochs}') + def run_event(self, event: Event, state: State, logger: Logger): + if event == Event.EPOCH_START: + print(f'Epoch {state.epoch}/{state.max_epochs}') .. note:: diff --git a/pyproject.toml b/pyproject.toml index 26f0faed2a..8e528a02ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,8 @@ filterwarnings = [ # "error", # warnings should be treated like errors, but still need to fix some warnings 'ignore:ExtraArgumentWarning', # extra arguments originate from pytest-specific CLI args 'ignore:DeferredLogMetricWarning', # deferred logging is fine + 'ignore:DDPDefaultValueWarning', # OK to assume no ddp + 'ignore:NoDDPWarning', # OK to assume no ddp ] # Coverage diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 485f78abb0..11ba661ba4 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -10,30 +10,22 @@ from composer.core.state import State -def test_callbacks_map_to_events(): - # callback methods must be 1:1 mapping with events - # exception for private methods - cb = Callback() - excluded_methods = ["state_dict", "load_state_dict", "setup", "run_event"] - methods = set(m for m in dir(cb) if (m not in excluded_methods and not m.startswith("_"))) - event_names = set(e.value for e in Event) - assert methods == event_names +class EventTrackerCallback(Callback): + + def __init__(self) -> None: + super().__init__() + self.event = None + + def run_event(self, event: Event, state: State, logger: Logger) -> None: + self.event = event @pytest.mark.parametrize('event', list(Event)) def test_run_event_callbacks(event: Event, dummy_state: State, monkeypatch: _pytest.monkeypatch.MonkeyPatch): - callback = Callback() + callback = EventTrackerCallback() logger = Logger(dummy_state) engine = Engine(state=dummy_state, algorithms=[], logger=logger, callbacks=[callback]) - called = [False] # storing as an array so it is a pointer, not primitive - - def patched_callback(state: State, logger: Logger): - - called[0] = True - - monkeypatch.setattr(callback, event.value, patched_callback) - engine.run_event(event) - assert called[0], "callback method not invoked" + assert callback.event == event diff --git a/tests/trainer/test_ddp.py b/tests/trainer/test_ddp.py index c67791a33e..f762fe9214 100755 --- a/tests/trainer/test_ddp.py +++ b/tests/trainer/test_ddp.py @@ -13,13 +13,14 @@ from _pytest.monkeypatch import MonkeyPatch import composer.core.types as types -from composer import Callback +from composer import Callback, Event from composer.callbacks import CallbackHparams from composer.core.logging import Logger from composer.core.state import State from composer.datasets import DataloaderHparams, DataloaderSpec, MemoryFormat, SyntheticDataset, SyntheticDatasetHparams from composer.trainer.devices import CPUDeviceHparams, GPUDeviceHparams from composer.trainer.trainer_hparams import TrainerHparams, callback_registry, dataset_registry +from composer.utils.ddp import get_global_rank from tests.fixtures.ddp_fixtures import with_distributed from tests.fixtures.models import SimpleBatchPairModelHparams @@ -91,29 +92,23 @@ def __init__(self, tmpdir: str): super().__init__() self.tmpdir = tmpdir - def before_forward(self, state: State, logger: Logger): - if state.batch_idx > 0: - return - rank: int = torch.distributed.get_rank() - last_input, last_target = state.batch_pair - torch.save( # type: ignore - { - "last_input": last_input, - "last_target": last_target, - }, get_batch_file_path(self.tmpdir, rank=rank, epoch=state.epoch, is_train=True)) - - def eval_before_forward(self, state: State, logger: Logger): - rank: int = torch.distributed.get_rank() - filepath = get_batch_file_path(self.tmpdir, rank=rank, epoch=state.epoch, is_train=False) - if os.path.exists(filepath): - return - assert not state.model.training - last_input, last_target = state.batch_pair - torch.save( # type: ignore - { - "last_input": last_input, - "last_target": last_target, - }, get_batch_file_path(self.tmpdir, rank=rank, epoch=state.epoch, is_train=False)) + def run_event(self, event: Event, state: State, logger: Logger) -> None: + super().run_event(event, state, logger) + if event in (Event.BEFORE_FORWARD, Event.EVAL_BEFORE_FORWARD): + filepath = get_batch_file_path(self.tmpdir, + rank=get_global_rank(), + epoch=state.epoch, + is_train=state.model.training) + if state.batch_idx > 0: + return + if os.path.exists(filepath): + return + last_input, last_target = state.batch_pair + torch.save( # type: ignore + { + "last_input": last_input, + "last_target": last_target, + }, filepath) @dataclass @@ -200,7 +195,7 @@ def _test_ddp(is_gpu: bool, num_procs: int, tmpdir: pathlib.Path, mosaic_trainer hparams.loggers = [] hparams.validate_every_n_batches = 0 hparams.validate_every_n_epochs = 1 - hparams.callbacks.append(CheckBatch0Hparams(tmpdir=tmpdir)) + hparams.callbacks.append(CheckBatch0Hparams(tmpdir=str(tmpdir))) trainer = hparams.initialize_object() assert trainer.state.world_size == num_procs assert trainer.state.local_world_size == num_procs From 0f1aa695d518b8e4fdca3aad85e80ce2c68271fc Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Mon, 15 Nov 2021 17:06:41 -0800 Subject: [PATCH 03/38] Fixed tests --- composer/loggers/tqdm_logger.py | 44 +++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/composer/loggers/tqdm_logger.py b/composer/loggers/tqdm_logger.py index ba28a3ca07..10f8c7f717 100644 --- a/composer/loggers/tqdm_logger.py +++ b/composer/loggers/tqdm_logger.py @@ -4,7 +4,7 @@ import sys from dataclasses import asdict, dataclass, field -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, Optional import yaml from tqdm import tqdm @@ -26,7 +26,6 @@ class _TQDMLoggerInstanceState: epoch: int is_train: bool n: int - keys_to_log: Sequence[str] epoch_metrics: Dict[str, TLogDataValue] = field(default_factory=dict) @@ -42,7 +41,6 @@ def __init__(self, epoch=epoch, is_train=is_train, n=n, - keys_to_log=_IS_TRAIN_TO_KEYS_TO_LOG[is_train], epoch_metrics=(epoch_metrics or {})) desc = f'Epoch {epoch + 1}{"" if is_train else " (val)"}' position = 0 if is_train else 1 @@ -54,7 +52,9 @@ def __init__(self, self.pbar.set_postfix(epoch_metrics) def log_metric(self, data: TLogData): - formatted_data = {k: format_log_data_value(v) for (k, v) in data.items() if k in self.state.keys_to_log} + formatted_data = { + k: format_log_data_value(v) for (k, v) in data.items() if k in _IS_TRAIN_TO_KEYS_TO_LOG[self.state.is_train] + } self.state.epoch_metrics.update(formatted_data) self.pbar.set_postfix(self.state.epoch_metrics) @@ -93,7 +93,7 @@ class TQDMLoggerBackend(RankZeroLoggerBackend): def __init__(self, config: Optional[Dict[str, Any]] = None) -> None: super().__init__() self.pbars: Dict[bool, _TQDMLoggerInstance] = {} - self.active_pbar: Optional[_TQDMLoggerInstance] = None + self.is_train: Optional[bool] = None self.config = config def _will_log(self, state: State, log_level: LogLevel) -> bool: @@ -102,10 +102,10 @@ def _will_log(self, state: State, log_level: LogLevel) -> bool: def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) -> None: del epoch, step, log_level # Unused - if self.active_pbar is None: + if self.is_train in self.pbars: # Logging outside an epoch - return - self.active_pbar.log_metric(data) + assert self.is_train is not None + self.pbars[self.is_train].log_metric(data) def _run_event(self, event: Event, state: State, logger: Logger) -> None: if event == Event.TRAINING_START: @@ -116,23 +116,29 @@ def _run_event(self, event: Event, state: State, logger: Logger) -> None: print("-" * 30) print() if event in (Event.EPOCH_START, Event.EVAL_START): - is_train = event == Event.TRAINING_START - self.pbars[is_train] = _TQDMLoggerInstance(total=state.steps_per_epoch, - epoch=state.epoch, - is_train=is_train) - self.active_pbar = self.pbars[is_train] + self.is_train = event == Event.EPOCH_START + self.pbars[self.is_train] = _TQDMLoggerInstance(total=state.steps_per_epoch, + epoch=state.epoch, + is_train=self.is_train) if event in (Event.AFTER_BACKWARD, Event.EVAL_AFTER_FORWARD): - assert self.active_pbar is not None - self.active_pbar.update() + if self.is_train in self.pbars: + assert self.is_train is not None + self.pbars[self.is_train].update() if event in (Event.EPOCH_END, Event.EVAL_END): - assert self.active_pbar is not None - self.active_pbar.close() - self.active_pbar = None + if self.is_train in self.pbars: + assert self.is_train is not None + self.pbars[self.is_train].close() + del self.pbars[self.is_train] + self.is_train = None super()._run_event(event, state, logger) def state_dict(self) -> StateDict: - return {"pbars": {k: v.state_dict() for (k, v) in self.pbars.items()}} + return { + "pbars": {k: v.state_dict() for (k, v) in self.pbars.items()}, + "is_train": self.is_train, + } def load_state_dict(self, state: StateDict) -> None: self.pbars = {k: _TQDMLoggerInstance(**v) for (k, v) in state["pbars"].items()} + self.is_train = state["is_train"] From 06cac4b926ce975856fac986a4241941a5e3a6b3 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Mon, 15 Nov 2021 17:07:01 -0800 Subject: [PATCH 04/38] Formatting --- composer/callbacks/speed_monitor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index e8c0c9f5e5..f58fb267c8 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -6,7 +6,7 @@ from collections import deque from typing import Deque, Optional -from composer import Logger, State, Event +from composer import Event, Logger, State from composer.callbacks.callback_hparams import SpeedMonitorHparams from composer.core.callback import RankZeroCallback from composer.core.types import StateDict From d886af6169a12955b8655ddd412f42e2255d7e0d Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Thu, 18 Nov 2021 11:02:44 -0800 Subject: [PATCH 05/38] Addressed PR feedback --- composer/callbacks/benchmarker.py | 3 +-- composer/callbacks/grad_monitor.py | 3 +-- composer/callbacks/lr_monitor.py | 3 +-- composer/callbacks/memory_monitor.py | 3 +-- composer/callbacks/speed_monitor.py | 1 - composer/callbacks/torch_profiler.py | 3 +-- composer/core/callback.py | 10 ++++++++-- composer/core/logging/base_backend.py | 14 ++++++++++---- composer/loggers/file_logger.py | 1 - composer/loggers/tqdm_logger.py | 2 -- composer/loggers/wandb_logger.py | 2 -- docs/source/core/callback.rst | 2 +- tests/callbacks/test_callbacks.py | 2 +- tests/trainer/test_checkpoint.py | 2 +- tests/trainer/test_ddp.py | 2 +- 15 files changed, 27 insertions(+), 26 deletions(-) diff --git a/composer/callbacks/benchmarker.py b/composer/callbacks/benchmarker.py index 4a7353472a..bfd7c9a24c 100644 --- a/composer/callbacks/benchmarker.py +++ b/composer/callbacks/benchmarker.py @@ -107,8 +107,7 @@ def __init__(self, self.original_max_epochs = -1 self.wct_dict = {} - def run_event(self, event: Event, state: State, logger: Logger) -> None: - super().run_event(event, state, logger) + def _run_event(self, event: Event, state: State, logger: Logger) -> None: if event == Event.TRAINING_START: self._training_start(state, logger) if event == Event.BATCH_START: diff --git a/composer/callbacks/grad_monitor.py b/composer/callbacks/grad_monitor.py index bef07df5c3..0bcca9fb22 100644 --- a/composer/callbacks/grad_monitor.py +++ b/composer/callbacks/grad_monitor.py @@ -24,7 +24,7 @@ def __init__(self, log_layer_grad_norms: bool = False): super().__init__() self.log_layer_grad_norms = log_layer_grad_norms - def run_event(self, event: Event, state: State, logger: Logger): + def _run_event(self, event: Event, state: State, logger: Logger): """Compute the gradient L2 norm after the reduction of the backwards pass across GPUs. This function iterates over the parameters of the model and hence may cause a reduction in @@ -39,7 +39,6 @@ def run_event(self, event: Event, state: State, logger: Logger): logger (Logger): The :class:`~composer.core.logging.logger.Logger` object. """ - super().run_event(event, state, logger) if event != Event.AFTER_TRAIN_BATCH: return norm = None diff --git a/composer/callbacks/lr_monitor.py b/composer/callbacks/lr_monitor.py index 3958405f92..aa98988c7f 100644 --- a/composer/callbacks/lr_monitor.py +++ b/composer/callbacks/lr_monitor.py @@ -14,8 +14,7 @@ class LRMonitor(Callback): def __init__(self) -> None: super().__init__() - def run_event(self, event: Event, state: State, logger: Logger): - super().run_event(event, state, logger) + def _run_event(self, event: Event, state: State, logger: Logger): if event != Event.BATCH_END: return assert state.optimizers is not None, "optimizers must be defined" diff --git a/composer/callbacks/memory_monitor.py b/composer/callbacks/memory_monitor.py index ef42861b43..e9382f8bdb 100755 --- a/composer/callbacks/memory_monitor.py +++ b/composer/callbacks/memory_monitor.py @@ -26,8 +26,7 @@ def __init__(self): if device_count == 0: log.warn("Memory monitor only works on GPU devices.") - def run_event(self, event: Event, state: State, logger: Logger): - super().run_event(event, state, logger) + def _run_event(self, event: Event, state: State, logger: Logger): if event != Event.AFTER_TRAIN_BATCH: return diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index f58fb267c8..a21cbc24d2 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -72,7 +72,6 @@ def _run_event(self, event: Event, state: State, logger: Logger) -> None: self._batch_end(state, logger) if event == Event.EPOCH_END: self._epoch_end(state, logger) - super()._run_event(event, state, logger) def _epoch_start(self, state: State, logger: Logger): del state, logger # unused diff --git a/composer/callbacks/torch_profiler.py b/composer/callbacks/torch_profiler.py index d8ca2467e3..c797f6581b 100644 --- a/composer/callbacks/torch_profiler.py +++ b/composer/callbacks/torch_profiler.py @@ -132,8 +132,7 @@ def scheduler_fn(self, profiler_step: int) -> ProfilerAction: torch_scheduler_action = ProfilerAction.RECORD_AND_SAVE return torch_scheduler_action - def run_event(self, event: Event, state: State, logger: Logger) -> None: - super().run_event(event, state, logger) + def _run_event(self, event: Event, state: State, logger: Logger) -> None: if event == Event.TRAINING_START: self._training_start(state, logger) if event == Event.BATCH_START: diff --git a/composer/core/callback.py b/composer/core/callback.py index aa6c9533d7..a8063f05ff 100644 --- a/composer/core/callback.py +++ b/composer/core/callback.py @@ -28,13 +28,15 @@ class Callback(Serializable, abc.ABC): they are run on specific events. By convention, Callbacks should not modify :class:`State`. - Subclasses should override :meth:`~Callback.run_event` - to run in response to given :class:`Event` invocations. + Subclasses should override :meth:`_run_event` + (**not** `run_event`) to run in response + to given :class:`Event` invocations. """ def __init__(self) -> None: super().__init__() + @final def run_event(self, event: Event, state: State, logger: Logger) -> None: """This method is called by the engine on each event. @@ -47,6 +49,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: event_cb = getattr(self, event.value) except AttributeError: # Good -- the callback does not override any methods + self._run_event(event, state, logger) return warnings.warn( f"CallbackMethodDeprecationWarning: `self.{event.value}()` will be removed in callbacks." @@ -54,6 +57,9 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: category=DeprecationWarning) return event_cb(state, logger) + def _run_event(self, event: Event, state: State, logger: Logger) -> None: + pass + class RankZeroCallback(Callback, abc.ABC): """Base class for callbacks that only run on the rank zero process. diff --git a/composer/core/logging/base_backend.py b/composer/core/logging/base_backend.py index e0a61bc9fb..febdab245d 100644 --- a/composer/core/logging/base_backend.py +++ b/composer/core/logging/base_backend.py @@ -6,7 +6,7 @@ from abc import ABC from typing import TYPE_CHECKING, List, Optional, Tuple -from composer.core.callback import Callback, RankZeroCallback +from composer.core.callback import Callback from composer.core.event import Event from composer.core.logging.logger import Logger from composer.utils.ddp import is_rank_zero @@ -64,7 +64,7 @@ def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) pass -class RankZeroLoggerBackend(BaseLoggerBackend, RankZeroCallback, ABC): +class RankZeroLoggerBackend(BaseLoggerBackend, Callback, ABC): """Base class for logging backends that run only on the rank zero process. In a multi-process training setup (e.g. when using DistributedDataParallel), @@ -151,11 +151,17 @@ def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) return return self._log_metric(epoch, step, log_level, data) - def _run_event(self, event: Event, state: State, logger: Logger) -> None: - super()._run_event(event, state, logger) + @final + def run_event(self, event: Event, state: State, logger: Logger) -> None: + if not is_rank_zero(): + return + self._run_event(event, state, logger) if event == Event.TRAINING_START: if self._deferred_log_metric_calls is None: raise RuntimeError("_deferred_log_metric_calls should not be None") for epoch, step, log_level, data in self._deferred_log_metric_calls: self._log_metric(epoch, step, log_level, data) self._deferred_log_metric_calls = None + + def _run_event(self, event: Event, state: State, logger: Logger) -> None: + pass diff --git a/composer/loggers/file_logger.py b/composer/loggers/file_logger.py index a32c850cf8..809f033b7d 100644 --- a/composer/loggers/file_logger.py +++ b/composer/loggers/file_logger.py @@ -85,7 +85,6 @@ def _run_event(self, event: Event, state: State, logger: Logger) -> None: self._training_start(state, logger) if event == Event.BATCH_END: self._batch_end(state, logger) - super()._run_event(event, state, logger) def _training_start(self, state: State, logger: Logger) -> None: if self.hparams.filename == "stdout": diff --git a/composer/loggers/tqdm_logger.py b/composer/loggers/tqdm_logger.py index 10f8c7f717..f6a6bceb9b 100644 --- a/composer/loggers/tqdm_logger.py +++ b/composer/loggers/tqdm_logger.py @@ -131,8 +131,6 @@ def _run_event(self, event: Event, state: State, logger: Logger) -> None: del self.pbars[self.is_train] self.is_train = None - super()._run_event(event, state, logger) - def state_dict(self) -> StateDict: return { "pbars": {k: v.state_dict() for (k, v) in self.pbars.items()}, diff --git a/composer/loggers/wandb_logger.py b/composer/loggers/wandb_logger.py index d23d2036f5..25cccabe2a 100644 --- a/composer/loggers/wandb_logger.py +++ b/composer/loggers/wandb_logger.py @@ -51,8 +51,6 @@ def _run_event(self, event: Event, state: State, logger: Logger) -> None: if event == Event.EPOCH_END: self._log_artifacts() - super()._run_event(event, state, logger) - def _log_artifacts(self): # Scan the run directory and upload artifacts to wandb # On resnet50, _log_artifacts() caused a 22% throughput degradation diff --git a/docs/source/core/callback.rst b/docs/source/core/callback.rst index 2b3c9f6ecc..9c4b0ab020 100644 --- a/docs/source/core/callback.rst +++ b/docs/source/core/callback.rst @@ -22,7 +22,7 @@ For example: class MyCallback(Callback) - def run_event(self, event: Event, state: State, logger: Logger): + def _run_event(self, event: Event, state: State, logger: Logger): if event == Event.EPOCH_START: print(f'Epoch {state.epoch}/{state.max_epochs}') diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 11ba661ba4..8007993d85 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -16,7 +16,7 @@ def __init__(self) -> None: super().__init__() self.event = None - def run_event(self, event: Event, state: State, logger: Logger) -> None: + def _run_event(self, event: Event, state: State, logger: Logger) -> None: self.event = event diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 359c13b0ad..df95046bf1 100755 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -53,7 +53,7 @@ def __init__(self) -> None: for event in Event: self.event_to_num_calls[event] = 0 - def run_event(self, event: Event, state: State, logger: Logger): + def _run_event(self, event: Event, state: State, logger: Logger): super().run_event(event, state, logger) if event == Event.TRAINING_START: # ignoring training start as it is called once per startup diff --git a/tests/trainer/test_ddp.py b/tests/trainer/test_ddp.py index f762fe9214..e99f2f0d4c 100755 --- a/tests/trainer/test_ddp.py +++ b/tests/trainer/test_ddp.py @@ -92,7 +92,7 @@ def __init__(self, tmpdir: str): super().__init__() self.tmpdir = tmpdir - def run_event(self, event: Event, state: State, logger: Logger) -> None: + def _run_event(self, event: Event, state: State, logger: Logger) -> None: super().run_event(event, state, logger) if event in (Event.BEFORE_FORWARD, Event.EVAL_BEFORE_FORWARD): filepath = get_batch_file_path(self.tmpdir, From 9644ad96154945b9035f978ed6fe53affd79697f Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Thu, 18 Nov 2021 13:18:14 -0800 Subject: [PATCH 06/38] Fixed tests --- composer/callbacks/speed_monitor.py | 10 ++++------ composer/core/callback.py | 17 ++++------------- composer/core/logging/base_backend.py | 3 --- tests/trainer/test_checkpoint.py | 1 - tests/trainer/test_ddp.py | 1 - 5 files changed, 8 insertions(+), 24 deletions(-) diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index a21cbc24d2..92ac6c2e6a 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -65,16 +65,15 @@ def _load_state(self) -> None: def _run_event(self, event: Event, state: State, logger: Logger) -> None: if event == Event.EPOCH_START: - self._epoch_start(state, logger) + self._epoch_start() if event == Event.BATCH_START: self._load_state() if event == Event.BATCH_END: self._batch_end(state, logger) if event == Event.EPOCH_END: - self._epoch_end(state, logger) + self._epoch_end(logger) - def _epoch_start(self, state: State, logger: Logger): - del state, logger # unused + def _epoch_start(self): self._load_state() self.epoch_start_time = time.time() self.batch_end_times.clear() @@ -96,8 +95,7 @@ def _batch_end(self, state: State, logger: Logger): throughput = sum(self.batch_num_samples) / (self.batch_end_times[-1] - self.batch_end_times[0]) logger.metric_batch({'throughput/step': throughput}) - def _epoch_end(self, state: State, logger: Logger): - del state # unused + def _epoch_end(self, logger: Logger): epoch_time = time.time() - self.epoch_start_time self.wall_clock_train += epoch_time logger.metric_epoch({ diff --git a/composer/core/callback.py b/composer/core/callback.py index a8063f05ff..54ae8477b2 100644 --- a/composer/core/callback.py +++ b/composer/core/callback.py @@ -45,21 +45,16 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: state (State): The state. logger (Logger): The logger. """ - try: - event_cb = getattr(self, event.value) - except AttributeError: - # Good -- the callback does not override any methods - self._run_event(event, state, logger) - return + self._run_event(event, state, logger) + + def _run_event(self, event: Event, state: State, logger: Logger) -> None: warnings.warn( f"CallbackMethodDeprecationWarning: `self.{event.value}()` will be removed in callbacks." "Instead, override `self.run_event()`.", category=DeprecationWarning) + event_cb = getattr(self, event.value) return event_cb(state, logger) - def _run_event(self, event: Event, state: State, logger: Logger) -> None: - pass - class RankZeroCallback(Callback, abc.ABC): """Base class for callbacks that only run on the rank zero process. @@ -71,10 +66,6 @@ class RankZeroCallback(Callback, abc.ABC): @final def run_event(self, event: Event, state: State, logger: Logger) -> None: - super().run_event(event, state, logger) if not is_rank_zero(): return return self._run_event(event, state, logger) - - def _run_event(self, event: Event, state: State, logger: Logger) -> None: - pass diff --git a/composer/core/logging/base_backend.py b/composer/core/logging/base_backend.py index febdab245d..126c7aa02e 100644 --- a/composer/core/logging/base_backend.py +++ b/composer/core/logging/base_backend.py @@ -162,6 +162,3 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: for epoch, step, log_level, data in self._deferred_log_metric_calls: self._log_metric(epoch, step, log_level, data) self._deferred_log_metric_calls = None - - def _run_event(self, event: Event, state: State, logger: Logger) -> None: - pass diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index df95046bf1..b8528b7516 100755 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -54,7 +54,6 @@ def __init__(self) -> None: self.event_to_num_calls[event] = 0 def _run_event(self, event: Event, state: State, logger: Logger): - super().run_event(event, state, logger) if event == Event.TRAINING_START: # ignoring training start as it is called once per startup # and the states otherwise won't match diff --git a/tests/trainer/test_ddp.py b/tests/trainer/test_ddp.py index e99f2f0d4c..b4157b5830 100755 --- a/tests/trainer/test_ddp.py +++ b/tests/trainer/test_ddp.py @@ -93,7 +93,6 @@ def __init__(self, tmpdir: str): self.tmpdir = tmpdir def _run_event(self, event: Event, state: State, logger: Logger) -> None: - super().run_event(event, state, logger) if event in (Event.BEFORE_FORWARD, Event.EVAL_BEFORE_FORWARD): filepath = get_batch_file_path(self.tmpdir, rank=get_global_rank(), From cf5e533b3c4a47c168bba27d704525c940e0e6cd Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Thu, 18 Nov 2021 13:23:39 -0800 Subject: [PATCH 07/38] Formatting --- composer/core/callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/core/callback.py b/composer/core/callback.py index 54ae8477b2..4001f69725 100644 --- a/composer/core/callback.py +++ b/composer/core/callback.py @@ -52,7 +52,7 @@ def _run_event(self, event: Event, state: State, logger: Logger) -> None: f"CallbackMethodDeprecationWarning: `self.{event.value}()` will be removed in callbacks." "Instead, override `self.run_event()`.", category=DeprecationWarning) - event_cb = getattr(self, event.value) + event_cb = getattr(self, event.value) return event_cb(state, logger) From b1bf400abb5686f3f3066564ebd2ea61b72468a2 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Thu, 18 Nov 2021 13:35:21 -0800 Subject: [PATCH 08/38] Fixed _run_event --- composer/core/callback.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/composer/core/callback.py b/composer/core/callback.py index 4001f69725..240ed85f6d 100644 --- a/composer/core/callback.py +++ b/composer/core/callback.py @@ -48,11 +48,15 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: self._run_event(event, state, logger) def _run_event(self, event: Event, state: State, logger: Logger) -> None: + # default fallback if the callback does not override _run_event + try: + event_cb = getattr(self, event.value) + except AttributeError: + return warnings.warn( f"CallbackMethodDeprecationWarning: `self.{event.value}()` will be removed in callbacks." - "Instead, override `self.run_event()`.", + "Instead, override `self._run_event()`.", category=DeprecationWarning) - event_cb = getattr(self, event.value) return event_cb(state, logger) From 4ed9f4fee2f3fcfea875fb8c995b7f4513831269 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Fri, 19 Nov 2021 08:54:11 -0800 Subject: [PATCH 09/38] Formatting --- tests/trainer/test_ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_ddp.py b/tests/trainer/test_ddp.py index cd11c018f3..44790d1de6 100755 --- a/tests/trainer/test_ddp.py +++ b/tests/trainer/test_ddp.py @@ -17,11 +17,11 @@ from composer.callbacks import CallbackHparams from composer.core.logging import Logger from composer.core.state import State -from composer.utils.ddp import get_global_rank from composer.datasets import DataloaderHparams, DataloaderSpec, MemoryFormat, SyntheticDataset, SyntheticDatasetHparams from composer.trainer.devices import CPUDeviceHparams, GPUDeviceHparams from composer.trainer.devices.device_hparams import DeviceHparams from composer.trainer.trainer_hparams import TrainerHparams, callback_registry, dataset_registry +from composer.utils.ddp import get_global_rank from tests.fixtures.models import SimpleBatchPairModelHparams From 75944eb8513de806e2b1042e2b5b63ff21df4cf6 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Fri, 19 Nov 2021 08:54:28 -0800 Subject: [PATCH 10/38] Removed ip --- composer/core/instrumentation_point.py | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 composer/core/instrumentation_point.py diff --git a/composer/core/instrumentation_point.py b/composer/core/instrumentation_point.py deleted file mode 100644 index 50a7c916c1..0000000000 --- a/composer/core/instrumentation_point.py +++ /dev/null @@ -1,3 +0,0 @@ -class InstrumentationPoint: - def __init__(self) -> None: - pass \ No newline at end of file From cee479f0b619171fd5033c78f8a652a17fb9193c Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Fri, 19 Nov 2021 11:28:58 -0800 Subject: [PATCH 11/38] Create dataloader on trainer __init__() #65 made the global rank available in the process start, so it is no longer necessarry to wait until training_start() to create the dataloader. Instead, dataloaders are now initialized in __init__. This change will help with dataloader profiling, as now the dataloader will be immediately bound to the state. --- composer/core/state.py | 8 +-- composer/trainer/trainer.py | 75 ++++++++++----------- tests/algorithms/test_blurpool_algorithm.py | 6 +- tests/algorithms/test_channels_last.py | 6 +- tests/algorithms/test_layer_freezing.py | 25 +++++-- tests/algorithms/test_stochastic_depth.py | 1 + tests/fixtures/dummy_fixtures.py | 10 ++- tests/test_state.py | 11 +-- tests/trainer/test_ddp_sync_strategy.py | 7 +- 9 files changed, 86 insertions(+), 63 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index b980255d6d..98b22bb890 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -108,6 +108,10 @@ class State(Serializable): # stopping conditions max_epochs: int + # dataloaders + train_dataloader: types.DataLoader + eval_dataloader: types.DataLoader + # precision # storing precision internally so strings can be passed into the constructor and setter # but the getter will always return a Precision enum @@ -134,10 +138,6 @@ class State(Serializable): # scaler scaler: Optional[types.Scaler] = None - # dataloaders - train_dataloader: Optional[types.DataLoader] = None - eval_dataloader: Optional[types.DataLoader] = None - # algorithms algorithms: Sequence[Algorithm] = tuple() callbacks: Sequence[Callback] = tuple() diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index e835891cc8..6f62dd3dd4 100755 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -203,15 +203,39 @@ def __init__( timeout=ddp_timeout, ) - self.state = State(max_epochs=max_epochs, - train_batch_size=train_batch_size, - eval_batch_size=eval_batch_size, - algorithms=algorithms, - callbacks=callbacks, - model=model, - grad_accum=grad_accum, - precision=precision, - precision_context=self.device.precision_context) + dl_hparams = DataloaderHparams(num_workers=num_workers, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, + pin_memory=pin_memory, + timeout=timeout) + + train_gpu_batch_size = train_batch_size // self.ddp.world_size + train_dataloader = self.device.dataloader_to_device( + self.ddp.create_dataloader(train_gpu_batch_size, dl_hparams, train_dataloader_spec), + train_dataloader_spec.prefetch_fn, + ) + self.train_dl_spec = train_dataloader_spec + + eval_gpu_batch_size = eval_batch_size // self.ddp.world_size + eval_dataloader = self.device.dataloader_to_device( + self.ddp.create_dataloader(eval_gpu_batch_size, dl_hparams, eval_dataloader_spec), + eval_dataloader_spec.prefetch_fn, + ) + self.eval_dl_spec = eval_dataloader_spec + + self.state = State( + max_epochs=max_epochs, + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + algorithms=algorithms, + callbacks=callbacks, + model=model, + grad_accum=grad_accum, + precision=precision, + precision_context=self.device.precision_context, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + ) if not log_destinations: log_destinations = [TQDMLoggerBackend()] @@ -220,15 +244,6 @@ def __init__( self.engine = Engine(self.state, self.state.algorithms, self.logger, self.state.callbacks) - self.train_dl_spec = train_dataloader_spec - self.eval_dl_spec = eval_dataloader_spec - - self.dl_hparams = DataloaderHparams(num_workers=num_workers, - prefetch_factor=prefetch_factor, - persistent_workers=persistent_workers, - pin_memory=pin_memory, - timeout=timeout) - self.validate_every_n_batches = validate_every_n_batches self.validate_every_n_epochs = validate_every_n_epochs self.compute_training_metrics = compute_training_metrics @@ -243,8 +258,8 @@ def __init__( # run INIT event before optimizers and schedulers are created self.engine.run_event(Event.INIT) - assert isinstance(self.train_dl_spec.dataset, collections.abc.Sized) - steps_per_epoch = len(self.train_dl_spec.dataset) // train_batch_size + assert isinstance(self.state.train_dataloader.dataset, collections.abc.Sized) + steps_per_epoch = len(self.state.train_dataloader.dataset) // train_batch_size # Need to use hparams here because optimizer and schedulers need to be created after Event.INIT if not optimizer_hparams: optimizer_hparams = DecoupledSGDWHparams(lr=0.1, momentum=0.9, weight_decay=1.0e-4) @@ -365,21 +380,6 @@ def _create_dataloaders(self) -> None: state = self.state # compute per gpu batch size - train_gpu_batch_size = state.train_batch_size // state.world_size - eval_gpu_batch_size = state.eval_batch_size // state.world_size - - train_dataloader = self.ddp.create_dataloader(train_gpu_batch_size, self.dl_hparams, self.train_dl_spec) - eval_dataloader = self.ddp.create_dataloader(eval_gpu_batch_size, self.dl_hparams, self.eval_dl_spec) - - # move to device - state.train_dataloader = self.device.dataloader_to_device( - train_dataloader, - self.train_dl_spec.prefetch_fn, - ) - state.eval_dataloader = self.device.dataloader_to_device( - eval_dataloader, - self.eval_dl_spec.prefetch_fn, - ) def _get_metrics_as_collection(self, *, is_train: bool) -> MetricCollection: """Get metrics relevant to the model. Metrics are all implemented as subclasses @@ -477,11 +477,6 @@ def _train_loop(self) -> None: state.model = self.device.module_to_device(state.model) state.optimizers = map_collection(state.optimizers, self.device.optimizer_to_device) - # create dataloaders here after distributed training has started - self._create_dataloaders() - if state.train_dataloader is None or state.eval_dataloader is None: - raise ValueError('Dataloaders were not created properly, and are None.') - # wrap model with DDP state.model = self.ddp.prepare_module(state.model) original_model = state.model.module diff --git a/tests/algorithms/test_blurpool_algorithm.py b/tests/algorithms/test_blurpool_algorithm.py index d3fea2d235..4baaed5613 100644 --- a/tests/algorithms/test_blurpool_algorithm.py +++ b/tests/algorithms/test_blurpool_algorithm.py @@ -12,12 +12,12 @@ from composer.algorithms import BlurPool, BlurPoolHparams from composer.algorithms.blurpool.blurpool_layers import BlurConv2d, BlurMaxPool2d from composer.core import Event, State -from composer.core.types import Model, Precision +from composer.core.types import DataLoader, Model, Precision from tests.fixtures.models import SimpleConvModel @pytest.fixture -def state(simple_conv_model: Model): +def state(simple_conv_model: Model, dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader): state = State( epoch=50, step=50, @@ -27,6 +27,8 @@ def state(simple_conv_model: Model): max_epochs=100, model=simple_conv_model, precision=Precision.FP32, + train_dataloader=dummy_train_dataloader, + eval_dataloader=dummy_val_dataloader, ) return state diff --git a/tests/algorithms/test_channels_last.py b/tests/algorithms/test_channels_last.py index 4bdf82386d..9cae2873d1 100644 --- a/tests/algorithms/test_channels_last.py +++ b/tests/algorithms/test_channels_last.py @@ -7,7 +7,7 @@ from composer.algorithms import ChannelsLastHparams from composer.core.event import Event from composer.core.state import State -from composer.core.types import Model, Precision, Tensor +from composer.core.types import DataLoader, Model, Precision, Tensor def _has_singleton_dimension(tensor: Tensor) -> bool: @@ -31,7 +31,7 @@ def _infer_memory_format(tensor: Tensor) -> str: @pytest.fixture() -def state(simple_conv_model: Model): +def state(simple_conv_model: Model, dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader): return State( model=simple_conv_model, train_batch_size=100, @@ -39,6 +39,8 @@ def state(simple_conv_model: Model): precision=Precision.FP32, grad_accum=1, max_epochs=10, + train_dataloader=dummy_train_dataloader, + eval_dataloader=dummy_val_dataloader, ) diff --git a/tests/algorithms/test_layer_freezing.py b/tests/algorithms/test_layer_freezing.py index a77a358d65..6f70a167b6 100644 --- a/tests/algorithms/test_layer_freezing.py +++ b/tests/algorithms/test_layer_freezing.py @@ -6,14 +6,15 @@ from composer.algorithms import LayerFreezing, LayerFreezingHparams from composer.core.state import State -from composer.core.types import Event, Model, Precision +from composer.core.types import DataLoader, Event, Model, Precision from composer.loggers import Logger from composer.trainer.trainer_hparams import TrainerHparams from composer.utils import ensure_tuple from tests.utils.trainer_fit import train_model -def _generate_state(epoch: int, max_epochs: int, model: Model): +def _generate_state(epoch: int, max_epochs: int, model: Model, train_dataloader: DataLoader, + val_dataloader: DataLoader): state = State( epoch=epoch, step=epoch, @@ -24,6 +25,8 @@ def _generate_state(epoch: int, max_epochs: int, model: Model): model=model, optimizers=(torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.99),), precision=Precision.FP32, + train_dataloader=train_dataloader, + eval_dataloader=val_dataloader, ) return state @@ -39,8 +42,13 @@ def _check_param_groups(expected_groups, actual_groups): assert (actual_groups[i]['params'][j] == expected_params).all() -def test_freeze_layers_no_freeze(simple_conv_model: Model, noop_dummy_logger: Logger): - state = _generate_state(epoch=10, max_epochs=100, model=simple_conv_model) +def test_freeze_layers_no_freeze(simple_conv_model: Model, noop_dummy_logger: Logger, + dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader): + state = _generate_state(epoch=10, + max_epochs=100, + model=simple_conv_model, + train_dataloader=dummy_train_dataloader, + val_dataloader=dummy_val_dataloader) first_optimizer = ensure_tuple(state.optimizers)[0] assert first_optimizer is not None @@ -53,8 +61,13 @@ def test_freeze_layers_no_freeze(simple_conv_model: Model, noop_dummy_logger: Lo _check_param_groups(expected_param_groups, updated_param_groups) -def test_freeze_layers_with_freeze(simple_conv_model: Model, noop_dummy_logger: Logger): - state = _generate_state(epoch=80, max_epochs=100, model=simple_conv_model) +def test_freeze_layers_with_freeze(simple_conv_model: Model, noop_dummy_logger: Logger, + dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader): + state = _generate_state(epoch=80, + max_epochs=100, + model=simple_conv_model, + train_dataloader=dummy_train_dataloader, + val_dataloader=dummy_val_dataloader) first_optimizer = ensure_tuple(state.optimizers)[0] assert first_optimizer is not None diff --git a/tests/algorithms/test_stochastic_depth.py b/tests/algorithms/test_stochastic_depth.py index 697f2f4a3e..0fa95507a4 100644 --- a/tests/algorithms/test_stochastic_depth.py +++ b/tests/algorithms/test_stochastic_depth.py @@ -35,6 +35,7 @@ def dummy_state(dummy_dataloader_hparams: DataloaderHparams): grad_accum=1, max_epochs=100, model=model, + eval_dataloader=train_dataloader, precision=Precision.FP32) diff --git a/tests/fixtures/dummy_fixtures.py b/tests/fixtures/dummy_fixtures.py index aafc319e7d..1dfec89dfe 100755 --- a/tests/fixtures/dummy_fixtures.py +++ b/tests/fixtures/dummy_fixtures.py @@ -71,8 +71,8 @@ def dummy_val_dataloader_spec(dummy_train_dataset_hparams: SyntheticDatasetHpara @pytest.fixture() -def dummy_state_without_rank(dummy_model: SimpleBatchPairModel, dummy_train_batch_size: int, - dummy_val_batch_size: int) -> State: +def dummy_state_without_rank(dummy_model: SimpleBatchPairModel, dummy_train_batch_size: int, dummy_val_batch_size: int, + dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader) -> State: state = State( model=dummy_model, epoch=5, @@ -81,6 +81,8 @@ def dummy_state_without_rank(dummy_model: SimpleBatchPairModel, dummy_train_batc grad_accum=1, train_batch_size=dummy_train_batch_size, eval_batch_size=dummy_val_batch_size, + train_dataloader=dummy_train_dataloader, + eval_dataloader=dummy_val_dataloader, max_epochs=10, ) return state @@ -191,7 +193,7 @@ def simple_conv_model_input(): @pytest.fixture() -def state_with_model(simple_conv_model: Model): +def state_with_model(simple_conv_model: Model, dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader): state = State( epoch=50, step=50, @@ -201,6 +203,8 @@ def state_with_model(simple_conv_model: Model): max_epochs=100, model=simple_conv_model, precision=Precision.FP32, + train_dataloader=dummy_train_dataloader, + eval_dataloader=dummy_val_dataloader, ) return state diff --git a/tests/test_state.py b/tests/test_state.py index 356acbc559..745b94663c 100755 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -22,7 +22,7 @@ def random_tensor(size=(4, 10)): return torch.rand(*size) -def get_dummy_state(model: BaseMosaicModel): +def get_dummy_state(model: BaseMosaicModel, train_dataloader: types.DataLoader, val_dataloader: types.DataLoader): optimizers = torch.optim.Adadelta(model.parameters()) return State(model=model, @@ -36,6 +36,8 @@ def get_dummy_state(model: BaseMosaicModel): loss=random_tensor(), batch=(random_tensor(), random_tensor()), outputs=random_tensor(), + train_dataloader=train_dataloader, + eval_dataloader=val_dataloader, optimizers=optimizers, schedulers=torch.optim.lr_scheduler.StepLR(optimizers, step_size=3), algorithms=[DummyHparams().initialize_object()]) @@ -106,12 +108,13 @@ def get_batch(model: SimpleBatchPairModel, dataloader_hparams: DataloaderHparams def test_state_serialize(tmpdir: pathlib.Path, dummy_model: BaseMosaicModel, - dummy_dataloader_hparams: DataloaderHparams): + dummy_dataloader_hparams: DataloaderHparams, dummy_train_dataloader: types.DataLoader, + dummy_val_dataloader: types.DataLoader): assert isinstance(dummy_model, SimpleBatchPairModel) - state1 = get_dummy_state(dummy_model) - state2 = get_dummy_state(dummy_model) + state1 = get_dummy_state(dummy_model, dummy_train_dataloader, dummy_val_dataloader) + state2 = get_dummy_state(dummy_model, dummy_train_dataloader, dummy_val_dataloader) # train one step to set the optimizer states batch = get_batch(dummy_model, dummy_dataloader_hparams) diff --git a/tests/trainer/test_ddp_sync_strategy.py b/tests/trainer/test_ddp_sync_strategy.py index 09ef3fab72..9436240794 100755 --- a/tests/trainer/test_ddp_sync_strategy.py +++ b/tests/trainer/test_ddp_sync_strategy.py @@ -7,7 +7,7 @@ import torch.nn as nn from composer.core.state import State -from composer.core.types import Tensor +from composer.core.types import DataLoader, Tensor from composer.trainer.ddp import DDP @@ -44,7 +44,8 @@ def loss(self, output: Tensor, target: Tensor): pytest.param('forced_sync', ([-1, None, None], [-1, -1, None], [-1.5, -1.5, None]), id='forced_sync'), ]) @pytest.mark.world_size(2) -def test_ddp_sync_strategy(ddp_sync_strategy: str, expected_grads: List[Optional[float]]): +def test_ddp_sync_strategy(ddp_sync_strategy: str, expected_grads: List[Optional[float]], + dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader): original_model = MinimalConditionalModel() ddp = DDP(backend="gloo", find_unused_parameters=True, sync_strategy=ddp_sync_strategy, timeout=5.) optimizer = torch.optim.SGD(original_model.parameters(), 0.1) @@ -55,6 +56,8 @@ def test_ddp_sync_strategy(ddp_sync_strategy: str, expected_grads: List[Optional eval_batch_size=1, grad_accum=2, max_epochs=1, + train_dataloader=dummy_train_dataloader, + eval_dataloader=dummy_val_dataloader, precision='fp32') batches = [[(1, Tensor([1])), (1, Tensor([2]))], [(2, Tensor([1])), (2, Tensor([2]))]] From 8b3563e15392ba1d2693fab50d811743db32c7ba Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Mon, 22 Nov 2021 15:37:35 -0800 Subject: [PATCH 12/38] Run Directory Uploader Added uploading of the run directory to various cloud providers via a callback. Depends on the LibCloud plugin. Closes #98. Depends on #85 and (for tests) #92. --- composer/callbacks/__init__.py | 2 + composer/callbacks/callback_hparams.py | 62 ++++- composer/callbacks/run_directory_uploader.py | 250 ++++++++++++++++++ composer/trainer/trainer_hparams.py | 4 +- setup.py | 2 +- .../callbacks/test_run_directory_uploader.py | 41 +++ 6 files changed, 358 insertions(+), 3 deletions(-) create mode 100644 composer/callbacks/run_directory_uploader.py create mode 100644 tests/callbacks/test_run_directory_uploader.py diff --git a/composer/callbacks/__init__.py b/composer/callbacks/__init__.py index ced53e1b73..8d158f1ad9 100644 --- a/composer/callbacks/__init__.py +++ b/composer/callbacks/__init__.py @@ -6,8 +6,10 @@ from composer.callbacks.callback_hparams import GradMonitorHparams as GradMonitorHparams from composer.callbacks.callback_hparams import LRMonitorHparams as LRMonitorHparams from composer.callbacks.callback_hparams import MemoryMonitorHparams as MemoryMonitorHparams +from composer.callbacks.callback_hparams import RunDirectoryUploaderHparams as RunDirectoryUploaderHparams from composer.callbacks.callback_hparams import SpeedMonitorHparams as SpeedMonitorHparams from composer.callbacks.callback_hparams import TorchProfilerHparams as TorchProfilerHparams from composer.callbacks.lr_monitor import LRMonitor as LRMonitor +from composer.callbacks.run_directory_uploader import RunDirectoryUploader as RunDirectoryUploader from composer.callbacks.speed_monitor import SpeedMonitor as SpeedMonitor from composer.callbacks.torch_profiler import TorchProfiler as TorchProfiler diff --git a/composer/callbacks/callback_hparams.py b/composer/callbacks/callback_hparams.py index 9c529e2a07..668e96687f 100644 --- a/composer/callbacks/callback_hparams.py +++ b/composer/callbacks/callback_hparams.py @@ -4,8 +4,9 @@ from __future__ import annotations import abc +import textwrap from dataclasses import asdict, dataclass -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, Any, Dict, List, Optional import yahp as hp @@ -16,6 +17,7 @@ from composer.callbacks.grad_monitor import GradMonitor from composer.callbacks.lr_monitor import LRMonitor from composer.callbacks.memory_monitor import MemoryMonitor + from composer.callbacks.run_directory_uploader import RunDirectoryUploader from composer.callbacks.speed_monitor import SpeedMonitor from composer.callbacks.torch_profiler import TorchProfiler @@ -153,3 +155,61 @@ class TorchProfilerHparams(CallbackHparams): def initialize_object(self) -> TorchProfiler: from composer.callbacks.torch_profiler import TorchProfiler return TorchProfiler(**asdict(self)) + + +@dataclass +class RunDirectoryUploaderHparams(CallbackHparams): + """:class:`~composer.callbacks.torch_profiler.RunDirectoryUploader` hyperparameters. + + See :class:`~composer.callbacks.torch_profiler.RunDirectoryUploader` for documentation. + """ + + # Args: + # provider_init_kwargs (Dict[str, Any], optional): Parameters to pass into the constructor for the + # :class:`~libcloud.storage.providers.Provider` constructor. These arguments would usually include the cloud region + # and credentials. Defaults to None, which is equivalent to an empty dictionary. + provider: str = hp.required("Cloud provider to use.") + container: str = hp.optional("he name of the container (i.e. bucket) to use.", default=None) + key: Optional[str] = hp.optional(textwrap.dedent( + """API key or username to use to connect to the provider. For security. do NOT hardcode the key in the YAML. + Instead, please specify via CLI arguments, or even better, environment variables."""), + default=None) + secret: Optional[str] = hp.optional(textwrap.dedent( + """API secret to use to connect to the provider. For security. do NOT hardcode the key in the YAML. +Instead, please specify via CLI arguments, or even better, environment variables."""), + default=None) + region: Optional[str] = hp.optional("Cloud region to use", default=None) + host: Optional[str] = hp.optional("Override hostname for connections", default=None) + port: Optional[int] = hp.optional("Override port for connections", default=None) + num_concurrent_uploads: int = hp.optional("Maximum number of concurrent uploads. Defaults to 4.", default=4) + use_procs: bool = hp.optional( + "Whether to perform file uploads in background processes (as opposed to threads). Defaults to True.", + default=True) + upload_staging_folder: Optional[str] = hp.optional( + "Staging folder for uploads. If not specified, will use a temporary directory.", default=None) + extra_init_kwargs: Dict[str, Any] = hp.optional( + "Extra keyword arguments to pass into the constructor for the specified provider.", default_factory=dict) + upload_every_n_batches: int = hp.optional( + textwrap.dedent("""Interval at which to scan the run directory for changes and to + queue uploads of files. Uploads are also queued at the end of the epoch. Defaults to every 100 batches."""), + default=100) + + def initialize_object(self) -> RunDirectoryUploader: + from composer.callbacks.run_directory_uploader import RunDirectoryUploader + init_kwargs = { + "key": self.key, + "secret": self.secret, + "host": self.host, + "port": self.port, + "region": self.region, + } + init_kwargs.update(self.extra_init_kwargs) + return RunDirectoryUploader( + provider=self.provider, + container=self.container, + num_concurrent_uploads=self.num_concurrent_uploads, + upload_staging_folder=self.upload_staging_folder, + use_procs=self.use_procs, + provider_init_kwargs=init_kwargs, + upload_every_n_batches=self.upload_every_n_batches, + ) diff --git a/composer/callbacks/run_directory_uploader.py b/composer/callbacks/run_directory_uploader.py new file mode 100644 index 0000000000..6f02318f05 --- /dev/null +++ b/composer/callbacks/run_directory_uploader.py @@ -0,0 +1,250 @@ +# Copyright 2021 MosaicML. All Rights Reserved. + +from __future__ import annotations + +import multiprocessing +import os +import queue +import shutil +import tempfile +import threading +import time +import warnings +from typing import Any, Callable, Dict, Optional, Type, Union + +from composer.core.callback import RankZeroCallback +from composer.core.event import Event +from composer.core.logging import Logger +from composer.core.logging.logger import LogLevel +from composer.core.state import State +from composer.utils.run_directory import get_run_directory + + +class RunDirectoryUploader(RankZeroCallback): + """Callback to upload the run directory to a blob store. + + This callback checks the run directory for new or modified files + at the end of every epoch, and after every `upload_every_n_batches` batches. + This callback detects new or modified files based off of the file modification + timestamp. Only files that have a newer last modified timestamp since the last upload + will be uploaded. + + This uploader is compatible with multi-GPU training. It blocks the main thread + for each local rank when creating a copy of the modified files in the run directory + before yielding back to the training loop. Uploads are performed from the copied files. + It assumes that only the main thread on each rank writes to the run directory. + + While all uploads happen in the background, here are some additional tips for minimizing + the performance impact: + + * Ensure that `upload_every_n_batches` is sufficiently infrequent as to limit when + the blocking scans of the run direcory and copies of modified files. + However, do not make it too infrequent in case if the training process unexpectedly dies, + since data from the last upload may be lost. + + * Set `use_procs=True` (the default) to use background processes, + instead of threads, to perform the file uploads. Processes are recommended to + ensure that the GIL is not blocking the training loop when performance CPU + operations on uploaded files (e.g. comparing and computing checksums). + Network I/O happens always occurs in the background. + + * Provide a RAM disk path for the `upload_staging_folder` parameter. Copying files to stage on RAM + will be faster than writing to disk. However, you must have sufficient excess RAM on your system, + or you may experience OutOfMemory errors. + + .. note:: + + To use this callback, install composer with `pip install mosaicml[logging]`. + + Args: + provider (str): Cloud provider to use. + + Specify the last part of the Apache Libcloud Module here. + `This document ` + lists all supported providers. For example, the module name for Amazon S3 is `libcloud.storage.drivers.s3`, so + to use S3, specify 's3' here. + + container (str): The name of the container (i.e. bucket) to use. + num_concurrent_uploads (int, optional): Maximum number of concurrent uploads. Defaults to 4. + upload_staging_folder (Optional[str], optional): A folder to use for staging uploads. + If not specified, defaults to using a :class:`~tempfile.TemporaryDirectory`. + use_procs (bool, optional): Whether to perform file uploads in background processes (as opposed to threads). + Defaults to True. + upload_every_n_batches (int, optional): Interval at which to scan the run directory for changes and to + queue uploads of files. Uploads are always queued at the end of the epoch. Defaults to every 100 batches. + provider_init_kwargs (Dict[str, Any], optional): Parameters to pass into the constructor for the + :class:`~libcloud.storage.providers.Provider` constructor. These arguments would usually include the cloud region + and credentials. Defaults to None, which is equivalent to an empty dictionary. + """ + + def __init__( + self, + provider: str, + container: Optional[str] = None, + num_concurrent_uploads: int = 4, + upload_staging_folder: Optional[str] = None, + use_procs: bool = True, + upload_every_n_batches: int = 100, + provider_init_kwargs: Dict[str, Any] = None, + ) -> None: + run_directory = get_run_directory() + if run_directory is None: + warnings.warn("NoRunDirectory: The run directory is not set, so the RunDirectoryUploader will be a no-op") + return + + if provider_init_kwargs is None: + provider_init_kwargs = {} + self._provider_init_kwargs = provider_init_kwargs + self._upload_every_n_batches = upload_every_n_batches + self._object_name_prefix = "" # TODO ravi. Decide how this will be set. Hparams? Run directory name? + + self._last_upload_timestamp = 0.0 # unix timestamp of last uploaded time + if upload_staging_folder is None: + self._tempdir = tempfile.TemporaryDirectory() + self._upload_staging_folder = self._tempdir.name + else: + self._tempdir = None + self._upload_staging_folder = upload_staging_folder + + if num_concurrent_uploads < 1: + raise ValueError("num_concurrent_uploads must be >= 1. Blocking uploads are not supported.") + self._num_concurrent_uploads = num_concurrent_uploads + self._provider = provider + self._container = container + + if use_procs: + self._file_upload_queue: Union[queue.Queue[str], + multiprocessing.JoinableQueue[str]] = multiprocessing.JoinableQueue() + self._finished_cls: Union[Callable[[], multiprocessing._EventType], + Type[threading.Event]] = multiprocessing.Event + self._proc_class = multiprocessing.Process + else: + self._file_upload_queue = queue.Queue() + self._finished_cls = threading.Event + self._proc_class = threading.Thread + self._finished: Union[None, multiprocessing._EventType, threading.Event] = None + self._workers = [] + + def _init(self) -> None: + self._finished = self._finished_cls() + self._last_upload_timestamp = 0.0 + self._workers = [ + self._proc_class(target=_upload_worker, + kwargs={ + "file_queue": self._file_upload_queue, + "is_finished": self._finished, + "upload_staging_dir": self._upload_staging_folder, + "provider_name": self._provider, + "container_name": self._container, + "object_name_prefix": self._object_name_prefix, + "init_kwargs": self._provider_init_kwargs, + }) for _ in range(self._num_concurrent_uploads) + ] + for worker in self._workers: + worker.start() + + def _run_event(self, event: Event, state: State, logger: Logger) -> None: + if get_run_directory() is None: + return + if event == Event.INIT: + self._init() + if event == Event.BATCH_END: + if (state.batch_idx + 1) % self._upload_every_n_batches == 0: + self._trigger_upload(state, logger, LogLevel.BATCH) + if event == Event.EPOCH_END: + self._trigger_upload(state, logger, LogLevel.EPOCH) + if event == Event.TRAINING_END: + self._trigger_upload(state, logger, LogLevel.FIT) + # TODO -- we are missing logfiles from other callbacks / loggers that write on training end but after + # the run directory uploader is invoked. This callback either needs to fire last, + # or we need another event such as cleanup + self._close() + + def _close(self): + assert self._finished is not None, "finished should not be None" + self._finished.set() + for worker in self._workers: + worker.join() + + def _trigger_upload(self, state: State, logger: Logger, log_level: LogLevel) -> None: + # Ensure that every rank is at this point + # Assuming only the main thread on each rank writes to the run directory, then the barrier here will ensure + # that the run directory is not being modified after we pass this barrier + # TODO(ravi) -- add in a ddp barrier here. + # state.ddp.barrier() + new_last_uploaded_timestamp = time.time() + # Now, for each file that was modified since self._last_upload_timestamp, copy it to the temporary directory + # IMPROTANT: From now, until self._last_upload_timestamp is updated, no files should be written to the run directory + run_directory = get_run_directory() + assert run_directory is not None, "invariant error" + files_to_be_uploaded = [] + for root, dirs, files in os.walk(run_directory): + del dirs # unused + for file in files: + filepath = os.path.join(root, file) + relpath = os.path.relpath(filepath, run_directory) # chop off the run directory + modified_time = os.path.getmtime(filepath) + if modified_time > self._last_upload_timestamp: + copied_path = os.path.join(self._upload_staging_folder, str(new_last_uploaded_timestamp), relpath) + files_to_be_uploaded.append(relpath) + copied_path_dirname = os.path.dirname(copied_path) + os.makedirs(copied_path_dirname, exist_ok=True) + # shutil.copyfile(filepath, copied_path) + shutil.copy2(filepath, copied_path) + self._file_upload_queue.put_nowait(copied_path) + self._last_upload_timestamp = new_last_uploaded_timestamp + # now log which files are being uploaded. OK to do, since we're done reading the directory, + # and any logfiles will now have their last modified timestamp + # incremented past self._last_upload_timestamp + logger.metric(log_level, {"run_directory/uploaded_files": files_to_be_uploaded}) + + +def _upload_worker( + file_queue: Union[queue.Queue[str], multiprocessing.JoinableQueue[str]], + is_finished: Union[multiprocessing._EventType, threading.Event], + upload_staging_dir: str, + provider_name: str, + container_name: str, + object_name_prefix: str, + init_kwargs: Dict[str, Any], +): + """A long-running function to handle uploading files. + + Args: + file_queue (queue.Queue or multiprocessing.JoinableQueue): The worker will poll + this queue for files to upload. + is_finished (threading.Event or multiprocessing.Event): An event that will be + set when training is finished and no new files will be added to the queue. + The worker will continue to upload existing files that are in the queue. + When the queue is empty, the worker will exit. + upload_staging_dir (str): The upload staging directory. + provider_name (str): The cloud provider name. + container_name (str): The container name (e.g. s3 bucket) for the provider + where files will be uploaded. + object_name_prefix (str): Prefix to prepend to the object names + before they are uploaded to the blob store. + init_kwargs (Dict[str, Any]): Arguments to pass in to the + :class:`~libcloud.storage.providers.Provider` constructor. + """ + from libcloud.storage.providers import get_driver + provider_cls = get_driver(provider_name) + provider = provider_cls(**init_kwargs) + container = provider.get_container(container_name) + while True: + try: + file_path_to_upload = file_queue.get_nowait() + except queue.Empty: + if is_finished.is_set(): + break + else: + time.sleep(0.5) + continue + obj_name = ",".join(os.path.relpath(file_path_to_upload, upload_staging_dir).split( + os.path.sep)[1:]) # the first folder is the upload timestamp. Chop that off. + provider.upload_object( + file_path=file_path_to_upload, + container=container, + object_name=object_name_prefix + obj_name, + ) + os.remove(file_path_to_upload) + file_queue.task_done() diff --git a/composer/trainer/trainer_hparams.py b/composer/trainer/trainer_hparams.py index d818de6971..c92c1cc77a 100755 --- a/composer/trainer/trainer_hparams.py +++ b/composer/trainer/trainer_hparams.py @@ -15,7 +15,8 @@ import composer.datasets as datasets from composer.algorithms import AlgorithmHparams, get_algorithm_registry from composer.callbacks import (BenchmarkerHparams, CallbackHparams, GradMonitorHparams, LRMonitorHparams, - MemoryMonitorHparams, SpeedMonitorHparams, TorchProfilerHparams) + MemoryMonitorHparams, RunDirectoryUploaderHparams, SpeedMonitorHparams, + TorchProfilerHparams) from composer.core.types import Precision from composer.datasets import DataloaderHparams from composer.loggers import (BaseLoggerBackendHparams, FileLoggerBackendHparams, TQDMLoggerBackendHparams, @@ -80,6 +81,7 @@ "lr_monitor": LRMonitorHparams, "grad_monitor": GradMonitorHparams, "memory_monitor": MemoryMonitorHparams, + "run_directory_uploader": RunDirectoryUploaderHparams, } logger_registry = { diff --git a/setup.py b/setup.py index bd31c35993..d591e29ba5 100755 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ def package_files(directory): 'testbook>=0.4.2', 'myst-parser>=0.15.2', ] -extra_deps['wandb'] = ['wandb>=0.12.2'] +extra_deps['logging'] = ['wandb>=0.12.2', 'apache-libcloud>=3.4.1'] extra_deps['nlp'] = [ 'transformers>=4.11.3', diff --git a/tests/callbacks/test_run_directory_uploader.py b/tests/callbacks/test_run_directory_uploader.py new file mode 100644 index 0000000000..26dccfc8ad --- /dev/null +++ b/tests/callbacks/test_run_directory_uploader.py @@ -0,0 +1,41 @@ +# Copyright 2021 MosaicML. All Rights Reserved. + +import os +import pathlib + +import pytest + +from composer.callbacks import RunDirectoryUploaderHparams +from composer.core.event import Event +from composer.core.logging import Logger +from composer.core.state import State +from composer.utils.run_directory import get_run_directory + + +@pytest.mark.parametrize("use_procs", [True, False]) +def test_run_directory_uploader(tmpdir: pathlib.Path, use_procs: bool, dummy_state: State, dummy_logger: Logger): + dummy_state.epoch = 0 + dummy_state.step = 0 + remote_dir = str(tmpdir / "run_directory_copy") + os.makedirs(remote_dir, exist_ok=True) + hparams = RunDirectoryUploaderHparams( + provider='local', + upload_every_n_batches=1, + key=remote_dir, # for the local option, the key is the path + container=".", + num_concurrent_uploads=1, + use_procs=use_procs, + ) + + uploader = hparams.initialize_object() + uploader.run_event(Event.INIT, dummy_state, dummy_logger) + run_directory = get_run_directory() + assert run_directory is not None + with open(os.path.join(run_directory, "dummy_file"), "w+") as f: + f.write("Hello, world!") + uploader.run_event(Event.BATCH_END, dummy_state, dummy_logger) + uploader.run_event(Event.TRAINING_END, dummy_state, dummy_logger) + + # now assert that we have a dummy file in the run directory copy folder + with open(os.path.join(remote_dir, "dummy_file"), "r") as f: + assert f.read() == "Hello, world!" From 5214f390c448852338d5365400a285c82017fbb0 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Mon, 22 Nov 2021 16:27:08 -0800 Subject: [PATCH 13/38] Supporting both styles for callbacks Removed deferred logging since rank is now known at the init event --- composer/callbacks/benchmarker.py | 19 +- composer/callbacks/grad_monitor.py | 7 +- composer/callbacks/lr_monitor.py | 6 +- composer/callbacks/memory_monitor.py | 14 +- composer/callbacks/speed_monitor.py | 24 +-- composer/callbacks/torch_profiler.py | 18 +- composer/core/callback.py | 253 ++++++++++++++++++++++++-- composer/core/logging/base_backend.py | 47 +---- composer/loggers/file_logger.py | 28 ++- composer/loggers/tqdm_logger.py | 7 +- composer/loggers/wandb_logger.py | 28 +-- docs/source/core/callback.rst | 36 +++- pyproject.toml | 5 +- tests/callbacks/test_callbacks.py | 10 + tests/test_logger.py | 2 +- 15 files changed, 359 insertions(+), 145 deletions(-) diff --git a/composer/callbacks/benchmarker.py b/composer/callbacks/benchmarker.py index bfd7c9a24c..9e72b2339e 100644 --- a/composer/callbacks/benchmarker.py +++ b/composer/callbacks/benchmarker.py @@ -10,7 +10,6 @@ from composer import Logger, State from composer.callbacks.callback_hparams import BenchmarkerHparams from composer.core.callback import Callback -from composer.core.event import Event from composer.core.types import BreakEpochException log = logging.getLogger(__name__) @@ -107,16 +106,6 @@ def __init__(self, self.original_max_epochs = -1 self.wct_dict = {} - def _run_event(self, event: Event, state: State, logger: Logger) -> None: - if event == Event.TRAINING_START: - self._training_start(state, logger) - if event == Event.BATCH_START: - self._batch_start(state, logger) - if event == Event.BATCH_END: - self._batch_end(state, logger) - if event == Event.EPOCH_END: - self._epoch_end(state, logger) - def _compute_elapsed_wct(self, epoch_wct_dict, steps_per_epoch: int, n_epochs: int): wct = 0.0 wct_per_step = 0 @@ -127,7 +116,7 @@ def _compute_elapsed_wct(self, epoch_wct_dict, steps_per_epoch: int, n_epochs: i wct += wct_per_step return wct * n_epochs - def _training_start(self, state: State, logger: Logger): + def training_start(self, state: State, logger: Logger): del logger # Unused warnings.warn("The timing monitor is activated. The model will not be fully trained." "All quality metrics for this run will be incorrect.") @@ -140,7 +129,7 @@ def _training_start(self, state: State, logger: Logger): self.wct_dict = {e: {s: -1.0 for s in self.step_list} for e in self.epoch_list} state.max_epochs = len(self.epoch_list) - def _epoch_end(self, state: State, logger: Logger): + def epoch_end(self, state: State, logger: Logger): prev_epoch = self.epoch_list[self.epoch_ix] epoch_wct_dict = self.wct_dict[prev_epoch] self.epoch_ix += 1 @@ -156,7 +145,7 @@ def _epoch_end(self, state: State, logger: Logger): self.wall_clock_train += self._compute_elapsed_wct(epoch_wct_dict, state.steps_per_epoch, n_epochs) logger.metric_epoch({'wall_clock_train': self.wall_clock_train}) - def _batch_start(self, state: State, logger: Logger): + def batch_start(self, state: State, logger: Logger): del state, logger # Unused if self.current_time is None: self.current_time = time.time() @@ -164,7 +153,7 @@ def _batch_start(self, state: State, logger: Logger): self.profile_steps = 0 self.profile_time = 0.0 - def _batch_end(self, state: State, logger: Logger): + def batch_end(self, state: State, logger: Logger): if self.current_time is not None: now = time.time() elapsed = now - self.current_time diff --git a/composer/callbacks/grad_monitor.py b/composer/callbacks/grad_monitor.py index 81aeeaf817..5025edf788 100644 --- a/composer/callbacks/grad_monitor.py +++ b/composer/callbacks/grad_monitor.py @@ -1,6 +1,6 @@ # Copyright 2021 MosaicML. All Rights Reserved. -from composer.core import Event, Logger, State +from composer.core import Logger, State from composer.core.callback import Callback @@ -24,7 +24,7 @@ def __init__(self, log_layer_grad_norms: bool = False): super().__init__() self.log_layer_grad_norms = log_layer_grad_norms - def _run_event(self, event: Event, state: State, logger: Logger): + def after_train_batch(self, state: State, logger: Logger): """Compute the gradient L2 norm after the reduction of the backwards pass across GPUs. This function iterates over the parameters of the model and hence may cause a reduction in @@ -33,14 +33,11 @@ def _run_event(self, event: Event, state: State, logger: Logger): unscaling in cases where gradients are scaled. Args: - event (Event): The :class:`~composer.core.Event` object state (State): The :class:`~composer.core.State` object used during training. logger (Logger): The :class:`~composer.core.logging.logger.Logger` object. """ - if event != Event.AFTER_TRAIN_BATCH: - return norm = 0.0 layer_norms = {} for name, p in state.model.named_parameters(): diff --git a/composer/callbacks/lr_monitor.py b/composer/callbacks/lr_monitor.py index aa98988c7f..461e0b4862 100644 --- a/composer/callbacks/lr_monitor.py +++ b/composer/callbacks/lr_monitor.py @@ -1,6 +1,6 @@ # Copyright 2021 MosaicML. All Rights Reserved. -from composer.core import Callback, Event, Logger, State +from composer.core import Callback, Logger, State from composer.utils import ensure_tuple @@ -14,9 +14,7 @@ class LRMonitor(Callback): def __init__(self) -> None: super().__init__() - def _run_event(self, event: Event, state: State, logger: Logger): - if event != Event.BATCH_END: - return + def batch_end(self, state: State, logger: Logger): assert state.optimizers is not None, "optimizers must be defined" for optimizer in ensure_tuple(state.optimizers): lrs = [group['lr'] for group in optimizer.param_groups] diff --git a/composer/callbacks/memory_monitor.py b/composer/callbacks/memory_monitor.py index e9382f8bdb..cc38f88d27 100755 --- a/composer/callbacks/memory_monitor.py +++ b/composer/callbacks/memory_monitor.py @@ -4,7 +4,7 @@ from torch.cuda import device_count, memory_stats -from composer.core import Event, Logger, State +from composer.core import Logger, State from composer.core.callback import Callback log = logging.getLogger(__name__) @@ -26,10 +26,16 @@ def __init__(self): if device_count == 0: log.warn("Memory monitor only works on GPU devices.") - def _run_event(self, event: Event, state: State, logger: Logger): - if event != Event.AFTER_TRAIN_BATCH: - return + def after_train_batch(self, state: State, logger: Logger): + """This function calls the torch cuda memory stats and reports basic memory + statistics. + Args: + state (State): The :class:`~composer.core.State` object + used during training. + logger (Logger): + The :class:`~composer.core.logging.logger.Logger` object. + """ memory_report = {} default_stats = { diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index 92ac6c2e6a..c2d031e6f5 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -6,7 +6,7 @@ from collections import deque from typing import Deque, Optional -from composer import Event, Logger, State +from composer import Logger, State from composer.callbacks.callback_hparams import SpeedMonitorHparams from composer.core.callback import RankZeroCallback from composer.core.types import StateDict @@ -63,24 +63,19 @@ def _load_state(self) -> None: self.batch_num_samples = self.loaded_state["batch_num_samples"] self.loaded_state = None - def _run_event(self, event: Event, state: State, logger: Logger) -> None: - if event == Event.EPOCH_START: - self._epoch_start() - if event == Event.BATCH_START: - self._load_state() - if event == Event.BATCH_END: - self._batch_end(state, logger) - if event == Event.EPOCH_END: - self._epoch_end(logger) - - def _epoch_start(self): + def batch_start(self, state: State, logger: Logger) -> None: + del state, logger # unused + self._load_state() + + def epoch_start(self, state: State, logger: Logger): + del state, logger # unused self._load_state() self.epoch_start_time = time.time() self.batch_end_times.clear() self.batch_num_samples.clear() self.train_examples_per_epoch = 0 - def _batch_end(self, state: State, logger: Logger): + def batch_end(self, state: State, logger: Logger): self.batch_end_times.append(time.time()) batch_num_samples = 0 batch_num_samples += state.last_batch_size @@ -95,7 +90,8 @@ def _batch_end(self, state: State, logger: Logger): throughput = sum(self.batch_num_samples) / (self.batch_end_times[-1] - self.batch_end_times[0]) logger.metric_batch({'throughput/step': throughput}) - def _epoch_end(self, logger: Logger): + def epoch_end(self, state: State, logger: Logger): + del state # unused epoch_time = time.time() - self.epoch_start_time self.wall_clock_train += epoch_time logger.metric_epoch({ diff --git a/composer/callbacks/torch_profiler.py b/composer/callbacks/torch_profiler.py index c797f6581b..0e3724a04f 100644 --- a/composer/callbacks/torch_profiler.py +++ b/composer/callbacks/torch_profiler.py @@ -12,7 +12,7 @@ from composer import Callback from composer.callbacks.callback_hparams import TorchProfilerHparams -from composer.core.types import Event, StateDict +from composer.core.types import StateDict from composer.utils.ddp import get_global_rank from composer.utils.run_directory import get_relative_to_run_directory @@ -132,15 +132,7 @@ def scheduler_fn(self, profiler_step: int) -> ProfilerAction: torch_scheduler_action = ProfilerAction.RECORD_AND_SAVE return torch_scheduler_action - def _run_event(self, event: Event, state: State, logger: Logger) -> None: - if event == Event.TRAINING_START: - self._training_start(state, logger) - if event == Event.BATCH_START: - self._batch_start(state, logger) - if event == Event.BATCH_END: - self._batch_end(state, logger) - - def _training_start(self, state: State, logger: Logger) -> None: + def training_start(self, state: State, logger: Logger) -> None: del state, logger # unused assert self.profiler is None, _PROFILE_MISSING_ERROR self.profiler = torch.profiler.profile( @@ -159,16 +151,16 @@ def _training_start(self, state: State, logger: Logger) -> None: self.profiler.__enter__() atexit.register(self._close_profiler) - def _batch_end(self, state: State, logger: Logger) -> None: + def batch_end(self, state: State, logger: Logger) -> None: del state, logger # unused assert self.profiler is not None, _PROFILE_MISSING_ERROR self.profiler.step() - def _epoch_start(self, state: State, logger: Logger) -> None: + def epoch_start(self, state: State, logger: Logger) -> None: del logger # unused self.profiler_state.batches_per_epoch = state.steps_per_epoch - def _batch_start(self, state: State, logger: Logger) -> None: + def batch_start(self, state: State, logger: Logger) -> None: self.profiler_state.batch_in_epoch = state.batch_idx assert self.profiler is not None, _PROFILE_MISSING_ERROR logger.metric_batch({"profiler/state": self.profiler.current_action.name}) diff --git a/composer/core/callback.py b/composer/core/callback.py index 240ed85f6d..6a496d6fa5 100644 --- a/composer/core/callback.py +++ b/composer/core/callback.py @@ -5,7 +5,6 @@ from __future__ import annotations import abc -import warnings from typing import TYPE_CHECKING from composer.core.serializable import Serializable @@ -28,9 +27,14 @@ class Callback(Serializable, abc.ABC): they are run on specific events. By convention, Callbacks should not modify :class:`State`. - Subclasses should override :meth:`_run_event` - (**not** `run_event`) to run in response - to given :class:`Event` invocations. + Callbacks can be implemented in two ways: + + #. Override the individual methods named for each :class:`Event`. + + #. Override :meth:`_run_event` (**not** :meth:`run_event`) to run in response + to all events. If this method is overridden, then the individual methods + corresponding to each event name will not be automatically called (however, + the subclass implementation can invoke these methods as it wishes.) """ def __init__(self) -> None: @@ -49,23 +53,242 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: def _run_event(self, event: Event, state: State, logger: Logger) -> None: # default fallback if the callback does not override _run_event - try: - event_cb = getattr(self, event.value) - except AttributeError: - return - warnings.warn( - f"CallbackMethodDeprecationWarning: `self.{event.value}()` will be removed in callbacks." - "Instead, override `self._run_event()`.", - category=DeprecationWarning) + event_cb = getattr(self, event.value) return event_cb(state, logger) + def init(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.INIT` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def training_start(self, state: State, logger: Logger) -> None: + """Called on the :attr:`Event.TRAINING_START` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def epoch_start(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.EPOCH_START` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def batch_start(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.BATCH_START` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def after_dataloader(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.AFTER_DATALOADER` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def before_train_batch(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.BEFORE_TRAIN_BATCH` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def before_forward(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.BEFORE_FORWARD` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def after_forward(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.AFTER_FORWARD` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def before_loss(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.BEFORE_LOSS` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def after_loss(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.AFTER_LOSS` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def before_backward(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.BEFORE_BACKWARD` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def after_backward(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.AFTER_BACKWARD` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def after_train_batch(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.AFTER_TRAIN_BATCH` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def batch_end(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.BATCH_END` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def epoch_end(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.EPOCH_END` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def training_end(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.TRAINING_END` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def eval_start(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.EVAL_START` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def eval_batch_start(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.EVAL_BATCH_START` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def eval_before_forward(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.EVAL_BATCH_FORWARD` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def eval_after_forward(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.EVAL_AFTER_FORWARD` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def eval_batch_end(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.EVAL_BATCH_END` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + + def eval_end(self, state: State, logger: Logger) -> None: + """Called on the :attr:`~Event.EVAL_END` event. + Args: + state (State): The global state. + logger (Logger): The logger. + + """ + del state, logger # unused + pass + class RankZeroCallback(Callback, abc.ABC): """Base class for callbacks that only run on the rank zero process. - Subclasses should override :meth:`_run_event` - (**not** `run_event`) to run in response - to given :class:`Event` invocations. + Callbacks can be implemented in two ways: + + #. Override the individual methods named for each :class:`Event`. (See + the parent class, :class:`Callback`.) + + #. Override :meth:`_run_event` (**not** :meth:`run_event`) to run in response + to all events. If this method is overridden, then the individual methods + corresponding to each event name will not be automatically called (however, + the subclass implementation can invoke these methods as it wishes.) """ @final diff --git a/composer/core/logging/base_backend.py b/composer/core/logging/base_backend.py index 126c7aa02e..5fca88532a 100644 --- a/composer/core/logging/base_backend.py +++ b/composer/core/logging/base_backend.py @@ -2,13 +2,10 @@ from __future__ import annotations -import warnings from abc import ABC -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING -from composer.core.callback import Callback -from composer.core.event import Event -from composer.core.logging.logger import Logger +from composer.core.callback import Callback, RankZeroCallback from composer.utils.ddp import is_rank_zero if TYPE_CHECKING: @@ -64,7 +61,7 @@ def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) pass -class RankZeroLoggerBackend(BaseLoggerBackend, Callback, ABC): +class RankZeroLoggerBackend(BaseLoggerBackend, RankZeroCallback, ABC): """Base class for logging backends that run only on the rank zero process. In a multi-process training setup (e.g. when using DistributedDataParallel), @@ -73,28 +70,18 @@ class RankZeroLoggerBackend(BaseLoggerBackend, Callback, ABC): and save data. When using this class, override - :func:`_will_log`, :func:`_log_metric`, and :func:`_training_start` instead of - :func:`will_log`, :func:`log_metric`, and :func:`training_start`, respectively. + :func:`_will_log` and :func:`_log_metric`` instead of + :func:`will_log` and :func:`log_metric`, respectively. - This class ensures that :func:`_log_metric` and :func:`_training_start` are invoked only - on the rank zero process. - - It caputres all logged data before the global rank is available. - On the rank zero process, during the - :attr:`~composer.core.event.Event.TRAINING_START` event (which occurs - after the global rank is set), it routes all captured logged data to - :func:`_log_metric`. For other processes, the captured log data - is eventually discarded. + This class ensures that :func:`_will_log` and :func:`_log_metric` + are invoked only on the rank zero process. .. automethod:: _will_log .. automethod:: _log_metric - .. automethod:: _training_start """ def __init__(self) -> None: super().__init__() - # self._deferred_log_metric_calls is set to None once the logger is initialized - self._deferred_log_metric_calls: Optional[List[Tuple[int, int, LogLevel, TLogData]]] = [] def _will_log(self, state: State, log_level: LogLevel) -> bool: """Called by the :class:`~composer.core.logging.logger.Logger` @@ -140,25 +127,5 @@ def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData @final def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) -> None: if not is_rank_zero(): - # no log if not on rank zero, clear deferred calls to free memory - self._deferred_log_metric_calls = None - return - if self._deferred_log_metric_calls is not None: - warnings.warn(f"DeferredLogMetricWarning: {self.__class__.__name__}.log_metric()" - "was invoked before training_start()." - "This log call will be queued and processed after training_start().") - self._deferred_log_metric_calls.append((epoch, step, log_level, data)) return return self._log_metric(epoch, step, log_level, data) - - @final - def run_event(self, event: Event, state: State, logger: Logger) -> None: - if not is_rank_zero(): - return - self._run_event(event, state, logger) - if event == Event.TRAINING_START: - if self._deferred_log_metric_calls is None: - raise RuntimeError("_deferred_log_metric_calls should not be None") - for epoch, step, log_level, data in self._deferred_log_metric_calls: - self._log_metric(epoch, step, log_level, data) - self._deferred_log_metric_calls = None diff --git a/composer/loggers/file_logger.py b/composer/loggers/file_logger.py index 809f033b7d..904a973ce9 100644 --- a/composer/loggers/file_logger.py +++ b/composer/loggers/file_logger.py @@ -9,7 +9,6 @@ import yaml -from composer.core.event import Event from composer.core.logging import Logger, LogLevel, RankZeroLoggerBackend, TLogData, format_log_data_value from composer.core.state import State from composer.loggers.logger_hparams import FileLoggerBackendHparams @@ -80,13 +79,8 @@ def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData data_str = format_log_data_value(data) print(f"[{log_level.name}][step={step}]: {data_str}", file=self.file) - def _run_event(self, event: Event, state: State, logger: Logger) -> None: - if event == Event.TRAINING_START: - self._training_start(state, logger) - if event == Event.BATCH_END: - self._batch_end(state, logger) - - def _training_start(self, state: State, logger: Logger) -> None: + def init(self, state: State, logger: Logger) -> None: + del state, logger # unused if self.hparams.filename == "stdout": self.file = sys.stdout elif self.hparams.filename == "stderr": @@ -101,13 +95,27 @@ def _training_start(self, state: State, logger: Logger) -> None: print("-" * 30, file=self.file) print(file=self.file) - def _batch_end(self, state: State, logger: Logger) -> None: + def batch_end(self, state: State, logger: Logger) -> None: + del logger # unused + assert self.file is not None + if (state.step + 1) % self.hparams.flush_every_n_batches == 0: + self._flush_file() + + def epoch_end(self, state: State, logger: Logger) -> None: + del state, logger # unused + self._flush_file() + + def training_end(self, state: State, logger: Logger) -> None: + self._flush_file() + + def _flush_file(self) -> None: assert self.file is not None - if (state.step + 1) % self.hparams.flush_every_n_batches == 0 and self.file not in (sys.stdout, sys.stderr): + if self.file not in (sys.stdout, sys.stderr): self.file.flush() os.fsync(self.file.fileno()) def _close_file(self) -> None: assert self.file is not None assert self.file not in (sys.stdout, sys.stderr) + self._flush_file() self.file.close() diff --git a/composer/loggers/tqdm_logger.py b/composer/loggers/tqdm_logger.py index f6a6bceb9b..52af518790 100644 --- a/composer/loggers/tqdm_logger.py +++ b/composer/loggers/tqdm_logger.py @@ -108,7 +108,7 @@ def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData self.pbars[self.is_train].log_metric(data) def _run_event(self, event: Event, state: State, logger: Logger) -> None: - if event == Event.TRAINING_START: + if event == Event.INIT: if self.config is not None: print("Config") print("-" * 30) @@ -117,7 +117,10 @@ def _run_event(self, event: Event, state: State, logger: Logger) -> None: print() if event in (Event.EPOCH_START, Event.EVAL_START): self.is_train = event == Event.EPOCH_START - self.pbars[self.is_train] = _TQDMLoggerInstance(total=state.steps_per_epoch, + assert state.train_dataloader is not None + assert state.eval_dataloader is not None + total_steps = len(state.train_dataloader) if self.is_train else len(state.eval_dataloader) + self.pbars[self.is_train] = _TQDMLoggerInstance(total=total_steps, epoch=state.epoch, is_train=self.is_train) if event in (Event.AFTER_BACKWARD, Event.EVAL_AFTER_FORWARD): diff --git a/composer/loggers/wandb_logger.py b/composer/loggers/wandb_logger.py index e494bf9b69..0ff6e34838 100644 --- a/composer/loggers/wandb_logger.py +++ b/composer/loggers/wandb_logger.py @@ -7,7 +7,6 @@ import sys from typing import Any, Dict, Optional -from composer.core.event import Event from composer.core.logging import LogLevel, RankZeroLoggerBackend, TLogData from composer.core.types import Logger, State, StateDict from composer.utils.run_directory import get_run_directory @@ -47,18 +46,25 @@ def state_dict(self) -> StateDict: # Storing these fields in the state dict to support run resuming in the future. return {"name": wandb.run.name, "project": wandb.run.project, "entity": wandb.run.entity, "id": wandb.run.id} - def _run_event(self, event: Event, state: State, logger: Logger) -> None: - if event == Event.TRAINING_START: - wandb.init(**self._init_params) - atexit.register(self._close_wandb) + def init(self, state: State, logger: Logger) -> None: + del state, logger # unused + wandb.init(**self._init_params) + atexit.register(self._close_wandb) - if event == Event.BATCH_END: - if self._log_artifacts and (state.step + 1) % self._log_artifacts_every_n_batches == 0: - self._upload_artifacts() + def batch_end(self, state: State, logger: Logger) -> None: + del logger # unused + if self._log_artifacts and (state.step + 1) % self._log_artifacts_every_n_batches == 0: + self._upload_artifacts() + + def epoch_end(self, state: State, logger: Logger) -> None: + del state, logger # unused + if self._log_artifacts: + self._upload_artifacts() - if event == Event.EPOCH_END: - if self._log_artifacts: - self._upload_artifacts() + def training_end(self, state: State, logger: Logger) -> None: + del state, logger # unused + if self._log_artifacts: + self._upload_artifacts() def _upload_artifacts(self): # Scan the run directory and upload artifacts to wandb diff --git a/docs/source/core/callback.rst b/docs/source/core/callback.rst index 9c4b0ab020..797d010f9a 100644 --- a/docs/source/core/callback.rst +++ b/docs/source/core/callback.rst @@ -11,21 +11,41 @@ they do not modify the training of the model. By convention, callbacks should not modify the :class:`State`. -Each callback inherits from the :class:`Callback` base class, -and overrides the :meth:`~Callback.run_event` method. +Each callback inherits from the :class:`Callback` base class. +Callbacks can be implemented in two ways: -For example: +#. Override the individual methods named for each :class:`Event`. -.. code-block:: python + For example, - from composer import Callback + .. code-block:: python - class MyCallback(Callback) + from composer import Callback - def _run_event(self, event: Event, state: State, logger: Logger): - if event == Event.EPOCH_START: + class MyCallback(Callback) + + def epoch_start(self, state: State, logger: Logger): print(f'Epoch {state.epoch}/{state.max_epochs}') + +#. Override :meth:`_run_event` (**not** :meth:`run_event`) to run in response + to all events. If this method is overridden, then the individual methods + corresponding to each event name will not be automatically called (however, + the subclass implementation can invoke these methods as it wishes.) + + For example, + + .. code-block:: python + + from composer import Callback + + class MyCallback(Callback) + + def _run_event(self, event: Event, state: State, logger: Logger): + if event == Event.EPOCH_START: + print(f'Epoch {state.epoch}/{state.max_epochs}') + + .. note:: To use Composer's built in callbacks, see :doc:`/callbacks`. diff --git a/pyproject.toml b/pyproject.toml index 9b9c3162e7..eb69b16ddd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,9 +48,8 @@ markers = [ filterwarnings = [ # "error", # warnings should be treated like errors, but still need to fix some warnings 'ignore:ExtraArgumentWarning', # extra arguments originate from pytest-specific CLI args - 'ignore:DeferredLogMetricWarning', # deferred logging is fine - 'ignore:DDPDefaultValueWarning', # OK to assume no ddp - 'ignore:NoDDPWarning', # OK to assume no ddp + 'ignore:DDPDefaultValueWarning', # default DDP values are fine + 'ignore:NoDDPWarning', # default DDP values are fine ] # Coverage diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 8007993d85..92c1eef557 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -10,6 +10,16 @@ from composer.core.state import State +def test_callbacks_map_to_events(): + # callback methods must be 1:1 mapping with events + # exception for private methods + cb = Callback() + excluded_methods = ["state_dict", "load_state_dict", "run_event"] + methods = set(m for m in dir(cb) if (m not in excluded_methods and not m.startswith("_"))) + event_names = set(e.value for e in Event) + assert methods == event_names + + class EventTrackerCallback(Callback): def __init__(self) -> None: diff --git a/tests/test_logger.py b/tests/test_logger.py index 6a8af93ba5..729de160a5 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -64,7 +64,7 @@ def test_file_logger(dummy_state: State, log_destination: FileLoggerBackend, mon class TestCoreLogger: @pytest.mark.world_size(2) - def test_deferred(self, dummy_state_without_rank: State, log_file_name: str, log_destination: FileLoggerBackend): + def test_rank_zero(self, dummy_state_without_rank: State, log_file_name: str, log_destination: FileLoggerBackend): dummy_state = dummy_state_without_rank dummy_state.step = 2 dummy_state.epoch = 0 From 47158fbad2f5c1f95be461fe677fe8011f8ba35b Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Mon, 22 Nov 2021 16:30:01 -0800 Subject: [PATCH 14/38] Minimizing Diff --- composer/core/callback.py | 21 +++++++++++++++++++++ docs/source/core/callback.rst | 4 ++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/composer/core/callback.py b/composer/core/callback.py index 6a496d6fa5..c01a1f5fcc 100644 --- a/composer/core/callback.py +++ b/composer/core/callback.py @@ -68,6 +68,7 @@ def init(self, state: State, logger: Logger) -> None: def training_start(self, state: State, logger: Logger) -> None: """Called on the :attr:`Event.TRAINING_START` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -78,6 +79,7 @@ def training_start(self, state: State, logger: Logger) -> None: def epoch_start(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.EPOCH_START` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -88,6 +90,7 @@ def epoch_start(self, state: State, logger: Logger) -> None: def batch_start(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.BATCH_START` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -98,6 +101,7 @@ def batch_start(self, state: State, logger: Logger) -> None: def after_dataloader(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.AFTER_DATALOADER` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -108,6 +112,7 @@ def after_dataloader(self, state: State, logger: Logger) -> None: def before_train_batch(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.BEFORE_TRAIN_BATCH` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -118,6 +123,7 @@ def before_train_batch(self, state: State, logger: Logger) -> None: def before_forward(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.BEFORE_FORWARD` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -128,6 +134,7 @@ def before_forward(self, state: State, logger: Logger) -> None: def after_forward(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.AFTER_FORWARD` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -138,6 +145,7 @@ def after_forward(self, state: State, logger: Logger) -> None: def before_loss(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.BEFORE_LOSS` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -148,6 +156,7 @@ def before_loss(self, state: State, logger: Logger) -> None: def after_loss(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.AFTER_LOSS` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -158,6 +167,7 @@ def after_loss(self, state: State, logger: Logger) -> None: def before_backward(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.BEFORE_BACKWARD` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -168,6 +178,7 @@ def before_backward(self, state: State, logger: Logger) -> None: def after_backward(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.AFTER_BACKWARD` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -178,6 +189,7 @@ def after_backward(self, state: State, logger: Logger) -> None: def after_train_batch(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.AFTER_TRAIN_BATCH` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -188,6 +200,7 @@ def after_train_batch(self, state: State, logger: Logger) -> None: def batch_end(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.BATCH_END` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -198,6 +211,7 @@ def batch_end(self, state: State, logger: Logger) -> None: def epoch_end(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.EPOCH_END` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -208,6 +222,7 @@ def epoch_end(self, state: State, logger: Logger) -> None: def training_end(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.TRAINING_END` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -218,6 +233,7 @@ def training_end(self, state: State, logger: Logger) -> None: def eval_start(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.EVAL_START` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -228,6 +244,7 @@ def eval_start(self, state: State, logger: Logger) -> None: def eval_batch_start(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.EVAL_BATCH_START` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -238,6 +255,7 @@ def eval_batch_start(self, state: State, logger: Logger) -> None: def eval_before_forward(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.EVAL_BATCH_FORWARD` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -248,6 +266,7 @@ def eval_before_forward(self, state: State, logger: Logger) -> None: def eval_after_forward(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.EVAL_AFTER_FORWARD` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -258,6 +277,7 @@ def eval_after_forward(self, state: State, logger: Logger) -> None: def eval_batch_end(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.EVAL_BATCH_END` event. + Args: state (State): The global state. logger (Logger): The logger. @@ -268,6 +288,7 @@ def eval_batch_end(self, state: State, logger: Logger) -> None: def eval_end(self, state: State, logger: Logger) -> None: """Called on the :attr:`~Event.EVAL_END` event. + Args: state (State): The global state. logger (Logger): The logger. diff --git a/docs/source/core/callback.rst b/docs/source/core/callback.rst index 797d010f9a..3cc11ad5c8 100644 --- a/docs/source/core/callback.rst +++ b/docs/source/core/callback.rst @@ -14,7 +14,7 @@ By convention, callbacks should not modify the :class:`State`. Each callback inherits from the :class:`Callback` base class. Callbacks can be implemented in two ways: -#. Override the individual methods named for each :class:`Event`. +#. Override the individual methods named for each :class:`Event`. For example, @@ -28,7 +28,7 @@ Callbacks can be implemented in two ways: print(f'Epoch {state.epoch}/{state.max_epochs}') -#. Override :meth:`_run_event` (**not** :meth:`run_event`) to run in response +#. Override :meth:`_run_event` (**not** :meth:`run_event`) to run in response to all events. If this method is overridden, then the individual methods corresponding to each event name will not be automatically called (however, the subclass implementation can invoke these methods as it wishes.) From 35faa2919f74ebd76e0c3b45580adbda5dff55a5 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Mon, 22 Nov 2021 16:41:15 -0800 Subject: [PATCH 15/38] Fixed tests --- tests/test_logger.py | 48 +------------------------------------------- 1 file changed, 1 insertion(+), 47 deletions(-) diff --git a/tests/test_logger.py b/tests/test_logger.py index 729de160a5..f9265ab477 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -12,7 +12,6 @@ from composer.core.state import State from composer.loggers.file_logger import FileLoggerBackend from composer.loggers.logger_hparams import FileLoggerBackendHparams -from composer.utils.ddp import is_rank_zero @pytest.fixture @@ -38,7 +37,7 @@ def test_file_logger(dummy_state: State, log_destination: FileLoggerBackend, mon dummy_state.epoch = 2 logger = Logger(dummy_state, backends=[log_destination]) monkeypatch.setattr(dist, "get_rank", lambda: 0) - log_destination.run_event(Event.TRAINING_START, dummy_state, logger) + log_destination.run_event(Event.INIT, dummy_state, logger) logger.metric_fit({"metric": "fit"}) # should print logger.metric_epoch({"metric": "epoch"}) # should print logger.metric_batch({"metric": "batch"}) # should print @@ -59,48 +58,3 @@ def test_file_logger(dummy_state: State, log_destination: FileLoggerBackend, mon '[BATCH][step=2]: { "metric": "batch", }\n', '[EPOCH][step=3]: { "metric": "epoch2", }\n', ] - - -class TestCoreLogger: - - @pytest.mark.world_size(2) - def test_rank_zero(self, dummy_state_without_rank: State, log_file_name: str, log_destination: FileLoggerBackend): - dummy_state = dummy_state_without_rank - dummy_state.step = 2 - dummy_state.epoch = 0 - logger = Logger(dummy_state, backends=[log_destination]) - logger.metric_batch({"metric": "before_training_start"}) - log_destination.run_event(Event.TRAINING_START, dummy_state, logger) - logger.metric_batch({"metric": "after_training_start"}) - log_destination.run_event(Event.BATCH_END, dummy_state, logger) - log_destination.run_event(Event.TRAINING_END, dummy_state, logger) - if is_rank_zero(): - with open(log_file_name, 'r') as f: - assert f.readlines() == [ - '[BATCH][step=2]: { "metric": "before_training_start", }\n', - '[BATCH][step=2]: { "metric": "after_training_start", }\n', - ] - return - else: - assert not os.path.exists(log_file_name), "nothing should be logged on rank 1" - - def test_deep_copy(self, dummy_state_without_rank: State, log_destination: FileLoggerBackend, - monkeypatch: MonkeyPatch, log_file_name: str): - # This test ensures that the logger deepcopies the logged metric when using deferred logging - dummy_state = dummy_state_without_rank - dummy_state.step = 2 - dummy_state.epoch = 0 - logger = Logger(dummy_state, backends=[log_destination]) - metric_data = [["hello"]] - logger.metric_batch({"metric": metric_data}) - metric_data[0] = ["world"] - monkeypatch.setattr(dist, "get_rank", lambda: 0) - log_destination.run_event(Event.TRAINING_START, dummy_state, logger) - logger.metric_batch({"metric": metric_data}) - log_destination.run_event(Event.BATCH_END, dummy_state, logger) - log_destination.run_event(Event.TRAINING_END, dummy_state, logger) - with open(log_file_name, 'r') as f: - assert f.readlines() == [ - '[BATCH][step=2]: { "metric": [["hello"]], }\n', - '[BATCH][step=2]: { "metric": [["world"]], }\n', - ] From d568aa621a99b65133d1e8db5f0f6ebf9a658a15 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Mon, 22 Nov 2021 16:53:05 -0800 Subject: [PATCH 16/38] Added fasteners --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 1893c49fea..bb879548cd 100755 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ def package_files(directory): extra_deps['dev'] = [ 'junitparser>=2.1.1', 'coverage[toml]>=6.1.1', + 'fasteners>=0.16.3', # run_directory_uploader tests require fasteners 'pytest>=6.2.0', 'yapf>=0.13.0', 'isort>=5.9.3', From f0d2090966d37243dbc425aebc2adcf127d5367f Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 23 Nov 2021 11:30:15 -0800 Subject: [PATCH 17/38] Lazy population of kwargs --- composer/callbacks/callback_hparams.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/composer/callbacks/callback_hparams.py b/composer/callbacks/callback_hparams.py index 668e96687f..872e1e9a2b 100644 --- a/composer/callbacks/callback_hparams.py +++ b/composer/callbacks/callback_hparams.py @@ -196,13 +196,11 @@ class RunDirectoryUploaderHparams(CallbackHparams): def initialize_object(self) -> RunDirectoryUploader: from composer.callbacks.run_directory_uploader import RunDirectoryUploader - init_kwargs = { - "key": self.key, - "secret": self.secret, - "host": self.host, - "port": self.port, - "region": self.region, - } + init_kwargs = {} + for key in ("key", "secret", "host", "port", "region"): + kwarg = getattr(self, key) + if getattr(self, key) is not None: + init_kwargs[key] = kwarg init_kwargs.update(self.extra_init_kwargs) return RunDirectoryUploader( provider=self.provider, From 06ade34017eaa6b6af4f39cf4ad2b614774e67fe Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 23 Nov 2021 14:02:19 -0800 Subject: [PATCH 18/38] 1. Added object_name_prefix 2. Tested on google cloud storage 3. Added exponential backoff and retrying for transient errors --- composer/callbacks/callback_hparams.py | 8 +- composer/callbacks/run_directory_uploader.py | 97 ++++++++++++++++---- composer/loggers/file_logger.py | 5 +- pyproject.toml | 2 + 4 files changed, 94 insertions(+), 18 deletions(-) diff --git a/composer/callbacks/callback_hparams.py b/composer/callbacks/callback_hparams.py index 872e1e9a2b..102387c510 100644 --- a/composer/callbacks/callback_hparams.py +++ b/composer/callbacks/callback_hparams.py @@ -169,7 +169,12 @@ class RunDirectoryUploaderHparams(CallbackHparams): # :class:`~libcloud.storage.providers.Provider` constructor. These arguments would usually include the cloud region # and credentials. Defaults to None, which is equivalent to an empty dictionary. provider: str = hp.required("Cloud provider to use.") - container: str = hp.optional("he name of the container (i.e. bucket) to use.", default=None) + container: str = hp.required("he name of the container (i.e. bucket) to use.") + object_name_prefix: Optional[str] = hp.optional(textwrap.dedent("""A prefix to prepend to all object keys. + An object's key is this prefix combined with its path relative to the run directory. + If the container prefix is non-empty, a trailing slash ('/') will + be added if necessary. If not specified, then the prefix defaults to the run directory. To disable prefixing, + set to the empty string."""), default=None) key: Optional[str] = hp.optional(textwrap.dedent( """API key or username to use to connect to the provider. For security. do NOT hardcode the key in the YAML. Instead, please specify via CLI arguments, or even better, environment variables."""), @@ -205,6 +210,7 @@ def initialize_object(self) -> RunDirectoryUploader: return RunDirectoryUploader( provider=self.provider, container=self.container, + object_name_prefix=self.object_name_prefix, num_concurrent_uploads=self.num_concurrent_uploads, upload_staging_folder=self.upload_staging_folder, use_procs=self.use_procs, diff --git a/composer/callbacks/run_directory_uploader.py b/composer/callbacks/run_directory_uploader.py index 6f02318f05..eb5129c825 100644 --- a/composer/callbacks/run_directory_uploader.py +++ b/composer/callbacks/run_directory_uploader.py @@ -10,7 +10,10 @@ import threading import time import warnings +import sys +import atexit from typing import Any, Callable, Dict, Optional, Type, Union +import logging from composer.core.callback import RankZeroCallback from composer.core.event import Event @@ -19,6 +22,8 @@ from composer.core.state import State from composer.utils.run_directory import get_run_directory +log = logging.getLogger(__name__) + class RunDirectoryUploader(RankZeroCallback): """Callback to upload the run directory to a blob store. @@ -65,8 +70,15 @@ class RunDirectoryUploader(RankZeroCallback): to use S3, specify 's3' here. container (str): The name of the container (i.e. bucket) to use. + object_name_prefix (str, optional): A prefix to prepend to all object keys. An object's key is this prefix combined + with its path relative to the run directory. If the container prefix is non-empty, a trailing slash ('/') will + be added if necessary. If not specified, then the prefix defaults to the run directory. To disable prefixing, + set to the empty string. + + For example, if `object_name_prefix = 'foo'` and there is a file in the run directory named `bar`, then that file + would be uploaded to `foo/bar` in the container. num_concurrent_uploads (int, optional): Maximum number of concurrent uploads. Defaults to 4. - upload_staging_folder (Optional[str], optional): A folder to use for staging uploads. + upload_staging_folder (str, optional): A folder to use for staging uploads. If not specified, defaults to using a :class:`~tempfile.TemporaryDirectory`. use_procs (bool, optional): Whether to perform file uploads in background processes (as opposed to threads). Defaults to True. @@ -80,7 +92,8 @@ class RunDirectoryUploader(RankZeroCallback): def __init__( self, provider: str, - container: Optional[str] = None, + container: str, + object_name_prefix: Optional[str] = None, num_concurrent_uploads: int = 4, upload_staging_folder: Optional[str] = None, use_procs: bool = True, @@ -96,8 +109,17 @@ def __init__( provider_init_kwargs = {} self._provider_init_kwargs = provider_init_kwargs self._upload_every_n_batches = upload_every_n_batches - self._object_name_prefix = "" # TODO ravi. Decide how this will be set. Hparams? Run directory name? - + if object_name_prefix is None: + self._object_name_prefix = f"{run_directory}" + if not run_directory.endswith("/"): + self._object_name_prefix += "/" + else: + if object_name_prefix == "": + self._object_name_prefix = "" + else: + if not object_name_prefix.endswith('/'): + object_name_prefix = f"{object_name_prefix}/" + self._object_name_prefix = object_name_prefix self._last_upload_timestamp = 0.0 # unix timestamp of last uploaded time if upload_staging_folder is None: self._tempdir = tempfile.TemporaryDirectory() @@ -112,12 +134,15 @@ def __init__( self._provider = provider self._container = container + _validate_credentials(provider, container, self._object_name_prefix, provider_init_kwargs) + if use_procs: + mp_ctx = multiprocessing.get_context('spawn') self._file_upload_queue: Union[queue.Queue[str], - multiprocessing.JoinableQueue[str]] = multiprocessing.JoinableQueue() + multiprocessing.JoinableQueue[str]] = mp_ctx.JoinableQueue() self._finished_cls: Union[Callable[[], multiprocessing._EventType], - Type[threading.Event]] = multiprocessing.Event - self._proc_class = multiprocessing.Process + Type[threading.Event]] = mp_ctx.Event + self._proc_class = mp_ctx.Process else: self._file_upload_queue = queue.Queue() self._finished_cls = threading.Event @@ -148,6 +173,7 @@ def _run_event(self, event: Event, state: State, logger: Logger) -> None: return if event == Event.INIT: self._init() + atexit.register(self._close) if event == Event.BATCH_END: if (state.batch_idx + 1) % self._upload_every_n_batches == 0: self._trigger_upload(state, logger, LogLevel.BATCH) @@ -161,8 +187,8 @@ def _run_event(self, event: Event, state: State, logger: Logger) -> None: self._close() def _close(self): - assert self._finished is not None, "finished should not be None" - self._finished.set() + if self._finished is not None: + self._finished.set() for worker in self._workers: worker.join() @@ -178,6 +204,13 @@ def _trigger_upload(self, state: State, logger: Logger, log_level: LogLevel) -> run_directory = get_run_directory() assert run_directory is not None, "invariant error" files_to_be_uploaded = [] + + # check if any upload threads have crashed. if so, then shutdown the training process + for worker in self._workers: + if not worker.is_alive(): + # assert self._finished is not None, "invariant error" + # self._finished.set() + raise RuntimeError("Upload worker crashed unexpectedly") for root, dirs, files in os.walk(run_directory): del dirs # unused for file in files: @@ -198,6 +231,22 @@ def _trigger_upload(self, state: State, logger: Logger, log_level: LogLevel) -> # incremented past self._last_upload_timestamp logger.metric(log_level, {"run_directory/uploaded_files": files_to_be_uploaded}) +def _validate_credentials( + provider_name: str, + container_name: str, + object_name_prefix: str, + init_kwargs: Dict[str, Any], +) -> None: + # Validates the credentails by attempting to touch a file in the bucket + from libcloud.storage.providers import get_driver + provider_cls = get_driver(provider_name) + provider = provider_cls(**init_kwargs) + container = provider.get_container(container_name) + provider.upload_object_via_stream( + iterator=iter([i.to_bytes(1, sys.byteorder) for i in b"validate_credentials"]), + container=container, + object_name=f"{object_name_prefix}.validate_credentials_success", + ) def _upload_worker( file_queue: Union[queue.Queue[str], multiprocessing.JoinableQueue[str]], @@ -227,6 +276,7 @@ def _upload_worker( :class:`~libcloud.storage.providers.Provider` constructor. """ from libcloud.storage.providers import get_driver + from libcloud.common.types import LibcloudError provider_cls = get_driver(provider_name) provider = provider_cls(**init_kwargs) container = provider.get_container(container_name) @@ -241,10 +291,25 @@ def _upload_worker( continue obj_name = ",".join(os.path.relpath(file_path_to_upload, upload_staging_dir).split( os.path.sep)[1:]) # the first folder is the upload timestamp. Chop that off. - provider.upload_object( - file_path=file_path_to_upload, - container=container, - object_name=object_name_prefix + obj_name, - ) - os.remove(file_path_to_upload) - file_queue.task_done() + log.info("Uploading file %s to %s://%s/%s%s", file_path_to_upload, provider_name, container_name, object_name_prefix, obj_name) + retry_counter = 0 + while True: + try: + provider.upload_object( + file_path=file_path_to_upload, + container=container, + object_name=object_name_prefix + obj_name, + ) + except LibcloudError as e: + # The S3 driver does not encode the error code in an easy-to-parse manner + # So doing something fairly basic to retry on transient error codes + if any(x in str(e) for x in ("408", "409", "425", "429", "500", "503", '504')): + if retry_counter < 3: + retry_counter += 1 + # exponential backoff + time.sleep(2**(retry_counter - 1)) + continue + raise e + os.remove(file_path_to_upload) + file_queue.task_done() + break diff --git a/composer/loggers/file_logger.py b/composer/loggers/file_logger.py index 904a973ce9..2d9f86785a 100644 --- a/composer/loggers/file_logger.py +++ b/composer/loggers/file_logger.py @@ -12,6 +12,7 @@ from composer.core.logging import Logger, LogLevel, RankZeroLoggerBackend, TLogData, format_log_data_value from composer.core.state import State from composer.loggers.logger_hparams import FileLoggerBackendHparams +from composer.utils.run_directory import get_relative_to_run_directory class FileLoggerBackend(RankZeroLoggerBackend): @@ -86,7 +87,9 @@ def init(self, state: State, logger: Logger) -> None: elif self.hparams.filename == "stderr": self.file = sys.stderr else: - self.file = open(self.hparams.filename, "x+", buffering=self.hparams.buffer_size) + self.file = open(get_relative_to_run_directory(self.hparams.filename), + "x+", + buffering=self.hparams.buffer_size) atexit.register(self._close_file) if self.config is not None: print("Config", file=self.file) diff --git a/pyproject.toml b/pyproject.toml index eb69b16ddd..a4fcf24eaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,8 @@ reportUnusedImport = "error" reportMissingModuleSource = "none" reportPrivateImportUsage = "warning" reportUndefinedVariable = "error" +reportUnboundVariable = "error" +strictParameterNoneValue = "error" pythonVersion = "3.8" pythonPlatform = "Linux" From e5614630b9d34071e64990f791292d04e2a5f863 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 23 Nov 2021 14:35:02 -0800 Subject: [PATCH 19/38] Addressed PR feedback --- composer/callbacks/callback_hparams.py | 9 +++------ composer/callbacks/run_directory_uploader.py | 19 ++++++++++--------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/composer/callbacks/callback_hparams.py b/composer/callbacks/callback_hparams.py index 102387c510..814ee5165f 100644 --- a/composer/callbacks/callback_hparams.py +++ b/composer/callbacks/callback_hparams.py @@ -164,17 +164,14 @@ class RunDirectoryUploaderHparams(CallbackHparams): See :class:`~composer.callbacks.torch_profiler.RunDirectoryUploader` for documentation. """ - # Args: - # provider_init_kwargs (Dict[str, Any], optional): Parameters to pass into the constructor for the - # :class:`~libcloud.storage.providers.Provider` constructor. These arguments would usually include the cloud region - # and credentials. Defaults to None, which is equivalent to an empty dictionary. provider: str = hp.required("Cloud provider to use.") - container: str = hp.required("he name of the container (i.e. bucket) to use.") + container: str = hp.required("The name of the container (i.e. bucket) to use.") object_name_prefix: Optional[str] = hp.optional(textwrap.dedent("""A prefix to prepend to all object keys. An object's key is this prefix combined with its path relative to the run directory. If the container prefix is non-empty, a trailing slash ('/') will be added if necessary. If not specified, then the prefix defaults to the run directory. To disable prefixing, - set to the empty string."""), default=None) + set to the empty string."""), + default=None) key: Optional[str] = hp.optional(textwrap.dedent( """API key or username to use to connect to the provider. For security. do NOT hardcode the key in the YAML. Instead, please specify via CLI arguments, or even better, environment variables."""), diff --git a/composer/callbacks/run_directory_uploader.py b/composer/callbacks/run_directory_uploader.py index eb5129c825..81586e69e0 100644 --- a/composer/callbacks/run_directory_uploader.py +++ b/composer/callbacks/run_directory_uploader.py @@ -2,18 +2,18 @@ from __future__ import annotations +import atexit +import logging import multiprocessing import os import queue import shutil +import sys import tempfile import threading import time import warnings -import sys -import atexit from typing import Any, Callable, Dict, Optional, Type, Union -import logging from composer.core.callback import RankZeroCallback from composer.core.event import Event @@ -140,8 +140,7 @@ def __init__( mp_ctx = multiprocessing.get_context('spawn') self._file_upload_queue: Union[queue.Queue[str], multiprocessing.JoinableQueue[str]] = mp_ctx.JoinableQueue() - self._finished_cls: Union[Callable[[], multiprocessing._EventType], - Type[threading.Event]] = mp_ctx.Event + self._finished_cls: Union[Callable[[], multiprocessing._EventType], Type[threading.Event]] = mp_ctx.Event self._proc_class = mp_ctx.Process else: self._file_upload_queue = queue.Queue() @@ -231,6 +230,7 @@ def _trigger_upload(self, state: State, logger: Logger, log_level: LogLevel) -> # incremented past self._last_upload_timestamp logger.metric(log_level, {"run_directory/uploaded_files": files_to_be_uploaded}) + def _validate_credentials( provider_name: str, container_name: str, @@ -248,6 +248,7 @@ def _validate_credentials( object_name=f"{object_name_prefix}.validate_credentials_success", ) + def _upload_worker( file_queue: Union[queue.Queue[str], multiprocessing.JoinableQueue[str]], is_finished: Union[multiprocessing._EventType, threading.Event], @@ -275,23 +276,23 @@ def _upload_worker( init_kwargs (Dict[str, Any]): Arguments to pass in to the :class:`~libcloud.storage.providers.Provider` constructor. """ - from libcloud.storage.providers import get_driver from libcloud.common.types import LibcloudError + from libcloud.storage.providers import get_driver provider_cls = get_driver(provider_name) provider = provider_cls(**init_kwargs) container = provider.get_container(container_name) while True: try: - file_path_to_upload = file_queue.get_nowait() + file_path_to_upload = file_queue.get(block=True, timeout=0.5) except queue.Empty: if is_finished.is_set(): break else: - time.sleep(0.5) continue obj_name = ",".join(os.path.relpath(file_path_to_upload, upload_staging_dir).split( os.path.sep)[1:]) # the first folder is the upload timestamp. Chop that off. - log.info("Uploading file %s to %s://%s/%s%s", file_path_to_upload, provider_name, container_name, object_name_prefix, obj_name) + log.info("Uploading file %s to %s://%s/%s%s", file_path_to_upload, provider_name, container_name, + object_name_prefix, obj_name) retry_counter = 0 while True: try: From f3aa6bdcf7d86e5994ae94e87f46f122b9c7e6f1 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 23 Nov 2021 15:41:13 -0800 Subject: [PATCH 20/38] Remove the composer.trainer.ddp class Before #65, composer.trainer.ddp ensured that DDP functionality was accessed only after ddp was initialized. Now, DDP is available from process start, so this class is no longer needed. Moved all the functionality from this class to the global composer.utils.ddp. This change allows callbacks, algroithms, etc... to use DDP (such as barriers and reductions) as needed. #97 and #101 depend on this functionality. Also removed DDP from the state, as that is available globally. --- .../curriculum_learning.py | 6 +- .../seq_length_warmup/seq_length_warmup.py | 6 +- composer/callbacks/benchmarker.py | 3 +- composer/callbacks/speed_monitor.py | 3 +- composer/core/callback.py | 4 +- composer/core/logging/base_backend.py | 6 +- composer/core/state.py | 21 -- composer/datasets/__init__.py | 1 + composer/datasets/dataloader.py | 42 ++- composer/trainer/__init__.py | 1 - composer/trainer/checkpoint.py | 30 +- composer/trainer/ddp.py | 262 ------------------ composer/trainer/devices/device_gpu.py | 4 +- composer/trainer/trainer.py | 57 ++-- composer/trainer/trainer_hparams.py | 13 +- composer/utils/__init__.py | 6 +- composer/utils/ddp.py | 210 +++++++++++++- tests/test_logger.py | 4 +- tests/trainer/test_checkpoint.py | 4 +- tests/trainer/test_ddp.py | 3 +- tests/trainer/test_ddp_sync_strategy.py | 14 +- tests/trainer/test_trainer.py | 4 +- tests/utils/dataloader.py | 3 +- tests/utils/trainer_fit.py | 11 +- 24 files changed, 333 insertions(+), 385 deletions(-) delete mode 100755 composer/trainer/ddp.py diff --git a/composer/algorithms/curriculum_learning/curriculum_learning.py b/composer/algorithms/curriculum_learning/curriculum_learning.py index e0d85b234e..566d118b3a 100644 --- a/composer/algorithms/curriculum_learning/curriculum_learning.py +++ b/composer/algorithms/curriculum_learning/curriculum_learning.py @@ -10,7 +10,7 @@ from composer.algorithms import AlgorithmHparams from composer.core.types import Algorithm, Batch, Event, Logger, State, Tensor from composer.models.transformer_shared import MosaicTransformer -from composer.utils import ensure_tuple +from composer.utils import ddp, ensure_tuple def apply_curriculum(batch: Dict[str, Tensor], curr_seq_len: int, truncate: bool) -> Batch: @@ -153,8 +153,8 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: # all of the parameters device = next(state.model.parameters()).device - assert (state.train_batch_size % state.world_size) == 0 - per_gpu_batch = math.ceil(state.train_batch_size / (state.world_size * state.grad_accum)) + assert (state.train_batch_size % ddp.get_world_size()) == 0 + per_gpu_batch = math.ceil(state.train_batch_size / (ddp.get_world_size() * state.grad_accum)) input_ids = torch.randint(low=0, high=vocab_size - 1, size=(per_gpu_batch, self.hparams.max_seq_length), diff --git a/composer/algorithms/seq_length_warmup/seq_length_warmup.py b/composer/algorithms/seq_length_warmup/seq_length_warmup.py index e6d5d572a5..2b3d0a50cd 100644 --- a/composer/algorithms/seq_length_warmup/seq_length_warmup.py +++ b/composer/algorithms/seq_length_warmup/seq_length_warmup.py @@ -10,7 +10,7 @@ from composer.algorithms import AlgorithmHparams from composer.core.types import Algorithm, Batch, Event, Logger, State, Tensor from composer.models.transformer_shared import MosaicTransformer -from composer.utils import ensure_tuple +from composer.utils import ddp, ensure_tuple def apply_seq_length_warmup(batch: Dict[str, Tensor], curr_seq_len: int, truncate: bool) -> Batch: @@ -180,8 +180,8 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: # all of the parameters device = next(state.model.parameters()).device - assert (state.train_batch_size % state.world_size) == 0 - per_gpu_batch = math.ceil(state.train_batch_size / (state.world_size * state.grad_accum)) + assert (state.train_batch_size % ddp.get_world_size()) == 0 + per_gpu_batch = math.ceil(state.train_batch_size / (ddp.get_world_size() * state.grad_accum)) input_ids = torch.randint(low=0, high=vocab_size - 1, size=(per_gpu_batch, self.hparams.max_seq_length), diff --git a/composer/callbacks/benchmarker.py b/composer/callbacks/benchmarker.py index 9e72b2339e..f89495c15c 100644 --- a/composer/callbacks/benchmarker.py +++ b/composer/callbacks/benchmarker.py @@ -11,6 +11,7 @@ from composer.callbacks.callback_hparams import BenchmarkerHparams from composer.core.callback import Callback from composer.core.types import BreakEpochException +from composer.utils import ddp log = logging.getLogger(__name__) @@ -158,7 +159,7 @@ def batch_end(self, state: State, logger: Logger): now = time.time() elapsed = now - self.current_time self.current_time = now - self.profile_examples += state.last_batch_size * state.world_size + self.profile_examples += state.last_batch_size * ddp.get_world_size() self.profile_steps += 1 self.profile_time += elapsed diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index c2d031e6f5..a7e9a74f6a 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -10,6 +10,7 @@ from composer.callbacks.callback_hparams import SpeedMonitorHparams from composer.core.callback import RankZeroCallback from composer.core.types import StateDict +from composer.utils import ddp class SpeedMonitor(RankZeroCallback): @@ -83,7 +84,7 @@ def batch_end(self, state: State, logger: Logger): # Ideally, callbacks would have a way of reducing tensors. # It assumes that each process has equal batch sizing # For the speed monitor, we might be able to use the static step converter with num_samples - batch_num_samples *= state.world_size + batch_num_samples *= ddp.get_world_size() self.batch_num_samples.append(batch_num_samples) self.train_examples_per_epoch += batch_num_samples if len(self.batch_end_times) == self.hparams.window_size + 1: diff --git a/composer/core/callback.py b/composer/core/callback.py index 94b8f925b0..f6a96af416 100644 --- a/composer/core/callback.py +++ b/composer/core/callback.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Callable from composer.core.serializable import Serializable -from composer.utils.ddp import is_rank_zero +from composer.utils.ddp import get_global_rank if TYPE_CHECKING: from composer import Logger, State @@ -303,7 +303,7 @@ def wrapped_fn( original_fn: Callable[[State, Logger], None] = original_fn, **kwargs: Any, ) -> None: - if not is_rank_zero(): + if not get_global_rank() == 0: return return original_fn(*args, **kwargs) diff --git a/composer/core/logging/base_backend.py b/composer/core/logging/base_backend.py index ef99ac79be..86057cad48 100644 --- a/composer/core/logging/base_backend.py +++ b/composer/core/logging/base_backend.py @@ -8,7 +8,7 @@ from composer.core.callback import Callback, RankZeroCallback from composer.core.logging.logger import Logger -from composer.utils.ddp import is_rank_zero +from composer.utils.ddp import get_global_rank if TYPE_CHECKING: from composer.core.logging.logger import LogLevel, TLogData @@ -116,7 +116,7 @@ def _will_log(self, state: State, log_level: LogLevel) -> bool: @final def will_log(self, state: State, log_level: LogLevel) -> bool: - if not state.is_rank_zero: + if get_global_rank() != 0: return False return self._will_log(state, log_level) @@ -138,7 +138,7 @@ def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData @final def log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData) -> None: - if not is_rank_zero(): + if get_global_rank() != 0: # no log if not on rank zero, clear deferred calls to free memory self._deferred_log_metric_calls = None return diff --git a/composer/core/state.py b/composer/core/state.py index 98b22bb890..8dd1719bd6 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -16,7 +16,6 @@ from composer.core.precision import Precision from composer.core.serializable import Serializable from composer.utils import ensure_tuple -from composer.utils.ddp import get_global_rank, get_local_rank, get_local_world_size, get_world_size from composer.utils.precision import default_precision_factory if TYPE_CHECKING: @@ -142,26 +141,6 @@ class State(Serializable): algorithms: Sequence[Algorithm] = tuple() callbacks: Sequence[Callback] = tuple() - @property - def world_size(self) -> int: - return get_world_size() - - @property - def global_rank(self) -> int: - return get_global_rank() - - @property - def local_world_size(self) -> int: - return get_local_world_size() - - @property - def local_rank(self) -> int: - return get_local_rank() - - @property - def is_rank_zero(self) -> bool: - return self.global_rank == 0 - def state_dict(self) -> types.StateDict: """Returns the state as a :class:`dict`.""" state_dict: types.StateDict = {} diff --git a/composer/datasets/__init__.py b/composer/datasets/__init__.py index e18bf8e24f..ddbe47c78d 100644 --- a/composer/datasets/__init__.py +++ b/composer/datasets/__init__.py @@ -3,6 +3,7 @@ from composer.datasets.brats import BratsDatasetHparams as BratsDatasetHparams from composer.datasets.cifar10 import CIFAR10DatasetHparams as CIFAR10DatasetHparams from composer.datasets.dataloader import DataloaderHparams as DataloaderHparams +from composer.datasets.dataloader import DDPDataLoader as DDPDataLoader from composer.datasets.dataloader import WrappedDataLoader as WrappedDataLoader from composer.datasets.hparams import DataloaderSpec as DataloaderSpec from composer.datasets.hparams import DatasetHparams as DatasetHparams diff --git a/composer/datasets/dataloader.py b/composer/datasets/dataloader.py index 754454551d..0d94096720 100644 --- a/composer/datasets/dataloader.py +++ b/composer/datasets/dataloader.py @@ -2,13 +2,15 @@ from __future__ import annotations +import warnings from dataclasses import dataclass -from typing import Any, Iterator +from typing import Any, Iterator, Optional import torch import torch.distributed import torch.utils.data import yahp as hp +from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import Sampler from composer.core.types import Batch, DataLoader @@ -44,6 +46,44 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +class DDPDataLoader(WrappedDataLoader): + """Ensure sampler.set_epoch() is called after each iteration. + + DDPDataLoader wraps a dataloader and a distributed sampler and is + called after each iteration (epoch) through the dataset. + See: https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler + """ + + def __init__(self, dataloader: DataLoader) -> None: + super().__init__(dataloader) + if not isinstance(self.dataloader.sampler, DistributedSampler): + raise ValueError("When using the DDP data loader, the sampler must be a DistributedSampler") + self._iterator: Optional[Iterator[Batch]] = None + + def __iter__(self) -> DDPDataLoader: + if self._iterator is not None: + warnings.warn( + "DataloaderMultipleIterationWarning: " + "The dataloader detected the start of a new iteration before the previous iteration finished. " + "The dataloader is skipping ahead to the start of the next epoch. " + "Multiple simultaneous iterations through the DDP dataloader prohibited, since " + "it automatically tracks the current epoch.") + assert isinstance(self.sampler, DistributedSampler) + self.sampler.set_epoch(epoch=self.sampler.epoch + 1) + self._iterator = iter(self.dataloader) + return self + + def __next__(self) -> Batch: + assert self._iterator is not None + try: + return next(self._iterator) + except StopIteration: + self._iterator = None + assert isinstance(self.sampler, DistributedSampler) + self.sampler.set_epoch(epoch=self.sampler.epoch + 1) + raise + + @dataclass class DataloaderHparams(hp.Hparams): """Hyperparameters to initialize a ``torch.utils.data.Dataloader``.""" diff --git a/composer/trainer/__init__.py b/composer/trainer/__init__.py index 9802b3dc56..b3d7cae6ff 100644 --- a/composer/trainer/__init__.py +++ b/composer/trainer/__init__.py @@ -1,7 +1,6 @@ # Copyright 2021 MosaicML. All Rights Reserved. from composer.trainer import devices as devices -from composer.trainer.ddp import DDPDataLoader as DDPDataLoader from composer.trainer.trainer import Trainer as Trainer from composer.trainer.trainer_hparams import TrainerHparams as TrainerHparams diff --git a/composer/trainer/checkpoint.py b/composer/trainer/checkpoint.py index 145b28fd33..b73332c5a2 100644 --- a/composer/trainer/checkpoint.py +++ b/composer/trainer/checkpoint.py @@ -12,9 +12,8 @@ from composer.core import Event, State from composer.core.types import StateDict -from composer.trainer.ddp import DDP from composer.trainer.devices.device import Device -from composer.utils import seed_all +from composer.utils import ddp, seed_all log = logging.getLogger(__name__) @@ -46,25 +45,25 @@ def restore_checkpoint_rng_state(self, state: State, device: Device): if self.checkpoint_rng_state is None: return - assert state.world_size == len( + assert ddp.get_world_size() == len( self.checkpoint_rng_state['torch'] ), f"invariant violation: if the rng state is being restored, then" \ "the world size should be the same as in the checkpoint." - torch.set_rng_state(self.checkpoint_rng_state['torch'][state.global_rank]) - device.load_state_dict(self.checkpoint_rng_state['device'][state.global_rank]) - random.setstate(self.checkpoint_rng_state['python'][state.global_rank]) - np.random.set_state(self.checkpoint_rng_state['numpy'][state.global_rank]) + torch.set_rng_state(self.checkpoint_rng_state['torch'][ddp.get_global_rank()]) + device.load_state_dict(self.checkpoint_rng_state['device'][ddp.get_global_rank()]) + random.setstate(self.checkpoint_rng_state['python'][ddp.get_global_rank()]) + np.random.set_state(self.checkpoint_rng_state['numpy'][ddp.get_global_rank()]) self.checkpoint_rng_state = None def _get_checkpoint_rng_state(self, state: State, checkpoint_rng_state: StateDict) -> Optional[StateDict]: original_world_size = len(checkpoint_rng_state["torch"]) - if original_world_size == state.world_size: + if original_world_size == ddp.get_world_size(): return checkpoint_rng_state else: warnings.warn(f"The checkpoint was created with world_size({original_world_size}), " - f"which differs from the current world_size({state.world_size})." + f"which differs from the current world_size({ddp.get_world_size()})." f"RNG state will not be restored.") @@ -104,12 +103,7 @@ def should_checkpoint(self, state: State, event: Event) -> bool: return state.step % self.save_interval == 0 return False - def save_checkpoint(self, - state: State, - seed: int, - device: Device, - ddp: DDP, - config: Optional[Dict[str, Any]] = None) -> None: + def save_checkpoint(self, state: State, seed: int, device: Device, config: Optional[Dict[str, Any]] = None) -> None: """Save the current state to a a new checkpoint file. Args: @@ -126,10 +120,10 @@ def save_checkpoint(self, # This will be fixed by: https://github.com/mosaicml/composer/issues/12 state_dict = { 'state': state.state_dict(), # should be the same across all ranks. per-rank state not stored - 'rng': self._get_rng_state(device=device, ddp=ddp), # stored across all ranks + 'rng': self._get_rng_state(device=device), # stored across all ranks 'seed': seed, } - if not state.is_rank_zero: + if ddp.get_global_rank() != 0: # only rank 0 saves checkpoints # Need the check down here so all the DDP syncs will work for generating the checkpoint return @@ -160,7 +154,7 @@ def save_checkpoint(self, torch.save(state_dict, f) log.info(f'Trainer checkpoint saved to {save_file}') - def _get_rng_state(self, device: Device, ddp: DDP) -> StateDict: + def _get_rng_state(self, device: Device) -> StateDict: rng_state = { "python": ddp.all_gather_object(random.getstate()), "numpy": ddp.all_gather_object(np.random.get_state()), diff --git a/composer/trainer/ddp.py b/composer/trainer/ddp.py deleted file mode 100755 index 977baa34ff..0000000000 --- a/composer/trainer/ddp.py +++ /dev/null @@ -1,262 +0,0 @@ -# Copyright 2021 MosaicML. All Rights Reserved. - -from __future__ import annotations - -import collections.abc -import datetime -import logging -import os -import warnings -from contextlib import contextmanager, nullcontext -from dataclasses import dataclass -from typing import Callable, ContextManager, Iterator, List, Optional, Sequence, TypeVar, cast - -import torch -import torch.distributed -import torch.utils.data -import yahp as hp -from torch.nn.parallel import DistributedDataParallel -from torch.utils.data.distributed import DistributedSampler - -from composer.core.state import State -from composer.core.types import Batch, DataLoader, Model, Tensor -from composer.datasets import DataloaderHparams, DataloaderSpec, WrappedDataLoader -from composer.utils.ddp import get_world_size -from composer.utils.iter_helpers import ensure_tuple -from composer.utils.string_enum import StringEnum - -logger = logging.getLogger(__name__) - -TObj = TypeVar("TObj") - -CLEANUP_TIMEOUT = datetime.timedelta(seconds=5) - - -class DataloaderMultipleIterationWarning(Warning): - pass - - -class DDPDataLoader(WrappedDataLoader): - """Ensure sampler.set_epoch() is called after each iteration. - - DDPDataLoader wraps a dataloader and a distributed sampler and is - called after each iteration (epoch) through the dataset. - See: https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler - """ - - def __init__(self, dataloader: DataLoader) -> None: - super().__init__(dataloader) - if not isinstance(self.dataloader.sampler, DistributedSampler): - raise ValueError("When using the DDP data loader, the sampler must be a DistributedSampler") - self._iterator: Optional[Iterator[Batch]] = None - - def __iter__(self) -> DDPDataLoader: - if self._iterator is not None: - warnings.warn( - "The dataloader detected the start of a new iteration before the previous iteration finished. " - "The dataloader is skipping ahead to the start of the next epoch. " - "Multiple simultaneous iterations through the DDP dataloader prohibited, since " - "it automatically tracks the current epoch.", - category=DataloaderMultipleIterationWarning) - assert isinstance(self.sampler, DistributedSampler) - self.sampler.set_epoch(epoch=self.sampler.epoch + 1) - self._iterator = iter(self.dataloader) - return self - - def __next__(self) -> Batch: - assert self._iterator is not None - try: - return next(self._iterator) - except StopIteration: - self._iterator = None - assert isinstance(self.sampler, DistributedSampler) - self.sampler.set_epoch(epoch=self.sampler.epoch + 1) - raise - - -class DDPSyncStrategy(StringEnum): - """How and when DDP gradient synchronization should happen. - - Attributes: - SINGLE_AUTO_SYNC: The default behavior for DDP. Gradients are synchronized as they - computed, for only the final microbatch of a batch. This is the most efficient - strategy, but can lead to errors when ``find_unused_parameters`` is set, since - it is possible different microbatches may use different sets of parameters, - leading to an incomplete sync. - MULTI_AUTO_SYNC: The default behavior for DDP when ``find_unused_parameters`` is set. - Gradients are synchronized as they are computed for all microbatches. This ensures - complete synchronization, but is less efficient than :attr:`SINGLE_AUTO_SYNC`. This - efficiency gap is usually small, as long as either DDP syncs are a small portion - of the trainer's overall runtime, or the number of microbatches per batch is - relatively small. - FORCED_SYNC: Gradients are manually synchronized only after all gradients have been - computed for the final microbatch of a batch. Like :attr:`MULTI_AUTO_SYNC`, this - strategy ensures complete gradient synchronization, but this tends to be slower than - :attr:`MULTI_AUTO_SYNC`. This is because ordinarily syncs can happen in parallel - with the ``loss.backward()`` computation, meaning syncs can be mostly complete by - the time that function finishes. However, in certain circumstances, syncs may take - a very long time to complete - if there are also a lot of microbatches per batch, - this strategy may be optimal. - """ - SINGLE_AUTO_SYNC = "single_auto_sync" - MULTI_AUTO_SYNC = "multi_auto_sync" - FORCED_SYNC = "forced_sync" - - -class DDP: - - def __init__(self, - *, - backend: str, - timeout: float, - find_unused_parameters: bool = False, - sync_strategy: Optional[str] = None): - self.backend = backend - self.find_unused_parameters = find_unused_parameters - if sync_strategy is None: - self.sync_strategy = DDPSyncStrategy.SINGLE_AUTO_SYNC if not find_unused_parameters else DDPSyncStrategy.FORCED_SYNC - else: - self.sync_strategy = DDPSyncStrategy(sync_strategy) - - _timeout = datetime.timedelta(seconds=timeout) - - if torch.distributed.is_initialized(): - - if not torch.distributed.get_backend() == self.backend.lower(): - raise RuntimeError( - f"The requested backend ({self.backend}) differs from the backend " - "of the current process group ({torch.distributed.get_backend()}). If you wish to change backends, " - "please restart the python process.") - return - - if "RANK" in os.environ and "WORLD_SIZE" in os.environ: - # Assume we can initialize based off of env vars - torch.distributed.init_process_group(self.backend, timeout=_timeout) - return - - warnings.warn("NoDDPWarning: RANK and WORLD_SIZE env vars not set; assuming no parallelization. " - "If this is unexpected, make sure you are running your training script with the " - "composer executable.") - store = torch.distributed.HashStore() - - torch.distributed.init_process_group(self.backend, timeout=_timeout, store=store, world_size=1, rank=0) - - @property - def world_size(self) -> int: - return get_world_size() - - def barrier(self) -> None: - if torch.distributed.is_available(): - torch.distributed.barrier() - # If not on DDP, then do nothing - - def all_reduce( - self, - tensor: torch.Tensor, - reduce_operation: str = "SUM", - ) -> None: - if torch.distributed.is_available(): - reduce_op = getattr(torch.distributed.ReduceOp, reduce_operation.upper()) - torch.distributed.all_reduce(tensor, op=reduce_op) - else: - raise NotImplementedError("Non-DDP versions of reduce operations are not yet implemented") - - def all_gather(self, tensor: torch.Tensor) -> Sequence[Tensor]: - """gather_to_rank_zero collects a tensor from each rank, and returns a sequence of tensors indexed by rank - - Args: - tensor (torch.Tensor): tensor from each rank to be gathered - - Returns: - Sequence[Tensor]: A sequence of tensors indexed by rank - """ - if torch.distributed.is_available(): - obj_gather_list = [torch.zeros_like(tensor) for _ in range(self.world_size)] - torch.distributed.all_gather(obj_gather_list, tensor) - return obj_gather_list - else: - return [tensor] - - def all_gather_object(self, obj: TObj) -> List[TObj]: - """gather_object_to_rank_zero collects a pickleable object from each rank, and returns a list of - these objects indexed by rank - - Args: - obj (TObj): Object to be gathered - - Returns: - List[TObj]: A list of objects indexed by rank - """ - if torch.distributed.is_available(): - obj_gather_list = [None for _ in range(self.world_size)] - torch.distributed.all_gather_object(obj_gather_list, obj) - # torch.distributed will replace the None's in obj_gather_list with the gathered objects on rank 0 - # or will just be None on non-rank-0 - return cast(List[TObj], obj_gather_list) - else: - return [obj] - - def prepare_module(self, module: Model) -> Model: - if torch.distributed.is_available(): - if any((p.requires_grad for p in module.parameters())): - ddp_model = DistributedDataParallel(module, find_unused_parameters=self.find_unused_parameters) - return cast(Model, ddp_model) - return module - else: - return module - - def create_dataloader(self, batch_size: int, dataloader_hparams: DataloaderHparams, - dataloader_spec: DataloaderSpec) -> DataLoader: - if torch.distributed.is_available(): - sampler = torch.utils.data.DistributedSampler[int](dataloader_spec.dataset, - drop_last=dataloader_spec.drop_last, - shuffle=dataloader_spec.shuffle) - else: - assert isinstance(dataloader_spec.dataset, collections.abc.Sized) - sampler = torch.utils.data.RandomSampler(dataloader_spec.dataset, generator=dataloader_spec.generator) - dataloader = dataloader_hparams.initialize_object(batch_size, sampler, dataloader_spec) - if torch.distributed.is_available(): - dataloader = DDPDataLoader(dataloader) - return dataloader - - @contextmanager - def sync_context(self, state: State, is_final_microbatch: bool): - assert isinstance(state.model, DistributedDataParallel), "state.model is not wrapped by DDP" - assert state.optimizers is not None, "optimizers have not been initialized" - - no_sync_context = cast(Callable[[], ContextManager], state.model.no_sync) - auto_sync_context = nullcontext - - if self.sync_strategy == DDPSyncStrategy.SINGLE_AUTO_SYNC: - context = auto_sync_context if is_final_microbatch else no_sync_context - with context(): - yield - - elif self.sync_strategy == DDPSyncStrategy.MULTI_AUTO_SYNC: - with auto_sync_context(): - yield - - elif self.sync_strategy == DDPSyncStrategy.FORCED_SYNC: - try: - with no_sync_context(): - yield - finally: - if is_final_microbatch: - for optimizer in ensure_tuple(state.optimizers): - for group in optimizer.param_groups: - for p in group["params"]: - if p.grad is not None: - self.all_reduce(p.grad) - p.grad = p.grad / state.world_size - - else: - raise ValueError("Unknown sync strategy", self.sync_strategy) - - -@dataclass -class DDPHparams(hp.Hparams): - sync_strategy: Optional[str] = hp.optional( - doc="The strategy for synchronizing DDP. Default value ``None`` causes the " - "trainer to auto-select a value depending on what algorithms are used.", - default=None) - timeout: float = hp.optional(doc="Timeout, in seconds, for initializing the DDP process group.", default=5.0) diff --git a/composer/trainer/devices/device_gpu.py b/composer/trainer/devices/device_gpu.py index cd9fb0225e..4a29ded9d2 100755 --- a/composer/trainer/devices/device_gpu.py +++ b/composer/trainer/devices/device_gpu.py @@ -13,7 +13,7 @@ from composer.core.types import Batch, BatchPair, DataLoader, Precision, StateDict, Tensor, Tensors, TPrefetchFn from composer.datasets.dataloader import WrappedDataLoader from composer.trainer.devices.device import Device, T_nnModule -from composer.utils import map_collection +from composer.utils import ddp, map_collection class CudaDataLoader(WrappedDataLoader): @@ -112,7 +112,7 @@ def __init__( def prepare(self, state: State) -> None: if self._device is not None: raise ValueError("device is already set") - gpu = state.local_rank + gpu = ddp.get_local_rank() self._device = torch.device(f"cuda:{gpu}") torch.cuda.set_device(self._device) assert torch.cuda.current_device() == gpu diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 6f62dd3dd4..edfe77ba3f 100755 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -4,6 +4,7 @@ import collections.abc import contextlib +import datetime import logging import warnings from typing import Any, Dict, List, Optional, Sequence, Union @@ -28,12 +29,11 @@ SchedulerHparams, WarmUpLRHparams) from composer.optim.scheduler import ensure_warmup_last from composer.trainer.checkpoint import Checkpointer, CheckpointLoader -from composer.trainer.ddp import DDP, DataloaderMultipleIterationWarning from composer.trainer.devices.device import Device from composer.trainer.devices.device_cpu import DeviceCPU from composer.trainer.scaler import ClosureGradScaler from composer.trainer.trainer_hparams import TrainerHparams -from composer.utils import ensure_tuple, get_random_seed, map_collection, seed_all +from composer.utils import ddp, ensure_tuple, get_random_seed, map_collection, seed_all from composer.utils.run_directory import get_relative_to_run_directory log = logging.getLogger(__name__) @@ -149,7 +149,7 @@ def __init__( timeout: int = 0, # ddp hparams - ddp_sync_strategy: Optional[str] = None, + ddp_sync_strategy: Optional[Union[str, ddp.DDPSyncStrategy]] = None, ddp_timeout: float = 5.0, # Randomness @@ -174,8 +174,6 @@ def __init__( self.config = config - self.ddp_sync_strategy = ddp_sync_strategy - if not device: device = DeviceCPU() self.device = device @@ -196,12 +194,14 @@ def __init__( self.backwards_create_graph = any(map(lambda x: x.backwards_create_graph, algorithms)) find_unused_parameters = any(map(lambda x: x.find_unused_parameters, algorithms)) - self.ddp = DDP( - backend=self.device.ddp_backend, - find_unused_parameters=find_unused_parameters, - sync_strategy=ddp_sync_strategy, - timeout=ddp_timeout, - ) + + self.find_unused_parameters = find_unused_parameters + if ddp_sync_strategy is None: + self.ddp_sync_strategy = ddp.DDPSyncStrategy.SINGLE_AUTO_SYNC if not find_unused_parameters else ddp.DDPSyncStrategy.FORCED_SYNC + else: + self.ddp_sync_strategy = ddp.DDPSyncStrategy(ddp_sync_strategy) + + ddp.initialize_ddp(device.ddp_backend, datetime.timedelta(seconds=ddp_timeout)) dl_hparams = DataloaderHparams(num_workers=num_workers, prefetch_factor=prefetch_factor, @@ -209,16 +209,16 @@ def __init__( pin_memory=pin_memory, timeout=timeout) - train_gpu_batch_size = train_batch_size // self.ddp.world_size + train_gpu_batch_size = train_batch_size // ddp.get_world_size() train_dataloader = self.device.dataloader_to_device( - self.ddp.create_dataloader(train_gpu_batch_size, dl_hparams, train_dataloader_spec), + ddp.create_dataloader(train_gpu_batch_size, dl_hparams, train_dataloader_spec), train_dataloader_spec.prefetch_fn, ) self.train_dl_spec = train_dataloader_spec - eval_gpu_batch_size = eval_batch_size // self.ddp.world_size + eval_gpu_batch_size = eval_batch_size // ddp.get_world_size() eval_dataloader = self.device.dataloader_to_device( - self.ddp.create_dataloader(eval_gpu_batch_size, dl_hparams, eval_dataloader_spec), + ddp.create_dataloader(eval_gpu_batch_size, dl_hparams, eval_dataloader_spec), eval_dataloader_spec.prefetch_fn, ) self.eval_dl_spec = eval_dataloader_spec @@ -344,8 +344,8 @@ def create_from_hparams(cls, hparams: TrainerHparams) -> Trainer: timeout=hparams.dataloader.timeout, # ddp hparams - ddp_sync_strategy=hparams.ddp.sync_strategy, - ddp_timeout=hparams.ddp.timeout, + ddp_sync_strategy=hparams.ddp_sync_strategy, + ddp_timeout=hparams.ddp_timeout, # Randomness seed=seed, @@ -433,7 +433,7 @@ def _spin_dataloaders(self): not be completely iterated through. """ # surpressing this multiple iteration warning -- it is OK to ignore - warnings.simplefilter(action="ignore", category=DataloaderMultipleIterationWarning, append=True) + warnings.filterwarnings(action="ignore", message=r"^DataloaderMultipleIterationWarning", append=True) assert self.state.train_dataloader is not None, "train dataloader should be set" assert self.state.eval_dataloader is not None, "eval dataloader should be set" @@ -478,7 +478,7 @@ def _train_loop(self) -> None: state.optimizers = map_collection(state.optimizers, self.device.optimizer_to_device) # wrap model with DDP - state.model = self.ddp.prepare_module(state.model) + state.model = ddp.prepare_module(state.model, self.find_unused_parameters) original_model = state.model.module assert isinstance(original_model, BaseMosaicModel) @@ -502,12 +502,12 @@ def _train_loop(self) -> None: def _ddp_reduce_scalar_and(flag: bool) -> bool: value = 1 if flag else 0 flag_tensor = self.device.tensor_to_device(torch.tensor(value).int()) - self.ddp.all_reduce(flag_tensor, reduce_operation='PRODUCT') + ddp.all_reduce(flag_tensor, reduce_operation='PRODUCT') return flag_tensor.item() == 1 def _ddp_reduce_tensor_sum(tensor: Tensor) -> Tensor: # Happens in-place; that's fine - self.ddp.all_reduce(tensor, reduce_operation="SUM") + ddp.all_reduce(tensor, reduce_operation="SUM") return tensor state.scaler = ClosureGradScaler(ddp_reduce_scalar_and=_ddp_reduce_scalar_and, @@ -588,10 +588,10 @@ def _ddp_reduce_tensor_sum(tensor: Tensor) -> Tensor: assert isinstance(total_loss, Tensor) # total_loss can be None if gradient scaling failed - self.ddp.all_reduce(total_loss, reduce_operation="SUM") - self.ddp.barrier() + ddp.all_reduce(total_loss, reduce_operation="SUM") + ddp.barrier() full_loss = total_loss.cpu().item() - self.logger.metric_batch({'loss/train': full_loss / state.world_size}) + self.logger.metric_batch({'loss/train': full_loss / ddp.get_world_size()}) if self.compute_training_metrics: self._compute_and_log_metrics(train_metrics, is_train=True, is_batch=True) @@ -608,7 +608,6 @@ def _ddp_reduce_tensor_sum(tensor: Tensor) -> Tensor: self.checkpointer.save_checkpoint(state=state, seed=self.seed, device=self.device, - ddp=self.ddp, config=self.config) except BreakEpochException: log.info(f'Skipping the rest of Epoch {state.epoch}') @@ -622,11 +621,7 @@ def _ddp_reduce_tensor_sum(tensor: Tensor) -> Tensor: state.epoch += 1 if self.checkpointer and self.checkpointer.should_checkpoint(state=state, event=Event.EPOCH_END): - self.checkpointer.save_checkpoint(state=state, - seed=self.seed, - device=self.device, - ddp=self.ddp, - config=self.config) + self.checkpointer.save_checkpoint(state=state, seed=self.seed, device=self.device, config=self.config) self.engine.run_event(Event.TRAINING_END) @@ -673,7 +668,7 @@ def _train_batch_inner(self, microbatches: Sequence[Batch]): for microbatch_idx, state.batch in enumerate(microbatches): is_final_microbatch = microbatch_idx + 1 == len(microbatches) - with self.ddp.sync_context(state, is_final_microbatch): + with ddp.sync_context(state, is_final_microbatch, self.ddp_sync_strategy): last_microbatch_size = self._get_batch_size(state.batch) # forward pass diff --git a/composer/trainer/trainer_hparams.py b/composer/trainer/trainer_hparams.py index d818de6971..857a5e51ec 100755 --- a/composer/trainer/trainer_hparams.py +++ b/composer/trainer/trainer_hparams.py @@ -24,9 +24,8 @@ ModelHparams, ResNet18Hparams, ResNet50Hparams, ResNet101Hparams, UnetHparams) from composer.optim import (AdamHparams, AdamWHparams, DecoupledAdamWHparams, DecoupledSGDWHparams, OptimizerHparams, RAdamHparams, RMSPropHparams, SchedulerHparams, SGDHparams, scheduler) -from composer.trainer.ddp import DDPHparams from composer.trainer.devices import CPUDeviceHparams, DeviceHparams, GPUDeviceHparams -from composer.utils.ddp import get_world_size +from composer.utils import ddp if TYPE_CHECKING: from composer.trainer.trainer import Trainer @@ -143,7 +142,11 @@ class TrainerHparams(hp.Hparams): "Determines the number of microbatches to split a per-gpu batch into, used to compensate for low-memory-capacity devices." ) precision: Precision = hp.required(doc="Precision to use for training", template_default=Precision.AMP) - ddp: DDPHparams = hp.optional(doc="DDP configuration", default_factory=DDPHparams) + ddp_sync_strategy: Optional[ddp.DDPSyncStrategy] = hp.optional( + doc="The strategy for synchronizing DDP. Default value ``None`` causes the " + "trainer to auto-select a value depending on what algorithms are used.", + default=None) + ddp_timeout: float = hp.optional(doc="Timeout, in seconds, for initializing the DDP process group.", default=5.0) grad_clip_norm: Optional[float] = hp.optional( default=None, doc='the norm to clip gradient magnitudes to. Default: None (no clip)') @@ -178,12 +181,10 @@ class TrainerHparams(hp.Hparams): compute_training_metrics: bool = hp.optional(doc="Log validation metrics on training data", default=False) log_level: str = hp.optional(doc="Python loglevel to use composer", default="INFO") - ddp_sync_strategy: Optional[str] = hp.optional(doc="Strategy for DDP syncing", default=None) - def validate(self): super().validate() - world_size = get_world_size() + world_size = ddp.get_world_size() if self.total_batch_size % world_size != 0: raise ValueError( diff --git a/composer/utils/__init__.py b/composer/utils/__init__.py index 3d3d26ffcc..813fd6d4e6 100644 --- a/composer/utils/__init__.py +++ b/composer/utils/__init__.py @@ -1,11 +1,7 @@ # Copyright 2021 MosaicML. All Rights Reserved. from composer.utils import augmentation_primitives as augmentation_primitives -from composer.utils.ddp import get_global_rank as get_global_rank -from composer.utils.ddp import get_local_rank as get_local_rank -from composer.utils.ddp import get_local_world_size as get_local_world_size -from composer.utils.ddp import get_world_size as get_world_size -from composer.utils.ddp import is_rank_zero as is_rank_zero +from composer.utils import ddp as ddp from composer.utils.determinism import get_random_seed as get_random_seed from composer.utils.determinism import seed_all as seed_all from composer.utils.iter_helpers import ensure_tuple as ensure_tuple diff --git a/composer/utils/ddp.py b/composer/utils/ddp.py index 0f2d838a79..69c8f77456 100644 --- a/composer/utils/ddp.py +++ b/composer/utils/ddp.py @@ -1,10 +1,57 @@ # Copyright 2021 MosaicML. All Rights Reserved. +from __future__ import annotations + +import collections.abc +import datetime import os import warnings -from typing import Optional +from contextlib import contextmanager, nullcontext +from typing import TYPE_CHECKING, Callable, ContextManager, List, Optional, Sequence, TypeVar, Union, cast +import torch import torch.distributed as dist +import torch.utils.data +from torch.nn.parallel import DistributedDataParallel + +from composer.utils.iter_helpers import ensure_tuple +from composer.utils.string_enum import StringEnum + +if TYPE_CHECKING: + from composer.core.state import State + from composer.core.types import DataLoader, Model + from composer.datasets.dataloader import DataloaderHparams, DataloaderSpec + +TObj = TypeVar("TObj") + + +class DDPSyncStrategy(StringEnum): + """How and when DDP gradient synchronization should happen. + + Attributes: + SINGLE_AUTO_SYNC: The default behavior for DDP. Gradients are synchronized as they + computed, for only the final microbatch of a batch. This is the most efficient + strategy, but can lead to errors when ``find_unused_parameters`` is set, since + it is possible different microbatches may use different sets of parameters, + leading to an incomplete sync. + MULTI_AUTO_SYNC: The default behavior for DDP when ``find_unused_parameters`` is set. + Gradients are synchronized as they are computed for all microbatches. This ensures + complete synchronization, but is less efficient than :attr:`SINGLE_AUTO_SYNC`. This + efficiency gap is usually small, as long as either DDP syncs are a small portion + of the trainer's overall runtime, or the number of microbatches per batch is + relatively small. + FORCED_SYNC: Gradients are manually synchronized only after all gradients have been + computed for the final microbatch of a batch. Like :attr:`MULTI_AUTO_SYNC`, this + strategy ensures complete gradient synchronization, but this tends to be slower than + :attr:`MULTI_AUTO_SYNC`. This is because ordinarily syncs can happen in parallel + with the ``loss.backward()`` computation, meaning syncs can be mostly complete by + the time that function finishes. However, in certain circumstances, syncs may take + a very long time to complete - if there are also a lot of microbatches per batch, + this strategy may be optimal. + """ + SINGLE_AUTO_SYNC = "single_auto_sync" + MULTI_AUTO_SYNC = "multi_auto_sync" + FORCED_SYNC = "forced_sync" def _get_distributed_config_var(env_var: str, @@ -31,6 +78,11 @@ def _get_distributed_config_var(env_var: str, def get_world_size() -> int: + """Returns the DDP world size + + Returns: + int: The world size + """ return _get_distributed_config_var(env_var="WORLD_SIZE", human_name="world size", default=1, @@ -38,18 +90,170 @@ def get_world_size() -> int: def get_global_rank() -> int: + """Returns the global rank of the current process, which is on `[0, WORLD_SIZE - 1]` + + Returns: + int: The global rank + """ return _get_distributed_config_var(env_var="RANK", human_name="global rank", default=0, fetch_fn_name="get_rank") def get_local_world_size() -> int: + """Returns the local world size, which is the number of processes for the current node. + + Returns: + int: The local world size + """ return _get_distributed_config_var(env_var="LOCAL_WORLD_SIZE", human_name="local world size", default=1) def get_local_rank() -> int: + """Returns the local rank for the current process, which is on `[0, LOCAL_WORLD_SIZE - 1]` + + Returns: + int: The local world size + """ local_rank = _get_distributed_config_var(env_var="LOCAL_RANK", human_name="local rank", default=0) assert local_rank == get_global_rank() % get_local_world_size(), "invariant violation" return local_rank -def is_rank_zero() -> bool: - return get_global_rank() == 0 +def barrier() -> None: + if dist.is_available(): + dist.barrier() + # If not on DDP, then do nothing + + +def all_reduce( + tensor: torch.Tensor, + reduce_operation: str = "SUM", +) -> None: + if dist.is_available(): + reduce_op = getattr(dist.ReduceOp, reduce_operation.upper()) + dist.all_reduce(tensor, op=reduce_op) + else: + raise NotImplementedError("Non-DDP versions of reduce operations are not yet implemented") + + +def all_gather(tensor: torch.Tensor) -> Sequence[torch.Tensor]: + """gather_to_rank_zero collects a tensor from each rank, and returns a sequence of tensors indexed by rank + + Args: + tensor (torch.Tensor): tensor from each rank to be gathered + + Returns: + Sequence[Tensor]: A sequence of tensors indexed by rank + """ + if dist.is_available(): + obj_gather_list = [torch.zeros_like(tensor) for _ in range(get_world_size())] + dist.all_gather(obj_gather_list, tensor) + return obj_gather_list + else: + return [tensor] + + +def all_gather_object(obj: TObj) -> List[TObj]: + """gather_object_to_rank_zero collects a pickleable object from each rank, and returns a list of + these objects indexed by rank + + Args: + obj (TObj): Object to be gathered + + Returns: + List[TObj]: A list of objects indexed by rank + """ + if dist.is_available(): + obj_gather_list = [None for _ in range(get_world_size())] + dist.all_gather_object(obj_gather_list, obj) + # torch.distributed will replace the None's in obj_gather_list with the gathered objects on rank 0 + # or will just be None on non-rank-0 + return cast(List[TObj], obj_gather_list) + else: + return [obj] + + +def initialize_ddp(backend: str, timeout: datetime.timedelta): + if not dist.is_available(): + return + if dist.is_initialized(): + + if not dist.get_backend() == backend.lower(): + raise RuntimeError( + f"The requested backend ({backend}) differs from the backend " + "of the current process group ({torch.distributed.get_backend()}). If you wish to change backends, " + "please restart the python process.") + return + + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + # Assume we can initialize based off of env vars + dist.init_process_group(backend, timeout=timeout) + return + + warnings.warn("NoDDPWarning: RANK and WORLD_SIZE env vars not set; assuming no parallelization. " + "If this is unexpected, make sure you are running your training script with the " + "composer executable.") + store = dist.HashStore() + + dist.init_process_group(backend, timeout=timeout, store=store, world_size=1, rank=0) + + +def prepare_module(module: Model, find_unused_parameters: bool) -> Model: + if dist.is_available(): + if any((p.requires_grad for p in module.parameters())): + ddp_model = DistributedDataParallel(module, find_unused_parameters=find_unused_parameters) + return ddp_model + return module + else: + return module + + +def create_dataloader(batch_size: int, dataloader_hparams: DataloaderHparams, + dataloader_spec: DataloaderSpec) -> DataLoader: + # TODO(ravi) refactor this function to return a sampler rather than create the dataloader + from composer.datasets.dataloader import DDPDataLoader + if dist.is_available(): + sampler = torch.utils.data.DistributedSampler[int](dataloader_spec.dataset, + drop_last=dataloader_spec.drop_last, + shuffle=dataloader_spec.shuffle) + else: + assert isinstance(dataloader_spec.dataset, collections.abc.Sized) + sampler = torch.utils.data.RandomSampler(dataloader_spec.dataset, generator=dataloader_spec.generator) + dataloader = dataloader_hparams.initialize_object(batch_size, sampler, dataloader_spec) + if dist.is_available(): + dataloader = DDPDataLoader(dataloader) + return dataloader + + +@contextmanager +def sync_context(state: State, is_final_microbatch: bool, sync_strategy: Union[str, DDPSyncStrategy]): + assert isinstance(state.model, DistributedDataParallel), "state.model is not wrapped by DDP" + assert state.optimizers is not None, "optimizers have not been initialized" + sync_strategy = DDPSyncStrategy(sync_strategy) + + no_sync_context = cast(Callable[[], ContextManager], state.model.no_sync) + auto_sync_context = nullcontext + + if sync_strategy == DDPSyncStrategy.SINGLE_AUTO_SYNC: + context = auto_sync_context if is_final_microbatch else no_sync_context + with context(): + yield + + elif sync_strategy == DDPSyncStrategy.MULTI_AUTO_SYNC: + with auto_sync_context(): + yield + + elif sync_strategy == DDPSyncStrategy.FORCED_SYNC: + try: + with no_sync_context(): + yield + finally: + if is_final_microbatch: + for optimizer in ensure_tuple(state.optimizers): + for group in optimizer.param_groups: + for p in group["params"]: + if p.grad is not None: + all_reduce(p.grad) + p.grad = p.grad / get_world_size() + + else: + raise ValueError("Unknown sync strategy", sync_strategy) diff --git a/tests/test_logger.py b/tests/test_logger.py index 2688d79ab6..2c6c5ed711 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -11,7 +11,7 @@ from composer.core.state import State from composer.loggers.file_logger import FileLoggerBackend from composer.loggers.logger_hparams import FileLoggerBackendHparams -from composer.utils.ddp import is_rank_zero +from composer.utils import ddp @pytest.fixture @@ -73,7 +73,7 @@ def test_deferred(self, dummy_state_without_rank: State, log_file_name: str, log logger.metric_batch({"metric": "after_training_start"}) log_destination.batch_end(dummy_state, logger) log_destination.training_end(dummy_state, logger) - if is_rank_zero(): + if ddp.get_global_rank() == 0: with open(log_file_name, 'r') as f: assert f.readlines() == [ '[BATCH][step=2]: { "metric": "before_training_start", }\n', diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index a72cbd9481..0eee3980c6 100755 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -19,7 +19,7 @@ from composer.trainer.devices import CPUDeviceHparams, DeviceHparams, GPUDeviceHparams from composer.trainer.trainer import Trainer from composer.trainer.trainer_hparams import TrainerHparams, callback_registry -from composer.utils.ddp import is_rank_zero +from composer.utils import ddp from tests.test_state import assert_state_equivalent from tests.utils.deep_compare import deep_compare @@ -175,7 +175,7 @@ def test_checkpoint( checkpoint_c_file_path = os.path.join(checkpoint_b_folder, final_checkpoint) trainer_2_hparams_filepath = os.path.join(checkpoint_b_folder, "hparams.yaml") - if is_rank_zero(): + if ddp.get_global_rank() == 0: assert_checkpoints_equivalent( hparams_file_a=trainer_1_hparams_filepath, diff --git a/tests/trainer/test_ddp.py b/tests/trainer/test_ddp.py index 4fdadcbadb..2e3226e619 100755 --- a/tests/trainer/test_ddp.py +++ b/tests/trainer/test_ddp.py @@ -21,6 +21,7 @@ from composer.trainer.devices import CPUDeviceHparams, GPUDeviceHparams from composer.trainer.devices.device_hparams import DeviceHparams from composer.trainer.trainer_hparams import TrainerHparams, callback_registry, dataset_registry +from composer.utils import ddp from tests.fixtures.models import SimpleBatchPairModelHparams @@ -228,7 +229,7 @@ def test_ddp(device: DeviceHparams, world_size: int, ddp_tmpdir: str, mosaic_tra is_train_to_pickles: Dict[bool, List[Dict[str, types.Tensor]]] = {True: [], False: []} for epoch in range(num_epochs): - for local_rank in range(trainer.state.local_world_size): + for local_rank in range(ddp.get_local_world_size()): for is_train in (True, False): data: Dict[str, types.Tensor] = torch.load( # type: ignore get_batch_file_path(ddp_tmpdir, rank=local_rank, epoch=epoch, is_train=is_train), diff --git a/tests/trainer/test_ddp_sync_strategy.py b/tests/trainer/test_ddp_sync_strategy.py index 9436240794..bcf6af9eab 100755 --- a/tests/trainer/test_ddp_sync_strategy.py +++ b/tests/trainer/test_ddp_sync_strategy.py @@ -8,7 +8,7 @@ from composer.core.state import State from composer.core.types import DataLoader, Tensor -from composer.trainer.ddp import DDP +from composer.utils import ddp class MinimalConditionalModel(nn.Module): @@ -47,7 +47,7 @@ def loss(self, output: Tensor, target: Tensor): def test_ddp_sync_strategy(ddp_sync_strategy: str, expected_grads: List[Optional[float]], dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader): original_model = MinimalConditionalModel() - ddp = DDP(backend="gloo", find_unused_parameters=True, sync_strategy=ddp_sync_strategy, timeout=5.) + # ddp = DDP(backend="gloo", find_unused_parameters=True, sync_strategy=ddp_sync_strategy, timeout=5.) optimizer = torch.optim.SGD(original_model.parameters(), 0.1) state = State(model=original_model, @@ -61,24 +61,24 @@ def test_ddp_sync_strategy(ddp_sync_strategy: str, expected_grads: List[Optional precision='fp32') batches = [[(1, Tensor([1])), (1, Tensor([2]))], [(2, Tensor([1])), (2, Tensor([2]))]] - state.model = ddp.prepare_module(state.model) + state.model = ddp.prepare_module(state.model, find_unused_parameters=True) optimizer.zero_grad() for microbatch_idx in range(2): - with ddp.sync_context(state, microbatch_idx == 1): - input, target = batches[microbatch_idx][state.local_rank] + with ddp.sync_context(state, microbatch_idx == 1, sync_strategy=ddp_sync_strategy): + input, target = batches[microbatch_idx][ddp.get_local_rank()] output = state.model.forward(input) loss = original_model.loss(output, target) loss.mul_(1 / 2) loss.backward() - if state.is_rank_zero: + if ddp.get_global_rank() == 0: grads = [p.grad.item() if p.grad else None for p in original_model.parameters()] for expected, actual in zip(expected_grads[microbatch_idx], grads): # type: ignore assert expected == actual - if state.is_rank_zero: + if ddp.get_global_rank() == 0: grads = [p.grad.item() if p.grad else None for p in original_model.parameters()] for expected, actual in zip(expected_grads[-1], grads): # type: ignore assert expected == actual diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a0581bf763..adc40ff2fb 100755 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -87,7 +87,7 @@ def test_trainer_determinism(mosaic_trainer_hparams: TrainerHparams): first_model = first_trainer.state.model.module assert isinstance(first_model, BaseMosaicModel) assert first_trainer.state.train_dataloader is not None - first_loss = get_total_loss(first_model, first_trainer.state.train_dataloader, first_trainer.ddp) + first_loss = get_total_loss(first_model, first_trainer.state.train_dataloader) # Second trainer must be created after fitting the first so that the # seeds get fully reset for the second training run @@ -96,7 +96,7 @@ def test_trainer_determinism(mosaic_trainer_hparams: TrainerHparams): second_model = second_trainer.state.model.module assert isinstance(second_model, BaseMosaicModel) assert second_trainer.state.train_dataloader is not None - second_loss = get_total_loss(second_model, second_trainer.state.train_dataloader, second_trainer.ddp) + second_loss = get_total_loss(second_model, second_trainer.state.train_dataloader) torch.testing.assert_allclose(second_loss, first_loss) diff --git a/tests/utils/dataloader.py b/tests/utils/dataloader.py index 6b5e3ca299..ff8b2cd847 100644 --- a/tests/utils/dataloader.py +++ b/tests/utils/dataloader.py @@ -3,8 +3,7 @@ import torch.utils.data from composer.core.types import DataLoader -from composer.datasets import DataloaderHparams, DataloaderSpec -from composer.trainer.ddp import DDPDataLoader +from composer.datasets import DataloaderHparams, DataloaderSpec, DDPDataLoader def get_dataloader(dataloader_spec: DataloaderSpec, dataloader_hparams: DataloaderHparams, diff --git a/tests/utils/trainer_fit.py b/tests/utils/trainer_fit.py index 75851ba864..862d5bc194 100644 --- a/tests/utils/trainer_fit.py +++ b/tests/utils/trainer_fit.py @@ -9,14 +9,13 @@ from composer.models.base import BaseMosaicModel from composer.models.classify_mnist.mnist_hparams import MnistClassifierHparams from composer.optim.optimizer_hparams import SGDHparams -from composer.trainer.ddp import DDP from composer.trainer.devices.device_gpu import DeviceGPU from composer.trainer.trainer import Trainer from composer.trainer.trainer_hparams import TrainerHparams -from composer.utils import ensure_tuple +from composer.utils import ddp, ensure_tuple -def get_total_loss(model: BaseMosaicModel, dataloader: DataLoader, ddp: DDP): +def get_total_loss(model: BaseMosaicModel, dataloader: DataLoader): with torch.no_grad(): total_loss = 0 for batch in dataloader: @@ -27,7 +26,7 @@ def get_total_loss(model: BaseMosaicModel, dataloader: DataLoader, ddp: DDP): total_loss_tensor = torch.Tensor([total_loss]) ddp.all_reduce(total_loss_tensor) - return total_loss_tensor.item() / ddp.world_size + return total_loss_tensor.item() / ddp.get_world_size() def train_model(mosaic_trainer_hparams: TrainerHparams, max_epochs: int = 2, run_loss_check: bool = False): @@ -69,10 +68,10 @@ def train_model(mosaic_trainer_hparams: TrainerHparams, max_epochs: int = 2, run original_model = trainer.device.module_to_device(original_model) if run_loss_check and trainer.state.train_dataloader: - initial_loss = get_total_loss(original_model, trainer.state.train_dataloader, trainer.ddp) + initial_loss = get_total_loss(original_model, trainer.state.train_dataloader) unwrapped_model = trainer.state.model.module assert isinstance(unwrapped_model, BaseMosaicModel) - post_fit_loss = get_total_loss(unwrapped_model, trainer.state.train_dataloader, trainer.ddp) + post_fit_loss = get_total_loss(unwrapped_model, trainer.state.train_dataloader) assert post_fit_loss < initial_loss + 1e-5 From f706ff8c7939cd89225750ca0f0a09f7cd6d0251 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 23 Nov 2021 15:52:19 -0800 Subject: [PATCH 21/38] Added in DDP barrier --- composer/callbacks/run_directory_uploader.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/composer/callbacks/run_directory_uploader.py b/composer/callbacks/run_directory_uploader.py index 81586e69e0..fca918db74 100644 --- a/composer/callbacks/run_directory_uploader.py +++ b/composer/callbacks/run_directory_uploader.py @@ -20,6 +20,7 @@ from composer.core.logging import Logger from composer.core.logging.logger import LogLevel from composer.core.state import State +from composer.utils import ddp from composer.utils.run_directory import get_run_directory log = logging.getLogger(__name__) @@ -175,11 +176,11 @@ def _run_event(self, event: Event, state: State, logger: Logger) -> None: atexit.register(self._close) if event == Event.BATCH_END: if (state.batch_idx + 1) % self._upload_every_n_batches == 0: - self._trigger_upload(state, logger, LogLevel.BATCH) + self._trigger_upload(logger, LogLevel.BATCH) if event == Event.EPOCH_END: - self._trigger_upload(state, logger, LogLevel.EPOCH) + self._trigger_upload(logger, LogLevel.EPOCH) if event == Event.TRAINING_END: - self._trigger_upload(state, logger, LogLevel.FIT) + self._trigger_upload(logger, LogLevel.FIT) # TODO -- we are missing logfiles from other callbacks / loggers that write on training end but after # the run directory uploader is invoked. This callback either needs to fire last, # or we need another event such as cleanup @@ -191,12 +192,11 @@ def _close(self): for worker in self._workers: worker.join() - def _trigger_upload(self, state: State, logger: Logger, log_level: LogLevel) -> None: + def _trigger_upload(self, logger: Logger, log_level: LogLevel) -> None: # Ensure that every rank is at this point # Assuming only the main thread on each rank writes to the run directory, then the barrier here will ensure # that the run directory is not being modified after we pass this barrier - # TODO(ravi) -- add in a ddp barrier here. - # state.ddp.barrier() + ddp.barrier() new_last_uploaded_timestamp = time.time() # Now, for each file that was modified since self._last_upload_timestamp, copy it to the temporary directory # IMPROTANT: From now, until self._last_upload_timestamp is updated, no files should be written to the run directory From d7841ca7f07e6979538337df2dd3209480f15b42 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 23 Nov 2021 15:55:34 -0800 Subject: [PATCH 22/38] Fixed tests --- tests/callbacks/test_run_directory_uploader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_run_directory_uploader.py b/tests/callbacks/test_run_directory_uploader.py index 26dccfc8ad..271de33ba9 100644 --- a/tests/callbacks/test_run_directory_uploader.py +++ b/tests/callbacks/test_run_directory_uploader.py @@ -37,5 +37,5 @@ def test_run_directory_uploader(tmpdir: pathlib.Path, use_procs: bool, dummy_sta uploader.run_event(Event.TRAINING_END, dummy_state, dummy_logger) # now assert that we have a dummy file in the run directory copy folder - with open(os.path.join(remote_dir, "dummy_file"), "r") as f: + with open(os.path.join(remote_dir, run_directory, "dummy_file"), "r") as f: assert f.read() == "Hello, world!" From 9e1206183296cfb80655c1889b3e45c17be177cc Mon Sep 17 00:00:00 2001 From: Jamie Bloxham Date: Mon, 29 Nov 2021 16:44:07 -0800 Subject: [PATCH 23/38] Update composer/utils/ddp.py --- composer/utils/ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/utils/ddp.py b/composer/utils/ddp.py index 69c8f77456..2a1a08f3e0 100644 --- a/composer/utils/ddp.py +++ b/composer/utils/ddp.py @@ -90,7 +90,7 @@ def get_world_size() -> int: def get_global_rank() -> int: - """Returns the global rank of the current process, which is on `[0, WORLD_SIZE - 1]` + """Returns the global rank of the current process, which is in `[0, WORLD_SIZE - 1]` Returns: int: The global rank From 87c044130d545d4df11a2c009df475b78ba10134 Mon Sep 17 00:00:00 2001 From: Jamie Bloxham Date: Mon, 29 Nov 2021 16:44:11 -0800 Subject: [PATCH 24/38] Update composer/utils/ddp.py --- composer/utils/ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/utils/ddp.py b/composer/utils/ddp.py index 2a1a08f3e0..dab481c9e9 100644 --- a/composer/utils/ddp.py +++ b/composer/utils/ddp.py @@ -108,7 +108,7 @@ def get_local_world_size() -> int: def get_local_rank() -> int: - """Returns the local rank for the current process, which is on `[0, LOCAL_WORLD_SIZE - 1]` + """Returns the local rank for the current process, which is in `[0, LOCAL_WORLD_SIZE - 1]` Returns: int: The local world size From 6a6427ef5fa95d585fbc51ae99eb9f82bf97268d Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 30 Nov 2021 08:28:52 -0800 Subject: [PATCH 25/38] Switched tqdm to using callback hooks Added test case for TQDM --- composer/loggers/tqdm_logger.py | 91 ++++++++++++++++++++----------- tests/callbacks/test_callbacks.py | 4 +- tests/test_logger.py | 34 +++++++++++- 3 files changed, 93 insertions(+), 36 deletions(-) diff --git a/composer/loggers/tqdm_logger.py b/composer/loggers/tqdm_logger.py index 52af518790..2658b293d3 100644 --- a/composer/loggers/tqdm_logger.py +++ b/composer/loggers/tqdm_logger.py @@ -6,10 +6,9 @@ from dataclasses import asdict, dataclass, field from typing import TYPE_CHECKING, Any, Dict, Optional +import tqdm import yaml -from tqdm import tqdm -from composer.core.event import Event from composer.core.logging import LogLevel, RankZeroLoggerBackend, TLogData, TLogDataValue, format_log_data_value from composer.core.state import State from composer.core.types import StateDict @@ -44,11 +43,11 @@ def __init__(self, epoch_metrics=(epoch_metrics or {})) desc = f'Epoch {epoch + 1}{"" if is_train else " (val)"}' position = 0 if is_train else 1 - self.pbar = tqdm(total=total, - desc=desc, - position=position, - initial=n, - bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}") + self.pbar = tqdm.tqdm(total=total, + desc=desc, + position=position, + initial=n, + bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}") self.pbar.set_postfix(epoch_metrics) def log_metric(self, data: TLogData): @@ -107,32 +106,58 @@ def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData assert self.is_train is not None self.pbars[self.is_train].log_metric(data) - def _run_event(self, event: Event, state: State, logger: Logger) -> None: - if event == Event.INIT: - if self.config is not None: - print("Config") - print("-" * 30) - yaml.safe_dump(self.config, stream=sys.stdout) - print("-" * 30) - print() - if event in (Event.EPOCH_START, Event.EVAL_START): - self.is_train = event == Event.EPOCH_START - assert state.train_dataloader is not None - assert state.eval_dataloader is not None - total_steps = len(state.train_dataloader) if self.is_train else len(state.eval_dataloader) - self.pbars[self.is_train] = _TQDMLoggerInstance(total=total_steps, - epoch=state.epoch, - is_train=self.is_train) - if event in (Event.AFTER_BACKWARD, Event.EVAL_AFTER_FORWARD): - if self.is_train in self.pbars: - assert self.is_train is not None - self.pbars[self.is_train].update() - if event in (Event.EPOCH_END, Event.EVAL_END): - if self.is_train in self.pbars: - assert self.is_train is not None - self.pbars[self.is_train].close() - del self.pbars[self.is_train] - self.is_train = None + def init(self, state: State, logger: Logger) -> None: + del state, logger # unused + if self.config is not None: + print("Config") + print("-" * 30) + yaml.safe_dump(self.config, stream=sys.stdout) + print("-" * 30) + print() + + def start(self, state: State): + assert state.train_dataloader is not None + assert state.eval_dataloader is not None + total_steps = len(state.train_dataloader) if self.is_train else len(state.eval_dataloader) + self.pbars[self.is_train] = _TQDMLoggerInstance(total=total_steps, epoch=state.epoch, is_train=self.is_train) + + def epoch_start(self, state: State, logger: Logger) -> None: + del logger # unused + self.is_train = True + self.start(state) + + def eval_start(self, state: State, logger: Logger) -> None: + del logger # unused + self.is_train = False + self.start(state) + + def update(self): + if self.is_train in self.pbars: + assert self.is_train is not None + self.pbars[self.is_train].update() + + def after_backward(self, state: State, logger: Logger) -> None: + del state, logger # unused + self.update() + + def eval_after_forward(self, state: State, logger: Logger) -> None: + del state, logger # unused + self.update() + + def end(self): + if self.is_train in self.pbars: + assert self.is_train is not None + self.pbars[self.is_train].close() + del self.pbars[self.is_train] + self.is_train = None + + def epoch_end(self, state: State, logger: Logger) -> None: + del state, logger # unused + self.end() + + def eval_end(self, state: State, logger: Logger) -> None: + del state, logger # unused + self.end() def state_dict(self) -> StateDict: return { diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 92c1eef557..0f6d1a4abe 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -1,6 +1,5 @@ # Copyright 2021 MosaicML. All Rights Reserved. -import _pytest.monkeypatch import pytest from composer.core import Event @@ -27,11 +26,12 @@ def __init__(self) -> None: self.event = None def _run_event(self, event: Event, state: State, logger: Logger) -> None: + del state, logger # unused self.event = event @pytest.mark.parametrize('event', list(Event)) -def test_run_event_callbacks(event: Event, dummy_state: State, monkeypatch: _pytest.monkeypatch.MonkeyPatch): +def test_run_event_callbacks(event: Event, dummy_state: State): callback = EventTrackerCallback() logger = Logger(dummy_state) engine = Engine(state=dummy_state, algorithms=[], logger=logger, callbacks=[callback]) diff --git a/tests/test_logger.py b/tests/test_logger.py index f9265ab477..7feb8b1e82 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -2,16 +2,19 @@ import os import pathlib +from unittest.mock import MagicMock import pytest import torch.distributed as dist +import tqdm from _pytest.monkeypatch import MonkeyPatch from composer.core.event import Event from composer.core.logging import Logger, LogLevel from composer.core.state import State from composer.loggers.file_logger import FileLoggerBackend -from composer.loggers.logger_hparams import FileLoggerBackendHparams +from composer.loggers.logger_hparams import FileLoggerBackendHparams, TQDMLoggerBackendHparams +from composer.trainer.trainer_hparams import TrainerHparams @pytest.fixture @@ -58,3 +61,32 @@ def test_file_logger(dummy_state: State, log_destination: FileLoggerBackend, mon '[BATCH][step=2]: { "metric": "batch", }\n', '[EPOCH][step=3]: { "metric": "epoch2", }\n', ] + + +def test_tqdm_logger(mosaic_trainer_hparams: TrainerHparams, monkeypatch: MonkeyPatch): + is_train_to_mock_tqdms = { + True: [], + False: [], + } + + def get_mock_tqdm(position: int, *args, **kwargs): + del args, kwargs # unused + is_train = position == 0 + mock_tqdm = MagicMock() + is_train_to_mock_tqdms[is_train].append(mock_tqdm) + return mock_tqdm + + monkeypatch.setattr(tqdm, "tqdm", get_mock_tqdm) + mosaic_trainer_hparams.loggers = [TQDMLoggerBackendHparams()] + trainer = mosaic_trainer_hparams.initialize_object() + trainer.fit() + assert len(is_train_to_mock_tqdms[True]) == mosaic_trainer_hparams.max_epochs + assert mosaic_trainer_hparams.validate_every_n_batches < 0 + assert len(is_train_to_mock_tqdms[False] + ) == mosaic_trainer_hparams.validate_every_n_epochs * mosaic_trainer_hparams.max_epochs + for mock_tqdm in is_train_to_mock_tqdms[True]: + assert mock_tqdm.update.call_count == trainer.state.steps_per_epoch + mock_tqdm.close.assert_called_once() + for mock_tqdm in is_train_to_mock_tqdms[False]: + assert mock_tqdm.update.call_count == len(trainer.state.eval_dataloader) + mock_tqdm.close.assert_called_once() From eb9def858ffc88c80333e848234ea9f75a835d57 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 30 Nov 2021 08:31:47 -0800 Subject: [PATCH 26/38] Fixed pyright --- composer/loggers/tqdm_logger.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/composer/loggers/tqdm_logger.py b/composer/loggers/tqdm_logger.py index 2658b293d3..ec3736a4e2 100644 --- a/composer/loggers/tqdm_logger.py +++ b/composer/loggers/tqdm_logger.py @@ -115,36 +115,35 @@ def init(self, state: State, logger: Logger) -> None: print("-" * 30) print() - def start(self, state: State): - assert state.train_dataloader is not None - assert state.eval_dataloader is not None + def _start(self, state: State): + assert self.is_train is not None, "self.is_train should be set by the callback" total_steps = len(state.train_dataloader) if self.is_train else len(state.eval_dataloader) self.pbars[self.is_train] = _TQDMLoggerInstance(total=total_steps, epoch=state.epoch, is_train=self.is_train) def epoch_start(self, state: State, logger: Logger) -> None: del logger # unused self.is_train = True - self.start(state) + self._start(state) def eval_start(self, state: State, logger: Logger) -> None: del logger # unused self.is_train = False - self.start(state) + self._start(state) - def update(self): + def _update(self): if self.is_train in self.pbars: assert self.is_train is not None self.pbars[self.is_train].update() def after_backward(self, state: State, logger: Logger) -> None: del state, logger # unused - self.update() + self._update() def eval_after_forward(self, state: State, logger: Logger) -> None: del state, logger # unused - self.update() + self._update() - def end(self): + def _end(self): if self.is_train in self.pbars: assert self.is_train is not None self.pbars[self.is_train].close() @@ -153,11 +152,11 @@ def end(self): def epoch_end(self, state: State, logger: Logger) -> None: del state, logger # unused - self.end() + self._end() def eval_end(self, state: State, logger: Logger) -> None: del state, logger # unused - self.end() + self._end() def state_dict(self) -> StateDict: return { From b8863dabe71e7dee2d078073b08fd7c1843a8ad9 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 30 Nov 2021 09:04:08 -0800 Subject: [PATCH 27/38] Fixed DDP barriers --- composer/utils/ddp.py | 60 ++++++++++++++++++++++++++++++++----------- 1 file changed, 45 insertions(+), 15 deletions(-) diff --git a/composer/utils/ddp.py b/composer/utils/ddp.py index 69c8f77456..1c69224dfd 100644 --- a/composer/utils/ddp.py +++ b/composer/utils/ddp.py @@ -119,24 +119,33 @@ def get_local_rank() -> int: def barrier() -> None: - if dist.is_available(): + if dist.is_available() and dist.is_initialized(): dist.barrier() - # If not on DDP, then do nothing + return + world_size = get_world_size() + if world_size == 1: + return + raise RuntimeError(f"Since the world_size({world_size}) > 1, please configure DDP to use ddp.barrier(). " + "The mosaic trainer will automatically do this for you.") def all_reduce( tensor: torch.Tensor, reduce_operation: str = "SUM", ) -> None: - if dist.is_available(): + if dist.is_available() and dist.is_initialized(): reduce_op = getattr(dist.ReduceOp, reduce_operation.upper()) dist.all_reduce(tensor, op=reduce_op) - else: - raise NotImplementedError("Non-DDP versions of reduce operations are not yet implemented") + return + world_size = get_world_size() + if world_size == 1: + return + raise RuntimeError(f"Since the world_size({world_size}) > 1, please configure DDP to use ddp.all_reduce(). " + "The mosaic trainer will automatically do this for you.") def all_gather(tensor: torch.Tensor) -> Sequence[torch.Tensor]: - """gather_to_rank_zero collects a tensor from each rank, and returns a sequence of tensors indexed by rank + """all_gather collects a tensor from each rank, and returns a sequence of tensors indexed by rank Args: tensor (torch.Tensor): tensor from each rank to be gathered @@ -144,16 +153,19 @@ def all_gather(tensor: torch.Tensor) -> Sequence[torch.Tensor]: Returns: Sequence[Tensor]: A sequence of tensors indexed by rank """ - if dist.is_available(): + if dist.is_available() and dist.is_initialized(): obj_gather_list = [torch.zeros_like(tensor) for _ in range(get_world_size())] dist.all_gather(obj_gather_list, tensor) return obj_gather_list - else: + world_size = get_world_size() + if world_size == 1: return [tensor] + raise RuntimeError(f"Since the world_size({world_size}) > 1, please configure DDP to use ddp.all_gather(). " + "The mosaic trainer will automatically do this for you.") def all_gather_object(obj: TObj) -> List[TObj]: - """gather_object_to_rank_zero collects a pickleable object from each rank, and returns a list of + """all_gather_object collects a pickleable object from each rank, and returns a list of these objects indexed by rank Args: @@ -168,12 +180,18 @@ def all_gather_object(obj: TObj) -> List[TObj]: # torch.distributed will replace the None's in obj_gather_list with the gathered objects on rank 0 # or will just be None on non-rank-0 return cast(List[TObj], obj_gather_list) - else: + world_size = get_world_size() + if world_size == 1: return [obj] + raise RuntimeError(f"Since the world_size({world_size}) > 1, please configure DDP to use ddp.all_gather_object(). " + "The mosaic trainer will automatically do this for you.") def initialize_ddp(backend: str, timeout: datetime.timedelta): if not dist.is_available(): + if get_world_size() != 1: + raise RuntimeError("When the world size is > 1, DDP must be used. However, it is not available in your " + "installation of PyTorch. Please install or build PyTorch with DDP support.") return if dist.is_initialized(): @@ -198,26 +216,38 @@ def initialize_ddp(backend: str, timeout: datetime.timedelta): def prepare_module(module: Model, find_unused_parameters: bool) -> Model: - if dist.is_available(): + if dist.is_available() and dist.is_initialized(): if any((p.requires_grad for p in module.parameters())): ddp_model = DistributedDataParallel(module, find_unused_parameters=find_unused_parameters) return ddp_model return module - else: + if get_world_size() == 1: return module + if dist.is_available(): + raise RuntimeError("Please call ddp.initialize_ddp() before calling ddp.prepare_module()") + raise RuntimeError("When the world size is > 1, DDP must be used. However, it is not available in your " + "installation of PyTorch. Please install or build PyTorch with DDP support.") def create_dataloader(batch_size: int, dataloader_hparams: DataloaderHparams, dataloader_spec: DataloaderSpec) -> DataLoader: # TODO(ravi) refactor this function to return a sampler rather than create the dataloader from composer.datasets.dataloader import DDPDataLoader - if dist.is_available(): + if dist.is_available() and dist.is_initialized(): sampler = torch.utils.data.DistributedSampler[int](dataloader_spec.dataset, drop_last=dataloader_spec.drop_last, shuffle=dataloader_spec.shuffle) - else: + elif get_world_size() == 1: assert isinstance(dataloader_spec.dataset, collections.abc.Sized) - sampler = torch.utils.data.RandomSampler(dataloader_spec.dataset, generator=dataloader_spec.generator) + if dataloader_spec.shuffle: + sampler = torch.utils.data.RandomSampler(dataloader_spec.dataset, generator=dataloader_spec.generator) + else: + sampler = torch.utils.data.SequentialSampler(dataloader_spec.dataset) + else: + if dist.is_available(): + raise RuntimeError("Please call ddp.initialize_ddp() before calling ddp.create_dataloader()") + raise RuntimeError("When the world size is > 1, DDP must be used. However, it is not available in your " + "installation of PyTorch. Please install or build PyTorch with DDP support.") dataloader = dataloader_hparams.initialize_object(batch_size, sampler, dataloader_spec) if dist.is_available(): dataloader = DDPDataLoader(dataloader) From a913fa96692f84e49eb302f08d113676be818b72 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 30 Nov 2021 09:33:15 -0800 Subject: [PATCH 28/38] Increased timeout for run directory uploader --- tests/callbacks/test_run_directory_uploader.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/callbacks/test_run_directory_uploader.py b/tests/callbacks/test_run_directory_uploader.py index 271de33ba9..faaca2f5f4 100644 --- a/tests/callbacks/test_run_directory_uploader.py +++ b/tests/callbacks/test_run_directory_uploader.py @@ -13,6 +13,9 @@ @pytest.mark.parametrize("use_procs", [True, False]) +# TODO(ravi) -- remove the pytest.in #110. The TRAINING_END event is likely slow as it has to copy many +# files created by the ddp test. #110 grately reduces the number of files from the DDP test. +@pytest.mark.timeout(15) def test_run_directory_uploader(tmpdir: pathlib.Path, use_procs: bool, dummy_state: State, dummy_logger: Logger): dummy_state.epoch = 0 dummy_state.step = 0 From 00fcc33462e239a591e7cc3a7aea14871119ab17 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 30 Nov 2021 10:13:54 -0800 Subject: [PATCH 29/38] Switched callback format for run directory uploader --- composer/callbacks/run_directory_uploader.py | 41 ++++++++++++-------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/composer/callbacks/run_directory_uploader.py b/composer/callbacks/run_directory_uploader.py index fca918db74..5c86f26969 100644 --- a/composer/callbacks/run_directory_uploader.py +++ b/composer/callbacks/run_directory_uploader.py @@ -16,7 +16,6 @@ from typing import Any, Callable, Dict, Optional, Type, Union from composer.core.callback import RankZeroCallback -from composer.core.event import Event from composer.core.logging import Logger from composer.core.logging.logger import LogLevel from composer.core.state import State @@ -150,7 +149,11 @@ def __init__( self._finished: Union[None, multiprocessing._EventType, threading.Event] = None self._workers = [] - def _init(self) -> None: + def init(self, state: State, logger: Logger) -> None: + if get_run_directory() is None: + return + del state, logger # unused + atexit.register(self._close) self._finished = self._finished_cls() self._last_upload_timestamp = 0.0 self._workers = [ @@ -168,23 +171,27 @@ def _init(self) -> None: for worker in self._workers: worker.start() - def _run_event(self, event: Event, state: State, logger: Logger) -> None: + def batch_end(self, state: State, logger: Logger) -> None: + if get_run_directory() is None: + return + if (state.batch_idx + 1) % self._upload_every_n_batches == 0: + self._trigger_upload(logger, LogLevel.BATCH) + + def epoch_end(self, state: State, logger: Logger) -> None: + del state # unused + if get_run_directory() is None: + return + self._trigger_upload(logger, LogLevel.EPOCH) + + def training_end(self, state: State, logger: Logger) -> None: + del state # unused if get_run_directory() is None: return - if event == Event.INIT: - self._init() - atexit.register(self._close) - if event == Event.BATCH_END: - if (state.batch_idx + 1) % self._upload_every_n_batches == 0: - self._trigger_upload(logger, LogLevel.BATCH) - if event == Event.EPOCH_END: - self._trigger_upload(logger, LogLevel.EPOCH) - if event == Event.TRAINING_END: - self._trigger_upload(logger, LogLevel.FIT) - # TODO -- we are missing logfiles from other callbacks / loggers that write on training end but after - # the run directory uploader is invoked. This callback either needs to fire last, - # or we need another event such as cleanup - self._close() + self._trigger_upload(logger, LogLevel.FIT) + # TODO -- we are missing logfiles from other callbacks / loggers that write on training end but after + # the run directory uploader is invoked. This callback either needs to fire last, + # or we need another event such as cleanup + self._close() def _close(self): if self._finished is not None: From 5066bdcf85440ef4d4c421e3387c9011d5ef74d1 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 30 Nov 2021 11:05:21 -0800 Subject: [PATCH 30/38] Replaced `atexit` with cleanup methods When running the trainer multiple times, such as in interactive enviroments, `atexit` does not fire. Instead, replaced it with `.close()` and `.post_close()` hooks on callbacks. `.close()` can be used to write and flush files. `.post_close()` can be used to backup the run directory and capture any changes that may have been made on `.close()` --- composer/callbacks/torch_profiler.py | 10 ++++------ composer/core/callback.py | 19 ++++++++++++++++++ composer/core/engine.py | 29 ++++++++++++++++++++++++++++ composer/loggers/file_logger.py | 17 +++++++++------- composer/loggers/wandb_logger.py | 5 ++--- composer/trainer/trainer.py | 5 ++++- tests/callbacks/test_callbacks.py | 2 +- tests/conftest.py | 16 --------------- 8 files changed, 69 insertions(+), 34 deletions(-) diff --git a/composer/callbacks/torch_profiler.py b/composer/callbacks/torch_profiler.py index 0e3724a04f..e73e1e4321 100644 --- a/composer/callbacks/torch_profiler.py +++ b/composer/callbacks/torch_profiler.py @@ -2,7 +2,6 @@ from __future__ import annotations -import atexit import warnings from dataclasses import asdict, dataclass from typing import TYPE_CHECKING, Optional @@ -132,7 +131,7 @@ def scheduler_fn(self, profiler_step: int) -> ProfilerAction: torch_scheduler_action = ProfilerAction.RECORD_AND_SAVE return torch_scheduler_action - def training_start(self, state: State, logger: Logger) -> None: + def init(self, state: State, logger: Logger) -> None: del state, logger # unused assert self.profiler is None, _PROFILE_MISSING_ERROR self.profiler = torch.profiler.profile( @@ -149,7 +148,6 @@ def training_start(self, state: State, logger: Logger) -> None: with_flops=self.hparams.with_flops, ) self.profiler.__enter__() - atexit.register(self._close_profiler) def batch_end(self, state: State, logger: Logger) -> None: del state, logger # unused @@ -165,6 +163,6 @@ def batch_start(self, state: State, logger: Logger) -> None: assert self.profiler is not None, _PROFILE_MISSING_ERROR logger.metric_batch({"profiler/state": self.profiler.current_action.name}) - def _close_profiler(self) -> None: - assert self.profiler is not None - self.profiler.__exit__(None, None, None) + def close(self) -> None: + if self.profiler is not None: + self.profiler.__exit__(None, None, None) diff --git a/composer/core/callback.py b/composer/core/callback.py index d8e357e9bc..7082921b59 100644 --- a/composer/core/callback.py +++ b/composer/core/callback.py @@ -302,6 +302,25 @@ def eval_end(self, state: State, logger: Logger) -> None: del state, logger # unused pass + def close(self) -> None: + """Called whenever the trainer finishes training. + Unlike the :attr:`~Event.TRAINING_END` event, :meth:`close` is + invoked even when there was an exception. + + It should be used for flushing and closing any files, etc... + that may have been opened during the :attr:`~Event.INIT` event. + """ + pass + + def post_close(self) -> None: + """This hook is called after :meth:`close` has been invoked for each callback. + Very few callbacks should need to implement :meth:`post_close`. + + This callback can be used to back up any data that may have been written by other + callbacks during :meth:`close`. + """ + pass + class RankZeroCallback(Callback, abc.ABC): """Base class for callbacks that only run on the local rank zero process. diff --git a/composer/core/engine.py b/composer/core/engine.py index dfdfb7983d..38b1ffcc6c 100644 --- a/composer/core/engine.py +++ b/composer/core/engine.py @@ -180,3 +180,32 @@ def _run_callbacks( for cb in self.callbacks: cb.run_event(event, self.state, self.logger) + + def close(self) -> None: + """Invoke :meth:`~Callback.close` and :meth:`~Callback.post_close` for each callback. + + :meth:`~Callback.close` is invoked for each callback. + For all callbacks where :meth:`~Callback.close` did not raise an exception, then + :meth:`~Callback.post_close` is invoked. + + Does not re-raise any exceptions from :meth:`~Callback.close` and :meth:`~Callback.post_close`. + Instead, these exceptions are logged. + """ + callback_to_has_exception: Dict[Callback, bool] = {} + for callback in self.callbacks: + try: + callback.close() + except Exception as e: + log.error( + f"Error running {callback.__class__.__name__}.close(). Skipping {callback.__class__.__name__}.post_close().", + exc_info=e, + stack_info=True) + callback_to_has_exception[callback] = True + else: + callback_to_has_exception[callback] = False + for callback in self.callbacks: + if callback_to_has_exception[callback] is False: + try: + callback.post_close() + except Exception as e: + log.error(f"Error running {callback.__class__.__name__}.post_close().", exc_info=e, stack_info=True) diff --git a/composer/loggers/file_logger.py b/composer/loggers/file_logger.py index 904a973ce9..37f9f16fa1 100644 --- a/composer/loggers/file_logger.py +++ b/composer/loggers/file_logger.py @@ -2,7 +2,6 @@ from __future__ import annotations -import atexit import os import sys from typing import Any, Dict, Optional, TextIO @@ -77,17 +76,20 @@ def _will_log(self, state: State, log_level: LogLevel) -> bool: def _log_metric(self, epoch: int, step: int, log_level: LogLevel, data: TLogData): data_str = format_log_data_value(data) + if self.file is None: + raise RuntimeError("Attempted to log before self.init() or after self.close()") print(f"[{log_level.name}][step={step}]: {data_str}", file=self.file) def init(self, state: State, logger: Logger) -> None: del state, logger # unused + if self.file is not None: + raise RuntimeError("The file logger is already initialized") if self.hparams.filename == "stdout": self.file = sys.stdout elif self.hparams.filename == "stderr": self.file = sys.stderr else: self.file = open(self.hparams.filename, "x+", buffering=self.hparams.buffer_size) - atexit.register(self._close_file) if self.config is not None: print("Config", file=self.file) print("-" * 30, file=self.file) @@ -114,8 +116,9 @@ def _flush_file(self) -> None: self.file.flush() os.fsync(self.file.fileno()) - def _close_file(self) -> None: - assert self.file is not None - assert self.file not in (sys.stdout, sys.stderr) - self._flush_file() - self.file.close() + def close(self) -> None: + if self.file is not None: + if self.file not in (sys.stdout, sys.stderr): + self._flush_file() + self.file.close() + self.file = None diff --git a/composer/loggers/wandb_logger.py b/composer/loggers/wandb_logger.py index 0ff6e34838..df84b6816b 100644 --- a/composer/loggers/wandb_logger.py +++ b/composer/loggers/wandb_logger.py @@ -2,7 +2,6 @@ from __future__ import annotations -import atexit import os import sys from typing import Any, Dict, Optional @@ -49,7 +48,6 @@ def state_dict(self) -> StateDict: def init(self, state: State, logger: Logger) -> None: del state, logger # unused wandb.init(**self._init_params) - atexit.register(self._close_wandb) def batch_end(self, state: State, logger: Logger) -> None: del logger # unused @@ -85,7 +83,8 @@ def _upload_artifacts(self): artifact.add_file(full_path) wandb.log_artifact(artifact) - def _close_wandb(self) -> None: + def post_close(self) -> None: + # Cleaning up on post_close so all artifacts are uploaded if self._log_artifacts: self._upload_artifacts() diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 0fefadf81f..807b304e17 100755 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -369,7 +369,10 @@ def create_from_hparams(cls, hparams: TrainerHparams) -> Trainer: def fit(self): """Train and evaluate the model on the provided data.""" - self._train_loop() + try: + self._train_loop() + finally: + self.engine.close() def _create_dataloaders(self) -> None: """Create the dataloaders. diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 0f6d1a4abe..08f0283ed5 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -13,7 +13,7 @@ def test_callbacks_map_to_events(): # callback methods must be 1:1 mapping with events # exception for private methods cb = Callback() - excluded_methods = ["state_dict", "load_state_dict", "run_event"] + excluded_methods = ["state_dict", "load_state_dict", "run_event", "close", "post_close"] methods = set(m for m in dir(cb) if (m not in excluded_methods and not m.startswith("_"))) event_names = set(e.value for e in Event) assert methods == event_names diff --git a/tests/conftest.py b/tests/conftest.py index f61f2591f2..dd295bb4ff 100755 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ # Copyright 2021 MosaicML. All Rights Reserved. -import atexit import datetime import logging import os @@ -14,7 +13,6 @@ import _pytest.mark import pytest import torch.distributed -from _pytest.monkeypatch import MonkeyPatch import composer from composer.utils.run_directory import get_relative_to_run_directory @@ -108,20 +106,6 @@ def pytest_collection_modifyitems(session: pytest.Session, config: _pytest.confi _filter_items_for_timeout(config, items) -@pytest.fixture(autouse=True) -def atexit_at_test_end(monkeypatch: MonkeyPatch): - # monkeypatch atexit so it is called when a test exits, not when the python process exits - atexit_callbacks = [] - - def register(func, *args, **kwargs): - atexit_callbacks.append((func, args, kwargs)) - - monkeypatch.setattr(atexit, "register", register) - yield - for func, args, kwargs in atexit_callbacks: - func(*args, **kwargs) - - @pytest.fixture(autouse=True) def set_loglevels(): logging.basicConfig() From 5171468943c959b52098e5b8a24a953708b43d82 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 30 Nov 2021 11:22:19 -0800 Subject: [PATCH 31/38] Uncommented code --- composer/callbacks/run_directory_uploader.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/composer/callbacks/run_directory_uploader.py b/composer/callbacks/run_directory_uploader.py index 627b2f6789..92487e9121 100644 --- a/composer/callbacks/run_directory_uploader.py +++ b/composer/callbacks/run_directory_uploader.py @@ -210,8 +210,8 @@ def _trigger_upload(self, logger: Optional[Logger], log_level: Optional[LogLevel # check if any upload threads have crashed. if so, then shutdown the training process for worker in self._workers: if not worker.is_alive(): - # assert self._finished is not None, "invariant error" - # self._finished.set() + assert self._finished is not None, "invariant error" + self._finished.set() raise RuntimeError("Upload worker crashed unexpectedly") for root, dirs, files in os.walk(run_directory): del dirs # unused @@ -224,7 +224,6 @@ def _trigger_upload(self, logger: Optional[Logger], log_level: Optional[LogLevel files_to_be_uploaded.append(relpath) copied_path_dirname = os.path.dirname(copied_path) os.makedirs(copied_path_dirname, exist_ok=True) - # shutil.copyfile(filepath, copied_path) shutil.copy2(filepath, copied_path) self._file_upload_queue.put_nowait(copied_path) self._last_upload_timestamp = new_last_uploaded_timestamp From 97326bdb3528fa21e709ce46c25d53de5e3a3718 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 30 Nov 2021 11:28:01 -0800 Subject: [PATCH 32/38] Running callbacks befor algorithms for the INIT event in the engine * For the INIT event, run the callbacks first to initialize the loggers. * For other events, run the algorithms first, so the callbacks have the state after algorithms modify it. --- composer/core/engine.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/composer/core/engine.py b/composer/core/engine.py index dfdfb7983d..63cb27f88c 100644 --- a/composer/core/engine.py +++ b/composer/core/engine.py @@ -96,8 +96,15 @@ def run_event( Returns: Dict[str, Trace]: dictionary of trace for each algorithm. """ - traces = self._run_algorithms(event) - self._run_callbacks(event) + if event == Event.INIT: + # For the INIT event, run the callbacks first to initialize the loggers + # For other events, run the algorithms first, so the callbacks have the state + # after algorithms modify it + self._run_callbacks(event) + traces = self._run_algorithms(event) + else: + traces = self._run_algorithms(event) + self._run_callbacks(event) return traces def _run_algorithms( From 20dc89616baf681445e7ae91546c14a15268c977 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 30 Nov 2021 19:19:31 -0800 Subject: [PATCH 33/38] Fixed tests --- tests/callbacks/test_run_directory_uploader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/callbacks/test_run_directory_uploader.py b/tests/callbacks/test_run_directory_uploader.py index faaca2f5f4..32ccdfa575 100644 --- a/tests/callbacks/test_run_directory_uploader.py +++ b/tests/callbacks/test_run_directory_uploader.py @@ -12,7 +12,7 @@ from composer.utils.run_directory import get_run_directory -@pytest.mark.parametrize("use_procs", [True, False]) +@pytest.mark.parametrize("use_procs", [False, True]) # TODO(ravi) -- remove the pytest.in #110. The TRAINING_END event is likely slow as it has to copy many # files created by the ddp test. #110 grately reduces the number of files from the DDP test. @pytest.mark.timeout(15) @@ -38,6 +38,8 @@ def test_run_directory_uploader(tmpdir: pathlib.Path, use_procs: bool, dummy_sta f.write("Hello, world!") uploader.run_event(Event.BATCH_END, dummy_state, dummy_logger) uploader.run_event(Event.TRAINING_END, dummy_state, dummy_logger) + uploader.close() + uploader.post_close() # now assert that we have a dummy file in the run directory copy folder with open(os.path.join(remote_dir, run_directory, "dummy_file"), "r") as f: From 42f9ab3e3b2fc46a62e7bd9b7d8a4285de3f5374 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Wed, 1 Dec 2021 14:24:10 -0800 Subject: [PATCH 34/38] Addressed PR feedback --- composer/callbacks/run_directory_uploader.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/composer/callbacks/run_directory_uploader.py b/composer/callbacks/run_directory_uploader.py index 92487e9121..33f09d0654 100644 --- a/composer/callbacks/run_directory_uploader.py +++ b/composer/callbacks/run_directory_uploader.py @@ -202,7 +202,6 @@ def _trigger_upload(self, logger: Optional[Logger], log_level: Optional[LogLevel ddp.barrier() new_last_uploaded_timestamp = time.time() # Now, for each file that was modified since self._last_upload_timestamp, copy it to the temporary directory - # IMPROTANT: From now, until self._last_upload_timestamp is updated, no files should be written to the run directory run_directory = get_run_directory() assert run_directory is not None, "invariant error" files_to_be_uploaded = [] @@ -311,7 +310,12 @@ def _upload_worker( if retry_counter < 3: retry_counter += 1 # exponential backoff - time.sleep(2**(retry_counter - 1)) + sleep_time = 2**(retry_counter - 1) + log.warn("Request failed with a transient error code. Sleeping %s seconds and retrying", + sleep_time, + exc_info=e, + stack_info=True) + time.sleep() continue raise e os.remove(file_path_to_upload) From 481ab37af7af374d8cf2bb5de400d0692a161dbb Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Thu, 2 Dec 2021 10:53:35 -0800 Subject: [PATCH 35/38] Fixed bug --- composer/callbacks/run_directory_uploader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/callbacks/run_directory_uploader.py b/composer/callbacks/run_directory_uploader.py index 33f09d0654..73efda6d01 100644 --- a/composer/callbacks/run_directory_uploader.py +++ b/composer/callbacks/run_directory_uploader.py @@ -315,7 +315,7 @@ def _upload_worker( sleep_time, exc_info=e, stack_info=True) - time.sleep() + time.sleep(sleep_time) continue raise e os.remove(file_path_to_upload) From 6fc5555a75915ac47302a106fa666238168d4744 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Thu, 2 Dec 2021 22:49:58 +0000 Subject: [PATCH 36/38] Fixed bugs --- composer/callbacks/run_directory_uploader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/composer/callbacks/run_directory_uploader.py b/composer/callbacks/run_directory_uploader.py index 73efda6d01..b483bb4962 100644 --- a/composer/callbacks/run_directory_uploader.py +++ b/composer/callbacks/run_directory_uploader.py @@ -14,7 +14,7 @@ import warnings from typing import Any, Callable, Dict, Optional, Type, Union -from composer.core.callback import RankZeroCallback +from composer.core.callback import Callback from composer.core.logging import Logger from composer.core.logging.logger import LogLevel from composer.core.state import State @@ -24,7 +24,7 @@ log = logging.getLogger(__name__) -class RunDirectoryUploader(RankZeroCallback): +class RunDirectoryUploader(Callback): """Callback to upload the run directory to a blob store. This callback checks the run directory for new or modified files @@ -291,7 +291,7 @@ def _upload_worker( break else: continue - obj_name = ",".join(os.path.relpath(file_path_to_upload, upload_staging_dir).split( + obj_name = os.path.sep.join(os.path.relpath(file_path_to_upload, upload_staging_dir).split( os.path.sep)[1:]) # the first folder is the upload timestamp. Chop that off. log.info("Uploading file %s to %s://%s/%s%s", file_path_to_upload, provider_name, container_name, object_name_prefix, obj_name) From ec7011e3e39c23bbe60a5495f38318702f942308 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Thu, 2 Dec 2021 23:01:51 +0000 Subject: [PATCH 37/38] Fixed rank 0 only uploads --- composer/callbacks/run_directory_uploader.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/composer/callbacks/run_directory_uploader.py b/composer/callbacks/run_directory_uploader.py index b483bb4962..b5cee5b8e7 100644 --- a/composer/callbacks/run_directory_uploader.py +++ b/composer/callbacks/run_directory_uploader.py @@ -133,8 +133,6 @@ def __init__( self._provider = provider self._container = container - _validate_credentials(provider, container, self._object_name_prefix, provider_init_kwargs) - if use_procs: mp_ctx = multiprocessing.get_context('spawn') self._file_upload_queue: Union[queue.Queue[str], @@ -148,9 +146,14 @@ def __init__( self._finished: Union[None, multiprocessing._EventType, threading.Event] = None self._workers = [] + if ddp.get_local_rank() == 0: + _validate_credentials(provider, container, self._object_name_prefix, provider_init_kwargs) + def init(self, state: State, logger: Logger) -> None: if get_run_directory() is None: return + if not ddp.get_local_rank() == 0: + return del state, logger # unused self._finished = self._finished_cls() self._last_upload_timestamp = 0.0 @@ -190,6 +193,8 @@ def training_end(self, state: State, logger: Logger) -> None: def post_close(self): # Cleaning up on post_close to ensure that all artifacts are uploaded self._trigger_upload(logger=None, log_level=None) + if not ddp.get_local_rank() == 0: + return if self._finished is not None: self._finished.set() for worker in self._workers: @@ -200,6 +205,8 @@ def _trigger_upload(self, logger: Optional[Logger], log_level: Optional[LogLevel # Assuming only the main thread on each rank writes to the run directory, then the barrier here will ensure # that the run directory is not being modified after we pass this barrier ddp.barrier() + if not ddp.get_local_rank() == 0: + return new_last_uploaded_timestamp = time.time() # Now, for each file that was modified since self._last_upload_timestamp, copy it to the temporary directory run_directory = get_run_directory() From 2d0b058f11a18a07e3f3102e7b84400b5d9bdaaf Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Thu, 2 Dec 2021 23:48:21 +0000 Subject: [PATCH 38/38] Using filesystem timestamps instead of python process timestamps to determine changed files --- composer/callbacks/run_directory_uploader.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/composer/callbacks/run_directory_uploader.py b/composer/callbacks/run_directory_uploader.py index b5cee5b8e7..95bad960ae 100644 --- a/composer/callbacks/run_directory_uploader.py +++ b/composer/callbacks/run_directory_uploader.py @@ -5,6 +5,7 @@ import logging import multiprocessing import os +import pathlib import queue import shutil import sys @@ -207,10 +208,15 @@ def _trigger_upload(self, logger: Optional[Logger], log_level: Optional[LogLevel ddp.barrier() if not ddp.get_local_rank() == 0: return - new_last_uploaded_timestamp = time.time() - # Now, for each file that was modified since self._last_upload_timestamp, copy it to the temporary directory run_directory = get_run_directory() assert run_directory is not None, "invariant error" + # the disk time can differ from system time, so going to touch a file and then read the timestamp from it to get the real time + python_time = time.time() + touch_file = (pathlib.Path(run_directory) / f".{python_time}") + touch_file.touch() + new_last_uploaded_timestamp = os.path.getmtime(str(touch_file)) + + # Now, for each file that was modified since self._last_upload_timestamp, copy it to the temporary directory files_to_be_uploaded = [] # check if any upload threads have crashed. if so, then shutdown the training process @@ -222,6 +228,9 @@ def _trigger_upload(self, logger: Optional[Logger], log_level: Optional[LogLevel for root, dirs, files in os.walk(run_directory): del dirs # unused for file in files: + if any(x.startswith(".") for x in file.split(os.path.sep)): + # skip hidden files and folders + continue filepath = os.path.join(root, file) relpath = os.path.relpath(filepath, run_directory) # chop off the run directory modified_time = os.path.getmtime(filepath)