Skip to content

Commit

Permalink
Change functional surgery method return values to None (#1543)
Browse files Browse the repository at this point in the history
Some functional surgery methods previously returned the model object.
  • Loading branch information
nik-mosaic authored Feb 6, 2023
1 parent 48d40f9 commit c191b37
Show file tree
Hide file tree
Showing 11 changed files with 16 additions and 38 deletions.
4 changes: 2 additions & 2 deletions .ci/release_tests/example_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
my_model = models.resnet18()

# add blurpool and squeeze excite layers
my_model = cf.apply_blurpool(my_model)
my_model = cf.apply_squeeze_excite(my_model)
cf.apply_blurpool(my_model)
cf.apply_squeeze_excite(my_model)

# your own training code starts here
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ from torchvision import models
my_model = models.resnet18()

# add blurpool and squeeze excite layers
my_model = cf.apply_blurpool(my_model)
my_model = cf.apply_squeeze_excite(my_model)
cf.apply_blurpool(my_model)
cf.apply_squeeze_excite(my_model)

# your own training code starts here
```
Expand Down
7 changes: 1 addition & 6 deletions composer/algorithms/blurpool/blurpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def apply_blurpool(model: torch.nn.Module,
replace_maxpools: bool = True,
blur_first: bool = True,
min_channels: int = 16,
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None) -> torch.nn.Module:
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None) -> None:
"""Add anti-aliasing filters to strided :class:`torch.nn.Conv2d` and/or :class:`torch.nn.MaxPool2d` modules.
These filters increase invariance to small spatial shifts in the input
Expand Down Expand Up @@ -55,9 +55,6 @@ def apply_blurpool(model: torch.nn.Module,
then it is safe to omit this parameter. These optimizers will see
the correct model parameters.
Returns:
The modified model
Example:
.. testcode::
Expand All @@ -78,8 +75,6 @@ def apply_blurpool(model: torch.nn.Module,
module_surgery.replace_module_classes(model, optimizers=optimizers, policies=transforms)
_log_surgery_result(model)

return model


class BlurPool(Algorithm):
"""`BlurPool <http://proceedings.mlr.press/v97/zhang19a.html>`_ adds anti-aliasing filters to convolutional layers.
Expand Down
2 changes: 1 addition & 1 deletion composer/algorithms/ema/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
__all__ = ['EMA', 'compute_ema']


def compute_ema(model: torch.nn.Module, ema_model: torch.nn.Module, smoothing: float = 0.99):
def compute_ema(model: torch.nn.Module, ema_model: torch.nn.Module, smoothing: float = 0.99) -> None:
r"""Updates the weights of ``ema_model`` to be closer to the weights of ``model``
according to an exponential weighted average. Weights are updated according to
Expand Down
6 changes: 1 addition & 5 deletions composer/algorithms/factorize/factorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def apply_factorization(model: torch.nn.Module,
latent_channels: Union[int, float] = 0.25,
min_features: int = 512,
latent_features: Union[int, float] = 0.25,
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None) -> torch.nn.Module:
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None) -> None:
"""Replaces :class:`torch.nn.Linear` and :class:`torch.nn.Conv2d` modules with
:class:`.FactorizedLinear` and :class:`.FactorizedConv2d` modules.
Expand Down Expand Up @@ -71,9 +71,6 @@ def apply_factorization(model: torch.nn.Module,
then it is safe to omit this parameter. These optimizers will see
the correct model parameters.
Returns:
The modified model
Example:
.. testcode::
Expand All @@ -92,7 +89,6 @@ def apply_factorization(model: torch.nn.Module,
min_features=min_features,
latent_features=latent_features,
optimizers=optimizers)
return model


class Factorize(Algorithm):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def apply_gated_linear_units(model: torch.nn.Module,
NoEffectWarning('No instances of BertIntermediate were found so Gated Linear Units will be skipped '
'as no modules can be replaced. This is likely because Gated Linear Units has already '
'been applied to this model.'))
return

# get the activation functions used
act_fns = {module.intermediate_act_fn for module in intermediate_modules}
Expand Down
5 changes: 2 additions & 3 deletions composer/algorithms/ghost_batchnorm/ghost_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

def apply_ghost_batchnorm(model: torch.nn.Module,
ghost_batch_size: int = 32,
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None) -> torch.nn.Module:
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None) -> None:
"""Replace batch normalization modules with ghost batch normalization modules.
Ghost batch normalization modules split their input into chunks of
Expand All @@ -39,7 +39,7 @@ def apply_ghost_batchnorm(model: torch.nn.Module,
model parameters.
Returns:
The modified model
The number of modules modified.
Example:
.. testcode::
Expand All @@ -59,7 +59,6 @@ def maybe_replace(module: torch.nn.Module, module_index: int) -> Optional[torch.
# now checks if `module.__class__ == cls`, rather than `isinstance(module, cls)`
transforms = {cls: maybe_replace for cls in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]}
module_surgery.replace_module_classes(model, optimizers=optimizers, policies=transforms)
return model


class GhostBatchNorm(Algorithm):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ def apply_low_precision_layernorm(model, optimizers: Union[torch.optim.Optimizer
warnings.warn(NoEffectWarning('No instances of torch.nn.LayerNorm found.'))
log.info(f'Successfully replaced {len(replaced_instances)} instances of LayerNorm with LowPrecisionLayerNorm')

return model


class LowPrecisionLayerNorm(Algorithm):
"""
Expand Down
15 changes: 5 additions & 10 deletions composer/algorithms/squeeze_excite/squeeze_excite.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def apply_squeeze_excite(
latent_channels: float = 64,
min_channels: int = 128,
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None,
):
) -> None:
"""Adds Squeeze-and-Excitation blocks (`Hu et al, 2019 <https://arxiv.org/abs/1709.01507>`_) after
:class:`torch.nn.Conv2d` layers.
Expand Down Expand Up @@ -50,9 +50,6 @@ def apply_squeeze_excite(
then it is safe to omit this parameter. These optimizers will see the correct
model parameters.
Returns:
The modified model
Example:
.. testcode::
Expand All @@ -73,8 +70,6 @@ def convert_module(module: torch.nn.Module, module_index: int):

module_surgery.replace_module_classes(model, optimizers=optimizers, policies={torch.nn.Conv2d: convert_module})

return model


class SqueezeExcite2d(torch.nn.Module):
"""Squeeze-and-Excitation block from (`Hu et al, 2019 <https://arxiv.org/abs/1709.01507>`_)
Expand Down Expand Up @@ -164,10 +159,10 @@ def match(self, event: Event, state: State) -> bool:
return event == Event.INIT

