Skip to content

Commit

Permalink
Fix consecutive same sampler selection in round robin sampler with nu…
Browse files Browse the repository at this point in the history
…m_workers>1

Signed-off-by: Piotr Żelasko <petezor@gmail.com>
  • Loading branch information
pzelasko committed Dec 6, 2024
1 parent a13c084 commit c645d59
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
18 changes: 15 additions & 3 deletions lhotse/dataset/sampling/round_robin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch

from lhotse import CutSet
from lhotse.cut import Cut
Expand Down Expand Up @@ -171,7 +172,18 @@ def __iter__(self):
if self._just_restored_state:
return self
self._nondepleted_samplers_indices = list(range(len(self.samplers)))
# In case this sampler lives in the dataloading worker subprocess,
# set the starting index to a different value on each dataloading worker.
# This helps avoid situations where the round robin sampler chooses
# the same underlying sampler for N consecutive mini-batches, where N = num_workers (>1).
self._cur_sampler_idx = 0
self._num_dl_workers = 1
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
self._cur_sampler_idx = worker_info.id % len(
self._nondepleted_samplers_indices
)
self._num_dl_workers = worker_info.num_workers
return self

def _next_batch(self) -> Union[CutSet, Tuple[CutSet]]:
Expand Down Expand Up @@ -202,9 +214,9 @@ def _set_next_idx(self) -> None:
p = [x / sum(p) for x in p]
self._cur_sampler_idx = self.rng.choice(N, size=1, replace=False, p=p)[0]
else:
self._cur_sampler_idx = (self._cur_sampler_idx + 1) % len(
self._nondepleted_samplers_indices
)
self._cur_sampler_idx = (
self._cur_sampler_idx + self._num_dl_workers
) % len(self._nondepleted_samplers_indices)

def set_epoch(self, epoch: int) -> None:
"""
Expand Down
22 changes: 21 additions & 1 deletion test/dataset/sampling/test_sampling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
import random
import re
from collections import Counter
Expand Down Expand Up @@ -856,6 +855,27 @@ def test_round_robin_sampler(randomize):
# ... and so on


@pytest.mark.parametrize("num_workers", [0, 1, 2, 3])
def test_nonrandomized_round_robin_sampler_keeps_round_robin_property_in_iterable_dataset(
num_workers,
):
cuts1 = DummyManifest(CutSet, begin_id=0, end_id=100)
cuts2 = DummyManifest(CutSet, begin_id=500, end_id=600)
cuts3 = DummyManifest(CutSet, begin_id=1000, end_id=1100)
sampler = RoundRobinSampler(
SimpleCutSampler(cuts1, max_cuts=1, shuffle=False),
SimpleCutSampler(cuts2, max_cuts=2, shuffle=False),
SimpleCutSampler(cuts3, max_cuts=3, shuffle=False),
)
dloader = DataLoader(
dataset=IterableDatasetWrapper(IdentityDataset(), sampler),
batch_size=None,
num_workers=num_workers,
)
lens = [len(b) for idx, b in zip(range(15), dloader)]
assert lens == [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]


@pytest.mark.parametrize("sampler_cls", [SimpleCutSampler, DynamicCutSampler])
def test_single_cut_sampler_drop_last(sampler_cls):
# The dummy cuts have a duration of 1 second each
Expand Down

0 comments on commit c645d59

Please sign in to comment.