Skip to content

Commit

Permalink
WIP: Add Ray Tracing (pytorch#3604)
Browse files Browse the repository at this point in the history
Summary:
Revamped version of pytorch#3234 (which was also revamp of pytorch#2850)

Differential Revision: D49197174

Pulled By: mthrok
  • Loading branch information
mthrok committed Oct 3, 2023
1 parent d9942ba commit 98838b8
Show file tree
Hide file tree
Showing 6 changed files with 444 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/prototype.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ Room Impulse Response Simulation
:toctree: generated
:nosignatures:

ray_tracing
simulate_rir_ism
223 changes: 223 additions & 0 deletions test/torchaudio_unittest/prototype/functional/functional_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,3 +460,226 @@ def _debug_plot():
except AssertionError:
_debug_plot()
raise

@parameterized.expand(
[
(0.1, 0.2, (2, 1, 2500)), # both float
# Per-wall
(torch.rand(4), 0.2, (2, 1, 2500)),
(0.1, torch.rand(4), (2, 1, 2500)),
(torch.rand(4), torch.rand(4), (2, 1, 2500)),
# Per-band and per-wall
(torch.rand(6, 4), 0.2, (2, 6, 2500)),
(0.1, torch.rand(6, 4), (2, 6, 2500)),
(torch.rand(6, 4), torch.rand(6, 4), (2, 6, 2500)),
]
)
def test_ray_tracing_output_shape(self, absorption, scattering, expected_shape):
room_dim = torch.tensor([20, 25], dtype=self.dtype)
mic_array = torch.tensor([[2, 2], [8, 8]], dtype=self.dtype)
source = torch.tensor([7, 6], dtype=self.dtype)
num_rays = 100

hist = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
)

assert hist.shape == expected_shape

def test_ray_tracing_input_errors(self):
with self.assertRaisesRegex(ValueError, "room must be a 1D tensor"):
F.ray_tracing(
room=torch.tensor([[4, 5]]), source=torch.tensor([0, 0]), mic_array=torch.tensor([[3, 4]]), num_rays=10
)
with self.assertRaisesRegex(ValueError, "room must be a 1D tensor"):
F.ray_tracing(
room=torch.tensor([4, 5, 4, 5]),
source=torch.tensor([0, 0]),
mic_array=torch.tensor([[3, 4]]),
num_rays=10,
)
with self.assertRaisesRegex(ValueError, r"mic_array must be 1D tensor of shape \(D,\), or 2D tensor"):
F.ray_tracing(
room=torch.tensor([4, 5]), source=torch.tensor([0, 0]), mic_array=torch.tensor([[[3, 4]]]), num_rays=10
)
with self.assertRaisesRegex(ValueError, "room must be of float32 or float64 dtype"):
F.ray_tracing(
room=torch.tensor([4, 5]).to(torch.int),
source=torch.tensor([0, 0]),
mic_array=torch.tensor([3, 4]),
num_rays=10,
)
with self.assertRaisesRegex(ValueError, "dtype of room, source and mic_array must be the same"):
F.ray_tracing(
room=torch.tensor([4, 5]).to(torch.float64),
source=torch.tensor([0, 0]).to(torch.float32),
mic_array=torch.tensor([3, 4]),
num_rays=10,
)
with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"):
F.ray_tracing(
room=torch.tensor([4, 5, 10], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
)
with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"):
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
)
with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"):
F.ray_tracing(
room=torch.tensor([4, 5, 10], dtype=torch.float),
source=torch.tensor([0, 0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
)
with self.assertRaisesRegex(ValueError, "time_thres=10 must be at least greater than hist_bin_size=11"):
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
time_thres=10,
hist_bin_size=11,
)
with self.assertRaisesRegex(ValueError, "The shape of absorption must be"):
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
absorption=torch.rand(5, dtype=torch.float),
)
with self.assertRaisesRegex(ValueError, "The shape of scattering must be"):
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
scattering=torch.rand(5, 5, dtype=torch.float),
)
with self.assertRaisesRegex(ValueError, "The shape of absorption must be"):
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
absorption=torch.rand(5, 5, dtype=torch.float),
)
with self.assertRaisesRegex(ValueError, "The shape of scattering must be"):
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
scattering=torch.rand(5, dtype=torch.float),
)
with self.assertRaisesRegex(
ValueError, "absorption and scattering must have the same number of bands and walls"
):
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
absorption=torch.rand(6, 4, dtype=torch.float),
scattering=torch.rand(5, 4, dtype=torch.float),
)

# Make sure passing different shapes for absorption or scattering doesn't raise an error
# float and tensor
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
absorption=0.1,
scattering=torch.rand(5, 4, dtype=torch.float),
)
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
absorption=torch.rand(5, 4, dtype=torch.float),
scattering=0.1,
)
# per-wall only and per-band + per-wall
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
absorption=torch.rand(4, dtype=torch.float),
scattering=torch.rand(6, 4, dtype=torch.float),
)
F.ray_tracing(
room=torch.tensor([4, 5], dtype=torch.float),
source=torch.tensor([0, 0], dtype=torch.float),
mic_array=torch.tensor([3, 4], dtype=torch.float),
num_rays=10,
absorption=torch.rand(6, 4, dtype=torch.float),
scattering=torch.rand(4, dtype=torch.float),
)

def test_ray_tracing_per_band_per_wall_absorption(self):
"""Check that when the value of absorption and scattering are the same
across walls and frequency bands, the output histograms are:
- all equal across frequency bands
- equal to simply passing a float value instead of a (num_bands, D) or
(D,) tensor.
"""

room_dim = torch.tensor([20, 25], dtype=self.dtype)
mic_array = torch.tensor([[2, 2], [8, 8]], dtype=self.dtype)
source = torch.tensor([7, 6], dtype=self.dtype)
num_rays = 1_000
ABS, SCAT = 0.1, 0.2

