Skip to content

Commit

Permalink
Initial partial support for infinite_mux
Browse files Browse the repository at this point in the history
  • Loading branch information
pzelasko committed Jan 22, 2024
1 parent bbabca5 commit abb1e52
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 11 deletions.
6 changes: 5 additions & 1 deletion lhotse/cut/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -2439,7 +2439,7 @@ def modify_ids(self, transform_fn: Callable[[str], str]) -> "CutSet":
a new string (new cut ID).
:return: a new ``CutSet`` with cuts with modified IDs.
"""
return self.map(lambda cut: cut.with_id(transform_fn(cut.id)))
return self.map(partial(_with_id, transform_fn=transform_fn))

def fill_supervisions(
self, add_empty: bool = True, shrink_ok: bool = False
Expand Down Expand Up @@ -3265,6 +3265,10 @@ def _add_features_path_prefix_single(cut, path):
return cut.with_features_path_prefix(path)


def _with_id(cut, transform_fn):
return cut.with_id(transform_fn(cut.id))


def _call(obj, member_fn: str, *args, **kwargs) -> Callable:
return getattr(obj, member_fn)(*args, **kwargs)

Expand Down
7 changes: 4 additions & 3 deletions lhotse/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def __getstate__(self):
def __setstate__(self, state):
if (
is_module_available("dill")
and os.environ.get("LHOTSE_DILL_ENABLED", "0") not in self._ENABLED_VALUES
and os.environ.get("LHOTSE_DILL_ENABLED", "0") in self._ENABLED_VALUES
):
import dill

Expand Down Expand Up @@ -481,9 +481,10 @@ def shuffled_streams():
# towards the beginning of an "epoch" and then keep yielding
# from it1 shards until the epoch is finished and we can sample
# from it0 again...
zipped_iter_weights = list(zip(self.iterators, self.weights))
indexes = list(range(len(self.iterators)))
while True:
yield rng.choices(zipped_iter_weights, self.weights, k=1)[0]
selected = rng.choices(indexes, self.weights, k=1)[0]
yield self.iterators[selected], self.weights[selected]

# Initialize an infinite sequence of finite streams.
# It is sampled with weights and replacement from ``self.iterators``,
Expand Down
70 changes: 63 additions & 7 deletions test/dataset/test_controllable_weights.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from uuid import uuid4

import numpy as np
import pytest
import torch

from lhotse import CutSet
from lhotse.dataset import DynamicCutSampler, IterableDatasetWrapper
from lhotse.testing.dummies import DummyManifest
from lhotse.testing.random import deterministic_rng


class DummyDataset(torch.utils.data.Dataset):
Expand All @@ -20,13 +23,17 @@ def _inner(cut):
return _inner


def random_id(*args):
return str(uuid4())


def assert_sources_are(cuts: CutSet, expected: list[int]):
actual = [c.source for c in cuts]
assert actual == expected


@pytest.mark.parametrize("weight_type", [list, np.array, torch.tensor])
def test_mux_with_controllable_weights(weight_type):
def test_mux_with_controllable_weights(deterministic_rng, weight_type):
"""The sampler and the worker are both in the main process."""

# 3 infinite iterables
Expand Down Expand Up @@ -59,7 +66,7 @@ def test_mux_with_controllable_weights(weight_type):
assert_sources_are(b, [2, 2])


def test_mux_with_controllable_weights_subprocess_worker():
def test_mux_with_controllable_weights_subprocess_worker(deterministic_rng):
"""
The sampler is in the main process but the worker is in a sub-process.
Expand Down Expand Up @@ -106,7 +113,9 @@ def test_mux_with_controllable_weights_subprocess_worker():
assert_sources_are(b, [2, 2])


def test_mux_with_controllable_weights_subprocess_sampler_shared_memory():
def test_mux_with_controllable_weights_subprocess_sampler_shared_memory(
deterministic_rng,
):
"""
The sampler is placed in the dataloading subprocess.
Expand Down Expand Up @@ -138,12 +147,59 @@ def test_mux_with_controllable_weights_subprocess_sampler_shared_memory():
b = next(dloader)
assert_sources_are(b, [0, 0])

weights[0] = 0.0
weights[1] = 1.0
weights[:] = torch.tensor([0, 1, 0]) # atomic update
b = next(dloader)
assert_sources_are(b, [1, 1])

weights[:] = torch.tensor([0, 0, 1]) # atomic update
b = next(dloader)
assert_sources_are(b, [2, 2])


@pytest.mark.skip(
reason="Infinite mux is not yet fully supported for shared memory weights."
)
def test_infinite_mux_with_controllable_weights_subprocess_sampler_shared_memory(
deterministic_rng,
):
"""
The sampler is placed in the dataloading subprocess.
Note: we are using PyTorch shared memory to share the weight tensor across processes.
In general expect a latency of ``prefetch_factor * num_workers`` in the propagation
of weights between the main process and the dataloading subprocesses.
"""

# 3 infinite iterables
cuts1 = DummyManifest(CutSet, begin_id=0, end_id=3).map(mark(0))
cuts2 = DummyManifest(CutSet, begin_id=10, end_id=13).map(mark(1))
cuts3 = DummyManifest(CutSet, begin_id=100, end_id=103).map(mark(2))

weights = torch.tensor([1, 0, 0]).share_memory_()
assert weights.is_shared()
# randomize_id is required because infinite_mux may sample the same cut in a mini batch
muxd = CutSet.infinite_mux(cuts1, cuts2, cuts3, weights=weights).modify_ids(
random_id
)

dloader = torch.utils.data.DataLoader(
dataset=IterableDatasetWrapper(
dataset=DummyDataset(), sampler=DynamicCutSampler(muxd, max_cuts=2)
),
batch_size=None,
num_workers=1,
prefetch_factor=1,
)

dloader = iter(dloader)
b = next(dloader)
assert_sources_are(b, [0, 0])

weights[:] = torch.tensor([0, 1, 0]) # atomic update
b = next(dloader)
assert_sources_are(b, [1, 1])

weights[1] = 0.0
weights[2] = 1.0
weights[:] = torch.tensor([0, 0, 1]) # atomic update
b = next(dloader)
assert_sources_are(b, [2, 2])

0 comments on commit abb1e52

Please sign in to comment.