Skip to content

Commit

Permalink
feat: Flickr2K dataset added
Browse files Browse the repository at this point in the history
  • Loading branch information
bagxi committed Jan 22, 2022
1 parent 711cfb5 commit 8f0a064
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 73 deletions.
16 changes: 8 additions & 8 deletions docs/pages/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@ For example: ::
div2k_data = esrgan.dataset.DIV2KDataset('path/to/div2k_root/')
data_loader = torch.utils.data.DataLoader(div2k_data, batch_size=4, shuffle=True)


DIV2K
^^^^^

.. autoclass:: esrgan.dataset.div2k.DIV2KDataset
.. autoclass:: esrgan.dataset.DIV2KDataset
:members:


Folder of Images
^^^^^^^^^^^^^^^^
Flickr2K
^^^^^^^^

.. autoclass:: esrgan.dataset.image_folder.ImageFolderDataset
.. autoclass:: esrgan.dataset.Flickr2K
:members:
:undoc-members:


Augmentation
^^^^^^^^^^^^
Folder of Images
^^^^^^^^^^^^^^^^

.. automodule:: esrgan.dataset.aug
.. autoclass:: esrgan.dataset.ImageFolderDataset
:members:
:undoc-members:
8 changes: 8 additions & 0 deletions docs/pages/api/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ Utilities
Set of utilities that can make life a little bit easier.


Augmentation
^^^^^^^^^^^^

.. automodule:: esrgan.utils.aug
:members:
:undoc-members:


Model init
^^^^^^^^^^

Expand Down
162 changes: 150 additions & 12 deletions esrgan/dataset/div2k.py → esrgan/dataset.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
import glob
from pathlib import Path
import random
from typing import Any, Callable, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

from albumentations.augmentations.crops import functional as F
from catalyst import data, utils
from catalyst import data
from catalyst.contrib.datasets import misc
import numpy as np
from torch.utils.data import Dataset

from esrgan.dataset import aug
from esrgan import utils

__all__ = ["DIV2KDataset"]
__all__ = ["DIV2KDataset", "Flickr2K", "ImageFolderDataset"]


def has_image_extension(uri: Union[str, Path]) -> bool:
"""Check that file has image extension.
Args:
uri: The resource to load the file from.
Returns:
``True`` if file has image extension, ``False`` otherwise.
"""
ext = Path(uri).suffix
return ext.lower() in {".bmp", ".png", ".jpeg", ".jpg", ".tif", ".tiff"}


