Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2nd round of autograd fixes #1923

Merged
merged 3 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Improved passivity enforcement near high-Q poles in `FastDispersionFitter`. Failed passivity enforcement could lead to simulation divergences.
- More helpful error and suggestion if users try to differentiate w.r.t. unsupported `FluxMonitor` output.
- Removed positive warnings in Simulation validators for Bloch boundary conditions.
- Improve accuracy in `Box` shifting boundary gradients.
- Improve accuracy in `FieldData` operations involving H fields (like `.flux`).
- Better error and warning handling in autograd pipeline.

## [2.7.2] - 2024-08-07

Expand Down
18 changes: 9 additions & 9 deletions tests/test_components/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

# whether to run numerical gradient tests, off by default because it runs real simulations
RUN_NUMERICAL = False
_NUMERICAL_COMBINATION = ("size_element", "mode")

TEST_MODES = ("pipeline", "adjoint", "speed")
TEST_MODE = "speed" if TEST_POLYSLAB_SPEED else "pipeline"
Expand All @@ -63,7 +64,7 @@
FWIDTH = FREQ0 / 10

# sim sizes
LZ = 7 * WVL
LZ = 7.0 * WVL

IS_3D = False

Expand Down Expand Up @@ -420,9 +421,11 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]:
def make_monitors() -> dict[str, tuple[td.Monitor, typing.Callable[[td.SimulationData], float]]]:
"""Make a dictionary of all the possible monitors in the simulation."""

X = 0.75

mode_mnt = td.ModeMonitor(
size=(2, 2, 0),
center=(0, 0, LZ / 2 - WVL),
center=(0, 0, +LZ / 2 - X * WVL),
mode_spec=td.ModeSpec(),
freqs=[FREQ0],
name="mode",
Expand All @@ -444,7 +447,7 @@ def diff_postprocess_fn(sim_data, mnt_data):

field_vol = td.FieldMonitor(
size=(1, 1, 0),
center=(0, 0, +LZ / 2 - WVL),
center=(0, 0, +LZ / 2 - X * WVL),
freqs=[FREQ0],
name="field_vol",
)
Expand All @@ -453,12 +456,9 @@ def field_vol_postprocess_fn(sim_data, mnt_data):
value = 0.0
for _, val in mnt_data.field_components.items():
value = value + abs(anp.sum(val.values))
# field components numerical is 3x higher
intensity = anp.nan_to_num(anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values))
value += intensity
# intensity numerical is 4.79x higher
value += anp.sum(mnt_data.flux.values)
# flux is 18.4x lower
return value

