diff --git a/CHANGELOG.md b/CHANGELOG.md index 4490741efc..535a43dc4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added YOLO model ([#552](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/552)) +- Added `SRGAN`, `SRImageLoggerCallback`, `TVTDataModule`, `SRCelebA`, `SRMNIST`, `SRSTL10` ([#466](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/466)) + + - Added nn.Module support for FasterRCNN backbone ([#661](https://github.com/PyTorchLightning/lightning-bolts/pull/661)) @@ -122,8 +125,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#323](https://github.com/PyTorchLightning/lightning-bolts/pull/323)) - Added data monitor callbacks `ModuleDataMonitor` and `TrainingDataMonitor` ([#285](https://github.com/PyTorchLightning/lightning-bolts/pull/285)) - Added DCGAN module ([#403](https://github.com/PyTorchLightning/lightning-bolts/pull/403)) -- Added `VisionDataModule` as parent class for `BinaryMNISTDataModule`, `CIFAR10DataModule`, `FashionMNISTDataModule`, - and `MNISTDataModule` ([#400](https://github.com/PyTorchLightning/lightning-bolts/pull/400)) +- Added `VisionDataModule` as parent class for `BinaryMNISTDataModule`, `CIFAR10DataModule`, `FashionMNISTDataModule`, and `MNISTDataModule` ([#400](https://github.com/PyTorchLightning/lightning-bolts/pull/400)) - Added GIoU loss ([#347](https://github.com/PyTorchLightning/lightning-bolts/pull/347)) - Added IoU loss ([#469](https://github.com/PyTorchLightning/lightning-bolts/pull/469)) - Added semantic segmentation model `SemSegment` with `UNet` backend ([#259](https://github.com/PyTorchLightning/lightning-bolts/pull/259)) diff --git a/docs/source/_images/gans/srgan-celeba-scale_factor=2.png b/docs/source/_images/gans/srgan-celeba-scale_factor=2.png new file mode 100644 index 0000000000..25b18474d4 Binary files /dev/null and b/docs/source/_images/gans/srgan-celeba-scale_factor=2.png differ diff --git a/docs/source/_images/gans/srgan-celeba-scale_factor=4.png b/docs/source/_images/gans/srgan-celeba-scale_factor=4.png new file mode 100644 index 0000000000..86ab633999 Binary files /dev/null and b/docs/source/_images/gans/srgan-celeba-scale_factor=4.png differ diff --git a/docs/source/_images/gans/srgan-mnist-scale_factor=2.png b/docs/source/_images/gans/srgan-mnist-scale_factor=2.png new file mode 100644 index 0000000000..3bafd1d1f0 Binary files /dev/null and b/docs/source/_images/gans/srgan-mnist-scale_factor=2.png differ diff --git a/docs/source/_images/gans/srgan-mnist-scale_factor=4.png b/docs/source/_images/gans/srgan-mnist-scale_factor=4.png new file mode 100644 index 0000000000..881fd0dd14 Binary files /dev/null and b/docs/source/_images/gans/srgan-mnist-scale_factor=4.png differ diff --git a/docs/source/_images/gans/srgan-stl10-scale_factor=2.png b/docs/source/_images/gans/srgan-stl10-scale_factor=2.png new file mode 100644 index 0000000000..0c15bf3c37 Binary files /dev/null and b/docs/source/_images/gans/srgan-stl10-scale_factor=2.png differ diff --git a/docs/source/_images/gans/srgan-stl10-scale_factor=4.png b/docs/source/_images/gans/srgan-stl10-scale_factor=4.png new file mode 100644 index 0000000000..ffc0f1869c Binary files /dev/null and b/docs/source/_images/gans/srgan-stl10-scale_factor=4.png differ diff --git a/docs/source/deprecated/models/gans.rst b/docs/source/deprecated/models/gans.rst index b51d69514c..cdb80dd195 100644 --- a/docs/source/deprecated/models/gans.rst +++ b/docs/source/deprecated/models/gans.rst @@ -86,3 +86,86 @@ LSUN Loss curves: .. autoclass:: pl_bolts.models.gans.DCGAN :noindex: + + +SRGAN +--------- +SRGAN implementation from the paper `Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial +Network `_. The implementation is based on the version from +`deeplearning.ai `_. + +Implemented by: + + - `Christoph Clement `_ + +MNIST results: + + SRGAN MNIST with scale factor of 2 (left: low res, middle: generated high res, right: ground truth high res): + + .. image:: ../../_images/gans/srgan-mnist-scale_factor=2.png + :width: 200 + :alt: SRGAN MNIST with scale factor of 2 + + SRGAN MNIST with scale factor of 4: + + .. image:: ../../_images/gans/srgan-mnist-scale_factor=4.png + :width: 200 + :alt: SRGAN MNIST with scale factor of 4 + + SRResNet pretraining command used:: + >>> python srresnet_module.py --dataset=mnist --data_dir=~/Data --scale_factor=4 --save_model_checkpoint \ + --batch_size=16 --num_workers=2 --gpus=4 --accelerator=ddp --precision=16 --max_steps=25000 + + SRGAN training command used:: + >>> python srgan_module.py --dataset=mnist --data_dir=~/Data --scale_factor=4 --batch_size=16 \ + --num_workers=2 --scheduler_step=29 --gpus=4 --accelerator=ddp --precision=16 --max_steps=50000 + +STL10 results: + + SRGAN STL10 with scale factor of 2: + + .. image:: ../../_images/gans/srgan-stl10-scale_factor=2.png + :width: 200 + :alt: SRGAN STL10 with scale factor of 2 + + SRGAN STL10 with scale factor of 4: + + .. image:: ../../_images/gans/srgan-stl10-scale_factor=4.png + :width: 200 + :alt: SRGAN STL10 with scale factor of 4 + + SRResNet pretraining command used:: + >>> python srresnet_module.py --dataset=stl10 --data_dir=~/Data --scale_factor=4 --save_model_checkpoint \ + --batch_size=16 --num_workers=2 --gpus=4 --accelerator=ddp --precision=16 --max_steps=25000 + + SRGAN training command used:: + >>> python srgan_module.py --dataset=stl10 --data_dir=~/Data --scale_factor=4 --batch_size=16 \ + --num_workers=2 --scheduler_step=29 --gpus=4 --accelerator=ddp --precision=16 --max_steps=50000 + +CelebA results: + + SRGAN CelebA with scale factor of 2: + + .. image:: ../../_images/gans/srgan-celeba-scale_factor=2.png + :width: 200 + :alt: SRGAN CelebA with scale factor of 2 + + SRGAN CelebA with scale factor of 4: + + .. image:: ../../_images/gans/srgan-celeba-scale_factor=4.png + :width: 200 + :alt: SRGAN CelebA with scale factor of 4 + + SRResNet pretraining command used:: + >>> python srresnet_module.py --dataset=celeba --data_dir=~/Data --scale_factor=4 --save_model_checkpoint \ + --batch_size=16 --num_workers=2 --gpus=4 --accelerator=ddp --precision=16 --max_steps=25000 + + SRGAN training command used:: + >>> python srgan_module.py --dataset=celeba --data_dir=~/Data --scale_factor=4 --batch_size=16 \ + --num_workers=2 --scheduler_step=29 --gpus=4 --accelerator=ddp --precision=16 --max_steps=50000 + +.. autoclass:: pl_bolts.models.gans.SRGAN + :noindex: + +.. autoclass:: pl_bolts.models.gans.SRResNet + :noindex: diff --git a/pl_bolts/callbacks/__init__.py b/pl_bolts/callbacks/__init__.py index f52a408c48..2225372f48 100644 --- a/pl_bolts/callbacks/__init__.py +++ b/pl_bolts/callbacks/__init__.py @@ -9,6 +9,7 @@ from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerificationCallback # type: ignore from pl_bolts.callbacks.vision.confused_logit import ConfusedLogitCallback from pl_bolts.callbacks.vision.image_generation import TensorboardGenerativeModelImageSampler +from pl_bolts.callbacks.vision.sr_image_logger import SRImageLoggerCallback __all__ = [ "BatchGradientVerificationCallback", @@ -20,6 +21,7 @@ "LatentDimInterpolator", "ConfusedLogitCallback", "TensorboardGenerativeModelImageSampler", + "SRImageLoggerCallback", "ORTCallback", "SparseMLCallback", ] diff --git a/pl_bolts/callbacks/vision/sr_image_logger.py b/pl_bolts/callbacks/vision/sr_image_logger.py new file mode 100644 index 0000000000..2767b04646 --- /dev/null +++ b/pl_bolts/callbacks/vision/sr_image_logger.py @@ -0,0 +1,67 @@ +from typing import Tuple + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from pytorch_lightning import Callback + +from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _TORCHVISION_AVAILABLE: + from torchvision.utils import make_grid +else: # pragma: no cover + warn_missing_pkg("torchvision") + + +class SRImageLoggerCallback(Callback): + """Logs low-res, generated high-res, and ground truth high-res images to TensorBoard Your model must implement + the ``forward`` function for generation. + + Requirements:: + + # model forward must work generating high-res from low-res image + hr_fake = pl_module(lr_image) + + Example:: + + from pl_bolts.callbacks import SRImageLoggerCallback + + trainer = Trainer(callbacks=[SRImageLoggerCallback()]) + """ + + def __init__(self, log_interval: int = 1000, scale_factor: int = 4, num_samples: int = 5) -> None: + """ + Args: + log_interval: Number of steps between logging. Default: ``1000``. + scale_factor: Scale factor used for downsampling the high-res images. Default: ``4``. + num_samples: Number of images of displayed in the grid. Default: ``5``. + """ + super().__init__() + self.log_interval = log_interval + self.scale_factor = scale_factor + self.num_samples = num_samples + + def on_train_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: torch.Tensor, + batch: Tuple[torch.Tensor, torch.Tensor], + batch_idx: int, + dataloader_idx: int, + ) -> None: + global_step = trainer.global_step + if global_step % self.log_interval == 0: + hr_image, lr_image = batch + hr_image, lr_image = hr_image.to(pl_module.device), lr_image.to(pl_module.device) + hr_fake = pl_module(lr_image) + lr_image = F.interpolate(lr_image, scale_factor=self.scale_factor) + + lr_image_grid = make_grid(lr_image[: self.num_samples], nrow=1, normalize=True) + hr_fake_grid = make_grid(hr_fake[: self.num_samples], nrow=1, normalize=True) + hr_image_grid = make_grid(hr_image[: self.num_samples], nrow=1, normalize=True) + + grid = torch.cat((lr_image_grid, hr_fake_grid, hr_image_grid), -1) + title = "sr_images" + trainer.logger.experiment.add_image(title, grid, global_step=global_step) diff --git a/pl_bolts/datamodules/__init__.py b/pl_bolts/datamodules/__init__.py index 83fd6efbe3..15515b2562 100644 --- a/pl_bolts/datamodules/__init__.py +++ b/pl_bolts/datamodules/__init__.py @@ -10,6 +10,7 @@ from pl_bolts.datamodules.kitti_datamodule import KittiDataModule from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule, SklearnDataset, TensorDataset +from pl_bolts.datamodules.sr_datamodule import TVTDataModule from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule from pl_bolts.datamodules.stl10_datamodule import STL10DataModule from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule @@ -31,6 +32,7 @@ "SklearnDataModule", "SklearnDataset", "TensorDataset", + "TVTDataModule", "SSLImagenetDataModule", "STL10DataModule", "VOCDetectionDataModule", diff --git a/pl_bolts/datamodules/sr_datamodule.py b/pl_bolts/datamodules/sr_datamodule.py new file mode 100644 index 0000000000..7a2c06dacf --- /dev/null +++ b/pl_bolts/datamodules/sr_datamodule.py @@ -0,0 +1,73 @@ +from typing import Any + +from pytorch_lightning import LightningDataModule +from torch.utils.data import DataLoader, Dataset + + +class TVTDataModule(LightningDataModule): + """Simple DataModule creating train, val, and test dataloaders from given train, val, and test dataset. + + Example:: + from pl_bolts.datamodules import TVTDataModule + from pl_bolts.datasets.sr_mnist_dataset import SRMNIST + + dataset_dev = SRMNIST(scale_factor=4, root=".", train=True) + dataset_train, dataset_val = random_split(dataset_dev, lengths=[55_000, 5_000]) + dataset_test = SRMNIST(scale_factor=4, root=".", train=True) + dm = TVTDataModule(dataset_train, dataset_val, dataset_test) + """ + + def __init__( + self, + dataset_train: Dataset, + dataset_val: Dataset, + dataset_test: Dataset, + batch_size: int = 16, + shuffle: bool = True, + num_workers: int = 8, + pin_memory: bool = True, + drop_last: bool = True, + *args: Any, + **kwargs: Any, + ) -> None: + """ + Args: + dataset_train: Train dataset + dataset_val: Val dataset + dataset_test: Test dataset + batch_size: How many samples per batch to load + num_workers: How many workers to use for loading data + shuffle: If true shuffles the train data every epoch + pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before + returning them + drop_last: If true drops the last incomplete batch + """ + super().__init__() + + self.dataset_train = dataset_train + self.dataset_val = dataset_val + self.dataset_test = dataset_test + self.num_workers = num_workers + self.batch_size = batch_size + self.shuffle = shuffle + self.pin_memory = pin_memory + self.drop_last = drop_last + + def train_dataloader(self) -> DataLoader: + return self._dataloader(self.dataset_train, shuffle=self.shuffle) + + def val_dataloader(self) -> DataLoader: + return self._dataloader(self.dataset_val, shuffle=False) + + def test_dataloader(self) -> DataLoader: + return self._dataloader(self.dataset_test, shuffle=False) + + def _dataloader(self, dataset: Dataset, shuffle: bool = True) -> DataLoader: + return DataLoader( + dataset, + batch_size=self.batch_size, + shuffle=shuffle, + num_workers=self.num_workers, + drop_last=self.drop_last, + pin_memory=self.pin_memory, + ) diff --git a/pl_bolts/datasets/sr_celeba_dataset.py b/pl_bolts/datasets/sr_celeba_dataset.py new file mode 100644 index 0000000000..f912ced3b1 --- /dev/null +++ b/pl_bolts/datasets/sr_celeba_dataset.py @@ -0,0 +1,33 @@ +import os +from typing import Any + +from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin +from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _PIL_AVAILABLE: + from PIL import Image +else: # pragma: no cover + warn_missing_pkg("PIL", pypi_name="Pillow") + +if _TORCHVISION_AVAILABLE: + from torchvision.datasets import CelebA +else: # pragma: no cover + warn_missing_pkg("torchvision") + CelebA = object + + +class SRCelebA(SRDatasetMixin, CelebA): + """CelebA dataset that can be used to train Super Resolution models. + + Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image. + """ + + def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None: + hr_image_size = 128 + lr_image_size = hr_image_size // scale_factor + self.image_channels = 3 + super().__init__(hr_image_size, lr_image_size, self.image_channels, *args, **kwargs) + + def _get_image(self, index: int): + return Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) diff --git a/pl_bolts/datasets/sr_dataset_mixin.py b/pl_bolts/datasets/sr_dataset_mixin.py new file mode 100644 index 0000000000..17bd176f92 --- /dev/null +++ b/pl_bolts/datasets/sr_dataset_mixin.py @@ -0,0 +1,52 @@ +"""Adapted from: https://github.com/https-deeplearning-ai/GANs-Public.""" +from typing import Any, Tuple + +import torch + +from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _PIL_AVAILABLE: + from PIL import Image +else: # pragma: no cover + warn_missing_pkg("PIL", pypi_name="Pillow") + +if _TORCHVISION_AVAILABLE: + from torchvision import transforms as transform_lib +else: # pragma: no cover + warn_missing_pkg("torchvision") + + +class SRDatasetMixin: + """Mixin for Super Resolution datasets. + + Scales range of high resolution images to [-1, 1] and range or low resolution images to [0, 1]. + """ + + def __init__(self, hr_image_size: int, lr_image_size: int, image_channels: int, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + self.hr_transforms = transform_lib.Compose( + [ + transform_lib.RandomCrop(hr_image_size), + transform_lib.ToTensor(), + transform_lib.Normalize(mean=(0.5,) * image_channels, std=(0.5,) * image_channels), + ] + ) + + self.lr_transforms = transform_lib.Compose( + [ + transform_lib.Normalize(mean=(-1.0,) * image_channels, std=(2.0,) * image_channels), + transform_lib.ToPILImage(), + transform_lib.Resize(lr_image_size, Image.BICUBIC), + transform_lib.ToTensor(), + ] + ) + + def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: + image = self._get_image(index) + + hr_image = self.hr_transforms(image) + lr_image = self.lr_transforms(hr_image) + + return hr_image, lr_image diff --git a/pl_bolts/datasets/sr_mnist_dataset.py b/pl_bolts/datasets/sr_mnist_dataset.py new file mode 100644 index 0000000000..70fb7c2c23 --- /dev/null +++ b/pl_bolts/datasets/sr_mnist_dataset.py @@ -0,0 +1,27 @@ +from typing import Any + +from pl_bolts.datasets.mnist_dataset import MNIST +from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin +from pl_bolts.utils import _PIL_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _PIL_AVAILABLE: + from PIL import Image +else: # pragma: no cover + warn_missing_pkg("PIL", pypi_name="Pillow") + + +class SRMNIST(SRDatasetMixin, MNIST): + """MNIST dataset that can be used to train Super Resolution models. + + Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image. + """ + + def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None: + hr_image_size = 28 + lr_image_size = hr_image_size // scale_factor + self.image_channels = 1 + super().__init__(hr_image_size, lr_image_size, self.image_channels, *args, **kwargs) + + def _get_image(self, index: int): + return Image.fromarray(self.data[index].numpy(), mode="L") diff --git a/pl_bolts/datasets/sr_stl10_dataset.py b/pl_bolts/datasets/sr_stl10_dataset.py new file mode 100644 index 0000000000..868565fd22 --- /dev/null +++ b/pl_bolts/datasets/sr_stl10_dataset.py @@ -0,0 +1,34 @@ +from typing import Any + +import numpy as np + +from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin +from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _PIL_AVAILABLE: + import PIL +else: # pragma: no cover + warn_missing_pkg("PIL", pypi_name="Pillow") + +if _TORCHVISION_AVAILABLE: + from torchvision.datasets import STL10 +else: # pragma: no cover + warn_missing_pkg("torchvision") + STL10 = object + + +class SRSTL10(SRDatasetMixin, STL10): + """STL10 dataset that can be used to train Super Resolution models. + + Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image. + """ + + def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None: + hr_image_size = 96 + lr_image_size = hr_image_size // scale_factor + self.image_channels = 3 + super().__init__(hr_image_size, lr_image_size, self.image_channels, *args, **kwargs) + + def _get_image(self, index: int): + return PIL.Image.fromarray(np.transpose(self.data[index], (1, 2, 0))) diff --git a/pl_bolts/datasets/utils.py b/pl_bolts/datasets/utils.py new file mode 100644 index 0000000000..77946beb5b --- /dev/null +++ b/pl_bolts/datasets/utils.py @@ -0,0 +1,39 @@ +from torch.utils.data.dataset import random_split + +from pl_bolts.datasets.sr_celeba_dataset import SRCelebA +from pl_bolts.datasets.sr_mnist_dataset import SRMNIST +from pl_bolts.datasets.sr_stl10_dataset import SRSTL10 + + +def prepare_sr_datasets(dataset: str, scale_factor: int, data_dir: str): + """Creates train, val, and test datasets for training a Super Resolution GAN. + + Args: + dataset: string indicating which dataset class to use (celeba, mnist, or stl10). + scale_factor: scale factor between low- and high resolution images. + data_dir: root dir of dataset. + + Returns: + sr_datasets: tuple containing train, val, and test dataset. + """ + assert dataset in ["celeba", "mnist", "stl10"] + + if dataset == "celeba": + dataset_cls = SRCelebA + dataset_train = dataset_cls(scale_factor, root=data_dir, split="train", download=True) + dataset_val = dataset_cls(scale_factor, root=data_dir, split="valid", download=True) + dataset_test = dataset_cls(scale_factor, root=data_dir, split="test", download=True) + + elif dataset == "mnist": + dataset_cls = SRMNIST + dataset_dev = dataset_cls(scale_factor, root=data_dir, train=True, download=True) + dataset_train, dataset_val = random_split(dataset_dev, lengths=[55_000, 5_000]) + dataset_test = dataset_cls(scale_factor, root=data_dir, train=False, download=True) + + elif dataset == "stl10": + dataset_cls = SRSTL10 + dataset_dev = dataset_cls(scale_factor, root=data_dir, split="train", download=True) + dataset_train, dataset_val = random_split(dataset_dev, lengths=[4_500, 500]) + dataset_test = dataset_cls(scale_factor, root=data_dir, split="test", download=True) + + return (dataset_train, dataset_val, dataset_test) diff --git a/pl_bolts/models/gans/__init__.py b/pl_bolts/models/gans/__init__.py index 751ba576ef..5132052460 100644 --- a/pl_bolts/models/gans/__init__.py +++ b/pl_bolts/models/gans/__init__.py @@ -1,9 +1,13 @@ from pl_bolts.models.gans.basic.basic_gan_module import GAN from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN from pl_bolts.models.gans.pix2pix.pix2pix_module import Pix2Pix +from pl_bolts.models.gans.srgan.srgan_module import SRGAN +from pl_bolts.models.gans.srgan.srresnet_module import SRResNet __all__ = [ "GAN", "DCGAN", "Pix2Pix", + "SRGAN", + "SRResNet", ] diff --git a/pl_bolts/models/gans/srgan/__init__.py b/pl_bolts/models/gans/srgan/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pl_bolts/models/gans/srgan/components.py b/pl_bolts/models/gans/srgan/components.py new file mode 100644 index 0000000000..63a06006aa --- /dev/null +++ b/pl_bolts/models/gans/srgan/components.py @@ -0,0 +1,150 @@ +"""Adapted from: https://github.com/https-deeplearning-ai/GANs-Public.""" +import torch +import torch.nn as nn + +from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _TORCHVISION_AVAILABLE: + from torchvision.models import vgg19 +else: # pragma: no cover + warn_missing_pkg("torchvision") + + +class ResidualBlock(nn.Module): + def __init__(self, feature_maps: int = 64) -> None: + super().__init__() + + self.block = nn.Sequential( + nn.Conv2d(feature_maps, feature_maps, kernel_size=3, padding=1), + nn.BatchNorm2d(feature_maps), + nn.PReLU(), + nn.Conv2d(feature_maps, feature_maps, kernel_size=3, padding=1), + nn.BatchNorm2d(feature_maps), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.block(x) + + +class SRGANGenerator(nn.Module): + def __init__( + self, + image_channels: int, + feature_maps: int = 64, + num_res_blocks: int = 16, + num_ps_blocks: int = 2, + ) -> None: + super().__init__() + # Input block (k9n64s1) + self.input_block = nn.Sequential( + nn.Conv2d(image_channels, feature_maps, kernel_size=9, padding=4), + nn.PReLU(), + ) + + # B residual blocks (k3n64s1) + res_blocks = [] + for _ in range(num_res_blocks): + res_blocks += [ResidualBlock(feature_maps)] + + # k3n64s1 + res_blocks += [ + nn.Conv2d(feature_maps, feature_maps, kernel_size=3, padding=1), + nn.BatchNorm2d(feature_maps), + ] + self.res_blocks = nn.Sequential(*res_blocks) + + # PixelShuffle blocks (k3n256s1) + ps_blocks = [] + for _ in range(num_ps_blocks): + ps_blocks += [ + nn.Conv2d(feature_maps, 4 * feature_maps, kernel_size=3, padding=1), + nn.PixelShuffle(2), + nn.PReLU(), + ] + self.ps_blocks = nn.Sequential(*ps_blocks) + + # Output block (k9n3s1) + self.output_block = nn.Sequential( + nn.Conv2d(feature_maps, image_channels, kernel_size=9, padding=4), + nn.Tanh(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_res = self.input_block(x) + x = x_res + self.res_blocks(x_res) + x = self.ps_blocks(x) + x = self.output_block(x) + return x + + +class SRGANDiscriminator(nn.Module): + def __init__(self, image_channels: int, feature_maps: int = 64) -> None: + super().__init__() + + self.conv_blocks = nn.Sequential( + # k3n64s1, k3n64s2 + self._make_double_conv_block(image_channels, feature_maps, first_batch_norm=False), + # k3n128s1, k3n128s2 + self._make_double_conv_block(feature_maps, feature_maps * 2), + # k3n256s1, k3n256s2 + self._make_double_conv_block(feature_maps * 2, feature_maps * 4), + # k3n512s1, k3n512s2 + self._make_double_conv_block(feature_maps * 4, feature_maps * 8), + ) + + self.mlp = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(feature_maps * 8, feature_maps * 16, kernel_size=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(feature_maps * 16, 1, kernel_size=1), + nn.Flatten(), + ) + + def _make_double_conv_block( + self, + in_channels: int, + out_channels: int, + first_batch_norm: bool = True, + ) -> nn.Sequential: + return nn.Sequential( + self._make_conv_block(in_channels, out_channels, batch_norm=first_batch_norm), + self._make_conv_block(out_channels, out_channels, stride=2), + ) + + @staticmethod + def _make_conv_block( + in_channels: int, + out_channels: int, + stride: int = 1, + batch_norm: bool = True, + ) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1), + nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(), + nn.LeakyReLU(0.2, inplace=True), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv_blocks(x) + x = self.mlp(x) + return x + + +class VGG19FeatureExtractor(nn.Module): + def __init__(self, image_channels: int = 3) -> None: + super().__init__() + + assert image_channels in [1, 3] + self.image_channels = image_channels + + vgg = vgg19(pretrained=True) + self.vgg = nn.Sequential(*list(vgg.features)[:-1]).eval() + for p in self.vgg.parameters(): + p.requires_grad = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.image_channels == 1: + x = x.repeat(1, 3, 1, 1) + + return self.vgg(x) diff --git a/pl_bolts/models/gans/srgan/srgan_module.py b/pl_bolts/models/gans/srgan/srgan_module.py new file mode 100644 index 0000000000..434731c761 --- /dev/null +++ b/pl_bolts/models/gans/srgan/srgan_module.py @@ -0,0 +1,228 @@ +"""Adapted from: https://github.com/https-deeplearning-ai/GANs-Public.""" +from argparse import ArgumentParser +from pathlib import Path +from typing import Any, List, Optional, Tuple +from warnings import warn + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F + +from pl_bolts.callbacks import SRImageLoggerCallback +from pl_bolts.datamodules import TVTDataModule +from pl_bolts.datasets.utils import prepare_sr_datasets +from pl_bolts.models.gans.srgan.components import SRGANDiscriminator, SRGANGenerator, VGG19FeatureExtractor + + +class SRGAN(pl.LightningModule): + """SRGAN implementation from the paper `Photo-Realistic Single Image Super-Resolution Using a Generative + Adversarial Network `__. It uses a pretrained SRResNet model as the generator + if available. + + Code adapted from `https-deeplearning-ai/GANs-Public `_ to + Lightning by: + + - `Christoph Clement `_ + + You can pretrain a SRResNet model with :code:`srresnet_module.py`. + + Example:: + + from pl_bolts.models.gan import SRGAN + + m = SRGAN() + Trainer(gpus=1).fit(m) + + Example CLI:: + + # CelebA dataset, scale_factor 4 + python srgan_module.py --dataset=celeba --scale_factor=4 --gpus=1 + + # MNIST dataset, scale_factor 4 + python srgan_module.py --dataset=mnist --scale_factor=4 --gpus=1 + + # STL10 dataset, scale_factor 4 + python srgan_module.py --dataset=stl10 --scale_factor=4 --gpus=1 + """ + + def __init__( + self, + image_channels: int = 3, + feature_maps_gen: int = 64, + feature_maps_disc: int = 64, + num_res_blocks: int = 16, + scale_factor: int = 4, + generator_checkpoint: Optional[str] = None, + learning_rate: float = 1e-4, + scheduler_step: int = 100, + **kwargs: Any, + ) -> None: + """ + Args: + image_channels: Number of channels of the images from the dataset + feature_maps_gen: Number of feature maps to use for the generator + feature_maps_disc: Number of feature maps to use for the discriminator + num_res_blocks: Number of res blocks to use in the generator + scale_factor: Scale factor for the images (either 2 or 4) + generator_checkpoint: Generator checkpoint created with SRResNet module + learning_rate: Learning rate + scheduler_step: Number of epochs after which the learning rate gets decayed + """ + super().__init__() + self.save_hyperparameters() + + if generator_checkpoint: + self.generator = torch.load(generator_checkpoint) + else: + assert scale_factor in [2, 4] + num_ps_blocks = scale_factor // 2 + self.generator = SRGANGenerator(image_channels, feature_maps_gen, num_res_blocks, num_ps_blocks) + + self.discriminator = SRGANDiscriminator(image_channels, feature_maps_disc) + self.vgg_feature_extractor = VGG19FeatureExtractor(image_channels) + + def configure_optimizers(self) -> Tuple[List[torch.optim.Adam], List[torch.optim.lr_scheduler.MultiStepLR]]: + opt_disc = torch.optim.Adam(self.discriminator.parameters(), lr=self.hparams.learning_rate) + opt_gen = torch.optim.Adam(self.generator.parameters(), lr=self.hparams.learning_rate) + + sched_disc = torch.optim.lr_scheduler.MultiStepLR(opt_disc, milestones=[self.hparams.scheduler_step], gamma=0.1) + sched_gen = torch.optim.lr_scheduler.MultiStepLR(opt_gen, milestones=[self.hparams.scheduler_step], gamma=0.1) + return [opt_disc, opt_gen], [sched_disc, sched_gen] + + def forward(self, lr_image: torch.Tensor) -> torch.Tensor: + """Generates a high resolution image given a low resolution image. + + Example:: + + srgan = SRGAN.load_from_checkpoint(PATH) + hr_image = srgan(lr_image) + """ + return self.generator(lr_image) + + def training_step( + self, + batch: Tuple[torch.Tensor, torch.Tensor], + batch_idx: int, + optimizer_idx: int, + ) -> torch.Tensor: + hr_image, lr_image = batch + + # Train discriminator + result = None + if optimizer_idx == 0: + result = self._disc_step(hr_image, lr_image) + + # Train generator + if optimizer_idx == 1: + result = self._gen_step(hr_image, lr_image) + + return result + + def _disc_step(self, hr_image: torch.Tensor, lr_image: torch.Tensor) -> torch.Tensor: + disc_loss = self._disc_loss(hr_image, lr_image) + self.log("loss/disc", disc_loss, on_step=True, on_epoch=True) + return disc_loss + + def _gen_step(self, hr_image: torch.Tensor, lr_image: torch.Tensor) -> torch.Tensor: + gen_loss = self._gen_loss(hr_image, lr_image) + self.log("loss/gen", gen_loss, on_step=True, on_epoch=True) + return gen_loss + + def _disc_loss(self, hr_image: torch.Tensor, lr_image: torch.Tensor) -> torch.Tensor: + real_pred = self.discriminator(hr_image) + real_loss = self._adv_loss(real_pred, ones=True) + + _, fake_pred = self._fake_pred(lr_image) + fake_loss = self._adv_loss(fake_pred, ones=False) + + disc_loss = 0.5 * (real_loss + fake_loss) + + return disc_loss + + def _gen_loss(self, hr_image: torch.Tensor, lr_image: torch.Tensor) -> torch.Tensor: + fake, fake_pred = self._fake_pred(lr_image) + + perceptual_loss = self._perceptual_loss(hr_image, fake) + adv_loss = self._adv_loss(fake_pred, ones=True) + content_loss = self._content_loss(hr_image, fake) + + gen_loss = 0.006 * perceptual_loss + 0.001 * adv_loss + content_loss + + return gen_loss + + def _fake_pred(self, lr_image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + fake = self(lr_image) + fake_pred = self.discriminator(fake) + return fake, fake_pred + + @staticmethod + def _adv_loss(pred: torch.Tensor, ones: bool) -> torch.Tensor: + target = torch.ones_like(pred) if ones else torch.zeros_like(pred) + adv_loss = F.binary_cross_entropy_with_logits(pred, target) + return adv_loss + + def _perceptual_loss(self, hr_image: torch.Tensor, fake: torch.Tensor) -> torch.Tensor: + real_features = self.vgg_feature_extractor(hr_image) + fake_features = self.vgg_feature_extractor(fake) + perceptual_loss = self._content_loss(real_features, fake_features) + return perceptual_loss + + @staticmethod + def _content_loss(hr_image: torch.Tensor, fake: torch.Tensor) -> torch.Tensor: + return F.mse_loss(hr_image, fake) + + @staticmethod + def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument("--feature_maps_gen", default=64, type=int) + parser.add_argument("--feature_maps_disc", default=64, type=int) + parser.add_argument("--learning_rate", default=1e-4, type=float) + parser.add_argument("--scheduler_step", default=100, type=float) + return parser + + +def cli_main(args=None): + pl.seed_everything(1234) + + parser = ArgumentParser() + parser.add_argument("--dataset", default="mnist", type=str, choices=["celeba", "mnist", "stl10"]) + parser.add_argument("--data_dir", default="./", type=str) + parser.add_argument("--log_interval", default=1000, type=int) + parser.add_argument("--scale_factor", default=4, type=int) + parser.add_argument("--save_model_checkpoint", dest="save_model_checkpoint", action="store_true") + + parser = TVTDataModule.add_argparse_args(parser) + parser = SRGAN.add_model_specific_args(parser) + parser = pl.Trainer.add_argparse_args(parser) + + args = parser.parse_args(args) + + datasets = prepare_sr_datasets(args.dataset, args.scale_factor, args.data_dir) + dm = TVTDataModule(*datasets, **vars(args)) + + generator_checkpoint = Path(f"model_checkpoints/srresnet-{args.dataset}-scale_factor={args.scale_factor}.pt") + if not generator_checkpoint.exists(): + warn( + "No generator checkpoint found. Training generator from scratch. \ + Use srresnet_module.py to pretrain the generator." + ) + generator_checkpoint = None + + model = SRGAN( + **vars(args), image_channels=dm.dataset_test.image_channels, generator_checkpoint=generator_checkpoint + ) + trainer = pl.Trainer.from_argparse_args( + args, + callbacks=[SRImageLoggerCallback(log_interval=args.log_interval, scale_factor=args.scale_factor)], + logger=pl.loggers.TensorBoardLogger( + save_dir="lightning_logs", + name="srgan", + version=f"{args.dataset}-scale_factor={args.scale_factor}", + default_hp_metric=False, + ), + ) + trainer.fit(model, dm) + + +if __name__ == "__main__": + cli_main() diff --git a/pl_bolts/models/gans/srgan/srresnet_module.py b/pl_bolts/models/gans/srgan/srresnet_module.py new file mode 100644 index 0000000000..e7fa02bbb1 --- /dev/null +++ b/pl_bolts/models/gans/srgan/srresnet_module.py @@ -0,0 +1,148 @@ +"""Adapted from: https://github.com/https-deeplearning-ai/GANs-Public.""" +from argparse import ArgumentParser +from typing import Any, Tuple + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F + +from pl_bolts.callbacks import SRImageLoggerCallback +from pl_bolts.datamodules import TVTDataModule +from pl_bolts.datasets.utils import prepare_sr_datasets +from pl_bolts.models.gans.srgan.components import SRGANGenerator + + +class SRResNet(pl.LightningModule): + """SRResNet implementation from the paper `Photo-Realistic Single Image Super-Resolution Using a Generative + Adversarial Network `__. A pretrained SRResNet model is used as the generator + for SRGAN. + + Code adapted from `https-deeplearning-ai/GANs-Public `_ to + Lightning by: + + - `Christoph Clement `_ + + Example:: + + from pl_bolts.models.gan import SRResNet + + m = SRResNet() + Trainer(gpus=1).fit(m) + + Example CLI:: + + # CelebA dataset, scale_factor 4 + python srresnet_module.py --dataset=celeba --scale_factor=4 --gpus=1 + + # MNIST dataset, scale_factor 4 + python srresnet_module.py --dataset=mnist --scale_factor=4 --gpus=1 + + # STL10 dataset, scale_factor 4 + python srresnet_module.py --dataset=stl10 --scale_factor=4 --gpus=1 + """ + + def __init__( + self, + image_channels: int = 3, + feature_maps: int = 64, + num_res_blocks: int = 16, + scale_factor: int = 4, + learning_rate: float = 1e-4, + **kwargs: Any, + ) -> None: + """ + Args: + image_channels: Number of channels of the images from the dataset + feature_maps: Number of feature maps to use + num_res_blocks: Number of res blocks to use in the generator + scale_factor: Scale factor for the images (either 2 or 4) + learning_rate: Learning rate + """ + super().__init__() + self.save_hyperparameters() + + assert scale_factor in [2, 4] + num_ps_blocks = scale_factor // 2 + self.srresnet = SRGANGenerator(image_channels, feature_maps, num_res_blocks, num_ps_blocks) + + def configure_optimizers(self) -> torch.optim.Adam: + return torch.optim.Adam(self.srresnet.parameters(), lr=self.hparams.learning_rate) + + def forward(self, lr_image: torch.Tensor) -> torch.Tensor: + """Creates a high resolution image given a low resolution image. + + Example:: + + srresnet = SRResNet.load_from_checkpoint(PATH) + hr_image = srresnet(lr_image) + """ + return self.srresnet(lr_image) + + def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + loss = self._loss(batch) + self.log("loss/train", loss, on_epoch=True) + return loss + + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + loss = self._loss(batch) + self.log("loss/val", loss, sync_dist=True) + return loss + + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + loss = self._loss(batch) + self.log("loss/test", loss, sync_dist=True) + return loss + + def _loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + hr_image, lr_image = batch + fake = self(lr_image) + loss = F.mse_loss(hr_image, fake) + return loss + + @staticmethod + def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument("--feature_maps", default=64, type=int) + parser.add_argument("--learning_rate", default=1e-4, type=float) + parser.add_argument("--num_res_blocks", default=16, type=int) + return parser + + +def cli_main(args=None): + pl.seed_everything(1234) + + parser = ArgumentParser() + parser.add_argument("--dataset", default="mnist", type=str, choices=["celeba", "mnist", "stl10"]) + parser.add_argument("--data_dir", default="./", type=str) + parser.add_argument("--log_interval", default=1000, type=int) + parser.add_argument("--scale_factor", default=4, type=int) + parser.add_argument("--save_model_checkpoint", dest="save_model_checkpoint", action="store_true") + + parser = TVTDataModule.add_argparse_args(parser) + parser = SRResNet.add_model_specific_args(parser) + parser = pl.Trainer.add_argparse_args(parser) + args = parser.parse_args(args) + + datasets = prepare_sr_datasets(args.dataset, args.scale_factor, args.data_dir) + dm = TVTDataModule(*datasets, **vars(args)) + + model = SRResNet(**vars(args), image_channels=dm.dataset_train.dataset.image_channels) + + trainer = pl.Trainer.from_argparse_args( + args, + callbacks=[SRImageLoggerCallback(log_interval=args.log_interval, scale_factor=args.scale_factor)], + logger=pl.loggers.TensorBoardLogger( + save_dir="lightning_logs", + name="srresnet", + version=f"{args.dataset}-scale_factor={args.scale_factor}", + default_hp_metric=False, + ), + ) + trainer.fit(model, dm) + + if args.save_model_checkpoint: + torch.save(model.srresnet, f"model_checkpoints/srresnet-{args.dataset}-scale_factor={args.scale_factor}.pt") + + +if __name__ == "__main__": + cli_main() diff --git a/tests/datamodules/test_datamodules.py b/tests/datamodules/test_datamodules.py index f1c05867a8..f08cba7b1f 100644 --- a/tests/datamodules/test_datamodules.py +++ b/tests/datamodules/test_datamodules.py @@ -14,7 +14,9 @@ FashionMNISTDataModule, MNISTDataModule, ) +from pl_bolts.datamodules.sr_datamodule import TVTDataModule from pl_bolts.datasets.cifar10_dataset import CIFAR10 +from pl_bolts.datasets.sr_mnist_dataset import SRMNIST def test_dev_datasets(datadir): @@ -90,6 +92,15 @@ def _create_dm(dm_cls, datadir, **kwargs): return dm +def test_sr_datamodule(datadir): + dataset = SRMNIST(scale_factor=4, root=datadir, download=True) + dm = TVTDataModule(dataset_train=dataset, dataset_val=dataset, dataset_test=dataset, batch_size=2) + + next(iter(dm.train_dataloader())) + next(iter(dm.val_dataloader())) + next(iter(dm.test_dataloader())) + + @pytest.mark.parametrize("split", ["byclass", "bymerge", "balanced", "letters", "digits", "mnist"]) @pytest.mark.parametrize("dm_cls", [BinaryEMNISTDataModule, EMNISTDataModule]) def test_emnist_datamodules(datadir, dm_cls, split): diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 373a52dbce..6d010fe15b 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1,6 +1,9 @@ +import pytest +import torch from torch.utils.data import DataLoader from pl_bolts.datasets import DummyDataset, RandomDataset, RandomDictDataset, RandomDictStringDataset +from pl_bolts.datasets.sr_mnist_dataset import SRMNIST def test_dummy_ds(): @@ -33,3 +36,19 @@ def test_rand_str_dict_ds(): for b in dl: pass + + +@pytest.mark.parametrize("scale_factor", [2, 4]) +def test_sr_datasets(datadir, scale_factor): + dl = DataLoader(SRMNIST(scale_factor, root=datadir, download=True)) + hr_image, lr_image = next(iter(dl)) + + hr_image_size = 28 + assert hr_image.size() == torch.Size([1, 1, hr_image_size, hr_image_size]) + assert lr_image.size() == torch.Size([1, 1, hr_image_size // scale_factor, hr_image_size // scale_factor]) + + atol = 0.3 + assert torch.allclose(hr_image.min(), torch.tensor(-1.0), atol=atol) + assert torch.allclose(hr_image.max(), torch.tensor(1.0), atol=atol) + assert torch.allclose(lr_image.min(), torch.tensor(0.0), atol=atol) + assert torch.allclose(lr_image.max(), torch.tensor(1.0), atol=atol) diff --git a/tests/models/test_gans.py b/tests/models/test_gans.py index 1905a1ab94..fb59a4753a 100644 --- a/tests/models/test_gans.py +++ b/tests/models/test_gans.py @@ -1,9 +1,11 @@ import pytest from pytorch_lightning import Trainer, seed_everything +from torch.utils.data.dataloader import DataLoader from torchvision import transforms as transform_lib from pl_bolts.datamodules import CIFAR10DataModule, MNISTDataModule -from pl_bolts.models.gans import DCGAN, GAN +from pl_bolts.datasets.sr_mnist_dataset import SRMNIST +from pl_bolts.models.gans import DCGAN, GAN, SRGAN, SRResNet @pytest.mark.parametrize( @@ -34,3 +36,14 @@ def test_dcgan(tmpdir, datadir, dm_cls): model = DCGAN(image_channels=dm.dims[0]) trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, dm) + + +@pytest.mark.parametrize("sr_module_cls", [SRResNet, SRGAN]) +@pytest.mark.parametrize("scale_factor", [2, 4]) +def test_sr_modules(tmpdir, datadir, sr_module_cls, scale_factor): + seed_everything(42) + + dl = DataLoader(SRMNIST(scale_factor, root=datadir, download=True)) + model = sr_module_cls(image_channels=1, scale_factor=scale_factor) + trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) + trainer.fit(model, dl) diff --git a/tests/models/test_scripts.py b/tests/models/test_scripts.py index 6a1d9686bd..35cb1672ec 100644 --- a/tests/models/test_scripts.py +++ b/tests/models/test_scripts.py @@ -30,6 +30,22 @@ def test_cli_run_dcgan(cli_args): cli_main() +@pytest.mark.parametrize("cli_args", ["--dataset mnist --scale_factor 4" + _DEFAULT_ARGS]) +def test_cli_run_srgan(cli_args): + from pl_bolts.models.gans.srgan.srgan_module import cli_main + + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): + cli_main() + + +@pytest.mark.parametrize("cli_args", ["--dataset mnist --scale_factor 4" + _DEFAULT_ARGS]) +def test_cli_run_srresnet(cli_args): + from pl_bolts.models.gans.srgan.srresnet_module import cli_main + + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): + cli_main() + + @pytest.mark.parametrize("cli_args", [_DEFAULT_ARGS]) def test_cli_run_mnist(cli_args): """Test running CLI for an example with default params."""