Skip to content

Commit

Permalink
WIP: Add Ray Tracing (pytorch#3604)
Browse files Browse the repository at this point in the history
Summary:
Revamped version of pytorch#3234 (which was also revamp of pytorch#2850)

Differential Revision: D49197174

Pulled By: mthrok
  • Loading branch information
mthrok committed Oct 5, 2023
1 parent d9942ba commit 574ae45
Show file tree
Hide file tree
Showing 9 changed files with 565 additions and 19 deletions.
1 change: 1 addition & 0 deletions docs/source/prototype.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ Room Impulse Response Simulation
:toctree: generated
:nosignatures:

ray_tracing
simulate_rir_ism
52 changes: 46 additions & 6 deletions test/cpp/rir/wall_collision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,24 @@

using namespace torchaudio::rir;

using DTYPE = double;

struct CollisionTestParam {
// Input
torch::Tensor origin;
torch::Tensor direction;
// Expected
torch::Tensor hit_point;
int next_wall_index;
float hit_distance;
DTYPE hit_distance;
};

CollisionTestParam par(
torch::ArrayRef<float> origin,
torch::ArrayRef<float> direction,
torch::ArrayRef<float> hit_point,
torch::ArrayRef<DTYPE> origin,
torch::ArrayRef<DTYPE> direction,
torch::ArrayRef<DTYPE> hit_point,
int next_wall_index,
float hit_distance) {
DTYPE hit_distance) {
auto dir = torch::tensor(direction);
return {
torch::tensor(origin),
Expand Down Expand Up @@ -50,7 +52,7 @@ TEST_P(Simple3DRoomCollisionTest, CollisionTest3D) {

auto param = GetParam();
auto [hit_point, next_wall_index, hit_distance] =
find_collision_wall<float>(room, param.origin, param.direction);
find_collision_wall<DTYPE>(room, param.origin, param.direction);

EXPECT_EQ(param.next_wall_index, next_wall_index);
EXPECT_FLOAT_EQ(param.hit_distance, hit_distance);
Expand Down Expand Up @@ -100,3 +102,41 @@ INSTANTIATE_TEST_CASE_P(
par({.5, .5, 1}, {0.0, -1., -1.}, {.5, .0, .5}, 2, ISQRT2),
par({.5, .5, 1}, {0.0, 1.0, -1.}, {.5, 1., .5}, 3, ISQRT2),
par({.5, .5, 1}, {0.0, 0.0, -1.}, {.5, .5, .0}, 4, 1.0)));


INSTANTIATE_TEST_CASE_P(
EdgeCollisionTest,
Simple3DRoomCollisionTest,
::testing::Values(
par({1, 1, 0}, {1., 1., 0.}, {1., 1., 0.}, 1, 0.0),
par({1, 1, 0}, {-1., 1., 0.}, {1., 1., 0.}, 3, 0.0),
//
par({1, 1, 1}, {1., 1., 1.}, {1., 1., 1.}, 1, 0.0),
par({1, 1, 1}, {-1., 1., 1.}, {1., 1., 1.}, 3, 0.0),
par({1, 1, 1}, {-1., -1., 1.}, {1., 1., 1.}, 5, 0.0)
));

class Simple3DRoomCollisionTest2
: public ::testing::TestWithParam<CollisionTestParam> {};

TEST_P(Simple3DRoomCollisionTest2, CollisionTest3D) {
auto room = torch::tensor({3, 4, 5});

auto param = GetParam();
auto [hit_point, next_wall_index, hit_distance] =
find_collision_wall<DTYPE>(room, param.origin, param.direction);

EXPECT_EQ(param.next_wall_index, next_wall_index);
EXPECT_FLOAT_EQ(param.hit_distance, hit_distance);
EXPECT_NEAR(param.hit_point[0].item<DTYPE>(), hit_point[0].item<DTYPE>(), 1e-5);
EXPECT_NEAR(param.hit_point[1].item<DTYPE>(), hit_point[1].item<DTYPE>(), 1e-5);
EXPECT_NEAR(param.hit_point[2].item<DTYPE>(), hit_point[2].item<DTYPE>(), 1e-5);
}


INSTANTIATE_TEST_CASE_P(
EdgeCollisionTest2,
Simple3DRoomCollisionTest2,
::testing::Values(
par({3., 4., 4.6542}, {-0.9798, 0.1733, 0.1000}, {3., 4., 4.6542}, 3, 0.0)
));
263 changes: 263 additions & 0 deletions test/torchaudio_unittest/prototype/functional/functional_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,269 @@ def test_exp_sigmoid_input_diff(self, linspace_input_values, exp_sigmoid_paramet

self.assertEqual(torch_out, torch.tensor(np_out))

@parameterized.expand(
[
# both float
(0.1, 0.2, (2, 1, 2500)),
# Per-wall
((6,), 0.2, (2, 1, 2500)),
(0.1, (6,), (2, 1, 2500)),
((6,), (6,), (2, 1, 2500)),
# Per-band and per-wall
((3, 6), 0.2, (2, 3, 2500)),
(0.1, (5, 6), (2, 5, 2500)),
((7, 6), (7, 6), (2, 7, 2500)),
]
)
def test_ray_tracing_output_shape(self, abs_, scat_, expected_shape):
if isinstance(abs_, float):
absorption = abs_
else:
absorption = torch.rand(abs_, dtype=self.dtype)
if isinstance(scat_, float):
scattering = scat_
else:
scattering = torch.rand(scat_, dtype=self.dtype)

room_dim = torch.tensor([3, 4, 5], dtype=self.dtype)
mic_array = torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=self.dtype)
source = torch.tensor([1, 2, 3], dtype=self.dtype)
num_rays = 100

hist = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
)
assert hist.shape == expected_shape

def test_ray_tracing_input_errors(self):
room = torch.tensor([3.0, 4.0, 5.0], dtype=self.dtype)
source = torch.tensor([0.0, 0.0, 0.0], dtype=self.dtype)
mic = torch.tensor([[1.0, 2.0, 3.0]], dtype=self.dtype)

# baseline. This should not raise
_ = F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10)

# invlaid room shape
for invalid in ([[4, 5]], [4, 5, 4, 5]):
invalid = torch.tensor(invalid, dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(room=invalid, source=source, mic_array=mic, num_rays=10)

error = str(cm.exception)
self.assertIn("`room` must be a 1D Tensor with 3 elements.", error)
self.assertIn(str(invalid.shape), error)

# invalid microphone shape
invalid = torch.tensor([[[3, 4]]], dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(room=room, source=source, mic_array=invalid, num_rays=10)

error = str(cm.exception)
self.assertIn("`mic_array` must be a 2D Tensor with shape (num_channels, 3).", error)
self.assertIn(str(invalid.shape), error)

# incompatible dtypes
with self.assertRaises(ValueError) as cm:
F.ray_tracing(
room=room.to(torch.float64),
source=source.to(torch.float32),
mic_array=mic.to(torch.float32),
num_rays=10,
)
error = str(cm.exception)
self.assertIn("dtype of `room`, `source` and `mic_array` must match.", error)
self.assertIn("`room` (torch.float64)", error)
self.assertIn("`source` (torch.float32)", error)
self.assertIn("`mic_array` (torch.float32)", error)

# invalid time configuration
with self.assertRaises(ValueError) as cm:
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
time_thres=10,
hist_bin_size=11,
)
error = str(cm.exception)
self.assertIn("`time_thres` must be greater than `hist_bin_size`.", error)
self.assertIn("hist_bin_size=11", error)
self.assertIn("time_thres=10", error)

# invalid absorption shape 1D
invalid_abs = torch.tensor([1, 2, 3], dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=invalid_abs,
)
error = str(cm.exception)
self.assertIn("The shape of `absorption` must be (6,) when", error)
self.assertIn(str(invalid_abs.shape), error)

# invalid absorption shape 2D
invalid_abs = torch.tensor([[1, 2, 3]], dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, absorption=invalid_abs)
error = str(cm.exception)
self.assertIn("The shape of `absorption` must be (NUM_BANDS, 6) when", error)
self.assertIn(str(invalid_abs.shape), error)

# invalid scattering shape 1D
invalid_scat = torch.tensor([1, 2, 3], dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
scattering=invalid_scat,
)
error = str(cm.exception)
self.assertIn("The shape of `scattering` must be (6,) when", error)
self.assertIn(str(invalid_scat.shape), error)

# invalid scattering shape 2D
invalid_scat = torch.tensor([[1, 2, 3]], dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, scattering=invalid_scat)
error = str(cm.exception)
self.assertIn("The shape of `scattering` must be (NUM_BANDS, 6) when", error)
self.assertIn(str(invalid_scat.shape), error)

# Invalid absorption value
for invalid_val in [-1., torch.tensor([i - 1. for i in range(6)])]:
with self.assertRaises(ValueError) as cm:
F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, absorption=invalid_val)

error = str(cm.exception)
self.assertIn("`absorption` must be non-negative`")

# Invalid scattering value
for invalid_val in [-1., torch.tensor([i - 1. for i in range(6)])]:
with self.assertRaises(ValueError) as cm:
F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, scattering=invalid_val)

