Skip to content

Commit

Permalink
Add Ray Tracing (pytorch#3604) (pytorch#2850) (pytorch#3655)
Browse files Browse the repository at this point in the history
Summary:
Revamped version of pytorch#3234
(which was also revamp of pytorch#2850)
  • Loading branch information
mthrok authored Oct 13, 2023
1 parent dde08ba commit fa78fb6
Show file tree
Hide file tree
Showing 9 changed files with 550 additions and 22 deletions.
1 change: 1 addition & 0 deletions docs/source/prototype.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ Room Impulse Response Simulation
:toctree: generated
:nosignatures:

ray_tracing
simulate_rir_ism
8 changes: 6 additions & 2 deletions src/libtorchaudio/rir/ray_tracing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,17 +220,21 @@ class RayTracer {
if (NORM(to_mic - dir * impact_distance) < mic_radius + EPS) {
// The length of this last hop
auto travel_dist_at_mic = travel_dist + std::abs(impact_distance);
auto bin_idx = get_bin_idx(travel_dist_at_mic);
if (bin_idx >= histograms.size(1)) {
continue;
}
auto coeff = get_energy_coeff(travel_dist_at_mic, mic_radius_sq);
auto energy = energies / coeff;
histograms[mic_idx][get_bin_idx(travel_dist_at_mic)] += energy;
histograms[mic_idx][bin_idx] += energy;
}
}
}

travel_dist += hit_distance;
energies *= wall.reflection;

// Let's shoot the scattered ray induced by the rebound on the wall
// Let's shoot the scattered ray induced by the rebound on the wall
if (do_scattering) {
scat_ray(histograms, wall, energies, origin, hit_point, travel_dist);
energies *= (1. - wall.scattering);
Expand Down
10 changes: 6 additions & 4 deletions src/libtorchaudio/rir/wall.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,15 @@ struct Wall {
const torch::Tensor origin;
const torch::Tensor normal;
const torch::Tensor scattering;

const torch::Tensor reflection;

Wall(
const torch::ArrayRef<scalar_t>& origin,
const torch::ArrayRef<scalar_t>& normal,
const torch::Tensor& absorption,
const torch::Tensor& scattering)
: origin(torch::tensor(origin)),
normal(torch::tensor(normal)),
: origin(torch::tensor(origin).to(scattering.dtype())),
normal(torch::tensor(normal).to(scattering.dtype())),
scattering(scattering),
reflection(1. - absorption) {}
};
Expand Down Expand Up @@ -137,7 +136,6 @@ std::tuple<torch::Tensor, int, scalar_t> find_collision_wall(
for (unsigned int i = 0; i < 3; ++i) {
auto dir0 = SCALAR(direction[i]);
auto abs_dir0 = std::abs(dir0);

// If the ray is almost parallel to a plane, then we delegate the
// computation to the other planes.
if (abs_dir0 < EPS) {
Expand All @@ -148,6 +146,10 @@ std::tuple<torch::Tensor, int, scalar_t> find_collision_wall(
scalar_t distance = (dir0 < 0.)
? SCALAR(origin[i]) // Going towards origin
: SCALAR(room[i] - origin[i]); // Going away from origin
// sometimes origin is slightly outside of room
if (distance < 0) {
distance = 0.;
}
auto ratio = distance / abs_dir0;
int i_increment = dir0 > 0.;

Expand Down
3 changes: 2 additions & 1 deletion src/torchaudio/prototype/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
oscillator_bank,
sinc_impulse_response,
)
from ._rir import simulate_rir_ism
from ._rir import ray_tracing, simulate_rir_ism
from .functional import barkscale_fbanks, chroma_filterbank


Expand All @@ -20,6 +20,7 @@
"filter_waveform",
"frequency_impulse_response",
"oscillator_bank",
"ray_tracing",
"sinc_impulse_response",
"simulate_rir_ism",
]
117 changes: 112 additions & 5 deletions src/torchaudio/prototype/functional/_rir.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,20 +133,24 @@ def _adjust_coeff(coeffs: Union[float, torch.Tensor], name: str) -> torch.Tensor
"""
num_walls = 6
if isinstance(coeffs, float):
if coeffs < 0:
raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}")
return torch.full((1, num_walls), coeffs)
if isinstance(coeffs, Tensor):
if torch.any(coeffs < 0):
raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}")
if coeffs.ndim == 1:
if coeffs.numel() != num_walls:
raise ValueError(
f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor."
f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor. "
f"Found the shape {coeffs.shape}."
)
return coeffs.unsqueeze(0)
if coeffs.ndim == 2:
if coeffs.shape != (7, num_walls):
if coeffs.shape[1] != num_walls:
raise ValueError(
f"The shape of `{name}` must be (7, {num_walls}) when it is a 2D Tensor."
f"Found the shape {coeffs.shape}."
f"The shape of `{name}` must be (NUM_BANDS, {num_walls}) when it "
f"is a 2D Tensor. Found: {coeffs.shape}."
)
return coeffs
raise TypeError(f"`{name}` must be float or Tensor.")
Expand All @@ -169,7 +173,7 @@ def _validate_inputs(
if not (source.ndim == 1 and source.numel() == 3):
raise ValueError(f"`source` must be 1D Tensor with 3 elements. Found {source.shape}.")
if not (mic_array.ndim == 2 and mic_array.shape[1] == 3):
raise ValueError(f"mic_array must be a 2D Tensor with shape (num_channels, 3). Found {mic_array.shape}.")
raise ValueError(f"`mic_array` must be a 2D Tensor with shape (num_channels, 3). Found {mic_array.shape}.")


def simulate_rir_ism(
Expand Down Expand Up @@ -270,3 +274,106 @@ def simulate_rir_ism(
rir = rir[..., :output_length]

return rir


def ray_tracing(
room: torch.Tensor,
source: torch.Tensor,
mic_array: torch.Tensor,
num_rays: int,
absorption: Union[float, torch.Tensor] = 0.0,
scattering: Union[float, torch.Tensor] = 0.0,
mic_radius: float = 0.5,
sound_speed: float = 343.0,
energy_thres: float = 1e-7,
time_thres: float = 10.0,
hist_bin_size: float = 0.004,
) -> torch.Tensor:
r"""Compute energy histogram via ray tracing.
The implementation is based on *pyroomacoustics* :cite:`scheibler2018pyroomacoustics`.
``num_rays`` rays are casted uniformly in all directions from the source;
when a ray intersects a wall, it is reflected and part of its energy is absorbed.
It is also scattered (sent directly to the microphone(s)) according to the ``scattering``
coefficient.
When a ray is close to the microphone, its current energy is recorded in the output
histogram for that given time slot.
.. devices:: CPU
.. properties:: TorchScript
Args:
room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents
three dimensions of the room.
source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`.
mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`.
absorption (float or torch.Tensor, optional): The absorption coefficients of wall materials.
(Default: ``0.0``).
If the type is ``float``, the absorption coefficient is identical to all walls and
all frequencies.
If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, representing absorption
coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, and
``"ceiling"``, respectively.
If ``absorption`` is a 2D Tensor, the shape must be `(num_bands, 6)`.
``num_bands`` is the number of frequency bands (usually 7).
scattering(float or torch.Tensor, optional): The scattering coefficients of wall materials. (Default: ``0.0``)
The shape and type of this parameter is the same as for ``absorption``.
mic_radius(float, optional): The radius of the microphone in meters. (Default: 0.5)
sound_speed (float, optional): The speed of sound in meters per second. (Default: ``343.0``)
energy_thres (float, optional): The energy level below which we stop tracing a ray. (Default: ``1e-7``)
The initial energy of each ray is ``2 / num_rays``.
time_thres (float, optional): The maximal duration for which rays are traced. (Unit: seconds) (Default: 10.0)
hist_bin_size (float, optional): The size of each bin in the output histogram. (Unit: seconds) (Default: 0.004)
Returns:
(torch.Tensor): The 3D histogram(s) where the energy of the traced ray is recorded.
Each bin corresponds to a given time slot.
The shape is `(channel, num_bands, num_bins)`, where
``num_bins = ceil(time_thres / hist_bin_size)``.
If both ``absorption`` and ``scattering`` are floats, then ``num_bands == 1``.
"""
if time_thres < hist_bin_size:
raise ValueError(
"`time_thres` must be greater than `hist_bin_size`. "
f"Found: hist_bin_size={hist_bin_size}, time_thres={time_thres}."
)

if room.dtype != source.dtype or source.dtype != mic_array.dtype:
raise ValueError(
"dtype of `room`, `source` and `mic_array` must match. "
f"Found: `room` ({room.dtype}), `source` ({source.dtype}) and "
f"`mic_array` ({mic_array.dtype})"
)

_validate_inputs(room, source, mic_array)
absorption = _adjust_coeff(absorption, "absorption").to(room.dtype)
scattering = _adjust_coeff(scattering, "scattering").to(room.dtype)

# Bring absorption and scattering to the same shape
if absorption.shape[0] == 1 and scattering.shape[0] > 1:
absorption = absorption.expand(scattering.shape)
if scattering.shape[0] == 1 and absorption.shape[0] > 1:
scattering = scattering.expand(absorption.shape)
if absorption.shape != scattering.shape:
raise ValueError(
"`absorption` and `scattering` must be broadcastable to the same number of bands and walls. "
f"Inferred shapes absorption={absorption.shape} and scattering={scattering.shape}"
)

histograms = torch.ops.torchaudio.ray_tracing(
room,
source,
mic_array,
num_rays,
absorption,
scattering,
mic_radius,
sound_speed,
energy_thres,
time_thres,
hist_bin_size,
)

return histograms
34 changes: 25 additions & 9 deletions test/cpp/rir/wall_collision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,24 @@

using namespace torchaudio::rir;

using DTYPE = double;

struct CollisionTestParam {
// Input
torch::Tensor origin;
torch::Tensor direction;
// Expected
torch::Tensor hit_point;
int next_wall_index;
float hit_distance;
DTYPE hit_distance;
};

CollisionTestParam par(
torch::ArrayRef<float> origin,
torch::ArrayRef<float> direction,
torch::ArrayRef<float> hit_point,
torch::ArrayRef<DTYPE> origin,
torch::ArrayRef<DTYPE> direction,
torch::ArrayRef<DTYPE> hit_point,
int next_wall_index,
float hit_distance) {
DTYPE hit_distance) {
auto dir = torch::tensor(direction);
return {
torch::tensor(origin),
Expand Down Expand Up @@ -50,18 +52,22 @@ TEST_P(Simple3DRoomCollisionTest, CollisionTest3D) {

auto param = GetParam();
auto [hit_point, next_wall_index, hit_distance] =
find_collision_wall<float>(room, param.origin, param.direction);
find_collision_wall<DTYPE>(room, param.origin, param.direction);

EXPECT_EQ(param.next_wall_index, next_wall_index);
EXPECT_FLOAT_EQ(param.hit_distance, hit_distance);
EXPECT_TRUE(torch::allclose(
param.hit_point, hit_point, /*rtol*/ 1e-05, /*atol*/ 1e-07));
EXPECT_NEAR(
param.hit_point[0].item<DTYPE>(), hit_point[0].item<DTYPE>(), 1e-5);
EXPECT_NEAR(
param.hit_point[1].item<DTYPE>(), hit_point[1].item<DTYPE>(), 1e-5);
EXPECT_NEAR(
param.hit_point[2].item<DTYPE>(), hit_point[2].item<DTYPE>(), 1e-5);
}

#define ISQRT2 0.70710678118

INSTANTIATE_TEST_CASE_P(
Collision3DTests,
BasicCollisionTests,
Simple3DRoomCollisionTest,
::testing::Values(
// From 0
Expand Down Expand Up @@ -100,3 +106,13 @@ INSTANTIATE_TEST_CASE_P(
par({.5, .5, 1}, {0.0, -1., -1.}, {.5, .0, .5}, 2, ISQRT2),
par({.5, .5, 1}, {0.0, 1.0, -1.}, {.5, 1., .5}, 3, ISQRT2),
par({.5, .5, 1}, {0.0, 0.0, -1.}, {.5, .5, .0}, 4, 1.0)));

INSTANTIATE_TEST_CASE_P(
CornerCollisionTest,
Simple3DRoomCollisionTest,
::testing::Values(
par({1, 1, 0}, {1., 1., 0.}, {1., 1., 0.}, 1, 0.0),
par({1, 1, 0}, {-1., 1., 0.}, {1., 1., 0.}, 3, 0.0),
par({1, 1, 1}, {1., 1., 1.}, {1., 1., 1.}, 1, 0.0),
par({1, 1, 1}, {-1., 1., 1.}, {1., 1., 1.}, 3, 0.0),
par({1, 1, 1}, {-1., -1., 1.}, {1., 1., 1.}, 5, 0.0)));
Loading

0 comments on commit fa78fb6

Please sign in to comment.