Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change functional surgery method return values to None #1543

Merged
merged 33 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
dde9fd0
Remove model return values
nik-mosaic Sep 20, 2022
a98939a
Update example ci release test
nik-mosaic Sep 20, 2022
a5f3248
Update notebook/readme
nik-mosaic Sep 20, 2022
fd23b44
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Sep 21, 2022
b424858
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Sep 23, 2022
72cd516
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Sep 26, 2022
f6d9ecc
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Sep 28, 2022
2647820
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Oct 1, 2022
9c72fec
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Oct 3, 2022
b0cba66
Return number of modules replaced for applicable model surgery methods
nik-mosaic Oct 4, 2022
289dba2
Fix factorize
nik-mosaic Oct 4, 2022
84e1501
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Oct 4, 2022
5d0fe36
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Oct 4, 2022
e5f9ddd
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Oct 5, 2022
cd9ef4f
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Oct 6, 2022
1ae17cc
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Oct 7, 2022
933751e
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Oct 11, 2022
d250c44
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Oct 18, 2022
4b2b7f5
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Oct 21, 2022
44cce02
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Oct 27, 2022
f490293
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Nov 1, 2022
d57c696
Merge branch 'mosaicml:dev' into nikhil/consistent-surgery
nik-mosaic Nov 16, 2022
4f22daa
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Jan 24, 2023
c8cd91b
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Jan 30, 2023
859902f
Remove int return values
nik-mosaic Jan 30, 2023
54ccd18
Merge branch 'nikhil/consistent-surgery' of https://github.com/nik-mo…
nik-mosaic Jan 30, 2023
179f872
Remove int return value from LPLN
nik-mosaic Jan 30, 2023
647d506
Update GLU
nik-mosaic Jan 30, 2023
896d0a5
Rerun Jenkins
nik-mosaic Jan 30, 2023
9be5cee
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Jan 30, 2023
aa58150
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Feb 1, 2023
72f04b2
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Feb 2, 2023
bbc6cb5
Merge branch 'dev' into nikhil/consistent-surgery
nik-mosaic Feb 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .ci/release_tests/example_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
my_model = models.resnet18()

# add blurpool and squeeze excite layers
my_model = cf.apply_blurpool(my_model)
my_model = cf.apply_squeeze_excite(my_model)
cf.apply_blurpool(my_model)
cf.apply_squeeze_excite(my_model)

# your own training code starts here
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ from torchvision import models
my_model = models.resnet18()

# add blurpool and squeeze excite layers
my_model = cf.apply_blurpool(my_model)
my_model = cf.apply_squeeze_excite(my_model)
cf.apply_blurpool(my_model)
cf.apply_squeeze_excite(my_model)

# your own training code starts here
```
Expand Down
7 changes: 6 additions & 1 deletion composer/algorithms/alibi/alibi.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def apply_alibi(
model: torch.nn.Module,
max_sequence_length: int,
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None,
) -> None:
) -> int:
"""Removes position embeddings and replaces the attention function and attention mask
as per :class:`.Alibi`. Note that the majority of the training speed-up from using ALiBi
comes from being able to train on shorter sequence lengths; this function does not scale
Expand Down Expand Up @@ -67,6 +67,9 @@ def apply_alibi(
If the optimizer(s) are constructed *after* calling this function,
then it is safe to omit this parameter. These optimizers will see the correct
model parameters.

Returns:
The number of modules modified.
"""
try:
from composer.algorithms.alibi.attention_surgery_functions import policy_registry
Expand Down Expand Up @@ -108,6 +111,8 @@ def replacement_function(module: torch.nn.Module, module_index: int):
else:
log.info(f' {count} instances of ALiBi added')

return count


class Alibi(Algorithm):
"""ALiBi (Attention with Linear Biases; `Press et al, 2021 <https://arxiv.org/abs/2108.12409>`_) dispenses with
Expand Down
8 changes: 4 additions & 4 deletions composer/algorithms/blurpool/blurpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def apply_blurpool(model: torch.nn.Module,
replace_maxpools: bool = True,
blur_first: bool = True,
min_channels: int = 16,
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None) -> torch.nn.Module:
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None) -> int:
"""Add anti-aliasing filters to strided :class:`torch.nn.Conv2d` and/or :class:`torch.nn.MaxPool2d` modules.

These filters increase invariance to small spatial shifts in the input
Expand Down Expand Up @@ -56,7 +56,7 @@ def apply_blurpool(model: torch.nn.Module,
the correct model parameters.

Returns:
The modified model
The number of modules modified.

Example:
.. testcode::
Expand All @@ -75,10 +75,10 @@ def apply_blurpool(model: torch.nn.Module,
blur_first=blur_first,
min_channels=min_channels,
)
module_surgery.replace_module_classes(model, optimizers=optimizers, policies=transforms)
replaced_instances = module_surgery.replace_module_classes(model, optimizers=optimizers, policies=transforms)
_log_surgery_result(model)

return model
return len(replaced_instances)


class BlurPool(Algorithm):
Expand Down
2 changes: 1 addition & 1 deletion composer/algorithms/ema/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
__all__ = ['EMA', 'compute_ema']


def compute_ema(model: torch.nn.Module, ema_model: torch.nn.Module, smoothing: float = 0.99):
def compute_ema(model: torch.nn.Module, ema_model: torch.nn.Module, smoothing: float = 0.99) -> None:
r"""Updates the weights of ``ema_model`` to be closer to the weights of ``model``
according to an exponential weighted average. Weights are updated according to

