Skip to content

Commit

Permalink
yews.datasets.dirs under cover
Browse files Browse the repository at this point in the history
  • Loading branch information
lijunzh committed Apr 17, 2019
1 parent 1f8be24 commit 744daae
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 36 deletions.
Binary file added tests/assets/array_folder/samples.npy
Binary file not shown.
Binary file added tests/assets/array_folder/targets.npy
Binary file not shown.
Binary file added tests/assets/folder/a/0/class_a.xxx.npy
Binary file not shown.
Binary file added tests/assets/folder/a/0/class_b.yyy.npy
Binary file not shown.
Binary file added tests/assets/folder/a/1/class_a.xxx.npy
Binary file not shown.
Binary file added tests/assets/folder/a/1/class_b.efg.npy
Binary file not shown.
Binary file added tests/assets/folder/b/0/class_a.xxx.npy
Binary file not shown.
Binary file added tests/assets/folder/b/0/class_b.xyz.npy
Binary file not shown.
Binary file added tests/assets/folder/b/1/class_a.xxx.npy
Binary file not shown.
Binary file added tests/assets/folder/b/1/class_b.abc.npy
Binary file not shown.
107 changes: 71 additions & 36 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import yews.transforms as transforms

from pathlib import Path
import numpy as np

root_dir = Path('tests/assets').resolve()

def test_is_dataset():
assert not datasets.is_dataset(0)
Expand All @@ -22,51 +24,52 @@ def __getitem__(self, index):
def __len__(self):
return self.size

class TestMandatoryMethods:

class DummyBaseDataset(datasets.BaseDataset):
def test_call_method(self):
assert all([hasattr(getattr(datasets, t), '__getitem__') for t in
datasets.__all__])

def build_dataset(self):
return DummpyDatasetlike(), DummpyDatasetlike()


class DummyBaseDatasetNoSamples(datasets.BaseDataset):

def build_dataset(self):
return 0, DummpyDatasetlike()
def test_repr_method(self):
assert all([hasattr(getattr(datasets, t), '__len__') for t in
datasets.__all__])

class TestBaseDataset:

class DummyBaseDatasetNoTargets(datasets.BaseDataset):
class DummyBaseDataset(datasets.BaseDataset):

def build_dataset(self):
return DummpyDatasetlike(), 0
def build_dataset(self):
return DummpyDatasetlike(), DummpyDatasetlike()


class DummyBaseDatasetWrongLength(datasets.BaseDataset):
class DummyBaseDatasetNoSamples(datasets.BaseDataset):

def build_dataset(self):
return DummpyDatasetlike(1), DummpyDatasetlike(2)
def build_dataset(self):
return 0, DummpyDatasetlike()


class DummyTransform(transforms.BaseTransform):
class DummyBaseDatasetNoTargets(datasets.BaseDataset):

def __call__(self, data):
return "transformed"
def build_dataset(self):
return DummpyDatasetlike(), 0


class DummyPathDataset(datasets.PathDataset):
class DummyBaseDatasetWrongLength(datasets.BaseDataset):

def build_dataset(self):
return DummpyDatasetlike(), DummpyDatasetlike()
def build_dataset(self):
return DummpyDatasetlike(1), DummpyDatasetlike(2)

class DummyTransform(transforms.BaseTransform):

class TestBaseDataset:
def __call__(self, data):
return "transformed"

def test_empty_construct(self):
dset = datasets.BaseDataset()
assert len(dset) == 0

def test_noempty_constrct(self):
dset = DummyBaseDataset(root='.')
dset = self.DummyBaseDataset(root='.')
assert len(dset) == 1

def test_raise_notimplmenetederror(self):
Expand All @@ -75,36 +78,68 @@ def test_raise_notimplmenetederror(self):

def test_no_samples(self):
with pytest.raises(ValueError):
dset = DummyBaseDatasetNoSamples(root='.')
dset = self.DummyBaseDatasetNoSamples(root='.')

def test_no_targets(self):
with pytest.raises(ValueError):
dset = DummyBaseDatasetNoTargets(root='.')
dset = self.DummyBaseDatasetNoTargets(root='.')

def test_samples_targets_not_match(self):
with pytest.raises(ValueError):
dset = DummyBaseDatasetWrongLength(root='.')
dset = self.DummyBaseDatasetWrongLength(root='.')

def test_getitem_with_transform(self):
dset = DummyBaseDataset(root='.',
sample_transform=DummyTransform(),
target_transform=DummyTransform())
dset = self.DummyBaseDataset(root='.',
sample_transform=self.DummyTransform(),
target_transform=self.DummyTransform())
assert dset[0] == ('transformed', 'transformed')
dset = DummyBaseDataset(root='.')
dset = self.DummyBaseDataset(root='.')
assert dset[0] == ('a item', 'a item')

def test_repr(self):
dset = DummyBaseDataset(root='.',
sample_transform='t',
target_transform='tt')
dset = self.DummyBaseDataset(root='.',
sample_transform='t',
target_transform='tt')
assert type(dset.__repr__()) is str
dset = DummyBaseDataset()
dset = self.DummyBaseDataset()
assert type(dset.__repr__()) is str


class TestPathDataset:

dset = DummyPathDataset(root='.')
class DummyPathDataset(datasets.PathDataset):

def build_dataset(self):
return DummpyDatasetlike(), DummpyDatasetlike()

def test_root_is_path(self):
assert self.dset.root == Path(self.dset.root).resolve()
dset = self.DummyPathDataset(root='.')
assert dset.root == Path(dset.root).resolve()


class TestDirDataset:

class DummyDirDataset(datasets.DirDataset):

def build_dataset(self):
return DummpyDatasetlike(), DummpyDatasetlike()

def test_dir_check(self):
dset = self.DummyDirDataset(root='.')
with pytest.raises(ValueError):
dset = self.DummyDirDataset(root='abc')


class TestDatasetArrayFolder:

def test_loading_npy(self):
dset = datasets.DatasetArrayFolder(root=root_dir / 'array_folder')
assert all([dset[0][0].shape == (3, 100), dset[0][1].shape == ()])


class TestDatasetFolder:

def test_loading_folder(self):
dset = datasets.DatasetFolder(root=root_dir/ 'folder', loader=np.load)
assert all([dset[0][0].shape == (3, 100), type(dset[0][1]) is str])

0 comments on commit 744daae

Please sign in to comment.