Skip to content

Commit

Permalink
Merge branch 'dev' into 4980-get-wsi-at-mpp
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolasSchmitz authored Mar 27, 2024
2 parents cfa7f8d + e5bebfc commit 0b7322e
Show file tree
Hide file tree
Showing 9 changed files with 436 additions and 5 deletions.
42 changes: 42 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,27 @@ Post-processing
:members:
:special-members: __call__

Regularization
^^^^^^^^^^^^^^

`CutMix`
""""""""
.. autoclass:: CutMix
:members:
:special-members: __call__

`CutOut`
""""""""
.. autoclass:: CutOut
:members:
:special-members: __call__

`MixUp`
"""""""
.. autoclass:: MixUp
:members:
:special-members: __call__

Signal
^^^^^^^

Expand Down Expand Up @@ -1707,6 +1728,27 @@ Post-processing (Dict)
:members:
:special-members: __call__

Regularization (Dict)
^^^^^^^^^^^^^^^^^^^^^

`CutMixd`
"""""""""
.. autoclass:: CutMixd
:members:
:special-members: __call__

`CutOutd`
"""""""""
.. autoclass:: CutOutd
:members:
:special-members: __call__

`MixUpd`
""""""""
.. autoclass:: MixUpd
:members:
:special-members: __call__

Signal (Dict)
^^^^^^^^^^^^^

Expand Down
10 changes: 10 additions & 0 deletions docs/source/transforms_idx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ Post-processing
post.array
post.dictionary

Regularization
^^^^^^^^^^^^^^

.. autosummary::
:toctree: _gen
:nosignatures:

regularization.array
regularization.dictionary

Signal
^^^^^^

Expand Down
5 changes: 1 addition & 4 deletions monai/bundle/config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,7 @@ def instantiate(self, **kwargs: Any) -> object:
mode = self.get_config().get("_mode_", CompInitMode.DEFAULT)
args = self.resolve_args()
args.update(kwargs)
try:
return instantiate(modname, mode, **args)
except Exception as e:
raise RuntimeError(f"Failed to instantiate {self}") from e
return instantiate(modname, mode, **args)


class ConfigExpression(ConfigItem):
Expand Down
12 changes: 12 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,18 @@
VoteEnsembled,
VoteEnsembleDict,
)
from .regularization.array import CutMix, CutOut, MixUp
from .regularization.dictionary import (
CutMixd,
CutMixD,
CutMixDict,
CutOutd,
CutOutD,
CutOutDict,
MixUpd,
MixUpD,
MixUpDict,
)
from .signal.array import (
SignalContinuousWavelet,
SignalFillEmpty,
Expand Down
10 changes: 10 additions & 0 deletions monai/transforms/regularization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
173 changes: 173 additions & 0 deletions monai/transforms/regularization/array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from abc import abstractmethod
from math import ceil, sqrt

import torch

from ..transform import RandomizableTransform

__all__ = ["MixUp", "CutMix", "CutOut", "Mixer"]


class Mixer(RandomizableTransform):
def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
"""
Mixer is a base class providing the basic logic for the mixup-class of
augmentations. In all cases, we need to sample the mixing weights for each
sample (lambda in the notation used in the papers). Also, pairs of samples
being mixed are picked by randomly shuffling the batch samples.
Args:
batch_size (int): number of samples per batch. That is, samples are expected tp
be of size batchsize x channels [x depth] x height x width.
alpha (float, optional): mixing weights are sampled from the Beta(alpha, alpha)
distribution. Defaults to 1.0, the uniform distribution.
"""
super().__init__()
if alpha <= 0:
raise ValueError(f"Expected positive number, but got {alpha = }")
self.alpha = alpha
self.batch_size = batch_size

@abstractmethod
def apply(self, data: torch.Tensor):
raise NotImplementedError()

def randomize(self, data=None) -> None:
"""
Sometimes you need may to apply the same transform to different tensors.
The idea is to get a sample and then apply it with apply() as often
as needed. You need to call this method everytime you apply the transform to a new
batch.
"""
self._params = (
torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32),
self.R.permutation(self.batch_size),
)


class MixUp(Mixer):
"""MixUp as described in:
Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz.
mixup: Beyond Empirical Risk Minimization, ICLR 2018
Class derived from :py:class:`monai.transforms.Mixer`. See corresponding
documentation for details on the constructor parameters.
"""

def apply(self, data: torch.Tensor):
weight, perm = self._params
nsamples, *dims = data.shape
if len(weight) != nsamples:
raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}")

if len(dims) not in [3, 4]:
raise ValueError("Unexpected number of dimensions")

mixweight = weight[(Ellipsis,) + (None,) * len(dims)]
return mixweight * data + (1 - mixweight) * data[perm, ...]

def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
self.randomize()
if labels is None:
return self.apply(data)
return self.apply(data), self.apply(labels)


