Skip to content

Commit

Permalink
JaxStructureStaticGeometry and JaxStructureStaticMedium to mix differ…
Browse files Browse the repository at this point in the history
…entiable and static geometry / medium in JaxStructure
  • Loading branch information
tylerflex committed Dec 3, 2023
1 parent b07b167 commit 86d9a76
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 64 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Ability to downsample recorded near fields to speed up server-side far field projections.
- `FieldData.apply_phase(phase)` to multiply field data by a phase.
- Optional `phase` argument to `SimulationData.plot_field` that applies a phase to complex-valued fields.
- Ability to mix regular mediums and geometries with differentiable analogues in `JaxStructure`. Enables support for shape optimization with dispersive mediums. New classes `JaxStructureStaticGeometry` and `JaxStructureStaticMedium` accept regular `Tidy3D` geometry and medium classes, respectively.

### Changed
- Indent for the json string of Tidy3D models has been changed to `None` when used internally; kept as `indent=4` for writing to `json` and `yaml` files.
Expand Down
Binary file modified tests/sims/simulation_2_5_0rc3.h5
Binary file not shown.
18 changes: 16 additions & 2 deletions tests/test_plugins/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from tidy3d.plugins.adjoint.components.geometry import JaxGeometryGroup
from tidy3d.plugins.adjoint.components.medium import JaxMedium, JaxAnisotropicMedium
from tidy3d.plugins.adjoint.components.medium import JaxCustomMedium, MAX_NUM_CELLS_CUSTOM_MEDIUM
from tidy3d.plugins.adjoint.components.structure import JaxStructure
from tidy3d.plugins.adjoint.components.structure import (
JaxStructure,
JaxStructureStaticMedium,
JaxStructureStaticGeometry,
)
from tidy3d.plugins.adjoint.components.simulation import JaxSimulation, JaxInfo, RUN_TIME_FACTOR
from tidy3d.plugins.adjoint.components.simulation import MAX_NUM_INPUT_STRUCTURES
from tidy3d.plugins.adjoint.components.data.sim_data import JaxSimulationData
Expand All @@ -33,7 +37,7 @@
from tidy3d.plugins.adjoint.utils.penalty import RadiusPenalty
from tidy3d.plugins.adjoint.utils.filter import ConicFilter, BinaryProjector, CircularFilter
from tidy3d.web.api.container import BatchData

import tidy3d.material_library as material_library
from ..utils import run_emulated, assert_log_level, log_capture, run_async_emulated
from ..test_components.test_custom import CUSTOM_MEDIUM

