From fea9a28890605c8951f761df29e1836808d5e103 Mon Sep 17 00:00:00 2001 From: Tyler Hughes Date: Mon, 17 Jun 2024 03:54:11 +0200 Subject: [PATCH] autograd geometry group --- CHANGELOG.md | 3 + tests/test_components/test_autograd.py | 108 ++++++++++------ tidy3d/components/autograd.py | 84 ------------ tidy3d/components/autograd/__init__.py | 23 ++++ .../components/autograd/derivative_utils.py | 121 ++++++++++++++++++ tidy3d/components/autograd/types.py | 49 +++++++ tidy3d/components/autograd/utils.py | 15 +++ tidy3d/components/base.py | 3 +- tidy3d/components/geometry/base.py | 99 ++++++-------- tidy3d/components/geometry/polyslab.py | 65 ++++------ tidy3d/components/medium.py | 62 +++------ tidy3d/components/structure.py | 33 ++--- tidy3d/plugins/autograd/README.md | 4 +- tidy3d/web/api/autograd/autograd.py | 18 ++- 14 files changed, 388 insertions(+), 299 deletions(-) delete mode 100644 tidy3d/components/autograd.py create mode 100644 tidy3d/components/autograd/__init__.py create mode 100644 tidy3d/components/autograd/derivative_utils.py create mode 100644 tidy3d/components/autograd/types.py create mode 100644 tidy3d/components/autograd/utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 345ec4a24..4666b88d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,9 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [2.8.0rc1] ### Added +- Support for differentiation with respect to `GeometryGroup.geometries` elements. - Users can now export `SimulationData` to MATLAB `.mat` files with the `to_mat_file` method. - Introduce RF material library. Users can now export `rf_material_library` from `tidy3d.plugins.microwave`. +### Changed + ### Fixed - Bug where boundary layers would be plotted too small in 2D simulations. diff --git a/tests/test_components/test_autograd.py b/tests/test_components/test_autograd.py index ecc576d03..0ee6d366b 100644 --- a/tests/test_components/test_autograd.py +++ b/tests/test_components/test_autograd.py @@ -243,6 +243,18 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: medium=med, ) + # geometry group + geo_group = td.Structure( + geometry=td.GeometryGroup( + geometries=[ + medium.geometry, + center_list.geometry, + size_element.geometry, + ], + ), + medium=td.Medium(permittivity=eps, conductivity=conductivity), + ) + return dict( medium=medium, center_list=center_list, @@ -250,6 +262,7 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: custom_med=custom_med, custom_med_vec=custom_med_vec, polyslab=polyslab, + geo_group=geo_group, ) @@ -336,6 +349,7 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = False) -> None: "custom_med", "custom_med_vec", "polyslab", + "geo_group", ) monitor_keys_ = ("mode", "diff", "field_vol", "field_point") @@ -356,7 +370,7 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = False) -> None: args = [("polyslab", "mode")] -# args = [("custom_med_vec", "mode")] +# args = [("geo_group", "mode")] def get_functions(structure_key: str, monitor_key: str) -> typing.Callable: @@ -716,42 +730,62 @@ def f3(x): # @pytest.mark.timeout(18.0) -# def test_many_structures_timeout(): -# """Test that a metalens-like simulation with many structures can be initialized fast enough.""" - -# with cProfile.Profile() as pr: -# import time - -# t = time.time() - -# Nx, Ny = 200, 200 -# sim_size = [Nx, Ny, 5] - -# geoms = [] -# for ix in range(Nx): -# for iy in range(Ny): -# verts = ((ix, iy), (ix + 0.5, iy), (ix + 0.5, iy + 0.5), (ix, iy + 0.5)) -# geom = td.PolySlab(slab_bounds=(0, 1), vertices=verts) -# geoms.append(geom) - -# metalens = td.Structure( -# geometry=td.GeometryGroup(geometries=geoms), -# medium=td.material_library["Si3N4"]["Horiba"], -# ) - -# src = td.PlaneWave( -# source_time=td.GaussianPulse(freq0=2.5e14, fwidth=1e13), -# center=(0, 0, -1), -# size=(td.inf, td.inf, 0), -# direction="+", -# ) - -# sim = td.Simulation(size=sim_size, structures=[metalens], sources=[src], run_time=1e-12) - -# t2 = time.time() - t -# pr.print_stats(sort="cumtime") -# pr.dump_stats("sim_test.prof") -# print(f"structures took {t2} seconds") +def _test_many_structures(): + """Test that a metalens-like simulation with many structures can be initialized fast enough.""" + + with cProfile.Profile() as pr: + import time + + t = time.time() + + N_length = 200 + Nx, Ny = N_length, N_length + sim_size = [Nx, Ny, 5] + + def f(x): + monitor, postprocess = make_monitors()["field_point"] + monitor = monitor.updated_copy(center=(0, 0, 0)) + + geoms = [] + for ix in range(Nx): + for iy in range(Ny): + ix = ix + x + iy = iy + x + verts = ((ix, iy), (ix + 0.5, iy), (ix + 0.5, iy + 0.5), (ix, iy + 0.5)) + geom = td.PolySlab(slab_bounds=(0, 1), vertices=verts) + geoms.append(geom) + + metalens = td.Structure( + geometry=td.GeometryGroup(geometries=geoms), + medium=td.material_library["Si3N4"]["Horiba"], + ) + + src = td.PlaneWave( + source_time=td.GaussianPulse(freq0=2.5e14, fwidth=1e13), + center=(0, 0, -1), + size=(td.inf, td.inf, 0), + direction="+", + ) + + sim = td.Simulation( + size=sim_size, + structures=[metalens], + sources=[src], + monitors=[monitor], + run_time=1e-12, + ) + + data = run_emulated(sim, task_name="test") + return postprocess(data, data[monitor.name]) + + x0 = 0.0 + ag.grad(f)(x0) + + t2 = time.time() - t + pr.print_stats(sort="cumtime") + pr.dump_stats("sim_test.prof") + print(f"structures took {t2} seconds") + """ times (tyler's system) * original : 35 sec diff --git a/tidy3d/components/autograd.py b/tidy3d/components/autograd.py deleted file mode 100644 index 25080c1cb..000000000 --- a/tidy3d/components/autograd.py +++ /dev/null @@ -1,84 +0,0 @@ -# utilities for working with autograd - -import copy -import typing - -import numpy as np -import xarray as xr -from autograd.builtins import dict as dict_ag -from autograd.extend import Box, defvjp, primitive -from autograd.tracer import getval - -from tidy3d.components.type_util import _add_schema - -from .types import ArrayFloat2D, ArrayLike, Bound, Size1D - -# add schema to the Box -_add_schema(Box, title="AutogradBox", field_type_str="autograd.tracer.Box") - -# make sure Boxes in tidy3d properly define VJPs for copy operations, for computational graph -_copy = primitive(copy.copy) -_deepcopy = primitive(copy.deepcopy) - -defvjp(_copy, lambda ans, x: lambda g: _copy(g)) -defvjp(_deepcopy, lambda ans, x, memo: lambda g: _deepcopy(g, memo)) - -Box.__copy__ = lambda v: _copy(v) -Box.__deepcopy__ = lambda v, memo: _deepcopy(v, memo) - -# Types for floats, or collections of floats that can also be autograd tracers -TracedFloat = typing.Union[float, Box] -TracedSize1D = typing.Union[Size1D, Box] -TracedSize = typing.Union[tuple[TracedSize1D, TracedSize1D, TracedSize1D], Box] -TracedCoordinate = typing.Union[tuple[TracedFloat, TracedFloat, TracedFloat], Box] -TracedVertices = typing.Union[ArrayFloat2D, Box] - - -# The data type that we pass in and out of the web.run() @autograd.primitive -AutogradTraced = typing.Union[Box, ArrayLike] -AutogradFieldMap = dict_ag[tuple[str, ...], AutogradTraced] - - -def get_static(x: typing.Any) -> typing.Any: - """Get the 'static' (untraced) version of some value.""" - return getval(x) - - -# TODO: could we move this into a DataArray method? -def integrate_within_bounds(arr: xr.DataArray, dims: list[str], bounds: Bound) -> xr.DataArray: - """integrate a data array within bounds, assumes bounds are [2, N] for N dims.""" - - _arr = arr.copy() - - # order bounds with dimension first (N, 2) - bounds = np.array(bounds).T - - all_coords = {} - - # loop over all dimensions - for dim, (bmin, bmax) in zip(dims, bounds): - bmin = get_static(bmin) - bmax = get_static(bmax) - - coord_values = _arr.coords[dim].values - - # reset all coordinates outside of bounds to the bounds, so that dL = 0 in integral - coord_values[coord_values < bmin] = bmin - coord_values[coord_values > bmax] = bmax - - all_coords[dim] = coord_values - - _arr = _arr.assign_coords(**all_coords) - - # uses trapezoidal rule - # https://docs.xarray.dev/en/stable/generated/xarray.DataArray.integrate.html - return _arr.integrate(coord=dims) - - -__all__ = [ - "Box", - "primitive", - "defvjp", - "get_static", - "integrate_within_bounds", -] diff --git a/tidy3d/components/autograd/__init__.py b/tidy3d/components/autograd/__init__.py new file mode 100644 index 000000000..f8ce7168b --- /dev/null +++ b/tidy3d/components/autograd/__init__.py @@ -0,0 +1,23 @@ +from .types import ( + AutogradFieldMap, + AutogradTraced, + TracedCoordinate, + TracedFloat, + TracedSize, + TracedSize1D, + TracedVertices, +) +from .utils import get_static + +__all__ = [ + "TracedFloat", + "TracedSize1D", + "TracedSize", + "TracedCoordinate", + "TracedVertices", + "AutogradTraced", + "AutogradFieldMap", + "get_static", + "integrate_within_bounds", + "DerivativeInfo", +] diff --git a/tidy3d/components/autograd/derivative_utils.py b/tidy3d/components/autograd/derivative_utils.py new file mode 100644 index 000000000..0bcd61b22 --- /dev/null +++ b/tidy3d/components/autograd/derivative_utils.py @@ -0,0 +1,121 @@ +# utilities for autograd derivative passing +from __future__ import annotations + +import numpy as np +import pydantic.v1 as pd +import xarray as xr + +from ..base import Tidy3dBaseModel +from ..data.data_array import ScalarFieldDataArray +from ..types import Bound, tidycomplex +from .types import PathType +from .utils import get_static + +# we do this because importing these creates circular imports +FieldData = dict[str, ScalarFieldDataArray] +PermittivityData = dict[str, ScalarFieldDataArray] + + +class DerivativeInfo(Tidy3dBaseModel): + """Stores derivative information passed to the ``.compute_derivatives`` methods.""" + + paths: list[PathType] = pd.Field( + ..., + title="Paths to Traced Fields", + description="List of paths to the traced fields that need derivatives calculated.", + ) + + E_der_map: FieldData = pd.Field( + ..., + title="Electric Field Gradient Map", + description='Dataset where the field components ``("Ex", "Ey", "Ez")`` store the ' + "multiplication of the forward and adjoint electric fields. The tangential components " + "of this dataset is used when computing adjoint gradients for shifting boundaries. " + "All components are used when computing volume-based gradients.", + ) + + D_der_map: FieldData = pd.Field( + ..., + title="Displacement Field Gradient Map", + description='Dataset where the field components ``("Ex", "Ey", "Ez")`` store the ' + "multiplication of the forward and adjoint displacement fields. The normal component " + "of this dataset is used when computing adjoint gradients for shifting boundaries.", + ) + + eps_data: PermittivityData = pd.Field( + ..., + title="Permittivity Dataset", + description="Dataset of relative permittivity values along all three dimensions. " + "Used for automatically computing permittivity inside or outside of a simple geometry.", + ) + + eps_in: tidycomplex = pd.Field( + title="Permittivity Inside", + description="Permittivity inside of the ``Structure``. " + "Typically computed from ``Structure.medium.eps_model``." + "Used when it can not be computed from ``eps_data`` or when ``eps_approx==True``.", + ) + + eps_out: tidycomplex = pd.Field( + ..., + title="Permittivity Outside", + description="Permittivity outside of the ``Structure``. " + "Typically computed from ``Simulation.medium.eps_model``." + "Used when it can not be computed from ``eps_data`` or when ``eps_approx==True``.", + ) + + bounds: Bound = pd.Field( + ..., + title="Geometry Bounds", + description="Bounds corresponding to the structure, used in ``Medium`` calculations.", + ) + + eps_approx: bool = pd.Field( + False, + title="Use Permittivity Approximation", + description="If ``True``, approximates outside permittivity using ``Simulation.medium``" + "and the inside permittivity using ``Structure.medium``. " + "Only set ``True`` for ``GeometryGroup`` handling where it is difficult to automatically " + "evaluate the inside and outside relative permittivity for each geometry.", + ) + + def updated_paths(self, paths: list[PathType]) -> DerivativeInfo: + """Update this ``DerivativeInfo`` with new set of paths.""" + return self.updated_copy(paths=paths) + + +# TODO: could we move this into a DataArray method? +def integrate_within_bounds(arr: xr.DataArray, dims: list[str], bounds: Bound) -> xr.DataArray: + """integrate a data array within bounds, assumes bounds are [2, N] for N dims.""" + + _arr = arr.copy() + + # order bounds with dimension first (N, 2) + bounds = np.array(bounds).T + + all_coords = {} + + # loop over all dimensions + for dim, (bmin, bmax) in zip(dims, bounds): + bmin = get_static(bmin) + bmax = get_static(bmax) + + coord_values = _arr.coords[dim].values + + # reset all coordinates outside of bounds to the bounds, so that dL = 0 in integral + coord_values[coord_values < bmin] = bmin + coord_values[coord_values > bmax] = bmax + + all_coords[dim] = coord_values + + _arr = _arr.assign_coords(**all_coords) + + # uses trapezoidal rule + # https://docs.xarray.dev/en/stable/generated/xarray.DataArray.integrate.html + return _arr.integrate(coord=dims) + + +__all__ = [ + "integrate_within_bounds", + "DerivativeInfo", +] diff --git a/tidy3d/components/autograd/types.py b/tidy3d/components/autograd/types.py new file mode 100644 index 000000000..da40ea6e5 --- /dev/null +++ b/tidy3d/components/autograd/types.py @@ -0,0 +1,49 @@ +# type information for autograd + +# utilities for working with autograd + +import copy +import typing + +from autograd.builtins import dict as dict_ag +from autograd.extend import Box, defvjp, primitive + +from tidy3d.components.type_util import _add_schema + +from ..types import ArrayFloat2D, ArrayLike, Size1D + +# add schema to the Box +_add_schema(Box, title="AutogradBox", field_type_str="autograd.tracer.Box") + +# make sure Boxes in tidy3d properly define VJPs for copy operations, for computational graph +_copy = primitive(copy.copy) +_deepcopy = primitive(copy.deepcopy) + +defvjp(_copy, lambda ans, x: lambda g: _copy(g)) +defvjp(_deepcopy, lambda ans, x, memo: lambda g: _deepcopy(g, memo)) + +Box.__copy__ = lambda v: _copy(v) +Box.__deepcopy__ = lambda v, memo: _deepcopy(v, memo) + +# Types for floats, or collections of floats that can also be autograd tracers +TracedFloat = typing.Union[float, Box] +TracedSize1D = typing.Union[Size1D, Box] +TracedSize = typing.Union[tuple[TracedSize1D, TracedSize1D, TracedSize1D], Box] +TracedCoordinate = typing.Union[tuple[TracedFloat, TracedFloat, TracedFloat], Box] +TracedVertices = typing.Union[ArrayFloat2D, Box] + + +# The data type that we pass in and out of the web.run() @autograd.primitive +AutogradTraced = typing.Union[Box, ArrayLike] +PathType = tuple[typing.Union[int, str], ...] +AutogradFieldMap = dict_ag[PathType, AutogradTraced] + +__all__ = [ + "TracedFloat", + "TracedSize1D", + "TracedSize", + "TracedCoordinate", + "TracedVertices", + "AutogradTraced", + "AutogradFieldMap", +] diff --git a/tidy3d/components/autograd/utils.py b/tidy3d/components/autograd/utils.py new file mode 100644 index 000000000..d2e0eb2e3 --- /dev/null +++ b/tidy3d/components/autograd/utils.py @@ -0,0 +1,15 @@ +# utilities for working with autograd + +import typing + +from autograd.tracer import getval + + +def get_static(x: typing.Any) -> typing.Any: + """Get the 'static' (untraced) version of some value.""" + return getval(x) + + +__all__ = [ + "get_static", +] diff --git a/tidy3d/components/base.py b/tidy3d/components/base.py index 27fc26e3b..57542589a 100644 --- a/tidy3d/components/base.py +++ b/tidy3d/components/base.py @@ -24,7 +24,8 @@ from ..exceptions import FileError from ..log import log -from .autograd import AutogradFieldMap, Box, get_static +from .autograd.types import AutogradFieldMap, Box +from .autograd.utils import get_static from .data.data_array import AUTOGRAD_KEY, DATA_ARRAY_MAP, DataArray from .file_util import compress_file_to_gzip, extract_gzip_file from .types import TYPE_TAG_STR, ComplexNumber, Literal diff --git a/tidy3d/components/geometry/base.py b/tidy3d/components/geometry/base.py index 0f088473c..ba858a615 100644 --- a/tidy3d/components/geometry/base.py +++ b/tidy3d/components/geometry/base.py @@ -23,9 +23,9 @@ ) from ...log import log from ...packaging import check_import, verify_packages_import -from ..autograd import TracedCoordinate, TracedSize, get_static, integrate_within_bounds +from ..autograd import AutogradFieldMap, TracedCoordinate, TracedSize, get_static +from ..autograd.derivative_utils import DerivativeInfo, integrate_within_bounds from ..base import Tidy3dBaseModel, cached_property -from ..data.dataset import ElectromagneticFieldDataset, PermittivityDataset from ..transformation import RotationAroundAxis from ..types import ( ArrayFloat2D, @@ -1427,18 +1427,8 @@ def to_gds_file( pathlib.Path(fname).parent.mkdir(parents=True, exist_ok=True) library.write_gds(fname) - def compute_derivatives( - self, - field_paths: list[tuple[str, ...]], - E_der_map: ElectromagneticFieldDataset, - D_der_map: ElectromagneticFieldDataset, - eps_data: PermittivityDataset, - eps_in: complex, - eps_out: complex, - bounds: Bound, - ) -> dict[str, Any]: - """Compute the adjoint derivative for this geometry.""" - + def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" raise NotImplementedError(f"Can't compute derivative for 'Geometry': '{type(self)}'.") def _as_union(self) -> List[Geometry]: @@ -2339,34 +2329,18 @@ def _surface_area(self, bounds: Bound) -> float: """ Autograd code """ - def compute_derivatives( - self, - field_paths: list[tuple[str, ...]], - E_der_map: ElectromagneticFieldDataset, - D_der_map: ElectromagneticFieldDataset, - eps_data: PermittivityDataset, - eps_in: complex, - eps_out: complex, - bounds: Bound, - ) -> dict[str, Any]: - """Compute adjoint derivatives for each of the ``field_path``s.""" + def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" # get gradients w.r.t. each of the 6 faces (in normal direction) - vjps_faces = self.derivative_faces( - E_der_map=E_der_map, - D_der_map=D_der_map, - eps_data=eps_data, - eps_in=eps_in, - eps_out=eps_out, - bounds=bounds, - ) + vjps_faces = self.derivative_faces(derivative_info=derivative_info) # post-process these values to give the gradients w.r.t. center and size vjps_center_size = self.derivatives_center_size(vjps_faces=vjps_faces) # store only the gradients asked for in 'field_paths' derivative_map = {} - for field_path in field_paths: + for field_path in derivative_info.paths: field_name, *index = field_path if field_name in vjps_center_size: @@ -2397,15 +2371,7 @@ def derivatives_center_size(vjps_faces: Bound) -> dict[str, Coordinate]: size=tuple(vjp_size.tolist()), ) - def derivative_faces( - self, - E_der_map: ElectromagneticFieldDataset, - D_der_map: ElectromagneticFieldDataset, - eps_data: PermittivityDataset, - eps_in: complex, - eps_out: complex, - bounds: Bound, - ) -> Bound: + def derivative_faces(self, derivative_info: DerivativeInfo) -> Bound: """Derivative with respect to normal position of 6 faces of ``Box``.""" # change in permittivity between inside and outside @@ -2416,12 +2382,7 @@ def derivative_faces( vjp_face = self.derivative_face( min_max_index=min_max_index, axis_normal=axis, - E_der_map=E_der_map, - D_der_map=D_der_map, - eps_data=eps_data, - eps_in=eps_in, - eps_out=eps_out, - bounds=bounds, + derivative_info=derivative_info, ) # record vjp for this face @@ -2433,12 +2394,7 @@ def derivative_face( self, min_max_index: int, axis_normal: Axis, - E_der_map: ElectromagneticFieldDataset, - D_der_map: ElectromagneticFieldDataset, - eps_data: PermittivityDataset, - eps_in: complex, - eps_out: complex, - bounds: Bound, + derivative_info: DerivativeInfo, ) -> float: """Compute the derivative w.r.t. shifting a face in the normal direction.""" @@ -2447,11 +2403,11 @@ def derivative_face( fld_normal, flds_perp = self.pop_axis(("Ex", "Ey", "Ez"), axis=axis_normal) # normal and tangential fields - D_normal = D_der_map.field_components[fld_normal] - Es_perp = tuple(E_der_map.field_components[key] for key in flds_perp) + D_normal = derivative_info.D_der_map[fld_normal] + Es_perp = tuple(derivative_info.E_der_map[key] for key in flds_perp) # normal and tangential bounds - bounds_T = np.array(bounds).T # put (xyz) first dimension + bounds_T = np.array(derivative_info.bounds).T # put (xyz) first dimension bounds_normal, bounds_perp = self.pop_axis(bounds_T, axis=axis_normal) # define the integration plane @@ -2473,15 +2429,16 @@ def derivative_face( return 0.0 # grab permittivity data inside and outside edge in normal direction - eps_xyz = [eps_data.field_components[f"eps_{dim}{dim}"] for dim in "xyz"] + eps_xyz = [derivative_info.eps_data[f"eps_{dim}{dim}"] for dim in "xyz"] # number of cells from the edge of data to register "inside" (index = num_cells_in - 1) num_cells_in = 4 # if not enough data, just use best guess using eps in medium and simulation - if any(len(eps.coords[dim_normal]) <= num_cells_in for eps in eps_xyz): - eps_xyz_inside = 3 * [eps_in] - eps_xyz_outside = 3 * [eps_out] + needs_eps_approx = any(len(eps.coords[dim_normal]) <= num_cells_in for eps in eps_xyz) + if derivative_info.eps_approx or needs_eps_approx: + eps_xyz_inside = 3 * [derivative_info.eps_in] + eps_xyz_outside = 3 * [derivative_info.eps_out] # TODO: not tested... # otherwise, try to grab the data at the edges @@ -3246,5 +3203,23 @@ def _update_from_bounds(self, bounds: Tuple[float, float], axis: Axis) -> Geomet ] return self.updated_copy(geometries=new_geometries) + def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" + + grad_vjps = {} + + for field_path in derivative_info.paths: + _, index, *geo_path = field_path + geo = self.geometries[index] + geo_info = derivative_info.updated_copy( + paths=[geo_path], bounds=geo.bounds, eps_approx=True + ) + vjp_dict_geo = geo.compute_derivatives(geo_info) + grad_vjp_values = list(vjp_dict_geo.values()) + assert len(grad_vjp_values) == 1, "Got multiple gradients for single geometry field." + grad_vjps[field_path] = grad_vjp_values[0] + + return grad_vjps + from .utils import GeometryType, from_shapely, vertices_from_shapely # noqa: E402 diff --git a/tidy3d/components/geometry/polyslab.py b/tidy3d/components/geometry/polyslab.py index 28f7be05b..778230ed0 100644 --- a/tidy3d/components/geometry/polyslab.py +++ b/tidy3d/components/geometry/polyslab.py @@ -3,7 +3,7 @@ from __future__ import annotations from math import isclose -from typing import Any, List, Tuple +from typing import List, Tuple import autograd.numpy as np import pydantic.v1 as pydantic @@ -15,9 +15,10 @@ from ...exceptions import SetupError, ValidationError from ...log import log from ...packaging import verify_packages_import -from ..autograd import TracedVertices, get_static +from ..autograd import AutogradFieldMap, TracedVertices, get_static +from ..autograd.derivative_utils import DerivativeInfo from ..base import cached_property, skip_if_fields_missing -from ..data.dataset import ElectromagneticFieldDataset, PermittivityDataset +from ..data.dataset import ElectromagneticFieldDataset from ..types import ( ArrayFloat2D, ArrayLike, @@ -1368,40 +1369,18 @@ def _surface_area(self, bounds: Bound) -> float: """ Autograd code """ - def compute_derivatives( - self, - field_paths: list[tuple[str, ...]], - E_der_map: ElectromagneticFieldDataset, - D_der_map: ElectromagneticFieldDataset, - eps_data: PermittivityDataset, - eps_in: complex, - eps_out: complex, - bounds: Bound, - ) -> dict[str, Any]: - """Compute adjoint derivatives for each of the ``field_path``s.""" - - assert field_paths == [("vertices",)], "only support derivative wrt 'PolySlab.vertices'." - - vjp_vertices = self.compute_derivative_vertices( - E_der_map=E_der_map, - D_der_map=D_der_map, - eps_data=eps_data, - eps_in=eps_in, - eps_out=eps_out, - bounds=bounds, - ) + def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" + + assert derivative_info.paths == [ + ("vertices",) + ], "only support derivative wrt 'PolySlab.vertices'." + + vjp_vertices = self.compute_derivative_vertices(derivative_info=derivative_info) return {("vertices",): vjp_vertices} - def compute_derivative_vertices( - self, - E_der_map: ElectromagneticFieldDataset, - D_der_map: ElectromagneticFieldDataset, - eps_data: PermittivityDataset, - eps_in: complex, - eps_out: complex, - bounds: Bound, - ) -> TracedVertices: + def compute_derivative_vertices(self, derivative_info: DerivativeInfo) -> TracedVertices: # derivative w.r.t each edge vertices = np.array(self.vertices) @@ -1420,8 +1399,12 @@ def compute_derivative_vertices( assert edge_centers_xyz.shape == (num_vertices, 3), "something bad happened" # compute the E and D fields at the edge centers - E_der_at_edges = self.der_at_centers(der_map=E_der_map, edge_centers=edge_centers_xyz) - D_der_at_edges = self.der_at_centers(der_map=D_der_map, edge_centers=edge_centers_xyz) + E_der_at_edges = self.der_at_centers( + der_map=derivative_info.E_der_map, edge_centers=edge_centers_xyz + ) + D_der_at_edges = self.der_at_centers( + der_map=derivative_info.D_der_map, edge_centers=edge_centers_xyz + ) # compute the basis vectors along each edge basis_vectors = self.edge_basis_vectors(edges=edges) @@ -1432,8 +1415,8 @@ def compute_derivative_vertices( E_der_slab = self.project_in_basis(E_der_at_edges, basis_vector=basis_vectors["slab"]) # approximate permittivity in and out - delta_eps_inv = 1.0 / eps_in - 1.0 / eps_out - delta_eps = eps_in - eps_out + delta_eps_inv = 1.0 / derivative_info.eps_in - 1.0 / derivative_info.eps_out + delta_eps = derivative_info.eps_in - derivative_info.eps_out # put together VJP using D_normal and E_perp integration vjps_edges = 0.0 @@ -1453,7 +1436,7 @@ def compute_derivative_vertices( # correction to edge area based on sidewall distance along slab axis dim_axis = "xyz"[self.axis] - field_coords_axis = E_der_map.field_components[f"E{dim_axis}"].coords[dim_axis] + field_coords_axis = derivative_info.E_der_map[f"E{dim_axis}"].coords[dim_axis] if len(field_coords_axis) > 1: slab_height = abs(float(np.squeeze(np.diff(self.slab_bounds)))) if not np.isinf(slab_height): @@ -1487,12 +1470,12 @@ def der_at_centers( interp_kwargs = {} for dim, centers_dim in zip("xyz", edge_centers.T): # only include dims where the data has more than 1 coord, to avoid warnings and errors - coords_data = der_map.field_components[f"E{dim}"].coords + coords_data = der_map[f"E{dim}"].coords if np.array(coords_data).size > 1: interp_kwargs[dim] = xr.DataArray(centers_dim, dims=edge_index_dim) components = {} - for fld_name, arr in der_map.field_components.items(): + for fld_name, arr in der_map.items(): components[fld_name] = arr.interp(**interp_kwargs).sum("f") return xr.Dataset(components) diff --git a/tidy3d/components/medium.py b/tidy3d/components/medium.py index d82320a36..7b3070682 100644 --- a/tidy3d/components/medium.py +++ b/tidy3d/components/medium.py @@ -5,7 +5,7 @@ import functools from abc import ABC, abstractmethod from math import isclose -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import autograd.numpy as np @@ -31,7 +31,8 @@ ) from ..exceptions import SetupError, ValidationError from ..log import log -from .autograd import TracedFloat, integrate_within_bounds +from .autograd.derivative_utils import DerivativeInfo, integrate_within_bounds +from .autograd.types import AutogradFieldMap, TracedFloat from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing from .data.data_array import DATA_ARRAY_MAP, ScalarFieldDataArray, SpatialDataArray from .data.dataset import ( @@ -1109,17 +1110,8 @@ def sel_inside(self, bounds: Bound) -> AbstractMedium: """ Autograd code """ - def compute_derivatives( - self, - field_paths: list[tuple[str, ...]], - E_der_map: ElectromagneticFieldDataset, - D_der_map: ElectromagneticFieldDataset, - eps_data: PermittivityDataset, - eps_in: complex, - eps_out: complex, - bounds: Bound, - ) -> dict[str, Any]: - """Compute the adjoint derivative for this geometry.""" + def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" raise NotImplementedError(f"Can't compute derivative for 'Medium': '{type(self)}'.") @@ -1523,24 +1515,17 @@ def from_nk(cls, n: float, k: float, freq: float, **kwargs): ) return cls(permittivity=eps, conductivity=sigma, **kwargs) - def compute_derivatives( - self, - field_paths: list[tuple[str, ...]], - E_der_map: ElectromagneticFieldDataset, - D_der_map: ElectromagneticFieldDataset, - eps_data: PermittivityDataset, - eps_in: complex, - eps_out: complex, - bounds: Bound, - ) -> dict[str, Any]: - """Compute adjoint derivatives for each of the ``fields`` given the multiplied E and D.""" + def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" # get vjps w.r.t. permittivity and conductivity of the bulk - vjps_volume = self.derivative_eps_sigma_volume(E_der_map=E_der_map, bounds=bounds) + vjps_volume = self.derivative_eps_sigma_volume( + E_der_map=derivative_info.E_der_map, bounds=derivative_info.bounds + ) # store the fields asked for by ``field_paths`` derivative_map = {} - for field_path in field_paths: + for field_path in derivative_info.paths: field_name, *_ = field_path if field_name in vjps_volume: derivative_map[field_path] = vjps_volume[field_name] @@ -1573,7 +1558,7 @@ def derivative_eps_complex_volume( vjp_value = 0.0 for field_name in ("Ex", "Ey", "Ez"): - fld = E_der_map.field_components[field_name] + fld = E_der_map[field_name] vjp_value_fld = integrate_within_bounds( arr=fld, dims=("x", "y", "z"), @@ -2438,26 +2423,17 @@ def _sel_custom_data_inside(self, bounds: Bound): eps_dataset=eps_reduced, ) - def compute_derivatives( - self, - field_paths: list[tuple[str, ...]], - E_der_map: ElectromagneticFieldDataset, - D_der_map: ElectromagneticFieldDataset, - eps_data: PermittivityDataset, - eps_in: complex, - eps_out: complex, - bounds: Bound, - ) -> dict[str, Any]: - """Compute the adjoint derivative for this geometry.""" + def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" vjps = {} - for field_path in field_paths: + for field_path in derivative_info.paths: if field_path == ("permittivity",): vjp_array = 0.0 for dim in "xyz": vjp_array += self._derivative_field_cmp( - E_der_map=E_der_map, eps_data=self.permittivity, dim=dim + E_der_map=derivative_info.E_der_map, eps_data=self.permittivity, dim=dim ) vjps[field_path] = vjp_array @@ -2465,7 +2441,9 @@ def compute_derivatives( key = field_path[1] dim = key[-1] vjps[field_path] = self._derivative_field_cmp( - E_der_map=E_der_map, eps_data=self.eps_dataset.field_components[key], dim=dim + E_der_map=derivative_info.E_der_map, + eps_data=self.eps_dataset.field_components[key], + dim=dim, ) else: @@ -2510,7 +2488,7 @@ def _derivative_field_cmp( d_vol = np.array(1.0) # TODO: probably this could be more robust. eg if the DataArray has weird edge cases - E_der_dim = E_der_map.field_components[f"E{dim}"] + E_der_dim = E_der_map[f"E{dim}"] E_der_dim_interp = E_der_dim.interp(**coords_interp).fillna(0.0).sum(dims_sum).real vjp_array = np.array(E_der_dim_interp.values).astype(float) vjp_array = vjp_array.reshape(eps_data.shape) diff --git a/tidy3d/components/structure.py b/tidy3d/components/structure.py index 6f1a83502..ea6011fd6 100644 --- a/tidy3d/components/structure.py +++ b/tidy3d/components/structure.py @@ -4,21 +4,22 @@ import pathlib from collections import defaultdict -from typing import Any, Callable, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import pydantic.v1 as pydantic from ..constants import MICROMETER from ..exceptions import SetupError, Tidy3dError, Tidy3dImportError -from .autograd import get_static +from .autograd.derivative_utils import DerivativeInfo +from .autograd.types import AutogradFieldMap +from .autograd.utils import get_static from .base import Tidy3dBaseModel, skip_if_fields_missing -from .data.monitor_data import FieldData, PermittivityData from .geometry.utils import GeometryType, validate_no_transformed_polyslabs from .grid.grid import Coords from .medium import AbstractCustomMedium, Medium2D, MediumType from .monitor import FieldMonitor, PermittivityMonitor -from .types import TYPE_TAG_STR, Ax, Axis, Bound +from .types import TYPE_TAG_STR, Ax, Axis from .validators import validate_name_str from .viz import add_ax_if_none, equal_aspect @@ -241,21 +242,12 @@ def get_derivative_function(self, path: tuple[str, ...]) -> Callable: raise NotImplementedError(f"Can't compute derivative for structure field path: {path}.") return derivative_map[path] - def compute_derivatives( - self, - structure_paths: list[tuple[str, ...]], - E_der_map: FieldData, - D_der_map: FieldData, - eps_data: PermittivityData, - eps_in: complex, - eps_out: complex, - bounds: Bound, - ) -> dict[tuple[str, ...], Any]: + def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: """Compute adjoint gradients given the forward and adjoint fields""" # generate a mapping from the 'medium', or 'geometry' tag to the list of fields for VJP structure_fields_map = defaultdict(list) - for structure_path in structure_paths: + for structure_path in derivative_info.paths: med_or_geo, *field_path = structure_path field_path = tuple(field_path) if med_or_geo not in ("geometry", "medium"): @@ -273,15 +265,8 @@ def compute_derivatives( for med_or_geo, field_paths in structure_fields_map.items(): # grab derivative values {field_name -> vjp_value} med_or_geo_field = self.medium if med_or_geo == "medium" else self.geometry - derivative_values_map = med_or_geo_field.compute_derivatives( - field_paths=field_paths, - E_der_map=E_der_map, - D_der_map=D_der_map, - eps_data=eps_data, - eps_in=eps_in, - eps_out=eps_out, - bounds=bounds, - ) + info = derivative_info.updated_copy(paths=field_paths) + derivative_values_map = med_or_geo_field.compute_derivatives(derivative_info=info) # construct map of {field path -> derivative value} for field_path, derivative_value in derivative_values_map.items(): diff --git a/tidy3d/plugins/autograd/README.md b/tidy3d/plugins/autograd/README.md index c1052a45c..34fea6497 100644 --- a/tidy3d/plugins/autograd/README.md +++ b/tidy3d/plugins/autograd/README.md @@ -115,9 +115,12 @@ The following components are traceable as inputs to the `td.Simulation` - `Medium.permittivity` - `Medium.conductivity` + - `CustomMedium.permittivity` - `CustomMedium.eps_dataset` +- `GeometryGroup.geometries` + The following components are traceable as outputs of the `td.SimulationData` - `ModeData.amps` @@ -147,7 +150,6 @@ Next on our roadmap (targeting 2.8 and 2.9, summer 2024) is to support: - custom (spatially-dependent) dispersive models, allowing topology optimization with metals. - `ComplexPolySlab` -- `GeometryGroup` Later this year (2024), we plan to support: diff --git a/tidy3d/web/api/autograd/autograd.py b/tidy3d/web/api/autograd/autograd.py index 66ad1c03b..a242990f8 100644 --- a/tidy3d/web/api/autograd/autograd.py +++ b/tidy3d/web/api/autograd/autograd.py @@ -5,9 +5,11 @@ import numpy as np from autograd.builtins import dict as dict_ag +from autograd.extend import defvjp, primitive import tidy3d as td -from tidy3d.components.autograd import AutogradFieldMap, defvjp, get_static, primitive +from tidy3d.components.autograd import AutogradFieldMap, get_static +from tidy3d.components.autograd.derivative_utils import DerivativeInfo from ..asynchronous import DEFAULT_DATA_DIR from ..asynchronous import run_async as run_async_webapi @@ -58,7 +60,7 @@ def is_valid_for_autograd(simulation: td.Simulation) -> bool: return False # if too many structures, raise an error - structure_indices = [i for key, i, *_ in traced_fields.keys() if key == "structures"] + structure_indices = {i for key, i, *_ in traced_fields.keys() if key == "structures"} num_traced_structures = len(structure_indices) if num_traced_structures > MAX_NUM_TRACED_STRUCTURES: raise ValueError( @@ -654,16 +656,18 @@ def postprocess_adj( eps_in = np.mean(structure.medium.eps_model(td.C_0)) eps_out = np.mean(sim_data_orig.simulation.medium.eps_model(td.C_0)) - vjp_value_map = structure.compute_derivatives( - structure_paths=structure_paths, - E_der_map=E_der_map, - D_der_map=D_der_map, - eps_data=eps_fwd, + derivative_info = DerivativeInfo( + paths=structure_paths, + E_der_map=E_der_map.field_components, + D_der_map=D_der_map.field_components, + eps_data=eps_fwd.field_components, eps_in=eps_in, eps_out=eps_out, bounds=structure.geometry.bounds, # TODO: pass intersecting bounds with sim? ) + vjp_value_map = structure.compute_derivatives(derivative_info) + # extract VJPs and put back into sim_fields_vjp AutogradFieldMap for structure_path, vjp_value in vjp_value_map.items(): sim_path = tuple(["structures", structure_index] + list(structure_path))