diff --git a/docs/pages/api/datasets.rst b/docs/pages/api/datasets.rst index 40e8cbf..923bdaf 100755 --- a/docs/pages/api/datasets.rst +++ b/docs/pages/api/datasets.rst @@ -9,24 +9,24 @@ For example: :: div2k_data = esrgan.dataset.DIV2KDataset('path/to/div2k_root/') data_loader = torch.utils.data.DataLoader(div2k_data, batch_size=4, shuffle=True) + DIV2K ^^^^^ -.. autoclass:: esrgan.dataset.div2k.DIV2KDataset +.. autoclass:: esrgan.dataset.DIV2KDataset :members: -Folder of Images -^^^^^^^^^^^^^^^^ +Flickr2K +^^^^^^^^ -.. autoclass:: esrgan.dataset.image_folder.ImageFolderDataset +.. autoclass:: esrgan.dataset.Flickr2K :members: - :undoc-members: -Augmentation -^^^^^^^^^^^^ +Folder of Images +^^^^^^^^^^^^^^^^ -.. automodule:: esrgan.dataset.aug +.. autoclass:: esrgan.dataset.ImageFolderDataset :members: :undoc-members: diff --git a/docs/pages/api/utils.rst b/docs/pages/api/utils.rst index b934bbd..ba566fc 100755 --- a/docs/pages/api/utils.rst +++ b/docs/pages/api/utils.rst @@ -4,6 +4,14 @@ Utilities Set of utilities that can make life a little bit easier. +Augmentation +^^^^^^^^^^^^ + +.. automodule:: esrgan.utils.aug + :members: + :undoc-members: + + Model init ^^^^^^^^^^ diff --git a/esrgan/dataset/div2k.py b/esrgan/dataset.py similarity index 54% rename from esrgan/dataset/div2k.py rename to esrgan/dataset.py index c5e3ebf..880a18b 100644 --- a/esrgan/dataset/div2k.py +++ b/esrgan/dataset.py @@ -1,17 +1,31 @@ import glob from pathlib import Path import random -from typing import Any, Callable, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from albumentations.augmentations.crops import functional as F -from catalyst import data, utils +from catalyst import data from catalyst.contrib.datasets import misc import numpy as np from torch.utils.data import Dataset -from esrgan.dataset import aug +from esrgan import utils -__all__ = ["DIV2KDataset"] +__all__ = ["DIV2KDataset", "Flickr2K", "ImageFolderDataset"] + + +def has_image_extension(uri: Union[str, Path]) -> bool: + """Check that file has image extension. + + Args: + uri: The resource to load the file from. + + Returns: + ``True`` if file has image extension, ``False`` otherwise. + + """ + ext = Path(uri).suffix + return ext.lower() in {".bmp", ".png", ".jpeg", ".jpg", ".tif", ".tiff"} def paired_random_crop( @@ -90,7 +104,7 @@ def __init__( train: bool = True, target_type: str = "bicubic_X4", patch_size: Tuple[int, int] = (96, 96), - transform: Optional[Callable[[Any], dict]] = None, + transform: Optional[Callable[[Any], Dict]] = None, low_resolution_image_key: str = "lr_image", high_resolution_image_key: str = "hr_image", download: bool = False, @@ -120,9 +134,11 @@ def __init__( self.lr_key = low_resolution_image_key self.hr_key = high_resolution_image_key + _, downscaling = target_type.split("_") + # 'index' files - lr_images = self._images_in_dir(Path(root) / Path(filename_lr).stem) - hr_images = self._images_in_dir(Path(root) / Path(filename_hr).stem) + lr_images = self._images_in_dir(root, Path(filename_lr).stem) + hr_images = self._images_in_dir(root, Path(filename_hr).stem) assert len(lr_images) == len(hr_images) self.data = [ @@ -135,14 +151,14 @@ def __init__( data.ImageReader(input_key="hr_image", output_key=self.hr_key), ]) - self.scale = int(target_type[-1]) if target_type[-1].isdigit() else 4 + self.scale = int(downscaling) if downscaling.isdigit() else 4 height, width = patch_size self.target_patch_size = patch_size self.input_patch_size = (height // self.scale, width // self.scale) - self.transform = aug.Augmentor(transform) + self.transform = utils.Augmentor(transform) - def __getitem__(self, index: int) -> dict: + def __getitem__(self, index: int) -> Dict: """Gets element of the dataset. Args: @@ -177,13 +193,135 @@ def __len__(self) -> int: """ return len(self.data) - def _images_in_dir(self, path: Path) -> List[str]: + def _images_in_dir(self, *path: Union[str, Path]) -> List[str]: # fix path to dir for `NTIRE 2017` datasets + path = Path(*path) if not path.exists(): idx = path.name.rfind("_") path = path.parent / path.name[:idx] / path.name[idx + 1:] files = glob.iglob(f"{path}/**/*", recursive=True) - images = sorted(filter(utils.has_image_extension, files)) + images = sorted(filter(has_image_extension, files)) return images + + +class Flickr2K(DIV2KDataset): + """`Flickr2K `_ Dataset. + + Args: + root: Root directory where images are downloaded to. + train: If True, creates dataset from training set, + otherwise creates from validation set. + target_type: Type of target to use, ``'bicubic_X2'``, ``'unknown_X4'``, + ... + patch_size: If ``train == True``, define sizes of patches to produce, + return full image otherwise. Tuple of height and width. + transform: A function / transform that takes in dictionary (with low + and high resolution images) and returns a transformed version. + low_resolution_image_key: Key to use to store images of low resolution. + high_resolution_image_key: Key to use to store high resolution images. + download: If true, downloads the dataset from the internet + and puts it in root directory. If dataset is already downloaded, + it is not downloaded again. + + """ + + url = "https://cv.snu.ac.kr/research/EDSR/" + resources = { + "Flickr2K.tar": "5d3f39443d5e9489bff8963f8f26cb03", + } + + def __init__( + self, + root: str, + train: bool = True, + target_type: str = "bicubic_X4", + patch_size: Tuple[int, int] = (96, 96), + transform: Optional[Callable[[Any], Dict]] = None, + low_resolution_image_key: str = "lr_image", + high_resolution_image_key: str = "hr_image", + download: bool = False, + ) -> None: + filename = "Flickr2K.tar" + if download: + # download images + misc.download_and_extract_archive( + f"{self.url}{filename}", + download_root=root, + filename=filename, + md5=self.resources[filename], + ) + + self.train = train + + self.lr_key = low_resolution_image_key + self.hr_key = high_resolution_image_key + + degradation, downscaling = target_type.split("_") + + # 'index' files + subdir_hr = "Flickr2K_HR" + subdir_lr = Path(f"Flickr2K_LR_{degradation}", downscaling) + lr_images = self._images_in_dir(root, Path(filename).stem, subdir_hr) + hr_images = self._images_in_dir(root, Path(filename).stem, subdir_lr) + assert len(lr_images) == len(hr_images) + + self.data = [ + {"lr_image": lr_image, "hr_image": hr_image} + for lr_image, hr_image in zip(lr_images, hr_images) + ] + + self.open_fn = data.ReaderCompose([ + data.ImageReader(input_key="lr_image", output_key=self.lr_key), + data.ImageReader(input_key="hr_image", output_key=self.hr_key), + ]) + + self.scale = int(downscaling) if downscaling.isdigit() else 4 + height, width = patch_size + self.target_patch_size = patch_size + self.input_patch_size = (height // self.scale, width // self.scale) + + self.transform = utils.Augmentor(transform) + + +class ImageFolderDataset(data.ListDataset): + """A generic data loader where the samples are arranged in this way: :: + + /xxx.ext + /xxy.ext + /xxz.ext + ... + /123.ext + /nsdf3.ext + /asd932_.ext + + Args: + pathname: Root directory of dataset. + image_key: Key to use to store image. + image_name_key: Key to use to store name of the image. + transform: A function / transform that takes in dictionary + and returns its transformed version. + + """ + + def __init__( + self, + pathname: str, + image_key: str = "image", + image_name_key: str = "filename", + transform: Optional[Callable[[Dict], Dict]] = None, + ) -> None: + files = glob.iglob(pathname, recursive=True) + images = sorted(filter(has_image_extension, files)) + + list_data = [{"image": filename} for filename in images] + open_fn = data.ReaderCompose([ + data.ImageReader(input_key="image", output_key=image_key), + data.LambdaReader(input_key="image", output_key=image_name_key), + ]) + transform = utils.Augmentor(transform) + + super().__init__( + list_data=list_data, open_fn=open_fn, dict_transform=transform + ) diff --git a/esrgan/dataset/__init__.py b/esrgan/dataset/__init__.py deleted file mode 100644 index 89e6706..0000000 --- a/esrgan/dataset/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# flake8: noqa -from esrgan.dataset.div2k import DIV2KDataset -from esrgan.dataset.image_folder import ImageFolderDataset diff --git a/esrgan/dataset/image_folder.py b/esrgan/dataset/image_folder.py deleted file mode 100644 index 886b8ca..0000000 --- a/esrgan/dataset/image_folder.py +++ /dev/null @@ -1,50 +0,0 @@ -import glob -from typing import Callable, Dict, Optional - -from catalyst import data, utils - -from esrgan.dataset import aug - -__all__ = ["ImageFolderDataset"] - - -class ImageFolderDataset(data.ListDataset): - """A generic data loader where the samples are arranged in this way: :: - - /xxx.ext - /xxy.ext - /xxz.ext - ... - /123.ext - /nsdf3.ext - /asd932_.ext - - Args: - pathname: Root directory of dataset. - image_key: Key to use to store image. - image_name_key: Key to use to store name of the image. - transform: A function / transform that takes in dictionary - and returns its transformed version. - - """ - - def __init__( - self, - pathname: str, - image_key: str = "image", - image_name_key: str = "filename", - transform: Optional[Callable[[Dict], Dict]] = None, - ) -> None: - files = glob.iglob(pathname, recursive=True) - images = sorted(filter(utils.has_image_extension, files)) - - list_data = [{"image": filename} for filename in images] - open_fn = data.ReaderCompose([ - data.ImageReader(input_key="image", output_key=image_key), - data.LambdaReader(input_key="image", output_key=image_name_key), - ]) - transform = aug.Augmentor(transform) - - super().__init__( - list_data=list_data, open_fn=open_fn, dict_transform=transform - ) diff --git a/esrgan/utils/__init__.py b/esrgan/utils/__init__.py index dea9b7c..dbec8a1 100644 --- a/esrgan/utils/__init__.py +++ b/esrgan/utils/__init__.py @@ -1,4 +1,5 @@ # flake8: noqa +from esrgan.utils.aug import Augmentor from esrgan.utils.init import module_init_, net_init_ from esrgan.utils.misc import is_power_of_two, pairwise from esrgan.utils.module_params import create_layer diff --git a/esrgan/dataset/aug.py b/esrgan/utils/aug.py similarity index 100% rename from esrgan/dataset/aug.py rename to esrgan/utils/aug.py