From 02615d3bd708133bb395e2cdc9df84d0e0a996e3 Mon Sep 17 00:00:00 2001 From: Yauheni Kachan <19803638+bagxi@users.noreply.github.com> Date: Fri, 4 Feb 2022 02:51:30 +0300 Subject: [PATCH] refactor: package structure updated, docs updated --- docs/index.rst | 4 +- docs/pages/api/catalyst.rst | 28 +++ docs/pages/api/core.rst | 9 - docs/pages/api/criterions.rst | 17 -- docs/pages/api/datasets.rst | 17 +- docs/pages/api/models.rst | 79 +++--- docs/pages/api/nn.rst | 101 ++++++++ docs/pages/api/utils.rst | 8 +- esrgan/criterions/__init__.py | 5 - esrgan/{dataset.py => datasets.py} | 0 esrgan/model/__init__.py | 4 - esrgan/model/discriminator.py | 46 ---- esrgan/model/module/__init__.py | 10 - esrgan/model/module/blocks/__init__.py | 11 - esrgan/model/module/conv.py | 118 --------- esrgan/model/module/linear.py | 83 ------- esrgan/models/__init__.py | 5 + esrgan/models/discriminator.py | 232 ++++++++++++++++++ esrgan/{model/module => models}/esrnet.py | 14 +- esrgan/{model => models}/generator.py | 0 esrgan/{model/module => models}/srresnet.py | 10 +- esrgan/nn/__init__.py | 9 + esrgan/nn/criterions/__init__.py | 5 + esrgan/{ => nn}/criterions/adversarial.py | 0 esrgan/{ => nn}/criterions/perceptual.py | 0 esrgan/nn/modules/__init__.py | 7 + .../module/blocks => nn/modules}/container.py | 0 .../module/blocks => nn/modules}/misc.py | 0 .../module/blocks => nn/modules}/rrdb.py | 3 +- .../blocks => nn/modules}/upsampling.py | 2 +- esrgan/runner.py | 90 ++++--- esrgan/utils/aug.py | 45 ++-- 32 files changed, 533 insertions(+), 429 deletions(-) create mode 100755 docs/pages/api/catalyst.rst delete mode 100755 docs/pages/api/core.rst delete mode 100755 docs/pages/api/criterions.rst create mode 100644 docs/pages/api/nn.rst delete mode 100644 esrgan/criterions/__init__.py rename esrgan/{dataset.py => datasets.py} (100%) delete mode 100644 esrgan/model/__init__.py delete mode 100644 esrgan/model/discriminator.py delete mode 100644 esrgan/model/module/__init__.py delete mode 100644 esrgan/model/module/blocks/__init__.py delete mode 100644 esrgan/model/module/conv.py delete mode 100644 esrgan/model/module/linear.py create mode 100644 esrgan/models/__init__.py create mode 100644 esrgan/models/discriminator.py rename esrgan/{model/module => models}/esrnet.py (92%) rename esrgan/{model => models}/generator.py (100%) rename esrgan/{model/module => models}/srresnet.py (94%) create mode 100644 esrgan/nn/__init__.py create mode 100644 esrgan/nn/criterions/__init__.py rename esrgan/{ => nn}/criterions/adversarial.py (100%) rename esrgan/{ => nn}/criterions/perceptual.py (100%) create mode 100644 esrgan/nn/modules/__init__.py rename esrgan/{model/module/blocks => nn/modules}/container.py (100%) rename esrgan/{model/module/blocks => nn/modules}/misc.py (100%) rename esrgan/{model/module/blocks => nn/modules}/rrdb.py (97%) rename esrgan/{model/module/blocks => nn/modules}/upsampling.py (98%) diff --git a/docs/index.rst b/docs/index.rst index 436ee22..e4e29a5 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -120,11 +120,11 @@ License :maxdepth: 2 :caption: API - pages/api/core + pages/api/nn pages/api/models - pages/api/criterions pages/api/datasets pages/api/utils + pages/api/catalyst Indices and tables ================== diff --git a/docs/pages/api/catalyst.rst b/docs/pages/api/catalyst.rst new file mode 100755 index 0000000..097469a --- /dev/null +++ b/docs/pages/api/catalyst.rst @@ -0,0 +1,28 @@ +Catalyst +======== + +Various features for customization/modification of `Catalyst `__ pipelines e.g., runners, metrics, callbacks: + +.. toctree:: + :titlesonly: + +.. contents:: + :local: + + +Runners +------- + +GANRunner +^^^^^^^^^ + +.. autoclass:: esrgan.runner.GANRunner + :members: + :undoc-members: + +GANConfigRunner +^^^^^^^^^^^^^^^ + +.. autoclass:: esrgan.runner.GANConfigRunner + :members: + :undoc-members: diff --git a/docs/pages/api/core.rst b/docs/pages/api/core.rst deleted file mode 100755 index 7143e94..0000000 --- a/docs/pages/api/core.rst +++ /dev/null @@ -1,9 +0,0 @@ -Core (Catalyst abstractions) -============================ - -GAN Runner -^^^^^^^^^^^ - -.. automodule:: esrgan.runner - :members: - :undoc-members: diff --git a/docs/pages/api/criterions.rst b/docs/pages/api/criterions.rst deleted file mode 100755 index 8530ce3..0000000 --- a/docs/pages/api/criterions.rst +++ /dev/null @@ -1,17 +0,0 @@ -Criterions -========== - -Adversarial Loss -^^^^^^^^^^^^^^^^ - -.. automodule:: esrgan.criterions.adversarial - :members: - :undoc-members: - - -Perceptual Loss -^^^^^^^^^^^^^^^ - -.. automodule:: esrgan.criterions.perceptual - :members: - :undoc-members: diff --git a/docs/pages/api/datasets.rst b/docs/pages/api/datasets.rst index 01f6c7e..1b01c90 100755 --- a/docs/pages/api/datasets.rst +++ b/docs/pages/api/datasets.rst @@ -6,27 +6,36 @@ implemented. Hence, they can all be passed to a :class:`torch.utils.data.DataLoa parallelly using ``torch.multiprocessing`` workers. For example: :: - div2k_data = esrgan.dataset.DIV2KDataset('path/to/div2k_root/') + div2k_data = esrgan.datasets.DIV2KDataset('path/to/div2k_root/') data_loader = torch.utils.data.DataLoader(div2k_data, batch_size=4, shuffle=True) +The models subpackage contains definitions for the following datasets for image super-resolution: + +.. toctree:: + :titlesonly: + +.. contents:: + :local: + + DIV2K ^^^^^ -.. autoclass:: esrgan.dataset.DIV2KDataset +.. autoclass:: esrgan.datasets.DIV2KDataset :members: Flickr2K ^^^^^^^^ -.. autoclass:: esrgan.dataset.Flickr2KDataset +.. autoclass:: esrgan.datasets.Flickr2KDataset :members: Folder of Images ^^^^^^^^^^^^^^^^ -.. autoclass:: esrgan.dataset.ImageFolderDataset +.. autoclass:: esrgan.datasets.ImageFolderDataset :members: :undoc-members: diff --git a/docs/pages/api/models.rst b/docs/pages/api/models.rst index 2169c1b..1348a17 100755 --- a/docs/pages/api/models.rst +++ b/docs/pages/api/models.rst @@ -1,71 +1,82 @@ Models ====== -Generator ---------- +The models subpackage contains definitions of models for addressing image super-resolution tasks: -.. automodule:: esrgan.model.generator +.. toctree:: + :titlesonly: + +.. contents:: + :local: + + +Generators +---------- + +EncoderDecoderNet +^^^^^^^^^^^^^^^^^ + +.. autoclass:: esrgan.models.EncoderDecoderNet :members: :undoc-members: + SRGAN ^^^^^ -.. automodule:: esrgan.model.module.srresnet +SRResNetEncoder +~~~~~~~~~~~~~~~ + +.. autoclass:: esrgan.models.SRResNetEncoder :members: :undoc-members: -ESRGAN -^^^^^^ +SRResNetDecoder +~~~~~~~~~~~~~~~ -.. automodule:: esrgan.model.module.esrnet +.. autoclass:: esrgan.models.SRResNetDecoder :members: :undoc-members: -Discriminator -------------- +ESRGAN +^^^^^^ -.. automodule:: esrgan.model.discriminator - :members: - :undoc-members: -.. automodule:: esrgan.model.module.conv - :members: - :undoc-members: -.. automodule:: esrgan.model.module.linear +ESREncoder +~~~~~~~~~~ + +.. autoclass:: esrgan.models.ESREncoder :members: :undoc-members: +ESRNetDecoder +~~~~~~~~~~~~~ -Layers ------- - -These are the basic building block for graphs - -Containers -^^^^^^^^^^ - -.. automodule:: esrgan.model.module.blocks.container +.. autoclass:: esrgan.models.ESRNetDecoder :members: :undoc-members: -Residual-in-Residual Block -^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. automodule:: esrgan.model.module.blocks.rrdb +Discriminators +-------------- + +VGGConv +^^^^^^^ + +.. autoclass:: esrgan.models.VGGConv :members: :undoc-members: -Upsample -^^^^^^^^ +StridedConvEncoder +~~~~~~~~~~~~~~~~~~ -.. automodule:: esrgan.model.module.blocks.upsampling +.. autoclass:: esrgan.models.StridedConvEncoder :members: :undoc-members: -Misc -^^^^ +LinearHead +~~~~~~~~~~ -.. automodule:: esrgan.model.module.blocks.misc +.. autoclass:: esrgan.models.LinearHead :members: :undoc-members: diff --git a/docs/pages/api/nn.rst b/docs/pages/api/nn.rst new file mode 100644 index 0000000..f7f6eab --- /dev/null +++ b/docs/pages/api/nn.rst @@ -0,0 +1,101 @@ +NN +== + +These are the basic building block for graphs: + +.. toctree:: + :titlesonly: + +.. contents:: + :local: + + +Containers +---------- + +ConcatInputModule +^^^^^^^^^^^^^^^^^ + +.. autoclass:: esrgan.nn.ConcatInputModule + :members: + :undoc-members: + +ResidualModule +^^^^^^^^^^^^^^ + +.. autoclass:: esrgan.nn.ResidualModule + :members: + :undoc-members: + + +Residual-in-Residual layers +--------------------------- + +ResidualDenseBlock +^^^^^^^^^^^^^^^^^^ + +.. autoclass:: esrgan.nn.ResidualDenseBlock + :members: + :undoc-members: + +ResidualInResidualDenseBlock +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: esrgan.nn.ResidualInResidualDenseBlock + :members: + :undoc-members: + + +UpSampling layers +----------------- + +InterpolateConv +^^^^^^^^^^^^^^^ + +.. autoclass:: esrgan.nn.InterpolateConv + :members: + :undoc-members: + +SubPixelConv +^^^^^^^^^^^^ + +.. autoclass:: esrgan.nn.SubPixelConv + :members: + :undoc-members: + + +Loss functions +-------------- + +AdversarialLoss +^^^^^^^^^^^^^^^ + +.. autoclass:: esrgan.nn.AdversarialLoss + :members: + :undoc-members: + +RelativisticAdversarialLoss +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: esrgan.nn.RelativisticAdversarialLoss + :members: + :undoc-members: + +PerceptualLoss +^^^^^^^^^^^^^^ + +.. autoclass:: esrgan.nn.PerceptualLoss + :members: + :undoc-members: + + +Misc +---- + +.. autoclass:: esrgan.nn.Conv2dSN + :members: + :undoc-members: + +.. autoclass:: esrgan.nn.LinearSN + :members: + :undoc-members: diff --git a/docs/pages/api/utils.rst b/docs/pages/api/utils.rst index ba566fc..b61e266 100755 --- a/docs/pages/api/utils.rst +++ b/docs/pages/api/utils.rst @@ -1,7 +1,13 @@ Utilities ========= -Set of utilities that can make life a little bit easier. +Set of utilities that can make life a little bit easier: + +.. toctree:: + :titlesonly: + +.. contents:: + :local: Augmentation diff --git a/esrgan/criterions/__init__.py b/esrgan/criterions/__init__.py deleted file mode 100644 index 81f5bb2..0000000 --- a/esrgan/criterions/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# flake8: noqa -from esrgan.criterions.adversarial import ( - AdversarialLoss, RelativisticAdversarialLoss, -) -from esrgan.criterions.perceptual import PerceptualLoss diff --git a/esrgan/dataset.py b/esrgan/datasets.py similarity index 100% rename from esrgan/dataset.py rename to esrgan/datasets.py diff --git a/esrgan/model/__init__.py b/esrgan/model/__init__.py deleted file mode 100644 index 51236e0..0000000 --- a/esrgan/model/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# flake8: noqa -from esrgan.model import module -from esrgan.model.discriminator import VGGConv -from esrgan.model.generator import EncoderDecoderNet diff --git a/esrgan/model/discriminator.py b/esrgan/model/discriminator.py deleted file mode 100644 index 6adfc0b..0000000 --- a/esrgan/model/discriminator.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch -from torch import nn - -from esrgan import utils - -__all__ = ["VGGConv"] - - -class VGGConv(nn.Module): - """VGG-like neural network for image classification. - - Args: - encoder: Image encoder module, usually used for the extraction - of embeddings from input signals. - pool: Pooling layer, used to reduce embeddings from the encoder. - head: Classification head, usually consists of Fully Connected layers. - - """ - - def __init__( - self, encoder: nn.Module, pool: nn.Module, head: nn.Module, - ) -> None: - super().__init__() - - self.encoder = encoder - self.pool = pool - self.head = head - - # TODO: - utils.net_init_(self) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward call. - - Args: - x: Batch of images. - - Returns: - Batch of logits. - - """ - x = self.pool(self.encoder(x)) - x = x.view(x.shape[0], -1) - x = self.head(x) - - return x diff --git a/esrgan/model/module/__init__.py b/esrgan/model/module/__init__.py deleted file mode 100644 index caa3f57..0000000 --- a/esrgan/model/module/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# flake8: noqa -from esrgan.model.module.blocks import ( - ConcatInputModule, Conv2d, Conv2dSN, InterpolateConv, LeakyReLU, LinearSN, - ResidualDenseBlock, ResidualInResidualDenseBlock, ResidualModule, - SubPixelConv, -) -from esrgan.model.module.conv import StridedConvEncoder -from esrgan.model.module.esrnet import ESREncoder, ESRNetDecoder -from esrgan.model.module.linear import LinearHead -from esrgan.model.module.srresnet import SRResNetDecoder, SRResNetEncoder diff --git a/esrgan/model/module/blocks/__init__.py b/esrgan/model/module/blocks/__init__.py deleted file mode 100644 index f046cd3..0000000 --- a/esrgan/model/module/blocks/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# flake8: noqa -from esrgan.model.module.blocks.container import ( - ConcatInputModule, ResidualModule, -) -from esrgan.model.module.blocks.misc import ( - Conv2d, Conv2dSN, LeakyReLU, LinearSN, -) -from esrgan.model.module.blocks.rrdb import ( - ResidualDenseBlock, ResidualInResidualDenseBlock, -) -from esrgan.model.module.blocks.upsampling import InterpolateConv, SubPixelConv diff --git a/esrgan/model/module/conv.py b/esrgan/model/module/conv.py deleted file mode 100644 index cdf5b54..0000000 --- a/esrgan/model/module/conv.py +++ /dev/null @@ -1,118 +0,0 @@ -import collections -from typing import Callable, Dict, Iterable, List, Optional, Tuple - -import torch -from torch import nn - -from esrgan import utils -from esrgan.model.module import blocks - -__all__ = ["StridedConvEncoder"] - - -class StridedConvEncoder(nn.Module): - """Generalized Fully Convolutional encoder. - - Args: - layers: List of feature maps sizes of each block. - layer_order: Ordered list of layers applied within each block. - For instance, if you don't want to use normalization layer - just exclude it from this list. - conv: Class constructor or partial object which when called - should return convolutional layer e.g., :py:class:`nn.Conv2d`. - norm: Class constructor or partial object which when called should - return normalization layer e.g., :py:class:`.nn.BatchNorm2d`. - activation: Class constructor or partial object which when called - should return activation function to use e.g., :py:class:`nn.ReLU`. - residual: Class constructor or partial object which when called - should return block wrapper module e.g., - :py:class:`~.blocks.container.ResidualModule` can be used - to add residual connections between blocks. - - """ - - def __init__( - self, - layers: Iterable[int] = (3, 64, 128, 128, 256, 256, 512, 512), - layer_order: Iterable[str] = ("conv", "norm", "activation"), - conv: Callable[..., nn.Module] = blocks.Conv2d, - norm: Optional[Callable[..., nn.Module]] = nn.BatchNorm2d, - activation: Callable[..., nn.Module] = blocks.LeakyReLU, - residual: Optional[Callable[..., nn.Module]] = None, - ): - super().__init__() - - name2fn: Dict[str, Callable[..., nn.Module]] = { - "activation": activation, - "conv": conv, - "norm": norm, - } - - self._layers = list(layers) - - net: List[Tuple[str, nn.Module]] = [] - - first_conv = collections.OrderedDict([ - ("conv_0", name2fn["conv"](self._layers[0], self._layers[1])), - ("act", name2fn["activation"]()), - ]) - net.append(("block_0", nn.Sequential(first_conv))) - - channels = utils.pairwise(self._layers[1:]) - for i, (in_ch, out_ch) in enumerate(channels, start=1): - block_list: List[Tuple[str, nn.Module]] = [] - for name in layer_order: - # `conv + 2x2 pooling` is equal to `conv with stride=2` - kwargs = {"stride": out_ch // in_ch} if name == "conv" else {} - - module = utils.create_layer( - layer_name=name, - layer=name2fn[name], - in_channels=in_ch, - out_channels=out_ch, - **kwargs - ) - block_list.append((name, module)) - block = nn.Sequential(collections.OrderedDict(block_list)) - - # add residual connection, like in resnet blocks - if residual is not None and in_ch == out_ch: - block = residual(block) - - net.append((f"block_{i}", block)) - - self.net = nn.Sequential(collections.OrderedDict(net)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass. - - Args: - x: Batch of inputs. - - Returns: - Batch of embeddings. - - """ - output = self.net(x) - - return output - - @property - def in_channels(self) -> int: - """The number of channels in the feature map of the input. - - Returns: - Size of the input feature map. - - """ - return self._layers[0] - - @property - def out_channels(self) -> int: - """Number of channels produced by the block. - - Returns: - Size of the output feature map. - - """ - return self._layers[-1] diff --git a/esrgan/model/module/linear.py b/esrgan/model/module/linear.py deleted file mode 100644 index 84d148c..0000000 --- a/esrgan/model/module/linear.py +++ /dev/null @@ -1,83 +0,0 @@ -from typing import Callable, Dict, Iterable, List, Optional, Tuple - -import torch -from torch import nn - -from esrgan import utils -from esrgan.model.module import blocks - -__all__ = ["LinearHead"] - - -class LinearHead(nn.Module): - """Stack of linear layers used for embeddings classification. - - Args: - in_channels: Size of each input sample. - out_channels: Size of each output sample. - latent_channels: Size of the latent space. - layer_order: Ordered list of layers applied within each block. - For instance, if you don't want to use activation function - just exclude it from this list. - linear: Class constructor or partial object which when called - should return linear layer e.g., :py:class:`nn.Linear`. - activation: Class constructor or partial object which when called - should return activation function layer e.g., :py:class:`nn.ReLU`. - norm: Class constructor or partial object which when called - should return normalization layer e.g., :py:class:`nn.BatchNorm1d`. - dropout: Class constructor or partial object which when called - should return dropout layer e.g., :py:class:`nn.Dropout`. - - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - latent_channels: Optional[Iterable[int]] = None, - layer_order: Iterable[str] = ("linear", "activation"), - linear: Callable[..., nn.Module] = nn.Linear, - activation: Callable[..., nn.Module] = blocks.LeakyReLU, - norm: Optional[Callable[..., nn.Module]] = None, - dropout: Optional[Callable[..., nn.Module]] = None, - ) -> None: - super().__init__() - - name2fn: Dict[str, Callable[..., nn.Module]] = { - "activation": activation, - "dropout": dropout, - "linear": linear, - "norm": norm, - } - - latent_channels = latent_channels or [] - channels = [in_channels, *latent_channels, out_channels] - channels_pairs: List[Tuple[int, int]] = list(utils.pairwise(channels)) - - net: List[nn.Module] = [] - for in_ch, out_ch in channels_pairs[:-1]: - for name in layer_order: - module = utils.create_layer( - layer_name=name, - layer=name2fn[name], - in_channels=in_ch, - out_channels=out_ch, - ) - net.append(module) - net.append(name2fn["linear"](*channels_pairs[-1])) - - self.net = nn.Sequential(*net) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass. - - Args: - x: Batch of inputs e.g. images. - - Returns: - Batch of logits. - - """ - output = self.net(x) - - return output diff --git a/esrgan/models/__init__.py b/esrgan/models/__init__.py new file mode 100644 index 0000000..27e828b --- /dev/null +++ b/esrgan/models/__init__.py @@ -0,0 +1,5 @@ +# flake8: noqa +from esrgan.models.discriminator import LinearHead, StridedConvEncoder, VGGConv +from esrgan.models.esrnet import ESREncoder, ESRNetDecoder +from esrgan.models.generator import EncoderDecoderNet +from esrgan.models.srresnet import SRResNetDecoder, SRResNetEncoder diff --git a/esrgan/models/discriminator.py b/esrgan/models/discriminator.py new file mode 100644 index 0000000..d3091f5 --- /dev/null +++ b/esrgan/models/discriminator.py @@ -0,0 +1,232 @@ +import collections +from typing import Callable, Dict, Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from esrgan import utils +from esrgan.nn import modules + +__all__ = ["StridedConvEncoder", "LinearHead", "VGGConv"] + + +class StridedConvEncoder(nn.Module): + """Generalized Fully Convolutional encoder. + + Args: + layers: List of feature maps sizes of each block. + layer_order: Ordered list of layers applied within each block. + For instance, if you don't want to use normalization layer + just exclude it from this list. + conv: Class constructor or partial object which when called + should return convolutional layer e.g., :py:class:`nn.Conv2d`. + norm: Class constructor or partial object which when called should + return normalization layer e.g., :py:class:`.nn.BatchNorm2d`. + activation: Class constructor or partial object which when called + should return activation function to use e.g., :py:class:`nn.ReLU`. + residual: Class constructor or partial object which when called + should return block wrapper module e.g., + :py:class:`esrgan.nn.ResidualModule` can be used + to add residual connections between blocks. + + """ + + def __init__( + self, + layers: Iterable[int] = (3, 64, 128, 128, 256, 256, 512, 512), + layer_order: Iterable[str] = ("conv", "norm", "activation"), + conv: Callable[..., nn.Module] = modules.Conv2d, + norm: Optional[Callable[..., nn.Module]] = nn.BatchNorm2d, + activation: Callable[..., nn.Module] = modules.LeakyReLU, + residual: Optional[Callable[..., nn.Module]] = None, + ): + super().__init__() + + name2fn: Dict[str, Callable[..., nn.Module]] = { + "activation": activation, + "conv": conv, + "norm": norm, + } + + self._layers = list(layers) + + net: List[Tuple[str, nn.Module]] = [] + + first_conv = collections.OrderedDict([ + ("conv_0", name2fn["conv"](self._layers[0], self._layers[1])), + ("act", name2fn["activation"]()), + ]) + net.append(("block_0", nn.Sequential(first_conv))) + + channels = utils.pairwise(self._layers[1:]) + for i, (in_ch, out_ch) in enumerate(channels, start=1): + block_list: List[Tuple[str, nn.Module]] = [] + for name in layer_order: + # `conv + 2x2 pooling` is equal to `conv with stride=2` + kwargs = {"stride": out_ch // in_ch} if name == "conv" else {} + + module = utils.create_layer( + layer_name=name, + layer=name2fn[name], + in_channels=in_ch, + out_channels=out_ch, + **kwargs + ) + block_list.append((name, module)) + block = nn.Sequential(collections.OrderedDict(block_list)) + + # add residual connection, like in resnet blocks + if residual is not None and in_ch == out_ch: + block = residual(block) + + net.append((f"block_{i}", block)) + + self.net = nn.Sequential(collections.OrderedDict(net)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Batch of inputs. + + Returns: + Batch of embeddings. + + """ + output = self.net(x) + + return output + + @property + def in_channels(self) -> int: + """The number of channels in the feature map of the input. + + Returns: + Size of the input feature map. + + """ + return self._layers[0] + + @property + def out_channels(self) -> int: + """Number of channels produced by the block. + + Returns: + Size of the output feature map. + + """ + return self._layers[-1] + + +class LinearHead(nn.Module): + """Stack of linear layers used for embeddings classification. + + Args: + in_channels: Size of each input sample. + out_channels: Size of each output sample. + latent_channels: Size of the latent space. + layer_order: Ordered list of layers applied within each block. + For instance, if you don't want to use activation function + just exclude it from this list. + linear: Class constructor or partial object which when called + should return linear layer e.g., :py:class:`nn.Linear`. + activation: Class constructor or partial object which when called + should return activation function layer e.g., :py:class:`nn.ReLU`. + norm: Class constructor or partial object which when called + should return normalization layer e.g., :py:class:`nn.BatchNorm1d`. + dropout: Class constructor or partial object which when called + should return dropout layer e.g., :py:class:`nn.Dropout`. + + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + latent_channels: Optional[Iterable[int]] = None, + layer_order: Iterable[str] = ("linear", "activation"), + linear: Callable[..., nn.Module] = nn.Linear, + activation: Callable[..., nn.Module] = modules.LeakyReLU, + norm: Optional[Callable[..., nn.Module]] = None, + dropout: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + + name2fn: Dict[str, Callable[..., nn.Module]] = { + "activation": activation, + "dropout": dropout, + "linear": linear, + "norm": norm, + } + + latent_channels = latent_channels or [] + channels = [in_channels, *latent_channels, out_channels] + channels_pairs: List[Tuple[int, int]] = list(utils.pairwise(channels)) + + net: List[nn.Module] = [] + for in_ch, out_ch in channels_pairs[:-1]: + for name in layer_order: + module = utils.create_layer( + layer_name=name, + layer=name2fn[name], + in_channels=in_ch, + out_channels=out_ch, + ) + net.append(module) + net.append(name2fn["linear"](*channels_pairs[-1])) + + self.net = nn.Sequential(*net) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Batch of inputs e.g. images. + + Returns: + Batch of logits. + + """ + output = self.net(x) + + return output + + +class VGGConv(nn.Module): + """VGG-like neural network for image classification. + + Args: + encoder: Image encoder module, usually used for the extraction + of embeddings from input signals. + pool: Pooling layer, used to reduce embeddings from the encoder. + head: Classification head, usually consists of Fully Connected layers. + + """ + + def __init__( + self, encoder: nn.Module, pool: nn.Module, head: nn.Module, + ) -> None: + super().__init__() + + self.encoder = encoder + self.pool = pool + self.head = head + + # TODO: + utils.net_init_(self) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward call. + + Args: + x: Batch of images. + + Returns: + Batch of logits. + + """ + x = self.pool(self.encoder(x)) + x = x.view(x.shape[0], -1) + x = self.head(x) + + return x diff --git a/esrgan/model/module/esrnet.py b/esrgan/models/esrnet.py similarity index 92% rename from esrgan/model/module/esrnet.py rename to esrgan/models/esrnet.py index 4e419e8..51af6f9 100644 --- a/esrgan/model/module/esrnet.py +++ b/esrgan/models/esrnet.py @@ -5,7 +5,7 @@ from torch import nn from esrgan import utils -from esrgan.model.module import blocks +from esrgan.nn import modules __all__ = ["ESREncoder", "ESRNetDecoder"] @@ -43,8 +43,8 @@ def __init__( num_basic_blocks: int = 23, num_dense_blocks: int = 3, num_residual_blocks: int = 5, - conv: Callable[..., nn.Module] = blocks.Conv2d, - activation: Callable[..., nn.Module] = blocks.LeakyReLU, + conv: Callable[..., nn.Module] = modules.Conv2d, + activation: Callable[..., nn.Module] = modules.LeakyReLU, residual_scaling: float = 0.2, ) -> None: super().__init__() @@ -57,7 +57,7 @@ def __init__( # basic blocks - sequence of rrdb layers for _ in range(num_basic_blocks): - basic_block = blocks.ResidualInResidualDenseBlock( + basic_block = modules.ResidualInResidualDenseBlock( num_features=out_channels, growth_channels=growth_channels, conv=conv, @@ -118,8 +118,8 @@ def __init__( in_channels: int = 64, out_channels: int = 3, scale_factor: int = 2, - conv: Callable[..., nn.Module] = blocks.Conv2d, - activation: Callable[..., nn.Module] = blocks.LeakyReLU, + conv: Callable[..., nn.Module] = modules.Conv2d, + activation: Callable[..., nn.Module] = modules.LeakyReLU, ) -> None: super().__init__() @@ -133,7 +133,7 @@ def __init__( # upsampling for i in range(scale_factor // 2): - upsampling_block = blocks.InterpolateConv( + upsampling_block = modules.InterpolateConv( num_features=in_channels, conv=conv, activation=activation, diff --git a/esrgan/model/generator.py b/esrgan/models/generator.py similarity index 100% rename from esrgan/model/generator.py rename to esrgan/models/generator.py diff --git a/esrgan/model/module/srresnet.py b/esrgan/models/srresnet.py similarity index 94% rename from esrgan/model/module/srresnet.py rename to esrgan/models/srresnet.py index 422b447..a6ded02 100644 --- a/esrgan/model/module/srresnet.py +++ b/esrgan/models/srresnet.py @@ -5,7 +5,7 @@ from torch import nn from esrgan import utils -from esrgan.model.module import blocks +from esrgan.nn import modules __all__ = ["SRResNetEncoder", "SRResNetDecoder"] @@ -38,7 +38,7 @@ def __init__( in_channels: int = 3, out_channels: int = 64, num_basic_blocks: int = 16, - conv: Callable[..., nn.Module] = blocks.Conv2d, + conv: Callable[..., nn.Module] = modules.Conv2d, norm: Callable[..., nn.Module] = nn.BatchNorm2d, activation: Callable[..., nn.Module] = nn.PReLU, ) -> None: @@ -62,7 +62,7 @@ def __init__( conv(num_features, num_features), norm(num_features), ) - blocks_list.append(blocks.ResidualModule(basic_block)) + blocks_list.append(modules.ResidualModule(basic_block)) # last conv of the encoder last_conv = nn.Sequential( @@ -116,7 +116,7 @@ def __init__( in_channels: int = 64, out_channels: int = 3, scale_factor: int = 2, - conv: Callable[..., nn.Module] = blocks.Conv2d, + conv: Callable[..., nn.Module] = modules.Conv2d, activation: Callable[..., nn.Module] = nn.PReLU, ) -> None: super().__init__() @@ -131,7 +131,7 @@ def __init__( # upsampling for i in range(scale_factor // 2): - upsampling_block = blocks.SubPixelConv( + upsampling_block = modules.SubPixelConv( num_features=in_channels, conv=conv, activation=activation, diff --git a/esrgan/nn/__init__.py b/esrgan/nn/__init__.py new file mode 100644 index 0000000..649ed62 --- /dev/null +++ b/esrgan/nn/__init__.py @@ -0,0 +1,9 @@ +# flake8: noqa +from esrgan.nn.criterions import ( + AdversarialLoss, PerceptualLoss, RelativisticAdversarialLoss, +) +from esrgan.nn.modules import ( + ConcatInputModule, Conv2d, Conv2dSN, InterpolateConv, LeakyReLU, LinearSN, + ResidualDenseBlock, ResidualInResidualDenseBlock, ResidualModule, + SubPixelConv, +) diff --git a/esrgan/nn/criterions/__init__.py b/esrgan/nn/criterions/__init__.py new file mode 100644 index 0000000..5784e86 --- /dev/null +++ b/esrgan/nn/criterions/__init__.py @@ -0,0 +1,5 @@ +# flake8: noqa +from esrgan.nn.criterions.adversarial import ( + AdversarialLoss, RelativisticAdversarialLoss, +) +from esrgan.nn.criterions.perceptual import PerceptualLoss diff --git a/esrgan/criterions/adversarial.py b/esrgan/nn/criterions/adversarial.py similarity index 100% rename from esrgan/criterions/adversarial.py rename to esrgan/nn/criterions/adversarial.py diff --git a/esrgan/criterions/perceptual.py b/esrgan/nn/criterions/perceptual.py similarity index 100% rename from esrgan/criterions/perceptual.py rename to esrgan/nn/criterions/perceptual.py diff --git a/esrgan/nn/modules/__init__.py b/esrgan/nn/modules/__init__.py new file mode 100644 index 0000000..2f9f3aa --- /dev/null +++ b/esrgan/nn/modules/__init__.py @@ -0,0 +1,7 @@ +# flake8: noqa +from esrgan.nn.modules.container import ConcatInputModule, ResidualModule +from esrgan.nn.modules.misc import Conv2d, Conv2dSN, LeakyReLU, LinearSN +from esrgan.nn.modules.rrdb import ( + ResidualDenseBlock, ResidualInResidualDenseBlock, +) +from esrgan.nn.modules.upsampling import InterpolateConv, SubPixelConv diff --git a/esrgan/model/module/blocks/container.py b/esrgan/nn/modules/container.py similarity index 100% rename from esrgan/model/module/blocks/container.py rename to esrgan/nn/modules/container.py diff --git a/esrgan/model/module/blocks/misc.py b/esrgan/nn/modules/misc.py similarity index 100% rename from esrgan/model/module/blocks/misc.py rename to esrgan/nn/modules/misc.py diff --git a/esrgan/model/module/blocks/rrdb.py b/esrgan/nn/modules/rrdb.py similarity index 97% rename from esrgan/model/module/blocks/rrdb.py rename to esrgan/nn/modules/rrdb.py index 85022b0..cce787c 100644 --- a/esrgan/model/module/blocks/rrdb.py +++ b/esrgan/nn/modules/rrdb.py @@ -3,7 +3,8 @@ from torch import nn -from esrgan.model.module.blocks import container, Conv2d, LeakyReLU +from esrgan.nn.modules import container +from esrgan.nn.modules.misc import Conv2d, LeakyReLU __all__ = ["ResidualDenseBlock", "ResidualInResidualDenseBlock"] diff --git a/esrgan/model/module/blocks/upsampling.py b/esrgan/nn/modules/upsampling.py similarity index 98% rename from esrgan/model/module/blocks/upsampling.py rename to esrgan/nn/modules/upsampling.py index 36b4317..f3513e7 100644 --- a/esrgan/model/module/blocks/upsampling.py +++ b/esrgan/nn/modules/upsampling.py @@ -4,7 +4,7 @@ from torch import nn from torch.nn import functional as F -from esrgan.model.module.blocks.misc import Conv2d, LeakyReLU +from esrgan.nn.modules.misc import Conv2d, LeakyReLU __all__ = ["SubPixelConv", "InterpolateConv"] diff --git a/esrgan/runner.py b/esrgan/runner.py index 46d9ffb..f0b0170 100644 --- a/esrgan/runner.py +++ b/esrgan/runner.py @@ -10,6 +10,26 @@ class GANRunner(IRunner): """Runner for ESRGAN, please check `catalyst docs`__ for more info. + Args: + input_key: Key in batch dict mapping for model input. + target_key: Key in batch dict mapping for target. + generator_output_key: Key in output dict model output + of the generator will be stored under. + discriminator_real_output_gkey: Key to store predictions of + discriminator for real inputs, contain gradients for generator. + discriminator_fake_output_gkey: Key to store predictions of + discriminator for predictions of generator, + contain gradients for generator. + discriminator_real_output_dkey: Key to store predictions of + discriminator for real inputs, + contain gradients for discriminator only. + discriminator_fake_output_dkey: Key to store predictions of + discriminator for items produced by generator, + contain gradients for discriminator only. + generator_key: Key in model dict mapping for generator model. + discriminator_key: Key in model dict mapping for discriminator + model (will be used in gan stages only). + __ https://catalyst-team.github.io/catalyst/api/core.html#experiment """ @@ -26,29 +46,6 @@ def __init__( generator_key: str = "generator", discriminator_key: str = "discriminator", ) -> None: - """Constructor method for the :py:class:`GANRunner` class. - - Args: - input_key: Key in batch dict mapping for model input. - target_key: Key in batch dict mapping for target. - generator_output_key: Key in output dict model output - of the generator will be stored under. - discriminator_real_output_gkey: Key to store predictions of - discriminator for real inputs, contain gradients for generator. - discriminator_fake_output_gkey: Key to store predictions of - discriminator for predictions of generator, - contain gradients for generator. - discriminator_real_output_dkey: Key to store predictions of - discriminator for real inputs, - contain gradients for discriminator only. - discriminator_fake_output_dkey: Key to store predictions of - discriminator for items produced by generator, - contain gradients for discriminator only. - generator_key: Key in model dict mapping for generator model. - discriminator_key: Key in model dict mapping for discriminator - model (will be used in gan stages only). - - """ super().__init__() self.generator_key = generator_key @@ -153,12 +150,33 @@ def on_stage_start(self, runner: IRunner) -> None: elif self.stage_key.endswith("_gan"): self.handle_batch = self._handle_batch_gan else: - raise NotImplementedError() + raise NotImplementedError(f"`{self.stage_key}` is not supported") class GANConfigRunner(runners.ConfigRunner, GANRunner): """ConfigRunner for ESRGAN, please check `catalyst docs`__ for more info. + Args: + config: Dictionary with parameters e.g., model or engine to use. + input_key: Key in batch dict mapping for model input. + target_key: Key in batch dict mapping for target. + generator_output_key: Key in output dict model output + of the generator will be stored under. + discriminator_real_output_gkey: Key to store predictions of + discriminator for real inputs, contain gradients for generator. + discriminator_fake_output_gkey: Key to store predictions of + discriminator for predictions of generator, + contain gradients for generator. + discriminator_real_output_dkey: Key to store predictions of + discriminator for real inputs, + contain gradients for discriminator only. + discriminator_fake_output_dkey: Key to store predictions of + discriminator for items produced by generator, + contain gradients for discriminator only. + generator_key: Key in model dict mapping for generator model. + discriminator_key: Key in model dict mapping for discriminator + model (will be used in gan stages only). + __ https://catalyst-team.github.io/catalyst/api/core.html#experiment """ @@ -176,30 +194,6 @@ def __init__( generator_key: str = "generator", discriminator_key: str = "discriminator", ): - """Constructor method for the :py:class:`GANConfigRunner` class. - - Args: - config: Dictionary with parameters e.g., model or engine to use. - input_key: Key in batch dict mapping for model input. - target_key: Key in batch dict mapping for target. - generator_output_key: Key in output dict model output - of the generator will be stored under. - discriminator_real_output_gkey: Key to store predictions of - discriminator for real inputs, contain gradients for generator. - discriminator_fake_output_gkey: Key to store predictions of - discriminator for predictions of generator, - contain gradients for generator. - discriminator_real_output_dkey: Key to store predictions of - discriminator for real inputs, - contain gradients for discriminator only. - discriminator_fake_output_dkey: Key to store predictions of - discriminator for items produced by generator, - contain gradients for discriminator only. - generator_key: Key in model dict mapping for generator model. - discriminator_key: Key in model dict mapping for discriminator - model (will be used in gan stages only). - - """ GANRunner.__init__( self, input_key=input_key, diff --git a/esrgan/utils/aug.py b/esrgan/utils/aug.py index fe2159b..565ecfd 100644 --- a/esrgan/utils/aug.py +++ b/esrgan/utils/aug.py @@ -3,21 +3,33 @@ __all__ = ["Augmentor"] +def indentity(d: Dict) -> Dict: + """A placeholder identity operator that is argument-insensitive. + + Args: + d: Dictionary with the data that describes sample. + + Returns: + Same dictionary ``d``. + + """ + return d + + class Augmentor: - """Applies provided transformation on dictionaries.""" + """Applies provided transformation on dictionaries. + + Args: + transform: A function / transform that takes in dictionary + and returns a transformed version. + If ``None``, the identity function is used. + + """ def __init__( self, transform: Optional[Callable[[Any], Dict]] = None ) -> None: - """Constructor method for the :py:class:`Augmentor` class. - - Args: - transform: A function / transform that takes in dictionary - and returns a transformed version. - If ``None``, the identity function is used. - - """ - self.transform = transform if transform is not None else self.indentity + self.transform = transform if transform is not None else indentity def __call__(self, d: Dict) -> Dict: """Applies ``transform`` to the dictionary ``d``. @@ -30,16 +42,3 @@ def __call__(self, d: Dict) -> Dict: """ return self.transform(**d) - - @staticmethod - def indentity(d: Dict) -> Dict: - """A placeholder identity operator that is argument-insensitive. - - Args: - d: Dictionary with the data that describes sample. - - Returns: - Same dictionary ``d``. - - """ - return d