From 0be21fb8c05fce222e452ab529a3db8715c6f66c Mon Sep 17 00:00:00 2001 From: Zhaoheng Ni Date: Wed, 5 Apr 2023 18:06:03 -0400 Subject: [PATCH] fix unit tests --- .../functional/functional_test_impl.py | 171 ++++++++---------- .../torchscript_consistency_test_impl.py | 7 +- torchaudio/prototype/functional/_rir.py | 87 ++++----- 3 files changed, 117 insertions(+), 148 deletions(-) diff --git a/test/torchaudio_unittest/prototype/functional/functional_test_impl.py b/test/torchaudio_unittest/prototype/functional/functional_test_impl.py index 5a10986cba8..e304a108355 100644 --- a/test/torchaudio_unittest/prototype/functional/functional_test_impl.py +++ b/test/torchaudio_unittest/prototype/functional/functional_test_impl.py @@ -551,19 +551,24 @@ def test_simulate_rir_ism_multi_band(self, channel): [ (0.1, 0.2, (2, 1, 2500)), # both float # Per-wall - (torch.rand(4), 0.2, (2, 1, 2500)), - (0.1, torch.rand(4), (2, 1, 2500)), - (torch.rand(4), torch.rand(4), (2, 1, 2500)), + (torch.rand(6), 0.2, (2, 1, 2500)), + (0.1, torch.rand(6), (2, 1, 2500)), + (torch.rand(6), torch.rand(6), (2, 1, 2500)), # Per-band and per-wall - (torch.rand(6, 4), 0.2, (2, 6, 2500)), - (0.1, torch.rand(6, 4), (2, 6, 2500)), - (torch.rand(6, 4), torch.rand(6, 4), (2, 6, 2500)), + (torch.rand(4, 6), 0.2, (2, 4, 2500)), + (0.1, torch.rand(4, 6), (2, 4, 2500)), + (torch.rand(4, 6), torch.rand(4, 6), (2, 4, 2500)), ] ) def test_ray_tracing_output_shape(self, absorption, scattering, expected_shape): - room_dim = torch.tensor([20, 25], dtype=self.dtype) - mic_array = torch.tensor([[2, 2], [8, 8]], dtype=self.dtype) - source = torch.tensor([7, 6], dtype=self.dtype) + room_dim = torch.tensor([20, 25, 30], dtype=self.dtype) + mic_array = torch.tensor([[2, 2, 2], [8, 8, 8]], dtype=self.dtype) + source = torch.tensor([7, 6, 5], dtype=self.dtype) + if isinstance(absorption, torch.Tensor): + absorption = absorption.to(self.dtype) + if isinstance(scattering, torch.Tensor): + scattering = scattering.to(self.dtype) + num_rays = 100 hist = F.ray_tracing( @@ -578,94 +583,76 @@ def test_ray_tracing_output_shape(self, absorption, scattering, expected_shape): assert hist.shape == expected_shape def test_ray_tracing_input_errors(self): - with self.assertRaisesRegex(ValueError, "room must be a 1D tensor"): + with self.assertRaisesRegex(ValueError, "room must be a 1D Tensor."): F.ray_tracing( room=torch.tensor([[4, 5]]), source=torch.tensor([0, 0]), mic_array=torch.tensor([[3, 4]]), num_rays=10 ) - with self.assertRaisesRegex(ValueError, "room must be a 1D tensor"): + with self.assertRaisesRegex(ValueError, "The shape of room must be"): F.ray_tracing( room=torch.tensor([4, 5, 4, 5]), source=torch.tensor([0, 0]), mic_array=torch.tensor([[3, 4]]), num_rays=10, ) - with self.assertRaisesRegex(ValueError, r"mic_array must be 1D tensor of shape \(D,\), or 2D tensor"): + with self.assertRaisesRegex(ValueError, "The second dimension of mic_array must be 3"): F.ray_tracing( - room=torch.tensor([4, 5]), source=torch.tensor([0, 0]), mic_array=torch.tensor([[[3, 4]]]), num_rays=10 + room=torch.tensor([4, 5, 6]), + source=torch.tensor([0, 0, 0]), + mic_array=torch.tensor([[3, 4]]), + num_rays=10, ) with self.assertRaisesRegex(ValueError, "room must be of float32 or float64 dtype"): F.ray_tracing( - room=torch.tensor([4, 5]).to(torch.int), - source=torch.tensor([0, 0]), - mic_array=torch.tensor([3, 4]), + room=torch.tensor([4, 5, 6]).to(torch.int), + source=torch.tensor([0, 0, 0]), + mic_array=torch.tensor([[3, 4, 5]]), num_rays=10, ) with self.assertRaisesRegex(ValueError, "dtype of room, source and mic_array must be the same"): F.ray_tracing( - room=torch.tensor([4, 5]).to(torch.float64), - source=torch.tensor([0, 0]).to(torch.float32), - mic_array=torch.tensor([3, 4]), - num_rays=10, - ) - with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"): - F.ray_tracing( - room=torch.tensor([4, 5, 10], dtype=torch.float), - source=torch.tensor([0, 0], dtype=torch.float), - mic_array=torch.tensor([3, 4], dtype=torch.float), - num_rays=10, - ) - with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"): - F.ray_tracing( - room=torch.tensor([4, 5], dtype=torch.float), - source=torch.tensor([0, 0, 0], dtype=torch.float), - mic_array=torch.tensor([3, 4], dtype=torch.float), - num_rays=10, - ) - with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"): - F.ray_tracing( - room=torch.tensor([4, 5, 10], dtype=torch.float), - source=torch.tensor([0, 0, 0], dtype=torch.float), - mic_array=torch.tensor([3, 4], dtype=torch.float), + room=torch.tensor([4, 5, 6]).to(torch.float64), + source=torch.tensor([0, 0, 0]).to(torch.float32), + mic_array=torch.tensor([[3, 4, 5]]), num_rays=10, ) with self.assertRaisesRegex(ValueError, "time_thres=10 must be at least greater than hist_bin_size=11"): F.ray_tracing( - room=torch.tensor([4, 5], dtype=torch.float), - source=torch.tensor([0, 0], dtype=torch.float), - mic_array=torch.tensor([3, 4], dtype=torch.float), + room=torch.tensor([4, 5, 6], dtype=torch.float), + source=torch.tensor([0, 0, 0], dtype=torch.float), + mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float), num_rays=10, time_thres=10, hist_bin_size=11, ) - with self.assertRaisesRegex(ValueError, "The shape of absorption must be"): + with self.assertRaisesRegex(ValueError, "The shape of coefficient must be"): F.ray_tracing( - room=torch.tensor([4, 5], dtype=torch.float), - source=torch.tensor([0, 0], dtype=torch.float), - mic_array=torch.tensor([3, 4], dtype=torch.float), + room=torch.tensor([4, 5, 6], dtype=torch.float), + source=torch.tensor([0, 0, 0], dtype=torch.float), + mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float), num_rays=10, absorption=torch.rand(5, dtype=torch.float), ) - with self.assertRaisesRegex(ValueError, "The shape of scattering must be"): + with self.assertRaisesRegex(ValueError, "The shape of coefficient must be"): F.ray_tracing( - room=torch.tensor([4, 5], dtype=torch.float), - source=torch.tensor([0, 0], dtype=torch.float), - mic_array=torch.tensor([3, 4], dtype=torch.float), + room=torch.tensor([4, 5, 6], dtype=torch.float), + source=torch.tensor([0, 0, 0], dtype=torch.float), + mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float), num_rays=10, scattering=torch.rand(5, 5, dtype=torch.float), ) - with self.assertRaisesRegex(ValueError, "The shape of absorption must be"): + with self.assertRaisesRegex(ValueError, "The shape of coefficient must be"): F.ray_tracing( - room=torch.tensor([4, 5], dtype=torch.float), - source=torch.tensor([0, 0], dtype=torch.float), - mic_array=torch.tensor([3, 4], dtype=torch.float), + room=torch.tensor([4, 5, 6], dtype=torch.float), + source=torch.tensor([0, 0, 0], dtype=torch.float), + mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float), num_rays=10, absorption=torch.rand(5, 5, dtype=torch.float), ) - with self.assertRaisesRegex(ValueError, "The shape of scattering must be"): + with self.assertRaisesRegex(ValueError, "The shape of coefficient must be"): F.ray_tracing( - room=torch.tensor([4, 5], dtype=torch.float), - source=torch.tensor([0, 0], dtype=torch.float), - mic_array=torch.tensor([3, 4], dtype=torch.float), + room=torch.tensor([4, 5, 6], dtype=torch.float), + source=torch.tensor([0, 0, 0], dtype=torch.float), + mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float), num_rays=10, scattering=torch.rand(5, dtype=torch.float), ) @@ -673,48 +660,48 @@ def test_ray_tracing_input_errors(self): ValueError, "absorption and scattering must have the same number of bands and walls" ): F.ray_tracing( - room=torch.tensor([4, 5], dtype=torch.float), - source=torch.tensor([0, 0], dtype=torch.float), - mic_array=torch.tensor([3, 4], dtype=torch.float), + room=torch.tensor([4, 5, 6], dtype=torch.float), + source=torch.tensor([0, 0, 0], dtype=torch.float), + mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float), num_rays=10, - absorption=torch.rand(6, 4, dtype=torch.float), - scattering=torch.rand(5, 4, dtype=torch.float), + absorption=torch.rand(6, 6, dtype=torch.float), + scattering=torch.rand(5, 6, dtype=torch.float), ) # Make sure passing different shapes for absorption or scattering doesn't raise an error # float and tensor F.ray_tracing( - room=torch.tensor([4, 5], dtype=torch.float), - source=torch.tensor([0, 0], dtype=torch.float), - mic_array=torch.tensor([3, 4], dtype=torch.float), + room=torch.tensor([4, 5, 6], dtype=torch.float), + source=torch.tensor([0, 0, 0], dtype=torch.float), + mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float), num_rays=10, absorption=0.1, - scattering=torch.rand(5, 4, dtype=torch.float), + scattering=torch.rand(5, 6, dtype=torch.float), ) F.ray_tracing( - room=torch.tensor([4, 5], dtype=torch.float), - source=torch.tensor([0, 0], dtype=torch.float), - mic_array=torch.tensor([3, 4], dtype=torch.float), + room=torch.tensor([4, 5, 6], dtype=torch.float), + source=torch.tensor([0, 0, 0], dtype=torch.float), + mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float), num_rays=10, - absorption=torch.rand(5, 4, dtype=torch.float), + absorption=torch.rand(5, 6, dtype=torch.float), scattering=0.1, ) # per-wall only and per-band + per-wall F.ray_tracing( - room=torch.tensor([4, 5], dtype=torch.float), - source=torch.tensor([0, 0], dtype=torch.float), - mic_array=torch.tensor([3, 4], dtype=torch.float), + room=torch.tensor([4, 5, 6], dtype=torch.float), + source=torch.tensor([0, 0, 0], dtype=torch.float), + mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float), num_rays=10, - absorption=torch.rand(4, dtype=torch.float), - scattering=torch.rand(6, 4, dtype=torch.float), + absorption=torch.rand(6, dtype=torch.float), + scattering=torch.rand(6, 6, dtype=torch.float), ) F.ray_tracing( - room=torch.tensor([4, 5], dtype=torch.float), - source=torch.tensor([0, 0], dtype=torch.float), - mic_array=torch.tensor([3, 4], dtype=torch.float), + room=torch.tensor([4, 5, 6], dtype=torch.float), + source=torch.tensor([0, 0, 0], dtype=torch.float), + mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float), num_rays=10, - absorption=torch.rand(6, 4, dtype=torch.float), - scattering=torch.rand(4, dtype=torch.float), + absorption=torch.rand(6, 6, dtype=torch.float), + scattering=torch.rand(6, dtype=torch.float), ) def test_ray_tracing_per_band_per_wall_absorption(self): @@ -725,14 +712,14 @@ def test_ray_tracing_per_band_per_wall_absorption(self): (D,) tensor. """ - room_dim = torch.tensor([20, 25], dtype=self.dtype) - mic_array = torch.tensor([[2, 2], [8, 8]], dtype=self.dtype) - source = torch.tensor([7, 6], dtype=self.dtype) + room_dim = torch.tensor([20, 25, 30], dtype=self.dtype) + mic_array = torch.tensor([[2, 2, 2], [8, 8, 8]], dtype=self.dtype) + source = torch.tensor([7, 6, 5], dtype=self.dtype) num_rays = 1_000 ABS, SCAT = 0.1, 0.2 - absorption = torch.full(fill_value=ABS, size=(6, 4), dtype=self.dtype) - scattering = torch.full(fill_value=SCAT, size=(6, 4), dtype=self.dtype) + absorption = torch.full(fill_value=ABS, size=(4, 6), dtype=self.dtype) + scattering = torch.full(fill_value=SCAT, size=(4, 6), dtype=self.dtype) hist_per_band_per_wall = F.ray_tracing( room=room_dim, source=source, @@ -741,8 +728,8 @@ def test_ray_tracing_per_band_per_wall_absorption(self): absorption=absorption, scattering=scattering, ) - absorption = torch.full(fill_value=ABS, size=(4,), dtype=self.dtype) - scattering = torch.full(fill_value=SCAT, size=(4,), dtype=self.dtype) + absorption = torch.full(fill_value=ABS, size=(6,), dtype=self.dtype) + scattering = torch.full(fill_value=SCAT, size=(6,), dtype=self.dtype) hist_per_wall = F.ray_tracing( room=room_dim, source=source, @@ -762,22 +749,20 @@ def test_ray_tracing_per_band_per_wall_absorption(self): absorption=absorption, scattering=scattering, ) - assert hist_per_band_per_wall.shape == (2, 6, 2500) + assert hist_per_band_per_wall.shape == (2, 4, 2500) assert hist_per_wall.shape == (2, 1, 2500) assert hist_single.shape == (2, 1, 2500) torch.testing.assert_close(hist_single, hist_per_wall) - hist_single = hist_single.expand(2, 6, 2500) + hist_single = hist_single.expand(2, 4, 2500) torch.testing.assert_close(hist_single, hist_per_band_per_wall) @parameterized.expand( [ - ([20, 25], [2, 2], [[8, 8], [7, 6]], 10_000), # 2D with 2 mics ([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 1_000), # 3D with 1 mic ] ) def test_ray_tracing_same_results_as_pyroomacoustics(self, room_dim, source, mic_array, num_rays): - walls = ["west", "east", "south", "north"] if len(room_dim) == 3: walls += ["floor", "ceiling"] diff --git a/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py b/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py index 2f769fcb0ae..cd8cef7f898 100644 --- a/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py +++ b/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py @@ -115,16 +115,15 @@ def test_simulate_rir_ism_multi_band(self, channel): @parameterized.expand( [ - ([20, 25], [2, 2], [[8, 8], [7, 6]], 1_000), # 2D with 2 mics ([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 500), # 3D with 1 mic ] ) def test_ray_tracing(self, room_dim, source, mic_array, num_rays): - num_walls = 4 if len(room_dim) == 2 else 6 + num_walls = 6 num_bands = 3 - absorption = torch.rand(num_bands, num_walls, dtype=torch.float32) - scattering = torch.rand(num_bands, num_walls, dtype=torch.float32) + absorption = torch.rand(num_bands, num_walls, dtype=self.dtype) + scattering = torch.rand(num_bands, num_walls, dtype=self.dtype) energy_thres = 1e-7 time_thres = 10.0 diff --git a/torchaudio/prototype/functional/_rir.py b/torchaudio/prototype/functional/_rir.py index a22243ca7d3..b7ac61da628 100644 --- a/torchaudio/prototype/functional/_rir.py +++ b/torchaudio/prototype/functional/_rir.py @@ -110,75 +110,60 @@ def _frac_delay(delay: torch.Tensor, delay_i: torch.Tensor, delay_filter_length: return torch.special.sinc(n - delay) * _hann(n - delay, 2 * pad) -def _adjust_coefficient(coefficient: Union[float, torch.Tensor]) -> torch.Tensor: +def _adjust_coefficient(coefficient: Union[float, torch.Tensor], dype: torch.dtype) -> torch.Tensor: """Validates and converts absorption or scattering parameters to a tensor with appropriate shape""" num_wall = 6 if isinstance(coefficient, float): - absorption = torch.ones(1, num_wall) * coefficient - elif isinstance(absorption, Tensor) and absorption.ndim == 1: - if absorption.shape[0] != num_wall: + coefficient = torch.ones(1, num_wall, dtype=dype) * coefficient + elif isinstance(coefficient, Tensor) and coefficient.ndim == 1: + if coefficient.shape[0] != num_wall: raise ValueError( - "The shape of absorption must be `(6,)` if it is a 1D Tensor." f"Found the shape {absorption.shape}." + "The shape of coefficient must be `(6,)` if it is a 1D Tensor." f"Found the shape {coefficient.shape}." ) - absorption = absorption.unsqueeze(0) - elif isinstance(absorption, Tensor) and absorption.ndim == 2: - if absorption.shape != (7, num_wall): + coefficient = coefficient.unsqueeze(0) + elif isinstance(coefficient, Tensor) and coefficient.ndim == 2: + if coefficient.shape[1] != num_wall: raise ValueError( - "The shape of coefficient must be `(7, 6)` if it is a 2D Tensor." - f"Found the shape of room is 3 and shape of absorption is {coefficient.shape}." + "The shape of coefficient must be `(num_bands, 6)` if it is a 2D Tensor." + f"Found the shape of room is 3 and shape of coefficient is {coefficient.shape}." ) - absorption = absorption + coefficient = coefficient else: - absorption = absorption - return absorption + coefficient = coefficient + return coefficient def _validate_inputs( room: torch.Tensor, source: torch.Tensor, mic_array: torch.Tensor, - absorption: Union[float, torch.Tensor], - scattering: Optional[Union[float, torch.Tensor]] = None, - -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Validate dimensions of input arguments, and normalize different kinds of absorption into the same dimension. +) -> None: + """Validate dimensions of input arguments. Args: room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents three dimensions of the room. source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`. mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`. - absorption (float or torch.Tensor): The absorption coefficients of wall materials. - If the dtype is ``float``, the absorption coefficient is identical for all walls and - all frequencies. - If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, where the values represent - absorption coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, - and ``"ceiling"``, respectively. - If ``absorption`` is a 2D Tensor, the shape must be `(7, 6)`, where 7 represents the number of octave bands. - scattering(float or torch.Tensor or None, optional): The scattering coefficients of wall materials. - (Default: None). The shape and type of this parameter is the same as for ``absorption``. - - Returns: - (torch.Tensor): The absorption Tensor. The shape is `(1, 6)` for single octave band case, - or `(7, 6)` for multi octave band case. - (torch.Tensor or None): The scattering Tensor. The shape is `(1, 6)` for single octave band case, - or `(7, 6)` for multi octave band case. """ if room.ndim != 1: raise ValueError(f"room must be a 1D Tensor. Found {room.shape}.") D = room.shape[0] if D != 3: - raise ValueError(f"room must be a 3D room. Found {room.shape}.") + raise ValueError(f"The shape of room must be `(3,)`. Found {room.shape}.") if source.shape[0] != D: raise ValueError(f"The shape of source must be `(3,)`. Found {source.shape}") if mic_array.ndim != 2: raise ValueError(f"mic_array must be a 2D Tensor. Found {mic_array.shape}.") if mic_array.shape[1] != D: raise ValueError(f"The second dimension of mic_array must be 3. Found {mic_array.shape}.") - absorption = _adjust_coefficient(absorption) - if scattering is not None: - scattering = _adjust_coefficient(scattering) - return absorption, scattering + if room.dtype not in (torch.float32, torch.float64): + raise ValueError(f"room must be of float32 or float64 dtype, got {room.dtype} instead.") + if not (room.dtype == source.dtype == mic_array.dtype): + raise ValueError( + "dtype of room, source and mic_array must be the same. " + f"Got {room.dtype}, {source.dtype}, and {mic_array.dtype}" + ) def simulate_rir_ism( @@ -237,7 +222,8 @@ def simulate_rir_ism( of octave bands are fixed to ``[125.0, 250.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0]``. Users need to tune the values of ``absorption`` to the corresponding frequencies. """ - absorption, _ = _validate_inputs(room, source, mic_array, absorption) + _validate_inputs(room, source, mic_array) + absorption = _adjust_coefficient(absorption, room.dtype) img_location, att = _compute_image_sources(room, source, max_order, absorption) # compute distances between image sources and microphones @@ -315,14 +301,11 @@ def ray_tracing( (Default: ``0.0``). If the dtype is ``float``, the absorption coefficient is identical to all walls and all frequencies. - If ``absorption`` is a 1D Tensor, the shape must be `(4,)` if the room is a 2D room, - representing absorption coefficients of ``"west"``, ``"east"``, ``"south"``, and - ``"north"`` walls, respectively. - Or the shape must be `(6,)` if the room is a 3D room, representing absorption coefficients - of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, and ``"ceiling"``, respectively. - If ``absorption`` is a 2D Tensor, the shape must be `(num_bands, 4)` if the room is a 2D room, - or `(num_bands, 6)` if the room is a 3D room. ``num_bands`` is the number of frequency bands (usually 7), - but you can choose other values. + If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, where the values represent + absorption coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, + and ``"ceiling"``, respectively. + If ``absorption`` is a 2D Tensor, the shape must be `(num_bands, 6)`, where ``num_bands`` is + the number of frequency bands (usually 7), but you can choose other values. scattering(float or torch.Tensor, optional): The scattering coefficients of wall materials. (Default: ``0.0``). The shape and type of this parameter is the same as for ``absorption``. mic_radius(float, optional): The radius of the microphone in meters. (Default: 0.5) @@ -339,14 +322,16 @@ def ray_tracing( """ if time_thres < hist_bin_size: raise ValueError(f"time_thres={time_thres} must be at least greater than hist_bin_size={hist_bin_size}.") - absorption, scattering = _validate_inputs(room, source, mic_array, absorption, scattering) + _validate_inputs(room, source, mic_array) + absorption = _adjust_coefficient(absorption, room.dtype) + scattering = _adjust_coefficient(scattering, room.dtype) # Bring absorption and scattering to the same shape - if absorption.shape[0] == 1 and scattering.shape[0] > 1: + if absorption.shape[0] == 1 and isinstance(scattering, torch.Tensor) and scattering.shape[0] > 1: absorption = absorption.expand(scattering.shape) - if scattering.shape[0] == 1 and absorption.shape[0] > 1: + if isinstance(scattering, torch.Tensor) and scattering.shape[0] == 1 and absorption.shape[0] > 1: scattering = scattering.expand(absorption.shape) - if absorption.shape != scattering.shape: + if isinstance(scattering, torch.Tensor) and absorption.shape != scattering.shape: raise ValueError( "absorption and scattering must have the same number of bands and walls. " f"Inferred shapes are {absorption.shape} and {scattering.shape}"