From 49be5918f6ebe4ae39555aa0d629c7ccb3dbc2bf Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 21 Sep 2023 10:44:05 -0700 Subject: [PATCH] WIP: Add Ray Tracing (#3604) Summary: Revamped version of https://github.com/pytorch/audio/pull/3234 (which was also revamp of https://github.com/pytorch/audio/pull/2850) Differential Revision: D49197174 Pulled By: mthrok --- docs/source/prototype.functional.rst | 1 + .../functional/functional_test_impl.py | 249 ++++++++++++++++++ .../pyroomacoustics_compatibility_test.py | 90 +++++++ .../torchscript_consistency_test_impl.py | 39 +++ torchaudio/csrc/rir/ray_tracing.cpp | 6 +- torchaudio/csrc/rir/wall.h | 5 +- torchaudio/prototype/functional/__init__.py | 3 +- torchaudio/prototype/functional/_rir.py | 117 +++++++- 8 files changed, 500 insertions(+), 10 deletions(-) diff --git a/docs/source/prototype.functional.rst b/docs/source/prototype.functional.rst index 72f390c71ac..f0527b4c7ae 100644 --- a/docs/source/prototype.functional.rst +++ b/docs/source/prototype.functional.rst @@ -35,4 +35,5 @@ Room Impulse Response Simulation :toctree: generated :nosignatures: + ray_tracing simulate_rir_ism diff --git a/test/torchaudio_unittest/prototype/functional/functional_test_impl.py b/test/torchaudio_unittest/prototype/functional/functional_test_impl.py index 61e6d869337..114a2b5cc8d 100644 --- a/test/torchaudio_unittest/prototype/functional/functional_test_impl.py +++ b/test/torchaudio_unittest/prototype/functional/functional_test_impl.py @@ -412,6 +412,255 @@ def test_exp_sigmoid_input_diff(self, linspace_input_values, exp_sigmoid_paramet self.assertEqual(torch_out, torch.tensor(np_out)) + @parameterized.expand( + [ + # both float + (0.1, 0.2, (2, 1, 2500)), + # Per-wall + ((6,), 0.2, (2, 1, 2500)), + (0.1, (6,), (2, 1, 2500)), + ((6,), (6,), (2, 1, 2500)), + # Per-band and per-wall + ((3, 6), 0.2, (2, 3, 2500)), + (0.1, (5, 6), (2, 5, 2500)), + ((7, 6), (7, 6), (2, 7, 2500)), + ] + ) + def test_ray_tracing_output_shape(self, abs_, scat_, expected_shape): + if isinstance(abs_, float): + absorption = abs_ + else: + absorption = torch.rand(abs_, dtype=self.dtype) + if isinstance(scat_, float): + scattering = scat_ + else: + scattering = torch.rand(scat_, dtype=self.dtype) + + room_dim = torch.tensor([3, 4, 5], dtype=self.dtype) + mic_array = torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=self.dtype) + source = torch.tensor([1, 2, 3], dtype=self.dtype) + num_rays = 100 + + hist = F.ray_tracing( + room=room_dim, + source=source, + mic_array=mic_array, + num_rays=num_rays, + absorption=absorption, + scattering=scattering, + ) + assert hist.shape == expected_shape + + def test_ray_tracing_input_errors(self): + room = torch.tensor([3.0, 4.0, 5.0], dtype=self.dtype) + source = torch.tensor([0.0, 0.0, 0.0], dtype=self.dtype) + mic = torch.tensor([[1.0, 2.0, 3.0]], dtype=self.dtype) + + # baseline. This should not raise + _ = F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10) + + # invlaid room shape + for invalid in ([[4, 5]], [4, 5, 4, 5]): + invalid = torch.tensor(invalid, dtype=self.dtype) + with self.assertRaises(ValueError) as cm: + F.ray_tracing(room=invalid, source=source, mic_array=mic, num_rays=10) + + error = str(cm.exception) + self.assertIn("`room` must be a 1D Tensor with 3 elements.", error) + self.assertIn(str(invalid.shape), error) + + # invalid microphone shape + invalid = torch.tensor([[[3, 4]]], dtype=self.dtype) + with self.assertRaises(ValueError) as cm: + F.ray_tracing(room=room, source=source, mic_array=invalid, num_rays=10) + + error = str(cm.exception) + self.assertIn("`mic_array` must be a 2D Tensor with shape (num_channels, 3).", error) + self.assertIn(str(invalid.shape), error) + + # incompatible dtypes + with self.assertRaises(ValueError) as cm: + F.ray_tracing( + room=room.to(torch.float64), + source=source.to(torch.float32), + mic_array=mic, + num_rays=10, + ) + error = str(cm.exception) + self.assertIn("dtype of `room`, `source` and `mic_array` must match.", error) + self.assertIn("`room` (torch.float64)", error) + self.assertIn("`source` (torch.float32)", error) + self.assertIn("`mic_array` (torch.float32)", error) + + # invalid time configuration + with self.assertRaises(ValueError) as cm: + F.ray_tracing( + room=room, + source=source, + mic_array=mic, + num_rays=10, + time_thres=10, + hist_bin_size=11, + ) + error = str(cm.exception) + self.assertIn("`time_thres` must be greater than `hist_bin_size`.", error) + self.assertIn("hist_bin_size=11", error) + self.assertIn("time_thres=10", error) + + # invalid absorption shape 1D + invalid_abs = torch.tensor([1, 2, 3], dtype=self.dtype) + with self.assertRaises(ValueError) as cm: + F.ray_tracing( + room=room, + source=source, + mic_array=mic, + num_rays=10, + absorption=invalid_abs, + ) + error = str(cm.exception) + self.assertIn("The shape of `absorption` must be (6,) when", error) + self.assertIn(str(invalid_abs.shape), error) + + # invalid absorption shape 2D + invalid_abs = torch.tensor([[1, 2, 3]], dtype=self.dtype) + with self.assertRaises(ValueError) as cm: + F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, absorption=invalid_abs) + error = str(cm.exception) + self.assertIn("The shape of `absorption` must be (NUM_BANDS, 6) when", error) + self.assertIn(str(invalid_abs.shape), error) + + # invalid scattering shape 1D + invalid_scat = torch.tensor([1, 2, 3], dtype=self.dtype) + with self.assertRaises(ValueError) as cm: + F.ray_tracing( + room=room, + source=source, + mic_array=mic, + num_rays=10, + scattering=invalid_scat, + ) + error = str(cm.exception) + self.assertIn("The shape of `scattering` must be (6,) when", error) + self.assertIn(str(invalid_scat.shape), error) + + # invalid scattering shape 2D + invalid_scat = torch.tensor([[1, 2, 3]], dtype=self.dtype) + with self.assertRaises(ValueError) as cm: + F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, scattering=invalid_scat) + error = str(cm.exception) + self.assertIn("The shape of `scattering` must be (NUM_BANDS, 6) when", error) + self.assertIn(str(invalid_scat.shape), error) + + # TODO: Invalid absorption/scattering value + + # incompatible scattering and absorption + abs_ = torch.zeros((7, 6), dtype=self.dtype) + scat = torch.zeros((5, 6), dtype=self.dtype) + with self.assertRaises(ValueError) as cm: + F.ray_tracing( + room=room, + source=source, + mic_array=mic, + num_rays=10, + absorption=abs_, + scattering=scat, + ) + error = str(cm.exception) + self.assertIn( + "`absorption` and `scattering` must be broadcastable to the same number of bands and walls", error + ) + self.assertIn(f"absorption={abs_.shape}", error) + self.assertIn(f"scattering={scat.shape}", error) + + # Make sure passing different shapes for absorption or scattering doesn't raise an error + # float and tensor + F.ray_tracing( + room=room, + source=source, + mic_array=mic, + num_rays=10, + absorption=0.1, + scattering=torch.randn((5, 6), dtype=self.dtype), + ) + F.ray_tracing( + room=room, + source=source, + mic_array=mic, + num_rays=10, + absorption=torch.randn((7, 6), dtype=self.dtype), + scattering=0.1, + ) + # per-wall only and per-band + per-wall + F.ray_tracing( + room=room, + source=source, + mic_array=mic, + num_rays=10, + absorption=torch.rand(6, dtype=self.dtype), + scattering=torch.rand(7, 6, dtype=self.dtype), + ) + F.ray_tracing( + room=room, + source=source, + mic_array=mic, + num_rays=10, + absorption=torch.rand(7, 6, dtype=self.dtype), + scattering=torch.rand(6, dtype=self.dtype), + ) + + def test_ray_tracing_per_band_per_wall_absorption(self): + """Check that when the value of absorption and scattering are the same + across walls and frequency bands, the output histograms are: + - all equal across frequency bands + - equal to simply passing a float value instead of a (num_bands, D) or + (D,) tensor. + """ + + room_dim = torch.tensor([20, 25, 5], dtype=self.dtype) + mic_array = torch.tensor([[2, 2, 0], [8, 8, 0]], dtype=self.dtype) + source = torch.tensor([7, 6, 0], dtype=self.dtype) + num_rays = 1_000 + ABS, SCAT = 0.1, 0.2 + + absorption = torch.full(fill_value=ABS, size=(7, 6), dtype=self.dtype) + scattering = torch.full(fill_value=SCAT, size=(7, 6), dtype=self.dtype) + hist_per_band_per_wall = F.ray_tracing( + room=room_dim, + source=source, + mic_array=mic_array, + num_rays=num_rays, + absorption=absorption, + scattering=scattering, + ) + absorption = torch.full(fill_value=ABS, size=(6,), dtype=self.dtype) + scattering = torch.full(fill_value=SCAT, size=(6,), dtype=self.dtype) + hist_per_wall = F.ray_tracing( + room=room_dim, + source=source, + mic_array=mic_array, + num_rays=num_rays, + absorption=absorption, + scattering=scattering, + ) + + absorption = ABS + scattering = SCAT + hist_single = F.ray_tracing( + room=room_dim, + source=source, + mic_array=mic_array, + num_rays=num_rays, + absorption=absorption, + scattering=scattering, + ) + self.assertEqual(hist_per_band_per_wall.shape, (2, 7, 2500)) + self.assertEqual(hist_per_wall.shape, (2, 1, 2500)) + self.assertEqual(hist_single.shape, (2, 1, 2500)) + torch.testing.assert_close(hist_single, hist_per_wall) + + hist_single = hist_single.expand(hist_per_band_per_wall.shape) + torch.testing.assert_close(hist_single, hist_per_band_per_wall) + class Functional64OnlyTestImpl(TestBaseMixin): @nested_params( diff --git a/test/torchaudio_unittest/prototype/functional/pyroomacoustics_compatibility_test.py b/test/torchaudio_unittest/prototype/functional/pyroomacoustics_compatibility_test.py index 59da1a49ac2..af2281d213e 100644 --- a/test/torchaudio_unittest/prototype/functional/pyroomacoustics_compatibility_test.py +++ b/test/torchaudio_unittest/prototype/functional/pyroomacoustics_compatibility_test.py @@ -1,3 +1,5 @@ +import math +import numpy as np import torch import torchaudio.prototype.functional as F @@ -9,6 +11,43 @@ import pyroomacoustics as pra +def _pra_ray_tracing(room_dim, absorption, scattering, num_bands, mic_array, source, num_rays, energy_thres, time_thres, hist_bin_size, mic_radius, sound_speed): + walls = ["west", "east", "south", "north", "floor", "ceiling"] + absorption = absorption.T.tolist() + scattering = scattering.T.tolist() + freqs = 125 * 2 ** np.arange(num_bands) + + room = pra.ShoeBox( + room_dim.tolist(), + ray_tracing=True, + materials={ + wall: pra.Material( + energy_absorption={"coeffs": absorp, "center_freqs": freqs}, + scattering={"coeffs": scat, "center_freqs": freqs}, + ) + for wall, absorp, scat in zip(walls, absorption, scattering) + }, + air_absorption=False, + max_order=0, # Make sure PRA doesn't use the hybrid method (we just want ray tracing) + ) + room.add_microphone_array(mic_array.T.tolist()) + room.add_source(source.tolist()) + room.set_ray_tracing( + n_rays=num_rays, + energy_thres=energy_thres, + time_thres=time_thres, + hist_bin_size=hist_bin_size, + receiver_radius=mic_radius, + ) + room.set_sound_speed(sound_speed) + room.compute_rir() + hist_pra = np.array(room.rt_histograms, dtype=np.double)[:, 0, 0] + + # PRA continues the simulation beyond time threshold, but torchaudio does not. + num_bins = math.ceil(time_thres / hist_bin_size) + return hist_pra[:, :, :num_bins] + + @skipIfNoModule("pyroomacoustics") @skipIfNoRIR class CompatibilityTest(PytorchTestCase): @@ -91,3 +130,54 @@ def test_simulate_rir_ism_multi_band(self, channel): expected[i, 0 : room.rir[i][0].shape[0]] = torch.from_numpy(room.rir[i][0]) actual = F.simulate_rir_ism(room_dim, source, mic_array, max_order, absorption) self.assertEqual(expected, actual, atol=1e-3, rtol=1e-3) + + @parameterized.expand( + [ + ([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 1_000), # 3D with 1 mic + ] + ) + def test_ray_tracing_same_results_as_pyroomacoustics(self, room_dim, source, mic_array, num_rays): + num_bands = 6 + energy_thres = 1e-7 + time_thres = 10.0 + hist_bin_size = 0.004 + mic_radius = 0.5 + sound_speed = 343.0 + + absorption = torch.rand((num_bands, 6), dtype=self.dtype) + scattering = torch.rand((num_bands, 6), dtype=self.dtype) + room_dim = torch.tensor(room_dim, dtype=self.dtype) + source = torch.tensor(source, dtype=self.dtype) + mic_array = torch.tensor(mic_array, dtype=self.dtype) + + hist_pra = _pra_ray_tracing( + room_dim, + absorption, + scattering, + num_bands, + mic_array, + source, + num_rays, + energy_thres, + time_thres, + hist_bin_size, + mic_radius, + sound_speed) + + hist = F.ray_tracing( + room=room_dim, + source=source, + mic_array=mic_array, + num_rays=num_rays, + absorption=absorption, + scattering=scattering, + sound_speed=sound_speed, + mic_radius=mic_radius, + energy_thres=energy_thres, + time_thres=time_thres, + hist_bin_size=hist_bin_size, + ) + + assert hist.ndim == 3 + assert hist.shape == hist_pra.shape + self.assertEqual(hist, hist_pra) diff --git a/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py b/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py index 60806d9afbc..5b947a83857 100644 --- a/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py +++ b/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py @@ -112,3 +112,42 @@ def test_simulate_rir_ism_multi_band(self, channel): F.simulate_rir_ism, (room_dim, source, mic_array, max_order, absorption, None, 81, center_frequency, 343.0, 16000.0), ) + + @parameterized.expand( + [ + ([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 500), # 3D with 1 mic + ] + ) + def test_ray_tracing(self, room_dim, source, mic_array, num_rays): + num_walls = 4 if len(room_dim) == 2 else 6 + num_bands = 3 + + absorption = torch.rand(num_bands, num_walls, dtype=torch.float32) + scattering = torch.rand(num_bands, num_walls, dtype=torch.float32) + + energy_thres = 1e-7 + time_thres = 10.0 + hist_bin_size = 0.004 + mic_radius = 0.5 + sound_speed = 343.0 + + room_dim = torch.tensor(room_dim, dtype=self.dtype) + source = torch.tensor(source, dtype=self.dtype) + mic_array = torch.tensor(mic_array, dtype=self.dtype) + + self._assert_consistency( + F.ray_tracing, + ( + room_dim, + source, + mic_array, + num_rays, + absorption, + scattering, + mic_radius, + sound_speed, + energy_thres, + time_thres, + hist_bin_size, + ), + ) diff --git a/torchaudio/csrc/rir/ray_tracing.cpp b/torchaudio/csrc/rir/ray_tracing.cpp index 839a999ed87..0329061be45 100644 --- a/torchaudio/csrc/rir/ray_tracing.cpp +++ b/torchaudio/csrc/rir/ray_tracing.cpp @@ -220,9 +220,13 @@ class RayTracer { if (NORM(to_mic - dir * impact_distance) < mic_radius + EPS) { // The length of this last hop auto travel_dist_at_mic = travel_dist + std::abs(impact_distance); + auto bin_idx = get_bin_idx(travel_dist_at_mic); + if (bin_idx >= histograms.size(1)) { + continue; + } auto coeff = get_energy_coeff(travel_dist_at_mic, mic_radius_sq); auto energy = energies / coeff; - histograms[mic_idx][get_bin_idx(travel_dist_at_mic)] += energy; + histograms[mic_idx][bin_idx] += energy; } } } diff --git a/torchaudio/csrc/rir/wall.h b/torchaudio/csrc/rir/wall.h index a7933b13eb2..cbd6de349e2 100644 --- a/torchaudio/csrc/rir/wall.h +++ b/torchaudio/csrc/rir/wall.h @@ -18,7 +18,6 @@ struct Wall { const torch::Tensor origin; const torch::Tensor normal; const torch::Tensor scattering; - const torch::Tensor reflection; Wall( @@ -26,8 +25,8 @@ struct Wall { const torch::ArrayRef<scalar_t>& normal, const torch::Tensor& absorption, const torch::Tensor& scattering) - : origin(torch::tensor(origin)), - normal(torch::tensor(normal)), + : origin(torch::tensor(origin).to(scattering.dtype())), + normal(torch::tensor(normal).to(scattering.dtype())), scattering(scattering), reflection(1. - absorption) {} }; diff --git a/torchaudio/prototype/functional/__init__.py b/torchaudio/prototype/functional/__init__.py index 3c08461b70f..20bc181731e 100644 --- a/torchaudio/prototype/functional/__init__.py +++ b/torchaudio/prototype/functional/__init__.py @@ -7,7 +7,7 @@ oscillator_bank, sinc_impulse_response, ) -from ._rir import simulate_rir_ism +from ._rir import ray_tracing, simulate_rir_ism from .functional import barkscale_fbanks, chroma_filterbank @@ -20,6 +20,7 @@ "filter_waveform", "frequency_impulse_response", "oscillator_bank", + "ray_tracing", "sinc_impulse_response", "simulate_rir_ism", ] diff --git a/torchaudio/prototype/functional/_rir.py b/torchaudio/prototype/functional/_rir.py index 959d8e98b60..0e67a5494d2 100644 --- a/torchaudio/prototype/functional/_rir.py +++ b/torchaudio/prototype/functional/_rir.py @@ -133,20 +133,24 @@ def _adjust_coeff(coeffs: Union[float, torch.Tensor], name: str) -> torch.Tensor """ num_walls = 6 if isinstance(coeffs, float): + if coeffs < 0: + raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}") return torch.full((1, num_walls), coeffs) if isinstance(coeffs, Tensor): + if torch.any(coeffs < 0): + raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}") if coeffs.ndim == 1: if coeffs.numel() != num_walls: raise ValueError( - f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor." + f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor. " f"Found the shape {coeffs.shape}." ) return coeffs.unsqueeze(0) if coeffs.ndim == 2: - if coeffs.shape != (7, num_walls): + if coeffs.shape[1] != num_walls: raise ValueError( - f"The shape of `{name}` must be (7, {num_walls}) when it is a 2D Tensor." - f"Found the shape {coeffs.shape}." + f"The shape of `{name}` must be (NUM_BANDS, {num_walls}) when it " + f"is a 2D Tensor. Found: {coeffs.shape}." ) return coeffs raise TypeError(f"`{name}` must be float or Tensor.") @@ -169,7 +173,7 @@ def _validate_inputs( if not (source.ndim == 1 and source.numel() == 3): raise ValueError(f"`source` must be 1D Tensor with 3 elements. Found {source.shape}.") if not (mic_array.ndim == 2 and mic_array.shape[1] == 3): - raise ValueError(f"mic_array must be a 2D Tensor with shape (num_channels, 3). Found {mic_array.shape}.") + raise ValueError(f"`mic_array` must be a 2D Tensor with shape (num_channels, 3). Found {mic_array.shape}.") def simulate_rir_ism( @@ -270,3 +274,106 @@ def simulate_rir_ism( rir = rir[..., :output_length] return rir + + +def ray_tracing( + room: torch.Tensor, + source: torch.Tensor, + mic_array: torch.Tensor, + num_rays: int, + absorption: Union[float, torch.Tensor] = 0.0, + scattering: Union[float, torch.Tensor] = 0.0, + mic_radius: float = 0.5, + sound_speed: float = 343.0, + energy_thres: float = 1e-7, + time_thres: float = 10.0, + hist_bin_size: float = 0.004, +) -> torch.Tensor: + r"""Compute energy histogram via ray tracing. + + The implementation is based on *pyroomacoustics* :cite:`scheibler2018pyroomacoustics`. + + ``num_rays`` rays are casted uniformly in all directions from the source; + when a ray intersects a wall, it is reflected and part of its energy is absorbed. + It is also scattered (sent directly to the microphone(s)) according to the ``scattering`` + coefficient. + When a ray is close to the microphone, its current energy is recorded in the output + histogram for that given time slot. + + .. devices:: CPU + + .. properties:: TorchScript + + Args: + room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents + three dimensions of the room. + source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`. + mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`. + absorption (float or torch.Tensor, optional): The absorption coefficients of wall materials. + (Default: ``0.0``). + If the type is ``float``, the absorption coefficient is identical to all walls and + all frequencies. + If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, representing absorption + coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, and + ``"ceiling"``, respectively. + If ``absorption`` is a 2D Tensor, the shape must be `(num_bands, 6)`. + ``num_bands`` is the number of frequency bands (usually 7). + scattering(float or torch.Tensor, optional): The scattering coefficients of wall materials. (Default: ``0.0``) + The shape and type of this parameter is the same as for ``absorption``. + mic_radius(float, optional): The radius of the microphone in meters. (Default: 0.5) + sound_speed (float, optional): The speed of sound in meters per second. (Default: ``343.0``) + energy_thres (float, optional): The energy level below which we stop tracing a ray. (Default: ``1e-7``) + The initial energy of each ray is ``2 / num_rays``. + time_thres (float, optional): The maximal duration for which rays are traced. (Unit: seconds) (Default: 10.0) + hist_bin_size (float, optional): The size of each bin in the output histogram. (Unit: seconds) (Default: 0.004) + + Returns: + (torch.Tensor): The 3D histogram(s) where the energy of the traced ray is recorded. + Each bin corresponds to a given time slot. + The shape is `(channel, num_bands, num_bins)`, where + ``num_bins = ceil(time_thres / hist_bin_size)``. + If both ``absorption`` and ``scattering`` are floats, then ``num_bands == 1``. + """ + if time_thres < hist_bin_size: + raise ValueError( + "`time_thres` must be greater than `hist_bin_size`. " + f"Found: hist_bin_size={hist_bin_size}, time_thres={time_thres}." + ) + + if room.dtype != source.dtype or source.dtype != mic_array.dtype: + raise ValueError( + "dtype of `room`, `source` and `mic_array` must match. " + f"Found: `room` ({room.dtype}), `source` ({source.dtype}) and " + f"`mic_array` ({mic_array.dtype})" + ) + + _validate_inputs(room, source, mic_array) + absorption = _adjust_coeff(absorption, "absorption").to(room.dtype) + scattering = _adjust_coeff(scattering, "scattering").to(room.dtype) + + # Bring absorption and scattering to the same shape + if absorption.shape[0] == 1 and scattering.shape[0] > 1: + absorption = absorption.expand(scattering.shape) + if scattering.shape[0] == 1 and absorption.shape[0] > 1: + scattering = scattering.expand(absorption.shape) + if absorption.shape != scattering.shape: + raise ValueError( + "`absorption` and `scattering` must be broadcastable to the same number of bands and walls. " + f"Inferred shapes absorption={absorption.shape} and scattering={scattering.shape}" + ) + + histograms = torch.ops.torchaudio.ray_tracing( + room, + source, + mic_array, + num_rays, + absorption, + scattering, + mic_radius, + sound_speed, + energy_thres, + time_thres, + hist_bin_size, + ) + + return histograms