Expand Down
25 changes: 14 additions & 11 deletions composer/algorithms/factorize/factorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def apply_factorization(model: torch.nn.Module,
latent_channels: Union[int, float] = 0.25,
min_features: int = 512,
latent_features: Union[int, float] = 0.25,
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None) -> torch.nn.Module:
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None) -> int:
"""Replaces :class:`torch.nn.Linear` and :class:`torch.nn.Conv2d` modules with
:class:`.FactorizedLinear` and :class:`.FactorizedConv2d` modules.

Expand Down Expand Up @@ -72,7 +72,7 @@ def apply_factorization(model: torch.nn.Module,
the correct model parameters.

Returns:
The modified model
The number of modules modified.

Example:
.. testcode::
Expand All @@ -82,17 +82,20 @@ def apply_factorization(model: torch.nn.Module,
model = models.resnet50()
cf.apply_factorization(model)
"""
replaced_conv_instances = {}
replaced_linear_instances = {}
if factorize_convs:
_factorize_conv2d_modules(model,
min_channels=min_channels,
latent_channels=latent_channels,
optimizers=optimizers)
replaced_conv_instances = _factorize_conv2d_modules(model,
min_channels=min_channels,
latent_channels=latent_channels,
optimizers=optimizers)
if factorize_linears:
_factorize_linear_modules(model,
min_features=min_features,
latent_features=latent_features,
optimizers=optimizers)
return model
replaced_linear_instances = _factorize_linear_modules(model,
min_features=min_features,
latent_features=latent_features,
optimizers=optimizers)

return len(replaced_conv_instances) + len(replaced_linear_instances)


class Factorize(Algorithm):
Expand Down
5 changes: 4 additions & 1 deletion composer/algorithms/fused_layernorm/fused_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ def from_LayerNorm(layer: torch.nn.Module, module_index: int) -> APEXFusedLayerN


def apply_fused_layernorm(model: torch.nn.Module, optimizers: Union[torch.optim.Optimizer,
Sequence[torch.optim.Optimizer]]) -> None:
Sequence[torch.optim.Optimizer]]) -> int:
"""Replaces all instances of `torch.nn.LayerNorm` with a `apex.normalization.fused_layer_norm.FusedLayerNorm
<https://nvidia.github.io/apex/layernorm.html>`_.

By fusing multiple kernel launches into one, this usually improves GPU utilization.
Returns:
The number of LayerNorms replaced with FusedLayerNorms.
"""
check_if_apex_installed()

Expand All @@ -56,6 +58,7 @@ def apply_fused_layernorm(model: torch.nn.Module, optimizers: Union[torch.optim.
NoEffectWarning(
'No instances of `torch.nn.LayerNorm` were found, and therefore, there were no modules to replace.'))
log.info(f'Successfully replaced {len(replaced_instances)} of LayerNorm with a Fused LayerNorm.')
return len(replaced_instances)


class FusedLayerNorm(Algorithm):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def apply_gated_linear_units(model: torch.nn.Module,
optimizers: Union[torch.optim.Optimizer, Sequence[torch.optim.Optimizer]],
act_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
gated_layer_bias: bool = False,
non_gated_layer_bias: bool = False) -> None:
non_gated_layer_bias: bool = False) -> int:
"""
Replaces the Linear layers in the feed-forward network with `Gated Linear Units <https://arxiv.org/abs/2002.05202>`_.

Expand All @@ -77,6 +77,9 @@ def apply_gated_linear_units(model: torch.nn.Module,
use the existing activation function in the model.
gated_layer_bias (bool, optional): Whether to use biases in the linear layers within the GLU. Default: ``False``.
non_gated_layer_bias (bool, optional): Whether to use biases in the linear layers within the GLU. Default: ``False``.

Returns:
The number of modules modified.
"""
if not IS_TRANSFORMERS_INSTALLED:
raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers')
Expand All @@ -94,7 +97,7 @@ def apply_gated_linear_units(model: torch.nn.Module,
NoEffectWarning('No instances of BertIntermediate were found so Gated Linear Units will be skipped '
'as no modules can be replaced. This is likely because Gated Linear Units has already '
'been applied to this model.'))
return
return 0

# get the activation functions used
act_fns = {module.intermediate_act_fn for module in intermediate_modules}
Expand Down Expand Up @@ -133,6 +136,7 @@ def from_bound_BertOutput(layer: torch.nn.Module, module_index: int) -> BERTGate
NoEffectWarning('No instances of BertIntermediate and BertOutput were found so no modules were replaced.'))
log.info(
f'Successfully replaced {len(replaced_instances)} of BertIntermediate and BertOutput with a GatedLinearUnit.')
return len(replaced_instances)


