From c2352d0065df0922281aa89e8523c2988cc688b4 Mon Sep 17 00:00:00 2001 From: Yannick Augenstein Date: Tue, 9 Jul 2024 15:46:16 +0200 Subject: [PATCH] Add class-based filters --- tidy3d/plugins/autograd/invdes/filters.py | 157 ++++++++++++++++++---- 1 file changed, 134 insertions(+), 23 deletions(-) diff --git a/tidy3d/plugins/autograd/invdes/filters.py b/tidy3d/plugins/autograd/invdes/filters.py index 999febe954..03cc0a82ab 100644 --- a/tidy3d/plugins/autograd/invdes/filters.py +++ b/tidy3d/plugins/autograd/invdes/filters.py @@ -1,13 +1,133 @@ -from functools import partial -from typing import Tuple, Union +import abc +from functools import lru_cache, partial +from typing import Callable, Iterable, Tuple, Union import numpy as np +import pydantic as pd + +from tidy3d.components.base import Tidy3dBaseModel from ..functions import convolve from ..types import KernelType, PaddingType from ..utilities import get_kernel_size_px, make_kernel +class AbstractFilter(Tidy3dBaseModel, abc.ABC): + """A filter class for creating and applying convolution filters. + + Parameters + ---------- + kernel_size : Tuple[pd.PositiveInt, ...] + Size of the kernel in pixels for each dimension. + normalize : bool = True + Whether to normalize the kernel so that it sums to 1. + padding : PaddingType = "reflect" + The padding mode to use. + """ + + kernel_size: Tuple[pd.PositiveInt, ...] = pd.Field( + ..., description="Size of the kernel in pixels for each dimension." + ) + normalize: bool = pd.Field( + True, description="Whether to normalize the kernel so that it sums to 1." + ) + padding: PaddingType = pd.Field("reflect", description="The padding mode to use.") + + @classmethod + def from_radius_dl( + cls, radius: Union[float, Tuple[float, ...]], dl: Union[float, Tuple[float, ...]], **kwargs + ) -> "AbstractFilter": + """Create a filter from radius and grid spacing. + + Parameters + ---------- + radius : Union[float, Tuple[float, ...]] + The radius of the kernel. Can be a scalar or a tuple. + dl : Union[float, Tuple[float, ...]] + The grid spacing. Can be a scalar or a tuple. + **kwargs + Additional keyword arguments to pass to the filter constructor. + + Returns + ------- + AbstractFilter + An instance of the filter. + """ + kernel_size = get_kernel_size_px(radius=radius, dl=dl) + return cls(kernel_size, **kwargs) + + @staticmethod + @abc.abstractmethod + def get_kernel(size_px: Iterable[int], normalize: bool) -> np.ndarray: + """Get the kernel for the filter. + + Parameters + ---------- + size_px : Iterable[int] + Size of the kernel in pixels for each dimension. + normalize : bool + Whether to normalize the kernel so that it sums to 1. + + Returns + ------- + np.ndarray + The kernel. + """ + ... + + def __call__(self, array: np.ndarray) -> np.ndarray: + """Apply the filter to an input array. + + Parameters + ---------- + array : np.ndarray + The input array to filter. + + Returns + ------- + np.ndarray + The filtered array. + """ + original_shape = array.shape + squeezed_array = np.squeeze(array) + size_px = self.kernel_size + if len(size_px) != squeezed_array.ndim: + size_px *= squeezed_array.ndim + kernel = self.get_kernel(size_px, self.normalize) + convolved_array = convolve(squeezed_array, kernel, padding=self.padding) + return np.reshape(convolved_array, original_shape) + + +class ConicFilter(AbstractFilter): + """A conic filter for creating and applying convolution filters.""" + + @staticmethod + @lru_cache(maxsize=1) + def get_kernel(size_px: Iterable[int], normalize: bool) -> np.ndarray: + """Get the conic kernel. + + See Also + -------- + :func:`~filters.AbstractFilter.get_kernel` For full method documentation. + """ + return make_kernel(kernel_type="conic", size=size_px, normalize=normalize) + + +class CircularFilter(AbstractFilter): + """A circular filter for creating and applying convolution filters.""" + + @staticmethod + @lru_cache(maxsize=1) + def get_kernel(size_px: Iterable[int], normalize: bool) -> np.ndarray: + """Get the circular kernel. + + See Also + -------- + :func:`~filters.AbstractFilter.get_kernel` For full method documentation. + """ + return make_kernel(kernel_type="circular", size=size_px, normalize=normalize) + + def _get_kernel_size( radius: Union[float, Tuple[float, ...]], dl: Union[float, Tuple[float, ...]], @@ -51,7 +171,7 @@ def make_filter( normalize: bool = True, padding: PaddingType = "reflect", filter_type: KernelType, -): +) -> Callable[[np.ndarray], np.ndarray]: """Create a filter function based on the specified kernel type and size. Parameters @@ -71,29 +191,20 @@ def make_filter( Returns ------- - function + Callable[[np.ndarray], np.ndarray] A function that applies the created filter to an input array. """ - _kernel = {} - kernel_size = _get_kernel_size(radius, dl, size_px) - def _filter(array): - original_shape = array.shape - squeezed_array = np.squeeze(array) - - if squeezed_array.ndim not in _kernel: - ks = kernel_size - if len(ks) != squeezed_array.ndim: - ks *= squeezed_array.ndim - _kernel[squeezed_array.ndim] = make_kernel( - kernel_type=filter_type, size=ks, normalize=normalize - ) - - convolved_array = convolve(squeezed_array, _kernel[squeezed_array.ndim], padding=padding) - return np.reshape(convolved_array, original_shape) + if filter_type == "conic": + filter_class = ConicFilter + elif filter_type == "circular": + filter_class = CircularFilter + else: + raise ValueError(f"Unsupported filter_type: {filter_type}") - return _filter + filter_instance = filter_class(kernel_size=kernel_size, normalize=normalize, padding=padding) + return filter_instance make_conic_filter = partial(make_filter, filter_type="conic") @@ -101,7 +212,7 @@ def _filter(array): See Also -------- -make_filter : Function to create a filter based on the specified kernel type and size. +:func:`~filters.make_filter` : Function to create a filter based on the specified kernel type and size. """ make_circular_filter = partial(make_filter, filter_type="circular") @@ -109,5 +220,5 @@ def _filter(array): See Also -------- -make_filter : Function to create a filter based on the specified kernel type and size. +:func:`~filters.make_filter` : Function to create a filter based on the specified kernel type and size. """