def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
state.model = apply_squeeze_excite(state.model,
optimizers=state.optimizers,
latent_channels=self.latent_channels,
min_channels=self.min_channels)
apply_squeeze_excite(state.model,
optimizers=state.optimizers,
latent_channels=self.latent_channels,
min_channels=self.min_channels)
layer_count = module_surgery.count_module_instances(state.model, SqueezeExciteConv2d)

log.info(f'Applied SqueezeExcite to model {state.model.__class__.__name__} '
Expand Down
6 changes: 1 addition & 5 deletions composer/algorithms/stochastic_depth/stochastic_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def apply_stochastic_depth(model: torch.nn.Module,
target_layer_name: str,
stochastic_method: str = 'block',
drop_rate: float = 0.2,
drop_distribution: str = 'linear') -> torch.nn.Module:
drop_distribution: str = 'linear') -> None:
"""Applies Stochastic Depth (`Huang et al, 2016 <https://arxiv.org/abs/1603.09382>`_) to the specified model.
The algorithm replaces the specified target layer with a stochastic version
Expand Down Expand Up @@ -67,9 +67,6 @@ def apply_stochastic_depth(model: torch.nn.Module,
starting with 0 drop rate and ending with ``drop_rate``.
Default: ``"linear"``.
Returns:
The modified model
Example:
.. testcode::
Expand All @@ -95,7 +92,6 @@ def apply_stochastic_depth(model: torch.nn.Module,
stochastic_method=stochastic_method)
transforms[target_layer] = stochastic_from_target_layer
module_surgery.replace_module_classes(model, policies=transforms)
return model


class StochasticDepth(Algorithm):
Expand Down
2 changes: 1 addition & 1 deletion examples/exporting_for_inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
"import composer.functional as cf\n",
"\n",
"model = ComposerClassifier(module=resnet.resnet50())\n",
"model = cf.apply_squeeze_excite(model)\n",
"cf.apply_squeeze_excite(model)\n",
"\n",
"# switch to eval mode\n",
"model.eval()"
Expand Down

0 comments on commit c191b37

Please sign in to comment.