Skip to content

Commit

Permalink
truncate adjoint filter if input shape small
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Oct 31, 2023
1 parent 9e93740 commit b83cf65
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

### Fixed
- If input to circular filters in adjoint have size smaller than the diameter, instead of erroring, warn user and truncate the filter kernel accordingly.

## [2.5.0rc2] - 2023-10-30

Expand Down
14 changes: 14 additions & 0 deletions tests/test_plugins/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,20 @@ def test_adjoint_utils(strict_binarize):
_ = radius_penalty.evaluate(polyslab.vertices)


@pytest.mark.parametrize(
"input_size_y, log_level_expected", [(13, None), (12, "WARNING"), (11, "WARNING"), (14, None)]
)
def test_adjoint_filter_sizes(log_capture, input_size_y, log_level_expected):
"""Warn if filter size along a dim is smaller than radius."""

signal_in = np.ones((266, input_size_y))

_filter = ConicFilter(radius=0.08, design_region_dl=0.015)
_filter.evaluate(signal_in)

assert_log_level(log_capture, log_level_expected)


def test_sim_data_plot_field(use_emulated_run):
"""Test splitting of regular simulation data into user and server data."""

Expand Down
51 changes: 51 additions & 0 deletions tidy3d/plugins/adjoint/utils/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ....components.base import Tidy3dBaseModel
from ....constants import MICROMETER
from ....log import log


class Filter(Tidy3dBaseModel, ABC):
Expand Down Expand Up @@ -59,6 +60,53 @@ def _deprecate_feature_size(cls, values):
def make_kernel(self, coords_rad: jnp.array) -> jnp.array:
"""Function to make the kernel out of a coordinate grid of radius values."""

@staticmethod
def _check_kernel_size(kernel: jnp.array, signal_in: jnp.array, warn: bool = True) -> jnp.array:
"""Make sure kernel isn't larger than signal and warn and truncate if so."""

kernel_shape = kernel.shape
input_shape = signal_in.shape

pixel_buffers = [in_shape - k_shape for k_shape, in_shape in zip(kernel_shape, input_shape)]

if np.any(np.array(pixel_buffers) < 0):

if warn:

# remove some pixels from the kernel to make things right
remove_pixels_total = np.abs(np.min(pixel_buffers))
remove_pixels_edge = int(np.ceil(remove_pixels_total / 2))
new_kernel = kernel.copy()
for axis, len_dim in enumerate(kernel_shape):
if kernel_shape[axis] > input_shape[axis]:
indices = np.arange(remove_pixels_edge, len_dim - remove_pixels_edge)
new_kernel = new_kernel.take(indices=indices.astype(int), axis=axis)

log.warning(
f"The filter input has shape {input_shape} whereas the "
f"kernel has shape {kernel_shape}. "
"These shapes are incompatible as the input must "
"be larger than the kernel along all dimensions. "
"The kernel will automatically be "
f"resized to {new_kernel.shape} to be less than the input shape. "
"smooth features. If this is unexpected, "
"either reduce the filter 'radius' or increase the input array's size."
)

return new_kernel

else:

raise ValueError(
f"The filter input has shape {input_shape} whereas the "
f"kernel has shape {kernel_shape}. "
"These shapes are incompatible as the input must "
"be larger than the kernel along all dimensions. "
" Either reduce the filter radius or increase the input size."
)

return kernel

def evaluate(self, spatial_data: jnp.array) -> jnp.array:
"""Process on supplied spatial data."""

Expand All @@ -74,6 +122,9 @@ def evaluate(self, spatial_data: jnp.array) -> jnp.array:
# construct the kernel
kernel = self.make_kernel(coords_rad)

# handle when kernel is too large compared to input
kernel = self._check_kernel_size(kernel=kernel, signal_in=rho)

# normalize by the kernel operating on a spatial_data of all ones
num = jsp.signal.convolve(rho, kernel, mode="same")
den = jsp.signal.convolve(jnp.ones_like(rho), kernel, mode="same")
Expand Down

0 comments on commit b83cf65

Please sign in to comment.