field_point = td.FieldMonitor(
Expand All @@ -471,7 +471,7 @@ def field_vol_postprocess_fn(sim_data, mnt_data):
def field_point_postprocess_fn(sim_data, mnt_data):
value = 0.0
for _, val in mnt_data.field_components.items():
value += abs(anp.sum(val.values))
value += abs(anp.sum(abs(val.values)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double abs? not that it matters i guess

value += anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values)
return value

Expand Down Expand Up @@ -529,7 +529,7 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = True) -> None:
args = [("polyslab", "mode")]


# args = [("custom_med", "mode")]
# args = [("size_element", "mode")]


def get_functions(structure_key: str, monitor_key: str) -> typing.Callable:
Expand Down Expand Up @@ -599,7 +599,7 @@ def test_polyslab_axis_ops(axis):


@pytest.mark.skipif(not RUN_NUMERICAL, reason="Numerical gradient tests runs through web API.")
@pytest.mark.parametrize("structure_key, monitor_key", (("cylinder", "mode"),))
@pytest.mark.parametrize("structure_key, monitor_key", (_NUMERICAL_COMBINATION,))
def test_autograd_numerical(structure_key, monitor_key):
"""Test an objective function through tidy3d autograd."""

Expand Down
48 changes: 19 additions & 29 deletions tidy3d/components/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from ...log import log
from ..base import TYPE_TAG_STR, cached_property, skip_if_fields_missing
from ..base_sim.data.monitor_data import AbstractMonitorData
from ..geometry.base import Box
from ..grid.grid import Coords, Grid
from ..medium import Medium, MediumType
from ..monitor import (
Expand Down Expand Up @@ -1069,46 +1068,33 @@ def to_adjoint_field_sources(self, fwidth: float) -> List[CustomCurrentSource]:

sources = []

# Define source geometry based on coordinates in the data
data_mins = []
data_maxs = []
source_geo = self.monitor.geometry
freqs = self.monitor.freqs

def shift_value(coords) -> float:
"""How much to shift the geometry by along a dimension (only if > 1D)."""
return SHIFT_VALUE_ADJ_FLD_SRC if len(coords) > 1 else 0

for _, field_component in self.field_components.items():
coords = field_component.coords
data_mins.append({key: min(val) + shift_value(val) for key, val in coords.items()})
data_maxs.append({key: max(val) + shift_value(val) for key, val in coords.items()})

rmin = []
rmax = []
for dim in "xyz":
rmin.append(max(val[dim] for val in data_mins))
rmax.append(min(val[dim] for val in data_maxs))

source_geo = Box.from_bounds(rmin=rmin, rmax=rmax)

# Define source dataset
# Offset coordinates by source center since local coords are assumed in CustomCurrentSource

for freq0 in tuple(self.field_components.values())[0].coords["f"]:
for freq0 in freqs:
src_field_components = {}
for name, field_component in self.field_components.items():
# get the VJP values at frequency and apply adjoint phase
field_component = field_component.sel(f=freq0)
forward_amps = field_component.values
values = -1j * forward_amps
values = -1j * field_component.values

# make source go backwards
if "H" in name:
values *= -1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this what we do in the adjoint plugin too?

I am trying to think about whether this is the correct thing to do in all cases. It's probably possible to see it somehow from the adjoint formulation if written for the H field. I think this is correct. I'm just wondering a little because the intuition of this minus sign "making source go backwards" makes sense for a real propagating mode, but it gets a bit murkier for evanescent or lossy modes. The other option is to take np.conj(values). If the E field is real and the mode is lossless and propagating, then the H field is imaginary and the two operations are the same. In general the np.conj(values) operation would be like time reversal while the values *= -1 operation is like reflection of the mode in the source plane, which produce different results in some cases.

I do think the -1 is correct here though but do you remember if it comes out of the adjoint math?


# make coords that are shifted relative to geometry (0,0,0) = geometry.center
coords = dict(field_component.coords.copy())
for dim, key in enumerate("xyz"):
coords[key] = np.array(coords[key]) - source_geo.center[dim]
coords["f"] = np.array([freq0])
values = np.expand_dims(values, axis=-1)

# ignore zero components
if not np.all(values == 0):
src_field_components[name] = ScalarFieldDataArray(values, coords=coords)

# construct custom Current source
dataset = FieldDataset(**src_field_components)

custom_source = CustomCurrentSource(
center=source_geo.center,
size=source_geo.size,
Expand Down Expand Up @@ -1763,7 +1749,7 @@ def make_adjoint_sources(self, dataset_names: list[str], fwidth: float) -> list[
for name in dataset_names:
if name == "amps":
adjoint_sources += self.make_adjoint_sources_amps(fwidth=fwidth)
else:
elif not np.all(self.n_complex.values == 0.0):
log.warning(
f"Can't create adjoint source for 'ModeData.{type(self)}.{name}'. "
f"for monitor '{self.monitor.name}'. "
Expand Down Expand Up @@ -1948,6 +1934,10 @@ def make_adjoint_sources(
) -> List[Union[CustomCurrentSource, PointDipole]]:
"""Converts a :class:`.FieldData` to a list of adjoint current or point sources."""

# avoids error in edge case where there are extraneous flux monitors not used in objective
if np.all(self.flux.values == 0.0):
return []

raise NotImplementedError(
"Could not formulate adjoint source for 'FluxMonitor' output. To compute derivatives "
"with respect to flux data, please use a 'FieldMonitor' and call '.flux' on the "
Expand Down
4 changes: 2 additions & 2 deletions tidy3d/components/data/sim_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
class AdjointSourceInfo(Tidy3dBaseModel):
"""Stores information about the adjoint sources to pass to autograd pipeline."""

sources: tuple[SourceType, ...] = pd.Field(
sources: Tuple[annotate_type(SourceType), ...] = pd.Field(
...,
title="Adjoint Sources",
description="Set of processed sources to include in the adjoint simulation.",
Expand Down Expand Up @@ -1074,7 +1074,7 @@ def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, SourceTy
)
sources_adj_all[mnt_data.monitor.name] = sources_adj

if not sources_adj_all:
if not any(src for _, src in sources_adj_all.items()):
raise ValueError(
"No adjoint sources created for this simulation. "
"This could indicate a bug in your setup, for example the objective function "
Expand Down
5 changes: 3 additions & 2 deletions tidy3d/components/geometry/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2433,10 +2433,11 @@ def derivative_face(
eps_xyz = [derivative_info.eps_data[f"eps_{dim}{dim}"] for dim in "xyz"]

# number of cells from the edge of data to register "inside" (index = num_cells_in - 1)
num_cells_in = 4
num_cells_in = 3

# if not enough data, just use best guess using eps in medium and simulation
needs_eps_approx = any(len(eps.coords[dim_normal]) <= num_cells_in for eps in eps_xyz)

if derivative_info.eps_approx or needs_eps_approx:
eps_xyz_inside = 3 * [derivative_info.eps_in]
eps_xyz_outside = 3 * [derivative_info.eps_out]
Expand All @@ -2447,7 +2448,7 @@ def derivative_face(
if min_max_index == 0:
index_out, index_in = (0, num_cells_in - 1)
else:
index_out, index_in = (-1, -num_cells_in - 1)
index_out, index_in = (-1, -num_cells_in)
eps_xyz_inside = [eps.isel(**{dim_normal: index_in}) for eps in eps_xyz]
eps_xyz_outside = [eps.isel(**{dim_normal: index_out}) for eps in eps_xyz]

Expand Down
36 changes: 17 additions & 19 deletions tidy3d/plugins/autograd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,37 +127,35 @@ The following components are traceable as inputs to the `td.Simulation`
- `CustomPoleResidue.eps_inf`
- `CustomPoleResidue.poles`

The following components are traceable as outputs of the `td.SimulationData`

- `ModeData.amps`
- `DiffractionData.amps`
- `FieldData.field_components`
- `Cylinder.radius`
- `Cylinder.center` (along non-axis dimensions)

We currently have the following restrictions:

- All monitors in the `Simulation` must be single frequency only.
- Only 500 max structures containing tracers can be added to the `Simulation` to cut down on processing time. In the future, `GeometryGroup` support will allow us to relax this restriction.
- `web.run_async` for simulations with tracers does not return a `BatchData` but rather a `dict` mapping task name to `SimulationData`. There may be high memory usage with many simulations or a lot of data for each.
- Gradient calculations are done client-side, meaning the field data in the traced structure regions must be downloaded. This can be a large amount of data for large, 3D structures.
- `ComplexPolySlab.sub_polyslabs`

### To be supported soon
The following components are traceable as outputs of the `td.SimulationData`

Next on our roadmap (targeting 2.8 and 2.9, summer 2024) is to support:
- `ModeData.amps`

- support for multi-frequency monitors in certain situations (single adjoint source).
- server-side gradient processing.
- `DiffractionData.amps`

- `FieldData.field_components`
- `FieldData` operations:
- `FieldData.flux`
- `SimulationData.get_intensity`
- `SimulationData.get_poynting`

- `PoleResidue` and other dispersive models.
- custom (spatially-dependent) dispersive models, allowing topology optimization with metals.
Other features
- support for multi-frequency monitors in certain situations (single adjoint source).
- server-side gradient processing by passing `local_gradient=False` to the `web` functions. This can dramatically cut down on data storage time, just be careful about using this with multi-frequency monitors and large design regions as it can lead to large data storage on our servers.

We currently have the following restrictions:

- `ComplexPolySlab`
- Only 500 max structures containing tracers can be added to the `Simulation` to cut down on processing time. To bypass this restriction, use `GeometryGroup` to group structures with the same medium.
- `web.run_async` for simulations with tracers does not return a `BatchData` but rather a `dict` mapping task name to `SimulationData`. There may be high memory usage with many simulations or a lot of data for each.

### To be supported soon

Later this year (2024), we plan to support:
Next on our roadmap (targeting 2.8 and 2.9, fall 2024) is to support:

- `TriangleMesh`.
- `GUI` integration of invdes plugin.
Expand Down
31 changes: 28 additions & 3 deletions tidy3d/web/api/autograd/autograd.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be nice to have private debug run function, or maybe a @debug decorator that could do things like this? Not really important but might be convenient to include in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which line? the adjoint field visualization?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah exactly, that and the global toggle to trigger that code

Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@
# default value for whether to do local gradient calculation (True) or server side (False)
LOCAL_GRADIENT = True

# if True, will plot the adjoint fields on the plane provided. used for debugging only
_INSPECT_ADJOINT_FIELDS = False
_INSPECT_ADJOINT_PLANE = td.Box(center=(0, 0, 0), size=(td.inf, td.inf, 0))


def is_valid_for_autograd(simulation: td.Simulation) -> bool:
"""Check whether a supplied simulation can use autograd run."""
Expand Down Expand Up @@ -732,9 +736,7 @@ def setup_adj(
td.log.info("Running custom vjp (adjoint) pipeline.")

# immediately filter out any data_vjps with all 0's in the data
data_fields_vjp = {
key: get_static(value) for key, value in data_fields_vjp.items() if not np.all(value == 0.0)
}
data_fields_vjp = {key: get_static(value) for key, value in data_fields_vjp.items()}

# insert the raw VJP data into the .data of the original SimulationData
sim_data_vjp = sim_data_orig.insert_traced_fields(field_mapping=data_fields_vjp)
Expand All @@ -751,6 +753,29 @@ def setup_adj(
data_vjp_paths=data_vjp_paths, adjoint_monitors=adjoint_monitors
)

if _INSPECT_ADJOINT_FIELDS:
adj_fld_mnt = td.FieldMonitor(
center=_INSPECT_ADJOINT_PLANE.center,
size=_INSPECT_ADJOINT_PLANE.size,
freqs=adjoint_monitors[0].freqs,
name="adjoint_fields",
)

import matplotlib.pylab as plt

import tidy3d.web as web

sim_data_new = web.run(
sim_adj.updated_copy(monitors=[adj_fld_mnt]),
task_name="adjoint_field_viz",
verbose=False,
)
_, (ax1, ax2, ax3) = plt.subplots(1, 3, tight_layout=True, figsize=(10, 4))
sim_data_new.plot_field("adjoint_fields", "Ex", "re", ax=ax1)
sim_data_new.plot_field("adjoint_fields", "Ey", "re", ax=ax2)
sim_data_new.plot_field("adjoint_fields", "Ez", "re", ax=ax3)
plt.show()

td.log.info(f"Adjoint simulation created with {len(sim_adj.sources)} sources.")

return sim_adj, adjoint_source_info
Expand Down
Loading