diff --git a/src/libtorchaudio/sox/effects.cpp b/src/libtorchaudio/sox/effects.cpp index 0c9fbc4fc7..ffc12e097c 100644 --- a/src/libtorchaudio/sox/effects.cpp +++ b/src/libtorchaudio/sox/effects.cpp @@ -130,17 +130,4 @@ auto apply_effects_file( return std::tuple( tensor, chain.getOutputSampleRate()); } - -namespace { - -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def( - "torchaudio::sox_effects_initialize_sox_effects", - &initialize_sox_effects); - m.def("torchaudio::sox_effects_shutdown_sox_effects", &shutdown_sox_effects); - m.def("torchaudio::sox_effects_apply_effects_tensor", &apply_effects_tensor); - m.def("torchaudio::sox_effects_apply_effects_file", &apply_effects_file); -} - -} // namespace } // namespace torchaudio::sox diff --git a/src/libtorchaudio/sox/io.cpp b/src/libtorchaudio/sox/io.cpp index 1ecd7dc965..395c40fa5f 100644 --- a/src/libtorchaudio/sox/io.cpp +++ b/src/libtorchaudio/sox/io.cpp @@ -125,11 +125,4 @@ void save_audio_file( chain.addOutputFile(sf); chain.run(); } - -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def("torchaudio::sox_io_get_info", &get_info_file); - m.def("torchaudio::sox_io_load_audio_file", &load_audio_file); - m.def("torchaudio::sox_io_save_audio_file", &save_audio_file); -} - } // namespace torchaudio::sox diff --git a/src/libtorchaudio/sox/pybind/pybind.cpp b/src/libtorchaudio/sox/pybind/pybind.cpp index 178a1e4014..bd9c82c349 100644 --- a/src/libtorchaudio/sox/pybind/pybind.cpp +++ b/src/libtorchaudio/sox/pybind/pybind.cpp @@ -1,3 +1,5 @@ +#include +#include #include #include @@ -5,6 +7,16 @@ namespace torchaudio { namespace sox { namespace { +TORCH_LIBRARY(torchaudio_sox, m) { + m.def("torchaudio_sox::get_info", &get_info_file); + m.def("torchaudio_sox::load_audio_file", &load_audio_file); + m.def("torchaudio_sox::save_audio_file", &save_audio_file); + m.def("torchaudio_sox::initialize_sox_effects", &initialize_sox_effects); + m.def("torchaudio_sox::shutdown_sox_effects", &shutdown_sox_effects); + m.def("torchaudio_sox::apply_effects_tensor", &apply_effects_tensor); + m.def("torchaudio_sox::apply_effects_file", &apply_effects_file); +} + PYBIND11_MODULE(_torchaudio_sox, m) { m.def("set_seed", &set_seed, "Set random seed."); m.def("set_verbosity", &set_verbosity, "Set verbosity."); diff --git a/src/torchaudio/_backend/sox.py b/src/torchaudio/_backend/sox.py index 549ab59033..592890c95c 100644 --- a/src/torchaudio/_backend/sox.py +++ b/src/torchaudio/_backend/sox.py @@ -2,10 +2,13 @@ from typing import BinaryIO, Optional, Tuple, Union import torch +import torchaudio from .backend import Backend from .common import AudioMetaData +sox_ext = torchaudio._extension.lazy_import_sox_ext() + class SoXBackend(Backend): @staticmethod @@ -16,7 +19,7 @@ def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_s "Please use an alternative backend that does support reading from file-like objects, e.g. FFmpeg.", ) else: - sinfo = torch.ops.torchaudio.sox_io_get_info(uri, format) + sinfo = sox_ext.get_info(uri, format) if sinfo: return AudioMetaData(*sinfo) else: @@ -38,9 +41,7 @@ def load( "Please use an alternative backend that does support loading from file-like objects, e.g. FFmpeg.", ) else: - ret = torch.ops.torchaudio.sox_io_load_audio_file( - uri, frame_offset, num_frames, normalize, channels_first, format - ) + ret = sox_ext.load_audio_file(uri, frame_offset, num_frames, normalize, channels_first, format) if not ret: raise RuntimeError(f"Failed to load audio from {uri}.") return ret @@ -62,7 +63,7 @@ def save( "Please use an alternative backend that does support writing to file-like objects, e.g. FFmpeg.", ) else: - torch.ops.torchaudio.sox_io_save_audio_file( + sox_ext.save_audio_file( uri, src, sample_rate, diff --git a/src/torchaudio/_backend/utils.py b/src/torchaudio/_backend/utils.py index 83d7729b9e..36cd5f11b1 100644 --- a/src/torchaudio/_backend/utils.py +++ b/src/torchaudio/_backend/utils.py @@ -4,7 +4,7 @@ import torch -from torchaudio._extension import _SOX_INITIALIZED, lazy_import_ffmpeg_ext +from torchaudio._extension import lazy_import_ffmpeg_ext, lazy_import_sox_ext from . import soundfile_backend @@ -20,7 +20,7 @@ def get_available_backends() -> Dict[str, Type[Backend]]: backend_specs: Dict[str, Type[Backend]] = {} if lazy_import_ffmpeg_ext().is_available(): backend_specs["ffmpeg"] = FFmpegBackend - if _SOX_INITIALIZED: + if lazy_import_sox_ext().is_available(): backend_specs["sox"] = SoXBackend if soundfile_backend._IS_SOUNDFILE_AVAILABLE: backend_specs["soundfile"] = SoundfileBackend diff --git a/src/torchaudio/_extension/__init__.py b/src/torchaudio/_extension/__init__.py index 5778f075c8..d19cb2956e 100644 --- a/src/torchaudio/_extension/__init__.py +++ b/src/torchaudio/_extension/__init__.py @@ -2,17 +2,9 @@ import os import sys -from torchaudio._internal.module_utils import eval_env, fail_with_message, is_module_available, no_op - -from .utils import ( - _check_cuda_version, - _fail_since_no_sox, - _init_dll_path, - _init_ffmpeg, - _init_sox, - _LazyImporter, - _load_lib, -) +from torchaudio._internal.module_utils import fail_with_message, is_module_available, no_op + +from .utils import _check_cuda_version, _init_dll_path, _init_ffmpeg, _init_sox, _LazyImporter, _load_lib _LG = logging.getLogger(__name__) @@ -22,11 +14,10 @@ # Builder uses it for debugging purpose, so we export it. # https://github.com/pytorch/builder/blob/e2e4542b8eb0bdf491214451a1a4128bd606cce2/test/smoke_test/smoke_test.py#L80 __all__ = [ - "fail_if_no_sox", "_check_cuda_version", "_IS_TORCHAUDIO_EXT_AVAILABLE", "_IS_RIR_AVAILABLE", - "_SOX_INITIALIZED", + "lazy_import_sox_ext", "lazy_import_ffmpeg_ext", ] @@ -54,34 +45,16 @@ _IS_ALIGN_AVAILABLE = torchaudio.lib._torchaudio.is_align_available() -# Initialize libsox-related features -_SOX_INITIALIZED = False -_USE_SOX = False if os.name == "nt" else eval_env("TORCHAUDIO_USE_SOX", True) -_SOX_MODULE_AVAILABLE = is_module_available("torchaudio.lib._torchaudio_sox") -if _USE_SOX and _SOX_MODULE_AVAILABLE: - try: - _init_sox() - _SOX_INITIALIZED = True - except Exception: - # The initialization of sox extension will fail if supported sox - # libraries are not found in the system. - # Since the rest of the torchaudio works without it, we do not report the - # error here. - # The error will be raised when user code attempts to use these features. - _LG.debug("Failed to initialize sox extension", exc_info=True) - - -if os.name == "nt": - fail_if_no_sox = fail_with_message("requires sox extension, which is not supported on Windows.") -elif not _USE_SOX: - fail_if_no_sox = fail_with_message("requires sox extension, but it is disabled. (TORCHAUDIO_USE_SOX=0)") -elif not _SOX_MODULE_AVAILABLE: - fail_if_no_sox = fail_with_message( - "requires sox extension, but TorchAudio is not compiled with it. " - "Please build TorchAudio with libsox support. (BUILD_SOX=1)" - ) -else: - fail_if_no_sox = no_op if _SOX_INITIALIZED else _fail_since_no_sox +_SOX_EXT = None + + +def lazy_import_sox_ext(): + """Load SoX integration based on availability in lazy manner""" + + global _SOX_EXT + if _SOX_EXT is None: + _SOX_EXT = _LazyImporter("_torchaudio_sox", _init_sox) + return _SOX_EXT _FFMPEG_EXT = None diff --git a/src/torchaudio/_extension/utils.py b/src/torchaudio/_extension/utils.py index ffad4bd735..b34af4bc02 100644 --- a/src/torchaudio/_extension/utils.py +++ b/src/torchaudio/_extension/utils.py @@ -9,10 +9,10 @@ import logging import os import types -from functools import wraps from pathlib import Path import torch +from torchaudio._internal.module_utils import eval_env _LG = logging.getLogger(__name__) _LIB_DIR = Path(__file__).parent.parent / "lib" @@ -62,16 +62,49 @@ def _load_lib(lib: str) -> bool: return True -def _init_sox(): +def _import_sox_ext(): + if os.name == "nt": + raise RuntimeError("sox extension is not supported on Windows") + if not eval_env("TORCHAUDIO_USE_SOX", True): + raise RuntimeError("sox extension is disabled. (TORCHAUDIO_USE_SOX=0)") + + ext = "torchaudio.lib._torchaudio_sox" + + if not importlib.util.find_spec(ext): + raise RuntimeError( + # fmt: off + "TorchAudio is not built with sox extension. " + "Please build TorchAudio with libsox support. (BUILD_SOX=1)" + # fmt: on + ) + _load_lib("libtorchaudio_sox") - import torchaudio.lib._torchaudio_sox # noqa + return importlib.import_module(ext) - torchaudio.lib._torchaudio_sox.set_verbosity(0) + +def _init_sox(): + ext = _import_sox_ext() + ext.set_verbosity(0) import atexit - torch.ops.torchaudio.sox_effects_initialize_sox_effects() - atexit.register(torch.ops.torchaudio.sox_effects_shutdown_sox_effects) + torch.ops.torchaudio_sox.initialize_sox_effects() + atexit.register(torch.ops.torchaudio_sox.shutdown_sox_effects) + + # Bundle functions registered with TORCH_LIBRARY into extension + # so that they can also be accessed in the same (lazy) manner + # from the extension. + keys = [ + "get_info", + "load_audio_file", + "save_audio_file", + "apply_effects_tensor", + "apply_effects_file", + ] + for key in keys: + setattr(ext, key, getattr(torch.ops.torchaudio_sox, key)) + + return ext _FFMPEG_VERS = ["6", "5", "4", ""] @@ -197,22 +230,3 @@ def _check_cuda_version(): "Please install the TorchAudio version that matches your PyTorch version." ) return version - - -def _fail_since_no_sox(func): - @wraps(func) - def wrapped(*_args, **_kwargs): - try: - # Note: - # We run _init_sox again just to show users the stacktrace. - # _init_sox would not succeed here. - _init_sox() - except Exception as err: - raise RuntimeError( - f"{func.__name__} requires sox extension which is not available. " - "Please refer to the stacktrace above for how to resolve this." - ) from err - # This should not happen in normal execution, but just in case. - return func(*_args, **_kwargs) - - return wrapped diff --git a/src/torchaudio/backend/_sox_io_backend.py b/src/torchaudio/backend/_sox_io_backend.py index 62b3f713b4..6af267b17a 100644 --- a/src/torchaudio/backend/_sox_io_backend.py +++ b/src/torchaudio/backend/_sox_io_backend.py @@ -5,8 +5,9 @@ import torchaudio from torchaudio import AudioMetaData +sox_ext = torchaudio._extension.lazy_import_sox_ext() + -@torchaudio._extension.fail_if_no_sox def info( filepath: str, format: Optional[str] = None, @@ -29,11 +30,10 @@ def info( if hasattr(filepath, "read"): raise RuntimeError("sox_io backend does not support file-like object.") filepath = os.fspath(filepath) - sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format) + sinfo = sox_ext.get_info(filepath, format) return AudioMetaData(*sinfo) -@torchaudio._extension.fail_if_no_sox def load( filepath: str, frame_offset: int = 0, @@ -123,12 +123,9 @@ def load( if hasattr(filepath, "read"): raise RuntimeError("sox_io backend does not support file-like object.") filepath = os.fspath(filepath) - return torch.ops.torchaudio.sox_io_load_audio_file( - filepath, frame_offset, num_frames, normalize, channels_first, format - ) + return sox_ext.load_audio_file(filepath, frame_offset, num_frames, normalize, channels_first, format) -@torchaudio._extension.fail_if_no_sox def save( filepath: str, src: torch.Tensor, @@ -285,7 +282,7 @@ def save( if hasattr(filepath, "write"): raise RuntimeError("sox_io backend does not handle file-like object.") filepath = os.fspath(filepath) - torch.ops.torchaudio.sox_io_save_audio_file( + sox_ext.save_audio_file( filepath, src, sample_rate, diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 75eeb07318..af34e707e5 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -1295,7 +1295,6 @@ def spectral_centroid( return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim) -@torchaudio._extension.fail_if_no_sox @deprecated("Please migrate to :py:class:`torchaudio.io.AudioEffector`.", remove=False) def apply_codec( waveform: Tensor, @@ -1329,11 +1328,13 @@ def apply_codec( Tensor: Resulting Tensor. If ``channels_first=True``, it has `(channel, time)` else `(time, channel)`. """ + from torchaudio.backend import _sox_io_backend + with tempfile.NamedTemporaryFile() as f: - torchaudio.backend.sox_io_backend.save( + torchaudio.backend._sox_io_backend.save( f.name, waveform, sample_rate, channels_first, compression, format, encoding, bits_per_sample ) - augmented, sr = torchaudio.backend.sox_io_backend.load(f.name, channels_first=channels_first, format=format) + augmented, sr = _sox_io_backend.load(f.name, channels_first=channels_first, format=format) if sr != sample_rate: augmented = resample(augmented, sr, sample_rate) return augmented @@ -1371,7 +1372,8 @@ def _get_sinc_resample_kernel( warnings.warn( f'"{resampling_method}" resampling method name is being deprecated and replaced by ' f'"{method_map[resampling_method]}" in the next release. ' - "The default behavior remains unchanged." + "The default behavior remains unchanged.", + stacklevel=3, ) elif resampling_method not in ["sinc_interp_hann", "sinc_interp_kaiser"]: raise ValueError("Invalid resampling method: {}".format(resampling_method)) diff --git a/src/torchaudio/sox_effects/sox_effects.py b/src/torchaudio/sox_effects/sox_effects.py index e7f8be7408..3d64d465ac 100644 --- a/src/torchaudio/sox_effects/sox_effects.py +++ b/src/torchaudio/sox_effects/sox_effects.py @@ -7,6 +7,9 @@ from torchaudio.utils.sox_utils import list_effects +sox_ext = torchaudio._extension.lazy_import_sox_ext() + + @deprecated("Please remove the call. This function is called automatically.") def init_sox_effects(): """Initialize resources required to use sox effects. @@ -36,7 +39,6 @@ def shutdown_sox_effects(): pass -@torchaudio._extension.fail_if_no_sox def effect_names() -> List[str]: """Gets list of valid sox effect names @@ -50,7 +52,6 @@ def effect_names() -> List[str]: return list(list_effects().keys()) -@torchaudio._extension.fail_if_no_sox def apply_effects_tensor( tensor: torch.Tensor, sample_rate: int, @@ -152,10 +153,9 @@ def apply_effects_tensor( >>> waveform, sample_rate = transform(waveform, input_sample_rate) >>> assert sample_rate == 8000 """ - return torch.ops.torchaudio.sox_effects_apply_effects_tensor(tensor, sample_rate, effects, channels_first) + return sox_ext.apply_effects_tensor(tensor, sample_rate, effects, channels_first) -@torchaudio._extension.fail_if_no_sox def apply_effects_file( path: str, effects: List[List[str]], @@ -269,4 +269,4 @@ def apply_effects_file( "Please use torchaudio.io.AudioEffector." ) path = os.fspath(path) - return torch.ops.torchaudio.sox_effects_apply_effects_file(path, effects, normalize, channels_first, format) + return sox_ext.apply_effects_file(path, effects, normalize, channels_first, format) diff --git a/src/torchaudio/utils/sox_utils.py b/src/torchaudio/utils/sox_utils.py index a978e8d1db..5212b77ea9 100644 --- a/src/torchaudio/utils/sox_utils.py +++ b/src/torchaudio/utils/sox_utils.py @@ -6,8 +6,9 @@ import torchaudio +sox_ext = torchaudio._extension.lazy_import_sox_ext() + -@torchaudio._extension.fail_if_no_sox def set_seed(seed: int): """Set libsox's PRNG @@ -17,10 +18,9 @@ def set_seed(seed: int): See Also: http://sox.sourceforge.net/sox.html """ - torchaudio.lib._torchaudio_sox.set_seed(seed) + sox_ext.set_seed(seed) -@torchaudio._extension.fail_if_no_sox def set_verbosity(verbosity: int): """Set libsox's verbosity @@ -35,10 +35,9 @@ def set_verbosity(verbosity: int): See Also: http://sox.sourceforge.net/sox.html """ - torchaudio.lib._torchaudio_sox.set_verbosity(verbosity) + sox_ext.set_verbosity(verbosity) -@torchaudio._extension.fail_if_no_sox def set_buffer_size(buffer_size: int): """Set buffer size for sox effect chain @@ -48,10 +47,9 @@ def set_buffer_size(buffer_size: int): See Also: http://sox.sourceforge.net/sox.html """ - torchaudio.lib._torchaudio_sox.set_buffer_size(buffer_size) + sox_ext.set_buffer_size(buffer_size) -@torchaudio._extension.fail_if_no_sox def set_use_threads(use_threads: bool): """Set multithread option for sox effect chain @@ -62,44 +60,40 @@ def set_use_threads(use_threads: bool): See Also: http://sox.sourceforge.net/sox.html """ - torchaudio.lib._torchaudio_sox.set_use_threads(use_threads) + sox_ext.set_use_threads(use_threads) -@torchaudio._extension.fail_if_no_sox def list_effects() -> Dict[str, str]: """List the available sox effect names Returns: Dict[str, str]: Mapping from ``effect name`` to ``usage`` """ - return dict(torchaudio.lib._torchaudio_sox.list_effects()) + return dict(sox_ext.list_effects()) -@torchaudio._extension.fail_if_no_sox def list_read_formats() -> List[str]: """List the supported audio formats for read Returns: List[str]: List of supported audio formats """ - return torchaudio.lib._torchaudio_sox.list_read_formats() + return sox_ext.list_read_formats() -@torchaudio._extension.fail_if_no_sox def list_write_formats() -> List[str]: """List the supported audio formats for write Returns: List[str]: List of supported audio formats """ - return torchaudio.lib._torchaudio_sox.list_write_formats() + return sox_ext.list_write_formats() -@torchaudio._extension.fail_if_no_sox def get_buffer_size() -> int: """Get buffer size for sox effect chain Returns: int: size in bytes of buffers used for processing audio. """ - return torchaudio.lib._torchaudio_sox.get_buffer_size() + return sox_ext.get_buffer_size() diff --git a/test/torchaudio_unittest/backend/dispatcher/sox/load_test.py b/test/torchaudio_unittest/backend/dispatcher/sox/load_test.py index 4ceef268f1..efa5808b58 100644 --- a/test/torchaudio_unittest/backend/dispatcher/sox/load_test.py +++ b/test/torchaudio_unittest/backend/dispatcher/sox/load_test.py @@ -293,6 +293,8 @@ def test_amr_nb(self): class TestLoadParams(TempDirMixin, PytorchTestCase): """Test the correctness of frame parameters of `sox_io_backend.load`""" + _load = partial(get_load_func(), backend="sox") + def _test(self, func, frame_offset, num_frames, channels_first, normalize): original = get_wav_data("int16", num_channels=2, normalize=False) path = self.get_temp_path("test.wav") @@ -316,7 +318,7 @@ def _test(self, func, frame_offset, num_frames, channels_first, normalize): def test_sox(self, frame_offset, num_frames, channels_first, normalize): """The combination of properly changes the output tensor""" - self._test(torch.ops.torchaudio.sox_io_load_audio_file, frame_offset, num_frames, channels_first, normalize) + self._test(self._load, frame_offset, num_frames, channels_first, normalize) @skipIfNoSox diff --git a/test/torchaudio_unittest/backend/sox_io/load_test.py b/test/torchaudio_unittest/backend/sox_io/load_test.py index 4cf2ff6a4b..89e456efb2 100644 --- a/test/torchaudio_unittest/backend/sox_io/load_test.py +++ b/test/torchaudio_unittest/backend/sox_io/load_test.py @@ -313,7 +313,7 @@ def _test(self, func, frame_offset, num_frames, channels_first, normalize): def test_sox(self, frame_offset, num_frames, channels_first, normalize): """The combination of properly changes the output tensor""" - self._test(torch.ops.torchaudio.sox_io_load_audio_file, frame_offset, num_frames, channels_first, normalize) + self._test(sox_io_backend.load, frame_offset, num_frames, channels_first, normalize) @skipIfNoSox diff --git a/test/torchaudio_unittest/common_utils/case_utils.py b/test/torchaudio_unittest/common_utils/case_utils.py index 243102b238..2b139e2697 100644 --- a/test/torchaudio_unittest/common_utils/case_utils.py +++ b/test/torchaudio_unittest/common_utils/case_utils.py @@ -112,6 +112,7 @@ class TorchaudioTestCase(TestBaseMixin, PytorchTestCase): _IS_FFMPEG_AVAILABLE = torchaudio._extension.lazy_import_ffmpeg_ext().is_available() +_IS_SOX_AVAILABLE = torchaudio._extension.lazy_import_sox_ext().is_available() _IS_CTC_DECODER_AVAILABLE = None _IS_CUDA_CTC_DECODER_AVAILABLE = None @@ -209,7 +210,7 @@ def skipIfNoModule(module, display_name=None): key="CUDA_SMALL_MEMORY", ) skipIfNoSox = _skipIf( - not torchaudio._extension._SOX_INITIALIZED, + not _IS_SOX_AVAILABLE, reason="Sox features are not available.", key="NO_SOX", ) @@ -217,7 +218,7 @@ def skipIfNoModule(module, display_name=None): def skipIfNoSoxDecoder(ext): return _skipIf( - not torchaudio._extension._SOX_INITIALIZED or ext not in torchaudio.utils.sox_utils.list_read_formats(), + not _IS_SOX_AVAILABLE or ext not in torchaudio.utils.sox_utils.list_read_formats(), f'sox does not handle "{ext}" for read.', key="NO_SOX_DECODER", ) @@ -225,7 +226,7 @@ def skipIfNoSoxDecoder(ext): def skipIfNoSoxEncoder(ext): return _skipIf( - not torchaudio._extension._SOX_INITIALIZED or ext not in torchaudio.utils.sox_utils.list_write_formats(), + not _IS_SOX_AVAILABLE or ext not in torchaudio.utils.sox_utils.list_write_formats(), f'sox does not handle "{ext}" for write.', key="NO_SOX_ENCODER", )