Skip to content

Commit

Permalink
WIP: Add Ray Tracing (pytorch#3604)
Browse files Browse the repository at this point in the history
Summary:
Revamped version of pytorch#3234 (which was also revamp of pytorch#2850)

Differential Revision: D49197174

Pulled By: mthrok
  • Loading branch information
mthrok committed Oct 4, 2023
1 parent d9942ba commit 49be591
Show file tree
Hide file tree
Showing 8 changed files with 500 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/source/prototype.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ Room Impulse Response Simulation
:toctree: generated
:nosignatures:

ray_tracing
simulate_rir_ism
249 changes: 249 additions & 0 deletions test/torchaudio_unittest/prototype/functional/functional_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,255 @@ def test_exp_sigmoid_input_diff(self, linspace_input_values, exp_sigmoid_paramet

self.assertEqual(torch_out, torch.tensor(np_out))

@parameterized.expand(
[
# both float
(0.1, 0.2, (2, 1, 2500)),
# Per-wall
((6,), 0.2, (2, 1, 2500)),
(0.1, (6,), (2, 1, 2500)),
((6,), (6,), (2, 1, 2500)),
# Per-band and per-wall
((3, 6), 0.2, (2, 3, 2500)),
(0.1, (5, 6), (2, 5, 2500)),
((7, 6), (7, 6), (2, 7, 2500)),
]
)
def test_ray_tracing_output_shape(self, abs_, scat_, expected_shape):
if isinstance(abs_, float):
absorption = abs_
else:
absorption = torch.rand(abs_, dtype=self.dtype)
if isinstance(scat_, float):
scattering = scat_
else:
scattering = torch.rand(scat_, dtype=self.dtype)

room_dim = torch.tensor([3, 4, 5], dtype=self.dtype)
mic_array = torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=self.dtype)
source = torch.tensor([1, 2, 3], dtype=self.dtype)
num_rays = 100

hist = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
)
assert hist.shape == expected_shape

def test_ray_tracing_input_errors(self):
room = torch.tensor([3.0, 4.0, 5.0], dtype=self.dtype)
source = torch.tensor([0.0, 0.0, 0.0], dtype=self.dtype)
mic = torch.tensor([[1.0, 2.0, 3.0]], dtype=self.dtype)

# baseline. This should not raise
_ = F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10)

# invlaid room shape
for invalid in ([[4, 5]], [4, 5, 4, 5]):
invalid = torch.tensor(invalid, dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(room=invalid, source=source, mic_array=mic, num_rays=10)

error = str(cm.exception)
self.assertIn("`room` must be a 1D Tensor with 3 elements.", error)
self.assertIn(str(invalid.shape), error)

# invalid microphone shape
invalid = torch.tensor([[[3, 4]]], dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(room=room, source=source, mic_array=invalid, num_rays=10)

error = str(cm.exception)
self.assertIn("`mic_array` must be a 2D Tensor with shape (num_channels, 3).", error)
self.assertIn(str(invalid.shape), error)

# incompatible dtypes
with self.assertRaises(ValueError) as cm:
F.ray_tracing(
room=room.to(torch.float64),
source=source.to(torch.float32),
mic_array=mic,
num_rays=10,
)
error = str(cm.exception)
self.assertIn("dtype of `room`, `source` and `mic_array` must match.", error)
self.assertIn("`room` (torch.float64)", error)
self.assertIn("`source` (torch.float32)", error)
self.assertIn("`mic_array` (torch.float32)", error)

# invalid time configuration
with self.assertRaises(ValueError) as cm:
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
time_thres=10,
hist_bin_size=11,
)
error = str(cm.exception)
self.assertIn("`time_thres` must be greater than `hist_bin_size`.", error)
self.assertIn("hist_bin_size=11", error)
self.assertIn("time_thres=10", error)

# invalid absorption shape 1D
invalid_abs = torch.tensor([1, 2, 3], dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=invalid_abs,
)
error = str(cm.exception)
self.assertIn("The shape of `absorption` must be (6,) when", error)
self.assertIn(str(invalid_abs.shape), error)

# invalid absorption shape 2D
invalid_abs = torch.tensor([[1, 2, 3]], dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, absorption=invalid_abs)
error = str(cm.exception)
self.assertIn("The shape of `absorption` must be (NUM_BANDS, 6) when", error)
self.assertIn(str(invalid_abs.shape), error)

# invalid scattering shape 1D
invalid_scat = torch.tensor([1, 2, 3], dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
scattering=invalid_scat,
)
error = str(cm.exception)
self.assertIn("The shape of `scattering` must be (6,) when", error)
self.assertIn(str(invalid_scat.shape), error)

# invalid scattering shape 2D
invalid_scat = torch.tensor([[1, 2, 3]], dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, scattering=invalid_scat)
error = str(cm.exception)
self.assertIn("The shape of `scattering` must be (NUM_BANDS, 6) when", error)
self.assertIn(str(invalid_scat.shape), error)

# TODO: Invalid absorption/scattering value

# incompatible scattering and absorption
abs_ = torch.zeros((7, 6), dtype=self.dtype)
scat = torch.zeros((5, 6), dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=abs_,
scattering=scat,
)
error = str(cm.exception)
self.assertIn(
"`absorption` and `scattering` must be broadcastable to the same number of bands and walls", error
)
self.assertIn(f"absorption={abs_.shape}", error)
self.assertIn(f"scattering={scat.shape}", error)

# Make sure passing different shapes for absorption or scattering doesn't raise an error
# float and tensor
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=0.1,
scattering=torch.randn((5, 6), dtype=self.dtype),
)
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=torch.randn((7, 6), dtype=self.dtype),
scattering=0.1,
)
# per-wall only and per-band + per-wall
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=torch.rand(6, dtype=self.dtype),
scattering=torch.rand(7, 6, dtype=self.dtype),
)
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=torch.rand(7, 6, dtype=self.dtype),
scattering=torch.rand(6, dtype=self.dtype),
)

