From 85c82ca28a141a1ec16b38a5e01b5843325aa61c Mon Sep 17 00:00:00 2001 From: ravi-mosaicml Date: Mon, 31 Jan 2022 17:33:16 -0500 Subject: [PATCH] Optimizer Surgery (#249) * Added surgery * Fixed most tests * Fixed more tests * Fixes * Fixed tests * PR Cleanup * Fixed sorting * Formatting * Cleaned up PR * Adressed most pr feedback * new method for post-device change optimizer surgery * fix len bug * fix failing test * fix test_load * minor docstring update * exclude deeplab Co-authored-by: Ravi Rahman Co-authored-by: hanlint Co-authored-by: root --- composer/algorithms/alibi/alibi.py | 33 +- composer/algorithms/blurpool/blurpool.py | 14 +- composer/algorithms/factorize/factorize.py | 25 +- .../ghost_batchnorm/ghost_batchnorm.py | 17 +- .../squeeze_excite/squeeze_excite.py | 16 +- .../stochastic_depth/stochastic_depth.py | 13 +- composer/core/engine.py | 7 +- composer/core/state.py | 23 +- composer/core/surgery.py | 286 +++++++++++++++--- composer/optim/pytorch_future.py | 4 +- composer/optim/scheduler.py | 36 ++- composer/trainer/trainer.py | 181 ++++++----- tests/algorithms/test_blurpool_algorithm.py | 43 ++- tests/algorithms/test_scale_schedule.py | 2 +- tests/fixtures/dummy_fixtures.py | 15 +- tests/fixtures/models.py | 13 +- tests/test_load.py | 5 + tests/test_surgery.py | 89 +++++- tests/trainer/test_scheduler.py | 67 ++-- tests/trainer/test_trainer.py | 13 +- 20 files changed, 655 insertions(+), 247 deletions(-) diff --git a/composer/algorithms/alibi/alibi.py b/composer/algorithms/alibi/alibi.py index 4608fed7b6..db424e2a31 100644 --- a/composer/algorithms/alibi/alibi.py +++ b/composer/algorithms/alibi/alibi.py @@ -8,13 +8,14 @@ from dataclasses import asdict, dataclass from operator import attrgetter from types import MethodType, ModuleType -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Type, Union, cast import torch import yahp as hp from composer.algorithms import AlgorithmHparams from composer.core import Algorithm, Event, Logger, State, surgery +from composer.core.types import Optimizers log = logging.getLogger(__name__) @@ -56,9 +57,17 @@ def initialize_object(self) -> "Alibi": return Alibi(**asdict(self)) -def apply_alibi(model: torch.nn.Module, heads_per_layer: int, max_sequence_length: int, - position_embedding_attribute: str, attention_module: torch.nn.Module, attr_to_replace: str, - alibi_attention: Callable, mask_replacement_function: Union[Callable, None]) -> None: +def apply_alibi( + model: torch.nn.Module, + heads_per_layer: int, + max_sequence_length: int, + position_embedding_attribute: str, + attention_module: Type[torch.nn.Module], + attr_to_replace: str, + alibi_attention: Callable, + mask_replacement_function: Union[Callable, None], + optimizers: Optional[Optimizers] = None, +) -> None: """ Removes position embeddings and replaces the attention function and attention mask according to `AliBi `_. @@ -85,6 +94,14 @@ def apply_alibi(model: torch.nn.Module, heads_per_layer: int, max_sequence_lengt attention mask. This is sometimes necessary for evaluating on sequence lengths longer than the model was initialized to accommodate. + optimizers (Optimizers, optional): Existing optimizers bound to ``model.parameters()``. + All optimizers that have already been constructed with, + ``model.parameters()`` must be specified here so they will optimize + the correct parameters. + + If the optimizer(s) are constructed *after* calling this function, + then it is safe to omit this parameter. These optimizers will see the correct + model parameters. """ zero_and_freeze_expand_position_embeddings(model=model, @@ -100,8 +117,9 @@ def convert_attention(module: torch.nn.Module, module_index: Optional[int] = Non module = mask_replacement_function(module, max_sequence_length) return module - transforms = {attention_module: convert_attention} - replaced_pairs = surgery.replace_module_classes(model, transforms) # type: ignore + replaced_pairs = surgery.replace_module_classes(model, + optimizers=optimizers, + policies={attention_module: convert_attention}) count = len(replaced_pairs) log.info(f" {count} instances of ALiBi added") @@ -183,7 +201,8 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: apply_alibi( state.model, - heads_per_layer=self.hparams.heads_per_layer, # type: ignore + optimizers=state.optimizers, + heads_per_layer=cast(int, self.hparams.heads_per_layer), max_sequence_length=self.hparams.max_sequence_length, position_embedding_attribute=self.hparams.position_embedding_attribute, attr_to_replace=self.hparams.attr_to_replace, diff --git a/composer/algorithms/blurpool/blurpool.py b/composer/algorithms/blurpool/blurpool.py index 422b0538d3..3ba87c9f38 100644 --- a/composer/algorithms/blurpool/blurpool.py +++ b/composer/algorithms/blurpool/blurpool.py @@ -14,6 +14,7 @@ from composer.algorithms import AlgorithmHparams from composer.algorithms.blurpool.blurpool_layers import BlurConv2d, BlurMaxPool2d from composer.core import Algorithm, Event, Logger, State, surgery +from composer.core.types import Optimizers log = logging.getLogger(__name__) @@ -27,6 +28,7 @@ def _log_surgery_result(model: torch.nn.Module): def apply_blurpool(model: torch.nn.Module, + optimizers: Optional[Optimizers] = None, replace_convs: bool = True, replace_maxpools: bool = True, blur_first: bool = True) -> None: @@ -38,6 +40,14 @@ def apply_blurpool(model: torch.nn.Module, Args: model: model to modify + optimizers (Optimizers, optional): Existing optimizers bound to ``model.parameters()``. + All optimizers that have already been constructed with, + ``model.parameters()`` must be specified here so they will optimize + the correct parameters. + + If the optimizer(s) are constructed *after* calling this function, + then it is safe to omit this parameter. These optimizers will see the correct + model parameters. replace_convs: replace strided :class:`torch.nn.Conv2d` modules with :class:`BlurConv2d` modules replace_maxpools: replace eligible :class:`torch.nn.MaxPool2d` modules @@ -57,7 +67,7 @@ def apply_blurpool(model: torch.nn.Module, _maybe_replace_strided_conv2d, blur_first=blur_first, ) - surgery.replace_module_classes(model, policies=transforms) + surgery.replace_module_classes(model, optimizers=optimizers, policies=transforms) _log_surgery_result(model) @@ -133,7 +143,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: """ assert state.model is not None - apply_blurpool(state.model, **asdict(self.hparams)) + apply_blurpool(state.model, optimizers=state.optimizers, **asdict(self.hparams)) self._log_results(event, state, logger) def _log_results(self, event: Event, state: State, logger: Logger) -> None: diff --git a/composer/algorithms/factorize/factorize.py b/composer/algorithms/factorize/factorize.py index 2d46b6a683..b947459ff7 100644 --- a/composer/algorithms/factorize/factorize.py +++ b/composer/algorithms/factorize/factorize.py @@ -13,6 +13,7 @@ from composer.algorithms.factorize.factorize_modules import (FactorizedConv2d, FactorizedLinear, factorizing_could_speedup) from composer.core import Algorithm, Event, Logger, State, surgery +from composer.core.types import Optimizers log = logging.getLogger(__name__) @@ -26,7 +27,10 @@ def _python_log_surgery_result(model: torch.nn.Module, new_class: Type[torch.nn. f'Model now has {num_replaced_modules} {new_class.__name__} modules') -def factorize_conv2d_modules(model: torch.nn.Module, min_channels: int, latent_channels: Union[int, float]): +def factorize_conv2d_modules(model: torch.nn.Module, + min_channels: int, + latent_channels: Union[int, float], + optimizers: Optional[Optimizers] = None): """Replaces :class:`torch.nn.Conv2d` modules in ``model`` with :class:`~composer.algorithms.factorize.FactorizedConv2d` modules. See :class:`Factorize` for details.""" def _maybe_replace_conv2d(module: torch.nn.Module, module_index: int) -> Optional[torch.nn.Module]: @@ -36,12 +40,17 @@ def _maybe_replace_conv2d(module: torch.nn.Module, module_index: int) -> Optiona return FactorizedConv2d.from_conv2d(module, module_index, latent_channels=latent_channels) return None # not enough rank reduction to be worth it - ret = surgery.replace_module_classes(model, {torch.nn.Conv2d: _maybe_replace_conv2d}) + ret = surgery.replace_module_classes(model, + optimizers=optimizers, + policies={torch.nn.Conv2d: _maybe_replace_conv2d}) _python_log_surgery_result(model, FactorizedConv2d) return ret -def factorize_linear_modules(model: torch.nn.Module, min_features: int, latent_features: Union[int, float]): +def factorize_linear_modules(model: torch.nn.Module, + min_features: int, + latent_features: Union[int, float], + optimizers: Optional[Optimizers] = None): """Replaces :class:`torch.nn.Linear` modules in ``model`` with :class:`~composer.algorithms.factorize.FactorizedLinear` modules. See :class:`Factorize` for details.""" def _maybe_replace_linear(module: torch.nn.Module, module_index: int) -> Optional[torch.nn.Module]: @@ -51,7 +60,9 @@ def _maybe_replace_linear(module: torch.nn.Module, module_index: int) -> Optiona return FactorizedLinear.from_linear(module, module_index, latent_features=latent_features) return None # not enough rank reduction to be worth it - ret = surgery.replace_module_classes(model, {torch.nn.Linear: _maybe_replace_linear}) + ret = surgery.replace_module_classes(model, + optimizers=optimizers, + policies={torch.nn.Linear: _maybe_replace_linear}) _python_log_surgery_result(model, FactorizedLinear) return ret @@ -175,7 +186,8 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: if self.hparams.factorize_convs: factorize_conv2d_modules(state.model, min_channels=self.hparams.min_channels, - latent_channels=self.hparams.latent_channels) + latent_channels=self.hparams.latent_channels, + optimizers=state.optimizers) num_factorized = surgery.count_module_instances(state.model, FactorizedConv2d) logger.metric_fit({ LOG_NUM_CONV2D_REPLACEMENTS_KEY: num_factorized, @@ -183,7 +195,8 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: if self.hparams.factorize_linears: factorize_linear_modules(state.model, min_features=self.hparams.min_features, - latent_features=self.hparams.latent_features) + latent_features=self.hparams.latent_features, + optimizers=state.optimizers) num_factorized = surgery.count_module_instances(state.model, FactorizedLinear) logger.metric_fit({ LOG_NUM_LINEAR_REPLACEMENTS_KEY: num_factorized, diff --git a/composer/algorithms/ghost_batchnorm/ghost_batchnorm.py b/composer/algorithms/ghost_batchnorm/ghost_batchnorm.py index 26b759e013..c4a7ec9eac 100644 --- a/composer/algorithms/ghost_batchnorm/ghost_batchnorm.py +++ b/composer/algorithms/ghost_batchnorm/ghost_batchnorm.py @@ -12,6 +12,7 @@ from composer.algorithms import AlgorithmHparams from composer.core import Algorithm, Event, Logger, State, surgery +from composer.core.types import Optimizers log = logging.getLogger(__name__) @@ -99,7 +100,9 @@ class GhostBatchNorm3d(_GhostBatchNorm): pass -def apply_ghost_batchnorm(model: torch.nn.Module, ghost_batch_size: int) -> torch.nn.Module: +def apply_ghost_batchnorm(model: torch.nn.Module, + ghost_batch_size: int, + optimizers: Optional[Optimizers] = None) -> torch.nn.Module: """Replace batch normalization modules with ghost batch normalization modules. Must be run before the model has been moved to accelerators and before @@ -108,6 +111,14 @@ def apply_ghost_batchnorm(model: torch.nn.Module, ghost_batch_size: int) -> torc Args: model: model to transform ghost_batch_size: size of sub-batches to normalize over + optimizers (Optimizers, optional): Existing optimizers bound to ``model.parameters()``. + All optimizers that have already been constructed with, + ``model.parameters()`` must be specified here so they will optimize + the correct parameters. + + If the optimizer(s) are constructed *after* calling this function, + then it is safe to omit this parameter. These optimizers will see the correct + model parameters. """ def maybe_replace(module: torch.nn.Module, module_index: int) -> Optional[torch.nn.Module]: @@ -117,7 +128,7 @@ def maybe_replace(module: torch.nn.Module, module_index: int) -> Optional[torch. # we have to specify class names explicitly because replace_module_classes # now checks if `module.__class__ == cls`, rather than `isinstance(module, cls)` transforms = {cls: maybe_replace for cls in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]} - surgery.replace_module_classes(model, policies=transforms) + surgery.replace_module_classes(model, optimizers=optimizers, policies=transforms) return model @@ -162,7 +173,7 @@ def apply(self, event: Event, state: State, logger: Optional[Logger] = None) -> """ assert state.model is not None, "Model must be in state" - apply_ghost_batchnorm(model=state.model, ghost_batch_size=self.ghost_batch_size) + apply_ghost_batchnorm(model=state.model, optimizers=state.optimizers, ghost_batch_size=self.ghost_batch_size) self._log_results(event, state, logger) def _log_results(self, event: Event, state: State, logger: Optional[Logger] = None) -> None: diff --git a/composer/algorithms/squeeze_excite/squeeze_excite.py b/composer/algorithms/squeeze_excite/squeeze_excite.py index cd71fa513c..0788b1d250 100644 --- a/composer/algorithms/squeeze_excite/squeeze_excite.py +++ b/composer/algorithms/squeeze_excite/squeeze_excite.py @@ -11,6 +11,7 @@ from composer.algorithms.algorithm_hparams import AlgorithmHparams from composer.core import Algorithm, Event, Logger, State, surgery +from composer.core.types import Optimizers log = logging.getLogger(__name__) @@ -81,16 +82,22 @@ def from_conv2d(module: torch.nn.Conv2d, module_index: int, latent_channels: flo return SqueezeExciteConv2d(conv=module, latent_channels=latent_channels) -def apply_se(model: torch.nn.Module, latent_channels: float, min_channels: int): +def apply_se( + model: torch.nn.Module, + latent_channels: float, + min_channels: int, + optimizers: Optional[Optimizers] = None, +): """See :class:`SqueezeExcite`""" - def convert_module(module: torch.nn.Conv2d, module_index: int): + def convert_module(module: torch.nn.Module, module_index: int): + assert isinstance(module, torch.nn.Conv2d), "should only be called with conv2d" if min(module.in_channels, module.out_channels) < min_channels: return None return SqueezeExciteConv2d.from_conv2d(module, module_index, latent_channels=latent_channels) - transforms = {torch.nn.Conv2d: convert_module} - surgery.replace_module_classes(model, transforms) # type: ignore + surgery.replace_module_classes(model, optimizers=optimizers, policies={torch.nn.Conv2d: convert_module}) + return model @@ -142,6 +149,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: logger (Logger): the training logger """ state.model = apply_se(state.model, + optimizers=state.optimizers, latent_channels=self.hparams.latent_channels, min_channels=self.hparams.min_channels) layer_count = surgery.count_module_instances(state.model, SqueezeExciteConv2d) diff --git a/composer/algorithms/stochastic_depth/stochastic_depth.py b/composer/algorithms/stochastic_depth/stochastic_depth.py index 3cc5857f60..951d59df4e 100644 --- a/composer/algorithms/stochastic_depth/stochastic_depth.py +++ b/composer/algorithms/stochastic_depth/stochastic_depth.py @@ -14,6 +14,7 @@ from composer.algorithms.stochastic_depth.sample_stochastic_layers import SampleStochasticBottleneck from composer.algorithms.stochastic_depth.stochastic_layers import StochasticBottleneck from composer.core import Algorithm, Event, Logger, State, surgery +from composer.core.types import Optimizers from composer.models.resnets import Bottleneck log = logging.getLogger(__name__) @@ -92,6 +93,7 @@ def validate(self): def apply_stochastic_depth(model: torch.nn.Module, stochastic_method: str, target_layer_name: str, + optimizers: Optional[Optimizers] = None, drop_rate: float = 0.2, drop_distribution: str = 'linear', use_same_gpu_seed: bool = True) -> None: @@ -113,6 +115,14 @@ def apply_stochastic_depth(model: torch.nn.Module, equivalent. The name must be registered in ``STOCHASTIC_LAYER_MAPPING`` dictionary with the target layer class and the stochastic layer class. Currently, only ``'ResNetBottleneck'`` is supported. + optimizers (Optimizers, optional): Existing optimizers bound to ``model.parameters()``. + All optimizers that have already been constructed with, + ``model.parameters()`` must be specified here so they will optimize + the correct parameters. + + If the optimizer(s) are constructed *after* calling this function, + then it is safe to omit this parameter. These optimizers will see the correct + model parameters. drop_rate: The base probability of dropping a layer or sample. Must be between 0.0 and 1.0. drop_distribution: How ``drop_rate`` is distributed across @@ -145,7 +155,7 @@ def apply_stochastic_depth(model: torch.nn.Module, raise ValueError(f"stochastic_method {stochastic_method} is not supported." f" Must be one of {list(STOCHASTIC_LAYER_MAPPING.keys())}") transforms[target_layer] = stochastic_from_target_layer - surgery.replace_module_classes(model, policies=transforms) + surgery.replace_module_classes(model, optimizers=optimizers, policies=transforms) def _update_drop_rate(module: torch.nn.Module, stochastic_block: Type[torch.nn.Module], drop_rate: float, @@ -258,6 +268,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: log.warning(f'No {self.hparams.target_layer_name} found in model! Algorithm will function as a no-op.') apply_stochastic_depth(state.model, + optimizers=state.optimizers, stochastic_method=self.hparams.stochastic_method, target_layer_name=self.hparams.target_layer_name, drop_rate=self.hparams.drop_rate, diff --git a/composer/core/engine.py b/composer/core/engine.py index 487e5e99af..10e7e94bc5 100755 --- a/composer/core/engine.py +++ b/composer/core/engine.py @@ -185,10 +185,13 @@ def _compile( Returns: algorithms_to_run(Sequence[Algorithm]): modified sequence of algorithms """ - from composer.algorithms import SelectiveBackprop + from composer.algorithms import SelectiveBackprop, StochasticDepth # Move selective backprop to the beginning while maintaining order of other algorithms - algorithms = sorted(algorithms_to_run, key=lambda x: not isinstance(x, SelectiveBackprop)) + algorithms = sorted(algorithms_to_run, + key=lambda x: not isinstance(x, SelectiveBackprop) and not isinstance(x, StochasticDepth)) + + print(event, algorithms) if event.is_after_event: """Establish a FILO queue of algorithms before_ and after_ an event. diff --git a/composer/core/state.py b/composer/core/state.py index 4b954e3c8f..28a4c636ef 100755 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -86,8 +86,8 @@ class State(Serializable): precision (str | Precision): The numerical precision to use for training. Should be one of ``[fp32, amp]``. precision_context ((precision: Precision) -> ContextManager): Function to produce a context manager to mandate precision. - optimizers (types.Optimizers): The optimizers being used to train the model. Multiple optimizers are not currently supported. - schedulers (types.Schedulers): The learning rate schedulers, typically wrapped in :class:`ComposableScheduler`. + optimizers (types.Optimizers, optional): The optimizers being used to train the model. Multiple optimizers are not currently supported. + schedulers (types.Schedulers, optional): The learning rate schedulers, typically wrapped in :class:`ComposableScheduler`. scaler (torch.cuda.amp.GradScaler, optional): The gradient scaler in use for mixed precision training. algorithms (Sequence[Algorithm]): The algorithms used for training. @@ -107,6 +107,7 @@ class State(Serializable): """ _max_duration: Time[int] + _steps_per_epoch: Optional[int] batch: types.Batch batch_num_samples: int batch_num_tokens: int @@ -140,16 +141,19 @@ def __init__( # algorithms and callbacks algorithms: Sequence[Algorithm] = tuple(), callbacks: Sequence[Callback] = tuple(), + + # steps per epoch + steps_per_epoch: Optional[int] = None, ): self.model = model self.grad_accum = grad_accum self.train_dataloader = train_dataloader self.eval_dataloader = eval_dataloader self.max_duration = max_duration + self.steps_per_epoch = steps_per_epoch self.timer = Timer() self._precision = Precision(precision) - self._steps_per_epoch = None self._precision_context = precision_context if optimizers is None: @@ -355,8 +359,17 @@ def steps_per_epoch(self): return self._steps_per_epoch @steps_per_epoch.setter - def steps_per_epoch(self, val: Optional[int]): - self._steps_per_epoch = val + def steps_per_epoch(self, steps_per_epoch: Optional[int]): + try: + dataloader_len = len(self.train_dataloader) + except (TypeError, NotImplementedError): + dataloader_len = None + if dataloader_len is not None and steps_per_epoch is not None and steps_per_epoch > dataloader_len: + warnings.warn( + textwrap.dedent(f"""SubsetNumBatchesWarning: The steps_per_epoch({steps_per_epoch}) + is greater than the number of batches in the training dataloader + ({dataloader_len})""")) + self._steps_per_epoch = steps_per_epoch @property def precision(self): diff --git a/composer/core/surgery.py b/composer/core/surgery.py index 536b1a3b3c..7d3a0e3a77 100644 --- a/composer/core/surgery.py +++ b/composer/core/surgery.py @@ -1,7 +1,12 @@ # Copyright 2021 MosaicML. All Rights Reserved. +import collections +import itertools import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +import textwrap +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, OrderedDict, Tuple, Type + +from composer.utils.iter_helpers import ensure_tuple try: from typing import Protocol @@ -12,6 +17,9 @@ from typing import Protocol import torch +import torch.distributed + +from composer.core.types import Optimizers log = logging.getLogger(__name__) @@ -35,15 +43,29 @@ def __call__(self, module: torch.nn.Module, module_index: int) -> Optional[torch ... +def _add_children_recursive( + module: torch.nn.Module, + children_to_parents_and_names: OrderedDict[torch.nn.Module, List[Tuple[torch.nn.Module, str]]], +) -> None: + # recursively build up children_to_parents_and_names so it maps a module to the list of + # (parent_module, attribute name) + for name, child in module.named_children(): + if child not in children_to_parents_and_names: + children_to_parents_and_names[child] = [] + _add_children_recursive(child, children_to_parents_and_names) + children_to_parents_and_names[child].append((module, name)) + + # adapted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_module.py#L408 def replace_module_classes( - model: torch.nn.Module, - policies: Dict[Any, ReplacementFunction], + module: torch.nn.Module, + policies: Mapping[Type[torch.nn.Module], ReplacementFunction], + optimizers: Optional[Optimizers] = None, recurse_on_replacements: bool = False, indices: Optional[Dict[Any, int]] = None, -) -> List[Tuple[torch.nn.Module, torch.nn.Module]]: +) -> Dict[torch.nn.Module, torch.nn.Module]: """Modify model in-place by recursively applying replacement policies. Replacement policies are a mapping - of source classes and `ReplacementFunction`. + of source classes and :class:`ReplacementFunction`. Examples: The following policy:: @@ -59,61 +81,85 @@ def replace_module_classes( Arguments: - module: Model to modify. - policies: Mapping of source class to replacement function. The - replacement may be either another module or `None`. If the latter, + module (torch.nn.Module): Model to modify. + policies (Mapping[torch.nn.Module, ReplacementFunction]): Mapping of source class to replacement function. The + replacement may be either another module or ``None``. If the latter, this replacement is skipped. - recurse_on_replacements: If true, policies will be applied to any module returned - by another policy. E.g., if one replaces a `Conv2d` with a module containing - another `Conv2d`, this new child `Conv2d` might also be replaced. This can recurse + recurse_on_replacements (bool): If true, policies will be applied to any module returned + by another policy. E.g., if one replaces a ``Conv2d`` with a module containing + another ``Conv2d``, this new child ``Conv2d`` might also be replaced. This can recurse infinitely if the replacement policies are not conditioned on module properties that change over the course of the recursion. - indices: A dictionary mapping module types to the number of times + indices (Dict[Any, int], optional): A dictionary mapping module types to the number of times they've occurred so far in the recursive traversal of - `model` and its child modules. Allows us to pass `module_index` + ``module`` and its child modules. Allows us to pass ``module_index`` to the replacement policies, so that a policy may switch behavior on the i-th instance of the module_class. Note that these indices may not correspond to the order in which modules get called in the forward pass. + optimizers (Optimizers, optional): One or more :class:`~torch.optim.Optimizer` objects. If provided, + this function will attempt to remove parameters in replaced modules + from these optimizers, and add parameters from the newly-created + modules. See :func:`update_params_in_optimizer` for more information. Returns: - replaced_pairs: a list of pairs of - (original module, replacement module), reflecting the replacements - applied to `module` and its children. + Dict[torch.nn.Module, torch.nn.Module]: + A dictionary of ``{original_module: replacement_module}`` + reflecting the replacements applied to ``module`` and its children. + """ - replaced_pairs = [] + if isinstance(module, torch.nn.parallel.DistributedDataParallel): + raise TypeError( + textwrap.dedent("""Surgery is not supported after a module is wrapped with + `torch.nn.parallel.DistributedDataParallel` Instead, please preform surgery on the underlying + `module.module` and re-wrap the `module.module` with `torch.nn.parallel.DistributedDataParallel`""")) + try: + import deepspeed + except ImportError: + pass + else: + if isinstance(module, deepspeed.DeepSpeedEngine): + raise TypeError( + textwrap.dedent("""Surgery is not supported after a module is wrapped with + `deepspeed.DeepSpeedEngine` Instead, please perform surgery on the underlying module`, + and re-wrap it with `deepspeed.DeepSpeedEngine`""")) + replaced_pairs = {} + children_to_parents_and_names: OrderedDict[torch.nn.Module, List[Tuple[torch.nn.Module, + str]]] = collections.OrderedDict() + _add_children_recursive(module, children_to_parents_and_names) indices = indices if indices is not None else {c: 0 for c in policies} - for name, child in model.named_children(): - already_recursed = False - child_class = child.__class__ - if child_class in policies: - module_index = indices[child_class] - replacement = policies[child_class]( + while len(children_to_parents_and_names) > 0: + child, parents = children_to_parents_and_names.popitem(last=False) + for policy_class, replacement_fn in policies.items(): + if not isinstance(child, policy_class): + continue + module_index = indices[policy_class] + replacement = replacement_fn( child, module_index=module_index, ) - indices[child_class] += 1 + indices[policy_class] += 1 if replacement is not None: - replaced_pairs.append((child, replacement)) + assert child not in replaced_pairs + replaced_pairs[child] = replacement + + for parent, name in parents: + # update each parent with the replaced child + setattr(parent, name, replacement) + # recurse on new child object if recurse_on_replacements: - # recurse on new child object - replaced_pairs += replace_module_classes( - replacement, - policies, - recurse_on_replacements=recurse_on_replacements, - indices=indices, - ) - already_recursed = True - setattr(model, name, replacement) - - if not already_recursed: - replaced_pairs += replace_module_classes( - child, - policies, - recurse_on_replacements=recurse_on_replacements, - indices=indices, - ) + children_to_parents_and_names[replacement] = list(parents) # copy the parents list + _add_children_recursive(replacement, children_to_parents_and_names) + if optimizers: + for old_module, new_module in replaced_pairs.items(): + update_params_in_optimizer(old_params=old_module.parameters(), + new_params=new_module.parameters(), + optimizers=optimizers) + elif len(replaced_pairs) > 0: + log.info( + f"optimizers was not provided. Be sure to either create the optimizer after invoking this method, or manually add new parameters to the existing optimizer." + ) return replaced_pairs @@ -142,3 +188,159 @@ def count_module_instances(model: torch.nn.Module, module_class: Type[torch.nn.M count += count_module_instances(child, module_class) return count + + +def _tensor_in(tensor: torch.Tensor, iterable: Iterable[torch.Tensor]): + """Returns whether `tensor is element` for any element in `iterable` + This function is necessary because `tensor in iterable` does not work + reliably for `Tensor`s. + See https://discuss.pytorch.org/t/how-to-judge-a-tensor-is-in-a-list/15998/4 + for further discussion. + """ + return any(tensor is elem for elem in iterable) + + +def _find_param_in_optimizer(param: torch.nn.parameter.Parameter, optimizer: torch.optim.Optimizer) -> int: + """Returns the index of the optimizer ``param_group`` containing ``param`` + Optimizers store their parameters within an iterable of ``dict``s called + :attr:`~torch.optim.Optimizer.param_groups`. + By default, there is only one group in :attr:`~torch.optim.Optimizer.param_groups` + that containing all the parameters, but there can be more than one. This + function is a simple utility to identify which parameter group in + :attr:`~torch.optim.Optimizer.param_groups` contains a given parameter, if any. The information + might be desirable to, e.g., inspect the optimizer settings being used + for a given parameter, or to remove unused parameter tensors from + the optimizer. + + Args: + param (torch.nn.parameter.Parameter): The parameter to search for. + optimizer (torch.optim.Optimizer): The optimizer to search within. + + Returns: + int: The index within `opt.param_groups` of the first group containing ``param``, + or `-1` if ``param`` is not in the ``opt`. + """ + for i, group in enumerate(optimizer.param_groups): + param_list: List[torch.nn.parameter.Parameter] = group['params'] + if _tensor_in(param, param_list): + return i + + return -1 + + +def update_params_in_optimizer(old_params: Iterable[torch.nn.parameter.Parameter], + new_params: Iterable[torch.nn.parameter.Parameter], optimizers: Optimizers) -> None: + """Removes old parameters from an optimizer and adds in new parameters + Parameters found in `old_params` but not `new_params` will be removed + from the optimizers. Similarly, parameters found in `new_params` but not + `old_params` will be added to the optimizer. Newly added parameters will + be added to the same optimizer `param_group` as the removed parameters + on a best-effort basis. If different removed parameters for a given + module are in different `param_group`s a RuntimeError will be thrown. + Dynamically removing parameters from an optimizer and adding parameters + to an existing `param_group` are not officially supported, so this + function may fail when PyTorch is updated. The recommended practice is + to instead recreate the optimizer when the parameter set changes if + possible. See `recommended practice `_. + To simply add new parameters without replacing existing ones, use + :meth:`~torch.optim.Optimizer.add_param_group`. + Args: + old_params: Parameters in this iterable should be removed if they are + not present in `new_params`. + new_params: Parameters in this iterable should be added if they are + not present in `old_params`. + optimizers (Optimizers): One or more `torch.optim.Optimizer` objects + Raises: + NotImplementedError: If `optimizers` contains more than one optimizer + RuntimeError: If not all removed parameters are found in the + same parameter group, or if any of them are not found at all + """ + if len(ensure_tuple(optimizers)) > 1: + raise NotImplementedError( + textwrap.dedent("""Surgery with multiple optimizers + is not yet supported.""")) + opt = ensure_tuple(optimizers)[0] + + # diff the two sets of parameters to find what needs to be removed or added + old_values = set(old_params) + new_values = set(new_params) + removed_params = old_values - new_values + added_params = new_values - old_values + + if len(removed_params) == 0 and len(added_params) == 0: + return # nothing to do + + # rip out the removed_params' states from the optimizer + for p in removed_params: + if _tensor_in(p, opt.state): # only true after training starts + opt.state.pop(p) + + if len(opt.param_groups) == 1: + group_idx = 0 + else: + # if there is more than one group, use the ripped out parameters to infer the group + # to add the new parameters into + old_group_idxs = [_find_param_in_optimizer(p, opt) for p in removed_params] + + if len(old_group_idxs) == 0: + raise RuntimeError("No parameters were removed, so unable to infer the group into which to add parameters.") + + missing_param_groups = [x for x in old_group_idxs if x < 0] + if len(missing_param_groups) > 0: + raise RuntimeError(f"Parameter groups {missing_param_groups} are not in the optimizer") + + if min(old_group_idxs) != max(old_group_idxs) and len(added_params): + raise RuntimeError( + textwrap.dedent("""Not all removed parameters are in the same parameter group. + This makes it unclear where to add the new parameters.""")) + group_idx = old_group_idxs[0] + + param_group = opt.param_groups[group_idx] + new_param_list = [p for p in param_group['params'] if not _tensor_in(p, removed_params)] + new_param_list += list(added_params) + log.info(f'adding {len(added_params)} new parameters to parameter group #{group_idx}') + param_group['params'] = new_param_list + + +def replace_params_in_optimizer(old_params: Iterable[torch.nn.parameter.Parameter], + new_params: Iterable[torch.nn.parameter.Parameter], optimizers: Optimizers) -> None: + """Fully replaces an optimizer's parameters. + + This differs from `update_params_in_optimizer` in that this method is capable + of replacing parameters spanning multiple param groups. To accomplish this, + this function assumes that parameters in `new_params` should inherit the + param group of the corresponding parameter from `old_params`. Thus, this + function also assumes that `old_params` and `new_params` have the same length. + Args: + old_params: Current parameters of the optimizer. + new_params: New parameters of the optimizer, given in the same order as + `old_params`. Must be the same length as `old_params`. + optimizers (Optimizers): One or more `torch.optim.Optimizer` objects. + Raises: + NotImplementedError: If `optimizers` contains more than one optimizer + RuntimeError: If `old_params` and `new_params` have different lengths, or + if a param from `old_params` cannot be found. + """ + if len(ensure_tuple(optimizers)) > 1: + raise NotImplementedError( + textwrap.dedent("""Surgery with multiple optimizers + is not yet supported.""")) + + opt = ensure_tuple(optimizers)[0] + opt.state.clear() + + param_to_idxs_map = {} + for group_idx, param_group in enumerate(opt.param_groups): + param_list = param_group["params"] + for param_idx, param in enumerate(param_list): + param_to_idxs_map[param] = (group_idx, param_idx) + + for old_param, new_param in itertools.zip_longest(old_params, new_params): + if old_params is None or new_params is None: + raise RuntimeError("old_params and new_params have different lengths.") + + if not old_param in param_to_idxs_map: + raise RuntimeError(f"Parameter {old_param} is missing from the optimizer.") + + group_idx, param_idx = param_to_idxs_map[old_param] + opt.param_groups[group_idx]["params"][param_idx] = new_param diff --git a/composer/optim/pytorch_future.py b/composer/optim/pytorch_future.py index 3d73d30bad..c95e9374cf 100644 --- a/composer/optim/pytorch_future.py +++ b/composer/optim/pytorch_future.py @@ -32,7 +32,7 @@ class WarmUpLR(_LRScheduler): learning rate schedule. Default: ``-1``. verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``. - interval (str): Frequency of ``step()`` calls, either ``step`` or ``epoch``. Default: ``step``. + interval (str): Frequency of ``step()`` calls, either ``step`` or ``epoch``. Default: ``epoch``. Example: >>> # Assuming optimizer uses lr = 0.05 for all groups @@ -68,7 +68,7 @@ def __init__(self, warmup_method="linear", last_epoch=-1, verbose=False, - interval='step'): + interval='epoch'): if warmup_method not in ("constant", "linear"): raise ValueError("Only 'constant' or 'linear' warmup_method accepted, but got {}".format(warmup_method)) diff --git a/composer/optim/scheduler.py b/composer/optim/scheduler.py index f770bd390e..95c1385a30 100755 --- a/composer/optim/scheduler.py +++ b/composer/optim/scheduler.py @@ -3,7 +3,7 @@ import logging from abc import ABC from dataclasses import asdict, dataclass -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union import torch import yahp as hp @@ -11,9 +11,10 @@ StepLR, _LRScheduler) from composer.core.time import TimeUnit -from composer.core.types import Optimizer, Scheduler, Time +from composer.core.types import Optimizer, Scheduler, Schedulers, Time from composer.optim.pytorch_future import LinearLR, WarmUpLR from composer.utils._time_conversion import convert as convert_time +from composer.utils.iter_helpers import ensure_tuple log = logging.getLogger(__name__) @@ -76,7 +77,7 @@ def initialize_object( samples_per_epoch: Optional[int] = None, dataset_num_tokens: Optional[int] = None, max_training_duration: Optional[Union[str, Time[int]]] = None, - ) -> Tuple[Scheduler, str]: + ) -> Scheduler: """Create the scheduler object from the current hparams. Args: @@ -86,7 +87,7 @@ def initialize_object( dataset_num_tokens (int, optional): The number of tokens in the dataset. max_training_duration (str or Time, optional): The total training duration. Returns: - (Scheduler, str): (The parametrized scheduler instance, schedule step interval) + Scheduler: The parametrized scheduler instance """ assert self.scheduler_object is not None, "Scheduler Hparams needs scheduler_object to initialize." @@ -103,7 +104,7 @@ def initialize_object( obj = self.scheduler_object(optimizer, **kwargs) obj.interval = self.interval # type: ignore obj.steps_per_epoch = steps_per_epoch # type: ignore - return obj, self.interval + return obj class ConstantLR(_LRScheduler): @@ -358,7 +359,7 @@ class ComposedScheduler(_LRScheduler): >>> # lr = 0.729 if epoch == 4 >>> scheduler1 = WarmUpLR(self.opt, warmup_factor=0.1, warmup_iters=2, warmup_method="constant") >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) - >>> scheduler = ComposedScheduler(zip([scheduler1, scheduler2], ["epoch", "epoch"])) + >>> scheduler = ComposedScheduler([scheduler1, scheduler2]) >>> for epoch in range(100): >>> train(...) >>> validate(...) @@ -372,22 +373,18 @@ class ComposedScheduler(_LRScheduler): >>> # lr = 0.2 if epoch == 4 . # MultiStepLR effect starts here >>> scheduler1 = WarmUpLR(self.opt, warmup_factor=0.1, warmup_iters=2, warmup_method="constant") >>> scheduler2 = MultiStepLR(optimizer, milestones=[4], gamma=0.2) - >>> scheduler = ComposedScheduler(zip([scheduler1, scheduler2], ["epoch", "epoch"])) + >>> scheduler = ComposedScheduler([scheduler1, scheduler2]) >>> for epoch in range(100): >>> train(...) >>> validate(...) >>> scheduler.step() """ - def __init__(self, schedulers): - - # check for tuple - if not all(isinstance(scheduler, tuple) for scheduler in schedulers): - raise ValueError('Schedulers must be a tuple of (Scheduler, interval), ' - 'where interval is one of "epoch" or "batch".') - + def __init__(self, schedulers: Schedulers): + schedulers = ensure_tuple(schedulers) self._validate_same_optimizers(schedulers) - self.schedulers, self.intervals = list(zip(*schedulers)) # unpack (scheduler, interval) + self.schedulers = schedulers + self.intervals = [getattr(scheduler, "interval", "epoch") for scheduler in schedulers] # generous with spelling (batch, batches)/(step, steps) and (epoch, epochs) self.intervals = [INTERVAL_MAP[interval] for interval in self.intervals] @@ -467,9 +464,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: scheduler.load_state_dict(state_dict["schedulers"][scheduler.__class__.__qualname__]) self._warmup_counter = state_dict["_warmup_counter"] - def _validate_same_optimizers(self, schedulers): + def _validate_same_optimizers(self, schedulers: Schedulers): """Verify that all schedulers correspond to the same optimizer.""" - for scheduler_idx in range(1, len(schedulers)): - if (schedulers[scheduler_idx][0].optimizer != schedulers[0][0].optimizer): # type: ignore + schedulers = ensure_tuple(schedulers) + for i, scheduler in enumerate(schedulers): + if (getattr(scheduler, "optimizer") != getattr(schedulers[0], "optimizer")): raise ValueError("ComposedScheduler expects all schedulers to belong to the same optimizer, but " - "got schedulers at index {} and {} to be different".format(0, scheduler_idx)) + f"got schedulers at index 0 and {i} to be different") diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index c8925c4bce..5911817ac5 100755 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -15,17 +15,19 @@ import torch.utils.data from torch.cuda.amp.grad_scaler import GradScaler from torch.nn.parallel import DistributedDataParallel +from torch.optim.lr_scheduler import CosineAnnealingLR from torchmetrics.collections import MetricCollection from torchmetrics.metric import Metric -from composer.core import Callback, DataSpec, Engine, Event, Logger, State, Time +from composer.core import Callback, DataSpec, Engine, Event, Logger, State, Time, surgery from composer.core.algorithm import Algorithm from composer.core.logging import BaseLoggerBackend, LogLevel -from composer.core.types import Batch, BreakEpochException, DataLoader, Metrics, Precision +from composer.core.time import TimeUnit +from composer.core.types import Batch, BreakEpochException, DataLoader, Metrics, Optimizers, Precision, Schedulers from composer.loggers.tqdm_logger import TQDMLoggerBackend from composer.models.base import BaseMosaicModel -from composer.optim import (ComposedScheduler, CosineAnnealingLRHparams, DecoupledSGDWHparams, OptimizerHparams, - SchedulerHparams, WarmUpLRHparams) +from composer.optim import ComposedScheduler +from composer.optim.decoupled_weight_decay import DecoupledSGDW from composer.optim.scheduler import ensure_warmup_last from composer.profiler.profiler_hparams import ProfilerHparams from composer.trainer.checkpoint_hparams import CheckpointLoaderHparams, CheckpointSaverHparams @@ -57,21 +59,13 @@ class Trainer: or dict of :class:`DataSpec` kwargs for the training data. eval_dataloader (DataLoader, DataSpec, or dict): The :class:`DataLoader`, :class:`DataSpec`, or dict of :class:`DataSpec` kwargs for the evaluation data. - max_duration (Union[str, `~composer.core.Time`]): The maxmimum number amount of Time to train for. - See `~composer.core.Time` for details. + max_duration (Time or str): The maximum duration to train. See `~composer.core.Time` for details. algorithms (List[Algorithm], optional): The algorithms to use during training. (default: ``[]``) - optimizer_hparams: (OptimizerHparams, optional): The OptimizerHparams for constructing - the optimizer for training. Must pass OptimizerHparams instead of a `torch.optim.Optimizer` - object because the optimizer has to be constructed after certain algorithms which modify - the model architecture have run on the model. (default: - ``MosaicMLSGDWHparams(lr=0.1, momentum=0.9, weight_decay=1.0e-4)``) - schedulers_hparams: (Union[SchedulerHparams, List[SchedulerHparams]], optional): The - SchedulerHparams for constructing the one or more learning rate schedulers used - during training. Must pass SchedulerHparams instead of a `torch.optim.lr_scheduler._LRScheduler` - object because the scheduler needs an optimizer to be constructed and we construct the optimizer - in `__init__`. (default: - ``[CosineAnnealingLRHparams(T_max=f"{max_epochs}ep"), WarmUpLRHparams()]``). + optimizers: (Optimizers, optional): The optimizers. + (default: ``DecoupledSGDW(model.parameters(), lr=0.1)``) + schedulers: (Schedulers, optional): The schedulers. + (default: ``[CosineAnnealingLR()]``). device (str or Device, optional): The device to use for training. Either `cpu` or `gpu`. (default `cpu`) grad_accum (int, optional): The number of microbatches to split a per-device batch into. Gradients @@ -125,8 +119,8 @@ def __init__( eval_dataloader: Union[DataLoader, DataSpec], max_duration: Union[str, Time], algorithms: Optional[List[Algorithm]] = None, - optimizer_hparams: Optional[OptimizerHparams] = None, - schedulers_hparams: Optional[Union[SchedulerHparams, List[SchedulerHparams]]] = None, + optimizers: Optional[Optimizers] = None, + schedulers: Optional[Schedulers] = None, # device device: Optional[Union[str, Device]] = None, @@ -171,6 +165,9 @@ def __init__( # self._use_grad_scaling() will raise a RuntimeError if grad scaling is not available when it is required warnings.filterwarnings(action="ignore", message="torch.cuda.amp.GradScaler") + if isinstance(max_duration, str): + max_duration = Time.from_timestring(max_duration) + self.config = config if isinstance(deepspeed_hparams, dict): @@ -237,6 +234,39 @@ def __init__( self._train_data_spec = train_dataloader self._eval_data_spec = eval_dataloader + if eval_subset_num_batches is not None: + try: + eval_dataloader_len = len(eval_dataloader.dataloader) + except (NotImplementedError, TypeError): + pass + else: + if eval_subset_num_batches > eval_dataloader_len: + warnings.warn( + textwrap.dedent( + f"""SubsetNumBatchesWarning: The eval_subset_num_batches({eval_subset_num_batches}) + is greater than the number of batches in the evaluation dataloader + ({len(eval_dataloader.dataloader)})""")) + self._eval_subset_num_batches = eval_subset_num_batches + + if not optimizers: + optimizers = DecoupledSGDW(list(model.parameters()), lr=0.1) + warnings.warn(f"No optimizer was specified. Defaulting to {repr(optimizers)}") + + num_optimizers = len(ensure_tuple(optimizers)) + + if num_optimizers != 1: + raise NotImplementedError(f"Only one optimizer is supported; found {num_optimizers} optimizers") + + if not schedulers: + optimizer = ensure_tuple(optimizers)[0] + if not max_duration.unit == TimeUnit.EPOCH: + raise ValueError("If a scheduler is not provided, max duration must be in epochs") + schedulers = CosineAnnealingLR(optimizer, T_max=max_duration.value) + warnings.warn(f"No scheduler was specified. Defaulting to {repr(schedulers)}") + if not isinstance(schedulers, (tuple, list)): + schedulers = [schedulers] + schedulers = ComposedScheduler(schedulers) + self.state = State( max_duration=max_duration, algorithms=algorithms, @@ -247,6 +277,9 @@ def __init__( precision_context=precision_context, train_dataloader=train_dataloader.dataloader, eval_dataloader=eval_dataloader.dataloader, + optimizers=optimizers, + steps_per_epoch=train_subset_num_batches, + schedulers=schedulers, ) # Configure the profiler @@ -254,26 +287,6 @@ def __init__( self.state.profiler = profiler.initialize_object(self.state) self.state.callbacks.extend(self.state.profiler.event_handlers) - # Steps per epoch - if train_subset_num_batches is not None: - if train_subset_num_batches > self.state.steps_per_epoch: - warnings.warn( - textwrap.dedent( - f"""SubsetNumBatchesWarning: The train_subset_num_batches({train_subset_num_batches}) - is greater than the number of batches in the training dataloader - ({self.state.steps_per_epoch})""")) - else: - self.state.steps_per_epoch = train_subset_num_batches - - if eval_subset_num_batches is not None: - if eval_subset_num_batches > len(self.state.eval_dataloader): - warnings.warn( - textwrap.dedent(f"""SubsetNumBatchesWarning: The eval_subset_num_batches({eval_subset_num_batches}) - is greater than the number of batches in the evaluation dataloader - ({len(self.state.eval_dataloader)})""")) - - self._eval_subset_num_batches = eval_subset_num_batches - if log_destinations is None: log_destinations = [TQDMLoggerBackend()] self.logger = Logger(self.state, log_destinations) @@ -292,36 +305,10 @@ def __init__( if deterministic_mode: reproducibility.configure_deterministic_mode() - # run INIT event before optimizers and schedulers are created self.engine.run_event(Event.INIT) - # Need to use hparams here because optimizer and schedulers need to be created after Event.INIT - if not optimizer_hparams: - optimizer_hparams = DecoupledSGDWHparams(lr=0.1, momentum=0.9, weight_decay=1.0e-4) - if not schedulers_hparams: - schedulers_hparams = [CosineAnnealingLRHparams(T_max=str(max_duration)), WarmUpLRHparams()] - if not isinstance(schedulers_hparams, list): - schedulers_hparams = [schedulers_hparams] - optimizer = optimizer_hparams.initialize_object(param_group=self.state.model.parameters()) - if self._train_data_spec.num_samples is None or self.state.train_dataloader.batch_size is None: - samples_per_epoch = None - else: - batch_size = self.state.train_dataloader.batch_size * dist.get_world_size() - - samples_per_epoch = min(self.state.steps_per_epoch * batch_size, self._train_data_spec.num_samples) - schedulers = [ - x.initialize_object(optimizer=optimizer, - max_training_duration=self.state.max_duration, - steps_per_epoch=self.state.steps_per_epoch, - samples_per_epoch=samples_per_epoch, - dataset_num_tokens=self._train_data_spec.num_tokens) - for x in ensure_warmup_last(schedulers_hparams) - ] - self.state.optimizers = optimizer - self.state.schedulers = ComposedScheduler(schedulers=schedulers) - assert isinstance(self.state.model, BaseMosaicModel) - self.original_model = self.state.model # type: ignore # TODO(ravi) -- update the state to add an original model helper + self.original_model = self.state.model # TODO(ravi) -- update the state to add an original model helper self.checkpoint_saver = None if checkpoint_saver is not None: @@ -353,7 +340,17 @@ def __init__( self.seed = restored_seed if not self.deepspeed_enabled: + host_model_params = self.state.model.parameters() self.state.model = self.device.module_to_device(self.state.model) + device_model_params = self.state.model.parameters() + + # use surgery to update the parameters of the optimizers, now that the model is on the device + # see https://pytorch.org/docs/stable/optim.html#constructing-it + surgery.replace_params_in_optimizer(old_params=host_model_params, + new_params=device_model_params, + optimizers=self.state.optimizers) + + # Move any remaining optimizer parameters onto the device self.state.optimizers = map_collection(self.state.optimizers, self.device.optimizer_to_device) # wrap model with DDP @@ -400,7 +397,7 @@ def create_from_hparams(cls, hparams: TrainerHparams) -> Trainer: textwrap.dedent(f"""SubsetNumBatchesWarning: When specifying train_subset_num_batches, (set to {hparams.train_subset_num_batches}), train_datset.shuffle should be set to False. Otherwise, each training epoch may load a different subset of samples.""")) - train_dataloader = hparams.train_dataset.initialize_object(train_device_batch_size, hparams.dataloader) + train_data = hparams.train_dataset.initialize_object(train_device_batch_size, hparams.dataloader) eval_device_batch_size = hparams.eval_batch_size // dist.get_world_size() if hparams.val_dataset.shuffle and hparams.eval_subset_num_batches is not None: @@ -408,16 +405,49 @@ def create_from_hparams(cls, hparams: TrainerHparams) -> Trainer: textwrap.dedent(f"""SubsetNumBatchesWarning: When specifying eval_subset_num_batches, (set to {hparams.eval_subset_num_batches}), val_dataset.shuffle should be set to False. Otherwise, each evaluation epoch may load a different subset of samples.""")) - eval_dataloader = hparams.val_dataset.initialize_object(eval_device_batch_size, hparams.dataloader) + eval_data = hparams.val_dataset.initialize_object(eval_device_batch_size, hparams.dataloader) + + optimizers = hparams.optimizer.initialize_object(model.parameters()) + + train_dataloader = train_data + + samples_per_epoch = None + tokens_per_epoch = None + + if isinstance(train_dataloader, DataSpec): + if train_dataloader.num_samples is not None: + samples_per_epoch = train_dataloader.num_samples + tokens_per_epoch = train_dataloader.num_tokens + train_dataloader = train_dataloader.dataloader + + try: + steps_per_epoch = len(train_dataloader) + except (AttributeError, NotImplementedError): + steps_per_epoch = None + + batch_size = None + if train_dataloader.batch_size is not None: + batch_size = train_dataloader.batch_size * dist.get_world_size() + + if samples_per_epoch is None and steps_per_epoch is not None and batch_size is not None: + samples_per_epoch = steps_per_epoch * batch_size + + schedulers = [ + x.initialize_object(optimizer=optimizers, + max_training_duration=hparams.max_duration, + steps_per_epoch=steps_per_epoch, + samples_per_epoch=samples_per_epoch, + dataset_num_tokens=tokens_per_epoch) for x in ensure_warmup_last(hparams.schedulers) + ] trainer = cls( model=model, - train_dataloader=train_dataloader, - eval_dataloader=eval_dataloader, + train_dataloader=train_data, + eval_dataloader=eval_data, max_duration=hparams.max_duration, algorithms=algorithms, - optimizer_hparams=hparams.optimizer, - schedulers_hparams=hparams.schedulers, + optimizers=optimizers, + schedulers=schedulers, # device device=device, @@ -545,13 +575,6 @@ def _train_loop(self) -> None: # shorthand state = self.state - assert state.optimizers is not None - assert state.schedulers is not None - - if len(ensure_tuple(state.optimizers)) != 1: - raise NotImplementedError("The Mosaic trainer only supports one optimizer; " - f"found {len(ensure_tuple(state.optimizers))} optimizers") - # print training start self.logger.metric_fit({"trainer/algorithms": [str(algo) for algo in self.state.algorithms]}) diff --git a/tests/algorithms/test_blurpool_algorithm.py b/tests/algorithms/test_blurpool_algorithm.py index 7032dce7b6..f3c04a8233 100644 --- a/tests/algorithms/test_blurpool_algorithm.py +++ b/tests/algorithms/test_blurpool_algorithm.py @@ -4,15 +4,17 @@ Test the blurpool algorithm. Primitives are tested in test_blurpool.py """ import itertools +from typing import List from unittest.mock import MagicMock import pytest import torch from composer.algorithms import BlurPool, BlurPoolHparams +from composer.algorithms.blurpool import apply_blurpool from composer.algorithms.blurpool.blurpool_layers import BlurConv2d, BlurMaxPool2d -from composer.core import Event, State -from composer.core.types import DataLoader, Model, Precision +from composer.core import Event, State, surgery +from composer.core.types import DataLoader, Logger, Model, Precision from tests.fixtures.models import SimpleConvModel @@ -45,7 +47,7 @@ def dummy_logger(): return MagicMock() -def test_blurconv(state, blurpool_instance, dummy_logger): +def test_blurconv(state: State, blurpool_instance: BlurPool, dummy_logger: Logger): blurpool_instance.apply(Event.INIT, state, dummy_logger) assert isinstance(state.model.module, SimpleConvModel) @@ -55,21 +57,21 @@ def test_blurconv(state, blurpool_instance, dummy_logger): assert type(state.model.module.conv1) is torch.nn.Conv2d -def test_maybe_replace_strided_conv_stride(state, blurpool_instance, dummy_logger): +def test_maybe_replace_strided_conv_stride(state: State, blurpool_instance: BlurPool, dummy_logger: Logger): blurpool_instance.apply(Event.INIT, state, dummy_logger) assert isinstance(state.model.module, SimpleConvModel) assert type(state.model.module.conv3) is torch.nn.Conv2d # stride = 1, should be no replacement -def test_maybe_replace_strided_conv_channels(state, blurpool_instance, dummy_logger): +def test_maybe_replace_strided_conv_channels(state: State, blurpool_instance: BlurPool, dummy_logger: Logger): blurpool_instance.apply(Event.INIT, state, dummy_logger) assert isinstance(state.model.module, SimpleConvModel) assert type(state.model.module.conv2) is torch.nn.Conv2d # channels < 16, should be no replacement -def test_blurconv_weights_preserved(state, blurpool_instance, dummy_logger): +def test_blurconv_weights_preserved(state: State, blurpool_instance: BlurPool, dummy_logger: Logger): assert isinstance(state.model.module, SimpleConvModel) original_weights = state.model.module.conv1.weight.clone() @@ -84,7 +86,7 @@ def test_blurconv_weights_preserved(state, blurpool_instance, dummy_logger): assert torch.allclose(original_weights, new_weights) -def test_blurpool(state, blurpool_instance, dummy_logger): +def test_blurpool(state: State, blurpool_instance: BlurPool, dummy_logger: Logger): blurpool_instance.apply(Event.INIT, state, dummy_logger) assert isinstance(state.model.module, SimpleConvModel) @@ -94,18 +96,39 @@ def test_blurpool(state, blurpool_instance, dummy_logger): assert type(state.model.module.pool1) is torch.nn.MaxPool2d -def test_blurpool_wrong_event(state, blurpool_instance): +def test_blurpool_wrong_event(state: State, blurpool_instance: BlurPool): assert blurpool_instance.match(Event.BATCH_START, state) == False -def test_blurpool_correct_event(state, blurpool_instance): +def test_blurpool_correct_event(state: State, blurpool_instance: BlurPool): assert blurpool_instance.match(Event.INIT, state) == True -def test_blurpool_algorithm_logging(state, blurpool_instance, dummy_logger): +def test_blurpool_algorithm_logging(state: State, blurpool_instance: BlurPool, dummy_logger: Logger): blurpool_instance.apply(Event.INIT, state, dummy_logger) dummy_logger.metric_fit.assert_called_once_with({ 'blurpool/num_blurpool_layers': 1 if blurpool_instance.hparams.replace_maxpools else 0, 'blurpool/num_blurconv_layers': 1 if blurpool_instance.hparams.replace_convs else 0, }) + + +def test_blurconv2d_optimizer_params_updated(): + model = SimpleConvModel() + orig_conv = model.conv1 + assert orig_conv.stride == (2, 2) # fail fast if test model changes + opt = torch.optim.SGD(model.parameters(), lr=.01) + apply_blurpool(model, optimizers=opt) + new_conv = model.conv1 + param_list: List[torch.Tensor] = opt.param_groups[0]['params'] + + # old params removed + assert not surgery._tensor_in(orig_conv.weight, param_list) + + # new params added + new_conv2d = new_conv.conv + assert isinstance(new_conv2d, torch.nn.Module) + new_weight = new_conv2d.weight + assert new_weight is not orig_conv.weight + assert isinstance(new_weight, torch.Tensor) + assert surgery._tensor_in(new_weight, param_list) diff --git a/tests/algorithms/test_scale_schedule.py b/tests/algorithms/test_scale_schedule.py index 03619d7995..2a7a118c57 100644 --- a/tests/algorithms/test_scale_schedule.py +++ b/tests/algorithms/test_scale_schedule.py @@ -73,7 +73,7 @@ def test_scale_schedule_cosine_warm_restarts(self, optimizer: Optimizer, ssr: fl def test_scale_schedule_warmup(self, optimizer: Optimizer, ssr: float): targets = [0.5] * 4 + [1.0] * 5 # no effect - scheduler = WarmUpLR(optimizer, warmup_factor=0.5, warmup_iters=4, warmup_method='constant') + scheduler = WarmUpLR(optimizer, warmup_factor=0.5, warmup_iters=4, warmup_method='constant', interval='step') epochs = int(9 * ssr) targets = targets[:epochs] self._test(targets, scheduler, epochs, optimizer, ssr) diff --git a/tests/fixtures/dummy_fixtures.py b/tests/fixtures/dummy_fixtures.py index 0d785d4650..b78c2a5022 100755 --- a/tests/fixtures/dummy_fixtures.py +++ b/tests/fixtures/dummy_fixtures.py @@ -8,7 +8,7 @@ import torch.utils.data from composer import Logger, State -from composer.core.types import DataLoader, DataSpec, Model, Precision +from composer.core.types import DataLoader, DataSpec, Model, Optimizer, Precision, Scheduler from composer.datasets import DataloaderHparams, DatasetHparams from composer.models import ModelHparams, MosaicClassifier from composer.optim import AdamHparams, ExponentialLRHparams @@ -74,8 +74,19 @@ def dummy_val_dataset_hparams(dummy_model: SimpleBatchPairModel, ) +@pytest.fixture +def dummy_optimizer(dummy_model: SimpleBatchPairModel): + return torch.optim.SGD(dummy_model.parameters(), lr=0.001) + + +@pytest.fixture +def dummy_scheduler(dummy_optimizer: Optimizer): + return torch.optim.lr_scheduler.LambdaLR(dummy_optimizer, lambda _: 1.0) + + @pytest.fixture() def dummy_state_without_rank(dummy_model: SimpleBatchPairModel, dummy_train_dataloader: DataLoader, + dummy_optimizer: Optimizer, dummy_scheduler: Scheduler, dummy_val_dataloader: DataLoader) -> State: state = State( model=dummy_model, @@ -83,6 +94,8 @@ def dummy_state_without_rank(dummy_model: SimpleBatchPairModel, dummy_train_data grad_accum=1, train_dataloader=dummy_train_dataloader, eval_dataloader=dummy_val_dataloader, + optimizers=dummy_optimizer, + schedulers=dummy_scheduler, max_duration="10ep", ) diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index 7c1806af65..2b48210238 100755 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -34,11 +34,20 @@ def __init__(self, in_shape: Tuple[int, ...], num_classes: int) -> None: self.train_acc = torchmetrics.Accuracy() self.val_acc = torchmetrics.Accuracy() + # Important: It is crucial that the FC layers are bound to `self` + # for the optimizer surgery tests. + # These tests attempt to perform surgery on `fc1` layer, and we want + # to make sure that post-surgery, self.fc1 refers to the same parameters + # as self.net[1] + self.fc1 = torch.nn.Linear(in_features_flattened, 5) + + self.fc2 = torch.nn.Linear(5, num_classes) + self.net = torch.nn.Sequential( torch.nn.Flatten(), - torch.nn.Linear(in_features_flattened, 5), + self.fc1, torch.nn.ReLU(), - torch.nn.Linear(5, num_classes), + self.fc2, torch.nn.Softmax(dim=-1), ) diff --git a/tests/test_load.py b/tests/test_load.py index fd4e896a42..d5c8dffc58 100755 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -21,6 +21,8 @@ def get_model_algs(model_name: str) -> List[str]: algs = algorithms.list_algorithms() + algs.remove("dummy") + algs.remove("no_op_model") is_image_model = any(x in model_name for x in ("resnet", "mnist", "efficientnet")) if is_image_model: algs.remove("alibi") @@ -39,6 +41,9 @@ def get_model_algs(model_name: str) -> List[str]: @pytest.mark.parametrize('model_name', model_names) @pytest.mark.timeout(15) def test_load(model_name: str): + if model_name in ['deeplabv3_ade20k']: + pytest.skip(f"Model {model_name} requires GPU") + trainer_hparams = trainer.load(model_name) trainer_hparams.precision = Precision.FP32 trainer_hparams.algorithms = algorithms.load_multiple(*get_model_algs(model_name)) diff --git a/tests/test_surgery.py b/tests/test_surgery.py index 695be367b9..bf172e9e91 100644 --- a/tests/test_surgery.py +++ b/tests/test_surgery.py @@ -1,5 +1,6 @@ # Copyright 2021 MosaicML. All Rights Reserved. +from typing import List, Mapping, Tuple, Type, cast from unittest.mock import Mock import pytest @@ -7,11 +8,13 @@ from torch import nn from composer.core import surgery +from composer.core.types import Optimizer +from tests.fixtures.models import SimpleBatchPairModel class RecursiveLinear(nn.Linear): - def __init__(self, in_features, out_features): + def __init__(self, in_features: int, out_features: int): super().__init__(in_features, out_features) # submodule has modified out_features to prevent infinite recursion during test @@ -29,15 +32,16 @@ def __init__(self): self.fc2 = nn.Linear(in_features=32, out_features=10) @staticmethod - def maybe_replace_linear(module, module_index): - if module.out_features in (10, 9): - return RecursiveLinear(module.in_features, module.out_features) + def maybe_replace_linear(module: torch.nn.Module, module_index: int): + del module_index # unused + if module.out_features in (10, 9) and not isinstance(module, RecursiveLinear): + return RecursiveLinear(cast(int, module.in_features), cast(int, module.out_features)) return None - def policy(self): + def policy(self) -> Mapping[Type[torch.nn.Module], surgery.ReplacementFunction]: return {nn.Linear: self.maybe_replace_linear} - def validate_replacements(self, recurse_on_replacements): + def validate_replacements(self, recurse_on_replacements: bool): assert type(self.fc1) is nn.Linear assert type(self.fc2) is RecursiveLinear @@ -54,12 +58,13 @@ class ModuleIdxReplacementPolicy(SimpleReplacementPolicy): """ @staticmethod - def maybe_replace_linear(module, module_index): + def maybe_replace_linear(module: torch.nn.Module, module_index: int): if module_index == 0: - return RecursiveLinear(module.in_features, module.out_features) + return RecursiveLinear(cast(int, module.in_features), cast(int, module.out_features)) return None - def validate_replacements(self, recurse_on_replacements): + def validate_replacements(self, recurse_on_replacements: bool): + del recurse_on_replacements # unused assert type(self.fc1) is RecursiveLinear assert type(self.fc2) is nn.Linear assert type(self.fc1.submodule) is nn.Linear @@ -70,22 +75,24 @@ class NoOpReplacementPolicy(SimpleReplacementPolicy): def policy(self): return {nn.Conv2d: Mock(side_effect=AssertionError('test should not match on this layer'))} - def validate_replacements(self, recurse_on_replacements): + def validate_replacements(self, recurse_on_replacements: bool): + del recurse_on_replacements # unused assert type(self.fc1) is nn.Linear assert type(self.fc2) is nn.Linear @pytest.mark.parametrize('recurse_on_replacements', [True, False]) -@pytest.mark.parametrize('model', [ +@pytest.mark.parametrize('model_cls', [ SimpleReplacementPolicy, ModuleIdxReplacementPolicy, NoOpReplacementPolicy, ]) -def test_module_replacement(model, recurse_on_replacements): - model = model() +def test_module_replacement(model_cls: Type[SimpleReplacementPolicy], recurse_on_replacements: bool): + model = model_cls() surgery.replace_module_classes( model, - model.policy(), + optimizers=None, + policies=model.policy(), recurse_on_replacements=recurse_on_replacements, ) @@ -98,7 +105,7 @@ def __init__(self, in_features: int, out_features: int): super().__init__() self.in_features = in_features self.out_features = out_features - self.weight = torch.nn.Parameter(torch.empty((out_features, in_features))) + self.weight = torch.nn.parameter.Parameter(torch.empty((out_features, in_features))) self.bias = None @staticmethod @@ -109,3 +116,55 @@ def from_linear(module: torch.nn.Module, module_index: int = -1): ret.weight.copy_(module.weight) # type: ignore ret.bias = module.bias # same param object return ret + + +@pytest.fixture +def optimizer_surgery_state(): + input_shape = (1, 18, 18) + n_classes = 10 + model = SimpleBatchPairModel(input_shape, n_classes) + policy: Mapping[Type[torch.nn.Module], surgery.ReplacementFunction] = {torch.nn.Linear: _CopyLinear.from_linear} + opt = torch.optim.SGD(model.parameters(), lr=.001) + orig_linear_modules = [model.fc1, model.fc2] + surgery.replace_module_classes(model, policies=policy, optimizers=opt) + new_linear_modules = [model.fc1, model.fc2] + return orig_linear_modules, new_linear_modules, opt + + +def test_optimizer_surgery_no_duplicate_params(optimizer_surgery_state: Tuple[List[torch.nn.Module], + List[torch.nn.Module], Optimizer]): + _, _, opt = optimizer_surgery_state + params_list = opt.param_groups[0]['params'] + params_set = set(params_list) + assert len(params_list) == len(params_set) + + +def _param_in_optimizer(param: torch.nn.parameter.Parameter, opt: torch.optim.Optimizer): + return surgery._find_param_in_optimizer(param, opt) >= 0 + + +def test_optimizer_surgery_removed_params_gone(optimizer_surgery_state: Tuple[List[torch.nn.Module], + List[torch.nn.Module], Optimizer]): + orig_linear_modules, _, opt = optimizer_surgery_state + for module in orig_linear_modules: + assert isinstance(module.weight, torch.nn.parameter.Parameter) + assert not _param_in_optimizer(module.weight, opt) + + +def test_optimizer_surgery_new_params_present(optimizer_surgery_state: Tuple[List[torch.nn.Module], + List[torch.nn.Module], Optimizer]): + _, new_linear_modules, opt = optimizer_surgery_state + for module in new_linear_modules: + assert isinstance(module.weight, torch.nn.parameter.Parameter) + assert _param_in_optimizer(module.weight, opt) + assert isinstance(module.bias, torch.nn.parameter.Parameter) + assert _param_in_optimizer(module.bias, opt) + + +def test_optimizer_surgery_params_not_removed_still_there(optimizer_surgery_state: Tuple[List[torch.nn.Module], + List[torch.nn.Module], + Optimizer]): + orig_linear_modules, _, opt = optimizer_surgery_state + for module in orig_linear_modules: + assert isinstance(module.bias, torch.nn.parameter.Parameter) + assert _param_in_optimizer(module.bias, opt) diff --git a/tests/trainer/test_scheduler.py b/tests/trainer/test_scheduler.py index 809f9f90a0..916dc45918 100644 --- a/tests/trainer/test_scheduler.py +++ b/tests/trainer/test_scheduler.py @@ -1,13 +1,13 @@ # Copyright 2021 MosaicML. All Rights Reserved. -from typing import Dict, Type, Union +from typing import Dict, List, Type, Union from unittest import mock import pytest import torch from torch.optim.lr_scheduler import ExponentialLR, MultiStepLR, StepLR -from composer.core.types import ModelParameters +from composer.core.types import Optimizer, Scheduler from composer.optim.pytorch_future import WarmUpLR from composer.optim.scheduler import (ComposedScheduler, ConstantLRHparams, CosineAnnealingLRHparams, CosineAnnealingWarmRestartsHparams, ExponentialLRHparams, LinearLRHparams, @@ -92,35 +92,23 @@ } -@pytest.fixture -def dummy_parameters() -> ModelParameters: - net = torch.nn.Sequential(torch.nn.Linear(5, 2), torch.nn.ReLU()) - return net.parameters() - - -@pytest.fixture -def dummy_optimizer(dummy_parameters) -> torch.optim.Optimizer: - return torch.optim.SGD(dummy_parameters, 0.1) - - @pytest.mark.parametrize("scheduler_name", scheduler_registry.keys()) class TestSchedulerInit(): - def test_scheduler_initialization(self, scheduler_name: str, dummy_optimizer): + def test_scheduler_initialization(self, scheduler_name: str, dummy_optimizer: Optimizer): # create the scheduler hparams object obj: Type[SchedulerHparams] = scheduler_registry[scheduler_name] scheduler_hparams = schedulers[obj] # create the scheduler object using the hparams - scheduler, interval = scheduler_hparams.initialize_object(dummy_optimizer, steps_per_epoch=1) + scheduler = scheduler_hparams.initialize_object(dummy_optimizer, steps_per_epoch=1) assert isinstance(scheduler, scheduler_hparams.scheduler_object) # type: ignore - assert interval == scheduler_hparams.interval # type: ignore @pytest.mark.parametrize('timestrings', EXPECTED_RESULTS_TIME_CONVERSION.keys()) @pytest.mark.parametrize('interval', ['steps', 'epochs']) - def test_scheduler_time_conversion(self, scheduler_name: str, dummy_optimizer, timestrings: Union[str, int], - interval: str): + def test_scheduler_time_conversion(self, scheduler_name: str, dummy_optimizer: Optimizer, + timestrings: Union[str, int], interval: str): expected = EXPECTED_RESULTS_TIME_CONVERSION[timestrings][interval] obj: Type[SchedulerHparams] = scheduler_registry[scheduler_name] steps_per_epoch = TIME_HPARAMS[timestrings]['steps_per_epoch'] @@ -131,39 +119,43 @@ def test_scheduler_time_conversion(self, scheduler_name: str, dummy_optimizer, t with mock.patch.object(scheduler_hparams, time_field[obj], timestrings), \ mock.patch.object(scheduler_hparams, 'interval', interval): - scheduler, interval = scheduler_hparams.initialize_object(dummy_optimizer, - steps_per_epoch=steps_per_epoch, - max_training_duration=f"{max_epochs}ep") + scheduler = scheduler_hparams.initialize_object(dummy_optimizer, + steps_per_epoch=steps_per_epoch, + max_training_duration=f"{max_epochs}ep") assert getattr(scheduler, time_field[obj]) == expected @pytest.fixture -def optimizer(dummy_model): +def optimizer(dummy_model: torch.nn.Module): return torch.optim.SGD(dummy_model.parameters(), lr=1) class TestComposedScheduler(): - def _test(self, scheduler, targets, epochs, optimizer, interval='epoch'): + def _test(self, + scheduler: Scheduler, + targets: List[List[float]], + epochs: int, + optimizer: Optimizer, + interval: str = 'epoch'): for epoch in range(epochs): for param_group, target in zip(optimizer.param_groups, targets): torch.testing.assert_allclose(target[epoch], param_group['lr']) optimizer.step() - scheduler.step(interval) + scheduler.step(interval) # type: ignore - def test_composed(self, optimizer): + def test_composed(self, optimizer: Optimizer): epochs = 9 targets = [[1 * 0.2 for _ in range(4)] + [1 * 0.9**x for x in range(7)]] schedulers = [ ExponentialLR(optimizer, gamma=0.9), WarmUpLR(optimizer, warmup_factor=0.2, warmup_iters=4, warmup_method="constant") ] - schedulers = [(s, 'epoch') for s in schedulers] scheduler = ComposedScheduler(schedulers) self._test(scheduler, targets, epochs, optimizer) - def test_composed_linear(self, optimizer): + def test_composed_linear(self, optimizer: Optimizer): epochs = 9 targets = [[1 * 0.5 + (x/4 * 0.5) for x in range(4)] + [1 * 0.9**x for x in range(2)] + \ [1 * 0.9**x for x in range(2, 7)]] @@ -171,11 +163,10 @@ def test_composed_linear(self, optimizer): ExponentialLR(optimizer, gamma=0.9), WarmUpLR(optimizer, warmup_factor=0.5, warmup_iters=4, warmup_method="linear") ] - schedulers = [(s, 'epoch') for s in schedulers] scheduler = ComposedScheduler(schedulers) self._test(scheduler, targets, epochs, optimizer) - def test_composed_linear2(self, optimizer): + def test_composed_linear2(self, optimizer: Optimizer): epochs = 9 targets = [[1 * 0.5 + (x/4 * 0.5) for x in range(4)] + \ [1 * 0.9**x for x in range(2)] + [1 * 0.9**x * 0.1 for x in range(2, 7)]] @@ -184,51 +175,49 @@ def test_composed_linear2(self, optimizer): MultiStepLR(optimizer, milestones=[6], gamma=0.1), WarmUpLR(optimizer, warmup_factor=0.5, warmup_iters=4, warmup_method="linear") ] - schedulers = [(s, 'epoch') for s in schedulers] scheduler = ComposedScheduler(schedulers) self._test(scheduler, targets, epochs, optimizer) - def test_composed_linear_from_zero(self, optimizer): + def test_composed_linear_from_zero(self, optimizer: Optimizer): epochs = 9 targets = [[1 * 0.0 + (x / 4 * 1.0) for x in range(4)] + [1 * 0.9**x for x in range(7)]] schedulers = [ ExponentialLR(optimizer, gamma=0.9), WarmUpLR(optimizer, warmup_factor=0, warmup_iters=4, warmup_method="linear") ] - schedulers = [(s, 'epoch') for s in schedulers] scheduler = ComposedScheduler(schedulers) self._test(scheduler, targets, epochs, optimizer) - def test_composed_linear_from_zero_step(self, optimizer): + def test_composed_linear_from_zero_step(self, optimizer: Optimizer): epochs = 9 targets = [[x / 4 for x in range(4)] + [1.0 for _ in range(7)]] schedulers = [ - (ExponentialLR(optimizer, gamma=0.9), 'epoch'), # should never trigger - (WarmUpLR(optimizer, warmup_factor=0, warmup_iters=4, warmup_method="linear"), 'batch') + ExponentialLR(optimizer, gamma=0.9), + WarmUpLR(optimizer, warmup_factor=0, warmup_iters=4, warmup_method="linear"), ] + schedulers[0].interval = 'epoch' # should never trigger + schedulers[1].interval = 'batch' scheduler = ComposedScheduler(schedulers) self._test(scheduler, targets, epochs, optimizer, interval='batch') @pytest.mark.xfail - def test_validate_compose_multistep(self, optimizer): + def test_validate_compose_multistep(self, optimizer: Optimizer): schedulers = [ ExponentialLR(optimizer, gamma=0.9), WarmUpLR(optimizer, warmup_factor=0, warmup_iters=4, warmup_method="linear"), MultiStepLR(optimizer, milestones=[3], gamma=0.1) ] - schedulers = [(s, 'epoch') for s in schedulers] with pytest.raises(ValueError): ComposedScheduler(schedulers) @pytest.mark.xfail - def test_validate_compose_step(self, optimizer): + def test_validate_compose_step(self, optimizer: Optimizer): schedulers = [ ExponentialLR(optimizer, gamma=0.9), WarmUpLR(optimizer, warmup_factor=0, warmup_iters=4, warmup_method="linear"), StepLR(optimizer, step_size=2, gamma=0.1) ] - schedulers = [(s, 'epoch') for s in schedulers] with pytest.raises(ValueError): ComposedScheduler(schedulers) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4c68e46208..d3bb6865c1 100755 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -5,16 +5,14 @@ import pytest import torch import torch.distributed -from torch.optim import Adam from composer.callbacks.lr_monitor import LRMonitor from composer.core.logging.logger import Logger from composer.core.precision import Precision -from composer.core.types import DataLoader +from composer.core.types import DataLoader, Optimizer, Scheduler from composer.loggers.tqdm_logger import TQDMLoggerBackend from composer.models.base import BaseMosaicModel -from composer.optim.optimizer_hparams import AdamHparams -from composer.optim.scheduler import ComposedScheduler, ExponentialLRHparams +from composer.optim.scheduler import ComposedScheduler from composer.trainer import Trainer, TrainerHparams from composer.trainer.devices.device_hparams import CPUDeviceHparams, DeviceHparams, GPUDeviceHparams from tests.utils.trainer_fit import get_total_loss, train_model @@ -31,20 +29,21 @@ def test_trainer_init_all_defaults(dummy_train_dataloader: DataLoader, dummy_val def test_trainer_init_additional_args(dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader, + dummy_optimizer: Optimizer, dummy_scheduler: Scheduler, dummy_model: BaseMosaicModel): trainer = Trainer( model=dummy_model, train_dataloader=dummy_train_dataloader, eval_dataloader=dummy_val_dataloader, max_duration="10ep", - optimizer_hparams=AdamHparams(), - schedulers_hparams=[ExponentialLRHparams(gamma=0.1)], + optimizers=dummy_optimizer, + schedulers=dummy_scheduler, log_destinations=[TQDMLoggerBackend()], callbacks=(LRMonitor(),), ) assert isinstance(trainer, Trainer) - assert isinstance(trainer.state.optimizers[0], Adam) + assert trainer.state.optimizers[0] == dummy_optimizer assert isinstance(trainer.state.schedulers[0], ComposedScheduler)