Skip to content

Commit

Permalink
Add random split transform (#1213)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiltsov-max authored Feb 28, 2020
1 parent 3604f0c commit da69a40
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
61 changes: 61 additions & 0 deletions datumaro/datumaro/plugins/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import logging as log
import os.path as osp
import random

import pycocotools.mask as mask_utils

Expand Down Expand Up @@ -295,6 +296,66 @@ def transform_item(self, item):
return self.wrap_item(item,
subset=self._mapping.get(item.subset, item.subset))

class RandomSplit(Transform, CliPlugin):
"""
Joins all subsets into one and splits the result into few parts.
It is expected that item ids are unique and subset ratios sum up to 1.|n
|n
Example:|n
|s|s%(prog)s --subset train:.67 --subset test:.33
"""

@staticmethod
def _split_arg(s):
parts = s.split(':')
if len(parts) != 2:
import argparse
raise argparse.ArgumentTypeError()
return (parts[0], float(parts[1]))

@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('-s', '--subset', action='append',
type=cls._split_arg, dest='splits',
help="Subsets in the form of: '<subset>:<ratio>' (repeatable)")
parser.add_argument('--seed', type=int, help="Random seed")
return parser

def __init__(self, extractor, splits, seed=None):
super().__init__(extractor)

total_ratio = sum((s[1] for s in splits), 0)
if not total_ratio == 1:
raise Exception(
"Sum of ratios is expected to be 1, got %s, which is %s" %
(splits, total_ratio))

dataset_size = len(extractor)
indices = list(range(dataset_size))

random.seed(seed)
random.shuffle(indices)

parts = []
s = 0
for subset, ratio in splits:
s += ratio
boundary = int(s * dataset_size)
parts.append((boundary, subset))

self._parts = parts

def _find_split(self, index):
for boundary, subset in self._parts:
if index < boundary:
return subset
return subset

def __iter__(self):
for i, item in enumerate(self._extractor):
yield self.wrap_item(item, subset=self._find_split(i))

class IdFromImageName(Transform, CliPlugin):
def transform_item(self, item):
name = item.id
Expand Down
37 changes: 37 additions & 0 deletions datumaro/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,40 @@ def __iter__(self):

actual = transforms.BoxesToMasks(SrcExtractor())
compare_datasets(self, DstExtractor(), actual)

def test_random_split(self):
class SrcExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=1, subset="a"),
DatasetItem(id=2, subset="a"),
DatasetItem(id=3, subset="b"),
DatasetItem(id=4, subset="b"),
DatasetItem(id=5, subset="b"),
DatasetItem(id=6, subset=""),
DatasetItem(id=7, subset=""),
])

actual = transforms.RandomSplit(SrcExtractor(), splits=[
('train', 4.0 / 7.0),
('test', 3.0 / 7.0),
])

self.assertEqual(4, len(actual.get_subset('train')))
self.assertEqual(3, len(actual.get_subset('test')))

def test_random_split_gives_error_on_non1_ratios(self):
class SrcExtractor(Extractor):
def __iter__(self):
return iter([DatasetItem(id=1)])

has_error = False
try:
transforms.RandomSplit(SrcExtractor(), splits=[
('train', 0.5),
('test', 0.7),
])
except Exception:
has_error = True

self.assertTrue(has_error)

0 comments on commit da69a40

Please sign in to comment.