def paired_random_crop(
Expand Down Expand Up @@ -90,7 +104,7 @@ def __init__(
train: bool = True,
target_type: str = "bicubic_X4",
patch_size: Tuple[int, int] = (96, 96),
transform: Optional[Callable[[Any], dict]] = None,
transform: Optional[Callable[[Any], Dict]] = None,
low_resolution_image_key: str = "lr_image",
high_resolution_image_key: str = "hr_image",
download: bool = False,
Expand Down Expand Up @@ -120,9 +134,11 @@ def __init__(
self.lr_key = low_resolution_image_key
self.hr_key = high_resolution_image_key

_, downscaling = target_type.split("_")

# 'index' files
lr_images = self._images_in_dir(Path(root) / Path(filename_lr).stem)
hr_images = self._images_in_dir(Path(root) / Path(filename_hr).stem)
lr_images = self._images_in_dir(root, Path(filename_lr).stem)
hr_images = self._images_in_dir(root, Path(filename_hr).stem)
assert len(lr_images) == len(hr_images)

self.data = [
Expand All @@ -135,14 +151,14 @@ def __init__(
data.ImageReader(input_key="hr_image", output_key=self.hr_key),
])

self.scale = int(target_type[-1]) if target_type[-1].isdigit() else 4
self.scale = int(downscaling) if downscaling.isdigit() else 4
height, width = patch_size
self.target_patch_size = patch_size
self.input_patch_size = (height // self.scale, width // self.scale)

self.transform = aug.Augmentor(transform)
self.transform = utils.Augmentor(transform)

def __getitem__(self, index: int) -> dict:
def __getitem__(self, index: int) -> Dict:
"""Gets element of the dataset.
Args:
Expand Down Expand Up @@ -177,13 +193,135 @@ def __len__(self) -> int:
"""
return len(self.data)

def _images_in_dir(self, path: Path) -> List[str]:
def _images_in_dir(self, *path: Union[str, Path]) -> List[str]:
# fix path to dir for `NTIRE 2017` datasets
path = Path(*path)
if not path.exists():
idx = path.name.rfind("_")
path = path.parent / path.name[:idx] / path.name[idx + 1:]

files = glob.iglob(f"{path}/**/*", recursive=True)
images = sorted(filter(utils.has_image_extension, files))
images = sorted(filter(has_image_extension, files))

return images


class Flickr2K(DIV2KDataset):
"""`Flickr2K <https://github.com/LimBee/NTIRE2017>`_ Dataset.
Args:
root: Root directory where images are downloaded to.
train: If True, creates dataset from training set,
otherwise creates from validation set.
target_type: Type of target to use, ``'bicubic_X2'``, ``'unknown_X4'``,
...
patch_size: If ``train == True``, define sizes of patches to produce,
return full image otherwise. Tuple of height and width.
transform: A function / transform that takes in dictionary (with low
and high resolution images) and returns a transformed version.
low_resolution_image_key: Key to use to store images of low resolution.
high_resolution_image_key: Key to use to store high resolution images.
download: If true, downloads the dataset from the internet
and puts it in root directory. If dataset is already downloaded,
it is not downloaded again.
"""

url = "https://cv.snu.ac.kr/research/EDSR/"
resources = {
"Flickr2K.tar": "5d3f39443d5e9489bff8963f8f26cb03",
}

def __init__(
self,
root: str,
train: bool = True,
target_type: str = "bicubic_X4",
patch_size: Tuple[int, int] = (96, 96),
transform: Optional[Callable[[Any], Dict]] = None,
low_resolution_image_key: str = "lr_image",
high_resolution_image_key: str = "hr_image",
download: bool = False,
) -> None:
filename = "Flickr2K.tar"
if download:
# download images
misc.download_and_extract_archive(
f"{self.url}{filename}",
download_root=root,
filename=filename,
md5=self.resources[filename],
)

self.train = train

self.lr_key = low_resolution_image_key
self.hr_key = high_resolution_image_key

degradation, downscaling = target_type.split("_")

# 'index' files
subdir_hr = "Flickr2K_HR"
subdir_lr = Path(f"Flickr2K_LR_{degradation}", downscaling)
lr_images = self._images_in_dir(root, Path(filename).stem, subdir_hr)
hr_images = self._images_in_dir(root, Path(filename).stem, subdir_lr)
assert len(lr_images) == len(hr_images)

self.data = [
{"lr_image": lr_image, "hr_image": hr_image}
for lr_image, hr_image in zip(lr_images, hr_images)
]

self.open_fn = data.ReaderCompose([
data.ImageReader(input_key="lr_image", output_key=self.lr_key),
data.ImageReader(input_key="hr_image", output_key=self.hr_key),
])

self.scale = int(downscaling) if downscaling.isdigit() else 4
height, width = patch_size
self.target_patch_size = patch_size
self.input_patch_size = (height // self.scale, width // self.scale)

self.transform = utils.Augmentor(transform)


class ImageFolderDataset(data.ListDataset):
"""A generic data loader where the samples are arranged in this way: ::
<pathname>/xxx.ext
<pathname>/xxy.ext
<pathname>/xxz.ext
...
<pathname>/123.ext
<pathname>/nsdf3.ext
<pathname>/asd932_.ext
Args:
pathname: Root directory of dataset.
image_key: Key to use to store image.
image_name_key: Key to use to store name of the image.
transform: A function / transform that takes in dictionary
and returns its transformed version.
"""

def __init__(
self,
pathname: str,
image_key: str = "image",
image_name_key: str = "filename",
transform: Optional[Callable[[Dict], Dict]] = None,
) -> None:
files = glob.iglob(pathname, recursive=True)
images = sorted(filter(has_image_extension, files))

list_data = [{"image": filename} for filename in images]
open_fn = data.ReaderCompose([
data.ImageReader(input_key="image", output_key=image_key),
data.LambdaReader(input_key="image", output_key=image_name_key),
])
transform = utils.Augmentor(transform)

super().__init__(
list_data=list_data, open_fn=open_fn, dict_transform=transform
)
3 changes: 0 additions & 3 deletions esrgan/dataset/__init__.py

This file was deleted.

50 changes: 0 additions & 50 deletions esrgan/dataset/image_folder.py

This file was deleted.

1 change: 1 addition & 0 deletions esrgan/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# flake8: noqa
from esrgan.utils.aug import Augmentor
from esrgan.utils.init import module_init_, net_init_
from esrgan.utils.misc import is_power_of_two, pairwise
from esrgan.utils.module_params import create_layer
File renamed without changes.

0 comments on commit 8f0a064

Please sign in to comment.