From 49be5918f6ebe4ae39555aa0d629c7ccb3dbc2bf Mon Sep 17 00:00:00 2001
From: moto <855818+mthrok@users.noreply.github.com>
Date: Thu, 21 Sep 2023 10:44:05 -0700
Subject: [PATCH] WIP: Add Ray Tracing (#3604)

Summary:
Revamped version of https://github.com/pytorch/audio/pull/3234 (which was also revamp of https://github.com/pytorch/audio/pull/2850)

Differential Revision: D49197174

Pulled By: mthrok
---
 docs/source/prototype.functional.rst          |   1 +
 .../functional/functional_test_impl.py        | 249 ++++++++++++++++++
 .../pyroomacoustics_compatibility_test.py     |  90 +++++++
 .../torchscript_consistency_test_impl.py      |  39 +++
 torchaudio/csrc/rir/ray_tracing.cpp           |   6 +-
 torchaudio/csrc/rir/wall.h                    |   5 +-
 torchaudio/prototype/functional/__init__.py   |   3 +-
 torchaudio/prototype/functional/_rir.py       | 117 +++++++-
 8 files changed, 500 insertions(+), 10 deletions(-)

diff --git a/docs/source/prototype.functional.rst b/docs/source/prototype.functional.rst
index 72f390c71ac..f0527b4c7ae 100644
--- a/docs/source/prototype.functional.rst
+++ b/docs/source/prototype.functional.rst
@@ -35,4 +35,5 @@ Room Impulse Response Simulation
    :toctree: generated
    :nosignatures:
 
+   ray_tracing
    simulate_rir_ism
diff --git a/test/torchaudio_unittest/prototype/functional/functional_test_impl.py b/test/torchaudio_unittest/prototype/functional/functional_test_impl.py
index 61e6d869337..114a2b5cc8d 100644
--- a/test/torchaudio_unittest/prototype/functional/functional_test_impl.py
+++ b/test/torchaudio_unittest/prototype/functional/functional_test_impl.py
@@ -412,6 +412,255 @@ def test_exp_sigmoid_input_diff(self, linspace_input_values, exp_sigmoid_paramet
 
         self.assertEqual(torch_out, torch.tensor(np_out))
 
+    @parameterized.expand(
+        [
+            # both float
+            (0.1, 0.2, (2, 1, 2500)),
+            # Per-wall
+            ((6,), 0.2, (2, 1, 2500)),
+            (0.1, (6,), (2, 1, 2500)),
+            ((6,), (6,), (2, 1, 2500)),
+            # Per-band and per-wall
+            ((3, 6), 0.2, (2, 3, 2500)),
+            (0.1, (5, 6), (2, 5, 2500)),
+            ((7, 6), (7, 6), (2, 7, 2500)),
+        ]
+    )
+    def test_ray_tracing_output_shape(self, abs_, scat_, expected_shape):
+        if isinstance(abs_, float):
+            absorption = abs_
+        else:
+            absorption = torch.rand(abs_, dtype=self.dtype)
+        if isinstance(scat_, float):
+            scattering = scat_
+        else:
+            scattering = torch.rand(scat_, dtype=self.dtype)
+
+        room_dim = torch.tensor([3, 4, 5], dtype=self.dtype)
+        mic_array = torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=self.dtype)
+        source = torch.tensor([1, 2, 3], dtype=self.dtype)
+        num_rays = 100
+
+        hist = F.ray_tracing(
+            room=room_dim,
+            source=source,
+            mic_array=mic_array,
+            num_rays=num_rays,
+            absorption=absorption,
+            scattering=scattering,
+        )
+        assert hist.shape == expected_shape
+
+    def test_ray_tracing_input_errors(self):
+        room = torch.tensor([3.0, 4.0, 5.0], dtype=self.dtype)
+        source = torch.tensor([0.0, 0.0, 0.0], dtype=self.dtype)
+        mic = torch.tensor([[1.0, 2.0, 3.0]], dtype=self.dtype)
+
+        # baseline. This should not raise
+        _ = F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10)
+
+        # invlaid room shape
+        for invalid in ([[4, 5]], [4, 5, 4, 5]):
+            invalid = torch.tensor(invalid, dtype=self.dtype)
+            with self.assertRaises(ValueError) as cm:
+                F.ray_tracing(room=invalid, source=source, mic_array=mic, num_rays=10)
+
+            error = str(cm.exception)
+            self.assertIn("`room` must be a 1D Tensor with 3 elements.", error)
+            self.assertIn(str(invalid.shape), error)
+
+        # invalid microphone shape
+        invalid = torch.tensor([[[3, 4]]], dtype=self.dtype)
+        with self.assertRaises(ValueError) as cm:
+            F.ray_tracing(room=room, source=source, mic_array=invalid, num_rays=10)
+
+        error = str(cm.exception)
+        self.assertIn("`mic_array` must be a 2D Tensor with shape (num_channels, 3).", error)
+        self.assertIn(str(invalid.shape), error)
+
+        # incompatible dtypes
+        with self.assertRaises(ValueError) as cm:
+            F.ray_tracing(
+                room=room.to(torch.float64),
+                source=source.to(torch.float32),
+                mic_array=mic,
+                num_rays=10,
+            )
+        error = str(cm.exception)
+        self.assertIn("dtype of `room`, `source` and `mic_array` must match.", error)
+        self.assertIn("`room` (torch.float64)", error)
+        self.assertIn("`source` (torch.float32)", error)
+        self.assertIn("`mic_array` (torch.float32)", error)
+
+        # invalid time configuration
+        with self.assertRaises(ValueError) as cm:
+            F.ray_tracing(
+                room=room,
+                source=source,
+                mic_array=mic,
+                num_rays=10,
+                time_thres=10,
+                hist_bin_size=11,
+            )
+        error = str(cm.exception)
+        self.assertIn("`time_thres` must be greater than `hist_bin_size`.", error)
+        self.assertIn("hist_bin_size=11", error)
+        self.assertIn("time_thres=10", error)
+
+        # invalid absorption shape 1D
+        invalid_abs = torch.tensor([1, 2, 3], dtype=self.dtype)
+        with self.assertRaises(ValueError) as cm:
+            F.ray_tracing(
+                room=room,
+                source=source,
+                mic_array=mic,
+                num_rays=10,
+                absorption=invalid_abs,
+            )
+        error = str(cm.exception)
+        self.assertIn("The shape of `absorption` must be (6,) when", error)
+        self.assertIn(str(invalid_abs.shape), error)
+
+        # invalid absorption shape 2D
+        invalid_abs = torch.tensor([[1, 2, 3]], dtype=self.dtype)
+        with self.assertRaises(ValueError) as cm:
+            F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, absorption=invalid_abs)
+        error = str(cm.exception)
+        self.assertIn("The shape of `absorption` must be (NUM_BANDS, 6) when", error)
+        self.assertIn(str(invalid_abs.shape), error)
+
+        # invalid scattering shape 1D
+        invalid_scat = torch.tensor([1, 2, 3], dtype=self.dtype)
+        with self.assertRaises(ValueError) as cm:
+            F.ray_tracing(
+                room=room,
+                source=source,
+                mic_array=mic,
+                num_rays=10,
+                scattering=invalid_scat,
+            )
+        error = str(cm.exception)
+        self.assertIn("The shape of `scattering` must be (6,) when", error)
+        self.assertIn(str(invalid_scat.shape), error)
+
+        # invalid scattering shape 2D
+        invalid_scat = torch.tensor([[1, 2, 3]], dtype=self.dtype)
+        with self.assertRaises(ValueError) as cm:
+            F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, scattering=invalid_scat)
+        error = str(cm.exception)
+        self.assertIn("The shape of `scattering` must be (NUM_BANDS, 6) when", error)
+        self.assertIn(str(invalid_scat.shape), error)
+
+        # TODO: Invalid absorption/scattering value
+
+        # incompatible scattering and absorption
+        abs_ = torch.zeros((7, 6), dtype=self.dtype)
+        scat = torch.zeros((5, 6), dtype=self.dtype)
+        with self.assertRaises(ValueError) as cm:
+            F.ray_tracing(
+                room=room,
+                source=source,
+                mic_array=mic,
+                num_rays=10,
+                absorption=abs_,
+                scattering=scat,
+            )
+        error = str(cm.exception)
+        self.assertIn(
+            "`absorption` and `scattering` must be broadcastable to the same number of bands and walls", error
+        )
+        self.assertIn(f"absorption={abs_.shape}", error)
+        self.assertIn(f"scattering={scat.shape}", error)
+
+        # Make sure passing different shapes for absorption or scattering doesn't raise an error
+        # float and tensor
+        F.ray_tracing(
+            room=room,
+            source=source,
+            mic_array=mic,
+            num_rays=10,
+            absorption=0.1,
+            scattering=torch.randn((5, 6), dtype=self.dtype),
+        )
+        F.ray_tracing(
+            room=room,
+            source=source,
+            mic_array=mic,
+            num_rays=10,
+            absorption=torch.randn((7, 6), dtype=self.dtype),
+            scattering=0.1,
+        )
+        # per-wall only and per-band + per-wall
+        F.ray_tracing(
+            room=room,
+            source=source,
+            mic_array=mic,
+            num_rays=10,
+            absorption=torch.rand(6, dtype=self.dtype),
+            scattering=torch.rand(7, 6, dtype=self.dtype),
+        )
+        F.ray_tracing(
+            room=room,
+            source=source,
+            mic_array=mic,
+            num_rays=10,
+            absorption=torch.rand(7, 6, dtype=self.dtype),
+            scattering=torch.rand(6, dtype=self.dtype),
+        )
+
+    def test_ray_tracing_per_band_per_wall_absorption(self):
+        """Check that when the value of absorption and scattering are the same
+        across walls and frequency bands, the output histograms are:
+        - all equal across frequency bands
+        - equal to simply passing a float value instead of a (num_bands, D) or
+        (D,) tensor.
+        """
+
+        room_dim = torch.tensor([20, 25, 5], dtype=self.dtype)
+        mic_array = torch.tensor([[2, 2, 0], [8, 8, 0]], dtype=self.dtype)
+        source = torch.tensor([7, 6, 0], dtype=self.dtype)
+        num_rays = 1_000
+        ABS, SCAT = 0.1, 0.2
+
+        absorption = torch.full(fill_value=ABS, size=(7, 6), dtype=self.dtype)
+        scattering = torch.full(fill_value=SCAT, size=(7, 6), dtype=self.dtype)
+        hist_per_band_per_wall = F.ray_tracing(
+            room=room_dim,
+            source=source,
+            mic_array=mic_array,
+            num_rays=num_rays,
+            absorption=absorption,
+            scattering=scattering,
+        )
+        absorption = torch.full(fill_value=ABS, size=(6,), dtype=self.dtype)
+        scattering = torch.full(fill_value=SCAT, size=(6,), dtype=self.dtype)
+        hist_per_wall = F.ray_tracing(
+            room=room_dim,
+            source=source,
+            mic_array=mic_array,
+            num_rays=num_rays,
+            absorption=absorption,
+            scattering=scattering,
+        )
+
+        absorption = ABS
+        scattering = SCAT
+        hist_single = F.ray_tracing(
+            room=room_dim,
+            source=source,
+            mic_array=mic_array,
+            num_rays=num_rays,
+            absorption=absorption,
+            scattering=scattering,
+        )
+        self.assertEqual(hist_per_band_per_wall.shape, (2, 7, 2500))
+        self.assertEqual(hist_per_wall.shape, (2, 1, 2500))
+        self.assertEqual(hist_single.shape, (2, 1, 2500))
+        torch.testing.assert_close(hist_single, hist_per_wall)
+
+        hist_single = hist_single.expand(hist_per_band_per_wall.shape)
+        torch.testing.assert_close(hist_single, hist_per_band_per_wall)
+
 
 class Functional64OnlyTestImpl(TestBaseMixin):
     @nested_params(
diff --git a/test/torchaudio_unittest/prototype/functional/pyroomacoustics_compatibility_test.py b/test/torchaudio_unittest/prototype/functional/pyroomacoustics_compatibility_test.py
index 59da1a49ac2..af2281d213e 100644
--- a/test/torchaudio_unittest/prototype/functional/pyroomacoustics_compatibility_test.py
+++ b/test/torchaudio_unittest/prototype/functional/pyroomacoustics_compatibility_test.py
@@ -1,3 +1,5 @@
+import math
+import numpy as np
 import torch
 import torchaudio.prototype.functional as F
 
@@ -9,6 +11,43 @@
     import pyroomacoustics as pra
 
 
+def _pra_ray_tracing(room_dim, absorption, scattering, num_bands, mic_array, source, num_rays, energy_thres, time_thres, hist_bin_size, mic_radius, sound_speed):
+    walls = ["west", "east", "south", "north", "floor", "ceiling"]
+    absorption = absorption.T.tolist()
+    scattering = scattering.T.tolist()
+    freqs = 125 * 2 ** np.arange(num_bands)
+
+    room = pra.ShoeBox(
+        room_dim.tolist(),
+        ray_tracing=True,
+        materials={
+            wall: pra.Material(
+                energy_absorption={"coeffs": absorp, "center_freqs": freqs},
+                scattering={"coeffs": scat, "center_freqs": freqs},
+            )
+            for wall, absorp, scat in zip(walls, absorption, scattering)
+        },
+        air_absorption=False,
+        max_order=0,  # Make sure PRA doesn't use the hybrid method (we just want ray tracing)
+    )
+    room.add_microphone_array(mic_array.T.tolist())
+    room.add_source(source.tolist())
+    room.set_ray_tracing(
+        n_rays=num_rays,
+        energy_thres=energy_thres,
+        time_thres=time_thres,
+        hist_bin_size=hist_bin_size,
+        receiver_radius=mic_radius,
+    )
+    room.set_sound_speed(sound_speed)
+    room.compute_rir()
+    hist_pra = np.array(room.rt_histograms, dtype=np.double)[:, 0, 0]
+
+    # PRA continues the simulation beyond time threshold, but torchaudio does not.
+    num_bins = math.ceil(time_thres / hist_bin_size)
+    return hist_pra[:, :, :num_bins]
+
+
 @skipIfNoModule("pyroomacoustics")
 @skipIfNoRIR
 class CompatibilityTest(PytorchTestCase):
@@ -91,3 +130,54 @@ def test_simulate_rir_ism_multi_band(self, channel):
             expected[i, 0 : room.rir[i][0].shape[0]] = torch.from_numpy(room.rir[i][0])
         actual = F.simulate_rir_ism(room_dim, source, mic_array, max_order, absorption)
         self.assertEqual(expected, actual, atol=1e-3, rtol=1e-3)
+
+    @parameterized.expand(
+        [
+            ([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 1_000),  # 3D with 1 mic
+        ]
+    )
+    def test_ray_tracing_same_results_as_pyroomacoustics(self, room_dim, source, mic_array, num_rays):
+        num_bands = 6
+        energy_thres = 1e-7
+        time_thres = 10.0
+        hist_bin_size = 0.004
+        mic_radius = 0.5
+        sound_speed = 343.0
+
+        absorption = torch.rand((num_bands, 6), dtype=self.dtype)
+        scattering = torch.rand((num_bands, 6), dtype=self.dtype)
+        room_dim = torch.tensor(room_dim, dtype=self.dtype)
+        source = torch.tensor(source, dtype=self.dtype)
+        mic_array = torch.tensor(mic_array, dtype=self.dtype)
+        
+        hist_pra = _pra_ray_tracing(
+            room_dim,
+            absorption,
+            scattering,
+            num_bands,
+            mic_array,
+            source,
+            num_rays,
+            energy_thres,
+            time_thres,
+            hist_bin_size,
+            mic_radius,
+            sound_speed)
+
+        hist = F.ray_tracing(
+            room=room_dim,
+            source=source,
+            mic_array=mic_array,
+            num_rays=num_rays,
+            absorption=absorption,
+            scattering=scattering,
+            sound_speed=sound_speed,
+            mic_radius=mic_radius,
+            energy_thres=energy_thres,
+            time_thres=time_thres,
+            hist_bin_size=hist_bin_size,
+        )
+
+        assert hist.ndim == 3
+        assert hist.shape == hist_pra.shape
+        self.assertEqual(hist, hist_pra)
diff --git a/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py b/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py
index 60806d9afbc..5b947a83857 100644
--- a/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py
+++ b/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py
@@ -112,3 +112,42 @@ def test_simulate_rir_ism_multi_band(self, channel):
             F.simulate_rir_ism,
             (room_dim, source, mic_array, max_order, absorption, None, 81, center_frequency, 343.0, 16000.0),
         )
+
+    @parameterized.expand(
+        [
+            ([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 500),  # 3D with 1 mic
+        ]
+    )
+    def test_ray_tracing(self, room_dim, source, mic_array, num_rays):
+        num_walls = 4 if len(room_dim) == 2 else 6
+        num_bands = 3
+
+        absorption = torch.rand(num_bands, num_walls, dtype=torch.float32)
+        scattering = torch.rand(num_bands, num_walls, dtype=torch.float32)
+
+        energy_thres = 1e-7
+        time_thres = 10.0
+        hist_bin_size = 0.004
+        mic_radius = 0.5
+        sound_speed = 343.0
+
+        room_dim = torch.tensor(room_dim, dtype=self.dtype)
+        source = torch.tensor(source, dtype=self.dtype)
+        mic_array = torch.tensor(mic_array, dtype=self.dtype)
+
+        self._assert_consistency(
+            F.ray_tracing,
+            (
+                room_dim,
+                source,
+                mic_array,
+                num_rays,
+                absorption,
+                scattering,
+                mic_radius,
+                sound_speed,
+                energy_thres,
+                time_thres,
+                hist_bin_size,
+            ),
+        )
diff --git a/torchaudio/csrc/rir/ray_tracing.cpp b/torchaudio/csrc/rir/ray_tracing.cpp
index 839a999ed87..0329061be45 100644
--- a/torchaudio/csrc/rir/ray_tracing.cpp
+++ b/torchaudio/csrc/rir/ray_tracing.cpp
@@ -220,9 +220,13 @@ class RayTracer {
           if (NORM(to_mic - dir * impact_distance) < mic_radius + EPS) {
             // The length of this last hop
             auto travel_dist_at_mic = travel_dist + std::abs(impact_distance);
+            auto bin_idx = get_bin_idx(travel_dist_at_mic);
+            if (bin_idx >= histograms.size(1)) {
+              continue;
+            }
             auto coeff = get_energy_coeff(travel_dist_at_mic, mic_radius_sq);
             auto energy = energies / coeff;
-            histograms[mic_idx][get_bin_idx(travel_dist_at_mic)] += energy;
+            histograms[mic_idx][bin_idx] += energy;
           }
         }
       }
diff --git a/torchaudio/csrc/rir/wall.h b/torchaudio/csrc/rir/wall.h
index a7933b13eb2..cbd6de349e2 100644
--- a/torchaudio/csrc/rir/wall.h
+++ b/torchaudio/csrc/rir/wall.h
@@ -18,7 +18,6 @@ struct Wall {
   const torch::Tensor origin;
   const torch::Tensor normal;
   const torch::Tensor scattering;
-
   const torch::Tensor reflection;
 
   Wall(
@@ -26,8 +25,8 @@ struct Wall {
       const torch::ArrayRef<scalar_t>& normal,
       const torch::Tensor& absorption,
       const torch::Tensor& scattering)
-      : origin(torch::tensor(origin)),
-        normal(torch::tensor(normal)),
+      : origin(torch::tensor(origin).to(scattering.dtype())),
+        normal(torch::tensor(normal).to(scattering.dtype())),
         scattering(scattering),
         reflection(1. - absorption) {}
 };
diff --git a/torchaudio/prototype/functional/__init__.py b/torchaudio/prototype/functional/__init__.py
index 3c08461b70f..20bc181731e 100644
--- a/torchaudio/prototype/functional/__init__.py
+++ b/torchaudio/prototype/functional/__init__.py
@@ -7,7 +7,7 @@
     oscillator_bank,
     sinc_impulse_response,
 )
-from ._rir import simulate_rir_ism
+from ._rir import ray_tracing, simulate_rir_ism
 from .functional import barkscale_fbanks, chroma_filterbank
 
 
@@ -20,6 +20,7 @@
     "filter_waveform",
     "frequency_impulse_response",
     "oscillator_bank",
+    "ray_tracing",
     "sinc_impulse_response",
     "simulate_rir_ism",
 ]
