Skip to content

Commit

Permalink
Simplify the logic to initialize sox
Browse files Browse the repository at this point in the history
Differential Revision: D50197331

Pull Request resolved: pytorch#3654
  • Loading branch information
moto-meta authored Oct 13, 2023
1 parent f62367a commit dde08ba
Show file tree
Hide file tree
Showing 14 changed files with 107 additions and 131 deletions.
13 changes: 0 additions & 13 deletions src/libtorchaudio/sox/effects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,4 @@ auto apply_effects_file(
return std::tuple<torch::Tensor, int64_t>(
tensor, chain.getOutputSampleRate());
}

namespace {

TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"torchaudio::sox_effects_initialize_sox_effects",
&initialize_sox_effects);
m.def("torchaudio::sox_effects_shutdown_sox_effects", &shutdown_sox_effects);
m.def("torchaudio::sox_effects_apply_effects_tensor", &apply_effects_tensor);
m.def("torchaudio::sox_effects_apply_effects_file", &apply_effects_file);
}

} // namespace
} // namespace torchaudio::sox
7 changes: 0 additions & 7 deletions src/libtorchaudio/sox/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,4 @@ void save_audio_file(
chain.addOutputFile(sf);
chain.run();
}

TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::sox_io_get_info", &get_info_file);
m.def("torchaudio::sox_io_load_audio_file", &load_audio_file);
m.def("torchaudio::sox_io_save_audio_file", &save_audio_file);
}

} // namespace torchaudio::sox
12 changes: 12 additions & 0 deletions src/libtorchaudio/sox/pybind/pybind.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
#include <libtorchaudio/sox/effects.h>
#include <libtorchaudio/sox/io.h>
#include <libtorchaudio/sox/utils.h>
#include <torch/extension.h>

