Skip to content

Commit

Permalink
Optimizer Surgery (#249)
Browse files Browse the repository at this point in the history
* Added surgery

* Fixed most tests

* Fixed more tests

* Fixes

* Fixed tests

* PR Cleanup

* Fixed sorting

* Formatting

* Cleaned up PR

* Adressed most pr feedback

* new method for post-device change optimizer surgery

* fix len bug

* fix failing test

* fix test_load

* minor docstring update

* exclude deeplab

Co-authored-by: Ravi Rahman <r_rahman@mit.edu>
Co-authored-by: hanlint <hanlin@mosaicml.com>
Co-authored-by: root <jamie@mosaicml.com>
  • Loading branch information
4 people authored and A-Jacobson committed Feb 10, 2022
1 parent a34bf3b commit 85c82ca
Show file tree
Hide file tree
Showing 20 changed files with 655 additions and 247 deletions.
33 changes: 26 additions & 7 deletions composer/algorithms/alibi/alibi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from dataclasses import asdict, dataclass
from operator import attrgetter
from types import MethodType, ModuleType
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Type, Union, cast

import torch
import yahp as hp

from composer.algorithms import AlgorithmHparams
from composer.core import Algorithm, Event, Logger, State, surgery
from composer.core.types import Optimizers

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,9 +57,17 @@ def initialize_object(self) -> "Alibi":
return Alibi(**asdict(self))


def apply_alibi(model: torch.nn.Module, heads_per_layer: int, max_sequence_length: int,
position_embedding_attribute: str, attention_module: torch.nn.Module, attr_to_replace: str,
alibi_attention: Callable, mask_replacement_function: Union[Callable, None]) -> None:
def apply_alibi(
model: torch.nn.Module,
heads_per_layer: int,
max_sequence_length: int,
position_embedding_attribute: str,
attention_module: Type[torch.nn.Module],
attr_to_replace: str,
alibi_attention: Callable,
mask_replacement_function: Union[Callable, None],
optimizers: Optional[Optimizers] = None,
) -> None:
"""
Removes position embeddings and replaces the attention function and attention mask
according to `AliBi <https://arxiv.org/abs/2108.12409>`_.
Expand All @@ -85,6 +94,14 @@ def apply_alibi(model: torch.nn.Module, heads_per_layer: int, max_sequence_lengt
attention mask. This is sometimes necessary for evaluating
on sequence lengths longer than the model was initialized to
accommodate.
optimizers (Optimizers, optional): Existing optimizers bound to ``model.parameters()``.
All optimizers that have already been constructed with,
``model.parameters()`` must be specified here so they will optimize
the correct parameters.
If the optimizer(s) are constructed *after* calling this function,
then it is safe to omit this parameter. These optimizers will see the correct
model parameters.
"""

zero_and_freeze_expand_position_embeddings(model=model,
Expand All @@ -100,8 +117,9 @@ def convert_attention(module: torch.nn.Module, module_index: Optional[int] = Non
module = mask_replacement_function(module, max_sequence_length)
return module

transforms = {attention_module: convert_attention}
replaced_pairs = surgery.replace_module_classes(model, transforms) # type: ignore
replaced_pairs = surgery.replace_module_classes(model,
optimizers=optimizers,
policies={attention_module: convert_attention})

count = len(replaced_pairs)
log.info(f" {count} instances of ALiBi added")
Expand Down Expand Up @@ -183,7 +201,8 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:

apply_alibi(
state.model,
heads_per_layer=self.hparams.heads_per_layer, # type: ignore
optimizers=state.optimizers,
heads_per_layer=cast(int, self.hparams.heads_per_layer),
max_sequence_length=self.hparams.max_sequence_length,
position_embedding_attribute=self.hparams.position_embedding_attribute,
attr_to_replace=self.hparams.attr_to_replace,
Expand Down
14 changes: 12 additions & 2 deletions composer/algorithms/blurpool/blurpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from composer.algorithms import AlgorithmHparams
from composer.algorithms.blurpool.blurpool_layers import BlurConv2d, BlurMaxPool2d
from composer.core import Algorithm, Event, Logger, State, surgery
from composer.core.types import Optimizers

log = logging.getLogger(__name__)

Expand All @@ -27,6 +28,7 @@ def _log_surgery_result(model: torch.nn.Module):


def apply_blurpool(model: torch.nn.Module,
optimizers: Optional[Optimizers] = None,
replace_convs: bool = True,
replace_maxpools: bool = True,
blur_first: bool = True) -> None:
Expand All @@ -38,6 +40,14 @@ def apply_blurpool(model: torch.nn.Module,
Args:
model: model to modify
optimizers (Optimizers, optional): Existing optimizers bound to ``model.parameters()``.
All optimizers that have already been constructed with,
``model.parameters()`` must be specified here so they will optimize
the correct parameters.
If the optimizer(s) are constructed *after* calling this function,
then it is safe to omit this parameter. These optimizers will see the correct
model parameters.
replace_convs: replace strided :class:`torch.nn.Conv2d` modules with
:class:`BlurConv2d` modules
replace_maxpools: replace eligible :class:`torch.nn.MaxPool2d` modules
Expand All @@ -57,7 +67,7 @@ def apply_blurpool(model: torch.nn.Module,
_maybe_replace_strided_conv2d,
blur_first=blur_first,
)
surgery.replace_module_classes(model, policies=transforms)
surgery.replace_module_classes(model, optimizers=optimizers, policies=transforms)
_log_surgery_result(model)


Expand Down Expand Up @@ -133,7 +143,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
"""
assert state.model is not None

apply_blurpool(state.model, **asdict(self.hparams))
apply_blurpool(state.model, optimizers=state.optimizers, **asdict(self.hparams))
self._log_results(event, state, logger)

def _log_results(self, event: Event, state: State, logger: Logger) -> None:
Expand Down
25 changes: 19 additions & 6 deletions composer/algorithms/factorize/factorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from composer.algorithms.factorize.factorize_modules import (FactorizedConv2d, FactorizedLinear,
factorizing_could_speedup)
from composer.core import Algorithm, Event, Logger, State, surgery
from composer.core.types import Optimizers

log = logging.getLogger(__name__)

Expand All @@ -26,7 +27,10 @@ def _python_log_surgery_result(model: torch.nn.Module, new_class: Type[torch.nn.
f'Model now has {num_replaced_modules} {new_class.__name__} modules')


def factorize_conv2d_modules(model: torch.nn.Module, min_channels: int, latent_channels: Union[int, float]):
def factorize_conv2d_modules(model: torch.nn.Module,
min_channels: int,
latent_channels: Union[int, float],
optimizers: Optional[Optimizers] = None):
"""Replaces :class:`torch.nn.Conv2d` modules in ``model`` with :class:`~composer.algorithms.factorize.FactorizedConv2d` modules. See :class:`Factorize` for details."""

def _maybe_replace_conv2d(module: torch.nn.Module, module_index: int) -> Optional[torch.nn.Module]:
Expand All @@ -36,12 +40,17 @@ def _maybe_replace_conv2d(module: torch.nn.Module, module_index: int) -> Optiona
return FactorizedConv2d.from_conv2d(module, module_index, latent_channels=latent_channels)
return None # not enough rank reduction to be worth it

ret = surgery.replace_module_classes(model, {torch.nn.Conv2d: _maybe_replace_conv2d})
ret = surgery.replace_module_classes(model,
optimizers=optimizers,
policies={torch.nn.Conv2d: _maybe_replace_conv2d})
_python_log_surgery_result(model, FactorizedConv2d)
return ret


def factorize_linear_modules(model: torch.nn.Module, min_features: int, latent_features: Union[int, float]):
def factorize_linear_modules(model: torch.nn.Module,
min_features: int,
latent_features: Union[int, float],
optimizers: Optional[Optimizers] = None):
"""Replaces :class:`torch.nn.Linear` modules in ``model`` with :class:`~composer.algorithms.factorize.FactorizedLinear` modules. See :class:`Factorize` for details."""

def _maybe_replace_linear(module: torch.nn.Module, module_index: int) -> Optional[torch.nn.Module]:
Expand All @@ -51,7 +60,9 @@ def _maybe_replace_linear(module: torch.nn.Module, module_index: int) -> Optiona
return FactorizedLinear.from_linear(module, module_index, latent_features=latent_features)
return None # not enough rank reduction to be worth it

ret = surgery.replace_module_classes(model, {torch.nn.Linear: _maybe_replace_linear})
ret = surgery.replace_module_classes(model,
optimizers=optimizers,
policies={torch.nn.Linear: _maybe_replace_linear})
_python_log_surgery_result(model, FactorizedLinear)
return ret

Expand Down Expand Up @@ -175,15 +186,17 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
if self.hparams.factorize_convs:
factorize_conv2d_modules(state.model,
min_channels=self.hparams.min_channels,
latent_channels=self.hparams.latent_channels)
latent_channels=self.hparams.latent_channels,
optimizers=state.optimizers)
num_factorized = surgery.count_module_instances(state.model, FactorizedConv2d)
logger.metric_fit({
LOG_NUM_CONV2D_REPLACEMENTS_KEY: num_factorized,
})
if self.hparams.factorize_linears:
factorize_linear_modules(state.model,
min_features=self.hparams.min_features,
latent_features=self.hparams.latent_features)
latent_features=self.hparams.latent_features,
optimizers=state.optimizers)
num_factorized = surgery.count_module_instances(state.model, FactorizedLinear)
logger.metric_fit({
LOG_NUM_LINEAR_REPLACEMENTS_KEY: num_factorized,
Expand Down
17 changes: 14 additions & 3 deletions composer/algorithms/ghost_batchnorm/ghost_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from composer.algorithms import AlgorithmHparams
from composer.core import Algorithm, Event, Logger, State, surgery
from composer.core.types import Optimizers

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,7 +100,9 @@ class GhostBatchNorm3d(_GhostBatchNorm):
pass


def apply_ghost_batchnorm(model: torch.nn.Module, ghost_batch_size: int) -> torch.nn.Module:
def apply_ghost_batchnorm(model: torch.nn.Module,
ghost_batch_size: int,
optimizers: Optional[Optimizers] = None) -> torch.nn.Module:
"""Replace batch normalization modules with ghost batch normalization modules.
Must be run before the model has been moved to accelerators and before
Expand All @@ -108,6 +111,14 @@ def apply_ghost_batchnorm(model: torch.nn.Module, ghost_batch_size: int) -> torc
Args:
model: model to transform
ghost_batch_size: size of sub-batches to normalize over
optimizers (Optimizers, optional): Existing optimizers bound to ``model.parameters()``.
All optimizers that have already been constructed with,
``model.parameters()`` must be specified here so they will optimize
the correct parameters.
If the optimizer(s) are constructed *after* calling this function,
then it is safe to omit this parameter. These optimizers will see the correct
model parameters.
"""

def maybe_replace(module: torch.nn.Module, module_index: int) -> Optional[torch.nn.Module]:
Expand All @@ -117,7 +128,7 @@ def maybe_replace(module: torch.nn.Module, module_index: int) -> Optional[torch.
# we have to specify class names explicitly because replace_module_classes
# now checks if `module.__class__ == cls`, rather than `isinstance(module, cls)`
transforms = {cls: maybe_replace for cls in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]}
surgery.replace_module_classes(model, policies=transforms)
surgery.replace_module_classes(model, optimizers=optimizers, policies=transforms)
return model


Expand Down Expand Up @@ -162,7 +173,7 @@ def apply(self, event: Event, state: State, logger: Optional[Logger] = None) ->
"""
assert state.model is not None, "Model must be in state"

apply_ghost_batchnorm(model=state.model, ghost_batch_size=self.ghost_batch_size)
apply_ghost_batchnorm(model=state.model, optimizers=state.optimizers, ghost_batch_size=self.ghost_batch_size)
self._log_results(event, state, logger)

def _log_results(self, event: Event, state: State, logger: Optional[Logger] = None) -> None:
Expand Down
16 changes: 12 additions & 4 deletions composer/algorithms/squeeze_excite/squeeze_excite.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from composer.algorithms.algorithm_hparams import AlgorithmHparams
from composer.core import Algorithm, Event, Logger, State, surgery
from composer.core.types import Optimizers

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -81,16 +82,22 @@ def from_conv2d(module: torch.nn.Conv2d, module_index: int, latent_channels: flo
return SqueezeExciteConv2d(conv=module, latent_channels=latent_channels)


def apply_se(model: torch.nn.Module, latent_channels: float, min_channels: int):
def apply_se(
model: torch.nn.Module,
latent_channels: float,
min_channels: int,
optimizers: Optional[Optimizers] = None,
):
"""See :class:`SqueezeExcite`"""

def convert_module(module: torch.nn.Conv2d, module_index: int):
def convert_module(module: torch.nn.Module, module_index: int):
assert isinstance(module, torch.nn.Conv2d), "should only be called with conv2d"
if min(module.in_channels, module.out_channels) < min_channels:
return None
return SqueezeExciteConv2d.from_conv2d(module, module_index, latent_channels=latent_channels)

transforms = {torch.nn.Conv2d: convert_module}
surgery.replace_module_classes(model, transforms) # type: ignore
surgery.replace_module_classes(model, optimizers=optimizers, policies={torch.nn.Conv2d: convert_module})

return model


Expand Down Expand Up @@ -142,6 +149,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
logger (Logger): the training logger
"""
state.model = apply_se(state.model,
optimizers=state.optimizers,
latent_channels=self.hparams.latent_channels,
min_channels=self.hparams.min_channels)
layer_count = surgery.count_module_instances(state.model, SqueezeExciteConv2d)
Expand Down
13 changes: 12 additions & 1 deletion composer/algorithms/stochastic_depth/stochastic_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from composer.algorithms.stochastic_depth.sample_stochastic_layers import SampleStochasticBottleneck
from composer.algorithms.stochastic_depth.stochastic_layers import StochasticBottleneck
from composer.core import Algorithm, Event, Logger, State, surgery
from composer.core.types import Optimizers
from composer.models.resnets import Bottleneck

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -92,6 +93,7 @@ def validate(self):
def apply_stochastic_depth(model: torch.nn.Module,
stochastic_method: str,
target_layer_name: str,
optimizers: Optional[Optimizers] = None,
drop_rate: float = 0.2,
drop_distribution: str = 'linear',
use_same_gpu_seed: bool = True) -> None:
Expand All @@ -113,6 +115,14 @@ def apply_stochastic_depth(model: torch.nn.Module,
equivalent. The name must be registered in ``STOCHASTIC_LAYER_MAPPING``
dictionary with the target layer class and the stochastic layer class.
Currently, only ``'ResNetBottleneck'`` is supported.
optimizers (Optimizers, optional): Existing optimizers bound to ``model.parameters()``.
All optimizers that have already been constructed with,
``model.parameters()`` must be specified here so they will optimize
the correct parameters.
If the optimizer(s) are constructed *after* calling this function,
then it is safe to omit this parameter. These optimizers will see the correct
model parameters.
drop_rate: The base probability of dropping a layer or sample. Must be
between 0.0 and 1.0.
drop_distribution: How ``drop_rate`` is distributed across
Expand Down Expand Up @@ -145,7 +155,7 @@ def apply_stochastic_depth(model: torch.nn.Module,
raise ValueError(f"stochastic_method {stochastic_method} is not supported."
f" Must be one of {list(STOCHASTIC_LAYER_MAPPING.keys())}")
transforms[target_layer] = stochastic_from_target_layer
surgery.replace_module_classes(model, policies=transforms)
surgery.replace_module_classes(model, optimizers=optimizers, policies=transforms)


def _update_drop_rate(module: torch.nn.Module, stochastic_block: Type[torch.nn.Module], drop_rate: float,
Expand Down Expand Up @@ -258,6 +268,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
log.warning(f'No {self.hparams.target_layer_name} found in model! Algorithm will function as a no-op.')

apply_stochastic_depth(state.model,
optimizers=state.optimizers,
stochastic_method=self.hparams.stochastic_method,
target_layer_name=self.hparams.target_layer_name,
drop_rate=self.hparams.drop_rate,
Expand Down
7 changes: 5 additions & 2 deletions composer/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,13 @@ def _compile(
Returns:
algorithms_to_run(Sequence[Algorithm]): modified sequence of algorithms
"""
from composer.algorithms import SelectiveBackprop
from composer.algorithms import SelectiveBackprop, StochasticDepth

# Move selective backprop to the beginning while maintaining order of other algorithms
algorithms = sorted(algorithms_to_run, key=lambda x: not isinstance(x, SelectiveBackprop))
algorithms = sorted(algorithms_to_run,
key=lambda x: not isinstance(x, SelectiveBackprop) and not isinstance(x, StochasticDepth))

print(event, algorithms)

if event.is_after_event:
"""Establish a FILO queue of algorithms before_ and after_ an event.
Expand Down
Loading

0 comments on commit 85c82ca

Please sign in to comment.