diff --git a/composer/algorithms/augmix/augmix.py b/composer/algorithms/augmix/augmix.py index 55383cd1ba..1a06463cb5 100644 --- a/composer/algorithms/augmix/augmix.py +++ b/composer/algorithms/augmix/augmix.py @@ -251,7 +251,10 @@ def __init__(self, self._transformed_datasets = weakref.WeakSet() def match(self, event: Event, state: State) -> bool: - return event == Event.FIT_START and state.train_dataloader.dataset not in self._transformed_datasets + if event != Event.FIT_START: + return False + assert state.dataloader is not None, "dataloader should be defined on fit start" + return state.dataloader.dataset not in self._transformed_datasets def apply(self, event: Event, state: State, logger: Logger) -> None: am = AugmentAndMixTransform(severity=self.severity, @@ -259,7 +262,8 @@ def apply(self, event: Event, state: State, logger: Logger) -> None: width=self.width, alpha=self.alpha, augmentation_set=self.augmentation_set) - dataset = state.train_dataloader.dataset + assert state.dataloader is not None, "dataloader should be defined on fit start" + dataset = state.dataloader.dataset if not isinstance(dataset, VisionDataset): raise TypeError( textwrap.dedent(f"""\ diff --git a/composer/algorithms/colout/colout.py b/composer/algorithms/colout/colout.py index d81ab6d0aa..f5d2704feb 100644 --- a/composer/algorithms/colout/colout.py +++ b/composer/algorithms/colout/colout.py @@ -214,11 +214,15 @@ def match(self, event: Event, state: State) -> bool: if self.batch: return event == Event.AFTER_DATALOADER else: - return event == Event.FIT_START and state.train_dataloader.dataset not in self._transformed_datasets + if event != Event.FIT_START: + return False + assert state.dataloader is not None, "dataloader should be defined on fit start" + return state.dataloader.dataset not in self._transformed_datasets def _apply_sample(self, state: State) -> None: """Add the ColOut dataset transform to the dataloader.""" - dataset = state.train_dataloader.dataset + assert state.dataloader is not None, "dataloader should be defined on fit start" + dataset = state.dataloader.dataset transform = ColOutTransform(p_row=self.p_row, p_col=self.p_col, resize_target=self.resize_target) diff --git a/composer/algorithms/layer_freezing/layer_freezing.py b/composer/algorithms/layer_freezing/layer_freezing.py index 4f278e9d58..d56544d726 100644 --- a/composer/algorithms/layer_freezing/layer_freezing.py +++ b/composer/algorithms/layer_freezing/layer_freezing.py @@ -133,10 +133,12 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: del event # unused optimizers = state.optimizers assert optimizers is not None + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed duration should be set on Event.EPOCH_END" freeze_depth, freeze_percentage = freeze_layers( model=state.model, optimizers=optimizers, - current_duration=float(state.get_elapsed_duration()), + current_duration=float(elapsed_duration), freeze_start=self.freeze_start, freeze_level=self.freeze_level, ) diff --git a/composer/algorithms/progressive_resizing/progressive_resizing.py b/composer/algorithms/progressive_resizing/progressive_resizing.py index 737ce82469..71a68d3d49 100644 --- a/composer/algorithms/progressive_resizing/progressive_resizing.py +++ b/composer/algorithms/progressive_resizing/progressive_resizing.py @@ -197,7 +197,9 @@ def apply(self, event: Event, state: State, logger: Optional[Logger] = None) -> # Calculate the current size of the inputs to use initial_size = self.initial_scale finetune_fraction = self.finetune_fraction - scale_frac_elapsed = min([state.get_elapsed_duration().value / (1 - finetune_fraction), 1]) + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed duration should be set on Event.AFTER_DATALOADER" + scale_frac_elapsed = min([elapsed_duration.value / (1 - finetune_fraction), 1]) # Linearly increase to full size at the start of the fine tuning period scale_factor = initial_size + (1 - initial_size) * scale_frac_elapsed diff --git a/composer/algorithms/randaugment/randaugment.py b/composer/algorithms/randaugment/randaugment.py index 7501bc372d..db3c2b0acf 100644 --- a/composer/algorithms/randaugment/randaugment.py +++ b/composer/algorithms/randaugment/randaugment.py @@ -188,12 +188,15 @@ def __init__(self, severity: int = 9, depth: int = 2, augmentation_set: str = "a self._transformed_datasets = weakref.WeakSet() def match(self, event: Event, state: State) -> bool: - return event == Event.FIT_START and state.train_dataloader.dataset not in self._transformed_datasets + if event != Event.FIT_START: + return False + assert state.dataloader is not None, "dataloader should be defined on fit start" + return state.dataloader.dataset not in self._transformed_datasets def apply(self, event: Event, state: State, logger: Logger) -> None: ra = RandAugmentTransform(severity=self.severity, depth=self.depth, augmentation_set=self.augmentation_set) - assert state.train_dataloader is not None - dataset = state.train_dataloader.dataset + assert state.dataloader is not None, "dataloader should be defined on fit start" + dataset = state.dataloader.dataset if not isinstance(dataset, VisionDataset): raise TypeError( textwrap.dedent(f"""\ diff --git a/composer/algorithms/selective_backprop/selective_backprop.py b/composer/algorithms/selective_backprop/selective_backprop.py index 40a6799ffa..fce013f74d 100644 --- a/composer/algorithms/selective_backprop/selective_backprop.py +++ b/composer/algorithms/selective_backprop/selective_backprop.py @@ -12,6 +12,7 @@ from torch.nn import functional as F from composer.core import Algorithm, Event, State +from composer.core.precision import get_precision_context from composer.loggers import Logger from composer.models import ComposerModel @@ -93,7 +94,7 @@ def select_using_loss(input: torch.Tensor, This function runs an extra forward pass through the model on the batch of data. If you are using a non-default precision, ensure that this forward pass runs in your desired precision. For example: - + .. testsetup:: N_sb, D_sb = 16, 8 @@ -223,8 +224,11 @@ def match(self, event: Event, state: State) -> bool: if not is_keep: return False + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed duration should be set on Event.AFTER_DATALOADER" + is_chosen = should_selective_backprop( - current_duration=float(state.get_elapsed_duration()), + current_duration=float(elapsed_duration), batch_idx=state.timer.batch_in_epoch.value, start=self.start, end=self.end, @@ -250,6 +254,6 @@ def loss(p, y, reduction="none"): assert self._loss_fn is not None, "loss_fn should be set on Event.INIT" return self._loss_fn(p, (torch.Tensor(), y), reduction=reduction) - with state.precision_context: + with get_precision_context(state.precision): new_input, new_target = select_using_loss(input, target, model, loss, self.keep, self.scale_factor) state.batch = (new_input, new_target) diff --git a/composer/algorithms/seq_length_warmup/seq_length_warmup.py b/composer/algorithms/seq_length_warmup/seq_length_warmup.py index 63cfd9a182..6232feadd7 100644 --- a/composer/algorithms/seq_length_warmup/seq_length_warmup.py +++ b/composer/algorithms/seq_length_warmup/seq_length_warmup.py @@ -9,6 +9,7 @@ import torch from composer.core import Algorithm, Event, State +from composer.core.precision import get_precision_context from composer.core.time import TimeUnit from composer.core.types import Batch from composer.loggers import Logger @@ -179,12 +180,12 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: {type(self).__name__} requires state.model to be of type {ComposerTransformer.__name__}, not of type {type(state.model)}""" )) - if state.train_dataloader.batch_size is None: - raise RuntimeError("Sequence Length Warmup algorithm requires constant batch size.") - self._original_model = state.model return + assert state.dataloader is not None, "dataloader should be set on AFTER_DATALOADER" + assert state.max_duration is not None, "max_duration should be set on AFTER_DATALOADER" + # in order to avoid OOMs, we do a forward and a backward pass on a dummy input. if not self._activated: # ensure that input_ids is a valid model input. since we don't need the @@ -204,7 +205,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: # all of the parameters device = next(state.model.parameters()).device - per_gpu_macrobatch = state.train_dataloader.batch_size + per_gpu_macrobatch = state.dataloader.batch_size if per_gpu_macrobatch is None: raise RuntimeError("Sequence Length Warmup algorithm requires constant batch size.") per_gpu_batch = ceil(per_gpu_macrobatch / state.grad_accum) @@ -223,7 +224,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: # start by running a forward and backward pass # of the maximum sequence length to allocate cache. - with state.precision_context: + with get_precision_context(state.precision): outputs = state.model.forward(model_inputs) loss = self._original_model.loss(outputs, model_inputs) @@ -238,7 +239,9 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: self._activated = True if state.max_duration.unit == TimeUnit.EPOCH: - num_optimization_steps = state.steps_per_epoch * state.max_duration.value + if state.dataloader_len is None: + raise RuntimeError("Sequential Length Warmup requires the dataloader to be sized.") + num_optimization_steps = int(state.dataloader_len) * state.max_duration.value elif state.max_duration.unit == TimeUnit.BATCH: num_optimization_steps = state.max_duration.value else: diff --git a/composer/algorithms/stochastic_depth/stochastic_depth.py b/composer/algorithms/stochastic_depth/stochastic_depth.py index ab70b88a19..3740ede5f8 100644 --- a/composer/algorithms/stochastic_depth/stochastic_depth.py +++ b/composer/algorithms/stochastic_depth/stochastic_depth.py @@ -232,8 +232,10 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: logger.data_epoch({'stochastic_depth/num_stochastic_layers': num_stochastic_layers}) elif event == Event.BATCH_START: - if state.get_elapsed_duration() < self.drop_warmup: - current_drop_rate = float(state.get_elapsed_duration() / self.drop_warmup) * self.drop_rate + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed duration is set on BATCH_START" + if elapsed_duration < self.drop_warmup: + current_drop_rate = float(elapsed_duration / self.drop_warmup) * self.drop_rate _update_drop_rate(state.model, stochastic_layer, current_drop_rate, self.drop_distribution) else: current_drop_rate = self.drop_rate diff --git a/composer/algorithms/swa/swa.py b/composer/algorithms/swa/swa.py index a2e1e9782e..761c4594e4 100644 --- a/composer/algorithms/swa/swa.py +++ b/composer/algorithms/swa/swa.py @@ -172,8 +172,12 @@ def __init__(self, self.match_event = Event.EPOCH_END def match(self, event: Event, state: State) -> bool: + if event != self.match_event: + return False if self.swa_start.unit == TimeUnit.DURATION: - should_start_swa = state.get_elapsed_duration() >= self.swa_start and not self.swa_completed + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed duration should be set on Event.BATCH_END or Event.EPOCH_END" + should_start_swa = elapsed_duration >= self.swa_start and not self.swa_completed elif self.swa_start.unit == TimeUnit.EPOCH: should_start_swa = state.timer.get("ep") >= self.swa_start and not self.swa_completed else: @@ -219,16 +223,19 @@ def apply(self, event: Event, state: State, logger: Logger) -> None: self.step_counter += 1 + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed duration should be set on Event.BATCH_END or Event.EPOCH_END" + # Determine whether it's time to end SWA - if self.swa_end.unit == TimeUnit.DURATION and (state.get_elapsed_duration() >= self.swa_end): + if self.swa_end.unit == TimeUnit.DURATION and (elapsed_duration >= self.swa_end): self.swa_completed = True if self.swa_end.unit == TimeUnit.EPOCH and (state.timer.get("ep") >= self.swa_end): self.swa_completed = True if self.swa_completed: if state.get_elapsed_duration() == 1: - log.warning("The baseline model was replaced with the SWA model after the end of " - "training. This means that SWA model will not have its batch norm " - "statistics updated. This will negatively impact accuracy. See the " - "documentation for the `swa_end` parameter for details.") + log.warning(("The baseline model was replaced with the SWA model after the end of " + "training. This means that SWA model will not have its batch norm " + "statistics updated. This will negatively impact accuracy. See the " + "documentation for the `swa_end` parameter for details.")) state.model.load_state_dict(self.swa_model.module.state_dict()) # type: ignore log.info('Set model to the averaged model') diff --git a/composer/callbacks/checkpoint_saver.py b/composer/callbacks/checkpoint_saver.py index 8df7211029..ed8d1bebf1 100644 --- a/composer/callbacks/checkpoint_saver.py +++ b/composer/callbacks/checkpoint_saver.py @@ -57,7 +57,10 @@ def checkpoint_periodically(interval: Union[str, int, Time]) -> Callable[[State, def save_interval(state: State, event: Event): nonlocal last_checkpoint_batch - if state.get_elapsed_duration() >= 1.0: + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed_duration is set on the BATCH_CHECKPOINT and EPOCH_CHECKPOINT" + + if elapsed_duration >= 1.0: # if doing batch-wise checkpointing, and we saved a checkpoint at the batch_checkpoint event # right before the epoch_checkpoint event, do not save another checkpoint at the epoch_checkpoint # event if the batch count didn't increase. @@ -332,12 +335,16 @@ def fit_start(self, state: State, logger: Logger) -> None: def batch_checkpoint(self, state: State, logger: Logger): if self.save_interval(state, Event.BATCH_CHECKPOINT): # If training is finished, log at the FIT loglevel - log_level = LogLevel.BATCH if state.get_elapsed_duration() < 1.0 else LogLevel.FIT + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed_duration is set on Event.BATCH_CHECKPOINT" + log_level = LogLevel.BATCH if elapsed_duration < 1.0 else LogLevel.FIT self._save_checkpoint(state, logger, log_level) def epoch_checkpoint(self, state: State, logger: Logger): if self.save_interval(state, Event.EPOCH_CHECKPOINT): - log_level = LogLevel.EPOCH if state.get_elapsed_duration() < 1.0 else LogLevel.FIT + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed_duration is set on Event.BATCH_CHECKPOINT" + log_level = LogLevel.EPOCH if elapsed_duration < 1.0 else LogLevel.FIT self._save_checkpoint(state, logger, log_level) def _save_checkpoint(self, state: State, logger: Logger, log_level: LogLevel): diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index ff2f480a46..6e29724200 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -126,6 +126,7 @@ def epoch_start(self, state: State, logger: Logger): def batch_end(self, state: State, logger: Logger): self.batch_end_times.append(time.time()) new_num_samples = state.timer.sample + assert self.batch_start_num_samples is not None, "self.batch_start_num_samples should have been set on Event.BATCH_START" batch_num_samples = int(new_num_samples - self.batch_start_num_samples) self.batch_num_samples.append(batch_num_samples) self.train_examples_per_epoch += batch_num_samples diff --git a/composer/core/engine.py b/composer/core/engine.py index 33ec3173a0..35cee9ec72 100644 --- a/composer/core/engine.py +++ b/composer/core/engine.py @@ -83,6 +83,25 @@ Traces = Dict[str, "Trace"] _ALWAYS_RECORD_EVENTS = [Event.INIT, Event.FIT_START, Event.EPOCH_START, Event.EPOCH_END] +_EVENTS_WHERE_DATALOADER_IS_SET = [e for e in Event if e != Event.INIT] +_EVENTS_WHERE_MAX_DURATION_IS_SET = [ + Event.FIT_START, + Event.EPOCH_START, + Event.BATCH_START, + Event.AFTER_DATALOADER, + Event.BEFORE_TRAIN_BATCH, + Event.BEFORE_FORWARD, + Event.AFTER_FORWARD, + Event.BEFORE_LOSS, + Event.AFTER_LOSS, + Event.BEFORE_BACKWARD, + Event.AFTER_BACKWARD, + Event.AFTER_TRAIN_BATCH, + Event.BATCH_END, + Event.BATCH_CHECKPOINT, + Event.EPOCH_END, + Event.EPOCH_CHECKPOINT, +] @dataclass @@ -174,6 +193,12 @@ def run_event( if event.is_after_event and duration_marker is not None: duration_marker.finish() + if event in _EVENTS_WHERE_DATALOADER_IS_SET: + assert self.state.dataloader is not None, f"The trainer should have set state.dataloader for event {event}." + + if event in _EVENTS_WHERE_MAX_DURATION_IS_SET: + assert self.state.max_duration is not None, f"The trainer should have set state.max_duration for event {event}." + if event == Event.INIT: # For the INIT event, run the callbacks first to initialize the loggers # For other events, run the algorithms first, so the callbacks have the state diff --git a/composer/core/precision.py b/composer/core/precision.py index db11f98fe9..5d124a15ed 100644 --- a/composer/core/precision.py +++ b/composer/core/precision.py @@ -2,9 +2,15 @@ """Enum class for the numerical precision to be used by the model.""" +import contextlib +from typing import Generator, Union + +import torch +from packaging import version + from composer.utils.string_enum import StringEnum -__all__ = ["Precision"] +__all__ = ["Precision", "get_precision_context"] class Precision(StringEnum): @@ -23,3 +29,32 @@ class Precision(StringEnum): FP16 = "fp16" FP32 = "fp32" BF16 = "bf16" + + +@contextlib.contextmanager +def get_precision_context(precision: Union[str, Precision]) -> Generator[None, None, None]: + """Returns a context manager to automatically cast to a specific precision. + + Args: + precision (str or Precision): Precision for the context + """ + + precision = Precision(precision) + if precision == Precision.FP32: + if torch.cuda.is_available(): + with torch.cuda.amp.autocast(False): + yield + else: + # Yield here to avoid warnings about cuda not being available + yield + elif precision == Precision.AMP: + # Retain compatibility with PyTorch < 1.10 + with torch.cuda.amp.autocast(True): + yield + elif precision == Precision.BF16: + if version.parse(torch.__version__) < version.parse("1.10"): + raise ValueError(f"BF16 precision requires torch > 1.10, got version {torch.__version__}") + with torch.cuda.amp.autocast(True, torch.bfloat16): + yield + else: + raise ValueError(f"Unsupported precision: {precision}") diff --git a/composer/core/state.py b/composer/core/state.py index 784df5e929..d8bb8bc0b2 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -3,11 +3,9 @@ """The state of the trainer.""" from __future__ import annotations -import contextlib import logging -import textwrap import warnings -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Optional, Sequence, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast import torch import torch.nn.modules.utils @@ -25,7 +23,6 @@ import composer.core.types as types from composer.core.algorithm import Algorithm from composer.core.callback import Callback - from composer.core.evaluator import Evaluator from composer.profiler import Profiler __all__ = ["State"] @@ -33,24 +30,6 @@ logger = logging.getLogger(__name__) -def _default_precision_factory() -> Callable[[Union[str, Precision]], ContextManager]: - """Returns a context manager to automatically cast to a specific precision. - - Args: - precision (str or Precision): Precision for the context - """ - if torch.cuda.is_available(): - return lambda precision: torch.cuda.amp.autocast(Precision(precision) == Precision.AMP) - else: - - def null(precision): - assert Precision( - precision) != Precision.AMP, "Precision AMP is only available when `torch.cuda.is_available() == True`." - return contextlib.nullcontext() - - return null - - def _ensure_backwards_compatible_checkpointing(state_dict: Dict[str, Any]): # v0.4.1 removed the leading underscores for the keys in the state_dict # It also renamed _is_model_ddp_wrapped to is_model_ddp @@ -95,16 +74,18 @@ class State(Serializable): model (torch.nn.Module): The model, typically as a subclass of :class:`~.ComposerModel`. rank_zero_seed (int): The seed used on the rank zero process. It is assumed that each rank's seed is ``rank_zero_seed + dist.get_global_rank()``. - grad_accum (int): The number of gradient accumulation steps to use. With this argument, micro batch size for + grad_accum (int, optional): The number of gradient accumulation steps to use. With this argument, micro batch size for each device becomes ``microbatch_size = train_batch_size / (num_devices * grad_accum)``. - train_dataloader (types.DataLoader, DataSpec, or dict): - The :class:`~.types.DataLoader`, :class:`~.DataSpec`, or dict of :class:`~.DataSpec` kwargs to used for training. - evaluators (evaluator.Evaluator | Sequence[evaluator.Evaluator]): - The evaluators contain the evaluation dataset(s) used for evaluation with specific metrics. - max_duration (str or Time): The maximum duration to train for. + dataloader (types.DataLoader, optional): The active DataLoader. + dataloader_len (int | Time[int], optional): The number of batches per dataloader iteration (e.g. epoch). + The trainer will yield the first ``dataloader_len`` batches per iteration. If ``-1`` (the default), + the entire dataloader will be iterated over. + dataloader_label (str, optional): The name for the dataloader. Required if ``dataloader`` is specified. (default: ``None``) + By convention, the training dataloader is called ``'train'``. The evaluator dataloader is called + ``'eval'``, or when multiple evaluators are used, the name of the evaluator. + max_duration (str | Time, optional): The maximum duration to train for. (default: ``None``) precision (str | Precision): The numerical precision to use for training. See :class:`~.Precision` for the supported precisions. - precision_context (Callable[[Precision], ContextManager]): Function to produce a context manager to mandate precision. optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional): The optimizer being used to train the model. Multiple optimizers are not currently supported. schedulers (types.PyTorchScheduler | Sequence[types.PyTorchScheduler], optional): @@ -187,8 +168,10 @@ class State(Serializable): +-----------------------+-------------------------------------------------------------+ """ - _max_duration: Time[int] - _steps_per_epoch: Optional[int] + _dataloader: Optional[types.DataLoader] + _dataloader_label: Optional[str] + _dataloader_len: Optional[Time[int]] + _max_duration: Optional[Time[int]] batch: types.Batch batch_num_samples: int batch_num_tokens: int @@ -201,18 +184,20 @@ def __init__( # model model: torch.nn.Module, - # stopping conditions - max_duration: Union[str, Time[int]], + # determinism rank_zero_seed: int, + # stopping conditions + max_duration: Optional[Union[str, Time[int]]] = None, + # data configurations - train_dataloader: types.DataLoader, - evaluators: Optional[Union[Evaluator, Sequence[Evaluator]]] = None, grad_accum: int = 1, + dataloader: Optional[types.DataLoader] = None, + dataloader_label: Optional[str] = None, + dataloader_len: Union[int, Time[int]] = -1, # precision precision: Union[str, Precision] = Precision.FP32, - precision_context: Callable[[Precision], ContextManager] = _default_precision_factory(), # optimizers optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None, @@ -223,21 +208,16 @@ def __init__( # algorithms and callbacks algorithms: Optional[Union[Algorithm, Sequence[Algorithm]]] = None, callbacks: Optional[Union[Callback, Sequence[Callback]]] = None, - - # steps per epoch - steps_per_epoch: Optional[int] = None, ): self.rank_zero_seed = rank_zero_seed self.model = model self.grad_accum = grad_accum - self.train_dataloader = train_dataloader - self.evaluators = list(ensure_tuple(evaluators)) + self._dataloader_len = None + self.set_dataloader(dataloader, dataloader_label, dataloader_len) self.max_duration = max_duration - self.steps_per_epoch = steps_per_epoch self.timer = Timer() self._precision = Precision(precision) - self._precision_context = precision_context if optimizers is None: self._optimizers = [] @@ -251,6 +231,7 @@ def __init__( self._callbacks = list(ensure_tuple(callbacks)) self.profiler: Optional[Profiler] = None + # These attributes will be serialized using .state_dict(), and loaded with .load_state_dict() # All other attributes will not be serialized. # For simplicity, omit the leading underscore for private attributes. @@ -281,20 +262,26 @@ def max_duration(self): return self._max_duration @max_duration.setter - def max_duration(self, max_duration: Union[str, Time[int]]): + def max_duration(self, max_duration: Optional[Union[str, Time[int]]]): + if max_duration is None: + self._max_duration = None + return if isinstance(max_duration, str): max_duration = cast(Time[int], Time.from_timestring(max_duration)) if max_duration.unit == TimeUnit.DURATION: raise ValueError("TimeUnit.DURATION is not allowed as a unit for max_duration") self._max_duration = max_duration - def get_elapsed_duration(self) -> Time[float]: + def get_elapsed_duration(self) -> Optional[Time[float]]: """Get the elapsed training duration. Returns: - Time: The elapsed duration, in :attr:`TimeUnit.DURATION`. ``Time(0.0, TimeUnit.DURATION)`` represents the - beginning of training and ``Time(1.0, TimeUnit.DURATION)`` represents a completed training process. + Optional[Time[float]]: The elapsed duration, in :attr:`TimeUnit.DURATION`. + ``Time(0.0, TimeUnit.DURATION)`` represents the beginning of training and ``Time(1.0, TimeUnit.DURATION)`` + represents a completed training process. Returns ``None`` if ``max_duration`` is None. """ + if self.max_duration is None: + return None return self.timer.get(self.max_duration.unit) / self.max_duration @property @@ -310,7 +297,7 @@ def schedulers(self): return self._schedulers @schedulers.setter - def schedulers(self, schedulers: types.PyTorchScheduler): + def schedulers(self, schedulers: Union[types.PyTorchScheduler, Sequence[types.PyTorchScheduler]]): self._schedulers[:] = ensure_tuple(schedulers) @property @@ -410,25 +397,87 @@ def load_state_dict(self, state: Dict[str, Any], strict: bool = False): pass @property - def steps_per_epoch(self): - """int: The maximum number of steps (batches) per epoch.""" - if self._steps_per_epoch is None: - return len(self.train_dataloader) - return self._steps_per_epoch - - @steps_per_epoch.setter - def steps_per_epoch(self, steps_per_epoch: Optional[int]): + def dataloader(self): + """The dataloader.""" + return self._dataloader + + @property + def dataloader_label(self): + """The dataloader label. By convention, the training dataloader is called ``'train'``. The evaluator dataloader + is called ``'eval'``, or when multiple evaluators are used, the name of the evaluator. + + Returns: + Optional[str]: The dataloader label, or None if no dataloader is set. + """ + return self._dataloader_label + + def set_dataloader( + self, + dataloader: Optional[types.DataLoader] = None, + dataloader_label: Optional[str] = None, + dataloader_len: Union[int, Time[int]] = -1, + ): + """Update the dataloader and dataloader label. + + Args: + dataloader (types.DataLoader, optional): The dataloader. Defaults to None. + dataloader_label (str, optional): The dataloader label. Must be ``None`` if and only if + ``dataloader`` is None. Defaults to None. + dataloader_len (int, int): The number of batches per dataloader iteration (e.g. epoch), as used by the trainer. + Set to ``-1`` to iterate over the entire dataset. (Default: ``-1``.) + """ + if dataloader is None: + dataloader_label = None + else: + if dataloader_label is None: + raise ValueError("If the `dataloader` is specified, then `dataloader_label` must not be None.") + self._dataloader = dataloader + self._dataloader_label = dataloader_label + if dataloader is not None: + self.dataloader_len = dataloader_len # setting it to -1 will do a failsafe read of len(dataloader) + + @property + def dataloader_len(self): + """The number of batches per dataloader iteration (e.g. epoch), as used by the trainer. + + .. note:: + + If not explicitely specified, this value is an approximation, as it depends on ``len(self.dataloader)``. + See the :doc:`PyTorch DataLoader Documentation ` for more information. + + Returns: + Optional[Time[int]]: The number of batches per dataloader iteration (e.g. epoch), or None if no dataloader + is defined or if the dataloader has an unknown length (e.g. streaming dataloaders). + """ + return self._dataloader_len + + @dataloader_len.setter + def dataloader_len(self, num_batches: Union[int, Time[int]]): + if isinstance(num_batches, int): + num_batches = Time(num_batches, TimeUnit.BATCH) + if self._dataloader is None: + raise RuntimeError("`State.dataloader_len` cannot be set if the dataloader is not defined.") try: - dataloader_len = len(self.train_dataloader) + dataloader_len = len(self._dataloader) except (TypeError, NotImplementedError): dataloader_len = None - if dataloader_len is not None and steps_per_epoch is not None and steps_per_epoch > dataloader_len: - warnings.warn( - textwrap.dedent(f"""\ - SubsetNumBatchesWarning: The steps_per_epoch({steps_per_epoch}) - is greater than the number of batches in the training dataloader - ({dataloader_len})""")) - self._steps_per_epoch = steps_per_epoch + if dataloader_len is not None and num_batches >= 0 and int(num_batches) > dataloader_len: + warnings.warn((f"DataloaderNumBatchesWarning: The dataloader_len ({int(num_batches)}) " + f"is greater than the length (i.e. number of batches) of the dataloader, which is " + f"{dataloader_len}. State.dataloader_len is thus being set to {dataloader_len}.")) + self._dataloader_len = Time(dataloader_len, TimeUnit.BATCH) + return + if num_batches < 0: + if dataloader_len is not None: + # len(dataloader) is an approximation -- see https://pytorch.org/docs/stable/data.html. + # However, in the worst case where additional last batches are dropped, this calculation should be + # an over-estimate, leading to the entire dataloader still being iterated over. + self._dataloader_len = Time(dataloader_len, TimeUnit.BATCH) + else: + # The dataloader length is unknown. + self._dataloader_len = None + return + self._dataloader_len = num_batches @property def precision(self): @@ -462,10 +511,6 @@ def batch_dict(self) -> types.BatchDict: from composer.core.types import as_batch_dict return as_batch_dict(self.batch) - @property - def precision_context(self): - return self._precision_context(self.precision) - @property def is_model_deepspeed(self) -> bool: """Whether :attr:`model` is an instance of a :class:`~deepspeed.DeepSpeedEngine`.""" diff --git a/composer/core/time.py b/composer/core/time.py index 6af0220d3e..ee40905d1e 100644 --- a/composer/core/time.py +++ b/composer/core/time.py @@ -224,18 +224,20 @@ def _parse(self, other: object) -> Time: # parse ``other`` into a Time object if isinstance(other, Time): return other + if isinstance(other, int): + return Time(other, self.unit) if isinstance(other, str): other_parsed = Time.from_timestring(other) warnings.warn( textwrap.dedent(f"""\ TimeImplicitStringConversion: - Implicitly converting {other} to {other_parsed}. - To fix this warning, replace {other} with {other_parsed}.""")) + Implicitly converting '{other}' to '{repr(other_parsed)}'. + To fix this warning, replace '{other}' with '{repr(other_parsed)}'.""")) return other_parsed raise TypeError(f"Cannot convert type {other} to {self.__class__.__name__}") - def _cmp(self, other: object) -> int: + def _cmp(self, other: Union[int, float, Time, str]) -> int: # When doing comparisions, and other is an integer (or float), we can safely infer # the unit from self.unit # E.g. calls like this should be allowed: if batch < 42: do_something() @@ -254,40 +256,40 @@ def _cmp(self, other: object) -> int: assert self.value > other.value return 1 - def __eq__(self, other: object): + def __eq__(self, other: Union[int, float, Time, str]): return self._cmp(other) == 0 - def __ne__(self, other: object): + def __ne__(self, other: Union[int, float, Time, str]): return self._cmp(other) != 0 - def __lt__(self, other: object): + def __lt__(self, other: Union[int, float, Time, str]): return self._cmp(other) < 0 - def __le__(self, other: object): + def __le__(self, other: Union[int, float, Time, str]): return self._cmp(other) <= 0 - def __gt__(self, other: object): + def __gt__(self, other: Union[int, float, Time, str]): return self._cmp(other) > 0 - def __ge__(self, other: object): + def __ge__(self, other: Union[int, float, Time, str]): return self._cmp(other) >= 0 - def __add__(self, other: object) -> Time[TValue]: + def __add__(self, other: Union[int, float, Time, str]) -> Time[TValue]: other = self._parse(other) if self.unit != other.unit: raise RuntimeError(f"Cannot add {self} to {other} since they have different units.") return Time(self.value + other.value, self.unit) - def __radd__(self, other: object) -> Time[TValue]: + def __radd__(self, other: Union[int, float, Time, str]) -> Time[TValue]: return self + other - def __sub__(self, other: object) -> Time[TValue]: + def __sub__(self, other: Union[int, float, Time, str]) -> Time[TValue]: other = self._parse(other) if self.unit != other.unit: raise RuntimeError(f"Cannot subtract {other} from {self} since they have different units.") return Time(self.value - other.value, self.unit) - def __rsub__(self, other: object) -> Time[TValue]: + def __rsub__(self, other: Union[int, float, Time, str]) -> Time[TValue]: return (-self) + other def __neg__(self) -> Time[TValue]: diff --git a/composer/loggers/progress_bar_logger.py b/composer/loggers/progress_bar_logger.py index 6357f124d2..cb1ba0e310 100644 --- a/composer/loggers/progress_bar_logger.py +++ b/composer/loggers/progress_bar_logger.py @@ -4,7 +4,6 @@ from __future__ import annotations -import collections.abc import sys from dataclasses import asdict, dataclass from typing import Any, Callable, Dict, List, Optional, TextIO, Union @@ -155,14 +154,7 @@ def _start(self, state: State): if dist.get_local_rank() != 0 or not self.show_pbar: return assert self.is_train is not None, "self.is_train should be set by the callback" - if self.is_train: - total_steps = state.steps_per_epoch - else: - total_steps = 0 - for evaluator in state.evaluators: - dataloader_spec = evaluator.dataloader - assert isinstance(dataloader_spec.dataloader, collections.abc.Sized) - total_steps += len(dataloader_spec.dataloader) + assert state.dataloader_len is not None, "dataloader_len should be set when using tqdm" desc = f'Epoch {int(state.timer.epoch)}' position = 0 if self.is_train else 1 @@ -171,7 +163,7 @@ def _start(self, state: State): self.pbars[self.is_train] = _ProgressBarLoggerInstance( file=self.stream, state=_ProgressBarLoggerInstanceState( - total=total_steps, + total=int(state.dataloader_len), position=position, n=0, keys_to_log=_IS_TRAIN_TO_KEYS_TO_LOG[self.is_train], diff --git a/composer/optim/scheduler.py b/composer/optim/scheduler.py index 47268e2d62..e3b19ba111 100644 --- a/composer/optim/scheduler.py +++ b/composer/optim/scheduler.py @@ -130,16 +130,22 @@ def _convert_time(time: Union[str, Time[int], Time[float]], state: State, ssr: f if isinstance(time, str): time = Time.from_timestring(time) + assert state.max_duration is not None, "max_duration should be set whenever schedulers are invoked" + if time.unit == TimeUnit.DURATION: + if state.dataloader_len is None: + raise RuntimeError("Cannot convert time, as state.dataloader_len is None.") if state.max_duration.unit == TimeUnit.EPOCH: - return Time(int(time.value * state.steps_per_epoch * state.max_duration.value), TimeUnit.BATCH) + return Time(int(time.value * int(state.dataloader_len) * state.max_duration.value), TimeUnit.BATCH) return Time(int(time.value * state.max_duration.value), state.max_duration.unit) if time.unit == TimeUnit.EPOCH: # Epochs do not provide sufficient granularity for SSR scaling # e.g. if max_duration = 1ep, then any SSR would result in a new duration of 0. # so, convert the time into batches - time = Time(value=time.value * state.steps_per_epoch, unit=TimeUnit.BATCH) + if state.dataloader_len is None: + raise RuntimeError("Cannot convert time, as state.dataloader_len is None.") + time = Time(value=time.value * int(state.dataloader_len), unit=TimeUnit.BATCH) return Time(value=int(time.value * ssr), unit=time.unit) diff --git a/composer/profiler/dataloader_profiler.py b/composer/profiler/dataloader_profiler.py index 439c5d88ac..012a42149a 100644 --- a/composer/profiler/dataloader_profiler.py +++ b/composer/profiler/dataloader_profiler.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Iterator, Optional from composer.core.callback import Callback +from composer.core.event import Event from composer.datasets.dataloader import WrappedDataLoader if TYPE_CHECKING: @@ -60,17 +61,16 @@ class DataLoaderProfiler(Callback): instance of this :class:`.DataLoaderProfiler` callback. """ - def fit_start(self, state: State, logger: Logger): + def run_event(self, event: Event, state: State, logger: Logger): del logger # unused + if event not in (Event.FIT_START, Event.EVAL_START): + return if state.profiler is None: raise RuntimeError(("The Composer Profiler was not enabled, which is required to use the " f"{type(self).__name__}. To enable, set the `prof_schedule` argument of the Trainer.")) - if not _ProfiledDataLoader.is_dataloader_already_wrapped(state.train_dataloader): - state.train_dataloader = _ProfiledDataLoader(state.profiler, state.train_dataloader, "train") - - for evaluator in state.evaluators: - - if not _ProfiledDataLoader.is_dataloader_already_wrapped(evaluator.dataloader.dataloader): - evaluator.dataloader.dataloader = _ProfiledDataLoader(state.profiler, evaluator.dataloader.dataloader, - evaluator.label) + assert state.dataloader, "dataloader should be set on FIT_START or EVAL_START" + assert state.dataloader_label is not None, "dataloader label should be set on FIT_START or EVAL_START" + if not _ProfiledDataLoader.is_dataloader_already_wrapped(state.dataloader): + state.set_dataloader(_ProfiledDataLoader(state.profiler, state.dataloader, state.dataloader_label), + state.dataloader_label) diff --git a/composer/profiler/profiler_schedule.py b/composer/profiler/profiler_schedule.py index 2f75cd6ee5..bd764b941c 100644 --- a/composer/profiler/profiler_schedule.py +++ b/composer/profiler/profiler_schedule.py @@ -55,7 +55,7 @@ def schedule(state: State): return ProfilerAction.SKIP if position_in_cycle < wait + warmup: return ProfilerAction.WARMUP - is_last_batch_in_epoch = state.timer.batch_in_epoch == state.steps_per_epoch - 1 + is_last_batch_in_epoch = state.dataloader_len is not None and state.timer.batch_in_epoch == state.dataloader_len - 1 if position_in_cycle == cycle_len - 1 or is_last_batch_in_epoch: return ProfilerAction.ACTIVE_AND_SAVE return ProfilerAction.ACTIVE diff --git a/composer/trainer/_deepspeed.py b/composer/trainer/_deepspeed.py index 9695673df8..17a7de642c 100644 --- a/composer/trainer/_deepspeed.py +++ b/composer/trainer/_deepspeed.py @@ -17,19 +17,20 @@ def _add_batch_config(config: Dict[str, Any], state: State): - if state.train_dataloader.batch_size is None: + assert state.dataloader is not None, "dataloader should be set on FIT_START, which is where the Deepspeed config is applied." + if state.dataloader.batch_size is None: raise RuntimeError("DeepSpeed requires a dataloader with a known batch size.") - if state.train_dataloader.batch_size % state.grad_accum != 0: + if state.dataloader.batch_size % state.grad_accum != 0: # DeepSpeed will throw an error in this configuration. raise ValueError("The Mosaic trainer has been configured to use batch size=" - f"{state.train_dataloader.batch_size}, but this is not divisible by the " + f"{state.dataloader.batch_size}, but this is not divisible by the " f"grad accum={state.grad_accum}. This is unsupported when using DeepSpeed.") - train_batch_size = state.train_dataloader.batch_size * dist.get_world_size() + train_batch_size = state.dataloader.batch_size * dist.get_world_size() grad_accum = state.grad_accum # Per the check at the start of this function, the following division is always clean. - per_gpu_microbatch_size = state.train_dataloader.batch_size // state.grad_accum + per_gpu_microbatch_size = state.dataloader.batch_size // state.grad_accum if "train_batch_size" in config: ds_train_batch_size = config["train_batch_size"] diff --git a/composer/trainer/devices/device.py b/composer/trainer/devices/device.py index 7bb581df99..a3cc14212c 100644 --- a/composer/trainer/devices/device.py +++ b/composer/trainer/devices/device.py @@ -4,14 +4,12 @@ from abc import ABC, abstractmethod from collections.abc import Mapping, Sequence -from contextlib import contextmanager -from typing import Any, Callable, Generator, TypeVar, Union +from typing import Any, Callable, TypeVar import torch import torch.nn from torch.optim import Optimizer -from composer.core.precision import Precision from composer.core.serializable import Serializable __all__ = ["Device", "T_nnModule"] @@ -90,35 +88,6 @@ def optimizer_to_device(self, optimizer: Optimizer) -> Optimizer: state[k] = self.tensor_to_device(v) return optimizer - @abstractmethod - @contextmanager - def precision_context(self, precision: Union[str, Precision]) -> Generator[None, None, None]: - """Precision returns a context manager that uses the specified precision. - - Example usage: - - .. doctest:: - - >>> from composer.core.precision import Precision - >>> from composer.trainer.devices import DeviceCPU - >>> - >>> device = DeviceCPU() - >>> for batch in train_dataloader: - ... with device.precision_context(Precision.FP32): - ... outputs = model.forward(batch) - ... - ... with device.precision_context(Precision.FP32): - ... loss = model.loss(outputs, batch) - >>> - - Args: - precision (Precision): The desired precision for the device. - - Yields: - Generator[None, None, None]: A context for the precision. - """ - pass - def _map_batch(batch: Any, map_fn: Callable) -> Any: """Recursively maps a function to all items in a batch. diff --git a/composer/trainer/devices/device_cpu.py b/composer/trainer/devices/device_cpu.py index 3665bd678e..539d3cb96b 100644 --- a/composer/trainer/devices/device_cpu.py +++ b/composer/trainer/devices/device_cpu.py @@ -5,13 +5,11 @@ from __future__ import annotations import logging -from contextlib import contextmanager -from typing import Any, Dict, Generator, TypeVar, Union +from typing import Any, Dict, TypeVar import torch -from composer.core import Precision -from composer.trainer.devices.device import Device, T_nnModule +from composer.trainer.devices.device import Device logger = logging.getLogger(__name__) @@ -35,14 +33,6 @@ def module_to_device(self, module: T_nnModule) -> T_nnModule: def tensor_to_device(self, tensor: torch.Tensor) -> torch.Tensor: return tensor.to(self._device) - @contextmanager - def precision_context(self, precision: Union[str, Precision]) -> Generator[None, None, None]: - precision = Precision(precision) - if precision == Precision.FP32: - yield - else: - raise ValueError(f"Precision {precision} not supported for a CPU") - def state_dict(self) -> Dict[str, Any]: # CPU device has no RNG state return {} diff --git a/composer/trainer/devices/device_gpu.py b/composer/trainer/devices/device_gpu.py index 15a6e1ab9b..81e7ffdf7d 100644 --- a/composer/trainer/devices/device_gpu.py +++ b/composer/trainer/devices/device_gpu.py @@ -4,16 +4,13 @@ from __future__ import annotations -from contextlib import contextmanager -from typing import Any, Dict, Generator, TypeVar, Union +from typing import Any, Dict, TypeVar import torch import torch.cuda.amp import torch.utils.data -from packaging import version -from composer.core.precision import Precision -from composer.trainer.devices.device import Device, T_nnModule +from composer.trainer.devices.device import Device from composer.utils import dist __all__ = ["DeviceGPU"] @@ -40,24 +37,6 @@ def module_to_device(self, module: T_nnModule) -> T_nnModule: def tensor_to_device(self, tensor: torch.Tensor) -> torch.Tensor: return tensor.to(self._device, non_blocking=True) - @contextmanager - def precision_context(self, precision: Union[str, Precision]) -> Generator[None, None, None]: - precision = Precision(precision) - enabled = False - if precision == Precision.FP32: - enabled = False - elif precision == Precision.AMP: - enabled = True - elif precision == Precision.BF16: - if version.parse(torch.__version__) < version.parse("1.10"): - raise ValueError(f"BF16 precision requires torch > 1.10, got version {torch.__version__}") - with torch.cuda.amp.autocast(True, torch.bfloat16): # type: ignore - yield - # Retain compatibility with PyTorch < 1.10 - if precision != Precision.BF16: - with torch.cuda.amp.autocast(enabled): # type: ignore - yield - def state_dict(self) -> Dict[str, Any]: return { "rng": torch.cuda.get_rng_state(), diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 9b38fd2d4f..236257c920 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -85,6 +85,7 @@ from composer.algorithms import ScaleSchedule from composer.callbacks import CheckpointSaver from composer.core import Algorithm, Callback, DataSpec, Engine, Evaluator, Event, Precision, State, Time, Timestamp +from composer.core.precision import get_precision_context from composer.core.types import Batch, BreakEpochException, DataLoader, PyTorchScheduler from composer.datasets.dataloader import unwrap_data_loader from composer.loggers import Logger, LoggerDestination, LogLevel, ProgressBarLogger @@ -115,7 +116,7 @@ class Trainer: with Composer. .. seealso:: :mod:`composer.models` for models built into Composer. - train_dataloader (DataLoader, DataSpec, or dict): The :class:`.DataLoader`, :class:`.DataSpec`, + train_dataloader (DataLoader, DataSpec, or dict, optional): The :class:`.DataLoader`, :class:`.DataSpec`, or dict of :class:`.DataSpec` kwargs for the training data. In order to specify custom preprocessing steps on each data batch, specify a :class:`.DataSpec` instead of a :class:`.DataLoader`. @@ -171,8 +172,12 @@ class Trainer: .. note:: ``fp16`` only works if ``deepspeed_config`` is also provided. scale_schedule_ratio (float, optional): Ratio by which to scale the training duration and learning rate - schedules. E.g., ``0.5`` makes the schedule take half as many epochs and ``2.0`` makes it take twice as - many epochs. ``1.0`` means no change. (default: ``1.0``) + schedules. (default: ``1.0``) + + E.g., ``0.5`` makes the schedule take half as many epochs and ``2.0`` makes it take twice as + many epochs. ``1.0`` means no change. + + This parameter has no effect if ``schedulers`` is not specified. .. note :: @@ -399,10 +404,13 @@ class Trainer: artifact stores. train_subset_num_batches (int, optional): If specified, finish every epoch early after training on this many batches. This parameter has no effect if it is greater than ``len(train_dataloader)``. - If ``None``, then the entire dataloader will be iterated over. (default: ``None``) - eval_subset_num_batches (int, optional): If specified, evaluate on this many batches. + If ``-1``, then the entire dataloader will be iterated over. (default: ``-1``) + + This parameter is ignored if ``train_dataloader`` is not specified. + + eval_subset_num_batches (int, optional): If specified, evaluate on this many batches per evaluation dataloader. This parameter has no effect if it is greater than ``len(eval_dataloader)``. - If ``None``, then the entire dataloader will be iterated over. (default: ``None``) + If ``-1``, then the entire dataloader will be iterated over. (default: ``-1``) deepspeed_config (bool or Dict[str, Any], optional): Configuration for DeepSpeed, formatted as a JSON according to `DeepSpeed's documentation `_. If ``True`` is provided, the trainer will initialize the DeepSpeed engine with an empty config ``{}``. If ``False`` @@ -489,7 +497,8 @@ def __init__( eval_dataloader: Optional[Union[DataLoader, DataSpec, Evaluator, Sequence[Evaluator]]] = None, algorithms: Optional[Union[Algorithm, Sequence[Algorithm]]] = None, optimizers: Optional[torch.optim.Optimizer] = None, - schedulers: Optional[Union[ComposerScheduler, Sequence[ComposerScheduler]]] = None, + schedulers: Optional[Union[ComposerScheduler, PyTorchScheduler, Sequence[Union[ComposerScheduler, + PyTorchScheduler]]]] = None, # device device: Optional[Union[str, Device]] = None, @@ -541,8 +550,8 @@ def __init__( save_num_checkpoints_to_keep: int = -1, # subset parameters - train_subset_num_batches: Optional[int] = None, - eval_subset_num_batches: Optional[int] = None, + train_subset_num_batches: int = -1, + eval_subset_num_batches: int = -1, # DeepSpeed deepspeed_config: Union[bool, Dict[str, Any]] = False, @@ -655,7 +664,7 @@ def __init__( evaluator = Evaluator(label="eval", dataloader=evaluator, metrics=metrics) self.evaluators.append(evaluator) - self._eval_subset_num_batches = eval_subset_num_batches + self.eval_subset_num_batches = eval_subset_num_batches # do a check here to make sure there is at least one validation set if len(self.evaluators) == 0: @@ -682,10 +691,6 @@ def __init__( To fix, please do not iterate over the dataloader before passing it into the trainer.""")) - # TODO(#123): DeepSpeed still needs a precision context, but it's not completely clear how to - # handle this with our version of Pytorch - precision_context = self._device.precision_context if not self.deepspeed_enabled else cast( - Callable[..., ContextManager], contextlib.nullcontext) if isinstance(precision, str): precision = Precision(precision) @@ -715,18 +720,13 @@ def __init__( raise ValueError("grad_accum must be an int or ``auto``") self.state = State( - max_duration=max_duration, rank_zero_seed=rank_zero_seed, algorithms=algorithms, model=model, callbacks=callbacks, grad_accum=grad_accum, precision=precision, - precision_context=precision_context, - train_dataloader=train_dataloader.dataloader, - evaluators=self.evaluators, optimizers=optimizers, - steps_per_epoch=train_subset_num_batches, ) pytorch_schedulers = [ @@ -761,14 +761,7 @@ def __init__( self._step_schedulers_every_batch = step_schedulers_every_batch - for scheduler in ensure_tuple(schedulers): - if isinstance(scheduler, PyTorchScheduler): - scale_pytorch_scheduler(scheduler, scale_schedule_ratio) - self.state.schedulers.append(scheduler) - else: # it's a composer scheduler - self.state.schedulers.append(compile_composer_scheduler(scheduler, self.state, scale_schedule_ratio)) - - if len(self.state.schedulers) == 0: + if len(ensure_tuple(schedulers)) == 0: warnings.warn(f"NoSchedulerWarning: No schedulers were specified. The learning rate will be constant.") # Configure profilers if profiling is enabled @@ -870,11 +863,23 @@ def __init__( self.engine.run_event(Event.INIT) + # After running Event.INIT, then set the "optional" elements of state that could be passed in on FIT instead of INIT + # Setting these attributes here ensures that algorithms do not depend on unavailable attributes during Event.INIT + self.state.set_dataloader(train_dataloader.dataloader, 'train', train_subset_num_batches) + self.state.max_duration = max_duration self.logger.data_fit({"rank_zero_seed": rank_zero_seed}) assert isinstance(self.state.model, ComposerModel) self._original_model = self.state.model # TODO(ravi) -- update the state to add an original model helper + # Compile and bind the schedulers + for scheduler in ensure_tuple(schedulers): + if isinstance(scheduler, PyTorchScheduler): + scale_pytorch_scheduler(scheduler, scale_schedule_ratio) + self.state.schedulers.append(scheduler) + else: # it's a composer scheduler + self.state.schedulers.append(compile_composer_scheduler(scheduler, self.state, scale_schedule_ratio)) + # place the state, model in the proper devices, and initialize from a checkpoint if provided if self.deepspeed_enabled: try: @@ -892,7 +897,12 @@ def __init__( model=self.state.model, optimizer=optimizer, ) - # The deepspeed engine is responsible for serializing the model and optimizer state, + # Since the DeepSpeed ZeRO optimizer does not inherit torch.optim.Optimizer, the schedulers must be + # compiled and bound BEFORE DeepSpeed initialization. However, this is OK, as the the DeepSpeed Zero + # optimizer uses the same underlying parameter groups as the original optimizer. See + # * https://github.com/microsoft/DeepSpeed/blob/fee73135980e78f8be7e1a3ff556751623ef6aaa/deepspeed/runtime/zero/stage_1_and_2.py#L1911-L1917 + # * https://github.com/microsoft/DeepSpeed/blob/ef17c89570ceae5b26a5f886e9d8cd0941afc0ac/deepspeed/runtime/zero/stage3.py#L2532-L2538 + # In addition, the deepspeed engine is responsible for serializing the model and optimizer state, # so these attributes should not be serialized with the composer state. if "model" in self.state.serialized_attributes: self.state.serialized_attributes.remove("model") @@ -1004,7 +1014,7 @@ def _spin_dataloaders(self): """ # spin the evaluator dataloaders once to initialize its sampler deterministically # so it does not affect any other RNG reads - for evaluator in self.state.evaluators: + for evaluator in self.evaluators: dataloader = evaluator.dataloader.dataloader # FFCV dataloaders use their own sampling strategy if isinstance(dataloader, torch.utils.data.DataLoader) and isinstance(dataloader.sampler, @@ -1014,12 +1024,13 @@ def _spin_dataloaders(self): break # spin the train dataloader's sampler to get to the state of the desired epoch + assert self.state.dataloader is not None, "train dataloader is set on state after FIT_START" for epoch in range(int(self.state.timer.epoch)): # TODO: hasattr check will be removed while fixing https://github.com/mosaicml/composer/issues/424 - if hasattr(self.state.train_dataloader, "sampler") and isinstance(self.state.train_dataloader.sampler, - torch.utils.data.DistributedSampler): - self.state.train_dataloader.sampler.set_epoch(epoch) - for _ in self.state.train_dataloader: + if hasattr(self.state.dataloader, "sampler") and isinstance(self.state.dataloader.sampler, + torch.utils.data.DistributedSampler): + self.state.dataloader.sampler.set_epoch(epoch) + for _ in self.state.dataloader: break def _train_loop(self) -> None: @@ -1027,6 +1038,8 @@ def _train_loop(self) -> None: # print training start self.logger.data_fit({"trainer/algorithms": [str(algo) for algo in self.state.algorithms]}) + assert self.state.dataloader is not None, "dataloader is set in __init__" + self.engine.run_event(Event.FIT_START) self.state.scaler = ClosureGradScaler() if self._use_closures() else GradScaler() @@ -1051,12 +1064,16 @@ def _train_loop(self) -> None: self.train_metrics.reset() # TODO: hasattr check will be removed while fixing https://github.com/mosaicml/composer/issues/424 - if hasattr(self.state.train_dataloader, "sampler") and isinstance(self.state.train_dataloader.sampler, - torch.utils.data.DistributedSampler): - self.state.train_dataloader.sampler.set_epoch(int(self.state.timer.epoch)) + if hasattr(self.state.dataloader, "sampler") and isinstance(self.state.dataloader.sampler, + torch.utils.data.DistributedSampler): + self.state.dataloader.sampler.set_epoch(int(self.state.timer.epoch)) - for batch_idx, self.state.batch in enumerate( - itertools.islice(self.state.train_dataloader, self.state.steps_per_epoch)): + if self.state.dataloader_len is None: + iterable = self.state.dataloader + else: + iterable = itertools.islice(self.state.dataloader, int(self.state.dataloader_len)) + + for batch_idx, self.state.batch in enumerate(iterable): # if resuming, skip dataloader forward to the minibatch index if batch_idx < int(self.state.timer.batch_in_epoch): @@ -1315,7 +1332,7 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, tot # forward pass self.engine.run_event(Event.BEFORE_FORWARD) - with self.state.precision_context: + with get_precision_context(self.state.precision): self.state.outputs = self.state.model(self.state.batch) self.engine.run_event(Event.AFTER_FORWARD) @@ -1323,7 +1340,7 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, tot # loss self.engine.run_event(Event.BEFORE_LOSS) - with self.state.precision_context: + with get_precision_context(self.state.precision): self.state.loss = self._original_model.loss(self.state.outputs, self.state.batch) # We always want to scale loss by the grad_accum before the backwards pass and @@ -1372,25 +1389,38 @@ def eval(self, log_level: LogLevel = LogLevel.FIT): """ restore_model_train = self.state.model.training + # back up the original dataloader on the state, so we can restore it after evaluation is finished + original_dataloader = self.state.dataloader + original_dataloader_label = self.state.dataloader_label + original_num_batches = self.state.dataloader_len + self.state.model.eval() with torch.no_grad(): - self.engine.run_event(Event.EVAL_START) + for evaluator in self.evaluators: + self.state.set_dataloader(evaluator.dataloader.dataloader, evaluator.label, + self.eval_subset_num_batches) + assert self.state.dataloader is not None, "dataloader is set" + + self.engine.run_event(Event.EVAL_START) - for evaluator in self.state.evaluators: - dataloader = evaluator.dataloader.dataloader metrics = self._ensure_metrics_device_and_dtype(evaluator.metrics) metrics.reset() # TODO: hasattr check will be removed while fixing https://github.com/mosaicml/composer/issues/424 - if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, - torch.utils.data.DistributedSampler): + if hasattr(self.state.dataloader, "sampler") and isinstance(self.state.dataloader.sampler, + torch.utils.data.DistributedSampler): # The distributed sampler uses `set_epoch` to set the random seed # Because evaluation can run on each batch, we use the batch to seed the sampler # so each evaluation will get a proper shuffle. # The epoch provided to `set_epoch` need not be sequential, so this is fine. - dataloader.sampler.set_epoch(int(self.state.timer.batch)) + self.state.dataloader.sampler.set_epoch(int(self.state.timer.batch)) - for self.state.batch in itertools.islice(dataloader, self._eval_subset_num_batches): + if self.state.dataloader_len is None: + iterable = self.state.dataloader + else: + iterable = itertools.islice(self.state.dataloader, int(self.state.dataloader_len)) + + for self.state.batch in iterable: self.state.batch = self._device.batch_to_device(self.state.batch) if evaluator.dataloader.device_transforms: self.state.batch = evaluator.dataloader.device_transforms(self.state.batch) @@ -1418,11 +1448,15 @@ def eval(self, log_level: LogLevel = LogLevel.FIT): self._compute_and_log_metrics(dataloader_label=evaluator.label, metrics=metrics, log_level=log_level) - self.engine.run_event(Event.EVAL_END) + self.engine.run_event(Event.EVAL_END) if restore_model_train: self.state.model.train() + self.state.set_dataloader(original_dataloader, original_dataloader_label) + if original_num_batches is not None: + self.state.dataloader_len = original_num_batches + def _use_grad_scaling(self, precision: Union[str, Precision], scaler: Optional[GradScaler]) -> bool: """Determines based on precision when to use grad scaling. diff --git a/composer/trainer/trainer_hparams.py b/composer/trainer/trainer_hparams.py index 98de4bbe7e..72d3dda798 100755 --- a/composer/trainer/trainer_hparams.py +++ b/composer/trainer/trainer_hparams.py @@ -378,10 +378,10 @@ class TrainerHparams(hp.Hparams): ) # subset parameters - train_subset_num_batches: Optional[int] = hp.optional( - "If specified, finish every epoch early after training on this many batches.", default=None) - eval_subset_num_batches: Optional[int] = hp.optional("If specified, stop each evaluation after this many batches.", - default=None) + train_subset_num_batches: int = hp.optional( + "If specified, finish every epoch early after training on this many batches.", default=-1) + eval_subset_num_batches: int = hp.optional("If specified, stop each evaluation after this many batches.", + default=-1) # DeepSpeed deepspeed: Optional[Dict[str, JSON]] = hp.optional(doc="Configuration for DeepSpeed.", default=None) diff --git a/docs/source/doctest_fixtures.py b/docs/source/doctest_fixtures.py index 837a49581b..32b0ced579 100644 --- a/docs/source/doctest_fixtures.py +++ b/docs/source/doctest_fixtures.py @@ -115,8 +115,8 @@ model=model, optimizers=optimizer, grad_accum=1, - train_dataloader=train_dataloader, - evaluators=[], + dataloader=train_dataloader, + dataloader_label="train", max_duration="1ep", precision="fp32", ) diff --git a/tests/algorithms/test_colout.py b/tests/algorithms/test_colout.py index 7ae3e90266..29645ec93c 100644 --- a/tests/algorithms/test_colout.py +++ b/tests/algorithms/test_colout.py @@ -291,8 +291,8 @@ def test_apply_sample(self, colout_algorithm: ColOut, minimal_state: State, empt original_image, _ = dataset[0] assert isinstance(original_image, Image.Image) - minimal_state.train_dataloader = dataloader - colout_algorithm.apply(Event.INIT, minimal_state, empty_logger) + minimal_state.set_dataloader(dataloader, "train") + colout_algorithm.apply(Event.FIT_START, minimal_state, empty_logger) new_image, _ = dataset[0] assert isinstance(new_image, Image.Image) diff --git a/tests/algorithms/test_layer_freezing.py b/tests/algorithms/test_layer_freezing.py index 748c7bde4e..77aba31c79 100644 --- a/tests/algorithms/test_layer_freezing.py +++ b/tests/algorithms/test_layer_freezing.py @@ -19,9 +19,9 @@ def _generate_state(epoch: int, max_epochs: int): rank_zero_seed=0, optimizers=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.99), precision=Precision.FP32, + dataloader=Mock(__len__=lambda x: 100), + dataloader_label="train", grad_accum=1, - train_dataloader=Mock(__len__=lambda x: 100), - evaluators=Mock(), max_duration=f'{max_epochs}ep') # fast forward by epochs diff --git a/tests/algorithms/test_progressive_resizing.py b/tests/algorithms/test_progressive_resizing.py index f80e74afc7..4a68ea226f 100644 --- a/tests/algorithms/test_progressive_resizing.py +++ b/tests/algorithms/test_progressive_resizing.py @@ -149,6 +149,7 @@ def test_match_incorrect(self, event: Event, pr_algorithm: ProgressiveResizing, def test_apply(self, epoch_frac: float, X: torch.Tensor, y: torch.Tensor, pr_algorithm: ProgressiveResizing, minimal_state: State, empty_logger: Logger): """Test apply at different epoch fractions (fraction of max epochs)""" + assert minimal_state.max_duration is not None assert minimal_state.max_duration.unit == TimeUnit.EPOCH minimal_state.timer.epoch._value = int(epoch_frac * minimal_state.max_duration.value) s = pr_algorithm.initial_scale diff --git a/tests/algorithms/test_selective_backprop.py b/tests/algorithms/test_selective_backprop.py index f42fa31677..8568631755 100644 --- a/tests/algorithms/test_selective_backprop.py +++ b/tests/algorithms/test_selective_backprop.py @@ -144,12 +144,12 @@ def conv_model(Ximage: torch.Tensor, D: int) -> ComposerClassifier: def state(minimal_state: State, conv_model: ComposerClassifier, loss_fun_tuple: Callable, epoch: int, batch: int) -> State: """State with required values set for Selective Backprop.""" - + assert minimal_state.dataloader_len is not None conv_model.loss = loss_fun_tuple minimal_state.model = conv_model minimal_state.timer.epoch._value = epoch - minimal_state.timer.batch._value = epoch * minimal_state.steps_per_epoch + batch + minimal_state.timer.batch._value = epoch * int(minimal_state.dataloader_len) + batch minimal_state.timer.batch_in_epoch._value = batch return minimal_state diff --git a/tests/algorithms/test_stochastic_depth.py b/tests/algorithms/test_stochastic_depth.py index 6451534c46..d00bfcfde1 100644 --- a/tests/algorithms/test_stochastic_depth.py +++ b/tests/algorithms/test_stochastic_depth.py @@ -232,8 +232,10 @@ def test_drop_rate_warmup(self, algorithm: StochasticDepth, step: int, state: St new_drop_rates = [] self.get_drop_rate_list(state.model, drop_rates=new_drop_rates) + assert state.max_duration is not None assert state.max_duration.unit == TimeUnit.EPOCH - drop_warmup_iters = int(state.steps_per_epoch * int(state.max_duration.value) * algorithm.drop_warmup) + assert state.dataloader_len is not None + drop_warmup_iters = int(int(state.dataloader_len) * int(state.max_duration.value) * algorithm.drop_warmup) assert torch.all(torch.tensor(new_drop_rates) == ((step / drop_warmup_iters) * torch.tensor(old_drop_rates))) diff --git a/tests/callbacks/test_speed_monitor.py b/tests/callbacks/test_speed_monitor.py index cda0eb7bad..8c72cf49f8 100644 --- a/tests/callbacks/test_speed_monitor.py +++ b/tests/callbacks/test_speed_monitor.py @@ -35,8 +35,10 @@ def test_speed_monitor(composer_trainer_hparams: TrainerHparams): if 'wall_clock_train' in metrics: wall_clock_train_calls += 1 - assert isinstance(trainer.state.train_dataloader, collections.abc.Sized) - expected_step_calls = (trainer.state.steps_per_epoch - speed_monitor_hparams.window_size) * max_epochs + assert isinstance(trainer.state.dataloader, collections.abc.Sized) + assert trainer.state.dataloader_label is not None + assert trainer.state.dataloader_len is not None + expected_step_calls = (trainer.state.dataloader_len - speed_monitor_hparams.window_size) * max_epochs assert throughput_step_calls == expected_step_calls assert throughput_epoch_calls == max_epochs assert wall_clock_train_calls == max_epochs diff --git a/tests/fixtures/dummy_fixtures.py b/tests/fixtures/dummy_fixtures.py index 9b6beb01bd..6a07b47be6 100644 --- a/tests/fixtures/dummy_fixtures.py +++ b/tests/fixtures/dummy_fixtures.py @@ -7,10 +7,8 @@ import torch import torch.utils.data from torch.optim import Optimizer -from torchmetrics import MetricCollection -from torchmetrics.classification.accuracy import Accuracy -from composer.core import DataSpec, Evaluator, Precision, State +from composer.core import DataSpec, Precision, State from composer.core.types import DataLoader, PyTorchScheduler from composer.datasets import DataLoaderHparams, DatasetHparams from composer.loggers import Logger @@ -107,21 +105,17 @@ def dummy_scheduler(dummy_optimizer: Optimizer): @pytest.fixture() def dummy_state(dummy_model: SimpleBatchPairModel, dummy_train_dataloader: DataLoader, dummy_optimizer: Optimizer, - dummy_scheduler: PyTorchScheduler, dummy_val_dataloader: DataLoader, rank_zero_seed: int) -> State: - evaluators = [ - Evaluator(label="dummy_label", dataloader=dummy_val_dataloader, metrics=dummy_model.metrics(train=False)) - ] + dummy_scheduler: PyTorchScheduler, rank_zero_seed: int) -> State: state = State( model=dummy_model, precision=Precision.FP32, grad_accum=1, rank_zero_seed=rank_zero_seed, - train_dataloader=dummy_train_dataloader, - evaluators=evaluators, optimizers=dummy_optimizer, max_duration="10ep", ) state.schedulers = dummy_scheduler + state.set_dataloader(dummy_train_dataloader, "train") return state @@ -224,20 +218,16 @@ def simple_conv_model_input(): @pytest.fixture() -def state_with_model(simple_conv_model: torch.nn.Module, dummy_train_dataloader: DataLoader, - dummy_val_dataloader: DataLoader, rank_zero_seed: int): - metric_coll = MetricCollection([Accuracy()]) - evaluators = [Evaluator(label="dummy_label", dataloader=dummy_val_dataloader, metrics=metric_coll)] - state = State( +def state_with_model(simple_conv_model: torch.nn.Module, dummy_train_dataloader: DataLoader, rank_zero_seed: int): + return State( grad_accum=1, rank_zero_seed=rank_zero_seed, max_duration="100ep", model=simple_conv_model, precision=Precision.FP32, - train_dataloader=dummy_train_dataloader, - evaluators=evaluators, + dataloader=dummy_train_dataloader, + dataloader_label="train", ) - return state @pytest.fixture() diff --git a/tests/fixtures/new_fixtures.py b/tests/fixtures/new_fixtures.py index df7034030e..9f01f93657 100644 --- a/tests/fixtures/new_fixtures.py +++ b/tests/fixtures/new_fixtures.py @@ -22,9 +22,9 @@ def minimal_state(rank_zero_seed: int): return State( model=SimpleModel(), rank_zero_seed=rank_zero_seed, - train_dataloader=DataLoader(RandomClassificationDataset()), - evaluators=[], max_duration='100ep', + dataloader=DataLoader(RandomClassificationDataset()), + dataloader_label="train", ) @@ -41,8 +41,13 @@ def disable_wandb(monkeypatch: pytest.MonkeyPatch): @pytest.fixture(autouse=True) def configure_dist(request: pytest.FixtureRequest): - # Configure dist globally, so individual tests that do not use the trainer + # Configure dist globally when the world size is greater than 1, + # so individual tests that do not use the trainer # do not need to worry about manually configuring dist. + + if dist.get_world_size() == 1: + return + backend = 'gloo' if request.node.get_closest_marker('gpu') is None else 'nccl' if not dist.is_initialized(): dist.initialize_dist(backend, timeout=datetime.timedelta(seconds=300)) diff --git a/tests/loggers/test_progress_bar_logger.py b/tests/loggers/test_progress_bar_logger.py index 94d03d9a13..ebdaf8c30a 100644 --- a/tests/loggers/test_progress_bar_logger.py +++ b/tests/loggers/test_progress_bar_logger.py @@ -41,8 +41,10 @@ def get_mock_tqdm(position: int, *args: object, **kwargs: object): assert composer_trainer_hparams.validate_every_n_batches < 0 assert len(is_train_to_mock_tqdms[False]) == composer_trainer_hparams.validate_every_n_epochs * max_epochs for mock_tqdm in is_train_to_mock_tqdms[True]: - assert mock_tqdm.update.call_count == trainer.state.steps_per_epoch + assert trainer.state.dataloader_len is not None + assert trainer.state.dataloader_label == "train" + assert mock_tqdm.update.call_count == int(trainer.state.dataloader_len) mock_tqdm.close.assert_called_once() for mock_tqdm in is_train_to_mock_tqdms[False]: - assert mock_tqdm.update.call_count == trainer._eval_subset_num_batches + assert mock_tqdm.update.call_count == trainer.eval_subset_num_batches mock_tqdm.close.assert_called_once() diff --git a/tests/profiler/test_json_trace_handler.py b/tests/profiler/test_json_trace_handler.py index bfcc3178ce..56314bb921 100644 --- a/tests/profiler/test_json_trace_handler.py +++ b/tests/profiler/test_json_trace_handler.py @@ -11,7 +11,7 @@ # This test shouldn't run with the Torch profiler enabled, not providing a model or data can cause a seg fault -@pytest.mark.timeout(10) +@pytest.mark.timeout(30) def test_json_trace_profiler_handler(composer_trainer_hparams: TrainerHparams, tmpdir: pathlib.Path): profiler_file = os.path.join(tmpdir, 'trace.json') json_trace_handler_params = JSONTraceHparams(folder=str(tmpdir), merged_trace_filename='trace.json') diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py index 484d421454..c0f89c90c0 100644 --- a/tests/test_notebooks.py +++ b/tests/test_notebooks.py @@ -56,10 +56,10 @@ def test_notebook(tb): tb.inject(""" from composer.core import Time trainer.state.max_duration = Time.from_timestring('2ep') - trainer.state.train_subset_num_batches = 2 + trainer.state.dataloader_len = 2 """) trainer = tb.ref("trainer") - assert trainer.state.train_subset_num_batches == 2 + assert trainer.state.dataloader_len == 2 except Exception as e: raise Exception( textwrap.dedent(""" diff --git a/tests/test_state.py b/tests/test_state.py index 6912ebac5e..caec67d527 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -9,7 +9,7 @@ from torch.functional import Tensor from composer.algorithms import ChannelsLastHparams -from composer.core import DataSpec, Evaluator, Precision, State +from composer.core import DataSpec, Precision, State from composer.core.types import Batch, DataLoader from composer.datasets.dataloader import DataLoaderHparams from composer.datasets.hparams import DatasetHparams @@ -21,23 +21,21 @@ def random_tensor(size=(4, 10)): return torch.rand(*size) -def get_dummy_state(model: ComposerModel, train_dataloader: DataLoader, val_dataloader: DataLoader): +def get_dummy_state(model: ComposerModel, train_dataloader: DataLoader): optimizers = torch.optim.Adadelta(model.parameters()) - evaluators = [Evaluator(label="dummy_label", dataloader=val_dataloader, metrics=model.metrics(train=False))] state = State(model=model, grad_accum=random.randint(0, 100), rank_zero_seed=random.randint(0, 100), precision=Precision.AMP, max_duration=f"{random.randint(0, 100)}ep", - train_dataloader=train_dataloader, - evaluators=evaluators, optimizers=optimizers, algorithms=[ChannelsLastHparams().initialize_object()]) state.schedulers = torch.optim.lr_scheduler.StepLR(optimizers, step_size=3) state.loss = random_tensor() state.batch = (random_tensor(), random_tensor()) state.outputs = random_tensor() + state.set_dataloader(train_dataloader, "train") return state @@ -111,14 +109,18 @@ def get_batch(dataset_hparams: DatasetHparams, dataloader_hparams: DataLoaderHpa raise RuntimeError("No batch in dataloader") -def test_state_serialize(tmpdir: pathlib.Path, dummy_model: ComposerModel, dummy_dataloader_hparams: DataLoaderHparams, - dummy_train_dataset_hparams: DatasetHparams, dummy_train_dataloader: DataLoader, - dummy_val_dataset_hparams: DatasetHparams, dummy_val_dataloader: DataLoader): +def test_state_serialize( + tmpdir: pathlib.Path, + dummy_model: ComposerModel, + dummy_dataloader_hparams: DataLoaderHparams, + dummy_train_dataset_hparams: DatasetHparams, + dummy_train_dataloader: DataLoader, +): assert isinstance(dummy_model, SimpleBatchPairModel) - state1 = get_dummy_state(dummy_model, dummy_train_dataloader, dummy_val_dataloader) - state2 = get_dummy_state(dummy_model, dummy_train_dataloader, dummy_val_dataloader) + state1 = get_dummy_state(dummy_model, dummy_train_dataloader) + state2 = get_dummy_state(dummy_model, dummy_train_dataloader) # train one step to set the optimizer states batch = get_batch(dummy_train_dataset_hparams, dummy_dataloader_hparams) diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 60dfe156f0..59d0fd18c7 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -428,10 +428,12 @@ def _test_checkpoint_trainer(trainer_hparams: TrainerHparams): def _validate_events_called_expected_number_of_times(trainer: Trainer): state = trainer.state - + assert state.dataloader_label == "train" + assert state.dataloader_len is not None + assert state.max_duration is not None assert state.max_duration.unit == TimeUnit.EPOCH num_epochs = state.max_duration.value - num_total_steps = num_epochs * state.steps_per_epoch + num_total_steps = num_epochs * int(state.dataloader_len) num_total_microbatches = num_total_steps * state.grad_accum num_evals = 0 if trainer._validate_every_n_batches > 0: @@ -439,11 +441,11 @@ def _validate_events_called_expected_number_of_times(trainer: Trainer): if trainer._validate_every_n_epochs > 0: num_evals = num_epochs // trainer._validate_every_n_epochs - assert state.evaluators is not None - for evaluator in state.evaluators: + assert trainer.evaluators is not None + for evaluator in trainer.evaluators: assert evaluator.dataloader is not None - assert trainer._eval_subset_num_batches is not None - num_eval_steps = num_evals * trainer._eval_subset_num_batches * len(state.evaluators) + assert trainer.eval_subset_num_batches is not None + num_eval_steps = num_evals * trainer.eval_subset_num_batches * len(trainer.evaluators) event_to_num_expected_invocations = { Event.INIT: 1, diff --git a/tests/trainer/test_ddp.py b/tests/trainer/test_ddp.py index 045f31db02..3bd82755ca 100644 --- a/tests/trainer/test_ddp.py +++ b/tests/trainer/test_ddp.py @@ -190,7 +190,6 @@ def test_ddp(device_hparams: DeviceHparams, world_size: int, dummy_model_hparams train_subset_num_batches=train_subset_num_batches, deepspeed_config={} if deepspeed else False, callbacks=[CheckBatch0(tmpdir)]) - assert isinstance(trainer.state.train_dataloader.dataset, collections.abc.Sized) for evaluator in trainer.evaluators: assert isinstance(evaluator.dataloader, DataSpec) diff --git a/tests/trainer/test_ddp_sync_strategy.py b/tests/trainer/test_ddp_sync_strategy.py index 50d6884661..594444ce49 100644 --- a/tests/trainer/test_ddp_sync_strategy.py +++ b/tests/trainer/test_ddp_sync_strategy.py @@ -6,10 +6,8 @@ import torch import torch.nn as nn from torch import Tensor -from torchmetrics import MetricCollection -from torchmetrics.classification.accuracy import Accuracy -from composer.core import Evaluator, State +from composer.core import State from composer.core.types import DataLoader from composer.trainer.ddp import _ddp_sync_context, _prepare_ddp_module from composer.utils import dist @@ -49,20 +47,20 @@ def loss(self, output: Tensor, target: Tensor): ]) @pytest.mark.world_size(2) def test_ddp_sync_strategy(ddp_sync_strategy: str, expected_grads: List[Optional[float]], - dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader, rank_zero_seed: int): + dummy_train_dataloader: DataLoader, rank_zero_seed: int): original_model = MinimalConditionalModel() # ddp = DDP(backend="gloo", find_unused_parameters=True, sync_strategy=ddp_sync_strategy, timeout=5.) optimizer = torch.optim.SGD(original_model.parameters(), 0.1) - metric_coll = MetricCollection([Accuracy()]) - evaluators = [Evaluator(label="dummy_label", dataloader=dummy_val_dataloader, metrics=metric_coll)] - state = State(model=original_model, - rank_zero_seed=rank_zero_seed, - optimizers=optimizer, - grad_accum=2, - max_duration="1ep", - train_dataloader=dummy_train_dataloader, - evaluators=evaluators, - precision='fp32') + state = State( + model=original_model, + rank_zero_seed=rank_zero_seed, + optimizers=optimizer, + grad_accum=2, + max_duration="1ep", + dataloader=dummy_train_dataloader, + dataloader_label="train", + precision='fp32', + ) batches = [[(1, Tensor([1])), (1, Tensor([2]))], [(2, Tensor([1])), (2, Tensor([2]))]] state.model = _prepare_ddp_module(state.model, find_unused_parameters=True) diff --git a/tests/trainer/test_scale_schedule.py b/tests/trainer/test_scale_schedule.py index 88e3766b52..f460fb1f68 100644 --- a/tests/trainer/test_scale_schedule.py +++ b/tests/trainer/test_scale_schedule.py @@ -9,8 +9,11 @@ from torch.optim.lr_scheduler import ExponentialLR from composer.algorithms import ScaleScheduleHparams +from composer.core import State +from composer.core.callback import Callback from composer.core.time import TimeUnit from composer.core.types import PyTorchScheduler +from composer.loggers.logger import Logger from composer.optim import MultiStepSchedulerHparams, SGDHparams from composer.trainer import TrainerHparams from composer.trainer._scale_schedule import scale_pytorch_scheduler @@ -27,6 +30,7 @@ def flatten(lst: list): @pytest.mark.parametrize('ssr', [0.5, 0.75, 1.0]) +@pytest.mark.filterwarnings(r"ignore:.*Detected call of \`lr_schedule.*:UserWarning") class TestScaleSchedule(): @staticmethod @@ -72,10 +76,36 @@ def test_scale_schedule_cosine_warm_restarts(self, optimizer: Optimizer, ssr: fl raise NotImplementedError +class CheckScaleSchedule(Callback): + + def __init__(self, ssr: float) -> None: + self.ssr = ssr + + def fit_start(self, state: State, logger: Logger) -> None: + scheduler = state.schedulers[0] + + test_steps = [int(20 * self.ssr), int(40 * self.ssr), int(60 * self.ssr)] + target_lrs = [1.0, 0.1, 0.01] + current_step = 0 + for test_step, target_lr in zip(test_steps, target_lrs): + + while current_step < test_step: + state.timer.on_batch_complete() + current_step += 1 + + scheduler.step() + + assert scheduler.get_last_lr()[0] == pytest.approx(target_lr) + + @pytest.mark.parametrize('ssr', [0.5, 0.75, 1.0]) -@pytest.mark.parametrize('use_algorithm', [False, True]) +@pytest.mark.parametrize('use_algorithm', [ + False, + pytest.param(True, marks=pytest.mark.filterwarnings(r"ignore:.*ScaleScheduleDeprecationWarning.*")), +]) class TestScaleScheduleTrainer(): + @pytest.mark.filterwarnings(r"ignore:.*Detected call of \`lr_schedule.*:UserWarning") def test_epochs_scaled( self, ssr: float, @@ -91,21 +121,12 @@ def test_epochs_scaled( composer_trainer_hparams.algorithms = [ScaleScheduleHparams(ratio=ssr)] else: composer_trainer_hparams.scale_schedule_ratio = ssr + trainer = composer_trainer_hparams.initialize_object() + trainer.state.callbacks.append(CheckScaleSchedule(ssr)) + assert trainer.state.max_duration is not None assert trainer.state.max_duration.unit == TimeUnit.EPOCH assert trainer.state.max_duration.value == int(10 * ssr) - scheduler = trainer.state.schedulers[0] - - test_steps = [int(20 * ssr), int(40 * ssr), int(60 * ssr)] - target_lrs = [1.0, 0.1, 0.01] - current_step = 0 - for test_step, target_lr in zip(test_steps, target_lrs): - while current_step < test_step: - trainer.state.timer.on_batch_complete() - current_step += 1 - - scheduler.step() - - assert scheduler.get_last_lr()[0] == pytest.approx(target_lr) + trainer.fit() diff --git a/tests/trainer/test_scheduler.py b/tests/trainer/test_scheduler.py index 6568595dc5..cfa21a61f4 100644 --- a/tests/trainer/test_scheduler.py +++ b/tests/trainer/test_scheduler.py @@ -5,6 +5,7 @@ import pytest import torch +import torch.utils.data from composer.core import State, Time from composer.core.time import TimeUnit @@ -21,14 +22,14 @@ @pytest.fixture -def dummy_schedulers_state(dummy_model: torch.nn.Module, dummy_train_dataloader: DataLoader, rank_zero_seed: int): - return State( +def dummy_schedulers_state(dummy_model: torch.nn.Module, rank_zero_seed: int): + state = State( model=dummy_model, rank_zero_seed=rank_zero_seed, - train_dataloader=dummy_train_dataloader, max_duration=MAX_DURATION, - steps_per_epoch=STEPS_PER_EPOCH, ) + state.set_dataloader(cast(torch.utils.data.DataLoader, [None] * STEPS_PER_EPOCH), "train") + return state @pytest.mark.parametrize("scheduler,ssr,test_times,expected_lrs", [ @@ -88,16 +89,18 @@ def test_scheduler_init(scheduler: ComposerScheduler, ssr: float, test_times: Li dummy_schedulers_state: State): state = dummy_schedulers_state - state._max_duration = Time(value=int(state.max_duration.value * ssr), unit=state.max_duration.unit) + assert state.dataloader_len is not None + assert state.max_duration is not None + state.max_duration = Time(value=int(state.max_duration.value * ssr), unit=state.max_duration.unit) for test_time, expected_lr in zip(test_times, expected_lrs): parsed_time = Time.from_timestring(test_time) assert parsed_time.unit in [TimeUnit.EPOCH, TimeUnit.BATCH] if parsed_time.unit == TimeUnit.EPOCH: state.timer._epoch = parsed_time - state.timer._batch = Time(int(state.steps_per_epoch * state.timer._epoch.value), TimeUnit.BATCH) + state.timer._batch = Time(int(state.dataloader_len) * int(state.timer.epoch), TimeUnit.BATCH) else: state.timer._batch = parsed_time - state.timer._epoch = Time(int(state.timer._batch.value / state.steps_per_epoch), TimeUnit.EPOCH) + state.timer._epoch = Time(int(state.timer.batch) // int(state.dataloader_len), TimeUnit.EPOCH) lr = scheduler(state, ssr) assert lr == pytest.approx(expected_lr, abs=1e-3) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 2af6fd0209..087c1e5dc6 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -111,6 +111,7 @@ def test_init_with_integers(self, config, tmpdir: pathlib.Path): def test_init_with_max_duration_in_batches(self, config): config["max_duration"] = '1ba' trainer = Trainer(**config) + assert trainer.state.max_duration is not None assert trainer.state.max_duration.to_timestring() == "1ba" diff --git a/tests/utils/classifer.py b/tests/utils/classifer.py index f587bd0fc3..e32fef0ea4 100644 --- a/tests/utils/classifer.py +++ b/tests/utils/classifer.py @@ -5,31 +5,23 @@ import torch import torch.nn.functional as F from torch.optim import SGD, Optimizer -from torchmetrics import MetricCollection -from torchmetrics.classification.accuracy import Accuracy -from composer.core import Algorithm, Engine, Evaluator, Event, Precision, State +from composer.core import Algorithm, Engine, Event, Precision, State from composer.core.types import DataLoader from composer.loggers import Logger from tests.utils.model import SimpleModel -def _get_state(train_dataloader: DataLoader, - eval_dataloader: DataLoader, - rank_zero_seed: int, - steps_per_epoch: int = 1): +def _get_state(train_dataloader: DataLoader, eval_dataloader: DataLoader, rank_zero_seed: int): model = SimpleModel() - steps_per_epoch = steps_per_epoch - metric_coll = MetricCollection([Accuracy()]) - evaluators = [Evaluator(label="dummy_label", dataloader=eval_dataloader, metrics=metric_coll)] return State( model=model, rank_zero_seed=rank_zero_seed, optimizers=SGD(model.parameters(), lr=.001, momentum=0.0), max_duration="1ep", - train_dataloader=train_dataloader, - evaluators=evaluators, grad_accum=1, + dataloader=train_dataloader, + dataloader_label="train", precision=Precision.FP32, ) @@ -47,7 +39,6 @@ def test_classifier_trains( state = _get_state(train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, - steps_per_epoch=n_steps, rank_zero_seed=rank_zero_seed) model = state.model @@ -72,9 +63,9 @@ def f_loss(state): opt = state.optimizers assert isinstance(opt, Optimizer) - assert state.train_dataloader is not None, "train dataloader should be set" + assert state.dataloader is not None, "train dataloader should be set" - for step, (X, y) in enumerate(state.train_dataloader): + for step, (X, y) in enumerate(state.dataloader): # reseed here so data is same for different sets of algorithms torch.manual_seed(step)