From bd3e8474c595fb21be189373585a9404367cca12 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Fri, 25 Mar 2022 16:51:54 -0700 Subject: [PATCH 01/17] [Eval-Only]: Made the `state.dataloader` optional; removed `state.steps_per_epoch`. 1. Made the `state.dataloader` optional, since it will not be provided on `__init__` as part of #40. 2. Binding the active dataloader to the state on `Event.FIT_START`, and switching the dataloader to each evaluation dataloader before `Event.EVAL_START`. Restoring the previous (training) dataloader after `Event.EVAL_END`. 3. Moved `Event.EVAL_START` and `Event.EVAL_END` to run for each evaluator, instead of once for all evaluators. With #40, the eval() will take in a dataloader, which would then require `Event.EVAL_START` and `Event.EVAL_END`. This change also permits for algorithms that wish to modify (each) evalution dataloader. 4. Moved scaling of the LR schedulers to `Trainer.fit()` before `Event.FIT_START` fires. Schedulers will be passed in on `Trainer.fit()` as part of #40. 5. Removed `steps_per_epoch` as part of the state. Instead, algorithms and callbacks can read len(state.dataloader) directly. While this change will make schedulers no longer accurate when using `train_subset_num_batches`, that flag should only be used for performance measurements. As such, it is not necessarry that SSR behaves correctly for performance runs. Added a warning for the `train_subset_num_batches` field. Implements the first part of #40. Closes #363. --- composer/algorithms/augmix/augmix.py | 8 +- composer/algorithms/colout/colout.py | 8 +- .../algorithms/randaugment/randaugment.py | 9 +- .../seq_length_warmup/seq_length_warmup.py | 9 +- composer/core/state.py | 34 +----- composer/loggers/progress_bar_logger.py | 3 +- composer/optim/scheduler.py | 8 +- composer/profiler/dataloader_profiler.py | 5 +- composer/profiler/torch_profiler.py | 4 - composer/trainer/_deepspeed.py | 11 +- composer/trainer/trainer.py | 106 +++++++++++------- tests/algorithms/test_colout.py | 2 +- tests/algorithms/test_layer_freezing.py | 3 +- tests/algorithms/test_selective_backprop.py | 4 +- tests/algorithms/test_stochastic_depth.py | 3 +- tests/callbacks/test_speed_monitor.py | 5 +- tests/fixtures/dummy_fixtures.py | 5 +- tests/fixtures/new_fixtures.py | 5 +- tests/loggers/test_progress_bar_logger.py | 4 +- tests/test_state.py | 2 +- tests/trainer/test_checkpoint.py | 8 +- tests/trainer/test_ddp.py | 1 - tests/trainer/test_ddp_sync_strategy.py | 2 +- tests/trainer/test_scale_schedule.py | 44 +++++--- tests/trainer/test_scheduler.py | 16 +-- tests/utils/classifer.py | 16 +-- tests/utils/trainer_fit.py | 6 +- 27 files changed, 180 insertions(+), 151 deletions(-) diff --git a/composer/algorithms/augmix/augmix.py b/composer/algorithms/augmix/augmix.py index 55383cd1ba..1a06463cb5 100644 --- a/composer/algorithms/augmix/augmix.py +++ b/composer/algorithms/augmix/augmix.py @@ -251,7 +251,10 @@ def __init__(self, self._transformed_datasets = weakref.WeakSet() def match(self, event: Event, state: State) -> bool: - return event == Event.FIT_START and state.train_dataloader.dataset not in self._transformed_datasets + if event != Event.FIT_START: + return False + assert state.dataloader is not None, "dataloader should be defined on fit start" + return state.dataloader.dataset not in self._transformed_datasets def apply(self, event: Event, state: State, logger: Logger) -> None: am = AugmentAndMixTransform(severity=self.severity, @@ -259,7 +262,8 @@ def apply(self, event: Event, state: State, logger: Logger) -> None: width=self.width, alpha=self.alpha, augmentation_set=self.augmentation_set) - dataset = state.train_dataloader.dataset + assert state.dataloader is not None, "dataloader should be defined on fit start" + dataset = state.dataloader.dataset if not isinstance(dataset, VisionDataset): raise TypeError( textwrap.dedent(f"""\ diff --git a/composer/algorithms/colout/colout.py b/composer/algorithms/colout/colout.py index 4b628a195c..b9de21ae73 100644 --- a/composer/algorithms/colout/colout.py +++ b/composer/algorithms/colout/colout.py @@ -160,11 +160,15 @@ def match(self, event: Event, state: State) -> bool: if self.batch: return event == Event.AFTER_DATALOADER else: - return event == Event.FIT_START and state.train_dataloader.dataset not in self._transformed_datasets + if event != Event.FIT_START: + return False + assert state.dataloader is not None, "dataloader should be defined on fit start" + return state.dataloader.dataset not in self._transformed_datasets def _apply_sample(self, state: State) -> None: """Add the ColOut dataset transform to the dataloader.""" - dataset = state.train_dataloader.dataset + assert state.dataloader is not None, "dataloader should be defined on fit start" + dataset = state.dataloader.dataset transform = ColOutTransform(p_row=self.p_row, p_col=self.p_col) diff --git a/composer/algorithms/randaugment/randaugment.py b/composer/algorithms/randaugment/randaugment.py index 7501bc372d..db3c2b0acf 100644 --- a/composer/algorithms/randaugment/randaugment.py +++ b/composer/algorithms/randaugment/randaugment.py @@ -188,12 +188,15 @@ def __init__(self, severity: int = 9, depth: int = 2, augmentation_set: str = "a self._transformed_datasets = weakref.WeakSet() def match(self, event: Event, state: State) -> bool: - return event == Event.FIT_START and state.train_dataloader.dataset not in self._transformed_datasets + if event != Event.FIT_START: + return False + assert state.dataloader is not None, "dataloader should be defined on fit start" + return state.dataloader.dataset not in self._transformed_datasets def apply(self, event: Event, state: State, logger: Logger) -> None: ra = RandAugmentTransform(severity=self.severity, depth=self.depth, augmentation_set=self.augmentation_set) - assert state.train_dataloader is not None - dataset = state.train_dataloader.dataset + assert state.dataloader is not None, "dataloader should be defined on fit start" + dataset = state.dataloader.dataset if not isinstance(dataset, VisionDataset): raise TypeError( textwrap.dedent(f"""\ diff --git a/composer/algorithms/seq_length_warmup/seq_length_warmup.py b/composer/algorithms/seq_length_warmup/seq_length_warmup.py index 11bb420385..15bb71cd93 100644 --- a/composer/algorithms/seq_length_warmup/seq_length_warmup.py +++ b/composer/algorithms/seq_length_warmup/seq_length_warmup.py @@ -179,12 +179,11 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: {type(self).__name__} requires state.model to be of type {ComposerTransformer.__name__}, not of type {type(state.model)}""" )) - if state.train_dataloader.batch_size is None: - raise RuntimeError("Sequence Length Warmup algorithm requires constant batch size.") - self._original_model = state.model return + assert state.dataloader is not None, "dataloader should be set on AFTER_DATALOADER" + # in order to avoid OOMs, we do a forward and a backward pass on a dummy input. if not self._activated: # ensure that input_ids is a valid model input. since we don't need the @@ -204,7 +203,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: # all of the parameters device = next(state.model.parameters()).device - per_gpu_macrobatch = state.train_dataloader.batch_size + per_gpu_macrobatch = state.dataloader.batch_size if per_gpu_macrobatch is None: raise RuntimeError("Sequence Length Warmup algorithm requires constant batch size.") per_gpu_batch = ceil(per_gpu_macrobatch / state.grad_accum) @@ -238,7 +237,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: self._activated = True if state.max_duration.unit == TimeUnit.EPOCH: - num_optimization_steps = state.steps_per_epoch * state.max_duration.value + num_optimization_steps = len(state.dataloader) * state.max_duration.value elif state.max_duration.unit == TimeUnit.BATCH: num_optimization_steps = state.max_duration.value else: diff --git a/composer/core/state.py b/composer/core/state.py index 38b668b1d2..afae46a72d 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -5,7 +5,6 @@ import contextlib import logging -import textwrap import warnings from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Optional, Sequence, Union, cast @@ -97,8 +96,6 @@ class State(Serializable): ``rank_zero_seed + dist.get_global_rank()``. grad_accum (int): The number of gradient accumulation steps to use. With this argument, micro batch size for each device becomes ``microbatch_size = train_batch_size / (num_devices * grad_accum)``. - train_dataloader (types.DataLoader, DataSpec, or dict): - The :class:`~.types.DataLoader`, :class:`~.DataSpec`, or dict of :class:`~.DataSpec` kwargs to used for training. evaluators (evaluator.Evaluator | Sequence[evaluator.Evaluator]): The evaluators contain the evaluation dataset(s) used for evaluation with specific metrics. max_duration (str or Time): The maximum duration to train for. @@ -115,6 +112,7 @@ class State(Serializable): profiler (Optional[Profiler]): The Composer profiler. Attributes: + dataloader (types.DataLoader): The active :class:`~.types.DataLoader`. batch (types.Batch): The batch. This will be the entire batch during the :attr:`.Event.AFTER_DATALOADER`, or a microbatch between :attr:`.Event.BATCH_START` and :attr:`.Event.BATCH_END`. batch_num_samples (int): The number of samples in the :attr:`batch`. @@ -152,7 +150,6 @@ class State(Serializable): """ _max_duration: Time[int] - _steps_per_epoch: Optional[int] batch: types.Batch batch_num_samples: int batch_num_tokens: int @@ -170,7 +167,6 @@ def __init__( rank_zero_seed: int, # data configurations - train_dataloader: types.DataLoader, evaluators: Optional[Union[Evaluator, Sequence[Evaluator]]] = None, grad_accum: int = 1, @@ -187,17 +183,13 @@ def __init__( # algorithms and callbacks algorithms: Optional[Union[Algorithm, Sequence[Algorithm]]] = None, callbacks: Optional[Union[Callback, Sequence[Callback]]] = None, - - # steps per epoch - steps_per_epoch: Optional[int] = None, ): self.rank_zero_seed = rank_zero_seed self.model = model self.grad_accum = grad_accum - self.train_dataloader = train_dataloader + self.dataloader: Optional[types.DataLoader] = None self.evaluators = list(ensure_tuple(evaluators)) self.max_duration = max_duration - self.steps_per_epoch = steps_per_epoch self.timer = Timer() self._precision = Precision(precision) @@ -215,6 +207,7 @@ def __init__( self._callbacks = list(ensure_tuple(callbacks)) self.profiler: Optional[Profiler] = None + # These attributes will be serialized using .state_dict(), and loaded with .load_state_dict() # All other attributes will not be serialized. # For simplicity, omit the leading underscore for private attributes. @@ -365,27 +358,6 @@ def load_state_dict(self, state: Dict[str, Any], strict: bool = False): # ignore AttributeError for properties that have getters but not setters. pass - @property - def steps_per_epoch(self): - """int: The maximum number of steps (batches) per epoch.""" - if self._steps_per_epoch is None: - return len(self.train_dataloader) - return self._steps_per_epoch - - @steps_per_epoch.setter - def steps_per_epoch(self, steps_per_epoch: Optional[int]): - try: - dataloader_len = len(self.train_dataloader) - except (TypeError, NotImplementedError): - dataloader_len = None - if dataloader_len is not None and steps_per_epoch is not None and steps_per_epoch > dataloader_len: - warnings.warn( - textwrap.dedent(f"""\ - SubsetNumBatchesWarning: The steps_per_epoch({steps_per_epoch}) - is greater than the number of batches in the training dataloader - ({dataloader_len})""")) - self._steps_per_epoch = steps_per_epoch - @property def precision(self): """The numerical precision to use for training. diff --git a/composer/loggers/progress_bar_logger.py b/composer/loggers/progress_bar_logger.py index e7cd9aa3c9..6889c63cff 100644 --- a/composer/loggers/progress_bar_logger.py +++ b/composer/loggers/progress_bar_logger.py @@ -99,8 +99,9 @@ def _start(self, state: State): if dist.get_global_rank() != 0: return assert self.is_train is not None, "self.is_train should be set by the callback" + assert state.dataloader is not None, "dataloader should be set when using tqdm" if self.is_train: - total_steps = state.steps_per_epoch + total_steps = len(state.dataloader) else: total_steps = 0 for evaluator in state.evaluators: diff --git a/composer/optim/scheduler.py b/composer/optim/scheduler.py index 47268e2d62..226bd20114 100644 --- a/composer/optim/scheduler.py +++ b/composer/optim/scheduler.py @@ -131,15 +131,19 @@ def _convert_time(time: Union[str, Time[int], Time[float]], state: State, ssr: f time = Time.from_timestring(time) if time.unit == TimeUnit.DURATION: + if state.dataloader is None: + raise RuntimeError("Cannot convert time, as state.dataloader is None.") if state.max_duration.unit == TimeUnit.EPOCH: - return Time(int(time.value * state.steps_per_epoch * state.max_duration.value), TimeUnit.BATCH) + return Time(int(time.value * len(state.dataloader) * state.max_duration.value), TimeUnit.BATCH) return Time(int(time.value * state.max_duration.value), state.max_duration.unit) if time.unit == TimeUnit.EPOCH: # Epochs do not provide sufficient granularity for SSR scaling # e.g. if max_duration = 1ep, then any SSR would result in a new duration of 0. # so, convert the time into batches - time = Time(value=time.value * state.steps_per_epoch, unit=TimeUnit.BATCH) + if state.dataloader is None: + raise RuntimeError("Cannot convert time, as state.dataloader is None.") + time = Time(value=time.value * len(state.dataloader), unit=TimeUnit.BATCH) return Time(value=int(time.value * ssr), unit=time.unit) diff --git a/composer/profiler/dataloader_profiler.py b/composer/profiler/dataloader_profiler.py index 4a843c9966..0e1a217f44 100644 --- a/composer/profiler/dataloader_profiler.py +++ b/composer/profiler/dataloader_profiler.py @@ -71,8 +71,9 @@ def fit_start(self, state: State, logger: Logger): textwrap.dedent("""To use the dataloader profiler, state.profiler must be set. Make sure to run composer with the profiler -- i.e. with the `--profiler` CLI flag.""")) - if not _ProfiledDataLoader.is_dataloader_already_wrapped(state.train_dataloader): - state.train_dataloader = _ProfiledDataLoader(state.profiler, state.train_dataloader, "train") + assert state.dataloader, "dataloader should be set on FIT_START" + if not _ProfiledDataLoader.is_dataloader_already_wrapped(state.dataloader): + state.dataloader = _ProfiledDataLoader(state.profiler, state.dataloader, "train") for evaluator in state.evaluators: diff --git a/composer/profiler/torch_profiler.py b/composer/profiler/torch_profiler.py index fb318c82c9..80be32227e 100644 --- a/composer/profiler/torch_profiler.py +++ b/composer/profiler/torch_profiler.py @@ -100,10 +100,6 @@ def _scheduler_fn(self, profiler_step: int, state: State) -> TorchProfilerAction assert state.profiler is not None, "composer profiler should be defined" composer_profiler_action = state.profiler.get_action(next_batch_in_epoch) next_composer_profiler_action = state.profiler.get_action(next_batch_in_epoch + 1) - if next_batch_in_epoch == state.steps_per_epoch: - if composer_profiler_action == ProfilerAction.ACTIVE: - # force saving at epoch boundaries - return TorchProfilerAction.RECORD_AND_SAVE if composer_profiler_action == ProfilerAction.ACTIVE and next_composer_profiler_action != ProfilerAction.ACTIVE: return TorchProfilerAction.RECORD_AND_SAVE if composer_profiler_action == ProfilerAction.ACTIVE: diff --git a/composer/trainer/_deepspeed.py b/composer/trainer/_deepspeed.py index 9695673df8..17a7de642c 100644 --- a/composer/trainer/_deepspeed.py +++ b/composer/trainer/_deepspeed.py @@ -17,19 +17,20 @@ def _add_batch_config(config: Dict[str, Any], state: State): - if state.train_dataloader.batch_size is None: + assert state.dataloader is not None, "dataloader should be set on FIT_START, which is where the Deepspeed config is applied." + if state.dataloader.batch_size is None: raise RuntimeError("DeepSpeed requires a dataloader with a known batch size.") - if state.train_dataloader.batch_size % state.grad_accum != 0: + if state.dataloader.batch_size % state.grad_accum != 0: # DeepSpeed will throw an error in this configuration. raise ValueError("The Mosaic trainer has been configured to use batch size=" - f"{state.train_dataloader.batch_size}, but this is not divisible by the " + f"{state.dataloader.batch_size}, but this is not divisible by the " f"grad accum={state.grad_accum}. This is unsupported when using DeepSpeed.") - train_batch_size = state.train_dataloader.batch_size * dist.get_world_size() + train_batch_size = state.dataloader.batch_size * dist.get_world_size() grad_accum = state.grad_accum # Per the check at the start of this function, the following division is always clean. - per_gpu_microbatch_size = state.train_dataloader.batch_size // state.grad_accum + per_gpu_microbatch_size = state.dataloader.batch_size // state.grad_accum if "train_batch_size" in config: ds_train_batch_size = config["train_batch_size"] diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 3a20682fa5..c814433730 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -356,6 +356,15 @@ class Trainer: train_subset_num_batches (int, optional): If specified, finish every epoch early after training on this many batches. This parameter has no effect if it is greater than ``len(train_dataloader)``. If ``None``, then the entire dataloader will be iterated over. (default: ``None``) + + .. warning:: + + This flag should be used for performance profiling only. Only a subset of the batches from the + dataloader will be used to training the model. + + In addition, learning rate schedulers will not be scaled to reflect ``train_subset_num_batches``, + and instead will assume ``len(train_dataloader)`` batches per epoch. + eval_subset_num_batches (int, optional): If specified, evaluate on this many batches. This parameter has no effect if it is greater than ``len(eval_dataloader)``. If ``None``, then the entire dataloader will be iterated over. (default: ``None``) @@ -415,12 +424,13 @@ def __init__( self, *, model: ComposerModel, - train_dataloader: Union[DataLoader, DataSpec], - max_duration: Union[int, str, Time], - eval_dataloader: Optional[Union[DataLoader, DataSpec, Evaluator, Sequence[Evaluator]]] = None, - algorithms: Optional[Union[Algorithm, Sequence[Algorithm]]] = None, - optimizers: Optional[torch.optim.Optimizer] = None, - schedulers: Optional[Union[ComposerScheduler, Sequence[ComposerScheduler]]] = None, + train_dataloader: Union[DataLoader, DataSpec], # TODO move param to fit() + max_duration: Union[int, str, Time], # TODO move param to fit() + eval_dataloader: Optional[Union[DataLoader, DataSpec, Evaluator, + Sequence[Evaluator]]] = None, # TODO move param to fit() + algorithms: Optional[Union[Algorithm, Sequence[Algorithm]]] = None, # TODO move param to fit() + optimizers: Optional[torch.optim.Optimizer] = None, # TODO move param to fit() + schedulers: Optional[Union[ComposerScheduler, Sequence[ComposerScheduler]]] = None, # TODO move param to fit() # device device: Optional[Union[str, Device]] = None, @@ -428,12 +438,12 @@ def __init__( # training hparams grad_accum: int = 1, grad_clip_norm: Optional[float] = None, - validate_every_n_batches: int = -1, - validate_every_n_epochs: int = 1, - compute_training_metrics: bool = False, + validate_every_n_batches: int = -1, # TODO move param to fit() + validate_every_n_epochs: int = 1, # TODO move param to fit() + compute_training_metrics: bool = False, # TODO move param to fit() precision: Union[str, Precision] = Precision.FP32, - scale_schedule_ratio: float = 1.0, - step_schedulers_every_batch: Optional[bool] = None, + scale_schedule_ratio: float = 1.0, # TODO move param to fit() + step_schedulers_every_batch: Optional[bool] = None, # TODO move param to fit() # dist hparams dist_timeout: float = 300.0, @@ -465,11 +475,11 @@ def __init__( save_weights_only: bool = False, # subset parameters - train_subset_num_batches: Optional[int] = None, - eval_subset_num_batches: Optional[int] = None, + train_subset_num_batches: Optional[int] = None, # TODO move param to fit() + eval_subset_num_batches: Optional[int] = None, # TODO move param to fit() # DeepSpeed - deepspeed_config: Union[bool, Dict[str, Any]] = False, + deepspeed_config: Union[bool, Dict[str, Any]] = False, # TODO move param to fit() # profiling profiler_trace_file: Optional[str] = None, @@ -580,7 +590,7 @@ def __init__( evaluator = Evaluator(label="eval_dataset", dataloader=evaluator, metrics=metrics) self.evaluators.append(evaluator) - self._eval_subset_num_batches = eval_subset_num_batches + self.eval_subset_num_batches = eval_subset_num_batches # do a check here to make sure there is at least one validation set if len(self.evaluators) == 0: @@ -632,12 +642,13 @@ def __init__( grad_accum=grad_accum, precision=precision, precision_context=precision_context, - train_dataloader=train_dataloader.dataloader, evaluators=self.evaluators, optimizers=optimizers, - steps_per_epoch=train_subset_num_batches, ) + self._train_dataloader = train_dataloader.dataloader + self.train_subset_num_batches = train_subset_num_batches + pytorch_schedulers = [ scheduler for scheduler in ensure_tuple(schedulers) if isinstance(scheduler, PyTorchScheduler) ] @@ -669,15 +680,11 @@ def __init__( step_schedulers_every_batch = True self._step_schedulers_every_batch = step_schedulers_every_batch + self._scale_schedule_ratio = scale_schedule_ratio - for scheduler in ensure_tuple(schedulers): - if isinstance(scheduler, PyTorchScheduler): - scale_pytorch_scheduler(scheduler, scale_schedule_ratio) - self.state.schedulers.append(scheduler) - else: # it's a composer scheduler - self.state.schedulers.append(compile_composer_scheduler(scheduler, self.state, scale_schedule_ratio)) + self._schedulers = ensure_tuple(schedulers) - if len(self.state.schedulers) == 0: + if len(self._schedulers) == 0: warnings.warn(f"NoSchedulerWarning: No schedulers were specified. The learning rate will be constant.") # Configure profilers if profiling is enabled @@ -918,12 +925,13 @@ def _spin_dataloaders(self): break # spin the train dataloader's sampler to get to the state of the desired epoch + assert self.state.dataloader is not None, "train dataloader is set on state after FIT_START" for epoch in range(int(self.state.timer.epoch)): # TODO: hasattr check will be removed while fixing https://github.com/mosaicml/composer/issues/424 - if hasattr(self.state.train_dataloader, "sampler") and isinstance(self.state.train_dataloader.sampler, - torch.utils.data.DistributedSampler): - self.state.train_dataloader.sampler.set_epoch(epoch) - for _ in self.state.train_dataloader: + if hasattr(self.state.dataloader, "sampler") and isinstance(self.state.dataloader.sampler, + torch.utils.data.DistributedSampler): + self.state.dataloader.sampler.set_epoch(epoch) + for _ in self.state.dataloader: break def _train_loop(self) -> None: @@ -944,6 +952,18 @@ def _train_loop(self) -> None: else: train_metrics = None + if self.state.dataloader is None: + # If the dataloader is not already set on State, then set it from what was passed in on init + self.state.dataloader = self._train_data_spec.dataloader + + for scheduler in ensure_tuple(self._schedulers): + if isinstance(scheduler, PyTorchScheduler): + scale_pytorch_scheduler(scheduler, self._scale_schedule_ratio) + self.state.schedulers.append(scheduler) + else: # it's a composer scheduler + self.state.schedulers.append( + compile_composer_scheduler(scheduler, self.state, self._scale_schedule_ratio)) + self.engine.run_event(Event.FIT_START) self.state.scaler = ClosureGradScaler() if self._use_closures() else GradScaler() @@ -965,12 +985,12 @@ def _train_loop(self) -> None: self.logger.data_epoch({"epoch": int(self.state.timer.epoch)}) # TODO: hasattr check will be removed while fixing https://github.com/mosaicml/composer/issues/424 - if hasattr(self.state.train_dataloader, "sampler") and isinstance(self.state.train_dataloader.sampler, - torch.utils.data.DistributedSampler): - self.state.train_dataloader.sampler.set_epoch(int(self.state.timer.epoch)) + if hasattr(self.state.dataloader, "sampler") and isinstance(self.state.dataloader.sampler, + torch.utils.data.DistributedSampler): + self.state.dataloader.sampler.set_epoch(int(self.state.timer.epoch)) for batch_idx, self.state.batch in enumerate( - itertools.islice(self.state.train_dataloader, self.state.steps_per_epoch)): + itertools.islice(self.state.dataloader, self.train_subset_num_batches)): # if resuming, skip dataloader forward to the minibatch index if batch_idx < int(self.state.timer.batch_in_epoch): @@ -1208,24 +1228,28 @@ def eval(self, is_batch: bool): """ restore_model_train = self.state.model.training + # back up the original dataloader on the state, so we can restore it after evaluation is finished + original_dataloader = self.state.dataloader + self.state.model.eval() with torch.no_grad(): - self.engine.run_event(Event.EVAL_START) - for evaluator in self.state.evaluators: - dataloader = evaluator.dataloader.dataloader + self.state.dataloader = evaluator.dataloader.dataloader + + self.engine.run_event(Event.EVAL_START) + metrics = self._ensure_metrics_device_and_dtype(evaluator.metrics) # TODO: hasattr check will be removed while fixing https://github.com/mosaicml/composer/issues/424 - if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, - torch.utils.data.DistributedSampler): + if hasattr(self.state.dataloader, "sampler") and isinstance(self.state.dataloader.sampler, + torch.utils.data.DistributedSampler): # The distributed sampler uses `set_epoch` to set the random seed # Because evaluation can run on each batch, we use the batch to seed the sampler # so each evaluation will get a proper shuffle. # The epoch provided to `set_epoch` need not be sequential, so this is fine. - dataloader.sampler.set_epoch(int(self.state.timer.batch)) + self.state.dataloader.sampler.set_epoch(int(self.state.timer.batch)) - for self.state.batch in itertools.islice(dataloader, self._eval_subset_num_batches): + for self.state.batch in itertools.islice(self.state.dataloader, self.eval_subset_num_batches): self.state.batch = self._device.batch_to_device(self.state.batch) if evaluator.dataloader.device_transforms: self.state.batch = evaluator.dataloader.device_transforms(self.state.batch) @@ -1250,11 +1274,13 @@ def eval(self, is_batch: bool): self._compute_and_log_metrics(metrics, is_train=False, is_batch=is_batch, logging_label=evaluator.label) - self.engine.run_event(Event.EVAL_END) + self.engine.run_event(Event.EVAL_END) if restore_model_train: self.state.model.train() + self.state.dataloader = original_dataloader + def _use_grad_scaling(self, precision: Union[str, Precision], scaler: Optional[GradScaler]) -> bool: """Determines based on precision when to use grad scaling. diff --git a/tests/algorithms/test_colout.py b/tests/algorithms/test_colout.py index cd7ffa29b6..edb7f99d85 100644 --- a/tests/algorithms/test_colout.py +++ b/tests/algorithms/test_colout.py @@ -212,7 +212,7 @@ def test_apply_sample(self, colout_algorithm: ColOut, minimal_state: State, empt original_image, _ = dataset[0] assert isinstance(original_image, Image.Image) - minimal_state.train_dataloader = dataloader + minimal_state.dataloader = dataloader colout_algorithm.apply(Event.INIT, minimal_state, empty_logger) new_image, _ = dataset[0] diff --git a/tests/algorithms/test_layer_freezing.py b/tests/algorithms/test_layer_freezing.py index 748c7bde4e..5f73941983 100644 --- a/tests/algorithms/test_layer_freezing.py +++ b/tests/algorithms/test_layer_freezing.py @@ -20,7 +20,6 @@ def _generate_state(epoch: int, max_epochs: int): optimizers=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.99), precision=Precision.FP32, grad_accum=1, - train_dataloader=Mock(__len__=lambda x: 100), evaluators=Mock(), max_duration=f'{max_epochs}ep') @@ -28,6 +27,8 @@ def _generate_state(epoch: int, max_epochs: int): for _ in range(epoch): state.timer.on_epoch_complete() + state.dataloader = Mock(__len__=lambda x: 100) + return state diff --git a/tests/algorithms/test_selective_backprop.py b/tests/algorithms/test_selective_backprop.py index f42fa31677..c3dcc4a089 100644 --- a/tests/algorithms/test_selective_backprop.py +++ b/tests/algorithms/test_selective_backprop.py @@ -144,12 +144,12 @@ def conv_model(Ximage: torch.Tensor, D: int) -> ComposerClassifier: def state(minimal_state: State, conv_model: ComposerClassifier, loss_fun_tuple: Callable, epoch: int, batch: int) -> State: """State with required values set for Selective Backprop.""" - + assert minimal_state.dataloader is not None conv_model.loss = loss_fun_tuple minimal_state.model = conv_model minimal_state.timer.epoch._value = epoch - minimal_state.timer.batch._value = epoch * minimal_state.steps_per_epoch + batch + minimal_state.timer.batch._value = epoch * len(minimal_state.dataloader) + batch minimal_state.timer.batch_in_epoch._value = batch return minimal_state diff --git a/tests/algorithms/test_stochastic_depth.py b/tests/algorithms/test_stochastic_depth.py index 68ec3dd4c8..583796221d 100644 --- a/tests/algorithms/test_stochastic_depth.py +++ b/tests/algorithms/test_stochastic_depth.py @@ -232,7 +232,8 @@ def test_drop_rate_warmup(self, algorithm: StochasticDepth, step: int, state: St self.get_drop_rate_list(state.model, drop_rates=new_drop_rates) assert state.max_duration.unit == TimeUnit.EPOCH - drop_warmup_iters = int(state.steps_per_epoch * int(state.max_duration.value) * algorithm.drop_warmup) + assert state.dataloader is not None + drop_warmup_iters = int(len(state.dataloader) * int(state.max_duration.value) * algorithm.drop_warmup) assert torch.all(torch.tensor(new_drop_rates) == ((step / drop_warmup_iters) * torch.tensor(old_drop_rates))) diff --git a/tests/callbacks/test_speed_monitor.py b/tests/callbacks/test_speed_monitor.py index cda0eb7bad..f19f707b4e 100644 --- a/tests/callbacks/test_speed_monitor.py +++ b/tests/callbacks/test_speed_monitor.py @@ -35,8 +35,9 @@ def test_speed_monitor(composer_trainer_hparams: TrainerHparams): if 'wall_clock_train' in metrics: wall_clock_train_calls += 1 - assert isinstance(trainer.state.train_dataloader, collections.abc.Sized) - expected_step_calls = (trainer.state.steps_per_epoch - speed_monitor_hparams.window_size) * max_epochs + assert isinstance(trainer.state.dataloader, collections.abc.Sized) + assert trainer.train_subset_num_batches is not None + expected_step_calls = (trainer.train_subset_num_batches - speed_monitor_hparams.window_size) * max_epochs assert throughput_step_calls == expected_step_calls assert throughput_epoch_calls == max_epochs assert wall_clock_train_calls == max_epochs diff --git a/tests/fixtures/dummy_fixtures.py b/tests/fixtures/dummy_fixtures.py index 37be0a7f88..66dc268659 100644 --- a/tests/fixtures/dummy_fixtures.py +++ b/tests/fixtures/dummy_fixtures.py @@ -111,12 +111,12 @@ def dummy_state(dummy_model: SimpleBatchPairModel, dummy_train_dataloader: DataL precision=Precision.FP32, grad_accum=1, rank_zero_seed=rank_zero_seed, - train_dataloader=dummy_train_dataloader, evaluators=evaluators, optimizers=dummy_optimizer, max_duration="10ep", ) state.schedulers = dummy_scheduler + state.dataloader = dummy_train_dataloader return state @@ -230,9 +230,10 @@ def state_with_model(simple_conv_model: torch.nn.Module, dummy_train_dataloader: max_duration="100ep", model=simple_conv_model, precision=Precision.FP32, - train_dataloader=dummy_train_dataloader, evaluators=evaluators, ) + state.dataloader = dummy_train_dataloader + return state diff --git a/tests/fixtures/new_fixtures.py b/tests/fixtures/new_fixtures.py index 59da65cc0d..b43482f821 100644 --- a/tests/fixtures/new_fixtures.py +++ b/tests/fixtures/new_fixtures.py @@ -15,13 +15,14 @@ def minimal_state(rank_zero_seed: int): Tests should configure the state for their specific needs. """ - return State( + state = State( model=SimpleModel(), rank_zero_seed=rank_zero_seed, - train_dataloader=DataLoader(RandomClassificationDataset()), evaluators=[], max_duration='100ep', ) + state.dataloader = DataLoader(RandomClassificationDataset()) + return state @pytest.fixture diff --git a/tests/loggers/test_progress_bar_logger.py b/tests/loggers/test_progress_bar_logger.py index 94d03d9a13..0191bc0b1f 100644 --- a/tests/loggers/test_progress_bar_logger.py +++ b/tests/loggers/test_progress_bar_logger.py @@ -41,8 +41,8 @@ def get_mock_tqdm(position: int, *args: object, **kwargs: object): assert composer_trainer_hparams.validate_every_n_batches < 0 assert len(is_train_to_mock_tqdms[False]) == composer_trainer_hparams.validate_every_n_epochs * max_epochs for mock_tqdm in is_train_to_mock_tqdms[True]: - assert mock_tqdm.update.call_count == trainer.state.steps_per_epoch + assert mock_tqdm.update.call_count == trainer.train_subset_num_batches mock_tqdm.close.assert_called_once() for mock_tqdm in is_train_to_mock_tqdms[False]: - assert mock_tqdm.update.call_count == trainer._eval_subset_num_batches + assert mock_tqdm.update.call_count == trainer.eval_subset_num_batches mock_tqdm.close.assert_called_once() diff --git a/tests/test_state.py b/tests/test_state.py index 29499fda73..7d6ab7db50 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -29,7 +29,6 @@ def get_dummy_state(model: ComposerModel, train_dataloader: DataLoader, val_data rank_zero_seed=random.randint(0, 100), precision=Precision.AMP, max_duration=f"{random.randint(0, 100)}ep", - train_dataloader=train_dataloader, evaluators=evaluators, optimizers=optimizers, algorithms=[ChannelsLastHparams().initialize_object()]) @@ -37,6 +36,7 @@ def get_dummy_state(model: ComposerModel, train_dataloader: DataLoader, val_data state.loss = random_tensor() state.batch = (random_tensor(), random_tensor()) state.outputs = random_tensor() + state.dataloader = train_dataloader return state diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 5db4b2d2bc..ebe17033c8 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -385,10 +385,10 @@ def _test_checkpoint_trainer(trainer_hparams: TrainerHparams): def _validate_events_called_expected_number_of_times(trainer: Trainer): state = trainer.state - + assert trainer.train_subset_num_batches is not None assert state.max_duration.unit == TimeUnit.EPOCH num_epochs = state.max_duration.value - num_total_steps = num_epochs * state.steps_per_epoch + num_total_steps = num_epochs * trainer.train_subset_num_batches num_total_microbatches = num_total_steps * state.grad_accum num_evals = 0 if trainer._validate_every_n_batches > 0: @@ -399,8 +399,8 @@ def _validate_events_called_expected_number_of_times(trainer: Trainer): assert state.evaluators is not None for evaluator in state.evaluators: assert evaluator.dataloader is not None - assert trainer._eval_subset_num_batches is not None - num_eval_steps = num_evals * trainer._eval_subset_num_batches * len(state.evaluators) + assert trainer.eval_subset_num_batches is not None + num_eval_steps = num_evals * trainer.eval_subset_num_batches * len(state.evaluators) event_to_num_expected_invocations = { Event.INIT: 1, diff --git a/tests/trainer/test_ddp.py b/tests/trainer/test_ddp.py index 80bf86ba67..e2900699c2 100644 --- a/tests/trainer/test_ddp.py +++ b/tests/trainer/test_ddp.py @@ -182,7 +182,6 @@ def test_ddp(device: DeviceHparams, world_size: int, composer_trainer_hparams: T if deepspeed: hparams.deepspeed = {} trainer = hparams.initialize_object() - assert isinstance(trainer.state.train_dataloader.dataset, collections.abc.Sized) for evaluator in trainer.evaluators: assert isinstance(evaluator.dataloader, DataSpec) diff --git a/tests/trainer/test_ddp_sync_strategy.py b/tests/trainer/test_ddp_sync_strategy.py index 50d6884661..30f65303df 100644 --- a/tests/trainer/test_ddp_sync_strategy.py +++ b/tests/trainer/test_ddp_sync_strategy.py @@ -60,9 +60,9 @@ def test_ddp_sync_strategy(ddp_sync_strategy: str, expected_grads: List[Optional optimizers=optimizer, grad_accum=2, max_duration="1ep", - train_dataloader=dummy_train_dataloader, evaluators=evaluators, precision='fp32') + state.dataloader = dummy_train_dataloader batches = [[(1, Tensor([1])), (1, Tensor([2]))], [(2, Tensor([1])), (2, Tensor([2]))]] state.model = _prepare_ddp_module(state.model, find_unused_parameters=True) diff --git a/tests/trainer/test_scale_schedule.py b/tests/trainer/test_scale_schedule.py index 88e3766b52..2084d1a988 100644 --- a/tests/trainer/test_scale_schedule.py +++ b/tests/trainer/test_scale_schedule.py @@ -9,8 +9,11 @@ from torch.optim.lr_scheduler import ExponentialLR from composer.algorithms import ScaleScheduleHparams +from composer.core import State +from composer.core.callback import Callback from composer.core.time import TimeUnit from composer.core.types import PyTorchScheduler +from composer.loggers.logger import Logger from composer.optim import MultiStepSchedulerHparams, SGDHparams from composer.trainer import TrainerHparams from composer.trainer._scale_schedule import scale_pytorch_scheduler @@ -27,6 +30,7 @@ def flatten(lst: list): @pytest.mark.parametrize('ssr', [0.5, 0.75, 1.0]) +@pytest.mark.filterwarnings(r"ignore:.*Detected call of \`lr_schedule.*:UserWarning") class TestScaleSchedule(): @staticmethod @@ -73,9 +77,13 @@ def test_scale_schedule_cosine_warm_restarts(self, optimizer: Optimizer, ssr: fl @pytest.mark.parametrize('ssr', [0.5, 0.75, 1.0]) -@pytest.mark.parametrize('use_algorithm', [False, True]) +@pytest.mark.parametrize('use_algorithm', [ + False, + pytest.param(True, marks=pytest.mark.filterwarnings(r"ignore:.*ScaleScheduleDeprecationWarning.*")), +]) class TestScaleScheduleTrainer(): + @pytest.mark.filterwarnings(r"ignore:.*Detected call of \`lr_schedule.*:UserWarning") def test_epochs_scaled( self, ssr: float, @@ -91,21 +99,29 @@ def test_epochs_scaled( composer_trainer_hparams.algorithms = [ScaleScheduleHparams(ratio=ssr)] else: composer_trainer_hparams.scale_schedule_ratio = ssr - trainer = composer_trainer_hparams.initialize_object() - assert trainer.state.max_duration.unit == TimeUnit.EPOCH - assert trainer.state.max_duration.value == int(10 * ssr) - scheduler = trainer.state.schedulers[0] + class CheckScaleSchedule(Callback): - test_steps = [int(20 * ssr), int(40 * ssr), int(60 * ssr)] - target_lrs = [1.0, 0.1, 0.01] - current_step = 0 - for test_step, target_lr in zip(test_steps, target_lrs): + def fit_start(self, state: State, logger: Logger) -> None: + scheduler = state.schedulers[0] - while current_step < test_step: - trainer.state.timer.on_batch_complete() - current_step += 1 + test_steps = [int(20 * ssr), int(40 * ssr), int(60 * ssr)] + target_lrs = [1.0, 0.1, 0.01] + current_step = 0 + for test_step, target_lr in zip(test_steps, target_lrs): - scheduler.step() + while current_step < test_step: + trainer.state.timer.on_batch_complete() + current_step += 1 + + scheduler.step() + + assert scheduler.get_last_lr()[0] == pytest.approx(target_lr) + + trainer = composer_trainer_hparams.initialize_object() + trainer.state.callbacks.append(CheckScaleSchedule()) + + assert trainer.state.max_duration.unit == TimeUnit.EPOCH + assert trainer.state.max_duration.value == int(10 * ssr) - assert scheduler.get_last_lr()[0] == pytest.approx(target_lr) + trainer.fit() diff --git a/tests/trainer/test_scheduler.py b/tests/trainer/test_scheduler.py index 6568595dc5..127ef3c7dd 100644 --- a/tests/trainer/test_scheduler.py +++ b/tests/trainer/test_scheduler.py @@ -5,6 +5,7 @@ import pytest import torch +import torch.utils.data from composer.core import State, Time from composer.core.time import TimeUnit @@ -21,14 +22,14 @@ @pytest.fixture -def dummy_schedulers_state(dummy_model: torch.nn.Module, dummy_train_dataloader: DataLoader, rank_zero_seed: int): - return State( +def dummy_schedulers_state(dummy_model: torch.nn.Module, rank_zero_seed: int): + state = State( model=dummy_model, rank_zero_seed=rank_zero_seed, - train_dataloader=dummy_train_dataloader, max_duration=MAX_DURATION, - steps_per_epoch=STEPS_PER_EPOCH, ) + state.dataloader = cast(torch.utils.data.DataLoader, [None] * STEPS_PER_EPOCH) + return state @pytest.mark.parametrize("scheduler,ssr,test_times,expected_lrs", [ @@ -88,16 +89,17 @@ def test_scheduler_init(scheduler: ComposerScheduler, ssr: float, test_times: Li dummy_schedulers_state: State): state = dummy_schedulers_state - state._max_duration = Time(value=int(state.max_duration.value * ssr), unit=state.max_duration.unit) + assert state.dataloader is not None + state.max_duration = Time(value=int(state.max_duration.value * ssr), unit=state.max_duration.unit) for test_time, expected_lr in zip(test_times, expected_lrs): parsed_time = Time.from_timestring(test_time) assert parsed_time.unit in [TimeUnit.EPOCH, TimeUnit.BATCH] if parsed_time.unit == TimeUnit.EPOCH: state.timer._epoch = parsed_time - state.timer._batch = Time(int(state.steps_per_epoch * state.timer._epoch.value), TimeUnit.BATCH) + state.timer._batch = Time(int(len(state.dataloader) * int(state.timer.epoch)), TimeUnit.BATCH) else: state.timer._batch = parsed_time - state.timer._epoch = Time(int(state.timer._batch.value / state.steps_per_epoch), TimeUnit.EPOCH) + state.timer._epoch = Time(int(state.timer.batch) // len(state.dataloader), TimeUnit.EPOCH) lr = scheduler(state, ssr) assert lr == pytest.approx(expected_lr, abs=1e-3) diff --git a/tests/utils/classifer.py b/tests/utils/classifer.py index f587bd0fc3..736c798094 100644 --- a/tests/utils/classifer.py +++ b/tests/utils/classifer.py @@ -14,24 +14,21 @@ from tests.utils.model import SimpleModel -def _get_state(train_dataloader: DataLoader, - eval_dataloader: DataLoader, - rank_zero_seed: int, - steps_per_epoch: int = 1): +def _get_state(train_dataloader: DataLoader, eval_dataloader: DataLoader, rank_zero_seed: int): model = SimpleModel() - steps_per_epoch = steps_per_epoch metric_coll = MetricCollection([Accuracy()]) evaluators = [Evaluator(label="dummy_label", dataloader=eval_dataloader, metrics=metric_coll)] - return State( + state = State( model=model, rank_zero_seed=rank_zero_seed, optimizers=SGD(model.parameters(), lr=.001, momentum=0.0), max_duration="1ep", - train_dataloader=train_dataloader, evaluators=evaluators, grad_accum=1, precision=Precision.FP32, ) + state.dataloader = train_dataloader + return state def test_classifier_trains( @@ -47,7 +44,6 @@ def test_classifier_trains( state = _get_state(train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, - steps_per_epoch=n_steps, rank_zero_seed=rank_zero_seed) model = state.model @@ -72,9 +68,9 @@ def f_loss(state): opt = state.optimizers assert isinstance(opt, Optimizer) - assert state.train_dataloader is not None, "train dataloader should be set" + assert state.dataloader is not None, "train dataloader should be set" - for step, (X, y) in enumerate(state.train_dataloader): + for step, (X, y) in enumerate(state.dataloader): # reseed here so data is same for different sets of algorithms torch.manual_seed(step) diff --git a/tests/utils/trainer_fit.py b/tests/utils/trainer_fit.py index c9c64612e3..e5ad6e5d5b 100644 --- a/tests/utils/trainer_fit.py +++ b/tests/utils/trainer_fit.py @@ -57,10 +57,10 @@ def train_model(composer_trainer_hparams: TrainerHparams, max_epochs: int = 2, r if isinstance(trainer._device, DeviceGPU): original_model = trainer._device.module_to_device(original_model) - if run_loss_check and trainer.state.train_dataloader: - initial_loss = get_total_loss(original_model, trainer.state.train_dataloader, trainer._device) + if run_loss_check and trainer.state.dataloader: + initial_loss = get_total_loss(original_model, trainer.state.dataloader, trainer._device) unwrapped_model = trainer.state.model.module assert isinstance(unwrapped_model, ComposerModel) - post_fit_loss = get_total_loss(unwrapped_model, trainer.state.train_dataloader, trainer._device) + post_fit_loss = get_total_loss(unwrapped_model, trainer.state.dataloader, trainer._device) assert post_fit_loss < initial_loss + 1e-5, f"post_fit_loss({post_fit_loss}) - initial_loss({initial_loss}) >= 1e-5" From 558f2790fa091458e3bc4823c7e604f943fbe965 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 12 Apr 2022 12:51:17 -0700 Subject: [PATCH 02/17] Restored `dataloader_len` on state --- .../seq_length_warmup/seq_length_warmup.py | 4 +- composer/core/state.py | 46 ++++++++++++++++++- composer/loggers/progress_bar_logger.py | 13 +----- composer/optim/scheduler.py | 12 ++--- composer/trainer/trainer.py | 24 +++++----- tests/algorithms/test_selective_backprop.py | 4 +- tests/algorithms/test_stochastic_depth.py | 4 +- tests/trainer/test_scheduler.py | 6 +-- 8 files changed, 74 insertions(+), 39 deletions(-) diff --git a/composer/algorithms/seq_length_warmup/seq_length_warmup.py b/composer/algorithms/seq_length_warmup/seq_length_warmup.py index 15bb71cd93..7fd573025e 100644 --- a/composer/algorithms/seq_length_warmup/seq_length_warmup.py +++ b/composer/algorithms/seq_length_warmup/seq_length_warmup.py @@ -237,7 +237,9 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: self._activated = True if state.max_duration.unit == TimeUnit.EPOCH: - num_optimization_steps = len(state.dataloader) * state.max_duration.value + if state.dataloader_len is None: + raise RuntimeError("Sequential Length Warmup requires the dataloader to be sized.") + num_optimization_steps = int(state.dataloader_len) * state.max_duration.value elif state.max_duration.unit == TimeUnit.BATCH: num_optimization_steps = state.max_duration.value else: diff --git a/composer/core/state.py b/composer/core/state.py index afae46a72d..b51d8399ff 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -112,7 +112,6 @@ class State(Serializable): profiler (Optional[Profiler]): The Composer profiler. Attributes: - dataloader (types.DataLoader): The active :class:`~.types.DataLoader`. batch (types.Batch): The batch. This will be the entire batch during the :attr:`.Event.AFTER_DATALOADER`, or a microbatch between :attr:`.Event.BATCH_START` and :attr:`.Event.BATCH_END`. batch_num_samples (int): The number of samples in the :attr:`batch`. @@ -187,7 +186,8 @@ def __init__( self.rank_zero_seed = rank_zero_seed self.model = model self.grad_accum = grad_accum - self.dataloader: Optional[types.DataLoader] = None + self._dataloader: Optional[types.DataLoader] = None + self._dataloader_len: Optional[Time[int]] = None self.evaluators = list(ensure_tuple(evaluators)) self.max_duration = max_duration @@ -358,6 +358,48 @@ def load_state_dict(self, state: Dict[str, Any], strict: bool = False): # ignore AttributeError for properties that have getters but not setters. pass + @property + def dataloader(self): + """The dataloader.""" + return self._dataloader + + @dataloader.setter + def dataloader(self, dataloader: Optional[types.DataLoader]): + self._dataloader = dataloader + self.dataloader_len = None # setting it to None will do a failsafe read of len(dataloader) + + @property + def dataloader_len(self): + """The number of batches per dataloader iteration (e.g. epoch), as used by the trainer. + + Returns: + Optional[Time[int]]: The number of batches per dataloader iteration (e.g. epoch), or None if no dataloader + is defined or if the dataloader has an unknown length (e.g. streaming dataloaders). + """ + return self._dataloader_len + + @dataloader_len.setter + def dataloader_len(self, num_batches: Optional[Union[int, Time[int]]]): + if isinstance(num_batches, int): + num_batches = Time(num_batches, TimeUnit.BATCH) + if self._dataloader is None: + raise RuntimeError("`State.dataloader_len` cannot be set if the dataloader is not defined.") + try: + dataloader_len = len(self._dataloader) + except (TypeError, NotImplementedError): + dataloader_len = None + if dataloader_len is not None and num_batches is not None and int(num_batches) > dataloader_len: + warnings.warn((f"DataloaderNumBatchesWarning: The dataloader_len ({int(num_batches)}) " + f"is greater than the length (i.e. number of batches) of the dataloader, which is " + f"{dataloader_len}. State.dataloader_len is thus being set to {dataloader_len}.")) + num_batches = Time(dataloader_len, TimeUnit.BATCH) + if num_batches is None and dataloader_len is not None: + # len(dataloader) is an approximation -- see https://pytorch.org/docs/stable/data.html. + # However, in the worst case where additional last batches are dropped, this calculation should be + # an over-estimate, leading to the entire dataloader still being iterated over. + num_batches = Time(dataloader_len, TimeUnit.BATCH) + self._dataloader_len = num_batches + @property def precision(self): """The numerical precision to use for training. diff --git a/composer/loggers/progress_bar_logger.py b/composer/loggers/progress_bar_logger.py index 6889c63cff..576dd4173e 100644 --- a/composer/loggers/progress_bar_logger.py +++ b/composer/loggers/progress_bar_logger.py @@ -4,7 +4,6 @@ from __future__ import annotations -import collections.abc from dataclasses import asdict, dataclass from typing import Any, Dict, List, Optional @@ -99,22 +98,14 @@ def _start(self, state: State): if dist.get_global_rank() != 0: return assert self.is_train is not None, "self.is_train should be set by the callback" - assert state.dataloader is not None, "dataloader should be set when using tqdm" - if self.is_train: - total_steps = len(state.dataloader) - else: - total_steps = 0 - for evaluator in state.evaluators: - dataloader_spec = evaluator.dataloader - assert isinstance(dataloader_spec.dataloader, collections.abc.Sized) - total_steps += len(dataloader_spec.dataloader) + assert state.dataloader_len is not None, "dataloader_len should be set when using tqdm" desc = f'Epoch {int(state.timer.epoch)}' position = 0 if self.is_train else 1 if not self.is_train: desc += f", Batch {int(state.timer.batch)} (val)" self.pbars[self.is_train] = _ProgressBarLoggerInstance( - _ProgressBarLoggerInstanceState(total=total_steps, + _ProgressBarLoggerInstanceState(total=state.dataloader_len, position=position, n=0, keys_to_log=_IS_TRAIN_TO_KEYS_TO_LOG[self.is_train], diff --git a/composer/optim/scheduler.py b/composer/optim/scheduler.py index 226bd20114..168f78d0a7 100644 --- a/composer/optim/scheduler.py +++ b/composer/optim/scheduler.py @@ -131,19 +131,19 @@ def _convert_time(time: Union[str, Time[int], Time[float]], state: State, ssr: f time = Time.from_timestring(time) if time.unit == TimeUnit.DURATION: - if state.dataloader is None: - raise RuntimeError("Cannot convert time, as state.dataloader is None.") + if state.dataloader_len is None: + raise RuntimeError("Cannot convert time, as state.dataloader_len is None.") if state.max_duration.unit == TimeUnit.EPOCH: - return Time(int(time.value * len(state.dataloader) * state.max_duration.value), TimeUnit.BATCH) + return Time(int(time.value * state.dataloader_len * state.max_duration.value), TimeUnit.BATCH) return Time(int(time.value * state.max_duration.value), state.max_duration.unit) if time.unit == TimeUnit.EPOCH: # Epochs do not provide sufficient granularity for SSR scaling # e.g. if max_duration = 1ep, then any SSR would result in a new duration of 0. # so, convert the time into batches - if state.dataloader is None: - raise RuntimeError("Cannot convert time, as state.dataloader is None.") - time = Time(value=time.value * len(state.dataloader), unit=TimeUnit.BATCH) + if state.dataloader_len is None: + raise RuntimeError("Cannot convert time, as state.dataloader_len is None.") + time = Time(value=time.value * state.dataloader_len, unit=TimeUnit.BATCH) return Time(value=int(time.value * ssr), unit=time.unit) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index c814433730..7728d9eaba 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -357,15 +357,7 @@ class Trainer: on this many batches. This parameter has no effect if it is greater than ``len(train_dataloader)``. If ``None``, then the entire dataloader will be iterated over. (default: ``None``) - .. warning:: - - This flag should be used for performance profiling only. Only a subset of the batches from the - dataloader will be used to training the model. - - In addition, learning rate schedulers will not be scaled to reflect ``train_subset_num_batches``, - and instead will assume ``len(train_dataloader)`` batches per epoch. - - eval_subset_num_batches (int, optional): If specified, evaluate on this many batches. + eval_subset_num_batches (int, optional): If specified, evaluate on this many batches per evaluation dataloader. This parameter has no effect if it is greater than ``len(eval_dataloader)``. If ``None``, then the entire dataloader will be iterated over. (default: ``None``) deepspeed_config (bool or Dict[str, Any], optional): Configuration for DeepSpeed, formatted as a JSON @@ -955,6 +947,7 @@ def _train_loop(self) -> None: if self.state.dataloader is None: # If the dataloader is not already set on State, then set it from what was passed in on init self.state.dataloader = self._train_data_spec.dataloader + self.state.dataloader_len = self.train_subset_num_batches for scheduler in ensure_tuple(self._schedulers): if isinstance(scheduler, PyTorchScheduler): @@ -989,8 +982,12 @@ def _train_loop(self) -> None: torch.utils.data.DistributedSampler): self.state.dataloader.sampler.set_epoch(int(self.state.timer.epoch)) - for batch_idx, self.state.batch in enumerate( - itertools.islice(self.state.dataloader, self.train_subset_num_batches)): + if self.state.dataloader_len is None: + iterable = self.state.dataloader + else: + iterable = itertools.islice(self.state.dataloader, int(self.state.dataloader_len)) + + for batch_idx, self.state.batch in enumerate(iterable): # if resuming, skip dataloader forward to the minibatch index if batch_idx < int(self.state.timer.batch_in_epoch): @@ -1230,12 +1227,14 @@ def eval(self, is_batch: bool): # back up the original dataloader on the state, so we can restore it after evaluation is finished original_dataloader = self.state.dataloader + original_num_batches = self.state.dataloader_len self.state.model.eval() with torch.no_grad(): for evaluator in self.state.evaluators: self.state.dataloader = evaluator.dataloader.dataloader + self.state.dataloader_len = self.eval_subset_num_batches self.engine.run_event(Event.EVAL_START) @@ -1249,7 +1248,7 @@ def eval(self, is_batch: bool): # The epoch provided to `set_epoch` need not be sequential, so this is fine. self.state.dataloader.sampler.set_epoch(int(self.state.timer.batch)) - for self.state.batch in itertools.islice(self.state.dataloader, self.eval_subset_num_batches): + for self.state.batch in itertools.islice(self.state.dataloader, self.state.dataloader_len): self.state.batch = self._device.batch_to_device(self.state.batch) if evaluator.dataloader.device_transforms: self.state.batch = evaluator.dataloader.device_transforms(self.state.batch) @@ -1280,6 +1279,7 @@ def eval(self, is_batch: bool): self.state.model.train() self.state.dataloader = original_dataloader + self.state.dataloader_len = original_num_batches def _use_grad_scaling(self, precision: Union[str, Precision], scaler: Optional[GradScaler]) -> bool: """Determines based on precision when to use grad scaling. diff --git a/tests/algorithms/test_selective_backprop.py b/tests/algorithms/test_selective_backprop.py index c3dcc4a089..9fb1a7bea0 100644 --- a/tests/algorithms/test_selective_backprop.py +++ b/tests/algorithms/test_selective_backprop.py @@ -144,12 +144,12 @@ def conv_model(Ximage: torch.Tensor, D: int) -> ComposerClassifier: def state(minimal_state: State, conv_model: ComposerClassifier, loss_fun_tuple: Callable, epoch: int, batch: int) -> State: """State with required values set for Selective Backprop.""" - assert minimal_state.dataloader is not None + assert minimal_state.dataloader_len is not None conv_model.loss = loss_fun_tuple minimal_state.model = conv_model minimal_state.timer.epoch._value = epoch - minimal_state.timer.batch._value = epoch * len(minimal_state.dataloader) + batch + minimal_state.timer.batch._value = epoch * minimal_state.dataloader_len + batch minimal_state.timer.batch_in_epoch._value = batch return minimal_state diff --git a/tests/algorithms/test_stochastic_depth.py b/tests/algorithms/test_stochastic_depth.py index 583796221d..a356d21526 100644 --- a/tests/algorithms/test_stochastic_depth.py +++ b/tests/algorithms/test_stochastic_depth.py @@ -232,8 +232,8 @@ def test_drop_rate_warmup(self, algorithm: StochasticDepth, step: int, state: St self.get_drop_rate_list(state.model, drop_rates=new_drop_rates) assert state.max_duration.unit == TimeUnit.EPOCH - assert state.dataloader is not None - drop_warmup_iters = int(len(state.dataloader) * int(state.max_duration.value) * algorithm.drop_warmup) + assert state.dataloader_len is not None + drop_warmup_iters = int(int(state.dataloader_len) * int(state.max_duration.value) * algorithm.drop_warmup) assert torch.all(torch.tensor(new_drop_rates) == ((step / drop_warmup_iters) * torch.tensor(old_drop_rates))) diff --git a/tests/trainer/test_scheduler.py b/tests/trainer/test_scheduler.py index 127ef3c7dd..f5a62d196f 100644 --- a/tests/trainer/test_scheduler.py +++ b/tests/trainer/test_scheduler.py @@ -89,17 +89,17 @@ def test_scheduler_init(scheduler: ComposerScheduler, ssr: float, test_times: Li dummy_schedulers_state: State): state = dummy_schedulers_state - assert state.dataloader is not None + assert state.dataloader_len is not None state.max_duration = Time(value=int(state.max_duration.value * ssr), unit=state.max_duration.unit) for test_time, expected_lr in zip(test_times, expected_lrs): parsed_time = Time.from_timestring(test_time) assert parsed_time.unit in [TimeUnit.EPOCH, TimeUnit.BATCH] if parsed_time.unit == TimeUnit.EPOCH: state.timer._epoch = parsed_time - state.timer._batch = Time(int(len(state.dataloader) * int(state.timer.epoch)), TimeUnit.BATCH) + state.timer._batch = Time(int(state.dataloader_len) * int(state.timer.epoch), TimeUnit.BATCH) else: state.timer._batch = parsed_time - state.timer._epoch = Time(int(state.timer.batch) // len(state.dataloader), TimeUnit.EPOCH) + state.timer._epoch = Time(int(state.timer.batch) // int(state.dataloader_len), TimeUnit.EPOCH) lr = scheduler(state, ssr) assert lr == pytest.approx(expected_lr, abs=1e-3) From f5d0a1bce8b9877b98f3d10c7c0907c8fbe56bc2 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 12 Apr 2022 13:01:20 -0700 Subject: [PATCH 03/17] Fixed tests --- composer/loggers/progress_bar_logger.py | 2 +- composer/optim/scheduler.py | 4 ++-- composer/trainer/trainer.py | 7 ++++++- tests/algorithms/test_selective_backprop.py | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/composer/loggers/progress_bar_logger.py b/composer/loggers/progress_bar_logger.py index 576dd4173e..3760c6422c 100644 --- a/composer/loggers/progress_bar_logger.py +++ b/composer/loggers/progress_bar_logger.py @@ -105,7 +105,7 @@ def _start(self, state: State): if not self.is_train: desc += f", Batch {int(state.timer.batch)} (val)" self.pbars[self.is_train] = _ProgressBarLoggerInstance( - _ProgressBarLoggerInstanceState(total=state.dataloader_len, + _ProgressBarLoggerInstanceState(total=int(state.dataloader_len), position=position, n=0, keys_to_log=_IS_TRAIN_TO_KEYS_TO_LOG[self.is_train], diff --git a/composer/optim/scheduler.py b/composer/optim/scheduler.py index 168f78d0a7..4ce97980a4 100644 --- a/composer/optim/scheduler.py +++ b/composer/optim/scheduler.py @@ -134,7 +134,7 @@ def _convert_time(time: Union[str, Time[int], Time[float]], state: State, ssr: f if state.dataloader_len is None: raise RuntimeError("Cannot convert time, as state.dataloader_len is None.") if state.max_duration.unit == TimeUnit.EPOCH: - return Time(int(time.value * state.dataloader_len * state.max_duration.value), TimeUnit.BATCH) + return Time(int(time.value * int(state.dataloader_len) * state.max_duration.value), TimeUnit.BATCH) return Time(int(time.value * state.max_duration.value), state.max_duration.unit) if time.unit == TimeUnit.EPOCH: @@ -143,7 +143,7 @@ def _convert_time(time: Union[str, Time[int], Time[float]], state: State, ssr: f # so, convert the time into batches if state.dataloader_len is None: raise RuntimeError("Cannot convert time, as state.dataloader_len is None.") - time = Time(value=time.value * state.dataloader_len, unit=TimeUnit.BATCH) + time = Time(value=time.value * int(state.dataloader_len), unit=TimeUnit.BATCH) return Time(value=int(time.value * ssr), unit=time.unit) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 7728d9eaba..811b3b63e0 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1248,7 +1248,12 @@ def eval(self, is_batch: bool): # The epoch provided to `set_epoch` need not be sequential, so this is fine. self.state.dataloader.sampler.set_epoch(int(self.state.timer.batch)) - for self.state.batch in itertools.islice(self.state.dataloader, self.state.dataloader_len): + if self.state.dataloader_len is None: + iterable = self.state.dataloader + else: + iterable = itertools.islice(self.state.dataloader, int(self.state.dataloader_len)) + + for self.state.batch in iterable: self.state.batch = self._device.batch_to_device(self.state.batch) if evaluator.dataloader.device_transforms: self.state.batch = evaluator.dataloader.device_transforms(self.state.batch) diff --git a/tests/algorithms/test_selective_backprop.py b/tests/algorithms/test_selective_backprop.py index 9fb1a7bea0..8568631755 100644 --- a/tests/algorithms/test_selective_backprop.py +++ b/tests/algorithms/test_selective_backprop.py @@ -149,7 +149,7 @@ def state(minimal_state: State, conv_model: ComposerClassifier, loss_fun_tuple: minimal_state.model = conv_model minimal_state.timer.epoch._value = epoch - minimal_state.timer.batch._value = epoch * minimal_state.dataloader_len + batch + minimal_state.timer.batch._value = epoch * int(minimal_state.dataloader_len) + batch minimal_state.timer.batch_in_epoch._value = batch return minimal_state From e4facaa3bcfd0ebf483d38e2f6a3bab3039fed30 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 12 Apr 2022 13:53:32 -0700 Subject: [PATCH 04/17] Added `dataloader_label`; removed `evaluators` from State It can be useful for algorithms and callbacks to know which dataloader is active, so added the `dataloader_label` to the state. Removed `evaluators` from state, as nothing is using that anymore. --- composer/core/state.py | 60 +++++++++++++++++++----- composer/core/time.py | 6 ++- composer/profiler/dataloader_profiler.py | 17 ++++--- composer/profiler/profiler_schedule.py | 2 +- composer/trainer/trainer.py | 44 ++++++++--------- docs/source/doctest_fixtures.py | 3 +- tests/algorithms/test_colout.py | 4 +- tests/algorithms/test_layer_freezing.py | 3 +- tests/fixtures/dummy_fixtures.py | 20 ++------ tests/fixtures/new_fixtures.py | 3 +- tests/test_state.py | 3 +- tests/trainer/test_checkpoint.py | 6 +-- tests/trainer/test_ddp_sync_strategy.py | 11 ++--- tests/trainer/test_scheduler.py | 2 +- tests/utils/classifer.py | 9 +--- 15 files changed, 104 insertions(+), 89 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 30c8f7bb41..3a576bf8b9 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -24,7 +24,6 @@ import composer.core.types as types from composer.core.algorithm import Algorithm from composer.core.callback import Callback - from composer.core.evaluator import Evaluator from composer.profiler import Profiler __all__ = ["State"] @@ -96,9 +95,14 @@ class State(Serializable): ``rank_zero_seed + dist.get_global_rank()``. grad_accum (int): The number of gradient accumulation steps to use. With this argument, micro batch size for each device becomes ``microbatch_size = train_batch_size / (num_devices * grad_accum)``. - evaluators (evaluator.Evaluator | Sequence[evaluator.Evaluator]): - The evaluators contain the evaluation dataset(s) used for evaluation with specific metrics. - max_duration (str or Time): The maximum duration to train for. + dataloader (types.DataLoader, optional): The active DataLoader. + dataloader_len (int | Time[int], optional): The number of batches per dataloader iteration (e.g. epoch). + The trainer will yield the first ``dataloader_len`` batches per iteration. If ``None`` (the default), + the entire dataloader will be iterated over. + dataloader_label (str, optional): The name for the dataloader. Required if ``dataloader`` is specified. (default: ``None``) + By convention, the training dataloader is called ``'train'``. The evaluator dataloader is called + ``'eval'``, or when multiple evaluators are used, the name of the evaluator. + max_duration (str | Time): The maximum duration to train for. precision (str | Precision): The numerical precision to use for training. See :class:`~.Precision` for the supported precisions. precision_context (Callable[[Precision], ContextManager]): Function to produce a context manager to mandate precision. @@ -145,6 +149,9 @@ class State(Serializable): +-----------------------+-------------------------------------------------------------+ """ + _dataloader: Optional[types.DataLoader] + _dataloader_label: Optional[str] + _dataloader_len: Optional[Time[int]] _max_duration: Time[int] batch: types.Batch batch_num_samples: int @@ -163,8 +170,10 @@ def __init__( rank_zero_seed: int, # data configurations - evaluators: Optional[Union[Evaluator, Sequence[Evaluator]]] = None, grad_accum: int = 1, + dataloader: Optional[types.DataLoader] = None, + dataloader_label: Optional[str] = None, + dataloader_len: Optional[Union[int, Time[int]]] = None, # precision precision: Union[str, Precision] = Precision.FP32, @@ -183,9 +192,8 @@ def __init__( self.rank_zero_seed = rank_zero_seed self.model = model self.grad_accum = grad_accum - self._dataloader: Optional[types.DataLoader] = None - self._dataloader_len: Optional[Time[int]] = None - self.evaluators = list(ensure_tuple(evaluators)) + self.set_dataloader(dataloader, dataloader_label, dataloader_len) + self.dataloader_len = dataloader_len self.max_duration = max_duration self.timer = Timer() @@ -364,11 +372,36 @@ def load_state_dict(self, state: Dict[str, Any], strict: bool = False): def dataloader(self): """The dataloader.""" return self._dataloader - - @dataloader.setter - def dataloader(self, dataloader: Optional[types.DataLoader]): + + @property + def dataloader_label(self): + """The dataloader label. By convention, the training dataloader is called ``'train'``. The evaluator dataloader + is called ``'eval'``, or when multiple evaluators are used, the name of the evaluator. + + Returns: + Optional[str]: The dataloader label, or None if no dataloader is set. + """ + return self._dataloader_label + + def set_dataloader( + self, + dataloader: Optional[types.DataLoader] = None, + dataloader_label: Optional[str] = None, + dataloader_len: Optional[Union[int, Time[int]]] = None, + ): + """Update the dataloader and dataloader label. + + Args: + dataloader (types.DataLoader, optional): The dataloader. Defaults to None. + dataloader_label (str, optional): The dataloader label. Must be ``None`` if and only if + ``dataloader`` is None. Defaults to None. + dataloader_len (int, int): + """ + if (dataloader == None) != (dataloader_label == None): + raise ValueError("Both `dataloader` and `dataloader_label` should be None, or neither should be None.") self._dataloader = dataloader - self.dataloader_len = None # setting it to None will do a failsafe read of len(dataloader) + self._dataloader_label = dataloader_label + self.dataloader_len = dataloader_len # setting it to None will do a failsafe read of len(dataloader) @property def dataloader_len(self): @@ -385,6 +418,9 @@ def dataloader_len(self, num_batches: Optional[Union[int, Time[int]]]): if isinstance(num_batches, int): num_batches = Time(num_batches, TimeUnit.BATCH) if self._dataloader is None: + if num_batches is None: + self._dataloader_len = None + return raise RuntimeError("`State.dataloader_len` cannot be set if the dataloader is not defined.") try: dataloader_len = len(self._dataloader) diff --git a/composer/core/time.py b/composer/core/time.py index 6af0220d3e..c9fafe6517 100644 --- a/composer/core/time.py +++ b/composer/core/time.py @@ -224,13 +224,15 @@ def _parse(self, other: object) -> Time: # parse ``other`` into a Time object if isinstance(other, Time): return other + if isinstance(other, int): + return Time(other, self.unit) if isinstance(other, str): other_parsed = Time.from_timestring(other) warnings.warn( textwrap.dedent(f"""\ TimeImplicitStringConversion: - Implicitly converting {other} to {other_parsed}. - To fix this warning, replace {other} with {other_parsed}.""")) + Implicitly converting '{other}' to '{repr(other_parsed)}'. + To fix this warning, replace '{other}' with '{repr(other_parsed)}'.""")) return other_parsed raise TypeError(f"Cannot convert type {other} to {self.__class__.__name__}") diff --git a/composer/profiler/dataloader_profiler.py b/composer/profiler/dataloader_profiler.py index 2d3f408efc..012a42149a 100644 --- a/composer/profiler/dataloader_profiler.py +++ b/composer/profiler/dataloader_profiler.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Iterator, Optional from composer.core.callback import Callback +from composer.core.event import Event from composer.datasets.dataloader import WrappedDataLoader if TYPE_CHECKING: @@ -60,18 +61,16 @@ class DataLoaderProfiler(Callback): instance of this :class:`.DataLoaderProfiler` callback. """ - def fit_start(self, state: State, logger: Logger): + def run_event(self, event: Event, state: State, logger: Logger): del logger # unused + if event not in (Event.FIT_START, Event.EVAL_START): + return if state.profiler is None: raise RuntimeError(("The Composer Profiler was not enabled, which is required to use the " f"{type(self).__name__}. To enable, set the `prof_schedule` argument of the Trainer.")) - assert state.dataloader, "dataloader should be set on FIT_START" + assert state.dataloader, "dataloader should be set on FIT_START or EVAL_START" + assert state.dataloader_label is not None, "dataloader label should be set on FIT_START or EVAL_START" if not _ProfiledDataLoader.is_dataloader_already_wrapped(state.dataloader): - state.dataloader = _ProfiledDataLoader(state.profiler, state.dataloader, "train") - - for evaluator in state.evaluators: - - if not _ProfiledDataLoader.is_dataloader_already_wrapped(evaluator.dataloader.dataloader): - evaluator.dataloader.dataloader = _ProfiledDataLoader(state.profiler, evaluator.dataloader.dataloader, - evaluator.label) + state.set_dataloader(_ProfiledDataLoader(state.profiler, state.dataloader, state.dataloader_label), + state.dataloader_label) diff --git a/composer/profiler/profiler_schedule.py b/composer/profiler/profiler_schedule.py index 2f75cd6ee5..bd764b941c 100644 --- a/composer/profiler/profiler_schedule.py +++ b/composer/profiler/profiler_schedule.py @@ -55,7 +55,7 @@ def schedule(state: State): return ProfilerAction.SKIP if position_in_cycle < wait + warmup: return ProfilerAction.WARMUP - is_last_batch_in_epoch = state.timer.batch_in_epoch == state.steps_per_epoch - 1 + is_last_batch_in_epoch = state.dataloader_len is not None and state.timer.batch_in_epoch == state.dataloader_len - 1 if position_in_cycle == cycle_len - 1 or is_last_batch_in_epoch: return ProfilerAction.ACTIVE_AND_SAVE return ProfilerAction.ACTIVE diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 65e1da5c97..0239f12e82 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -451,13 +451,12 @@ def __init__( self, *, model: ComposerModel, - train_dataloader: Union[DataLoader, DataSpec], # TODO move param to fit() - max_duration: Union[int, str, Time], # TODO move param to fit() - eval_dataloader: Optional[Union[DataLoader, DataSpec, Evaluator, - Sequence[Evaluator]]] = None, # TODO move param to fit() - algorithms: Optional[Union[Algorithm, Sequence[Algorithm]]] = None, # TODO move param to fit() - optimizers: Optional[torch.optim.Optimizer] = None, # TODO move param to fit() - schedulers: Optional[Union[ComposerScheduler, Sequence[ComposerScheduler]]] = None, # TODO move param to fit() + train_dataloader: Union[DataLoader, DataSpec], + max_duration: Union[int, str, Time], + eval_dataloader: Optional[Union[DataLoader, DataSpec, Evaluator, Sequence[Evaluator]]] = None, + algorithms: Optional[Union[Algorithm, Sequence[Algorithm]]] = None, + optimizers: Optional[torch.optim.Optimizer] = None, + schedulers: Optional[Union[ComposerScheduler, Sequence[ComposerScheduler]]] = None, # device device: Optional[Union[str, Device]] = None, @@ -465,12 +464,12 @@ def __init__( # training hparams grad_accum: int = 1, grad_clip_norm: Optional[float] = None, - validate_every_n_batches: int = -1, # TODO move param to fit() - validate_every_n_epochs: int = 1, # TODO move param to fit() - compute_training_metrics: bool = False, # TODO move param to fit() + validate_every_n_batches: int = -1, + validate_every_n_epochs: int = 1, + compute_training_metrics: bool = False, precision: Union[str, Precision] = Precision.FP32, - scale_schedule_ratio: float = 1.0, # TODO move param to fit() - step_schedulers_every_batch: Optional[bool] = None, # TODO move param to fit() + scale_schedule_ratio: float = 1.0, + step_schedulers_every_batch: Optional[bool] = None, # dist hparams dist_timeout: float = 300.0, @@ -504,11 +503,11 @@ def __init__( save_num_checkpoints_to_keep: int = -1, # subset parameters - train_subset_num_batches: Optional[int] = None, # TODO move param to fit() - eval_subset_num_batches: Optional[int] = None, # TODO move param to fit() + train_subset_num_batches: Optional[int] = None, + eval_subset_num_batches: Optional[int] = None, # DeepSpeed - deepspeed_config: Union[bool, Dict[str, Any]] = False, # TODO move param to fit() + deepspeed_config: Union[bool, Dict[str, Any]] = False, # profiling prof_trace_handlers: Optional[Union[TraceHandler, Sequence[TraceHandler]]] = None, @@ -670,7 +669,6 @@ def __init__( grad_accum=grad_accum, precision=precision, precision_context=precision_context, - evaluators=self.evaluators, optimizers=optimizers, ) @@ -930,7 +928,7 @@ def _spin_dataloaders(self): """ # spin the evaluator dataloaders once to initialize its sampler deterministically # so it does not affect any other RNG reads - for evaluator in self.state.evaluators: + for evaluator in self.evaluators: dataloader = evaluator.dataloader.dataloader # FFCV dataloaders use their own sampling strategy if isinstance(dataloader, torch.utils.data.DataLoader) and isinstance(dataloader.sampler, @@ -969,9 +967,11 @@ def _train_loop(self) -> None: if self.state.dataloader is None: # If the dataloader is not already set on State, then set it from what was passed in on init - self.state.dataloader = self._train_data_spec.dataloader + self.state.set_dataloader(self._train_data_spec.dataloader, "train") self.state.dataloader_len = self.train_subset_num_batches + assert self.state.dataloader is not None + for scheduler in ensure_tuple(self._schedulers): if isinstance(scheduler, PyTorchScheduler): scale_pytorch_scheduler(scheduler, self._scale_schedule_ratio) @@ -1250,13 +1250,15 @@ def eval(self, is_batch: bool): # back up the original dataloader on the state, so we can restore it after evaluation is finished original_dataloader = self.state.dataloader + original_dataloader_label = self.state.dataloader_label original_num_batches = self.state.dataloader_len self.state.model.eval() with torch.no_grad(): - for evaluator in self.state.evaluators: - self.state.dataloader = evaluator.dataloader.dataloader + for evaluator in self.evaluators: + self.state.set_dataloader(evaluator.dataloader.dataloader, evaluator.label) + assert self.state.dataloader is not None, "dataloader is set" self.state.dataloader_len = self.eval_subset_num_batches self.engine.run_event(Event.EVAL_START) @@ -1306,7 +1308,7 @@ def eval(self, is_batch: bool): if restore_model_train: self.state.model.train() - self.state.dataloader = original_dataloader + self.state.set_dataloader(original_dataloader, original_dataloader_label) self.state.dataloader_len = original_num_batches def _use_grad_scaling(self, precision: Union[str, Precision], scaler: Optional[GradScaler]) -> bool: diff --git a/docs/source/doctest_fixtures.py b/docs/source/doctest_fixtures.py index 7b11dd3b38..6bb3a34126 100644 --- a/docs/source/doctest_fixtures.py +++ b/docs/source/doctest_fixtures.py @@ -110,8 +110,7 @@ model=model, optimizers=optimizer, grad_accum=1, - train_dataloader=train_dataloader, - evaluators=[], + dataloader=train_dataloader, max_duration="1ep", precision="fp32", ) diff --git a/tests/algorithms/test_colout.py b/tests/algorithms/test_colout.py index edb7f99d85..b41e435dd1 100644 --- a/tests/algorithms/test_colout.py +++ b/tests/algorithms/test_colout.py @@ -212,8 +212,8 @@ def test_apply_sample(self, colout_algorithm: ColOut, minimal_state: State, empt original_image, _ = dataset[0] assert isinstance(original_image, Image.Image) - minimal_state.dataloader = dataloader - colout_algorithm.apply(Event.INIT, minimal_state, empty_logger) + minimal_state.set_dataloader(dataloader, "train") + colout_algorithm.apply(Event.FIT_START, minimal_state, empty_logger) new_image, _ = dataset[0] assert isinstance(new_image, Image.Image) diff --git a/tests/algorithms/test_layer_freezing.py b/tests/algorithms/test_layer_freezing.py index 5f73941983..9deec8f033 100644 --- a/tests/algorithms/test_layer_freezing.py +++ b/tests/algorithms/test_layer_freezing.py @@ -20,14 +20,13 @@ def _generate_state(epoch: int, max_epochs: int): optimizers=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.99), precision=Precision.FP32, grad_accum=1, - evaluators=Mock(), max_duration=f'{max_epochs}ep') # fast forward by epochs for _ in range(epoch): state.timer.on_epoch_complete() - state.dataloader = Mock(__len__=lambda x: 100) + state.set_dataloader(Mock(__len__=lambda x: 100), "train") return state diff --git a/tests/fixtures/dummy_fixtures.py b/tests/fixtures/dummy_fixtures.py index 3f4cb03df8..7589e70560 100644 --- a/tests/fixtures/dummy_fixtures.py +++ b/tests/fixtures/dummy_fixtures.py @@ -7,10 +7,8 @@ import torch import torch.utils.data from torch.optim import Optimizer -from torchmetrics import MetricCollection -from torchmetrics.classification.accuracy import Accuracy -from composer.core import DataSpec, Evaluator, Precision, State +from composer.core import DataSpec, Precision, State from composer.core.types import DataLoader, PyTorchScheduler from composer.datasets import DataLoaderHparams, DatasetHparams from composer.loggers import Logger @@ -102,21 +100,17 @@ def dummy_scheduler(dummy_optimizer: Optimizer): @pytest.fixture() def dummy_state(dummy_model: SimpleBatchPairModel, dummy_train_dataloader: DataLoader, dummy_optimizer: Optimizer, - dummy_scheduler: PyTorchScheduler, dummy_val_dataloader: DataLoader, rank_zero_seed: int) -> State: - evaluators = [ - Evaluator(label="dummy_label", dataloader=dummy_val_dataloader, metrics=dummy_model.metrics(train=False)) - ] + dummy_scheduler: PyTorchScheduler, rank_zero_seed: int) -> State: state = State( model=dummy_model, precision=Precision.FP32, grad_accum=1, rank_zero_seed=rank_zero_seed, - evaluators=evaluators, optimizers=dummy_optimizer, max_duration="10ep", ) state.schedulers = dummy_scheduler - state.dataloader = dummy_train_dataloader + state.set_dataloader(dummy_train_dataloader, "train") return state @@ -220,19 +214,15 @@ def simple_conv_model_input(): @pytest.fixture() -def state_with_model(simple_conv_model: torch.nn.Module, dummy_train_dataloader: DataLoader, - dummy_val_dataloader: DataLoader, rank_zero_seed: int): - metric_coll = MetricCollection([Accuracy()]) - evaluators = [Evaluator(label="dummy_label", dataloader=dummy_val_dataloader, metrics=metric_coll)] +def state_with_model(simple_conv_model: torch.nn.Module, dummy_train_dataloader: DataLoader, rank_zero_seed: int): state = State( grad_accum=1, rank_zero_seed=rank_zero_seed, max_duration="100ep", model=simple_conv_model, precision=Precision.FP32, - evaluators=evaluators, ) - state.dataloader = dummy_train_dataloader + state.set_dataloader(dummy_train_dataloader, "train") return state diff --git a/tests/fixtures/new_fixtures.py b/tests/fixtures/new_fixtures.py index b43482f821..2b36795ed3 100644 --- a/tests/fixtures/new_fixtures.py +++ b/tests/fixtures/new_fixtures.py @@ -18,10 +18,9 @@ def minimal_state(rank_zero_seed: int): state = State( model=SimpleModel(), rank_zero_seed=rank_zero_seed, - evaluators=[], max_duration='100ep', ) - state.dataloader = DataLoader(RandomClassificationDataset()) + state.set_dataloader(DataLoader(RandomClassificationDataset()), "train") return state diff --git a/tests/test_state.py b/tests/test_state.py index 7d6ab7db50..5fd691cdba 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -29,14 +29,13 @@ def get_dummy_state(model: ComposerModel, train_dataloader: DataLoader, val_data rank_zero_seed=random.randint(0, 100), precision=Precision.AMP, max_duration=f"{random.randint(0, 100)}ep", - evaluators=evaluators, optimizers=optimizers, algorithms=[ChannelsLastHparams().initialize_object()]) state.schedulers = torch.optim.lr_scheduler.StepLR(optimizers, step_size=3) state.loss = random_tensor() state.batch = (random_tensor(), random_tensor()) state.outputs = random_tensor() - state.dataloader = train_dataloader + state.set_dataloader(train_dataloader, "train") return state diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 333641b85d..bc10b52ad6 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -416,11 +416,11 @@ def _validate_events_called_expected_number_of_times(trainer: Trainer): if trainer._validate_every_n_epochs > 0: num_evals = num_epochs // trainer._validate_every_n_epochs - assert state.evaluators is not None - for evaluator in state.evaluators: + assert trainer.evaluators is not None + for evaluator in trainer.evaluators: assert evaluator.dataloader is not None assert trainer.eval_subset_num_batches is not None - num_eval_steps = num_evals * trainer.eval_subset_num_batches * len(state.evaluators) + num_eval_steps = num_evals * trainer.eval_subset_num_batches * len(trainer.evaluators) event_to_num_expected_invocations = { Event.INIT: 1, diff --git a/tests/trainer/test_ddp_sync_strategy.py b/tests/trainer/test_ddp_sync_strategy.py index 30f65303df..ad88534477 100644 --- a/tests/trainer/test_ddp_sync_strategy.py +++ b/tests/trainer/test_ddp_sync_strategy.py @@ -6,10 +6,8 @@ import torch import torch.nn as nn from torch import Tensor -from torchmetrics import MetricCollection -from torchmetrics.classification.accuracy import Accuracy -from composer.core import Evaluator, State +from composer.core import State from composer.core.types import DataLoader from composer.trainer.ddp import _ddp_sync_context, _prepare_ddp_module from composer.utils import dist @@ -49,20 +47,17 @@ def loss(self, output: Tensor, target: Tensor): ]) @pytest.mark.world_size(2) def test_ddp_sync_strategy(ddp_sync_strategy: str, expected_grads: List[Optional[float]], - dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader, rank_zero_seed: int): + dummy_train_dataloader: DataLoader, rank_zero_seed: int): original_model = MinimalConditionalModel() # ddp = DDP(backend="gloo", find_unused_parameters=True, sync_strategy=ddp_sync_strategy, timeout=5.) optimizer = torch.optim.SGD(original_model.parameters(), 0.1) - metric_coll = MetricCollection([Accuracy()]) - evaluators = [Evaluator(label="dummy_label", dataloader=dummy_val_dataloader, metrics=metric_coll)] state = State(model=original_model, rank_zero_seed=rank_zero_seed, optimizers=optimizer, grad_accum=2, max_duration="1ep", - evaluators=evaluators, precision='fp32') - state.dataloader = dummy_train_dataloader + state.set_dataloader(dummy_train_dataloader, "train") batches = [[(1, Tensor([1])), (1, Tensor([2]))], [(2, Tensor([1])), (2, Tensor([2]))]] state.model = _prepare_ddp_module(state.model, find_unused_parameters=True) diff --git a/tests/trainer/test_scheduler.py b/tests/trainer/test_scheduler.py index f5a62d196f..b191ed32c3 100644 --- a/tests/trainer/test_scheduler.py +++ b/tests/trainer/test_scheduler.py @@ -28,7 +28,7 @@ def dummy_schedulers_state(dummy_model: torch.nn.Module, rank_zero_seed: int): rank_zero_seed=rank_zero_seed, max_duration=MAX_DURATION, ) - state.dataloader = cast(torch.utils.data.DataLoader, [None] * STEPS_PER_EPOCH) + state.set_dataloader(cast(torch.utils.data.DataLoader, [None] * STEPS_PER_EPOCH), "train") return state diff --git a/tests/utils/classifer.py b/tests/utils/classifer.py index 736c798094..a4f085a6a5 100644 --- a/tests/utils/classifer.py +++ b/tests/utils/classifer.py @@ -5,10 +5,8 @@ import torch import torch.nn.functional as F from torch.optim import SGD, Optimizer -from torchmetrics import MetricCollection -from torchmetrics.classification.accuracy import Accuracy -from composer.core import Algorithm, Engine, Evaluator, Event, Precision, State +from composer.core import Algorithm, Engine, Event, Precision, State from composer.core.types import DataLoader from composer.loggers import Logger from tests.utils.model import SimpleModel @@ -16,18 +14,15 @@ def _get_state(train_dataloader: DataLoader, eval_dataloader: DataLoader, rank_zero_seed: int): model = SimpleModel() - metric_coll = MetricCollection([Accuracy()]) - evaluators = [Evaluator(label="dummy_label", dataloader=eval_dataloader, metrics=metric_coll)] state = State( model=model, rank_zero_seed=rank_zero_seed, optimizers=SGD(model.parameters(), lr=.001, momentum=0.0), max_duration="1ep", - evaluators=evaluators, grad_accum=1, precision=Precision.FP32, ) - state.dataloader = train_dataloader + state.set_dataloader(train_dataloader, "train") return state From 56fd87de5dfe5315de07ff993e27b6811684d2e5 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 12 Apr 2022 14:13:08 -0700 Subject: [PATCH 05/17] Fixed pyright --- tests/test_state.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_state.py b/tests/test_state.py index 5fd691cdba..516e071cef 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -20,10 +20,9 @@ def random_tensor(size=(4, 10)): return torch.rand(*size) -def get_dummy_state(model: ComposerModel, train_dataloader: DataLoader, val_dataloader: DataLoader): +def get_dummy_state(model: ComposerModel, train_dataloader: DataLoader): optimizers = torch.optim.Adadelta(model.parameters()) - evaluators = [Evaluator(label="dummy_label", dataloader=val_dataloader, metrics=model.metrics(train=False))] state = State(model=model, grad_accum=random.randint(0, 100), rank_zero_seed=random.randint(0, 100), From b89c3bc16f222638f638d1543b7482a5c218d0e6 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Tue, 12 Apr 2022 14:14:30 -0700 Subject: [PATCH 06/17] Fixed pyright --- tests/test_state.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/test_state.py b/tests/test_state.py index 516e071cef..7c14c0499c 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -8,7 +8,7 @@ from torch.functional import Tensor from composer.algorithms import ChannelsLastHparams -from composer.core import DataSpec, Evaluator, Precision, State +from composer.core import DataSpec, Precision, State from composer.core.types import Batch, DataLoader from composer.datasets.dataloader import DataLoaderHparams from composer.datasets.hparams import DatasetHparams @@ -90,14 +90,18 @@ def get_batch(dataset_hparams: DatasetHparams, dataloader_hparams: DataLoaderHpa raise RuntimeError("No batch in dataloader") -def test_state_serialize(tmpdir: pathlib.Path, dummy_model: ComposerModel, dummy_dataloader_hparams: DataLoaderHparams, - dummy_train_dataset_hparams: DatasetHparams, dummy_train_dataloader: DataLoader, - dummy_val_dataset_hparams: DatasetHparams, dummy_val_dataloader: DataLoader): +def test_state_serialize( + tmpdir: pathlib.Path, + dummy_model: ComposerModel, + dummy_dataloader_hparams: DataLoaderHparams, + dummy_train_dataset_hparams: DatasetHparams, + dummy_train_dataloader: DataLoader, +): assert isinstance(dummy_model, SimpleBatchPairModel) - state1 = get_dummy_state(dummy_model, dummy_train_dataloader, dummy_val_dataloader) - state2 = get_dummy_state(dummy_model, dummy_train_dataloader, dummy_val_dataloader) + state1 = get_dummy_state(dummy_model, dummy_train_dataloader) + state2 = get_dummy_state(dummy_model, dummy_train_dataloader) # train one step to set the optimizer states batch = get_batch(dummy_train_dataset_hparams, dummy_dataloader_hparams) From 096b44c87cdba2e599a4c89cd9ffd9cf006a2a1a Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Wed, 13 Apr 2022 12:47:36 -0700 Subject: [PATCH 07/17] Made `max_duration` optional --- .../layer_freezing/layer_freezing.py | 4 ++- .../progressive_resizing.py | 4 ++- .../selective_backprop/selective_backprop.py | 5 +++- .../seq_length_warmup/seq_length_warmup.py | 1 + .../stochastic_depth/stochastic_depth.py | 6 +++-- composer/callbacks/checkpoint_saver.py | 13 +++++++--- composer/core/state.py | 26 ++++++++++++------- composer/optim/scheduler.py | 2 ++ composer/trainer/trainer.py | 15 +++++------ tests/algorithms/test_progressive_resizing.py | 1 + tests/algorithms/test_stochastic_depth.py | 1 + tests/trainer/test_checkpoint.py | 1 + tests/trainer/test_scale_schedule.py | 1 + tests/trainer/test_scheduler.py | 1 + tests/trainer/test_trainer.py | 1 + 15 files changed, 57 insertions(+), 25 deletions(-) diff --git a/composer/algorithms/layer_freezing/layer_freezing.py b/composer/algorithms/layer_freezing/layer_freezing.py index 4f278e9d58..c309f461e7 100644 --- a/composer/algorithms/layer_freezing/layer_freezing.py +++ b/composer/algorithms/layer_freezing/layer_freezing.py @@ -133,10 +133,12 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: del event # unused optimizers = state.optimizers assert optimizers is not None + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed duration is available on epoch end" freeze_depth, freeze_percentage = freeze_layers( model=state.model, optimizers=optimizers, - current_duration=float(state.get_elapsed_duration()), + current_duration=float(elapsed_duration), freeze_start=self.freeze_start, freeze_level=self.freeze_level, ) diff --git a/composer/algorithms/progressive_resizing/progressive_resizing.py b/composer/algorithms/progressive_resizing/progressive_resizing.py index 737ce82469..c5c7f6977f 100644 --- a/composer/algorithms/progressive_resizing/progressive_resizing.py +++ b/composer/algorithms/progressive_resizing/progressive_resizing.py @@ -197,7 +197,9 @@ def apply(self, event: Event, state: State, logger: Optional[Logger] = None) -> # Calculate the current size of the inputs to use initial_size = self.initial_scale finetune_fraction = self.finetune_fraction - scale_frac_elapsed = min([state.get_elapsed_duration().value / (1 - finetune_fraction), 1]) + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed duration is set on Event.AFTER_DATALOADER" + scale_frac_elapsed = min([elapsed_duration.value / (1 - finetune_fraction), 1]) # Linearly increase to full size at the start of the fine tuning period scale_factor = initial_size + (1 - initial_size) * scale_frac_elapsed diff --git a/composer/algorithms/selective_backprop/selective_backprop.py b/composer/algorithms/selective_backprop/selective_backprop.py index e9c64ce24e..788194de6c 100644 --- a/composer/algorithms/selective_backprop/selective_backprop.py +++ b/composer/algorithms/selective_backprop/selective_backprop.py @@ -204,8 +204,11 @@ def match(self, event: Event, state: State) -> bool: if not is_keep: return False + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed duration should be set on Event.AFTER_DATALOADER" + is_chosen = should_selective_backprop( - current_duration=float(state.get_elapsed_duration()), + current_duration=float(), batch_idx=state.timer.batch_in_epoch.value, start=self.start, end=self.end, diff --git a/composer/algorithms/seq_length_warmup/seq_length_warmup.py b/composer/algorithms/seq_length_warmup/seq_length_warmup.py index c9345665f7..b826f302b1 100644 --- a/composer/algorithms/seq_length_warmup/seq_length_warmup.py +++ b/composer/algorithms/seq_length_warmup/seq_length_warmup.py @@ -183,6 +183,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: return assert state.dataloader is not None, "dataloader should be set on AFTER_DATALOADER" + assert state.max_duration is not None, "max_duration should be set on AFTER_DATALOADER" # in order to avoid OOMs, we do a forward and a backward pass on a dummy input. if not self._activated: diff --git a/composer/algorithms/stochastic_depth/stochastic_depth.py b/composer/algorithms/stochastic_depth/stochastic_depth.py index ab70b88a19..3740ede5f8 100644 --- a/composer/algorithms/stochastic_depth/stochastic_depth.py +++ b/composer/algorithms/stochastic_depth/stochastic_depth.py @@ -232,8 +232,10 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: logger.data_epoch({'stochastic_depth/num_stochastic_layers': num_stochastic_layers}) elif event == Event.BATCH_START: - if state.get_elapsed_duration() < self.drop_warmup: - current_drop_rate = float(state.get_elapsed_duration() / self.drop_warmup) * self.drop_rate + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed duration is set on BATCH_START" + if elapsed_duration < self.drop_warmup: + current_drop_rate = float(elapsed_duration / self.drop_warmup) * self.drop_rate _update_drop_rate(state.model, stochastic_layer, current_drop_rate, self.drop_distribution) else: current_drop_rate = self.drop_rate diff --git a/composer/callbacks/checkpoint_saver.py b/composer/callbacks/checkpoint_saver.py index ae811fb6ce..710834dd13 100644 --- a/composer/callbacks/checkpoint_saver.py +++ b/composer/callbacks/checkpoint_saver.py @@ -57,7 +57,10 @@ def checkpoint_periodically(interval: Union[str, int, Time]) -> Callable[[State, def save_interval(state: State, event: Event): nonlocal last_checkpoint_batch - if state.get_elapsed_duration() >= 1.0: + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed_duration is set on the BATCH_CHECKPOINT and EPOCH_CHECKPOINT" + + if elapsed_duration >= 1.0: # if doing batch-wise checkpointing, and we saved a checkpoint at the batch_checkpoint event # right before the epoch_checkpoint event, do not save another checkpoint at the epoch_checkpoint # event if the batch count didn't increase. @@ -317,12 +320,16 @@ def fit_start(self, state: State, logger: Logger) -> None: def batch_checkpoint(self, state: State, logger: Logger): if self.save_interval(state, Event.BATCH_CHECKPOINT): # If training is finished, log at the FIT loglevel - log_level = LogLevel.BATCH if state.get_elapsed_duration() < 1.0 else LogLevel.FIT + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed_duration is set on Event.BATCH_CHECKPOINT" + log_level = LogLevel.BATCH if elapsed_duration < 1.0 else LogLevel.FIT self._save_checkpoint(state, logger, log_level) def epoch_checkpoint(self, state: State, logger: Logger): if self.save_interval(state, Event.EPOCH_CHECKPOINT): - log_level = LogLevel.EPOCH if state.get_elapsed_duration() < 1.0 else LogLevel.FIT + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed_duration is set on Event.BATCH_CHECKPOINT" + log_level = LogLevel.EPOCH if elapsed_duration < 1.0 else LogLevel.FIT self._save_checkpoint(state, logger, log_level) def _save_checkpoint(self, state: State, logger: Logger, log_level: LogLevel): diff --git a/composer/core/state.py b/composer/core/state.py index 3a576bf8b9..474806fccc 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -93,7 +93,7 @@ class State(Serializable): model (torch.nn.Module): The model, typically as a subclass of :class:`~.ComposerModel`. rank_zero_seed (int): The seed used on the rank zero process. It is assumed that each rank's seed is ``rank_zero_seed + dist.get_global_rank()``. - grad_accum (int): The number of gradient accumulation steps to use. With this argument, micro batch size for + grad_accum (int, optional): The number of gradient accumulation steps to use. With this argument, micro batch size for each device becomes ``microbatch_size = train_batch_size / (num_devices * grad_accum)``. dataloader (types.DataLoader, optional): The active DataLoader. dataloader_len (int | Time[int], optional): The number of batches per dataloader iteration (e.g. epoch). @@ -102,7 +102,7 @@ class State(Serializable): dataloader_label (str, optional): The name for the dataloader. Required if ``dataloader`` is specified. (default: ``None``) By convention, the training dataloader is called ``'train'``. The evaluator dataloader is called ``'eval'``, or when multiple evaluators are used, the name of the evaluator. - max_duration (str | Time): The maximum duration to train for. + max_duration (str | Time, optional): The maximum duration to train for. (default: ``None``) precision (str | Precision): The numerical precision to use for training. See :class:`~.Precision` for the supported precisions. precision_context (Callable[[Precision], ContextManager]): Function to produce a context manager to mandate precision. @@ -152,7 +152,7 @@ class State(Serializable): _dataloader: Optional[types.DataLoader] _dataloader_label: Optional[str] _dataloader_len: Optional[Time[int]] - _max_duration: Time[int] + _max_duration: Optional[Time[int]] batch: types.Batch batch_num_samples: int batch_num_tokens: int @@ -165,10 +165,12 @@ def __init__( # model model: torch.nn.Module, - # stopping conditions - max_duration: Union[str, Time[int]], + # determinism rank_zero_seed: int, + # stopping conditions + max_duration: Optional[Union[str, Time[int]]] = None, + # data configurations grad_accum: int = 1, dataloader: Optional[types.DataLoader] = None, @@ -240,20 +242,26 @@ def max_duration(self): return self._max_duration @max_duration.setter - def max_duration(self, max_duration: Union[str, Time[int]]): + def max_duration(self, max_duration: Optional[Union[str, Time[int]]]): + if max_duration is None: + self._max_duration = None + return if isinstance(max_duration, str): max_duration = cast(Time[int], Time.from_timestring(max_duration)) if max_duration.unit == TimeUnit.DURATION: raise ValueError("TimeUnit.DURATION is not allowed as a unit for max_duration") self._max_duration = max_duration - def get_elapsed_duration(self) -> Time[float]: + def get_elapsed_duration(self) -> Optional[Time[float]]: """Get the elapsed training duration. Returns: - Time: The elapsed duration, in :attr:`TimeUnit.DURATION`. ``Time(0.0, TimeUnit.DURATION)`` represents the - beginning of training and ``Time(1.0, TimeUnit.DURATION)`` represents a completed training process. + Optional[Time[float]]: The elapsed duration, in :attr:`TimeUnit.DURATION`. + ``Time(0.0, TimeUnit.DURATION)`` represents the beginning of training and ``Time(1.0, TimeUnit.DURATION)`` + represents a completed training process. Returns ``None`` if ``max_duration`` is None. """ + if self.max_duration is None: + return None return self.timer.get(self.max_duration.unit) / self.max_duration @property diff --git a/composer/optim/scheduler.py b/composer/optim/scheduler.py index 4ce97980a4..e3b19ba111 100644 --- a/composer/optim/scheduler.py +++ b/composer/optim/scheduler.py @@ -130,6 +130,8 @@ def _convert_time(time: Union[str, Time[int], Time[float]], state: State, ssr: f if isinstance(time, str): time = Time.from_timestring(time) + assert state.max_duration is not None, "max_duration should be set whenever schedulers are invoked" + if time.unit == TimeUnit.DURATION: if state.dataloader_len is None: raise RuntimeError("Cannot convert time, as state.dataloader_len is None.") diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index fea9f9cf26..dc3bc04458 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -678,7 +678,6 @@ def __init__( raise ValueError("grad_accum must be an int or ``auto``") self.state = State( - max_duration=max_duration, rank_zero_seed=rank_zero_seed, algorithms=algorithms, model=model, @@ -689,7 +688,6 @@ def __init__( optimizers=optimizers, ) - self._train_dataloader = train_dataloader.dataloader self.train_subset_num_batches = train_subset_num_batches pytorch_schedulers = [ @@ -800,6 +798,12 @@ def __init__( self.engine.run_event(Event.INIT) + # After running Event.INIT, then set the "optional" elements of state that could be passed in on FIT instead of INIT + # Setting these attributes here ensures that algorithms do not depend on unavailable attributes during Event.INIT + self.state.set_dataloader(train_dataloader.dataloader, 'train') + self.state.dataloader_len = self.train_subset_num_batches + self.state.max_duration = max_duration + assert isinstance(self.state.model, ComposerModel) self._original_model = self.state.model # TODO(ravi) -- update the state to add an original model helper @@ -982,12 +986,7 @@ def _train_loop(self) -> None: else: train_metrics = None - if self.state.dataloader is None: - # If the dataloader is not already set on State, then set it from what was passed in on init - self.state.set_dataloader(self._train_data_spec.dataloader, "train") - self.state.dataloader_len = self.train_subset_num_batches - - assert self.state.dataloader is not None + assert self.state.dataloader is not None, "dataloader is set in __init__" for scheduler in ensure_tuple(self._schedulers): if isinstance(scheduler, PyTorchScheduler): diff --git a/tests/algorithms/test_progressive_resizing.py b/tests/algorithms/test_progressive_resizing.py index f80e74afc7..4a68ea226f 100644 --- a/tests/algorithms/test_progressive_resizing.py +++ b/tests/algorithms/test_progressive_resizing.py @@ -149,6 +149,7 @@ def test_match_incorrect(self, event: Event, pr_algorithm: ProgressiveResizing, def test_apply(self, epoch_frac: float, X: torch.Tensor, y: torch.Tensor, pr_algorithm: ProgressiveResizing, minimal_state: State, empty_logger: Logger): """Test apply at different epoch fractions (fraction of max epochs)""" + assert minimal_state.max_duration is not None assert minimal_state.max_duration.unit == TimeUnit.EPOCH minimal_state.timer.epoch._value = int(epoch_frac * minimal_state.max_duration.value) s = pr_algorithm.initial_scale diff --git a/tests/algorithms/test_stochastic_depth.py b/tests/algorithms/test_stochastic_depth.py index 5ab2cbebf9..d00bfcfde1 100644 --- a/tests/algorithms/test_stochastic_depth.py +++ b/tests/algorithms/test_stochastic_depth.py @@ -232,6 +232,7 @@ def test_drop_rate_warmup(self, algorithm: StochasticDepth, step: int, state: St new_drop_rates = [] self.get_drop_rate_list(state.model, drop_rates=new_drop_rates) + assert state.max_duration is not None assert state.max_duration.unit == TimeUnit.EPOCH assert state.dataloader_len is not None drop_warmup_iters = int(int(state.dataloader_len) * int(state.max_duration.value) * algorithm.drop_warmup) diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index c7b03b7c96..c8e84507d2 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -414,6 +414,7 @@ def _test_checkpoint_trainer(trainer_hparams: TrainerHparams): def _validate_events_called_expected_number_of_times(trainer: Trainer): state = trainer.state assert trainer.train_subset_num_batches is not None + assert state.max_duration is not None assert state.max_duration.unit == TimeUnit.EPOCH num_epochs = state.max_duration.value num_total_steps = num_epochs * trainer.train_subset_num_batches diff --git a/tests/trainer/test_scale_schedule.py b/tests/trainer/test_scale_schedule.py index 2084d1a988..0d98624929 100644 --- a/tests/trainer/test_scale_schedule.py +++ b/tests/trainer/test_scale_schedule.py @@ -121,6 +121,7 @@ def fit_start(self, state: State, logger: Logger) -> None: trainer = composer_trainer_hparams.initialize_object() trainer.state.callbacks.append(CheckScaleSchedule()) + assert trainer.state.max_duration is not None assert trainer.state.max_duration.unit == TimeUnit.EPOCH assert trainer.state.max_duration.value == int(10 * ssr) diff --git a/tests/trainer/test_scheduler.py b/tests/trainer/test_scheduler.py index b191ed32c3..cfa21a61f4 100644 --- a/tests/trainer/test_scheduler.py +++ b/tests/trainer/test_scheduler.py @@ -90,6 +90,7 @@ def test_scheduler_init(scheduler: ComposerScheduler, ssr: float, test_times: Li state = dummy_schedulers_state assert state.dataloader_len is not None + assert state.max_duration is not None state.max_duration = Time(value=int(state.max_duration.value * ssr), unit=state.max_duration.unit) for test_time, expected_lr in zip(test_times, expected_lrs): parsed_time = Time.from_timestring(test_time) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 48c113b76b..6b718e2984 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -109,6 +109,7 @@ def test_init_with_integers(self, config, tmpdir: pathlib.Path): def test_init_with_max_duration_in_batches(self, config): config["max_duration"] = '1ba' trainer = Trainer(**config) + assert trainer.state.max_duration is not None assert trainer.state.max_duration.to_timestring() == "1ba" From fe701332345c034ee6f8fbb701d27ce8793ed68c Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Thu, 14 Apr 2022 10:37:18 -0700 Subject: [PATCH 08/17] Addressed PR feedback; fixed Time type annotations --- .../layer_freezing/layer_freezing.py | 2 +- .../progressive_resizing.py | 2 +- composer/algorithms/swa/swa.py | 19 ++++++--- composer/callbacks/speed_monitor.py | 1 + composer/core/engine.py | 25 +++++++++++ composer/core/time.py | 22 +++++----- tests/trainer/test_scale_schedule.py | 42 ++++++++++--------- 7 files changed, 75 insertions(+), 38 deletions(-) diff --git a/composer/algorithms/layer_freezing/layer_freezing.py b/composer/algorithms/layer_freezing/layer_freezing.py index c309f461e7..d56544d726 100644 --- a/composer/algorithms/layer_freezing/layer_freezing.py +++ b/composer/algorithms/layer_freezing/layer_freezing.py @@ -134,7 +134,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: optimizers = state.optimizers assert optimizers is not None elapsed_duration = state.get_elapsed_duration() - assert elapsed_duration is not None, "elapsed duration is available on epoch end" + assert elapsed_duration is not None, "elapsed duration should be set on Event.EPOCH_END" freeze_depth, freeze_percentage = freeze_layers( model=state.model, optimizers=optimizers, diff --git a/composer/algorithms/progressive_resizing/progressive_resizing.py b/composer/algorithms/progressive_resizing/progressive_resizing.py index c5c7f6977f..71a68d3d49 100644 --- a/composer/algorithms/progressive_resizing/progressive_resizing.py +++ b/composer/algorithms/progressive_resizing/progressive_resizing.py @@ -198,7 +198,7 @@ def apply(self, event: Event, state: State, logger: Optional[Logger] = None) -> initial_size = self.initial_scale finetune_fraction = self.finetune_fraction elapsed_duration = state.get_elapsed_duration() - assert elapsed_duration is not None, "elapsed duration is set on Event.AFTER_DATALOADER" + assert elapsed_duration is not None, "elapsed duration should be set on Event.AFTER_DATALOADER" scale_frac_elapsed = min([elapsed_duration.value / (1 - finetune_fraction), 1]) # Linearly increase to full size at the start of the fine tuning period diff --git a/composer/algorithms/swa/swa.py b/composer/algorithms/swa/swa.py index a2e1e9782e..761c4594e4 100644 --- a/composer/algorithms/swa/swa.py +++ b/composer/algorithms/swa/swa.py @@ -172,8 +172,12 @@ def __init__(self, self.match_event = Event.EPOCH_END def match(self, event: Event, state: State) -> bool: + if event != self.match_event: + return False if self.swa_start.unit == TimeUnit.DURATION: - should_start_swa = state.get_elapsed_duration() >= self.swa_start and not self.swa_completed + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed duration should be set on Event.BATCH_END or Event.EPOCH_END" + should_start_swa = elapsed_duration >= self.swa_start and not self.swa_completed elif self.swa_start.unit == TimeUnit.EPOCH: should_start_swa = state.timer.get("ep") >= self.swa_start and not self.swa_completed else: @@ -219,16 +223,19 @@ def apply(self, event: Event, state: State, logger: Logger) -> None: self.step_counter += 1 + elapsed_duration = state.get_elapsed_duration() + assert elapsed_duration is not None, "elapsed duration should be set on Event.BATCH_END or Event.EPOCH_END" + # Determine whether it's time to end SWA - if self.swa_end.unit == TimeUnit.DURATION and (state.get_elapsed_duration() >= self.swa_end): + if self.swa_end.unit == TimeUnit.DURATION and (elapsed_duration >= self.swa_end): self.swa_completed = True if self.swa_end.unit == TimeUnit.EPOCH and (state.timer.get("ep") >= self.swa_end): self.swa_completed = True if self.swa_completed: if state.get_elapsed_duration() == 1: - log.warning("The baseline model was replaced with the SWA model after the end of " - "training. This means that SWA model will not have its batch norm " - "statistics updated. This will negatively impact accuracy. See the " - "documentation for the `swa_end` parameter for details.") + log.warning(("The baseline model was replaced with the SWA model after the end of " + "training. This means that SWA model will not have its batch norm " + "statistics updated. This will negatively impact accuracy. See the " + "documentation for the `swa_end` parameter for details.")) state.model.load_state_dict(self.swa_model.module.state_dict()) # type: ignore log.info('Set model to the averaged model') diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index ff2f480a46..6e29724200 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -126,6 +126,7 @@ def epoch_start(self, state: State, logger: Logger): def batch_end(self, state: State, logger: Logger): self.batch_end_times.append(time.time()) new_num_samples = state.timer.sample + assert self.batch_start_num_samples is not None, "self.batch_start_num_samples should have been set on Event.BATCH_START" batch_num_samples = int(new_num_samples - self.batch_start_num_samples) self.batch_num_samples.append(batch_num_samples) self.train_examples_per_epoch += batch_num_samples diff --git a/composer/core/engine.py b/composer/core/engine.py index 33ec3173a0..35cee9ec72 100644 --- a/composer/core/engine.py +++ b/composer/core/engine.py @@ -83,6 +83,25 @@ Traces = Dict[str, "Trace"] _ALWAYS_RECORD_EVENTS = [Event.INIT, Event.FIT_START, Event.EPOCH_START, Event.EPOCH_END] +_EVENTS_WHERE_DATALOADER_IS_SET = [e for e in Event if e != Event.INIT] +_EVENTS_WHERE_MAX_DURATION_IS_SET = [ + Event.FIT_START, + Event.EPOCH_START, + Event.BATCH_START, + Event.AFTER_DATALOADER, + Event.BEFORE_TRAIN_BATCH, + Event.BEFORE_FORWARD, + Event.AFTER_FORWARD, + Event.BEFORE_LOSS, + Event.AFTER_LOSS, + Event.BEFORE_BACKWARD, + Event.AFTER_BACKWARD, + Event.AFTER_TRAIN_BATCH, + Event.BATCH_END, + Event.BATCH_CHECKPOINT, + Event.EPOCH_END, + Event.EPOCH_CHECKPOINT, +] @dataclass @@ -174,6 +193,12 @@ def run_event( if event.is_after_event and duration_marker is not None: duration_marker.finish() + if event in _EVENTS_WHERE_DATALOADER_IS_SET: + assert self.state.dataloader is not None, f"The trainer should have set state.dataloader for event {event}." + + if event in _EVENTS_WHERE_MAX_DURATION_IS_SET: + assert self.state.max_duration is not None, f"The trainer should have set state.max_duration for event {event}." + if event == Event.INIT: # For the INIT event, run the callbacks first to initialize the loggers # For other events, run the algorithms first, so the callbacks have the state diff --git a/composer/core/time.py b/composer/core/time.py index c9fafe6517..ee40905d1e 100644 --- a/composer/core/time.py +++ b/composer/core/time.py @@ -237,7 +237,7 @@ def _parse(self, other: object) -> Time: raise TypeError(f"Cannot convert type {other} to {self.__class__.__name__}") - def _cmp(self, other: object) -> int: + def _cmp(self, other: Union[int, float, Time, str]) -> int: # When doing comparisions, and other is an integer (or float), we can safely infer # the unit from self.unit # E.g. calls like this should be allowed: if batch < 42: do_something() @@ -256,40 +256,40 @@ def _cmp(self, other: object) -> int: assert self.value > other.value return 1 - def __eq__(self, other: object): + def __eq__(self, other: Union[int, float, Time, str]): return self._cmp(other) == 0 - def __ne__(self, other: object): + def __ne__(self, other: Union[int, float, Time, str]): return self._cmp(other) != 0 - def __lt__(self, other: object): + def __lt__(self, other: Union[int, float, Time, str]): return self._cmp(other) < 0 - def __le__(self, other: object): + def __le__(self, other: Union[int, float, Time, str]): return self._cmp(other) <= 0 - def __gt__(self, other: object): + def __gt__(self, other: Union[int, float, Time, str]): return self._cmp(other) > 0 - def __ge__(self, other: object): + def __ge__(self, other: Union[int, float, Time, str]): return self._cmp(other) >= 0 - def __add__(self, other: object) -> Time[TValue]: + def __add__(self, other: Union[int, float, Time, str]) -> Time[TValue]: other = self._parse(other) if self.unit != other.unit: raise RuntimeError(f"Cannot add {self} to {other} since they have different units.") return Time(self.value + other.value, self.unit) - def __radd__(self, other: object) -> Time[TValue]: + def __radd__(self, other: Union[int, float, Time, str]) -> Time[TValue]: return self + other - def __sub__(self, other: object) -> Time[TValue]: + def __sub__(self, other: Union[int, float, Time, str]) -> Time[TValue]: other = self._parse(other) if self.unit != other.unit: raise RuntimeError(f"Cannot subtract {other} from {self} since they have different units.") return Time(self.value - other.value, self.unit) - def __rsub__(self, other: object) -> Time[TValue]: + def __rsub__(self, other: Union[int, float, Time, str]) -> Time[TValue]: return (-self) + other def __neg__(self) -> Time[TValue]: diff --git a/tests/trainer/test_scale_schedule.py b/tests/trainer/test_scale_schedule.py index 0d98624929..f460fb1f68 100644 --- a/tests/trainer/test_scale_schedule.py +++ b/tests/trainer/test_scale_schedule.py @@ -76,6 +76,28 @@ def test_scale_schedule_cosine_warm_restarts(self, optimizer: Optimizer, ssr: fl raise NotImplementedError +class CheckScaleSchedule(Callback): + + def __init__(self, ssr: float) -> None: + self.ssr = ssr + + def fit_start(self, state: State, logger: Logger) -> None: + scheduler = state.schedulers[0] + + test_steps = [int(20 * self.ssr), int(40 * self.ssr), int(60 * self.ssr)] + target_lrs = [1.0, 0.1, 0.01] + current_step = 0 + for test_step, target_lr in zip(test_steps, target_lrs): + + while current_step < test_step: + state.timer.on_batch_complete() + current_step += 1 + + scheduler.step() + + assert scheduler.get_last_lr()[0] == pytest.approx(target_lr) + + @pytest.mark.parametrize('ssr', [0.5, 0.75, 1.0]) @pytest.mark.parametrize('use_algorithm', [ False, @@ -100,26 +122,8 @@ def test_epochs_scaled( else: composer_trainer_hparams.scale_schedule_ratio = ssr - class CheckScaleSchedule(Callback): - - def fit_start(self, state: State, logger: Logger) -> None: - scheduler = state.schedulers[0] - - test_steps = [int(20 * ssr), int(40 * ssr), int(60 * ssr)] - target_lrs = [1.0, 0.1, 0.01] - current_step = 0 - for test_step, target_lr in zip(test_steps, target_lrs): - - while current_step < test_step: - trainer.state.timer.on_batch_complete() - current_step += 1 - - scheduler.step() - - assert scheduler.get_last_lr()[0] == pytest.approx(target_lr) - trainer = composer_trainer_hparams.initialize_object() - trainer.state.callbacks.append(CheckScaleSchedule()) + trainer.state.callbacks.append(CheckScaleSchedule(ssr)) assert trainer.state.max_duration is not None assert trainer.state.max_duration.unit == TimeUnit.EPOCH From 11544bcff9f2c50aaea831ba94a9fe4a4b402e67 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Thu, 14 Apr 2022 11:02:34 -0700 Subject: [PATCH 09/17] Fixed doctests --- docs/source/doctest_fixtures.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/doctest_fixtures.py b/docs/source/doctest_fixtures.py index 222ae50f7c..61785aa545 100644 --- a/docs/source/doctest_fixtures.py +++ b/docs/source/doctest_fixtures.py @@ -116,6 +116,7 @@ optimizers=optimizer, grad_accum=1, dataloader=train_dataloader, + dataloader_label="train", max_duration="1ep", precision="fp32", ) From 1313a4966dd9504e0e01e2e727fe7e5b15482a22 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Thu, 14 Apr 2022 11:34:53 -0700 Subject: [PATCH 10/17] Fixed selective backprop --- composer/algorithms/selective_backprop/selective_backprop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/algorithms/selective_backprop/selective_backprop.py b/composer/algorithms/selective_backprop/selective_backprop.py index 788194de6c..eb4b59a9cf 100644 --- a/composer/algorithms/selective_backprop/selective_backprop.py +++ b/composer/algorithms/selective_backprop/selective_backprop.py @@ -208,7 +208,7 @@ def match(self, event: Event, state: State) -> bool: assert elapsed_duration is not None, "elapsed duration should be set on Event.AFTER_DATALOADER" is_chosen = should_selective_backprop( - current_duration=float(), + current_duration=float(elapsed_duration), batch_idx=state.timer.batch_in_epoch.value, start=self.start, end=self.end, From b5b6192e2584f02e14bce1381f0f65fc270382d3 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Thu, 14 Apr 2022 12:41:40 -0700 Subject: [PATCH 11/17] Inceased timeout --- tests/profiler/test_json_trace_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/profiler/test_json_trace_handler.py b/tests/profiler/test_json_trace_handler.py index bfcc3178ce..56314bb921 100644 --- a/tests/profiler/test_json_trace_handler.py +++ b/tests/profiler/test_json_trace_handler.py @@ -11,7 +11,7 @@ # This test shouldn't run with the Torch profiler enabled, not providing a model or data can cause a seg fault -@pytest.mark.timeout(10) +@pytest.mark.timeout(30) def test_json_trace_profiler_handler(composer_trainer_hparams: TrainerHparams, tmpdir: pathlib.Path): profiler_file = os.path.join(tmpdir, 'trace.json') json_trace_handler_params = JSONTraceHparams(folder=str(tmpdir), merged_trace_filename='trace.json') From d3408ba890c57234de480da3d1c8b849d388329e Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Fri, 15 Apr 2022 08:03:45 -0700 Subject: [PATCH 12/17] Remove optimizers from state on init; clean up PR --- composer/core/state.py | 1 - composer/trainer/trainer.py | 12 +++++------- tests/algorithms/test_layer_freezing.py | 4 ++-- tests/fixtures/dummy_fixtures.py | 7 +++---- tests/fixtures/new_fixtures.py | 6 +++--- tests/trainer/test_ddp_sync_strategy.py | 17 ++++++++++------- tests/utils/classifer.py | 6 +++--- 7 files changed, 26 insertions(+), 27 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 474806fccc..9d6d15de05 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -195,7 +195,6 @@ def __init__( self.model = model self.grad_accum = grad_accum self.set_dataloader(dataloader, dataloader_label, dataloader_len) - self.dataloader_len = dataloader_len self.max_duration = max_duration self.timer = Timer() diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 10f777b58e..5ae112af00 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -457,7 +457,8 @@ def __init__( eval_dataloader: Optional[Union[DataLoader, DataSpec, Evaluator, Sequence[Evaluator]]] = None, algorithms: Optional[Union[Algorithm, Sequence[Algorithm]]] = None, optimizers: Optional[torch.optim.Optimizer] = None, - schedulers: Optional[Union[ComposerScheduler, Sequence[ComposerScheduler]]] = None, + schedulers: Optional[Union[ComposerScheduler, PyTorchScheduler, Sequence[Union[ComposerScheduler, + PyTorchScheduler]]]] = None, # device device: Optional[Union[str, Device]] = None, @@ -685,11 +686,8 @@ def __init__( grad_accum=grad_accum, precision=precision, precision_context=precision_context, - optimizers=optimizers, ) - self.train_subset_num_batches = train_subset_num_batches - pytorch_schedulers = [ scheduler for scheduler in ensure_tuple(schedulers) if isinstance(scheduler, PyTorchScheduler) ] @@ -800,9 +798,9 @@ def __init__( # After running Event.INIT, then set the "optional" elements of state that could be passed in on FIT instead of INIT # Setting these attributes here ensures that algorithms do not depend on unavailable attributes during Event.INIT - self.state.set_dataloader(train_dataloader.dataloader, 'train') - self.state.dataloader_len = self.train_subset_num_batches + self.state.set_dataloader(train_dataloader.dataloader, 'train', train_subset_num_batches) self.state.max_duration = max_duration + self.state.optimizers = optimizers assert isinstance(self.state.model, ComposerModel) self._original_model = self.state.model # TODO(ravi) -- update the state to add an original model helper @@ -988,7 +986,7 @@ def _train_loop(self) -> None: assert self.state.dataloader is not None, "dataloader is set in __init__" - for scheduler in ensure_tuple(self._schedulers): + for scheduler in self._schedulers: if isinstance(scheduler, PyTorchScheduler): scale_pytorch_scheduler(scheduler, self._scale_schedule_ratio) self.state.schedulers.append(scheduler) diff --git a/tests/algorithms/test_layer_freezing.py b/tests/algorithms/test_layer_freezing.py index 9deec8f033..77aba31c79 100644 --- a/tests/algorithms/test_layer_freezing.py +++ b/tests/algorithms/test_layer_freezing.py @@ -19,6 +19,8 @@ def _generate_state(epoch: int, max_epochs: int): rank_zero_seed=0, optimizers=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.99), precision=Precision.FP32, + dataloader=Mock(__len__=lambda x: 100), + dataloader_label="train", grad_accum=1, max_duration=f'{max_epochs}ep') @@ -26,8 +28,6 @@ def _generate_state(epoch: int, max_epochs: int): for _ in range(epoch): state.timer.on_epoch_complete() - state.set_dataloader(Mock(__len__=lambda x: 100), "train") - return state diff --git a/tests/fixtures/dummy_fixtures.py b/tests/fixtures/dummy_fixtures.py index 7589e70560..a3023ef87d 100644 --- a/tests/fixtures/dummy_fixtures.py +++ b/tests/fixtures/dummy_fixtures.py @@ -215,16 +215,15 @@ def simple_conv_model_input(): @pytest.fixture() def state_with_model(simple_conv_model: torch.nn.Module, dummy_train_dataloader: DataLoader, rank_zero_seed: int): - state = State( + return State( grad_accum=1, rank_zero_seed=rank_zero_seed, max_duration="100ep", model=simple_conv_model, precision=Precision.FP32, + dataloader=dummy_train_dataloader, + dataloader_label="train", ) - state.set_dataloader(dummy_train_dataloader, "train") - - return state @pytest.fixture() diff --git a/tests/fixtures/new_fixtures.py b/tests/fixtures/new_fixtures.py index 7d8f8d9681..5c460ed72d 100644 --- a/tests/fixtures/new_fixtures.py +++ b/tests/fixtures/new_fixtures.py @@ -15,13 +15,13 @@ def minimal_state(rank_zero_seed: int): Tests should configure the state for their specific needs. """ - state = State( + return State( model=SimpleModel(), rank_zero_seed=rank_zero_seed, max_duration='100ep', + dataloader=DataLoader(RandomClassificationDataset()), + dataloader_label="train", ) - state.set_dataloader(DataLoader(RandomClassificationDataset()), "train") - return state @pytest.fixture diff --git a/tests/trainer/test_ddp_sync_strategy.py b/tests/trainer/test_ddp_sync_strategy.py index ad88534477..594444ce49 100644 --- a/tests/trainer/test_ddp_sync_strategy.py +++ b/tests/trainer/test_ddp_sync_strategy.py @@ -51,13 +51,16 @@ def test_ddp_sync_strategy(ddp_sync_strategy: str, expected_grads: List[Optional original_model = MinimalConditionalModel() # ddp = DDP(backend="gloo", find_unused_parameters=True, sync_strategy=ddp_sync_strategy, timeout=5.) optimizer = torch.optim.SGD(original_model.parameters(), 0.1) - state = State(model=original_model, - rank_zero_seed=rank_zero_seed, - optimizers=optimizer, - grad_accum=2, - max_duration="1ep", - precision='fp32') - state.set_dataloader(dummy_train_dataloader, "train") + state = State( + model=original_model, + rank_zero_seed=rank_zero_seed, + optimizers=optimizer, + grad_accum=2, + max_duration="1ep", + dataloader=dummy_train_dataloader, + dataloader_label="train", + precision='fp32', + ) batches = [[(1, Tensor([1])), (1, Tensor([2]))], [(2, Tensor([1])), (2, Tensor([2]))]] state.model = _prepare_ddp_module(state.model, find_unused_parameters=True) diff --git a/tests/utils/classifer.py b/tests/utils/classifer.py index a4f085a6a5..e32fef0ea4 100644 --- a/tests/utils/classifer.py +++ b/tests/utils/classifer.py @@ -14,16 +14,16 @@ def _get_state(train_dataloader: DataLoader, eval_dataloader: DataLoader, rank_zero_seed: int): model = SimpleModel() - state = State( + return State( model=model, rank_zero_seed=rank_zero_seed, optimizers=SGD(model.parameters(), lr=.001, momentum=0.0), max_duration="1ep", grad_accum=1, + dataloader=train_dataloader, + dataloader_label="train", precision=Precision.FP32, ) - state.set_dataloader(train_dataloader, "train") - return state def test_classifier_trains( From 1470d11fd8e45a49053087ca20775d9606463cee Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Fri, 15 Apr 2022 08:35:44 -0700 Subject: [PATCH 13/17] Bind the schedulers to the state in `__init__()`, rather than on `fit()`. Preferable to keep variables on the state object rather than as trainer members, where appropriate. Before, the state.schedulers was empty after `__init__()` but before `fit()`. Now, state.schedulers contains the compiled composer schedulers or original pytorch schedulers. Restored optimizers on Event.INIT; it will be a bigger issue to rewrite algs to not depend on optimizers on init. --- composer/trainer/trainer.py | 23 ++++++++++------------- tests/callbacks/test_speed_monitor.py | 5 +++-- tests/loggers/test_progress_bar_logger.py | 4 +++- tests/test_notebooks.py | 4 ++-- tests/trainer/test_checkpoint.py | 5 +++-- 5 files changed, 21 insertions(+), 20 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 5ae112af00..5a53468e31 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -686,6 +686,7 @@ def __init__( grad_accum=grad_accum, precision=precision, precision_context=precision_context, + optimizers=optimizers, ) pytorch_schedulers = [ @@ -719,11 +720,8 @@ def __init__( step_schedulers_every_batch = True self._step_schedulers_every_batch = step_schedulers_every_batch - self._scale_schedule_ratio = scale_schedule_ratio - self._schedulers = ensure_tuple(schedulers) - - if len(self._schedulers) == 0: + if len(ensure_tuple(schedulers)) == 0: warnings.warn(f"NoSchedulerWarning: No schedulers were specified. The learning rate will be constant.") # Configure profilers if profiling is enabled @@ -800,7 +798,6 @@ def __init__( # Setting these attributes here ensures that algorithms do not depend on unavailable attributes during Event.INIT self.state.set_dataloader(train_dataloader.dataloader, 'train', train_subset_num_batches) self.state.max_duration = max_duration - self.state.optimizers = optimizers assert isinstance(self.state.model, ComposerModel) self._original_model = self.state.model # TODO(ravi) -- update the state to add an original model helper @@ -830,6 +827,14 @@ def __init__( if "optimizers" in self.state.serialized_attributes: self.state.serialized_attributes.remove("optimizers") + # Compile and bind the schedulers, after deepspeed might have potentially changed the Optimizers + for scheduler in ensure_tuple(schedulers): + if isinstance(scheduler, PyTorchScheduler): + scale_pytorch_scheduler(scheduler, scale_schedule_ratio) + self.state.schedulers.append(scheduler) + else: # it's a composer scheduler + self.state.schedulers.append(compile_composer_scheduler(scheduler, self.state, scale_schedule_ratio)) + # If using DeepSpeed, the model must be loaded from checkpoint after the engine has been # initialized, but if using PyTorch DDP, the model must be loaded before it is wrapped with # DDP. @@ -986,14 +991,6 @@ def _train_loop(self) -> None: assert self.state.dataloader is not None, "dataloader is set in __init__" - for scheduler in self._schedulers: - if isinstance(scheduler, PyTorchScheduler): - scale_pytorch_scheduler(scheduler, self._scale_schedule_ratio) - self.state.schedulers.append(scheduler) - else: # it's a composer scheduler - self.state.schedulers.append( - compile_composer_scheduler(scheduler, self.state, self._scale_schedule_ratio)) - self.engine.run_event(Event.FIT_START) self.state.scaler = ClosureGradScaler() if self._use_closures() else GradScaler() diff --git a/tests/callbacks/test_speed_monitor.py b/tests/callbacks/test_speed_monitor.py index f19f707b4e..8c72cf49f8 100644 --- a/tests/callbacks/test_speed_monitor.py +++ b/tests/callbacks/test_speed_monitor.py @@ -36,8 +36,9 @@ def test_speed_monitor(composer_trainer_hparams: TrainerHparams): wall_clock_train_calls += 1 assert isinstance(trainer.state.dataloader, collections.abc.Sized) - assert trainer.train_subset_num_batches is not None - expected_step_calls = (trainer.train_subset_num_batches - speed_monitor_hparams.window_size) * max_epochs + assert trainer.state.dataloader_label is not None + assert trainer.state.dataloader_len is not None + expected_step_calls = (trainer.state.dataloader_len - speed_monitor_hparams.window_size) * max_epochs assert throughput_step_calls == expected_step_calls assert throughput_epoch_calls == max_epochs assert wall_clock_train_calls == max_epochs diff --git a/tests/loggers/test_progress_bar_logger.py b/tests/loggers/test_progress_bar_logger.py index 0191bc0b1f..ebdaf8c30a 100644 --- a/tests/loggers/test_progress_bar_logger.py +++ b/tests/loggers/test_progress_bar_logger.py @@ -41,7 +41,9 @@ def get_mock_tqdm(position: int, *args: object, **kwargs: object): assert composer_trainer_hparams.validate_every_n_batches < 0 assert len(is_train_to_mock_tqdms[False]) == composer_trainer_hparams.validate_every_n_epochs * max_epochs for mock_tqdm in is_train_to_mock_tqdms[True]: - assert mock_tqdm.update.call_count == trainer.train_subset_num_batches + assert trainer.state.dataloader_len is not None + assert trainer.state.dataloader_label == "train" + assert mock_tqdm.update.call_count == int(trainer.state.dataloader_len) mock_tqdm.close.assert_called_once() for mock_tqdm in is_train_to_mock_tqdms[False]: assert mock_tqdm.update.call_count == trainer.eval_subset_num_batches diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py index 484d421454..c0f89c90c0 100644 --- a/tests/test_notebooks.py +++ b/tests/test_notebooks.py @@ -56,10 +56,10 @@ def test_notebook(tb): tb.inject(""" from composer.core import Time trainer.state.max_duration = Time.from_timestring('2ep') - trainer.state.train_subset_num_batches = 2 + trainer.state.dataloader_len = 2 """) trainer = tb.ref("trainer") - assert trainer.state.train_subset_num_batches == 2 + assert trainer.state.dataloader_len == 2 except Exception as e: raise Exception( textwrap.dedent(""" diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index c8e84507d2..8834ed81c2 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -413,11 +413,12 @@ def _test_checkpoint_trainer(trainer_hparams: TrainerHparams): def _validate_events_called_expected_number_of_times(trainer: Trainer): state = trainer.state - assert trainer.train_subset_num_batches is not None + assert state.dataloader_label == "train" + assert state.dataloader_len is not None assert state.max_duration is not None assert state.max_duration.unit == TimeUnit.EPOCH num_epochs = state.max_duration.value - num_total_steps = num_epochs * trainer.train_subset_num_batches + num_total_steps = num_epochs * int(state.dataloader_len) num_total_microbatches = num_total_steps * state.grad_accum num_evals = 0 if trainer._validate_every_n_batches > 0: From 54d6b8898a7ef702bb5469f9bbfdc24141923f95 Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Wed, 20 Apr 2022 16:32:24 -0700 Subject: [PATCH 14/17] Fixed the deepspeed schedulers --- composer/trainer/trainer.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 3aa30d40fd..d190e6733f 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -870,6 +870,14 @@ def __init__( assert isinstance(self.state.model, ComposerModel) self._original_model = self.state.model # TODO(ravi) -- update the state to add an original model helper + # Compile and bind the schedulers + for scheduler in ensure_tuple(schedulers): + if isinstance(scheduler, PyTorchScheduler): + scale_pytorch_scheduler(scheduler, scale_schedule_ratio) + self.state.schedulers.append(scheduler) + else: # it's a composer scheduler + self.state.schedulers.append(compile_composer_scheduler(scheduler, self.state, scale_schedule_ratio)) + # place the state, model in the proper devices, and initialize from a checkpoint if provided if self.deepspeed_enabled: try: @@ -887,7 +895,12 @@ def __init__( model=self.state.model, optimizer=optimizer, ) - # The deepspeed engine is responsible for serializing the model and optimizer state, + # Since the DeepSpeed ZeRO optimizer does not inherit torch.optim.Optimizer, the schedulers must be + # compiled and bound BEFORE DeepSpeed initialization. However, this is OK, as the the DeepSpeed Zero + # optimizer uses the same underlying parameter groups as the original optimizer. See + # * https://github.com/microsoft/DeepSpeed/blob/fee73135980e78f8be7e1a3ff556751623ef6aaa/deepspeed/runtime/zero/stage_1_and_2.py#L1911-L1917 + # * https://github.com/microsoft/DeepSpeed/blob/ef17c89570ceae5b26a5f886e9d8cd0941afc0ac/deepspeed/runtime/zero/stage3.py#L2532-L2538 + # In addition, the deepspeed engine is responsible for serializing the model and optimizer state, # so these attributes should not be serialized with the composer state. if "model" in self.state.serialized_attributes: self.state.serialized_attributes.remove("model") @@ -895,14 +908,6 @@ def __init__( if "optimizers" in self.state.serialized_attributes: self.state.serialized_attributes.remove("optimizers") - # Compile and bind the schedulers, after deepspeed might have potentially changed the Optimizers - for scheduler in ensure_tuple(schedulers): - if isinstance(scheduler, PyTorchScheduler): - scale_pytorch_scheduler(scheduler, scale_schedule_ratio) - self.state.schedulers.append(scheduler) - else: # it's a composer scheduler - self.state.schedulers.append(compile_composer_scheduler(scheduler, self.state, scale_schedule_ratio)) - # If using DeepSpeed, the model must be loaded from checkpoint after the engine has been # initialized, but if using PyTorch DDP, the model must be loaded before it is wrapped with # DDP. From 49e68e844656877694c46f0e515febccba02fe7d Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Thu, 21 Apr 2022 10:17:26 -0700 Subject: [PATCH 15/17] * Addressed PR Feedback * Removed `precision_context` from state * Switched `train_subset_num_batches` and `eval_subset_num_batches` to use `-1` as the default value instead of `None`. --- .../selective_backprop/selective_backprop.py | 5 +- .../seq_length_warmup/seq_length_warmup.py | 3 +- composer/core/precision.py | 37 ++++++++- composer/core/state.py | 75 ++++++++----------- composer/trainer/devices/device.py | 33 +------- composer/trainer/devices/device_cpu.py | 14 +--- composer/trainer/devices/device_gpu.py | 25 +------ composer/trainer/trainer.py | 37 ++++----- composer/trainer/trainer_hparams.py | 8 +- 9 files changed, 100 insertions(+), 137 deletions(-) diff --git a/composer/algorithms/selective_backprop/selective_backprop.py b/composer/algorithms/selective_backprop/selective_backprop.py index 5d13b34b75..fce013f74d 100644 --- a/composer/algorithms/selective_backprop/selective_backprop.py +++ b/composer/algorithms/selective_backprop/selective_backprop.py @@ -12,6 +12,7 @@ from torch.nn import functional as F from composer.core import Algorithm, Event, State +from composer.core.precision import get_precision_context from composer.loggers import Logger from composer.models import ComposerModel @@ -93,7 +94,7 @@ def select_using_loss(input: torch.Tensor, This function runs an extra forward pass through the model on the batch of data. If you are using a non-default precision, ensure that this forward pass runs in your desired precision. For example: - + .. testsetup:: N_sb, D_sb = 16, 8 @@ -253,6 +254,6 @@ def loss(p, y, reduction="none"): assert self._loss_fn is not None, "loss_fn should be set on Event.INIT" return self._loss_fn(p, (torch.Tensor(), y), reduction=reduction) - with state.precision_context: + with get_precision_context(state.precision): new_input, new_target = select_using_loss(input, target, model, loss, self.keep, self.scale_factor) state.batch = (new_input, new_target) diff --git a/composer/algorithms/seq_length_warmup/seq_length_warmup.py b/composer/algorithms/seq_length_warmup/seq_length_warmup.py index b826f302b1..6232feadd7 100644 --- a/composer/algorithms/seq_length_warmup/seq_length_warmup.py +++ b/composer/algorithms/seq_length_warmup/seq_length_warmup.py @@ -9,6 +9,7 @@ import torch from composer.core import Algorithm, Event, State +from composer.core.precision import get_precision_context from composer.core.time import TimeUnit from composer.core.types import Batch from composer.loggers import Logger @@ -223,7 +224,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: # start by running a forward and backward pass # of the maximum sequence length to allocate cache. - with state.precision_context: + with get_precision_context(state.precision): outputs = state.model.forward(model_inputs) loss = self._original_model.loss(outputs, model_inputs) diff --git a/composer/core/precision.py b/composer/core/precision.py index db11f98fe9..575937955d 100644 --- a/composer/core/precision.py +++ b/composer/core/precision.py @@ -2,9 +2,15 @@ """Enum class for the numerical precision to be used by the model.""" +import contextlib +from typing import Generator, Union + +import torch +from packaging import version + from composer.utils.string_enum import StringEnum -__all__ = ["Precision"] +__all__ = ["Precision", "get_precision_context"] class Precision(StringEnum): @@ -23,3 +29,32 @@ class Precision(StringEnum): FP16 = "fp16" FP32 = "fp32" BF16 = "bf16" + + +@contextlib.contextmanager +def get_precision_context(precision: Union[str, Precision]) -> Generator[None, None, None]: + """Returns a context manager to automatically cast to a specific precision. + + Args: + precision (str or Precision): Precision for the context + """ + + precision = Precision(precision) + enabled = False + if precision == Precision.FP32: + if not torch.cuda.is_available(): + # Yield here to avoid warnings about cuda not being available + yield + return + enabled = False + elif precision == Precision.AMP: + enabled = True + elif precision == Precision.BF16: + if version.parse(torch.__version__) < version.parse("1.10"): + raise ValueError(f"BF16 precision requires torch > 1.10, got version {torch.__version__}") + with torch.cuda.amp.autocast(True, torch.bfloat16): # type: ignore + yield + # Retain compatibility with PyTorch < 1.10 + if precision != Precision.BF16: + with torch.cuda.amp.autocast(enabled): # type: ignore + yield diff --git a/composer/core/state.py b/composer/core/state.py index 327d92fc03..dad8eed675 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -3,10 +3,9 @@ """The state of the trainer.""" from __future__ import annotations -import contextlib import logging import warnings -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Optional, Sequence, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast import torch import torch.nn.modules.utils @@ -31,24 +30,6 @@ logger = logging.getLogger(__name__) -def _default_precision_factory() -> Callable[[Union[str, Precision]], ContextManager]: - """Returns a context manager to automatically cast to a specific precision. - - Args: - precision (str or Precision): Precision for the context - """ - if torch.cuda.is_available(): - return lambda precision: torch.cuda.amp.autocast(Precision(precision) == Precision.AMP) - else: - - def null(precision): - assert Precision( - precision) != Precision.AMP, "Precision AMP is only available when `torch.cuda.is_available() == True`." - return contextlib.nullcontext() - - return null - - def _ensure_backwards_compatible_checkpointing(state_dict: Dict[str, Any]): # v0.4.1 removed the leading underscores for the keys in the state_dict # It also renamed _is_model_ddp_wrapped to is_model_ddp @@ -97,7 +78,7 @@ class State(Serializable): each device becomes ``microbatch_size = train_batch_size / (num_devices * grad_accum)``. dataloader (types.DataLoader, optional): The active DataLoader. dataloader_len (int | Time[int], optional): The number of batches per dataloader iteration (e.g. epoch). - The trainer will yield the first ``dataloader_len`` batches per iteration. If ``None`` (the default), + The trainer will yield the first ``dataloader_len`` batches per iteration. If ``-1`` (the default), the entire dataloader will be iterated over. dataloader_label (str, optional): The name for the dataloader. Required if ``dataloader`` is specified. (default: ``None``) By convention, the training dataloader is called ``'train'``. The evaluator dataloader is called @@ -105,7 +86,6 @@ class State(Serializable): max_duration (str | Time, optional): The maximum duration to train for. (default: ``None``) precision (str | Precision): The numerical precision to use for training. See :class:`~.Precision` for the supported precisions. - precision_context (Callable[[Precision], ContextManager]): Function to produce a context manager to mandate precision. optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional): The optimizer being used to train the model. Multiple optimizers are not currently supported. schedulers (types.PyTorchScheduler | Sequence[types.PyTorchScheduler], optional): @@ -214,11 +194,10 @@ def __init__( grad_accum: int = 1, dataloader: Optional[types.DataLoader] = None, dataloader_label: Optional[str] = None, - dataloader_len: Optional[Union[int, Time[int]]] = None, + dataloader_len: Union[int, Time[int]] = -1, # precision precision: Union[str, Precision] = Precision.FP32, - precision_context: Callable[[Precision], ContextManager] = _default_precision_factory(), # optimizers optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None, @@ -238,7 +217,6 @@ def __init__( self.timer = Timer() self._precision = Precision(precision) - self._precision_context = precision_context if optimizers is None: self._optimizers = [] @@ -318,7 +296,7 @@ def schedulers(self): return self._schedulers @schedulers.setter - def schedulers(self, schedulers: types.PyTorchScheduler): + def schedulers(self, schedulers: Union[types.PyTorchScheduler, Sequence[types.PyTorchScheduler]]): self._schedulers[:] = ensure_tuple(schedulers) @property @@ -436,7 +414,7 @@ def set_dataloader( self, dataloader: Optional[types.DataLoader] = None, dataloader_label: Optional[str] = None, - dataloader_len: Optional[Union[int, Time[int]]] = None, + dataloader_len: Union[int, Time[int]] = -1, ): """Update the dataloader and dataloader label. @@ -444,18 +422,28 @@ def set_dataloader( dataloader (types.DataLoader, optional): The dataloader. Defaults to None. dataloader_label (str, optional): The dataloader label. Must be ``None`` if and only if ``dataloader`` is None. Defaults to None. - dataloader_len (int, int): + dataloader_len (int, int): The number of batches per dataloader iteration (e.g. epoch), as used by the trainer. + Set to ``-1`` to iterate over the entire dataset. (Default: ``-1``.) """ - if (dataloader == None) != (dataloader_label == None): - raise ValueError("Both `dataloader` and `dataloader_label` should be None, or neither should be None.") + if dataloader is None: + dataloader_label = None + else: + if dataloader_label is None: + raise ValueError("If the `dataloader` is specified, then `dataloader_label` must not be None.") self._dataloader = dataloader self._dataloader_label = dataloader_label - self.dataloader_len = dataloader_len # setting it to None will do a failsafe read of len(dataloader) + if dataloader is not None: + self.dataloader_len = dataloader_len # setting it to -1 will do a failsafe read of len(dataloader) @property def dataloader_len(self): """The number of batches per dataloader iteration (e.g. epoch), as used by the trainer. + .. note:: + + If not explicitely specified, this value is an approximation, as it depends on ``len(self.dataloader)``. + See the :doc:`PyTorch DataLoader Documentation ` for more information. + Returns: Optional[Time[int]]: The number of batches per dataloader iteration (e.g. epoch), or None if no dataloader is defined or if the dataloader has an unknown length (e.g. streaming dataloaders). @@ -463,13 +451,10 @@ def dataloader_len(self): return self._dataloader_len @dataloader_len.setter - def dataloader_len(self, num_batches: Optional[Union[int, Time[int]]]): + def dataloader_len(self, num_batches: Union[int, Time[int]]): if isinstance(num_batches, int): num_batches = Time(num_batches, TimeUnit.BATCH) if self._dataloader is None: - if num_batches is None: - self._dataloader_len = None - return raise RuntimeError("`State.dataloader_len` cannot be set if the dataloader is not defined.") try: dataloader_len = len(self._dataloader) @@ -479,12 +464,16 @@ def dataloader_len(self, num_batches: Optional[Union[int, Time[int]]]): warnings.warn((f"DataloaderNumBatchesWarning: The dataloader_len ({int(num_batches)}) " f"is greater than the length (i.e. number of batches) of the dataloader, which is " f"{dataloader_len}. State.dataloader_len is thus being set to {dataloader_len}.")) - num_batches = Time(dataloader_len, TimeUnit.BATCH) - if num_batches is None and dataloader_len is not None: - # len(dataloader) is an approximation -- see https://pytorch.org/docs/stable/data.html. - # However, in the worst case where additional last batches are dropped, this calculation should be - # an over-estimate, leading to the entire dataloader still being iterated over. - num_batches = Time(dataloader_len, TimeUnit.BATCH) + self._dataloader_len = Time(dataloader_len, TimeUnit.BATCH) + if num_batches == -1: + if dataloader_len is not None: + # len(dataloader) is an approximation -- see https://pytorch.org/docs/stable/data.html. + # However, in the worst case where additional last batches are dropped, this calculation should be + # an over-estimate, leading to the entire dataloader still being iterated over. + self._dataloader_len = Time(dataloader_len, TimeUnit.BATCH) + else: + # The dataloader length is unknown. + self._dataloader_len = None self._dataloader_len = num_batches @property @@ -519,10 +508,6 @@ def batch_dict(self) -> types.BatchDict: from composer.core.types import as_batch_dict return as_batch_dict(self.batch) - @property - def precision_context(self): - return self._precision_context(self.precision) - @property def is_model_deepspeed(self) -> bool: """Whether :attr:`model` is an instance of a :class:`~deepspeed.DeepSpeedEngine`.""" diff --git a/composer/trainer/devices/device.py b/composer/trainer/devices/device.py index 7bb581df99..a3cc14212c 100644 --- a/composer/trainer/devices/device.py +++ b/composer/trainer/devices/device.py @@ -4,14 +4,12 @@ from abc import ABC, abstractmethod from collections.abc import Mapping, Sequence -from contextlib import contextmanager -from typing import Any, Callable, Generator, TypeVar, Union +from typing import Any, Callable, TypeVar import torch import torch.nn from torch.optim import Optimizer -from composer.core.precision import Precision from composer.core.serializable import Serializable __all__ = ["Device", "T_nnModule"] @@ -90,35 +88,6 @@ def optimizer_to_device(self, optimizer: Optimizer) -> Optimizer: state[k] = self.tensor_to_device(v) return optimizer - @abstractmethod - @contextmanager - def precision_context(self, precision: Union[str, Precision]) -> Generator[None, None, None]: - """Precision returns a context manager that uses the specified precision. - - Example usage: - - .. doctest:: - - >>> from composer.core.precision import Precision - >>> from composer.trainer.devices import DeviceCPU - >>> - >>> device = DeviceCPU() - >>> for batch in train_dataloader: - ... with device.precision_context(Precision.FP32): - ... outputs = model.forward(batch) - ... - ... with device.precision_context(Precision.FP32): - ... loss = model.loss(outputs, batch) - >>> - - Args: - precision (Precision): The desired precision for the device. - - Yields: - Generator[None, None, None]: A context for the precision. - """ - pass - def _map_batch(batch: Any, map_fn: Callable) -> Any: """Recursively maps a function to all items in a batch. diff --git a/composer/trainer/devices/device_cpu.py b/composer/trainer/devices/device_cpu.py index 3665bd678e..539d3cb96b 100644 --- a/composer/trainer/devices/device_cpu.py +++ b/composer/trainer/devices/device_cpu.py @@ -5,13 +5,11 @@ from __future__ import annotations import logging -from contextlib import contextmanager -from typing import Any, Dict, Generator, TypeVar, Union +from typing import Any, Dict, TypeVar import torch -from composer.core import Precision -from composer.trainer.devices.device import Device, T_nnModule +from composer.trainer.devices.device import Device logger = logging.getLogger(__name__) @@ -35,14 +33,6 @@ def module_to_device(self, module: T_nnModule) -> T_nnModule: def tensor_to_device(self, tensor: torch.Tensor) -> torch.Tensor: return tensor.to(self._device) - @contextmanager - def precision_context(self, precision: Union[str, Precision]) -> Generator[None, None, None]: - precision = Precision(precision) - if precision == Precision.FP32: - yield - else: - raise ValueError(f"Precision {precision} not supported for a CPU") - def state_dict(self) -> Dict[str, Any]: # CPU device has no RNG state return {} diff --git a/composer/trainer/devices/device_gpu.py b/composer/trainer/devices/device_gpu.py index 15a6e1ab9b..81e7ffdf7d 100644 --- a/composer/trainer/devices/device_gpu.py +++ b/composer/trainer/devices/device_gpu.py @@ -4,16 +4,13 @@ from __future__ import annotations -from contextlib import contextmanager -from typing import Any, Dict, Generator, TypeVar, Union +from typing import Any, Dict, TypeVar import torch import torch.cuda.amp import torch.utils.data -from packaging import version -from composer.core.precision import Precision -from composer.trainer.devices.device import Device, T_nnModule +from composer.trainer.devices.device import Device from composer.utils import dist __all__ = ["DeviceGPU"] @@ -40,24 +37,6 @@ def module_to_device(self, module: T_nnModule) -> T_nnModule: def tensor_to_device(self, tensor: torch.Tensor) -> torch.Tensor: return tensor.to(self._device, non_blocking=True) - @contextmanager - def precision_context(self, precision: Union[str, Precision]) -> Generator[None, None, None]: - precision = Precision(precision) - enabled = False - if precision == Precision.FP32: - enabled = False - elif precision == Precision.AMP: - enabled = True - elif precision == Precision.BF16: - if version.parse(torch.__version__) < version.parse("1.10"): - raise ValueError(f"BF16 precision requires torch > 1.10, got version {torch.__version__}") - with torch.cuda.amp.autocast(True, torch.bfloat16): # type: ignore - yield - # Retain compatibility with PyTorch < 1.10 - if precision != Precision.BF16: - with torch.cuda.amp.autocast(enabled): # type: ignore - yield - def state_dict(self) -> Dict[str, Any]: return { "rng": torch.cuda.get_rng_state(), diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index d190e6733f..236257c920 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -85,6 +85,7 @@ from composer.algorithms import ScaleSchedule from composer.callbacks import CheckpointSaver from composer.core import Algorithm, Callback, DataSpec, Engine, Evaluator, Event, Precision, State, Time, Timestamp +from composer.core.precision import get_precision_context from composer.core.types import Batch, BreakEpochException, DataLoader, PyTorchScheduler from composer.datasets.dataloader import unwrap_data_loader from composer.loggers import Logger, LoggerDestination, LogLevel, ProgressBarLogger @@ -115,7 +116,7 @@ class Trainer: with Composer. .. seealso:: :mod:`composer.models` for models built into Composer. - train_dataloader (DataLoader, DataSpec, or dict): The :class:`.DataLoader`, :class:`.DataSpec`, + train_dataloader (DataLoader, DataSpec, or dict, optional): The :class:`.DataLoader`, :class:`.DataSpec`, or dict of :class:`.DataSpec` kwargs for the training data. In order to specify custom preprocessing steps on each data batch, specify a :class:`.DataSpec` instead of a :class:`.DataLoader`. @@ -171,8 +172,12 @@ class Trainer: .. note:: ``fp16`` only works if ``deepspeed_config`` is also provided. scale_schedule_ratio (float, optional): Ratio by which to scale the training duration and learning rate - schedules. E.g., ``0.5`` makes the schedule take half as many epochs and ``2.0`` makes it take twice as - many epochs. ``1.0`` means no change. (default: ``1.0``) + schedules. (default: ``1.0``) + + E.g., ``0.5`` makes the schedule take half as many epochs and ``2.0`` makes it take twice as + many epochs. ``1.0`` means no change. + + This parameter has no effect if ``schedulers`` is not specified. .. note :: @@ -399,11 +404,13 @@ class Trainer: artifact stores. train_subset_num_batches (int, optional): If specified, finish every epoch early after training on this many batches. This parameter has no effect if it is greater than ``len(train_dataloader)``. - If ``None``, then the entire dataloader will be iterated over. (default: ``None``) + If ``-1``, then the entire dataloader will be iterated over. (default: ``-1``) + + This parameter is ignored if ``train_dataloader`` is not specified. eval_subset_num_batches (int, optional): If specified, evaluate on this many batches per evaluation dataloader. This parameter has no effect if it is greater than ``len(eval_dataloader)``. - If ``None``, then the entire dataloader will be iterated over. (default: ``None``) + If ``-1``, then the entire dataloader will be iterated over. (default: ``-1``) deepspeed_config (bool or Dict[str, Any], optional): Configuration for DeepSpeed, formatted as a JSON according to `DeepSpeed's documentation `_. If ``True`` is provided, the trainer will initialize the DeepSpeed engine with an empty config ``{}``. If ``False`` @@ -543,8 +550,8 @@ def __init__( save_num_checkpoints_to_keep: int = -1, # subset parameters - train_subset_num_batches: Optional[int] = None, - eval_subset_num_batches: Optional[int] = None, + train_subset_num_batches: int = -1, + eval_subset_num_batches: int = -1, # DeepSpeed deepspeed_config: Union[bool, Dict[str, Any]] = False, @@ -684,10 +691,6 @@ def __init__( To fix, please do not iterate over the dataloader before passing it into the trainer.""")) - # TODO(#123): DeepSpeed still needs a precision context, but it's not completely clear how to - # handle this with our version of Pytorch - precision_context = self._device.precision_context if not self.deepspeed_enabled else cast( - Callable[..., ContextManager], contextlib.nullcontext) if isinstance(precision, str): precision = Precision(precision) @@ -723,7 +726,6 @@ def __init__( callbacks=callbacks, grad_accum=grad_accum, precision=precision, - precision_context=precision_context, optimizers=optimizers, ) @@ -1330,7 +1332,7 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, tot # forward pass self.engine.run_event(Event.BEFORE_FORWARD) - with self.state.precision_context: + with get_precision_context(self.state.precision): self.state.outputs = self.state.model(self.state.batch) self.engine.run_event(Event.AFTER_FORWARD) @@ -1338,7 +1340,7 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, tot # loss self.engine.run_event(Event.BEFORE_LOSS) - with self.state.precision_context: + with get_precision_context(self.state.precision): self.state.loss = self._original_model.loss(self.state.outputs, self.state.batch) # We always want to scale loss by the grad_accum before the backwards pass and @@ -1396,9 +1398,9 @@ def eval(self, log_level: LogLevel = LogLevel.FIT): with torch.no_grad(): for evaluator in self.evaluators: - self.state.set_dataloader(evaluator.dataloader.dataloader, evaluator.label) + self.state.set_dataloader(evaluator.dataloader.dataloader, evaluator.label, + self.eval_subset_num_batches) assert self.state.dataloader is not None, "dataloader is set" - self.state.dataloader_len = self.eval_subset_num_batches self.engine.run_event(Event.EVAL_START) @@ -1452,7 +1454,8 @@ def eval(self, log_level: LogLevel = LogLevel.FIT): self.state.model.train() self.state.set_dataloader(original_dataloader, original_dataloader_label) - self.state.dataloader_len = original_num_batches + if original_num_batches is not None: + self.state.dataloader_len = original_num_batches def _use_grad_scaling(self, precision: Union[str, Precision], scaler: Optional[GradScaler]) -> bool: """Determines based on precision when to use grad scaling. diff --git a/composer/trainer/trainer_hparams.py b/composer/trainer/trainer_hparams.py index 96ec5314bc..06e3fb3034 100755 --- a/composer/trainer/trainer_hparams.py +++ b/composer/trainer/trainer_hparams.py @@ -378,10 +378,10 @@ class TrainerHparams(hp.Hparams): ) # subset parameters - train_subset_num_batches: Optional[int] = hp.optional( - "If specified, finish every epoch early after training on this many batches.", default=None) - eval_subset_num_batches: Optional[int] = hp.optional("If specified, stop each evaluation after this many batches.", - default=None) + train_subset_num_batches: int = hp.optional( + "If specified, finish every epoch early after training on this many batches.", default=-1) + eval_subset_num_batches: int = hp.optional("If specified, stop each evaluation after this many batches.", + default=-1) # DeepSpeed deepspeed: Optional[Dict[str, JSON]] = hp.optional(doc="Configuration for DeepSpeed.", default=None) From da5bd40370c4f6c3bfa92da9daf9191aceb4af3d Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Thu, 21 Apr 2022 12:15:50 -0700 Subject: [PATCH 16/17] Fixing the `dataloader_len` setter --- composer/core/state.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index dad8eed675..1aac1aea06 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -460,12 +460,13 @@ def dataloader_len(self, num_batches: Union[int, Time[int]]): dataloader_len = len(self._dataloader) except (TypeError, NotImplementedError): dataloader_len = None - if dataloader_len is not None and num_batches is not None and int(num_batches) > dataloader_len: + if dataloader_len is not None and num_batches >= 0 and int(num_batches) > dataloader_len: warnings.warn((f"DataloaderNumBatchesWarning: The dataloader_len ({int(num_batches)}) " f"is greater than the length (i.e. number of batches) of the dataloader, which is " f"{dataloader_len}. State.dataloader_len is thus being set to {dataloader_len}.")) self._dataloader_len = Time(dataloader_len, TimeUnit.BATCH) - if num_batches == -1: + return + if num_batches < 0: if dataloader_len is not None: # len(dataloader) is an approximation -- see https://pytorch.org/docs/stable/data.html. # However, in the worst case where additional last batches are dropped, this calculation should be @@ -474,6 +475,7 @@ def dataloader_len(self, num_batches: Union[int, Time[int]]): else: # The dataloader length is unknown. self._dataloader_len = None + return self._dataloader_len = num_batches @property From 42ab447c87dcf4bcb6ad5c179a5ae77574787dff Mon Sep 17 00:00:00 2001 From: Ravi Rahman Date: Fri, 22 Apr 2022 08:27:35 -0700 Subject: [PATCH 17/17] Fix tests --- composer/core/precision.py | 20 ++++++++++---------- composer/core/state.py | 1 + tests/fixtures/new_fixtures.py | 7 ++++++- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/composer/core/precision.py b/composer/core/precision.py index 575937955d..5d124a15ed 100644 --- a/composer/core/precision.py +++ b/composer/core/precision.py @@ -40,21 +40,21 @@ def get_precision_context(precision: Union[str, Precision]) -> Generator[None, N """ precision = Precision(precision) - enabled = False if precision == Precision.FP32: - if not torch.cuda.is_available(): + if torch.cuda.is_available(): + with torch.cuda.amp.autocast(False): + yield + else: # Yield here to avoid warnings about cuda not being available yield - return - enabled = False elif precision == Precision.AMP: - enabled = True + # Retain compatibility with PyTorch < 1.10 + with torch.cuda.amp.autocast(True): + yield elif precision == Precision.BF16: if version.parse(torch.__version__) < version.parse("1.10"): raise ValueError(f"BF16 precision requires torch > 1.10, got version {torch.__version__}") - with torch.cuda.amp.autocast(True, torch.bfloat16): # type: ignore + with torch.cuda.amp.autocast(True, torch.bfloat16): yield - # Retain compatibility with PyTorch < 1.10 - if precision != Precision.BF16: - with torch.cuda.amp.autocast(enabled): # type: ignore - yield + else: + raise ValueError(f"Unsupported precision: {precision}") diff --git a/composer/core/state.py b/composer/core/state.py index 1aac1aea06..d8bb8bc0b2 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -212,6 +212,7 @@ def __init__( self.rank_zero_seed = rank_zero_seed self.model = model self.grad_accum = grad_accum + self._dataloader_len = None self.set_dataloader(dataloader, dataloader_label, dataloader_len) self.max_duration = max_duration diff --git a/tests/fixtures/new_fixtures.py b/tests/fixtures/new_fixtures.py index 7df0e24756..9f01f93657 100644 --- a/tests/fixtures/new_fixtures.py +++ b/tests/fixtures/new_fixtures.py @@ -41,8 +41,13 @@ def disable_wandb(monkeypatch: pytest.MonkeyPatch): @pytest.fixture(autouse=True) def configure_dist(request: pytest.FixtureRequest): - # Configure dist globally, so individual tests that do not use the trainer + # Configure dist globally when the world size is greater than 1, + # so individual tests that do not use the trainer # do not need to worry about manually configuring dist. + + if dist.get_world_size() == 1: + return + backend = 'gloo' if request.node.get_closest_marker('gpu') is None else 'nccl' if not dist.is_initialized(): dist.initialize_dist(backend, timeout=datetime.timedelta(seconds=300))