absorption = torch.full(fill_value=ABS, size=(6, 4), dtype=self.dtype)
scattering = torch.full(fill_value=SCAT, size=(6, 4), dtype=self.dtype)
hist_per_band_per_wall = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
)
absorption = torch.full(fill_value=ABS, size=(4,), dtype=self.dtype)
scattering = torch.full(fill_value=SCAT, size=(4,), dtype=self.dtype)
hist_per_wall = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
)

absorption = ABS
scattering = SCAT
hist_single = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
)
assert hist_per_band_per_wall.shape == (2, 6, 2500)
assert hist_per_wall.shape == (2, 1, 2500)
assert hist_single.shape == (2, 1, 2500)
torch.testing.assert_close(hist_single, hist_per_wall)

hist_single = hist_single.expand(2, 6, 2500)
torch.testing.assert_close(hist_single, hist_per_band_per_wall)
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch
import torchaudio.prototype.functional as F

Expand Down Expand Up @@ -91,3 +92,80 @@ def test_simulate_rir_ism_multi_band(self, channel):
expected[i, 0 : room.rir[i][0].shape[0]] = torch.from_numpy(room.rir[i][0])
actual = F.simulate_rir_ism(room_dim, source, mic_array, max_order, absorption)
self.assertEqual(expected, actual, atol=1e-3, rtol=1e-3)

@parameterized.expand(
[
([20, 25], [2, 2], [[8, 8], [7, 6]], 10_000), # 2D with 2 mics
([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 1_000), # 3D with 1 mic
]
)
def test_ray_tracing_same_results_as_pyroomacoustics(self, room_dim, source, mic_array, num_rays):

walls = ["west", "east", "south", "north"]
if len(room_dim) == 3:
walls += ["floor", "ceiling"]
num_walls = len(walls)
num_bands = 6 # Note: in ray tracing, we don't need to restrict the number of bands to 7

absorption = torch.rand(num_bands, num_walls, dtype=self.dtype)
scattering = torch.rand(num_bands, num_walls, dtype=self.dtype)
energy_thres = 1e-7
time_thres = 10.0
hist_bin_size = 0.004
mic_radius = 0.5
sound_speed = 343.0

room_dim = torch.tensor(room_dim, dtype=self.dtype)
source = torch.tensor(source, dtype=self.dtype)
mic_array = torch.tensor(mic_array, dtype=self.dtype)

room = pra.ShoeBox(
room_dim.tolist(),
ray_tracing=True,
materials={
walls[i]: pra.Material(
energy_absorption={
"coeffs": absorption[:, i].reshape(-1).detach().numpy(),
"center_freqs": 125 * 2 ** np.arange(num_bands),
},
scattering={
"coeffs": scattering[:, i].reshape(-1).detach().numpy(),
"center_freqs": 125 * 2 ** np.arange(num_bands),
},
)
for i in range(num_walls)
},
air_absorption=False,
max_order=0, # Make sure PRA doesn't use the hybrid method (we just want ray tracing)
)
room.add_microphone_array(mic_array.T.tolist())
room.add_source(source.tolist())
room.set_ray_tracing(
n_rays=num_rays,
energy_thres=energy_thres,
time_thres=time_thres,
hist_bin_size=hist_bin_size,
receiver_radius=mic_radius,
)
room.set_sound_speed(sound_speed)

room.compute_rir()
hist_pra = torch.tensor(np.array(room.rt_histograms))[:, 0, 0]

hist = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
sound_speed=sound_speed,
mic_radius=mic_radius,
energy_thres=energy_thres,
time_thres=time_thres,
hist_bin_size=hist_bin_size,
)

assert hist.ndim == 3
assert hist.shape == hist_pra.shape
self.assertEqual(hist.to(torch.float32), hist_pra)
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,43 @@ def test_simulate_rir_ism_multi_band(self, channel):
F.simulate_rir_ism,
(room_dim, source, mic_array, max_order, absorption, None, 81, center_frequency, 343.0, 16000.0),
)

@parameterized.expand(
[
([20, 25], [2, 2], [[8, 8], [7, 6]], 1_000), # 2D with 2 mics
([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 500), # 3D with 1 mic
]
)
def test_ray_tracing(self, room_dim, source, mic_array, num_rays):
num_walls = 4 if len(room_dim) == 2 else 6
num_bands = 3

absorption = torch.rand(num_bands, num_walls, dtype=torch.float32)
scattering = torch.rand(num_bands, num_walls, dtype=torch.float32)

energy_thres = 1e-7
time_thres = 10.0
hist_bin_size = 0.004
mic_radius = 0.5
sound_speed = 343.0

room_dim = torch.tensor(room_dim, dtype=self.dtype)
source = torch.tensor(source, dtype=self.dtype)
mic_array = torch.tensor(mic_array, dtype=self.dtype)

self._assert_consistency(
F.ray_tracing,
(
room_dim,
source,
mic_array,
num_rays,
absorption,
scattering,
mic_radius,
sound_speed,
energy_thres,
time_thres,
hist_bin_size,
),
)
3 changes: 2 additions & 1 deletion torchaudio/prototype/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
oscillator_bank,
sinc_impulse_response,
)
from ._rir import simulate_rir_ism
from ._rir import ray_tracing, simulate_rir_ism
from .functional import barkscale_fbanks, chroma_filterbank


Expand All @@ -20,6 +20,7 @@
"filter_waveform",
"frequency_impulse_response",
"oscillator_bank",
"ray_tracing",
"sinc_impulse_response",
"simulate_rir_ism",
]
Loading

0 comments on commit 98838b8

Please sign in to comment.