Skip to content

Commit

Permalink
Support most CutSet operations without dill; fix tests; infinite_mu…
Browse files Browse the repository at this point in the history
…x works
  • Loading branch information
pzelasko committed Jan 23, 2024
1 parent abb1e52 commit 8c53f96
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 38 deletions.
128 changes: 104 additions & 24 deletions lhotse/cut/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ def filter_supervisions(
:param predicate: A callable that accepts `SupervisionSegment` and returns bool
:return: a CutSet with filtered supervisions
"""
return self.map(lambda cut: cut.filter_supervisions(predicate))
return self.map(partial(_filter_supervisions, predicate=predicate))

def merge_supervisions(
self,
Expand All @@ -982,8 +982,10 @@ def merge_supervisions(
``custom_merge_fn(custom_key, [s.custom[custom_key] for s in sups])``
"""
return self.map(
lambda cut: cut.merge_supervisions(
merge_policy=merge_policy, custom_merge_fn=custom_merge_fn
partial(
_merge_supervisions,
merge_policy=merge_policy,
custom_merge_fn=custom_merge_fn,
)
)

Expand Down Expand Up @@ -1341,7 +1343,8 @@ def pad(
duration = max(cut.duration for cut in self)

return self.map(
lambda cut: cut.pad(
partial(
_pad,
duration=duration,
num_frames=num_frames,
num_samples=num_samples,
Expand Down Expand Up @@ -1422,7 +1425,8 @@ def extend_by(
:return: a new CutSet instance.
"""
return self.map(
lambda cut: cut.extend_by(
partial(
_extend_by,
duration=duration,
direction=direction,
preserve_id=preserve_id,
Expand Down Expand Up @@ -1535,7 +1539,9 @@ def resample(self, sampling_rate: int, affix_id: bool = False) -> "CutSet":
cut are going to be present in a single manifest).
:return: a modified copy of the ``CutSet``.
"""
return self.map(lambda cut: cut.resample(sampling_rate, affix_id=affix_id))
return self.map(
partial(_resample, sampling_rate=sampling_rate, affix_id=affix_id)
)

def perturb_speed(self, factor: float, affix_id: bool = True) -> "CutSet":
"""
Expand All @@ -1550,7 +1556,7 @@ def perturb_speed(self, factor: float, affix_id: bool = True) -> "CutSet":
cut are going to be present in a single manifest).
:return: a modified copy of the ``CutSet``.
"""
return self.map(lambda cut: cut.perturb_speed(factor=factor, affix_id=affix_id))
return self.map(partial(_perturb_speed, factor=factor, affix_id=affix_id))

def perturb_tempo(self, factor: float, affix_id: bool = True) -> "CutSet":
"""
Expand All @@ -1568,7 +1574,7 @@ def perturb_tempo(self, factor: float, affix_id: bool = True) -> "CutSet":
cut are going to be present in a single manifest).
:return: a modified copy of the ``CutSet``.
"""
return self.map(lambda cut: cut.perturb_tempo(factor=factor, affix_id=affix_id))
return self.map(partial(_perturb_tempo, factor=factor, affix_id=affix_id))

def perturb_volume(self, factor: float, affix_id: bool = True) -> "CutSet":
"""
Expand All @@ -1582,9 +1588,7 @@ def perturb_volume(self, factor: float, affix_id: bool = True) -> "CutSet":
cut are going to be present in a single manifest).
:return: a modified copy of the ``CutSet``.
"""
return self.map(
lambda cut: cut.perturb_volume(factor=factor, affix_id=affix_id)
)
return self.map(partial(_perturb_volume, factor=factor, affix_id=affix_id))

def normalize_loudness(
self, target: float, mix_first: bool = True, affix_id: bool = True
Expand All @@ -1599,8 +1603,11 @@ def normalize_loudness(
:return: a modified copy of the current ``CutSet``.
"""
return self.map(
lambda cut: cut.normalize_loudness(
target=target, mix_first=mix_first, affix_id=affix_id
partial(
_normalize_loudness,
target=target,
mix_first=mix_first,
affix_id=affix_id,
)
)

Expand All @@ -1612,7 +1619,7 @@ def dereverb_wpe(self, affix_id: bool = True) -> "CutSet":
by affixing it with "_wpe".
:return: a modified copy of the current ``CutSet``.
"""
return self.map(lambda cut: cut.dereverb_wpe(affix_id=affix_id))
return self.map(partial(_dereverb_wpe, affix_id=affix_id))

def reverb_rir(
self,
Expand Down Expand Up @@ -1643,7 +1650,8 @@ def reverb_rir(
"""
rir_recordings = list(rir_recordings) if rir_recordings else None
return self.map(
lambda cut: cut.reverb_rir(
partial(
_reverb_rir,
rir_recording=random.choice(rir_recordings) if rir_recordings else None,
normalize_output=normalize_output,
early_only=early_only,
Expand Down Expand Up @@ -1713,25 +1721,25 @@ def drop_features(self) -> "CutSet":
"""
Return a new :class:`.CutSet`, where each :class:`.Cut` is copied and detached from its extracted features.
"""
return self.map(lambda cut: cut.drop_features())
return self.map(_drop_features)

def drop_recordings(self) -> "CutSet":
"""
Return a new :class:`.CutSet`, where each :class:`.Cut` is copied and detached from its recordings.
"""
return self.map(lambda cut: cut.drop_recording())
return self.map(_drop_recordings)

def drop_supervisions(self) -> "CutSet":
"""
Return a new :class:`.CutSet`, where each :class:`.Cut` is copied and detached from its supervisions.
"""
return self.map(lambda cut: cut.drop_supervisions())
return self.map(_drop_supervisions)

def drop_alignments(self) -> "CutSet":
"""
Return a new :class:`.CutSet`, where each :class:`.Cut` is copied and detached from the alignments present in its supervisions.
"""
return self.map(lambda cut: cut.drop_alignments())
return self.map(_drop_alignments)

def compute_and_store_features(
self,
Expand Down Expand Up @@ -2461,7 +2469,7 @@ def fill_supervisions(
of calling this method.
"""
return self.map(
lambda cut: cut.fill_supervision(add_empty=add_empty, shrink_ok=shrink_ok)
partial(_fill_supervision, add_empty=add_empty, shrink_ok=shrink_ok)
)

def map_supervisions(
Expand All @@ -2473,7 +2481,7 @@ def map_supervisions(
:param transform_fn: a function that modifies a supervision as an argument.
:return: a new, modified CutSet.
"""
return self.map(lambda cut: cut.map_supervisions(transform_fn))
return self.map(partial(_map_supervisions, transform_fn=transform_fn))

def transform_text(self, transform_fn: Callable[[str], str]) -> "CutSet":
"""
Expand All @@ -2483,7 +2491,9 @@ def transform_text(self, transform_fn: Callable[[str], str]) -> "CutSet":
:param transform_fn: a function that accepts a string and returns a string.
:return: a new, modified CutSet.
"""
return self.map_supervisions(lambda s: s.transform_text(transform_fn))
return self.map_supervisions(
partial(_transform_text, transform_fn=transform_fn)
)

def __repr__(self) -> str:
try:
Expand Down Expand Up @@ -3269,8 +3279,78 @@ def _with_id(cut, transform_fn):
return cut.with_id(transform_fn(cut.id))


def _call(obj, member_fn: str, *args, **kwargs) -> Callable:
return getattr(obj, member_fn)(*args, **kwargs)
def _fill_supervision(cut, add_empty, shrink_ok):
return cut.fill_supervision(add_empty=add_empty, shrink_ok=shrink_ok)


def _map_supervisions(cut, transform_fn):
return cut.map_supervisions(transform_fn)


def _transform_text(sup, transform_fn):
return sup.transform_text(transform_fn)


def _filter_supervisions(cut, predicate):
return cut.filter_supervisions(predicate)


def _merge_supervisions(cut, merge_policy, custom_merge_fn):
return cut.merge_supervisions(
merge_policy=merge_policy, custom_merge_fn=custom_merge_fn
)


def _pad(cut, *args, **kwargs):
return cut.pad(*args, **kwargs)


def _extend_by(cut, *args, **kwargs):
return cut.extend_by(*args, **kwargs)


def _resample(cut, *args, **kwargs):
return cut.resample(*args, **kwargs)


def _perturb_speed(cut, *args, **kwargs):
return cut.perturb_speed(*args, **kwargs)


def _perturb_tempo(cut, *args, **kwargs):
return cut.perturb_speed(*args, **kwargs)


def _perturb_volume(cut, *args, **kwargs):
return cut.perturb_speed(*args, **kwargs)


def _reverb_rir(cut, *args, **kwargs):
return cut.perturb_speed(*args, **kwargs)


def _normalize_loudness(cut, *args, **kwargs):
return cut.normalize_loudness(*args, **kwargs)


def _dereverb_wpe(cut, *args, **kwargs):
return cut.dereverb_wpe(*args, **kwargs)


def _drop_features(cut, *args, **kwargs):
return cut.drop_features(*args, **kwargs)


def _drop_recordings(cut, *args, **kwargs):
return cut.drop_recording(*args, **kwargs)


def _drop_alignments(cut, *args, **kwargs):
return cut.drop_alignments(*args, **kwargs)


def _drop_supervisions(cut, *args, **kwargs):
return cut.drop_supervisions(*args, **kwargs)


def _export_to_shar_single(
Expand Down
24 changes: 15 additions & 9 deletions lhotse/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,20 +494,28 @@ def shuffled_streams():
# Sample the first M active streams to be multiplexed.
# As streams get depleted, we will replace them with
# new streams sampled from the stream source.
active_streams = []
active_weights = []
active_streams = [None] * self.max_open_streams
active_weights = [None] * self.max_open_streams
stream_indexes = list(range(self.max_open_streams))
for _ in range(self.max_open_streams):

def sample_new_stream_at(pos: int) -> None:
sampled_stream, sampled_weight = next(stream_source)
active_streams.append(iter(sampled_stream))
active_weights.append(sampled_weight)
active_streams[pos] = iter(sampled_stream)
active_weights[pos] = sampled_weight

for stream_pos in range(self.max_open_streams):
sample_new_stream_at(stream_pos)

# The actual multiplexing loop.
while True:
# Select a stream from the currently active streams.
# We actually sample an index so that we know which position
# to replace if a stream is exhausted.
stream_pos = rng.choices(stream_indexes, weights=active_weights, k=1)[0]
stream_pos = rng.choices(
stream_indexes,
weights=active_weights if sum(active_weights) > 0 else None,
k=1,
)[0]
selected = active_streams[stream_pos]
try:
# Sample from the selected stream.
Expand All @@ -516,9 +524,7 @@ def shuffled_streams():
except StopIteration:
# The selected stream is exhausted. Replace it with another one,
# and return a sample from the newly opened stream.
sampled_stream, sampled_weight = next(stream_source)
active_streams[stream_pos] = iter(sampled_stream)
active_weights[stream_pos] = sampled_weight
sample_new_stream_at(stream_pos)
item = next(active_streams[stream_pos])
yield item

Expand Down
6 changes: 6 additions & 0 deletions lhotse/testing/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, List

import numpy as np
import pytest
import torch

from lhotse import (
Expand All @@ -22,6 +23,11 @@
from lhotse.utils import Seconds, uuid4


@pytest.fixture()
def with_dill_enabled():
os.environ["LHOTSE_ENABLE_DILL"] = "1"


def random_cut_set(n_cuts=100) -> CutSet:
sr = 16000
return CutSet.from_cuts(
Expand Down
9 changes: 8 additions & 1 deletion test/dataset/sampling/test_sampler_pickling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
)
from lhotse.dataset.sampling.dynamic import DynamicCutSampler
from lhotse.testing.dummies import DummyManifest
from lhotse.testing.fixtures import with_dill_enabled
from lhotse.utils import is_module_available

CUTS = DummyManifest(CutSet, begin_id=0, end_id=100)
CUTS_MOD = CUTS.modify_ids(lambda cid: cid + "_alt")
Expand Down Expand Up @@ -120,8 +122,13 @@ def test_sampler_pickling_with_filter(sampler):
assert batches_restored[0][0].id == "dummy-mono-cut-0000"


@pytest.mark.xfail(
not is_module_available("dill"),
reason="This test will fail when 'dill' module is not installed as it won't be able to pickle a closure.",
raises=AttributeError,
)
@pytest.mark.parametrize("sampler", create_samplers_to_test_filter())
def test_sampler_pickling_with_filter_local_closure(sampler):
def test_sampler_pickling_with_filter_local_closure(with_dill_enabled, sampler):

selected_id = "dummy-mono-cut-0000"

Expand Down
18 changes: 15 additions & 3 deletions test/dataset/test_controllable_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,6 @@ def test_mux_with_controllable_weights_subprocess_sampler_shared_memory(
assert_sources_are(b, [2, 2])


@pytest.mark.skip(
reason="Infinite mux is not yet fully supported for shared memory weights."
)
def test_infinite_mux_with_controllable_weights_subprocess_sampler_shared_memory(
deterministic_rng,
):
Expand Down Expand Up @@ -196,10 +193,25 @@ def test_infinite_mux_with_controllable_weights_subprocess_sampler_shared_memory
b = next(dloader)
assert_sources_are(b, [0, 0])

# Note the latency for several batches. The reason is the following:
# infinite_mux() samples streams with replacement, and at the beginning of the test is samples
# 3x stream #0, which has 3 items each with equal probability.
# It will only start returning items from stream #1 once one of the previous streams is exhausted.
weights[:] = torch.tensor([0, 1, 0]) # atomic update
b = next(dloader)
assert_sources_are(b, [0, 0])
b = next(dloader)
assert_sources_are(b, [0, 0])
b = next(dloader)
assert_sources_are(b, [1, 1])

# The latency strikes again as now we have both streams #0 and #1 open,
# but they have zero weight. It means they will be uniformly sampled until
# one of them is exhausted and a new stream #2 is opened.
weights[:] = torch.tensor([0, 0, 1]) # atomic update
b = next(dloader)
assert_sources_are(b, [0, 0])
b = next(dloader)
assert_sources_are(b, [1, 2])
b = next(dloader)
assert_sources_are(b, [2, 2])
3 changes: 2 additions & 1 deletion test/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lhotse import CutSet, FeatureSet, RecordingSet, SupervisionSet, combine
from lhotse.lazy import LazyJsonlIterator
from lhotse.testing.dummies import DummyManifest, as_lazy
from lhotse.testing.fixtures import with_dill_enabled
from lhotse.utils import fastcopy, is_module_available


Expand Down Expand Up @@ -235,7 +236,7 @@ def _get_ids(cuts):
reason="This test will fail when 'dill' module is not installed as it won't be able to pickle a lambda.",
raises=AttributeError,
)
def test_dillable():
def test_dillable(with_dill_enabled):
cuts = DummyManifest(CutSet, begin_id=0, end_id=2)
with as_lazy(cuts) as lazy_cuts:
lazy_cuts = lazy_cuts.map(lambda c: fastcopy(c, id=c.id + "-random-suffix"))
Expand Down

0 comments on commit 8c53f96

Please sign in to comment.