error = str(cm.exception)
self.assertIn("`scattering` must be non-negative`")

# incompatible scattering and absorption
abs_ = torch.zeros((7, 6), dtype=self.dtype)
scat = torch.zeros((5, 6), dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=abs_,
scattering=scat,
)
error = str(cm.exception)
self.assertIn(
"`absorption` and `scattering` must be broadcastable to the same number of bands and walls", error
)
self.assertIn(f"absorption={abs_.shape}", error)
self.assertIn(f"scattering={scat.shape}", error)

# Make sure passing different shapes for absorption or scattering doesn't raise an error
# float and tensor
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=0.1,
scattering=torch.rand((5, 6), dtype=self.dtype),
)
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=torch.rand((7, 6), dtype=self.dtype),
scattering=0.1,
)
# per-wall only and per-band + per-wall
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=torch.rand(6, dtype=self.dtype),
scattering=torch.rand(7, 6, dtype=self.dtype),
)
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=torch.rand(7, 6, dtype=self.dtype),
scattering=torch.rand(6, dtype=self.dtype),
)

def test_ray_tracing_per_band_per_wall_absorption(self):
"""Check that when the value of absorption and scattering are the same
across walls and frequency bands, the output histograms are:
- all equal across frequency bands
- equal to simply passing a float value instead of a (num_bands, D) or
(D,) tensor.
"""

