From b83cf6506f332ca8f52e1dbec44311872e8ba010 Mon Sep 17 00:00:00 2001 From: Tyler Hughes Date: Tue, 31 Oct 2023 11:35:00 -0400 Subject: [PATCH] truncate adjoint filter if input shape small --- CHANGELOG.md | 1 + tests/test_plugins/test_adjoint.py | 14 +++++++ tidy3d/plugins/adjoint/utils/filter.py | 51 ++++++++++++++++++++++++++ 3 files changed, 66 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7de09b686d..be84728f54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed ### Fixed +- If input to circular filters in adjoint have size smaller than the diameter, instead of erroring, warn user and truncate the filter kernel accordingly. ## [2.5.0rc2] - 2023-10-30 diff --git a/tests/test_plugins/test_adjoint.py b/tests/test_plugins/test_adjoint.py index a03cfbfa97..59eaf6b231 100644 --- a/tests/test_plugins/test_adjoint.py +++ b/tests/test_plugins/test_adjoint.py @@ -1473,6 +1473,20 @@ def test_adjoint_utils(strict_binarize): _ = radius_penalty.evaluate(polyslab.vertices) +@pytest.mark.parametrize( + "input_size_y, log_level_expected", [(13, None), (12, "WARNING"), (11, "WARNING"), (14, None)] +) +def test_adjoint_filter_sizes(log_capture, input_size_y, log_level_expected): + """Warn if filter size along a dim is smaller than radius.""" + + signal_in = np.ones((266, input_size_y)) + + _filter = ConicFilter(radius=0.08, design_region_dl=0.015) + _filter.evaluate(signal_in) + + assert_log_level(log_capture, log_level_expected) + + def test_sim_data_plot_field(use_emulated_run): """Test splitting of regular simulation data into user and server data.""" diff --git a/tidy3d/plugins/adjoint/utils/filter.py b/tidy3d/plugins/adjoint/utils/filter.py index 89487d07e1..dee2813d48 100644 --- a/tidy3d/plugins/adjoint/utils/filter.py +++ b/tidy3d/plugins/adjoint/utils/filter.py @@ -8,6 +8,7 @@ from ....components.base import Tidy3dBaseModel from ....constants import MICROMETER +from ....log import log class Filter(Tidy3dBaseModel, ABC): @@ -59,6 +60,53 @@ def _deprecate_feature_size(cls, values): def make_kernel(self, coords_rad: jnp.array) -> jnp.array: """Function to make the kernel out of a coordinate grid of radius values.""" + @staticmethod + def _check_kernel_size(kernel: jnp.array, signal_in: jnp.array, warn: bool = True) -> jnp.array: + """Make sure kernel isn't larger than signal and warn and truncate if so.""" + + kernel_shape = kernel.shape + input_shape = signal_in.shape + + pixel_buffers = [in_shape - k_shape for k_shape, in_shape in zip(kernel_shape, input_shape)] + + if np.any(np.array(pixel_buffers) < 0): + + if warn: + + # remove some pixels from the kernel to make things right + remove_pixels_total = np.abs(np.min(pixel_buffers)) + remove_pixels_edge = int(np.ceil(remove_pixels_total / 2)) + new_kernel = kernel.copy() + for axis, len_dim in enumerate(kernel_shape): + if kernel_shape[axis] > input_shape[axis]: + indices = np.arange(remove_pixels_edge, len_dim - remove_pixels_edge) + new_kernel = new_kernel.take(indices=indices.astype(int), axis=axis) + + log.warning( + f"The filter input has shape {input_shape} whereas the " + f"kernel has shape {kernel_shape}. " + "These shapes are incompatible as the input must " + "be larger than the kernel along all dimensions. " + "The kernel will automatically be " + f"resized to {new_kernel.shape} to be less than the input shape. " + "smooth features. If this is unexpected, " + "either reduce the filter 'radius' or increase the input array's size." + ) + + return new_kernel + + else: + + raise ValueError( + f"The filter input has shape {input_shape} whereas the " + f"kernel has shape {kernel_shape}. " + "These shapes are incompatible as the input must " + "be larger than the kernel along all dimensions. " + " Either reduce the filter radius or increase the input size." + ) + + return kernel + def evaluate(self, spatial_data: jnp.array) -> jnp.array: """Process on supplied spatial data.""" @@ -74,6 +122,9 @@ def evaluate(self, spatial_data: jnp.array) -> jnp.array: # construct the kernel kernel = self.make_kernel(coords_rad) + # handle when kernel is too large compared to input + kernel = self._check_kernel_size(kernel=kernel, signal_in=rho) + # normalize by the kernel operating on a spatial_data of all ones num = jsp.signal.convolve(rho, kernel, mode="same") den = jsp.signal.convolve(jnp.ones_like(rho), kernel, mode="same")