diff --git a/CHANGELOG.md b/CHANGELOG.md index c588d8acf..44cf068a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Ability to downsample recorded near fields to speed up server-side far field projections. - `FieldData.apply_phase(phase)` to multiply field data by a phase. - Optional `phase` argument to `SimulationData.plot_field` that applies a phase to complex-valued fields. +- Ability to mix regular mediums and geometries with differentiable analogues in `JaxStructure`. Enables support for shape optimization with dispersive mediums. New classes `JaxStructureStaticGeometry` and `JaxStructureStaticMedium` accept regular `Tidy3D` geometry and medium classes, respectively. ### Changed - Indent for the json string of Tidy3D models has been changed to `None` when used internally; kept as `indent=4` for writing to `json` and `yaml` files. diff --git a/tests/sims/simulation_2_5_0rc3.h5 b/tests/sims/simulation_2_5_0rc3.h5 index 6742fa208..c93f159de 100644 Binary files a/tests/sims/simulation_2_5_0rc3.h5 and b/tests/sims/simulation_2_5_0rc3.h5 differ diff --git a/tests/test_plugins/test_adjoint.py b/tests/test_plugins/test_adjoint.py index 59eaf6b23..ba9e672b3 100644 --- a/tests/test_plugins/test_adjoint.py +++ b/tests/test_plugins/test_adjoint.py @@ -20,7 +20,11 @@ from tidy3d.plugins.adjoint.components.geometry import JaxGeometryGroup from tidy3d.plugins.adjoint.components.medium import JaxMedium, JaxAnisotropicMedium from tidy3d.plugins.adjoint.components.medium import JaxCustomMedium, MAX_NUM_CELLS_CUSTOM_MEDIUM -from tidy3d.plugins.adjoint.components.structure import JaxStructure +from tidy3d.plugins.adjoint.components.structure import ( + JaxStructure, + JaxStructureStaticMedium, + JaxStructureStaticGeometry, +) from tidy3d.plugins.adjoint.components.simulation import JaxSimulation, JaxInfo, RUN_TIME_FACTOR from tidy3d.plugins.adjoint.components.simulation import MAX_NUM_INPUT_STRUCTURES from tidy3d.plugins.adjoint.components.data.sim_data import JaxSimulationData @@ -33,7 +37,7 @@ from tidy3d.plugins.adjoint.utils.penalty import RadiusPenalty from tidy3d.plugins.adjoint.utils.filter import ConicFilter, BinaryProjector, CircularFilter from tidy3d.web.api.container import BatchData - +import tidy3d.material_library as material_library from ..utils import run_emulated, assert_log_level, log_capture, run_async_emulated from ..test_components.test_custom import CUSTOM_MEDIUM @@ -253,6 +257,14 @@ def make_sim( jax_geo_group = JaxGeometryGroup(geometries=[jax_polyslab1, jax_polyslab1]) jax_struct_group = JaxStructure(geometry=jax_geo_group, medium=jax_med1) + + jax_struct_static_med = JaxStructureStaticMedium( + geometry=jax_box1, medium=td.Medium() # material_library["Ag"]["Rakic1998BB"] + ) + jax_struct_static_geo = JaxStructureStaticGeometry( + geometry=td.Box(size=(1, 1, 1)), medium=jax_med1 + ) + # TODO: Add new geometries as they are created. # NOTE: Any new output monitors should be added below as they are made @@ -305,6 +317,8 @@ def make_sim( jax_struct3, jax_struct_group, jax_struct_custom_anis, + jax_struct_static_med, + jax_struct_static_geo, ), output_monitors=(output_mnt1, output_mnt2, output_mnt3, output_mnt4), sources=[src], diff --git a/tidy3d/plugins/adjoint/__init__.py b/tidy3d/plugins/adjoint/__init__.py index 7630d93f1..9bfc03267 100644 --- a/tidy3d/plugins/adjoint/__init__.py +++ b/tidy3d/plugins/adjoint/__init__.py @@ -4,7 +4,11 @@ try: from .components.geometry import JaxBox, JaxPolySlab, JaxGeometryGroup from .components.medium import JaxMedium, JaxAnisotropicMedium, JaxCustomMedium - from .components.structure import JaxStructure + from .components.structure import ( + JaxStructure, + JaxStructureStaticGeometry, + JaxStructureStaticMedium, + ) from .components.simulation import JaxSimulation from .components.data.sim_data import JaxSimulationData from .components.data.monitor_data import JaxModeData @@ -12,9 +16,9 @@ from .components.data.data_array import JaxDataArray except ImportError as e: raise ImportError( - "The 'jax' package is required for adjoint plugin and not installed. " - "To get the appropriate packages, install tidy3d using '[jax]' option, for example: " - "$pip install 'tidy3d[jax]'." + "The 'jax' package is required for adjoint plugin. We were not able to import it. " + "To get the appropriate packages for your system, install tidy3d using '[jax]' option, " + "for example: $pip install 'tidy3d[jax]'." ) from e try: @@ -30,6 +34,8 @@ "JaxAnisotropicMedium", "JaxCustomMedium", "JaxStructure", + "JaxStructureStaticMedium", + "JaxStructureStaticGeometry", "JaxSimulation", "JaxSimulationData", "JaxModeData", diff --git a/tidy3d/plugins/adjoint/components/__init__.py b/tidy3d/plugins/adjoint/components/__init__.py index 59d5df0ba..e70157765 100644 --- a/tidy3d/plugins/adjoint/components/__init__.py +++ b/tidy3d/plugins/adjoint/components/__init__.py @@ -1,9 +1,9 @@ """Component imports for adjoint plugin. from tidy3d.plugins.adjoint.components import *""" # import the jax version of tidy3d components -from .geometry import JaxBox # , JaxPolySlab +from .geometry import JaxBox, JaxPolySlab from .medium import JaxMedium, JaxAnisotropicMedium, JaxCustomMedium -from .structure import JaxStructure +from .structure import JaxStructure, JaxStructureStaticMedium, JaxStructureStaticGeometry from .simulation import JaxSimulation from .data.sim_data import JaxSimulationData from .data.monitor_data import JaxModeData @@ -18,6 +18,8 @@ "JaxAnisotropicMedium", "JaxCustomMedium", "JaxStructure", + "JaxStructureStaticMedium", + "JaxStructureStaticGeometry", "JaxSimulation", "JaxSimulationData", "JaxModeData", diff --git a/tidy3d/plugins/adjoint/components/medium.py b/tidy3d/plugins/adjoint/components/medium.py index 3512270bd..dc6826e7d 100644 --- a/tidy3d/plugins/adjoint/components/medium.py +++ b/tidy3d/plugins/adjoint/components/medium.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Dict, Tuple, Union, Callable, Optional -from abc import ABC +from abc import ABC, abstractmethod import pydantic.v1 as pd import numpy as np @@ -34,6 +34,14 @@ class AbstractJaxMedium(ABC, JaxObject): """Holds some utility functions for Jax medium types.""" + def to_tidy3d(self) -> AbstractJaxMedium: + """Convert self to tidy3d component.""" + return self.to_medium() + + @abstractmethod + def to_medium(self) -> AbstractJaxMedium: + """Convert self to medium.""" + def _get_volume_disc( self, grad_data: FieldData, sim_bounds: Bound, wvl_mat: float ) -> Tuple[Dict[str, np.ndarray], float]: diff --git a/tidy3d/plugins/adjoint/components/simulation.py b/tidy3d/plugins/adjoint/components/simulation.py index 08213bc40..62a27dc31 100644 --- a/tidy3d/plugins/adjoint/components/simulation.py +++ b/tidy3d/plugins/adjoint/components/simulation.py @@ -1,7 +1,7 @@ """Defines a jax-compatible simulation.""" from __future__ import annotations -from typing import Tuple, Union, List, Dict +from typing import Tuple, Union, List, Dict, Literal from multiprocessing import Pool import pydantic.v1 as pd @@ -21,7 +21,12 @@ from ....exceptions import AdjointError from .base import JaxObject -from .structure import JaxStructure +from .structure import ( + JaxStructure, + JaxStructureType, + JaxStructureStaticMedium, + JaxStructureStaticGeometry, +) from .geometry import JaxPolySlab, JaxGeometryGroup @@ -82,12 +87,20 @@ class JaxInfo(Tidy3dBaseModel): units=SECOND, ) + input_structure_types: Tuple[ + Literal["JaxStructure", "JaxStructureStaticMedium", "JaxStructureStaticGeometry"], ... + ] = pd.Field( + (), + title="Input Structure Types", + description="Type of the original input_structures (as strings).", + ) + @register_pytree_node_class class JaxSimulation(Simulation, JaxObject): """A :class:`.Simulation` registered with jax.""" - input_structures: Tuple[JaxStructure, ...] = pd.Field( + input_structures: Tuple[annotate_type(JaxStructureType), ...] = pd.Field( (), title="Input Structures", description="Tuple of jax-compatible structures" @@ -171,7 +184,8 @@ def _restrict_input_structures(cls, val): def _warn_overlap(cls, val, values): """Print appropriate warning if structures intersect in ways that cause gradient error.""" - input_structures = list(val) + input_structures = [s for s in val if "geometry" in s._differentiable_fields] + structures = list(values.get("structures")) # if the center and size of all structure geometries do not contain all numbers, skip check @@ -349,6 +363,7 @@ def to_simulation(self) -> Tuple[Simulation, JaxInfo]: num_grad_eps_monitors=len(self.grad_eps_monitors), fwidth_adjoint=self.fwidth_adjoint, run_time_adjoint=self.run_time_adjoint, + input_structure_types=[s.type for s in self.input_structures], ) return sim, jax_info @@ -556,7 +571,19 @@ def split_structures( # split the list based on these numbers structures = all_structures[:num_structs] - input_structures = [JaxStructure.from_structure(s) for s in all_structures[num_structs:]] + structure_type_map = dict( + JaxStructure=JaxStructure, + JaxStructureStaticMedium=JaxStructureStaticMedium, + JaxStructureStaticGeometry=JaxStructureStaticGeometry, + ) + + input_structures = [] + for struct_type_str, struct in zip( + jax_info.input_structure_types, all_structures[num_structs:] + ): + struct_type = structure_type_map[struct_type_str] + new_structure = struct_type.from_structure(struct) + input_structures.append(new_structure) # return a dictionary containing these split structures return dict(structures=structures, input_structures=input_structures) diff --git a/tidy3d/plugins/adjoint/components/structure.py b/tidy3d/plugins/adjoint/components/structure.py index 065cfcb09..c358de6cc 100644 --- a/tidy3d/plugins/adjoint/components/structure.py +++ b/tidy3d/plugins/adjoint/components/structure.py @@ -1,7 +1,7 @@ """Defines a jax-compatible structure and its conversion to a gradient monitor.""" from __future__ import annotations -from typing import List +from typing import List, Union, Dict import pydantic.v1 as pd import numpy as np @@ -12,60 +12,91 @@ from ....components.monitor import FieldMonitor from ....components.data.monitor_data import FieldData, PermittivityData from ....components.types import Bound, TYPE_TAG_STR +from ....components.medium import MediumType +from ....components.geometry.utils import GeometryType from .base import JaxObject from .medium import JaxMediumType, JAX_MEDIUM_MAP -from .geometry import JaxGeometryType, JAX_GEOMETRY_MAP +from .geometry import JaxGeometryType, JAX_GEOMETRY_MAP, JaxBox +GEO_MED_MAPPINGS = dict(geometry=JAX_GEOMETRY_MAP, medium=JAX_MEDIUM_MAP) -@register_pytree_node_class -class JaxStructure(Structure, JaxObject): + +class AbstractJaxStructure(Structure, JaxObject): """A :class:`.Structure` registered with jax.""" - geometry: JaxGeometryType = pd.Field( - ..., - title="Geometry", - description="Geometry of the structure, which is jax-compatible.", - jax_field=True, - discriminator=TYPE_TAG_STR, - ) + geometry: Union[JaxGeometryType, GeometryType] + medium: Union[JaxMediumType, MediumType] - medium: JaxMediumType = pd.Field( - ..., - title="Medium", - description="Medium of the structure, which is jax-compatible.", - jax_field=True, - discriminator=TYPE_TAG_STR, - ) + # which of "geometry" or "medium" is differentiable for this class + _differentiable_fields = () + + @pd.validator("medium", always=True) + def _check_2d_geometry(cls, val, values): + """Override validator checking 2D geometry, which triggers unnecessarily for gradients.""" + return val + + @property + def jax_fields(self): + """The fields that are jax-traced for this class.""" + return dict(geometry=self.geometry, medium=self.medium) + + @property + def exclude_fields(self): + """Fields to exclude from the self dict.""" + return set(["type"] + list(self.jax_fields.keys())) def to_structure(self) -> Structure: """Convert :class:`.JaxStructure` instance to :class:`.Structure`""" - self_dict = self.dict(exclude={"type", "geometry", "medium"}) - self_dict["geometry"] = self.geometry.to_tidy3d() - self_dict["medium"] = self.medium.to_medium() + self_dict = self.dict(exclude=self.exclude_fields) + for key, component in self.jax_fields.items(): + if key in self._differentiable_fields: + self_dict[key] = component.to_tidy3d() + else: + self_dict[key] = component return Structure.parse_obj(self_dict) @classmethod def from_structure(cls, structure: Structure) -> JaxStructure: """Convert :class:`.Structure` to :class:`.JaxStructure`.""" - # get the appropriate jax types corresponding to the td.Structure fields - jax_geometry_type = JAX_GEOMETRY_MAP[type(structure.geometry)] - jax_medium_type = JAX_MEDIUM_MAP[type(structure.medium)] + struct_dict = structure.dict(exclude={"type"}) + + jax_fields = dict(geometry=structure.geometry, medium=structure.medium) - # load them into the JaxStructure dictionary and parse it into an instance - struct_dict = structure.dict(exclude={"type", "geometry", "medium"}) - struct_dict["geometry"] = jax_geometry_type.from_tidy3d(structure.geometry) - struct_dict["medium"] = jax_medium_type.from_tidy3d(structure.medium) + for key, component in jax_fields.items(): + if key in cls._differentiable_fields: + type_map = GEO_MED_MAPPINGS[key] + jax_type = type_map[type(component)] + struct_dict[key] = jax_type.from_tidy3d(component) + else: + struct_dict[key] = component return cls.parse_obj(struct_dict) - @pd.validator("medium", always=True) - def _check_2d_geometry(cls, val, values): - """Override validator checking 2D geometry, which triggers unnecessarily for gradients.""" - return val + def make_grad_monitors(self, freqs: List[float], name: str) -> FieldMonitor: + """Return gradient monitor associated with this object.""" + if "geometry" not in self._differentiable_fields: + # make a fake JaxBox to be able to call .make_grad_monitors + rmin, rmax = self.geometry.bounds + geometry = JaxBox.from_bounds(rmin=rmin, rmax=rmax) + else: + geometry = self.geometry + return geometry.make_grad_monitors(freqs=freqs, name=name) + + def _get_medium_params( + self, + grad_data_eps: PermittivityData, + ) -> Dict[str, float]: + """Compute params in the material of this structure.""" + freq_max = max(grad_data_eps.eps_xx.f) + eps_in = self.medium.eps_model(frequency=freq_max) + ref_ind = np.sqrt(np.max(np.real(eps_in))) + wvl_free_space = C_0 / freq_max + wvl_mat = wvl_free_space / ref_ind + return dict(wvl_mat=wvl_mat, eps_in=eps_in) - def store_vjp( + def geometry_vjp( self, grad_data_fwd: FieldData, grad_data_adj: FieldData, @@ -73,37 +104,150 @@ def store_vjp( sim_bounds: Bound, eps_out: complex, num_proc: int = 1, - ) -> JaxStructure: - """Returns the gradient of the structure parameters given forward and adjoint field data.""" + ) -> JaxGeometryType: + """Compute the VJP for the structure geometry.""" - # compute minimum wavelength in material (to use for determining integration points) - freq_max = max(grad_data_eps.eps_xx.f) - wvl_free_space = C_0 / freq_max - eps_in = self.medium.eps_model(frequency=freq_max) - ref_ind = np.sqrt(np.max(np.real(eps_in))) - wvl_mat = wvl_free_space / ref_ind + medium_params = self._get_medium_params(grad_data_eps=grad_data_eps) - geo_vjp = self.geometry.store_vjp( + return self.geometry.store_vjp( grad_data_fwd=grad_data_fwd, grad_data_adj=grad_data_adj, grad_data_eps=grad_data_eps, sim_bounds=sim_bounds, - wvl_mat=wvl_mat, + wvl_mat=medium_params["wvl_mat"], eps_out=eps_out, - eps_in=eps_in, + eps_in=medium_params["eps_in"], num_proc=num_proc, ) - medium_vjp = self.medium.store_vjp( + def medium_vjp( + self, + grad_data_fwd: FieldData, + grad_data_adj: FieldData, + grad_data_eps: PermittivityData, + sim_bounds: Bound, + ) -> JaxMediumType: + """Compute the VJP for the structure medium.""" + + medium_params = self._get_medium_params(grad_data_eps=grad_data_eps) + + return self.medium.store_vjp( grad_data_fwd=grad_data_fwd, grad_data_adj=grad_data_adj, sim_bounds=sim_bounds, - wvl_mat=wvl_mat, + wvl_mat=medium_params["wvl_mat"], inside_fn=self.geometry.inside, ) - return self.copy(update=dict(geometry=geo_vjp, medium=medium_vjp)) + def store_vjp( + self, + # field_keys: List[Literal["medium", "geometry"]], + grad_data_fwd: FieldData, + grad_data_adj: FieldData, + grad_data_eps: PermittivityData, + sim_bounds: Bound, + eps_out: complex, + num_proc: int = 1, + ) -> JaxStructure: + """Returns the gradient of the structure parameters given forward and adjoint field data.""" - def make_grad_monitors(self, freqs: List[float], name: str) -> FieldMonitor: - """Return gradient monitor associated with this object.""" - return self.geometry.make_grad_monitors(freqs=freqs, name=name) + # return right away if field_keys are not present for some reason + if not self._differentiable_fields: + return self + + vjp_dict = {} + + # compute minimum wavelength in material (to use for determining integration points) + if "geometry" in self._differentiable_fields: + vjp_dict["geometry"] = self.geometry_vjp( + grad_data_fwd=grad_data_fwd, + grad_data_adj=grad_data_adj, + grad_data_eps=grad_data_eps, + sim_bounds=sim_bounds, + eps_out=eps_out, + num_proc=num_proc, + ) + + if "medium" in self._differentiable_fields: + vjp_dict["medium"] = self.medium_vjp( + grad_data_fwd=grad_data_fwd, + grad_data_adj=grad_data_adj, + grad_data_eps=grad_data_eps, + sim_bounds=sim_bounds, + ) + + return self.updated_copy(**vjp_dict) + + +@register_pytree_node_class +class JaxStructure(AbstractJaxStructure, JaxObject): + """A :class:`.Structure` registered with jax.""" + + geometry: JaxGeometryType = pd.Field( + ..., + title="Geometry", + description="Geometry of the structure, which is jax-compatible.", + jax_field=True, + discriminator=TYPE_TAG_STR, + ) + + medium: JaxMediumType = pd.Field( + ..., + title="Medium", + description="Medium of the structure, which is jax-compatible.", + jax_field=True, + discriminator=TYPE_TAG_STR, + ) + + _differentiable_fields = ("medium", "geometry") + + +@register_pytree_node_class +class JaxStructureStaticMedium(AbstractJaxStructure, JaxObject): + """A :class:`.Structure` registered with jax.""" + + geometry: JaxGeometryType = pd.Field( + ..., + title="Geometry", + description="Geometry of the structure, which is jax-compatible.", + jax_field=True, + discriminator=TYPE_TAG_STR, + ) + + medium: MediumType = pd.Field( + ..., + title="Medium", + description="Regular ``tidy3d`` medium of the structure, non differentiable. " + "Supports dispersive materials.", + jax_field=False, + discriminator=TYPE_TAG_STR, + ) + + _differentiable_fields = ("geometry",) + + +@register_pytree_node_class +class JaxStructureStaticGeometry(AbstractJaxStructure, JaxObject): + """A :class:`.Structure` registered with jax.""" + + geometry: GeometryType = pd.Field( + ..., + title="Geometry", + description="Regular ``tidy3d`` geometry of the structure, non differentiable. " + "Supports angled sidewalls and other complex geometries.", + jax_field=False, + discriminator=TYPE_TAG_STR, + ) + + medium: JaxMediumType = pd.Field( + ..., + title="Medium", + description="Medium of the structure, which is jax-compatible.", + jax_field=True, + discriminator=TYPE_TAG_STR, + ) + + _differentiable_fields = ("medium",) + + +JaxStructureType = Union[JaxStructure, JaxStructureStaticMedium, JaxStructureStaticGeometry]