From 59bc3e68e6affff027ad047bc8967404635f3206 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 21 Sep 2023 10:44:05 -0700 Subject: [PATCH] WIP: Add Ray Tracing (#3604) Summary: Revamped version of https://github.com/pytorch/audio/pull/3234 (which was also revamp of https://github.com/pytorch/audio/pull/2850) Differential Revision: D49197174 Pulled By: mthrok --- docs/source/prototype.functional.rst | 1 + .../functional/functional_test_impl.py | 223 ++++++++++++++++++ .../pyroomacoustics_compatibility_test.py | 78 ++++++ .../torchscript_consistency_test_impl.py | 40 ++++ torchaudio/prototype/functional/__init__.py | 3 +- torchaudio/prototype/functional/_rir.py | 100 ++++++++ 6 files changed, 444 insertions(+), 1 deletion(-) diff --git a/docs/source/prototype.functional.rst b/docs/source/prototype.functional.rst index 72f390c71ac..f0527b4c7ae 100644 --- a/docs/source/prototype.functional.rst +++ b/docs/source/prototype.functional.rst @@ -35,4 +35,5 @@ Room Impulse Response Simulation :toctree: generated :nosignatures: + ray_tracing simulate_rir_ism diff --git a/test/torchaudio_unittest/prototype/functional/functional_test_impl.py b/test/torchaudio_unittest/prototype/functional/functional_test_impl.py index 61e6d869337..81d6cbbb571 100644 --- a/test/torchaudio_unittest/prototype/functional/functional_test_impl.py +++ b/test/torchaudio_unittest/prototype/functional/functional_test_impl.py @@ -460,3 +460,226 @@ def _debug_plot(): except AssertionError: _debug_plot() raise + + @parameterized.expand( + [ + (0.1, 0.2, (2, 1, 2500)), # both float + # Per-wall + (torch.rand(4), 0.2, (2, 1, 2500)), + (0.1, torch.rand(4), (2, 1, 2500)), + (torch.rand(4), torch.rand(4), (2, 1, 2500)), + # Per-band and per-wall + (torch.rand(6, 4), 0.2, (2, 6, 2500)), + (0.1, torch.rand(6, 4), (2, 6, 2500)), + (torch.rand(6, 4), torch.rand(6, 4), (2, 6, 2500)), + ] + ) + def test_ray_tracing_output_shape(self, absorption, scattering, expected_shape): + room_dim = torch.tensor([20, 25], dtype=self.dtype) + mic_array = torch.tensor([[2, 2], [8, 8]], dtype=self.dtype) + source = torch.tensor([7, 6], dtype=self.dtype) + num_rays = 100 + + hist = F.ray_tracing( + room=room_dim, + source=source, + mic_array=mic_array, + num_rays=num_rays, + absorption=absorption, + scattering=scattering, + ) + + assert hist.shape == expected_shape + + def test_ray_tracing_input_errors(self): + with self.assertRaisesRegex(ValueError, "room must be a 1D tensor"): + F.ray_tracing( + room=torch.tensor([[4, 5]]), source=torch.tensor([0, 0]), mic_array=torch.tensor([[3, 4]]), num_rays=10 + ) + with self.assertRaisesRegex(ValueError, "room must be a 1D tensor"): + F.ray_tracing( + room=torch.tensor([4, 5, 4, 5]), + source=torch.tensor([0, 0]), + mic_array=torch.tensor([[3, 4]]), + num_rays=10, + ) + with self.assertRaisesRegex(ValueError, r"mic_array must be 1D tensor of shape \(D,\), or 2D tensor"): + F.ray_tracing( + room=torch.tensor([4, 5]), source=torch.tensor([0, 0]), mic_array=torch.tensor([[[3, 4]]]), num_rays=10 + ) + with self.assertRaisesRegex(ValueError, "room must be of float32 or float64 dtype"): + F.ray_tracing( + room=torch.tensor([4, 5]).to(torch.int), + source=torch.tensor([0, 0]), + mic_array=torch.tensor([3, 4]), + num_rays=10, + ) + with self.assertRaisesRegex(ValueError, "dtype of room, source and mic_array must be the same"): + F.ray_tracing( + room=torch.tensor([4, 5]).to(torch.float64), + source=torch.tensor([0, 0]).to(torch.float32), + mic_array=torch.tensor([3, 4]), + num_rays=10, + ) + with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"): + F.ray_tracing( + room=torch.tensor([4, 5, 10], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + ) + with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"): + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + ) + with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"): + F.ray_tracing( + room=torch.tensor([4, 5, 10], dtype=torch.float), + source=torch.tensor([0, 0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + ) + with self.assertRaisesRegex(ValueError, "time_thres=10 must be at least greater than hist_bin_size=11"): + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + time_thres=10, + hist_bin_size=11, + ) + with self.assertRaisesRegex(ValueError, "The shape of absorption must be"): + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + absorption=torch.rand(5, dtype=torch.float), + ) + with self.assertRaisesRegex(ValueError, "The shape of scattering must be"): + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + scattering=torch.rand(5, 5, dtype=torch.float), + ) + with self.assertRaisesRegex(ValueError, "The shape of absorption must be"): + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + absorption=torch.rand(5, 5, dtype=torch.float), + ) + with self.assertRaisesRegex(ValueError, "The shape of scattering must be"): + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + scattering=torch.rand(5, dtype=torch.float), + ) + with self.assertRaisesRegex( + ValueError, "absorption and scattering must have the same number of bands and walls" + ): + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + absorption=torch.rand(6, 4, dtype=torch.float), + scattering=torch.rand(5, 4, dtype=torch.float), + ) + + # Make sure passing different shapes for absorption or scattering doesn't raise an error + # float and tensor + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + absorption=0.1, + scattering=torch.rand(5, 4, dtype=torch.float), + ) + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + absorption=torch.rand(5, 4, dtype=torch.float), + scattering=0.1, + ) + # per-wall only and per-band + per-wall + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + absorption=torch.rand(4, dtype=torch.float), + scattering=torch.rand(6, 4, dtype=torch.float), + ) + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + absorption=torch.rand(6, 4, dtype=torch.float), + scattering=torch.rand(4, dtype=torch.float), + ) + + def test_ray_tracing_per_band_per_wall_absorption(self): + """Check that when the value of absorption and scattering are the same + across walls and frequency bands, the output histograms are: + - all equal across frequency bands + - equal to simply passing a float value instead of a (num_bands, D) or + (D,) tensor. + """ + + room_dim = torch.tensor([20, 25], dtype=self.dtype) + mic_array = torch.tensor([[2, 2], [8, 8]], dtype=self.dtype) + source = torch.tensor([7, 6], dtype=self.dtype) + num_rays = 1_000 + ABS, SCAT = 0.1, 0.2 + + absorption = torch.full(fill_value=ABS, size=(6, 4), dtype=self.dtype) + scattering = torch.full(fill_value=SCAT, size=(6, 4), dtype=self.dtype) + hist_per_band_per_wall = F.ray_tracing( + room=room_dim, + source=source, + mic_array=mic_array, + num_rays=num_rays, + absorption=absorption, + scattering=scattering, + ) + absorption = torch.full(fill_value=ABS, size=(4,), dtype=self.dtype) + scattering = torch.full(fill_value=SCAT, size=(4,), dtype=self.dtype) + hist_per_wall = F.ray_tracing( + room=room_dim, + source=source, + mic_array=mic_array, + num_rays=num_rays, + absorption=absorption, + scattering=scattering, + ) + + absorption = ABS + scattering = SCAT + hist_single = F.ray_tracing( + room=room_dim, + source=source, + mic_array=mic_array, + num_rays=num_rays, + absorption=absorption, + scattering=scattering, + ) + assert hist_per_band_per_wall.shape == (2, 6, 2500) + assert hist_per_wall.shape == (2, 1, 2500) + assert hist_single.shape == (2, 1, 2500) + torch.testing.assert_close(hist_single, hist_per_wall) + + hist_single = hist_single.expand(2, 6, 2500) + torch.testing.assert_close(hist_single, hist_per_band_per_wall) diff --git a/test/torchaudio_unittest/prototype/functional/pyroomacoustics_compatibility_test.py b/test/torchaudio_unittest/prototype/functional/pyroomacoustics_compatibility_test.py index 59da1a49ac2..29f2caa5b62 100644 --- a/test/torchaudio_unittest/prototype/functional/pyroomacoustics_compatibility_test.py +++ b/test/torchaudio_unittest/prototype/functional/pyroomacoustics_compatibility_test.py @@ -1,3 +1,4 @@ +import numpy as np import torch import torchaudio.prototype.functional as F @@ -91,3 +92,80 @@ def test_simulate_rir_ism_multi_band(self, channel): expected[i, 0 : room.rir[i][0].shape[0]] = torch.from_numpy(room.rir[i][0]) actual = F.simulate_rir_ism(room_dim, source, mic_array, max_order, absorption) self.assertEqual(expected, actual, atol=1e-3, rtol=1e-3) + + @parameterized.expand( + [ + ([20, 25], [2, 2], [[8, 8], [7, 6]], 10_000), # 2D with 2 mics + ([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 1_000), # 3D with 1 mic + ] + ) + def test_ray_tracing_same_results_as_pyroomacoustics(self, room_dim, source, mic_array, num_rays): + + walls = ["west", "east", "south", "north"] + if len(room_dim) == 3: + walls += ["floor", "ceiling"] + num_walls = len(walls) + num_bands = 6 # Note: in ray tracing, we don't need to restrict the number of bands to 7 + + absorption = torch.rand(num_bands, num_walls, dtype=self.dtype) + scattering = torch.rand(num_bands, num_walls, dtype=self.dtype) + energy_thres = 1e-7 + time_thres = 10.0 + hist_bin_size = 0.004 + mic_radius = 0.5 + sound_speed = 343.0 + + room_dim = torch.tensor(room_dim, dtype=self.dtype) + source = torch.tensor(source, dtype=self.dtype) + mic_array = torch.tensor(mic_array, dtype=self.dtype) + + room = pra.ShoeBox( + room_dim.tolist(), + ray_tracing=True, + materials={ + walls[i]: pra.Material( + energy_absorption={ + "coeffs": absorption[:, i].reshape(-1).detach().numpy(), + "center_freqs": 125 * 2 ** np.arange(num_bands), + }, + scattering={ + "coeffs": scattering[:, i].reshape(-1).detach().numpy(), + "center_freqs": 125 * 2 ** np.arange(num_bands), + }, + ) + for i in range(num_walls) + }, + air_absorption=False, + max_order=0, # Make sure PRA doesn't use the hybrid method (we just want ray tracing) + ) + room.add_microphone_array(mic_array.T.tolist()) + room.add_source(source.tolist()) + room.set_ray_tracing( + n_rays=num_rays, + energy_thres=energy_thres, + time_thres=time_thres, + hist_bin_size=hist_bin_size, + receiver_radius=mic_radius, + ) + room.set_sound_speed(sound_speed) + + room.compute_rir() + hist_pra = torch.tensor(np.array(room.rt_histograms))[:, 0, 0] + + hist = F.ray_tracing( + room=room_dim, + source=source, + mic_array=mic_array, + num_rays=num_rays, + absorption=absorption, + scattering=scattering, + sound_speed=sound_speed, + mic_radius=mic_radius, + energy_thres=energy_thres, + time_thres=time_thres, + hist_bin_size=hist_bin_size, + ) + + assert hist.ndim == 3 + assert hist.shape == hist_pra.shape + self.assertEqual(hist.to(torch.float32), hist_pra) diff --git a/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py b/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py index 60806d9afbc..2f769fcb0ae 100644 --- a/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py +++ b/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py @@ -112,3 +112,43 @@ def test_simulate_rir_ism_multi_band(self, channel): F.simulate_rir_ism, (room_dim, source, mic_array, max_order, absorption, None, 81, center_frequency, 343.0, 16000.0), ) + + @parameterized.expand( + [ + ([20, 25], [2, 2], [[8, 8], [7, 6]], 1_000), # 2D with 2 mics + ([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 500), # 3D with 1 mic + ] + ) + def test_ray_tracing(self, room_dim, source, mic_array, num_rays): + num_walls = 4 if len(room_dim) == 2 else 6 + num_bands = 3 + + absorption = torch.rand(num_bands, num_walls, dtype=torch.float32) + scattering = torch.rand(num_bands, num_walls, dtype=torch.float32) + + energy_thres = 1e-7 + time_thres = 10.0 + hist_bin_size = 0.004 + mic_radius = 0.5 + sound_speed = 343.0 + + room_dim = torch.tensor(room_dim, dtype=self.dtype) + source = torch.tensor(source, dtype=self.dtype) + mic_array = torch.tensor(mic_array, dtype=self.dtype) + + self._assert_consistency( + F.ray_tracing, + ( + room_dim, + source, + mic_array, + num_rays, + absorption, + scattering, + mic_radius, + sound_speed, + energy_thres, + time_thres, + hist_bin_size, + ), + ) diff --git a/torchaudio/prototype/functional/__init__.py b/torchaudio/prototype/functional/__init__.py index 3c08461b70f..20bc181731e 100644 --- a/torchaudio/prototype/functional/__init__.py +++ b/torchaudio/prototype/functional/__init__.py @@ -7,7 +7,7 @@ oscillator_bank, sinc_impulse_response, ) -from ._rir import simulate_rir_ism +from ._rir import ray_tracing, simulate_rir_ism from .functional import barkscale_fbanks, chroma_filterbank @@ -20,6 +20,7 @@ "filter_waveform", "frequency_impulse_response", "oscillator_bank", + "ray_tracing", "sinc_impulse_response", "simulate_rir_ism", ] diff --git a/torchaudio/prototype/functional/_rir.py b/torchaudio/prototype/functional/_rir.py index 1f44d8c47a8..ded0ab44f9b 100644 --- a/torchaudio/prototype/functional/_rir.py +++ b/torchaudio/prototype/functional/_rir.py @@ -272,3 +272,103 @@ def simulate_rir_ism( rir = rir[..., :output_length] return rir + + +def ray_tracing( + room: torch.Tensor, + source: torch.Tensor, + mic_array: torch.Tensor, + num_rays: int, + absorption: Union[float, torch.Tensor] = 0.0, + scattering: Union[float, torch.Tensor] = 0.0, + mic_radius: float = 0.5, + sound_speed: float = 343.0, + energy_thres: float = 1e-7, + time_thres: float = 10.0, + hist_bin_size: float = 0.004, +) -> torch.Tensor: + r"""Compute energy histogram via ray tracing. + + The implementation is based on *pyroomacoustics* :cite:`scheibler2018pyroomacoustics`. + + ``num_rays`` rays are casted uniformly in all directions from the source; when a ray intersects a wall, + it is reflected and part of its energy is absorbed. It is also scattered (sent directly to the microphone(s)) + according to the ``scattering`` coefficient. When a ray is close to the microphone, its current energy is + recorded in the output histogram for that given time slot. + + .. devices:: CPU + + .. properties:: TorchScript + + Args: + room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents + three dimensions of the room. + source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`. + mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`. + absorption (float or torch.Tensor, optional): The absorption coefficients of wall materials. + (Default: ``0.0``). + If the dtype is ``float``, the absorption coefficient is identical to all walls and + all frequencies. + If ``absorption`` is a 1D Tensor, the shape must be `(4,)` if the room is a 2D room, + representing absorption coefficients of ``"west"``, ``"east"``, ``"south"``, and + ``"north"`` walls, respectively. + Or the shape must be `(6,)` if the room is a 3D room, representing absorption coefficients + of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, and ``"ceiling"``, respectively. + If ``absorption`` is a 2D Tensor, the shape must be `(num_bands, 4)` if the room is a 2D room, + or `(num_bands, 6)` if the room is a 3D room. ``num_bands`` is the number of frequency bands (usually 7), + but you can choose other values. + scattering(float or torch.Tensor, optional): The scattering coefficients of wall materials. + (Default: ``0.0``). The shape and type of this parameter is the same as for ``absorption``. + mic_radius(float, optional): The radius of the microphone in meters. (Default: 0.5) + sound_speed (float, optional): The speed of sound in meters per second. (Default: ``343.0``) + energy_thres (float, optional): The energy level below which we stop tracing a ray. (Default: ``1e-7``). + The initial energy of each ray is ``2 / num_rays``. + time_thres (float, optional): The maximal duration (in seconds) for which rays are traced. (Default: 10.0) + hist_bin_size (float, optional): The size (in seconds) of each bin in the output histogram. (Default: 0.004) + Returns: + (torch.Tensor): The 3D histogram(s) where the energy of the traced ray is recorded. Each bin corresponds + to a given time slot. The shape is `(channel, num_bands, num_bins)` + where ``num_bins = ceil(time_thres / hist_bin_size)``. If both ``absorption`` and ``scattering`` + are floats, then ``num_bands == 1``. + """ + if time_thres < hist_bin_size: + raise ValueError(f"time_thres={time_thres} must be at least greater than hist_bin_size={hist_bin_size}.") + dim: int + if room.numel() == 2: + dim = 2 + elif room.numel() == 3: + dim = 3 + else: + raise ValueError(f"`room` must be 1D tensor with 2 or 3 elements. Found: {room.shape}") + + _validate_inputs(dim, room, source, mic_array) + absorption = _adjust_coeff(dim, absorption, "absorption") + if scattering is not None: + scattering = _adjust_coeff(dim, scattering, "scattering") + + # Bring absorption and scattering to the same shape + if absorption.shape[0] == 1 and scattering.shape[0] > 1: + absorption = absorption.expand(scattering.shape) + if scattering.shape[0] == 1 and absorption.shape[0] > 1: + scattering = scattering.expand(absorption.shape) + if absorption.shape != scattering.shape: + raise ValueError( + "absorption and scattering must have the same number of bands and walls. " + f"Inferred shapes are {absorption.shape} and {scattering.shape}" + ) + + histograms = torch.ops.torchaudio.ray_tracing( + room, + source, + mic_array, + num_rays, + absorption, + scattering, + mic_radius, + sound_speed, + energy_thres, + time_thres, + hist_bin_size, + ) + + return histograms