From 07a27f243fe644b345c790ad7edd5c8e16567807 Mon Sep 17 00:00:00 2001 From: Tyler Hughes Date: Tue, 31 Oct 2023 11:35:00 -0400 Subject: [PATCH] truncate adjoint filter if input shape small --- CHANGELOG.md | 1 + tests/test_plugins/test_adjoint.py | 14 ++++++++++ tidy3d/plugins/adjoint/utils/filter.py | 37 ++++++++++++++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c5647502f..605f4243b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fixed the duplication of log messages in Jupyter when `set_logging_file` is used. +- If input to circular filters in adjoint have size smaller than the diameter, instead of erroring, warn user and truncate the filter kernel accordingly. ## [2.5.0rc2] - 2023-10-30 diff --git a/tests/test_plugins/test_adjoint.py b/tests/test_plugins/test_adjoint.py index a03cfbfa9..59eaf6b23 100644 --- a/tests/test_plugins/test_adjoint.py +++ b/tests/test_plugins/test_adjoint.py @@ -1473,6 +1473,20 @@ def test_adjoint_utils(strict_binarize): _ = radius_penalty.evaluate(polyslab.vertices) +@pytest.mark.parametrize( + "input_size_y, log_level_expected", [(13, None), (12, "WARNING"), (11, "WARNING"), (14, None)] +) +def test_adjoint_filter_sizes(log_capture, input_size_y, log_level_expected): + """Warn if filter size along a dim is smaller than radius.""" + + signal_in = np.ones((266, input_size_y)) + + _filter = ConicFilter(radius=0.08, design_region_dl=0.015) + _filter.evaluate(signal_in) + + assert_log_level(log_capture, log_level_expected) + + def test_sim_data_plot_field(use_emulated_run): """Test splitting of regular simulation data into user and server data.""" diff --git a/tidy3d/plugins/adjoint/utils/filter.py b/tidy3d/plugins/adjoint/utils/filter.py index 89487d07e..dc81f66a2 100644 --- a/tidy3d/plugins/adjoint/utils/filter.py +++ b/tidy3d/plugins/adjoint/utils/filter.py @@ -8,6 +8,7 @@ from ....components.base import Tidy3dBaseModel from ....constants import MICROMETER +from ....log import log class Filter(Tidy3dBaseModel, ABC): @@ -59,6 +60,39 @@ def _deprecate_feature_size(cls, values): def make_kernel(self, coords_rad: jnp.array) -> jnp.array: """Function to make the kernel out of a coordinate grid of radius values.""" + @staticmethod + def _check_kernel_size(kernel: jnp.array, signal_in: jnp.array) -> jnp.array: + """Make sure kernel isn't larger than signal and warn and truncate if so.""" + + kernel_shape = kernel.shape + input_shape = signal_in.shape + + if any((k_shape > in_shape for k_shape, in_shape in zip(kernel_shape, input_shape))): + + # remove some pixels from the kernel to make things right + new_kernel = kernel.copy() + for axis, (len_kernel, len_input) in enumerate(zip(kernel_shape, input_shape)): + if len_kernel > len_input: + rm_pixels_total = len_kernel - len_input + rm_pixels_edge = int(np.ceil(rm_pixels_total / 2)) + indices_truncated = np.arange(rm_pixels_edge, len_kernel - rm_pixels_edge) + new_kernel = new_kernel.take(indices=indices_truncated.astype(int), axis=axis) + + log.warning( + f"The filter input has shape {input_shape} whereas the " + f"kernel has shape {kernel_shape}. " + "These shapes are incompatible as the input must " + "be larger than the kernel along all dimensions. " + "The kernel will automatically be " + f"resized to {new_kernel.shape} to be less than the input shape. " + "If this is unexpected, " + "either reduce the filter 'radius' or increase the input array's size." + ) + + return new_kernel + + return kernel + def evaluate(self, spatial_data: jnp.array) -> jnp.array: """Process on supplied spatial data.""" @@ -74,6 +108,9 @@ def evaluate(self, spatial_data: jnp.array) -> jnp.array: # construct the kernel kernel = self.make_kernel(coords_rad) + # handle when kernel is too large compared to input + kernel = self._check_kernel_size(kernel=kernel, signal_in=rho) + # normalize by the kernel operating on a spatial_data of all ones num = jsp.signal.convolve(rho, kernel, mode="same") den = jsp.signal.convolve(jnp.ones_like(rho), kernel, mode="same")