Skip to content

Commit

Permalink
Support resampling when torchaudio is missing using scipy
Browse files Browse the repository at this point in the history
  • Loading branch information
pzelasko committed Jan 4, 2024
1 parent 2303883 commit d167bdd
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 9 deletions.
35 changes: 26 additions & 9 deletions lhotse/augmentation/torchaudio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from dataclasses import dataclass
from decimal import ROUND_HALF_UP
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -11,6 +11,7 @@
Seconds,
compute_num_samples,
during_docs_build,
is_module_available,
is_torchaudio_available,
perturb_num_samples,
)
Expand Down Expand Up @@ -181,19 +182,35 @@ class Resample(AudioTransform):
def __post_init__(self):
self.source_sampling_rate = int(self.source_sampling_rate)
self.target_sampling_rate = int(self.target_sampling_rate)
self.resampler = get_or_create_resampler(
self.source_sampling_rate, self.target_sampling_rate
)
if not is_torchaudio_available():
assert is_module_available(

Check warning on line 186 in lhotse/augmentation/torchaudio.py

View check run for this annotation

Codecov / codecov/patch

lhotse/augmentation/torchaudio.py#L186

Added line #L186 was not covered by tests
"scipy"
), "In order to use resampling, either torchaudio or scipy needs to be installed."
else:
self.resampler = get_or_create_resampler(
self.source_sampling_rate, self.target_sampling_rate
)

def __call__(self, samples: np.ndarray, *args, **kwargs) -> np.ndarray:
check_for_torchaudio()
if self.source_sampling_rate == self.target_sampling_rate:
return samples

if isinstance(samples, np.ndarray):
samples = torch.from_numpy(samples)
augmented = self.resampler(samples)
return augmented.numpy()
if is_torchaudio_available():
if isinstance(samples, np.ndarray):
samples = torch.from_numpy(samples)
augmented = self.resampler(samples)
return augmented.numpy()
else:
import scipy

Check warning on line 204 in lhotse/augmentation/torchaudio.py

View check run for this annotation

Codecov / codecov/patch

lhotse/augmentation/torchaudio.py#L204

Added line #L204 was not covered by tests

gcd = np.gcd(self.source_sampling_rate, self.target_sampling_rate)
augmented = scipy.signal.resample_poly(

Check warning on line 207 in lhotse/augmentation/torchaudio.py

View check run for this annotation

Codecov / codecov/patch

lhotse/augmentation/torchaudio.py#L206-L207

Added lines #L206 - L207 were not covered by tests
samples,
up=self.target_sampling_rate // gcd,
down=self.source_sampling_rate // gcd,
axis=-1,
)
return augmented

Check warning on line 213 in lhotse/augmentation/torchaudio.py

View check run for this annotation

Codecov / codecov/patch

lhotse/augmentation/torchaudio.py#L213

Added line #L213 was not covered by tests

def reverse_timestamps(
self, offset: Seconds, duration: Optional[Seconds], sampling_rate: int
Expand Down
13 changes: 13 additions & 0 deletions test/test_missing_torchaudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ def test_lhotse_load_audio():
assert isinstance(audio, np.ndarray)


@notorchaudio
@pytest.mark.parametrize("sr", [8000, 16000, 22500, 24000, 44100])
def test_lhotse_resample(sr):
import lhotse

cuts = lhotse.CutSet.from_file("test/fixtures/libri/cuts.json")
cut = cuts[0]
cut = cut.resample(sr)
audio = cut.load_audio()
assert isinstance(audio, np.ndarray)
assert audio.shape == (1, cut.num_samples)


@notorchaudio
def test_lhotse_audio_in_memory():
import lhotse
Expand Down

0 comments on commit d167bdd

Please sign in to comment.