diff --git a/torchaudio/prototype/functional/_rir.py b/torchaudio/prototype/functional/_rir.py
index 959d8e98b60..0e67a5494d2 100644
--- a/torchaudio/prototype/functional/_rir.py
+++ b/torchaudio/prototype/functional/_rir.py
@@ -133,20 +133,24 @@ def _adjust_coeff(coeffs: Union[float, torch.Tensor], name: str) -> torch.Tensor
     """
     num_walls = 6
     if isinstance(coeffs, float):
+        if coeffs < 0:
+            raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}")
         return torch.full((1, num_walls), coeffs)
     if isinstance(coeffs, Tensor):
+        if torch.any(coeffs < 0):
+            raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}")
         if coeffs.ndim == 1:
             if coeffs.numel() != num_walls:
                 raise ValueError(
-                    f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor."
+                    f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor. "
                     f"Found the shape {coeffs.shape}."
                 )
             return coeffs.unsqueeze(0)
         if coeffs.ndim == 2:
-            if coeffs.shape != (7, num_walls):
+            if coeffs.shape[1] != num_walls:
                 raise ValueError(
-                    f"The shape of `{name}` must be (7, {num_walls}) when it is a 2D Tensor."
-                    f"Found the shape {coeffs.shape}."
+                    f"The shape of `{name}` must be (NUM_BANDS, {num_walls}) when it "
+                    f"is a 2D Tensor. Found: {coeffs.shape}."
                 )
             return coeffs
     raise TypeError(f"`{name}` must be float or Tensor.")
@@ -169,7 +173,7 @@ def _validate_inputs(
     if not (source.ndim == 1 and source.numel() == 3):
         raise ValueError(f"`source` must be 1D Tensor with 3 elements. Found {source.shape}.")
     if not (mic_array.ndim == 2 and mic_array.shape[1] == 3):
-        raise ValueError(f"mic_array must be a 2D Tensor with shape (num_channels, 3). Found {mic_array.shape}.")
+        raise ValueError(f"`mic_array` must be a 2D Tensor with shape (num_channels, 3). Found {mic_array.shape}.")
 
 
 def simulate_rir_ism(
@@ -270,3 +274,106 @@ def simulate_rir_ism(
             rir = rir[..., :output_length]
 
     return rir
+
+
+def ray_tracing(
+    room: torch.Tensor,
+    source: torch.Tensor,
+    mic_array: torch.Tensor,
+    num_rays: int,
+    absorption: Union[float, torch.Tensor] = 0.0,
+    scattering: Union[float, torch.Tensor] = 0.0,
+    mic_radius: float = 0.5,
+    sound_speed: float = 343.0,
+    energy_thres: float = 1e-7,
+    time_thres: float = 10.0,
+    hist_bin_size: float = 0.004,
+) -> torch.Tensor:
+    r"""Compute energy histogram via ray tracing.
+
+    The implementation is based on *pyroomacoustics* :cite:`scheibler2018pyroomacoustics`.
+
+    ``num_rays`` rays are casted uniformly in all directions from the source;
+    when a ray intersects a wall, it is reflected and part of its energy is absorbed.
+    It is also scattered (sent directly to the microphone(s)) according to the ``scattering``
+    coefficient.
+    When a ray is close to the microphone, its current energy is recorded in the output
+    histogram for that given time slot.
+
+    .. devices:: CPU
+
+    .. properties:: TorchScript
+
+    Args:
+        room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents
+            three dimensions of the room.
+        source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`.
+        mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`.
+        absorption (float or torch.Tensor, optional): The absorption coefficients of wall materials.
+            (Default: ``0.0``).
+            If the type is ``float``, the absorption coefficient is identical to all walls and
+            all frequencies.
+            If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, representing absorption
+            coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, and
+            ``"ceiling"``, respectively.
+            If ``absorption`` is a 2D Tensor, the shape must be  `(num_bands, 6)`.
+            ``num_bands`` is the number of frequency bands (usually 7).
+        scattering(float or torch.Tensor, optional): The scattering coefficients of wall materials. (Default: ``0.0``)
+            The shape and type of this parameter is the same as for ``absorption``.
+        mic_radius(float, optional): The radius of the microphone in meters. (Default: 0.5)
+        sound_speed (float, optional): The speed of sound in meters per second. (Default: ``343.0``)
+        energy_thres (float, optional): The energy level below which we stop tracing a ray. (Default: ``1e-7``)
+            The initial energy of each ray is ``2 / num_rays``.
+        time_thres (float, optional): The maximal duration for which rays are traced. (Unit: seconds) (Default: 10.0)
+        hist_bin_size (float, optional): The size of each bin in the output histogram. (Unit: seconds) (Default: 0.004)
+
+    Returns:
+        (torch.Tensor): The 3D histogram(s) where the energy of the traced ray is recorded.
+            Each bin corresponds to a given time slot.
+            The shape is `(channel, num_bands, num_bins)`, where
+            ``num_bins = ceil(time_thres / hist_bin_size)``.
+            If both ``absorption`` and ``scattering`` are floats, then ``num_bands == 1``.
+    """
+    if time_thres < hist_bin_size:
+        raise ValueError(
+            "`time_thres` must be greater than `hist_bin_size`. "
+            f"Found: hist_bin_size={hist_bin_size}, time_thres={time_thres}."
+        )
+
+    if room.dtype != source.dtype or source.dtype != mic_array.dtype:
+        raise ValueError(
+            "dtype of `room`, `source` and `mic_array` must match. "
+            f"Found: `room` ({room.dtype}), `source` ({source.dtype}) and "
+            f"`mic_array` ({mic_array.dtype})"
+        )
+
+    _validate_inputs(room, source, mic_array)
+    absorption = _adjust_coeff(absorption, "absorption").to(room.dtype)
+    scattering = _adjust_coeff(scattering, "scattering").to(room.dtype)
+
+    # Bring absorption and scattering to the same shape
+    if absorption.shape[0] == 1 and scattering.shape[0] > 1:
+        absorption = absorption.expand(scattering.shape)
+    if scattering.shape[0] == 1 and absorption.shape[0] > 1:
+        scattering = scattering.expand(absorption.shape)
+    if absorption.shape != scattering.shape:
+        raise ValueError(
+            "`absorption` and `scattering` must be broadcastable to the same number of bands and walls. "
+            f"Inferred shapes absorption={absorption.shape} and scattering={scattering.shape}"
+        )
+
+    histograms = torch.ops.torchaudio.ray_tracing(
+        room,
+        source,
+        mic_array,
+        num_rays,
+        absorption,
+        scattering,
+        mic_radius,
+        sound_speed,
+        energy_thres,
+        time_thres,
+        hist_bin_size,
+    )
+
+    return histograms