class CutMix(Mixer):
"""CutMix augmentation as described in:
Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo.
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features,
ICCV 2019
Class derived from :py:class:`monai.transforms.Mixer`. See corresponding
documentation for details on the constructor parameters. Here, alpha not only determines
the mixing weight but also the size of the random rectangles used during for mixing.
Please refer to the paper for details.
The most common use case is something close to:
.. code-block:: python
cm = CutMix(batch_size=8, alpha=0.5)
for batch in loader:
images, labels = batch
augimg, auglabels = cm(images, labels)
output = model(augimg)
loss = loss_function(output, auglabels)
...
"""

def apply(self, data: torch.Tensor):
weights, perm = self._params
nsamples, _, *dims = data.shape
if len(weights) != nsamples:
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")

mask = torch.ones_like(data)
for s, weight in enumerate(weights):
coords = [torch.randint(0, d, size=(1,)) for d in dims]
lengths = [d * sqrt(1 - weight) for d in dims]
idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)]
mask[s][idx] = 0

return mask * data + (1 - mask) * data[perm, ...]

def apply_on_labels(self, labels: torch.Tensor):
weights, perm = self._params
nsamples, *dims = labels.shape
if len(weights) != nsamples:
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")

mixweight = weights[(Ellipsis,) + (None,) * len(dims)]
return mixweight * labels + (1 - mixweight) * labels[perm, ...]

def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
self.randomize()
augmented = self.apply(data)
return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented


class CutOut(Mixer):
"""Cutout as described in the paper:
Terrance DeVries, Graham W. Taylor.
Improved Regularization of Convolutional Neural Networks with Cutout,
arXiv:1708.04552
Class derived from :py:class:`monai.transforms.Mixer`. See corresponding
documentation for details on the constructor parameters. Here, alpha not only determines
the mixing weight but also the size of the random rectangles being cut put.
Please refer to the paper for details.
"""

def apply(self, data: torch.Tensor):
weights, _ = self._params
nsamples, _, *dims = data.shape
if len(weights) != nsamples:
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")

mask = torch.ones_like(data)
for s, weight in enumerate(weights):
coords = [torch.randint(0, d, size=(1,)) for d in dims]
lengths = [d * sqrt(1 - weight) for d in dims]
idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)]
mask[s][idx] = 0

return mask * data

def __call__(self, data: torch.Tensor):
self.randomize()
return self.apply(data)
97 changes: 97 additions & 0 deletions monai/transforms/regularization/dictionary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from monai.config import KeysCollection
from monai.utils.misc import ensure_tuple

from ..transform import MapTransform
from .array import CutMix, CutOut, MixUp

__all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"]


class MixUpd(MapTransform):
"""
Dictionary-based version :py:class:`monai.transforms.MixUp`.
Notice that the mixup transformation will be the same for all entries
for consistency, i.e. images and labels must be applied the same augmenation.
"""

def __init__(
self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False
) -> None:
super().__init__(keys, allow_missing_keys)
self.mixup = MixUp(batch_size, alpha)

def __call__(self, data):
self.mixup.randomize()
result = dict(data)
for k in self.keys:
result[k] = self.mixup.apply(data[k])
return result


class CutMixd(MapTransform):
"""
Dictionary-based version :py:class:`monai.transforms.CutMix`.
Notice that the mixture weights will be the same for all entries
for consistency, i.e. images and labels must be aggregated with the same weights,
but the random crops are not.
"""

def __init__(
self,
keys: KeysCollection,
batch_size: int,
label_keys: KeysCollection | None = None,
alpha: float = 1.0,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.mixer = CutMix(batch_size, alpha)
self.label_keys = ensure_tuple(label_keys) if label_keys is not None else []

def __call__(self, data):
self.mixer.randomize()
result = dict(data)
for k in self.keys:
result[k] = self.mixer.apply(data[k])
for k in self.label_keys:
result[k] = self.mixer.apply_on_labels(data[k])
return result


class CutOutd(MapTransform):
"""
Dictionary-based version :py:class:`monai.transforms.CutOut`.
Notice that the cutout is different for every entry in the dictionary.
"""

def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bool = False) -> None:
super().__init__(keys, allow_missing_keys)
self.cutout = CutOut(batch_size)

def __call__(self, data):
result = dict(data)
self.cutout.randomize()
for k in self.keys:
result[k] = self.cutout(data[k])
return result


MixUpD = MixUpDict = MixUpd
CutMixD = CutMixDict = CutMixd
CutOutD = CutOutDict = CutOutd
2 changes: 1 addition & 1 deletion monai/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any:
return pdb.runcall(component, **kwargs)
except Exception as e:
raise RuntimeError(
f"Failed to instantiate component '{__path}' with kwargs: {kwargs}"
f"Failed to instantiate component '{__path}' with keywords: {','.join(kwargs.keys())}"
f"\n set '_mode_={CompInitMode.DEBUG}' to enter the debugging mode."
) from e

Expand Down
Loading

0 comments on commit 0b7322e

Please sign in to comment.