Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming vision datasets #284

Merged
merged 63 commits into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
b6dab36
build_(vision datasets).
knighton Jan 27, 2022
4b4e8d4
Streaming dataset: Web(MNIST, CIFAR-10/20/100, TinyImagenet-200, Imag…
knighton Jan 28, 2022
302f666
WebDataset index files that indirect num_shards, (load, create)_webda…
knighton Feb 1, 2022
4a61894
Make uniform.
knighton Feb 1, 2022
aa89dda
Link up webdataset creation -> load_dataset -> dataset hparam init.
knighton Feb 2, 2022
84e8254
Fix.
knighton Feb 2, 2022
fa22cf0
isort.
knighton Feb 2, 2022
75aa4a8
yapf.
knighton Feb 2, 2022
842029c
build/ -> create/ ("build" is overloaded term).
knighton Feb 2, 2022
ec582ff
pyright (typing annotations).
knighton Feb 2, 2022
2fd7f53
Rm imagenet1k-multiproc example.
knighton Feb 2, 2022
43023d6
Typing.
knighton Feb 3, 2022
679e0f2
isort, etc.
knighton Feb 8, 2022
adb3289
ADE20k.
knighton Feb 11, 2022
9f5f561
Refactor: dataset_s3_bucket.
knighton Feb 14, 2022
3172c4c
Fix (sharding).
knighton Feb 22, 2022
48c4c52
fix wds length
abhi-mosaic Feb 23, 2022
cd34c82
add wds cifar10 yaml
abhi-mosaic Feb 23, 2022
72ac7d4
add deps
abhi-mosaic Feb 23, 2022
42f2bb4
minor fixes
abhi-mosaic Feb 23, 2022
142245b
typo
abhi-mosaic Feb 23, 2022
c41a94c
size_webdataset().
knighton Feb 24, 2022
9b56693
shuffle (following webdataset defaults).
knighton Feb 24, 2022
47fafb9
Add YAMLs.
knighton Feb 25, 2022
e1a4d90
Fix.
knighton Feb 26, 2022
ccafa8c
Split JpgCls webdataset hparams.
knighton Feb 28, 2022
55822f6
Extract hyperparam for shuffle_buffer_per_worker.
knighton Mar 1, 2022
3c818ee
Merge branch 'dev' into james/wds
knighton Mar 3, 2022
6263224
Remove directory of webdataset versions of yamls save one, minor rena…
knighton Mar 3, 2022
eb41cc4
Make dataset_s3_bucket configurable, add typing annotations.
knighton Mar 3, 2022
78bae50
Drop use_synthetic from WebDataset datasets.
knighton Mar 3, 2022
9e4780b
Merge branch 'dev' into james/wds
knighton Mar 3, 2022
c997cbf
Make s3 bucket configurable for all webdatasets.
knighton Mar 3, 2022
a8947d5
Make naming more uniform, make all wds s3/local configurable, etc.
knighton Mar 3, 2022
b4c33b7
Merge remote-tracking branch 'refs/remotes/origin/james/wds' into jam…
knighton Mar 3, 2022
2c91dd6
Update ResNet50 yaml
Landanjs Mar 4, 2022
d0a96f7
Fix: use_synthetic.
knighton Mar 4, 2022
c480aad
Fix (remove mixin).
knighton Mar 4, 2022
25a3dd5
Raise on error instead of silently failing.
knighton Mar 4, 2022
d831033
load_webdataset: s3_bucket -> remote (either s3 or local fs).
knighton Mar 4, 2022
aa45054
Set up webdataset with one call.
knighton Mar 4, 2022
6e03882
Class variables <-> yahp.
knighton Mar 4, 2022
b6ead1a
Weaken docutils requirement to make .[all] happy.
knighton Mar 4, 2022
bfd2e4c
Optional webdataset.
knighton Mar 4, 2022
1a75f8e
Also Wurlitzer.
knighton Mar 5, 2022
e25f125
Lint.
knighton Mar 5, 2022
e0223d5
Docstrings.
knighton Mar 5, 2022
eecc937
Merge branch 'dev' into james/wds
knighton Mar 5, 2022
13035c0
Python usage too modern.
knighton Mar 5, 2022
cb84613
Skip over webdatasets (wds_).
knighton Mar 5, 2022
8df5392
Fix (split).
knighton Mar 7, 2022
a7b4334
Test hparams abstract base class solution.
knighton Mar 7, 2022
1d8fb9d
Required hparams before optional hparams.
knighton Mar 7, 2022
40eae59
All optional.
knighton Mar 7, 2022
c525d3b
Fix (default).
knighton Mar 7, 2022
523dc41
Merge branch 'dev' into james/wds
hanlint Mar 7, 2022
74e8dda
fix typing
hanlint Mar 7, 2022
aee6084
fix lint
hanlint Mar 7, 2022
393f2d5
lazy import load_webdataset
hanlint Mar 7, 2022
a14bf6c
Skip over WebDatasets properly.
knighton Mar 7, 2022
c3cea69
Ported Dockerfile README to rst, style fixes
Mar 7, 2022
e8681e2
Merge branch 'james/wds' of github.com:mosaicml/composer into james/wds
Mar 7, 2022
a6ef182
Minor doc tweak
Mar 7, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions composer/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# Copyright 2021 MosaicML. All Rights Reserved.

from composer.datasets.ade20k import ADE20kDatasetHparams as ADE20kDatasetHparams
from composer.datasets.ade20k import ADE20kDatasetHparams, ADE20kWebDatasetHparams
from composer.datasets.brats import BratsDatasetHparams as BratsDatasetHparams
from composer.datasets.cifar10 import CIFAR10DatasetHparams as CIFAR10DatasetHparams
from composer.datasets.cifar import (CIFAR10DatasetHparams, CIFAR10WebDatasetHparams, CIFAR20WebDatasetHparams,
CIFAR100WebDatasetHparams)
from composer.datasets.dataloader import DataloaderHparams as DataloaderHparams
from composer.datasets.dataloader import WrappedDataLoader as WrappedDataLoader
from composer.datasets.glue import GLUEHparams as GLUEHparams
from composer.datasets.hparams import DatasetHparams as DatasetHparams
from composer.datasets.hparams import SyntheticHparamsMixin as SyntheticHparamsMixin
from composer.datasets.imagenet import ImagenetDatasetHparams as ImagenetDatasetHparams
from composer.datasets.imagenet import (Imagenet1KWebDatasetHparams, ImagenetDatasetHparams,
TinyImagenet200WebDatasetHparams)
from composer.datasets.lm_datasets import LMDatasetHparams as LMDatasetHparams
from composer.datasets.mnist import MNISTDatasetHparams as MNISTDatasetHparams
from composer.datasets.mnist import MNISTDatasetHparams, MNISTWebDatasetHparams
from composer.datasets.synthetic import MemoryFormat as MemoryFormat
from composer.datasets.synthetic import SyntheticBatchPairDataset as SyntheticBatchPairDataset
from composer.datasets.synthetic import SyntheticDataLabelType as SyntheticDataLabelType
Expand Down
131 changes: 128 additions & 3 deletions composer/datasets/ade20k.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from torchvision import transforms

from composer.core.types import DataSpec
from composer.datasets.hparams import DatasetHparams, SyntheticHparamsMixin
from composer.datasets.hparams import DatasetHparams, SyntheticHparamsMixin, WebDatasetHparams
from composer.datasets.imagenet import IMAGENET_CHANNEL_MEAN, IMAGENET_CHANNEL_STD
from composer.datasets.synthetic import SyntheticBatchPairDataset
from composer.datasets.webdataset import load_webdataset, size_webdataset
from composer.utils import dist
from composer.utils.data import NormalizationFn, pil_image_collate

Expand Down Expand Up @@ -249,7 +250,7 @@ class ADE20kDatasetHparams(DatasetHparams, SyntheticHparamsMixin):
min_resize_scale (float): the minimum value the samples can be rescaled. Default is 0.5.
max_resize_scale (float): the maximum value the samples can be rescaled. Default is 2.0.
final_size (int): the final size of the image and target. Default is 512.
ignore_background (bool): if true, ignore the background class when calculating the training loss.
ignore_background (bool): if true, ignore the background class when calculating the training loss.
Default is true.

"""
Expand Down Expand Up @@ -311,7 +312,8 @@ def initialize_object(self, batch_size, dataloader_hparams) -> DataSpec:
RandomHFlipPair(),
)

# Photometric distoration values come from mmsegmentation: https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/datasets/pipelines/transforms.py#L837
# Photometric distoration values come from mmsegmentation:
# https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/datasets/pipelines/transforms.py#L837
r_mean, g_mean, b_mean = IMAGENET_CHANNEL_MEAN
image_transforms = torch.nn.Sequential(
PhotometricDistoration(brightness=32. / 255, contrast=0.5, saturation=0.5, hue=18. / 255),
Expand Down Expand Up @@ -345,3 +347,126 @@ def initialize_object(self, batch_size, dataloader_hparams) -> DataSpec:
collate_fn=collate_fn,
drop_last=self.drop_last),
device_transforms=device_transform_fn)


@dataclass
class ADE20kWebDatasetHparams(WebDatasetHparams, SyntheticHparamsMixin):
"""Defines an instance of the ADE20k dataset for semantic segmentation.

Args:
split (str): the dataset split to use either 'train', 'val', or 'test'. Default is `train`.
base_size (int): initial size of the image and target before other augmentations. Default is 512.
min_resize_scale (float): the minimum value the samples can be rescaled. Default is 0.5.
max_resize_scale (float): the maximum value the samples can be rescaled. Default is 2.0.
final_size (int): the final size of the image and target. Default is 512.
ignore_background (bool): if true, ignore the background class when calculating the training loss.
Default is true.

"""

split: str = hp.optional("Which split of the dataset to use. Either ['train', 'val', 'test']", default='train')
base_size: int = hp.optional("Initial size of the image and target before other augmentations", default=512)
min_resize_scale: float = hp.optional("Minimum value that the image and target can be scaled", default=0.5)
max_resize_scale: float = hp.optional("Maximum value that the image and target can be scaled", default=2.0)
final_size: int = hp.optional("Final size of the image and target", default=512)
ignore_background: bool = hp.optional("If true, ignore the background class in training loss", default=True)

def validate(self):
if self.split not in ['train', 'val', 'test']:
raise ValueError(f"split value {self.split} must be one of ['train', 'val', 'test'].")

if self.base_size <= 0:
raise ValueError("base_size cannot be zero or negative.")

if self.min_resize_scale <= 0:
raise ValueError("min_resize_scale cannot be zero or negative")

if self.max_resize_scale < self.min_resize_scale:
raise ValueError("max_resize_scale cannot be less than min_resize_scale")

def initialize_object(self, batch_size, dataloader_hparams) -> DataSpec:
self.validate()

if self.use_synthetic:
if self.split == 'train':
total_dataset_size = 20_206
elif self.split == 'val':
total_dataset_size = 2_000
else:
total_dataset_size = 3_352

dataset = SyntheticBatchPairDataset(
total_dataset_size=total_dataset_size,
data_shape=[3, self.final_size, self.final_size],
label_shape=[self.final_size, self.final_size],
num_classes=150,
num_unique_samples_to_create=self.synthetic_num_unique_samples,
device=self.synthetic_device,
memory_format=self.synthetic_memory_format,
)
collate_fn = None
device_transform_fn = None

else:
# Define data transformations based on data split
if self.split == 'train':
both_transforms = torch.nn.Sequential(
RandomResizePair(min_scale=self.min_resize_scale,
max_scale=self.max_resize_scale,
base_size=(self.base_size, self.base_size)),
RandomCropPair(crop_size=(self.final_size, self.final_size)),
RandomHFlipPair(),
)

# Photometric distoration values come from mmsegmentation:
# https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/datasets/pipelines/transforms.py#L837
r_mean, g_mean, b_mean = IMAGENET_CHANNEL_MEAN
image_transforms = torch.nn.Sequential(
PhotometricDistoration(brightness=32. / 255, contrast=0.5, saturation=0.5, hue=18. / 255),
PadToSize(size=(self.final_size, self.final_size), fill=(int(r_mean), int(g_mean), int(b_mean))))

target_transforms = transforms.Compose([
PadToSize(size=(self.final_size, self.final_size), fill=0),
transforms.Grayscale(),
])
else:
both_transforms = None
image_transforms = transforms.Resize(size=(self.final_size, self.final_size),
interpolation=TF.InterpolationMode.BILINEAR)
target_transforms = transforms.Compose([
transforms.Resize(size=(self.final_size, self.final_size),
interpolation=TF.InterpolationMode.NEAREST),
transforms.Grayscale(),
])

def map_fn(args):
x, y = args
if both_transforms:
x, y = both_transforms((x, y))
if image_transforms:
x = image_transforms(x)
if target_transforms:
y = target_transforms(y)
return x, y

dataset, meta = load_webdataset('mosaicml-internal-dataset-ade20k', 'ade20k', self.split,
self.webdataset_cache_dir, self.webdataset_cache_verbose)
if self.shuffle:
dataset = dataset.shuffle(512)
dataset = dataset.decode('pil').to_tuple('scene.jpg', 'annotation.png').map(map_fn)
dataset = size_webdataset(dataset, meta['n_shards'], meta['samples_per_shard'], dist.get_world_size(),
dataloader_hparams.num_workers, batch_size, self.drop_last)

collate_fn = pil_image_collate
device_transform_fn = NormalizationFn(mean=IMAGENET_CHANNEL_MEAN,
std=IMAGENET_CHANNEL_STD,
ignore_background=self.ignore_background)

return DataSpec(dataloader=dataloader_hparams.initialize_object(
dataset=dataset,
batch_size=batch_size,
sampler=None,
drop_last=self.drop_last,
collate_fn=collate_fn,
),
device_transforms=device_transform_fn)
56 changes: 51 additions & 5 deletions composer/datasets/cifar10.py → composer/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,23 @@

from composer.core.types import DataLoader
from composer.datasets.dataloader import DataloaderHparams
from composer.datasets.hparams import DatasetHparams, SyntheticHparamsMixin
from composer.datasets.hparams import DatasetHparams, JpgClsWebDatasetHparams, SyntheticHparamsMixin
from composer.datasets.synthetic import SyntheticBatchPairDataset
from composer.utils import dist


@dataclass
class CIFAR10DatasetHparams(DatasetHparams, SyntheticHparamsMixin):
"""Defines an instance of the CIFAR-10 dataset for image classification.

Parameters:
download (bool): Whether to download the dataset, if needed.
"""
download: bool = hp.optional("whether to download the dataset, if needed", default=True)

def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHparams) -> DataLoader:
cifar10_mean, cifar10_std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
cifar10_mean = 0.4914, 0.4822, 0.4465
cifar10_std = 0.247, 0.243, 0.261

if self.use_synthetic:
total_dataset_size = 50_000 if self.is_train else 10_000
Expand All @@ -45,12 +46,12 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=cifar10_mean, std=cifar10_std),
transforms.Normalize(cifar10_mean, cifar10_std),
])
else:
transformation = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=cifar10_mean, std=cifar10_std),
transforms.Normalize(cifar10_mean, cifar10_std),
])

dataset = CIFAR10(
Expand All @@ -65,3 +66,48 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
batch_size=batch_size,
sampler=sampler,
drop_last=self.drop_last)


@dataclass
class CIFAR10WebDatasetHparams(JpgClsWebDatasetHparams):
"""Defines an instance of the CIFAR-10 WebDataset for image classification."""

dataset_s3_bucket = 'mosaicml-internal-dataset-cifar10'
dataset_name = 'cifar10'
n_train_samples = 50_000
n_val_samples = 10_000
height = 32
width = 32
n_classes = 10
channel_means = 0.4914, 0.4822, 0.4465
channel_stds = 0.247, 0.243, 0.261


@dataclass
class CIFAR20WebDatasetHparams(JpgClsWebDatasetHparams):
"""Defines an instance of the CIFAR-20 WebDataset for image classification."""

dataset_s3_bucket = 'mosaicml-internal-dataset-cifar20'
dataset_name = 'cifar20'
n_train_samples = 50_000
n_val_samples = 10_000
height = 32
width = 32
n_classes = 20
channel_means = 0.5071, 0.4867, 0.4408
channel_stds = 0.2675, 0.2565, 0.2761


@dataclass
class CIFAR100WebDatasetHparams(JpgClsWebDatasetHparams):
"""Defines an instance of the CIFAR-100 WebDataset for image classification."""

dataset_s3_bucket = 'mosaicml-internal-dataset-cifar100'
dataset_name = 'cifar100'
n_train_samples = 50_000
n_val_samples = 10_000
height = 32
width = 32
n_classes = 100
channel_means = 0.5071, 0.4867, 0.4408
channel_stds = 0.2675, 0.2565, 0.2761
Empty file.
68 changes: 68 additions & 0 deletions composer/datasets/create/ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from argparse import ArgumentParser, Namespace
from glob import glob
from random import shuffle
from typing import Any, Dict, Generator, List, Tuple

from PIL import Image

from composer.datasets.webdataset import create_webdataset


def parse_args() -> Namespace:
args = ArgumentParser()
args.add_argument('--in_root', type=str, required=True)
args.add_argument('--out_root', type=str, required=True)
args.add_argument('--train_shards', type=int, default=512)
args.add_argument('--val_shards', type=int, default=64)
args.add_argument('--tqdm', type=int, default=1)
return args.parse_args()


def each_sample(pairs: List[Tuple[str, str]]) -> Generator[Dict[str, Any], None, None]:
for idx, (scene_file, annotation_file) in enumerate(pairs):
scene = Image.open(scene_file)
annotation = Image.open(annotation_file)
yield {
'__key__': f'{idx:05d}',
'scene.jpg': scene,
'annotation.png': annotation,
}


def process_split(in_root: str, out_root: str, split: str, n_shards: int, use_tqdm: int):
pattern = f'{in_root}/images/{split}/ADE_{split}_*.jpg'
scenes = sorted(glob(pattern))

pattern = f'{in_root}/annotations/{split}/ADE_{split}_*.png'
annotations = sorted(glob(pattern))

pairs = list(zip(scenes, annotations))
shuffle(pairs)

create_webdataset(each_sample(pairs), out_root, split, len(pairs), n_shards, use_tqdm)


def main(args: Namespace) -> None:
'''
Directory layout:

ADE20k/
annotations/
train/
ADE_train_%08d.png
val/
ADE_val_%08d.png
images/
test/
ADE_test_%08d.jpg
train/
ADE_train_%08d.jpg
val/
ADE_val_%08d.jpg
'''
process_split(args.in_root, args.out_root, 'train', args.train_shards, args.tqdm)
process_split(args.in_root, args.out_root, 'val', args.val_shards, args.tqdm)


if __name__ == '__main__':
main(parse_args())
49 changes: 49 additions & 0 deletions composer/datasets/create/cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from argparse import ArgumentParser, Namespace
from typing import Any, Dict, Generator, Tuple

import numpy as np
from torchvision.datasets import CIFAR10
from wurlitzer import pipes

from composer.datasets.webdataset import create_webdataset


def parse_args() -> Namespace:
args = ArgumentParser()
args.add_argument('--out_root', type=str, required=True)
args.add_argument('--train_shards', type=int, default=128)
args.add_argument('--val_shards', type=int, default=128)
args.add_argument('--tqdm', type=int, default=1)
return args.parse_args()


def shuffle(dataset: CIFAR10) -> Tuple[np.ndarray, np.ndarray]:
indices = np.random.permutation(len(dataset))
images = dataset.data[indices]
classes = np.array(dataset.targets)[indices]
return images, classes


def each_sample(images: np.ndarray, classes: np.ndarray) -> Generator[Dict[str, Any], None, None]:
for idx, (img, cls) in enumerate(zip(images, classes)):
yield {
'__key__': f'{idx:05d}',
'jpg': img,
'cls': cls,
}


def main(args: Namespace) -> None:
with pipes():
dataset = CIFAR10(root='/datasets/cifar10', train=True, download=True)
images, classes = shuffle(dataset)
create_webdataset(each_sample(images, classes), args.out_root, 'train', len(images), args.train_shards, args.tqdm)

with pipes():
dataset = CIFAR10(root='/datasets/cifar10', train=False, download=True)
images, classes = shuffle(dataset)
create_webdataset(each_sample(images, classes), args.out_root, 'val', len(images), args.val_shards, args.tqdm)


if __name__ == '__main__':
main(parse_args())
Loading