namespace torchaudio {
namespace sox {
namespace {

TORCH_LIBRARY(torchaudio_sox, m) {
m.def("torchaudio_sox::get_info", &get_info_file);
m.def("torchaudio_sox::load_audio_file", &load_audio_file);
m.def("torchaudio_sox::save_audio_file", &save_audio_file);
m.def("torchaudio_sox::initialize_sox_effects", &initialize_sox_effects);
m.def("torchaudio_sox::shutdown_sox_effects", &shutdown_sox_effects);
m.def("torchaudio_sox::apply_effects_tensor", &apply_effects_tensor);
m.def("torchaudio_sox::apply_effects_file", &apply_effects_file);
}

PYBIND11_MODULE(_torchaudio_sox, m) {
m.def("set_seed", &set_seed, "Set random seed.");
m.def("set_verbosity", &set_verbosity, "Set verbosity.");
Expand Down
11 changes: 6 additions & 5 deletions src/torchaudio/_backend/sox.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
from typing import BinaryIO, Optional, Tuple, Union

import torch
import torchaudio

from .backend import Backend
from .common import AudioMetaData

sox_ext = torchaudio._extension.lazy_import_sox_ext()


class SoXBackend(Backend):
@staticmethod
Expand All @@ -16,7 +19,7 @@ def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_s
"Please use an alternative backend that does support reading from file-like objects, e.g. FFmpeg.",
)
else:
sinfo = torch.ops.torchaudio.sox_io_get_info(uri, format)
sinfo = sox_ext.get_info(uri, format)
if sinfo:
return AudioMetaData(*sinfo)
else:
Expand All @@ -38,9 +41,7 @@ def load(
"Please use an alternative backend that does support loading from file-like objects, e.g. FFmpeg.",
)
else:
ret = torch.ops.torchaudio.sox_io_load_audio_file(
uri, frame_offset, num_frames, normalize, channels_first, format
)
ret = sox_ext.load_audio_file(uri, frame_offset, num_frames, normalize, channels_first, format)
if not ret:
raise RuntimeError(f"Failed to load audio from {uri}.")
return ret
Expand All @@ -62,7 +63,7 @@ def save(
"Please use an alternative backend that does support writing to file-like objects, e.g. FFmpeg.",
)
else:
torch.ops.torchaudio.sox_io_save_audio_file(
sox_ext.save_audio_file(
uri,
src,
sample_rate,
Expand Down
4 changes: 2 additions & 2 deletions src/torchaudio/_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from torchaudio._extension import _SOX_INITIALIZED, lazy_import_ffmpeg_ext
from torchaudio._extension import lazy_import_ffmpeg_ext, lazy_import_sox_ext

from . import soundfile_backend

Expand All @@ -20,7 +20,7 @@ def get_available_backends() -> Dict[str, Type[Backend]]:
backend_specs: Dict[str, Type[Backend]] = {}
if lazy_import_ffmpeg_ext().is_available():
backend_specs["ffmpeg"] = FFmpegBackend
if _SOX_INITIALIZED:
if lazy_import_sox_ext().is_available():
backend_specs["sox"] = SoXBackend
if soundfile_backend._IS_SOUNDFILE_AVAILABLE:
backend_specs["soundfile"] = SoundfileBackend
Expand Down
55 changes: 14 additions & 41 deletions src/torchaudio/_extension/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,9 @@
import os
import sys

from torchaudio._internal.module_utils import eval_env, fail_with_message, is_module_available, no_op

from .utils import (
_check_cuda_version,
_fail_since_no_sox,
_init_dll_path,
_init_ffmpeg,
_init_sox,
_LazyImporter,
_load_lib,
)
from torchaudio._internal.module_utils import fail_with_message, is_module_available, no_op

from .utils import _check_cuda_version, _init_dll_path, _init_ffmpeg, _init_sox, _LazyImporter, _load_lib

_LG = logging.getLogger(__name__)

Expand All @@ -22,11 +14,10 @@
# Builder uses it for debugging purpose, so we export it.
# https://github.com/pytorch/builder/blob/e2e4542b8eb0bdf491214451a1a4128bd606cce2/test/smoke_test/smoke_test.py#L80
__all__ = [
"fail_if_no_sox",
"_check_cuda_version",
"_IS_TORCHAUDIO_EXT_AVAILABLE",
"_IS_RIR_AVAILABLE",
"_SOX_INITIALIZED",
"lazy_import_sox_ext",
"lazy_import_ffmpeg_ext",
]

Expand Down Expand Up @@ -54,34 +45,16 @@
_IS_ALIGN_AVAILABLE = torchaudio.lib._torchaudio.is_align_available()


# Initialize libsox-related features
_SOX_INITIALIZED = False
_USE_SOX = False if os.name == "nt" else eval_env("TORCHAUDIO_USE_SOX", True)
_SOX_MODULE_AVAILABLE = is_module_available("torchaudio.lib._torchaudio_sox")
if _USE_SOX and _SOX_MODULE_AVAILABLE:
try:
_init_sox()
_SOX_INITIALIZED = True
except Exception:
# The initialization of sox extension will fail if supported sox
# libraries are not found in the system.
# Since the rest of the torchaudio works without it, we do not report the
# error here.
# The error will be raised when user code attempts to use these features.
_LG.debug("Failed to initialize sox extension", exc_info=True)


if os.name == "nt":
fail_if_no_sox = fail_with_message("requires sox extension, which is not supported on Windows.")
elif not _USE_SOX:
fail_if_no_sox = fail_with_message("requires sox extension, but it is disabled. (TORCHAUDIO_USE_SOX=0)")
elif not _SOX_MODULE_AVAILABLE:
fail_if_no_sox = fail_with_message(
"requires sox extension, but TorchAudio is not compiled with it. "
"Please build TorchAudio with libsox support. (BUILD_SOX=1)"
)
else:
fail_if_no_sox = no_op if _SOX_INITIALIZED else _fail_since_no_sox
_SOX_EXT = None


def lazy_import_sox_ext():
"""Load SoX integration based on availability in lazy manner"""

global _SOX_EXT
if _SOX_EXT is None:
_SOX_EXT = _LazyImporter("_torchaudio_sox", _init_sox)
return _SOX_EXT


_FFMPEG_EXT = None
Expand Down
64 changes: 39 additions & 25 deletions src/torchaudio/_extension/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import logging
import os
import types
from functools import wraps
from pathlib import Path

import torch
from torchaudio._internal.module_utils import eval_env

_LG = logging.getLogger(__name__)
_LIB_DIR = Path(__file__).parent.parent / "lib"
Expand Down Expand Up @@ -62,16 +62,49 @@ def _load_lib(lib: str) -> bool:
return True


def _init_sox():
def _import_sox_ext():
if os.name == "nt":
raise RuntimeError("sox extension is not supported on Windows")
if not eval_env("TORCHAUDIO_USE_SOX", True):
raise RuntimeError("sox extension is disabled. (TORCHAUDIO_USE_SOX=0)")

ext = "torchaudio.lib._torchaudio_sox"

if not importlib.util.find_spec(ext):
raise RuntimeError(
# fmt: off
"TorchAudio is not built with sox extension. "
"Please build TorchAudio with libsox support. (BUILD_SOX=1)"
# fmt: on
)

_load_lib("libtorchaudio_sox")
import torchaudio.lib._torchaudio_sox # noqa
return importlib.import_module(ext)

torchaudio.lib._torchaudio_sox.set_verbosity(0)

def _init_sox():
ext = _import_sox_ext()
ext.set_verbosity(0)

import atexit

torch.ops.torchaudio.sox_effects_initialize_sox_effects()
atexit.register(torch.ops.torchaudio.sox_effects_shutdown_sox_effects)
torch.ops.torchaudio_sox.initialize_sox_effects()
atexit.register(torch.ops.torchaudio_sox.shutdown_sox_effects)

# Bundle functions registered with TORCH_LIBRARY into extension
# so that they can also be accessed in the same (lazy) manner
# from the extension.
keys = [
"get_info",
"load_audio_file",
"save_audio_file",
"apply_effects_tensor",
"apply_effects_file",
]
for key in keys:
setattr(ext, key, getattr(torch.ops.torchaudio_sox, key))

return ext


_FFMPEG_VERS = ["6", "5", "4", ""]
Expand Down Expand Up @@ -197,22 +230,3 @@ def _check_cuda_version():
"Please install the TorchAudio version that matches your PyTorch version."
)
return version


def _fail_since_no_sox(func):
@wraps(func)
def wrapped(*_args, **_kwargs):
try:
# Note:
# We run _init_sox again just to show users the stacktrace.
# _init_sox would not succeed here.
_init_sox()
except Exception as err:
raise RuntimeError(
f"{func.__name__} requires sox extension which is not available. "
"Please refer to the stacktrace above for how to resolve this."
) from err
# This should not happen in normal execution, but just in case.
return func(*_args, **_kwargs)

return wrapped
13 changes: 5 additions & 8 deletions src/torchaudio/backend/_sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import torchaudio
from torchaudio import AudioMetaData

sox_ext = torchaudio._extension.lazy_import_sox_ext()


@torchaudio._extension.fail_if_no_sox
def info(
filepath: str,
format: Optional[str] = None,
Expand All @@ -29,11 +30,10 @@ def info(
if hasattr(filepath, "read"):
raise RuntimeError("sox_io backend does not support file-like object.")
filepath = os.fspath(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
sinfo = sox_ext.get_info(filepath, format)
return AudioMetaData(*sinfo)


@torchaudio._extension.fail_if_no_sox
def load(
filepath: str,
frame_offset: int = 0,
Expand Down Expand Up @@ -123,12 +123,9 @@ def load(
if hasattr(filepath, "read"):
raise RuntimeError("sox_io backend does not support file-like object.")
filepath = os.fspath(filepath)
return torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format
)
return sox_ext.load_audio_file(filepath, frame_offset, num_frames, normalize, channels_first, format)


@torchaudio._extension.fail_if_no_sox
def save(
filepath: str,
src: torch.Tensor,
Expand Down Expand Up @@ -285,7 +282,7 @@ def save(
if hasattr(filepath, "write"):
raise RuntimeError("sox_io backend does not handle file-like object.")
filepath = os.fspath(filepath)
torch.ops.torchaudio.sox_io_save_audio_file(
sox_ext.save_audio_file(
filepath,
src,
sample_rate,
Expand Down
10 changes: 6 additions & 4 deletions src/torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,6 @@ def spectral_centroid(
return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)


@torchaudio._extension.fail_if_no_sox
@deprecated("Please migrate to :py:class:`torchaudio.io.AudioEffector`.", remove=False)
def apply_codec(
waveform: Tensor,
Expand Down Expand Up @@ -1329,11 +1328,13 @@ def apply_codec(
Tensor: Resulting Tensor.
If ``channels_first=True``, it has `(channel, time)` else `(time, channel)`.
"""
from torchaudio.backend import _sox_io_backend

with tempfile.NamedTemporaryFile() as f:
torchaudio.backend.sox_io_backend.save(
torchaudio.backend._sox_io_backend.save(
f.name, waveform, sample_rate, channels_first, compression, format, encoding, bits_per_sample
)
augmented, sr = torchaudio.backend.sox_io_backend.load(f.name, channels_first=channels_first, format=format)
augmented, sr = _sox_io_backend.load(f.name, channels_first=channels_first, format=format)
if sr != sample_rate:
augmented = resample(augmented, sr, sample_rate)
return augmented
Expand Down Expand Up @@ -1371,7 +1372,8 @@ def _get_sinc_resample_kernel(
warnings.warn(
f'"{resampling_method}" resampling method name is being deprecated and replaced by '
f'"{method_map[resampling_method]}" in the next release. '
"The default behavior remains unchanged."
"The default behavior remains unchanged.",
stacklevel=3,
)
elif resampling_method not in ["sinc_interp_hann", "sinc_interp_kaiser"]:
raise ValueError("Invalid resampling method: {}".format(resampling_method))
Expand Down
Loading

0 comments on commit dde08ba

Please sign in to comment.