Skip to content

Commit

Permalink
Fix remainder logic for subset splitting (#1222)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiltsov-max authored Mar 2, 2020
1 parent 9719523 commit 0d873a3
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
11 changes: 7 additions & 4 deletions datumaro/datumaro/plugins/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,12 @@ def build_cmdline_parser(cls, **kwargs):
def __init__(self, extractor, splits, seed=None):
super().__init__(extractor)

total_ratio = sum((s[1] for s in splits), 0)
if not total_ratio == 1:
assert 0 < len(splits), "Expected at least one split"
assert all(0.0 <= r and r <= 1.0 for _, r in splits), \
"Ratios are expected to be in the range [0; 1], but got %s" % splits

total_ratio = sum(s[1] for s in splits)
if not abs(total_ratio - 1.0) <= 1e-7:
raise Exception(
"Sum of ratios is expected to be 1, got %s, which is %s" %
(splits, total_ratio))
Expand All @@ -336,7 +340,6 @@ def __init__(self, extractor, splits, seed=None):

random.seed(seed)
random.shuffle(indices)

parts = []
s = 0
for subset, ratio in splits:
Expand All @@ -350,7 +353,7 @@ def _find_split(self, index):
for boundary, subset in self._parts:
if index < boundary:
return subset
return subset
return subset # all the possible remainder goes to the last split

def __iter__(self):
for i, item in enumerate(self._extractor):
Expand Down
7 changes: 1 addition & 6 deletions datumaro/tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,15 +534,10 @@ def __iter__(self):

class DatasetItemTest(TestCase):
def test_ctor_requires_id(self):
has_error = False
try:
with self.assertRaises(Exception):
# pylint: disable=no-value-for-parameter
DatasetItem()
# pylint: enable=no-value-for-parameter
except AssertionError:
has_error = True

self.assertTrue(has_error)

@staticmethod
def test_ctors_with_image():
Expand Down
16 changes: 10 additions & 6 deletions datumaro/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,18 +342,22 @@ def __iter__(self):
self.assertEqual(4, len(actual.get_subset('train')))
self.assertEqual(3, len(actual.get_subset('test')))

def test_random_split_gives_error_on_non1_ratios(self):
def test_random_split_gives_error_on_wrong_ratios(self):
class SrcExtractor(Extractor):
def __iter__(self):
return iter([DatasetItem(id=1)])

has_error = False
try:
with self.assertRaises(Exception):
transforms.RandomSplit(SrcExtractor(), splits=[
('train', 0.5),
('test', 0.7),
])
except Exception:
has_error = True

self.assertTrue(has_error)
with self.assertRaises(Exception):
transforms.RandomSplit(SrcExtractor(), splits=[])

with self.assertRaises(Exception):
transforms.RandomSplit(SrcExtractor(), splits=[
('train', -0.5),
('test', 1.5),
])

0 comments on commit 0d873a3

Please sign in to comment.