def test_ray_tracing_per_band_per_wall_absorption(self):
"""Check that when the value of absorption and scattering are the same
across walls and frequency bands, the output histograms are:
- all equal across frequency bands
- equal to simply passing a float value instead of a (num_bands, D) or
(D,) tensor.
"""

room_dim = torch.tensor([20, 25, 5], dtype=self.dtype)
mic_array = torch.tensor([[2, 2, 0], [8, 8, 0]], dtype=self.dtype)
source = torch.tensor([7, 6, 0], dtype=self.dtype)
num_rays = 1_000
ABS, SCAT = 0.1, 0.2

absorption = torch.full(fill_value=ABS, size=(7, 6), dtype=self.dtype)
scattering = torch.full(fill_value=SCAT, size=(7, 6), dtype=self.dtype)
hist_per_band_per_wall = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
)
absorption = torch.full(fill_value=ABS, size=(6,), dtype=self.dtype)
scattering = torch.full(fill_value=SCAT, size=(6,), dtype=self.dtype)
hist_per_wall = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
)

absorption = ABS
scattering = SCAT
hist_single = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
)
self.assertEqual(hist_per_band_per_wall.shape, (2, 7, 2500))
self.assertEqual(hist_per_wall.shape, (2, 1, 2500))
self.assertEqual(hist_single.shape, (2, 1, 2500))
torch.testing.assert_close(hist_single, hist_per_wall)

hist_single = hist_single.expand(hist_per_band_per_wall.shape)
torch.testing.assert_close(hist_single, hist_per_band_per_wall)


class Functional64OnlyTestImpl(TestBaseMixin):
@nested_params(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math
import numpy as np
import torch
import torchaudio.prototype.functional as F

Expand All @@ -9,6 +11,43 @@
import pyroomacoustics as pra


def _pra_ray_tracing(room_dim, absorption, scattering, num_bands, mic_array, source, num_rays, energy_thres, time_thres, hist_bin_size, mic_radius, sound_speed):
walls = ["west", "east", "south", "north", "floor", "ceiling"]
absorption = absorption.T.tolist()
scattering = scattering.T.tolist()
freqs = 125 * 2 ** np.arange(num_bands)

room = pra.ShoeBox(
room_dim.tolist(),
ray_tracing=True,
materials={
wall: pra.Material(
energy_absorption={"coeffs": absorp, "center_freqs": freqs},
scattering={"coeffs": scat, "center_freqs": freqs},
)
for wall, absorp, scat in zip(walls, absorption, scattering)
},
air_absorption=False,
max_order=0, # Make sure PRA doesn't use the hybrid method (we just want ray tracing)
)
room.add_microphone_array(mic_array.T.tolist())
room.add_source(source.tolist())
room.set_ray_tracing(
n_rays=num_rays,
energy_thres=energy_thres,
time_thres=time_thres,
hist_bin_size=hist_bin_size,
receiver_radius=mic_radius,
)
room.set_sound_speed(sound_speed)
room.compute_rir()
hist_pra = np.array(room.rt_histograms, dtype=np.double)[:, 0, 0]

# PRA continues the simulation beyond time threshold, but torchaudio does not.
num_bins = math.ceil(time_thres / hist_bin_size)
return hist_pra[:, :, :num_bins]


@skipIfNoModule("pyroomacoustics")
@skipIfNoRIR
class CompatibilityTest(PytorchTestCase):
Expand Down Expand Up @@ -91,3 +130,54 @@ def test_simulate_rir_ism_multi_band(self, channel):
expected[i, 0 : room.rir[i][0].shape[0]] = torch.from_numpy(room.rir[i][0])
actual = F.simulate_rir_ism(room_dim, source, mic_array, max_order, absorption)
self.assertEqual(expected, actual, atol=1e-3, rtol=1e-3)

@parameterized.expand(
[
([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 1_000), # 3D with 1 mic
]
)
def test_ray_tracing_same_results_as_pyroomacoustics(self, room_dim, source, mic_array, num_rays):
num_bands = 6
energy_thres = 1e-7
time_thres = 10.0
hist_bin_size = 0.004
mic_radius = 0.5
sound_speed = 343.0

absorption = torch.rand((num_bands, 6), dtype=self.dtype)
scattering = torch.rand((num_bands, 6), dtype=self.dtype)
room_dim = torch.tensor(room_dim, dtype=self.dtype)
source = torch.tensor(source, dtype=self.dtype)
mic_array = torch.tensor(mic_array, dtype=self.dtype)

hist_pra = _pra_ray_tracing(
room_dim,
absorption,
scattering,
num_bands,
mic_array,
source,
num_rays,
energy_thres,
time_thres,
hist_bin_size,
mic_radius,
sound_speed)

hist = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
sound_speed=sound_speed,
mic_radius=mic_radius,
energy_thres=energy_thres,
time_thres=time_thres,
hist_bin_size=hist_bin_size,
)

assert hist.ndim == 3
assert hist.shape == hist_pra.shape
self.assertEqual(hist, hist_pra)
Loading

0 comments on commit 49be591

Please sign in to comment.