Expand Down Expand Up @@ -253,6 +257,14 @@ def make_sim(

jax_geo_group = JaxGeometryGroup(geometries=[jax_polyslab1, jax_polyslab1])
jax_struct_group = JaxStructure(geometry=jax_geo_group, medium=jax_med1)

jax_struct_static_med = JaxStructureStaticMedium(
geometry=jax_box1, medium=td.Medium() # material_library["Ag"]["Rakic1998BB"]
)
jax_struct_static_geo = JaxStructureStaticGeometry(
geometry=td.Box(size=(1, 1, 1)), medium=jax_med1
)

# TODO: Add new geometries as they are created.

# NOTE: Any new output monitors should be added below as they are made
Expand Down Expand Up @@ -305,6 +317,8 @@ def make_sim(
jax_struct3,
jax_struct_group,
jax_struct_custom_anis,
jax_struct_static_med,
jax_struct_static_geo,
),
output_monitors=(output_mnt1, output_mnt2, output_mnt3, output_mnt4),
sources=[src],
Expand Down
14 changes: 10 additions & 4 deletions tidy3d/plugins/adjoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@
try:
from .components.geometry import JaxBox, JaxPolySlab, JaxGeometryGroup
from .components.medium import JaxMedium, JaxAnisotropicMedium, JaxCustomMedium
from .components.structure import JaxStructure
from .components.structure import (
JaxStructure,
JaxStructureStaticGeometry,
JaxStructureStaticMedium,
)
from .components.simulation import JaxSimulation
from .components.data.sim_data import JaxSimulationData
from .components.data.monitor_data import JaxModeData
from .components.data.dataset import JaxPermittivityDataset
from .components.data.data_array import JaxDataArray
except ImportError as e:
raise ImportError(
"The 'jax' package is required for adjoint plugin and not installed. "
"To get the appropriate packages, install tidy3d using '[jax]' option, for example: "
"$pip install 'tidy3d[jax]'."
"The 'jax' package is required for adjoint plugin. We were not able to import it. "
"To get the appropriate packages for your system, install tidy3d using '[jax]' option, "
"for example: $pip install 'tidy3d[jax]'."
) from e

try:
Expand All @@ -30,6 +34,8 @@
"JaxAnisotropicMedium",
"JaxCustomMedium",
"JaxStructure",
"JaxStructureStaticMedium",
"JaxStructureStaticGeometry",
"JaxSimulation",
"JaxSimulationData",
"JaxModeData",
Expand Down
6 changes: 4 additions & 2 deletions tidy3d/plugins/adjoint/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Component imports for adjoint plugin. from tidy3d.plugins.adjoint.components import *"""

# import the jax version of tidy3d components
from .geometry import JaxBox # , JaxPolySlab
from .geometry import JaxBox, JaxPolySlab
from .medium import JaxMedium, JaxAnisotropicMedium, JaxCustomMedium
from .structure import JaxStructure
from .structure import JaxStructure, JaxStructureStaticMedium, JaxStructureStaticGeometry
from .simulation import JaxSimulation
from .data.sim_data import JaxSimulationData
from .data.monitor_data import JaxModeData
Expand All @@ -18,6 +18,8 @@
"JaxAnisotropicMedium",
"JaxCustomMedium",
"JaxStructure",
"JaxStructureStaticMedium",
"JaxStructureStaticGeometry",
"JaxSimulation",
"JaxSimulationData",
"JaxModeData",
Expand Down
10 changes: 9 additions & 1 deletion tidy3d/plugins/adjoint/components/medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

from typing import Dict, Tuple, Union, Callable, Optional
from abc import ABC
from abc import ABC, abstractmethod

import pydantic.v1 as pd
import numpy as np
Expand Down Expand Up @@ -34,6 +34,14 @@
class AbstractJaxMedium(ABC, JaxObject):
"""Holds some utility functions for Jax medium types."""

def to_tidy3d(self) -> AbstractJaxMedium:
"""Convert self to tidy3d component."""
return self.to_medium()

@abstractmethod
def to_medium(self) -> AbstractJaxMedium:
"""Convert self to medium."""

def _get_volume_disc(
self, grad_data: FieldData, sim_bounds: Bound, wvl_mat: float
) -> Tuple[Dict[str, np.ndarray], float]:
Expand Down
37 changes: 32 additions & 5 deletions tidy3d/plugins/adjoint/components/simulation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Defines a jax-compatible simulation."""
from __future__ import annotations

from typing import Tuple, Union, List, Dict
from typing import Tuple, Union, List, Dict, Literal
from multiprocessing import Pool

import pydantic.v1 as pd
Expand All @@ -21,7 +21,12 @@
from ....exceptions import AdjointError

from .base import JaxObject
from .structure import JaxStructure
from .structure import (
JaxStructure,
JaxStructureType,
JaxStructureStaticMedium,
JaxStructureStaticGeometry,
)
from .geometry import JaxPolySlab, JaxGeometryGroup


Expand Down Expand Up @@ -82,12 +87,20 @@ class JaxInfo(Tidy3dBaseModel):
units=SECOND,
)

input_structure_types: Tuple[
Literal["JaxStructure", "JaxStructureStaticMedium", "JaxStructureStaticGeometry"], ...
] = pd.Field(
(),
title="Input Structure Types",
description="Type of the original input_structures (as strings).",
)


@register_pytree_node_class
class JaxSimulation(Simulation, JaxObject):
"""A :class:`.Simulation` registered with jax."""

input_structures: Tuple[JaxStructure, ...] = pd.Field(
input_structures: Tuple[annotate_type(JaxStructureType), ...] = pd.Field(
(),
title="Input Structures",
description="Tuple of jax-compatible structures"
Expand Down Expand Up @@ -171,7 +184,8 @@ def _restrict_input_structures(cls, val):
def _warn_overlap(cls, val, values):
"""Print appropriate warning if structures intersect in ways that cause gradient error."""

input_structures = list(val)
input_structures = [s for s in val if "geometry" in s._differentiable_fields]

structures = list(values.get("structures"))

# if the center and size of all structure geometries do not contain all numbers, skip check
Expand Down Expand Up @@ -349,6 +363,7 @@ def to_simulation(self) -> Tuple[Simulation, JaxInfo]:
num_grad_eps_monitors=len(self.grad_eps_monitors),
fwidth_adjoint=self.fwidth_adjoint,
run_time_adjoint=self.run_time_adjoint,
input_structure_types=[s.type for s in self.input_structures],
)

return sim, jax_info
Expand Down Expand Up @@ -556,7 +571,19 @@ def split_structures(

# split the list based on these numbers
structures = all_structures[:num_structs]
input_structures = [JaxStructure.from_structure(s) for s in all_structures[num_structs:]]
structure_type_map = dict(
JaxStructure=JaxStructure,
JaxStructureStaticMedium=JaxStructureStaticMedium,
JaxStructureStaticGeometry=JaxStructureStaticGeometry,
)

input_structures = []
for struct_type_str, struct in zip(
jax_info.input_structure_types, all_structures[num_structs:]
):
struct_type = structure_type_map[struct_type_str]
new_structure = struct_type.from_structure(struct)
input_structures.append(new_structure)

# return a dictionary containing these split structures
return dict(structures=structures, input_structures=input_structures)
Expand Down
Loading

0 comments on commit 86d9a76

Please sign in to comment.