class GatedLinearUnits(Algorithm):
Expand Down
8 changes: 4 additions & 4 deletions composer/algorithms/ghost_batchnorm/ghost_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

def apply_ghost_batchnorm(model: torch.nn.Module,
ghost_batch_size: int = 32,
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None) -> torch.nn.Module:
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None) -> int:
"""Replace batch normalization modules with ghost batch normalization modules.

Ghost batch normalization modules split their input into chunks of
Expand All @@ -39,7 +39,7 @@ def apply_ghost_batchnorm(model: torch.nn.Module,
model parameters.

Returns:
The modified model
The number of modules modified.

Example:
.. testcode::
Expand All @@ -57,8 +57,8 @@ def maybe_replace(module: torch.nn.Module, module_index: int) -> Optional[torch.
# we have to specify class names explicitly because replace_module_classes
# now checks if `module.__class__ == cls`, rather than `isinstance(module, cls)`
transforms = {cls: maybe_replace for cls in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]}
module_surgery.replace_module_classes(model, optimizers=optimizers, policies=transforms)
return model
replaced_instances = module_surgery.replace_module_classes(model, optimizers=optimizers, policies=transforms)
return len(replaced_instances)


class GhostBatchNorm(Algorithm):
Expand Down
2 changes: 1 addition & 1 deletion composer/algorithms/gradient_clipping/gradient_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


def apply_gradient_clipping(parameters: Union[torch.Tensor, Iterable[torch.Tensor]], clipping_type: str,
clipping_threshold: float):
clipping_threshold: float) -> None:
"""Clips all gradients in model based on specified clipping_type.

Args:
Expand Down
19 changes: 10 additions & 9 deletions composer/algorithms/squeeze_excite/squeeze_excite.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def apply_squeeze_excite(
latent_channels: float = 64,
min_channels: int = 128,
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None,
):
) -> int:
"""Adds Squeeze-and-Excitation blocks (`Hu et al, 2019 <https://arxiv.org/abs/1709.01507>`_) after
:class:`torch.nn.Conv2d` layers.

Expand Down Expand Up @@ -51,7 +51,7 @@ def apply_squeeze_excite(
model parameters.

Returns:
The modified model
The number of modified modules.

Example:
.. testcode::
Expand All @@ -71,9 +71,10 @@ def convert_module(module: torch.nn.Module, module_index: int):
return None
return SqueezeExciteConv2d.from_conv2d(module, module_index, latent_channels=latent_channels)

module_surgery.replace_module_classes(model, optimizers=optimizers, policies={torch.nn.Conv2d: convert_module})

return model
replaced_instances = module_surgery.replace_module_classes(model,
optimizers=optimizers,
policies={torch.nn.Conv2d: convert_module})
return len(replaced_instances)


class SqueezeExcite2d(torch.nn.Module):
Expand Down Expand Up @@ -163,10 +164,10 @@ def match(self, event: Event, state: State) -> bool:
return event == Event.INIT

def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
state.model = apply_squeeze_excite(state.model,
optimizers=state.optimizers,
latent_channels=self.latent_channels,
min_channels=self.min_channels)
apply_squeeze_excite(state.model,
optimizers=state.optimizers,
latent_channels=self.latent_channels,
min_channels=self.min_channels)
layer_count = module_surgery.count_module_instances(state.model, SqueezeExciteConv2d)

log.info(f'Applied SqueezeExcite to model {state.model.__class__.__name__} '
Expand Down
8 changes: 4 additions & 4 deletions composer/algorithms/stochastic_depth/stochastic_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def apply_stochastic_depth(model: torch.nn.Module,
target_layer_name: str,
stochastic_method: str = 'block',
drop_rate: float = 0.2,
drop_distribution: str = 'linear') -> torch.nn.Module:
drop_distribution: str = 'linear') -> int:
"""Applies Stochastic Depth (`Huang et al, 2016 <https://arxiv.org/abs/1603.09382>`_) to the specified model.

The algorithm replaces the specified target layer with a stochastic version
Expand Down Expand Up @@ -68,7 +68,7 @@ def apply_stochastic_depth(model: torch.nn.Module,
Default: ``"linear"``.

Returns:
The modified model
The number of modified modules

Example:
.. testcode::
Expand All @@ -94,8 +94,8 @@ def apply_stochastic_depth(model: torch.nn.Module,
module_count=module_count,
stochastic_method=stochastic_method)
transforms[target_layer] = stochastic_from_target_layer
module_surgery.replace_module_classes(model, policies=transforms)
return model
replaced_instances = module_surgery.replace_module_classes(model, policies=transforms)
return len(replaced_instances)


class StochasticDepth(Algorithm):
Expand Down
2 changes: 1 addition & 1 deletion examples/exporting_for_inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
"import composer.functional as cf\n",
"\n",
"model = ComposerClassifier(module=resnet.resnet50())\n",
"model = cf.apply_squeeze_excite(model)\n",
"cf.apply_squeeze_excite(model)\n",
"\n",
"# switch to eval mode\n",
"model.eval()"
Expand Down