room_dim = torch.tensor([20, 25, 5], dtype=self.dtype)
mic_array = torch.tensor([[2, 2, 0], [8, 8, 0]], dtype=self.dtype)
source = torch.tensor([7, 6, 0], dtype=self.dtype)
num_rays = 1_000
ABS, SCAT = 0.1, 0.2

absorption = torch.full(fill_value=ABS, size=(7, 6), dtype=self.dtype)
scattering = torch.full(fill_value=SCAT, size=(7, 6), dtype=self.dtype)
hist_per_band_per_wall = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
)
absorption = torch.full(fill_value=ABS, size=(6,), dtype=self.dtype)
scattering = torch.full(fill_value=SCAT, size=(6,), dtype=self.dtype)
hist_per_wall = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
)

absorption = ABS
scattering = SCAT
hist_single = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
)
self.assertEqual(hist_per_band_per_wall.shape, (2, 7, 2500))
self.assertEqual(hist_per_wall.shape, (2, 1, 2500))
self.assertEqual(hist_single.shape, (2, 1, 2500))
torch.testing.assert_close(hist_single, hist_per_wall)

hist_single = hist_single.expand(hist_per_band_per_wall.shape)
torch.testing.assert_close(hist_single, hist_per_band_per_wall)


class Functional64OnlyTestImpl(TestBaseMixin):
@nested_params(
Expand Down
Loading

0 comments on commit 574ae45

Please sign in to comment.