From c645d59aeab8cdb1dd19f462fabc4e5f26c648fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 6 Dec 2024 10:53:24 -0500 Subject: [PATCH] Fix consecutive same sampler selection in round robin sampler with num_workers>1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- lhotse/dataset/sampling/round_robin.py | 18 +++++++++++++++--- test/dataset/sampling/test_sampling.py | 22 +++++++++++++++++++++- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/lhotse/dataset/sampling/round_robin.py b/lhotse/dataset/sampling/round_robin.py index 7f9957a37..4fe21e617 100644 --- a/lhotse/dataset/sampling/round_robin.py +++ b/lhotse/dataset/sampling/round_robin.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np +import torch from lhotse import CutSet from lhotse.cut import Cut @@ -171,7 +172,18 @@ def __iter__(self): if self._just_restored_state: return self self._nondepleted_samplers_indices = list(range(len(self.samplers))) + # In case this sampler lives in the dataloading worker subprocess, + # set the starting index to a different value on each dataloading worker. + # This helps avoid situations where the round robin sampler chooses + # the same underlying sampler for N consecutive mini-batches, where N = num_workers (>1). self._cur_sampler_idx = 0 + self._num_dl_workers = 1 + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + self._cur_sampler_idx = worker_info.id % len( + self._nondepleted_samplers_indices + ) + self._num_dl_workers = worker_info.num_workers return self def _next_batch(self) -> Union[CutSet, Tuple[CutSet]]: @@ -202,9 +214,9 @@ def _set_next_idx(self) -> None: p = [x / sum(p) for x in p] self._cur_sampler_idx = self.rng.choice(N, size=1, replace=False, p=p)[0] else: - self._cur_sampler_idx = (self._cur_sampler_idx + 1) % len( - self._nondepleted_samplers_indices - ) + self._cur_sampler_idx = ( + self._cur_sampler_idx + self._num_dl_workers + ) % len(self._nondepleted_samplers_indices) def set_epoch(self, epoch: int) -> None: """ diff --git a/test/dataset/sampling/test_sampling.py b/test/dataset/sampling/test_sampling.py index 686b472b9..6f8c10d51 100644 --- a/test/dataset/sampling/test_sampling.py +++ b/test/dataset/sampling/test_sampling.py @@ -1,4 +1,3 @@ -import math import random import re from collections import Counter @@ -856,6 +855,27 @@ def test_round_robin_sampler(randomize): # ... and so on +@pytest.mark.parametrize("num_workers", [0, 1, 2, 3]) +def test_nonrandomized_round_robin_sampler_keeps_round_robin_property_in_iterable_dataset( + num_workers, +): + cuts1 = DummyManifest(CutSet, begin_id=0, end_id=100) + cuts2 = DummyManifest(CutSet, begin_id=500, end_id=600) + cuts3 = DummyManifest(CutSet, begin_id=1000, end_id=1100) + sampler = RoundRobinSampler( + SimpleCutSampler(cuts1, max_cuts=1, shuffle=False), + SimpleCutSampler(cuts2, max_cuts=2, shuffle=False), + SimpleCutSampler(cuts3, max_cuts=3, shuffle=False), + ) + dloader = DataLoader( + dataset=IterableDatasetWrapper(IdentityDataset(), sampler), + batch_size=None, + num_workers=num_workers, + ) + lens = [len(b) for idx, b in zip(range(15), dloader)] + assert lens == [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3] + + @pytest.mark.parametrize("sampler_cls", [SimpleCutSampler, DynamicCutSampler]) def test_single_cut_sampler_drop_last(sampler_cls): # The dummy cuts have a duration of 1 second each