Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🗻 autograd: broadband #1774

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Broadband adjoint support in autograd for adjoint sources with the same spatial dependence.

## [2.7.1]

### Added
Expand All @@ -13,6 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `ModeSolver` methods to plot the mode plane simulation components, including `.plot()`, `.plot_eps()`, `.plot_structures_eps()`, `.plot_grid()`, and `.plot_pml()`.
- Support for differentiation with respect to monitor attributes that require interpolation, such as flux and intensity.
- Support for automatic differentiation with respect to `.eps_inf` and `.poles` contained in dispersive mediums `td.PoleResidue` and `td.CustomPoleResidue`.
- Support for automatic differentiation with respect to `.eps_inf` and `.poles` contained in dispersive mediums `td.PoleResidue` and `td.CustomPoleResidue`.

### Changed

### Fixed
- Bug where boundary layers would be plotted too small in 2D simulations.
Expand Down
271 changes: 239 additions & 32 deletions tests/test_components/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import numpy as np
import pytest
import tidy3d as td
import xarray as xr
from tidy3d.components.autograd.derivative_utils import DerivativeInfo
from tidy3d.web import run_async
from tidy3d.web.api.autograd.autograd import run
from tidy3d.web import run, run_async

from ..utils import SIM_FULL, AssertLogLevel, run_emulated

Expand Down Expand Up @@ -53,6 +53,7 @@

WVL = 1.0
FREQ0 = td.C_0 / WVL
FREQS = [0.9 * FREQ0, FREQ0, FREQ0 * 1.1]

# sim sizes
LZ = 7 * WVL
Expand Down Expand Up @@ -154,8 +155,10 @@ def run_async_emulated(simulations, **kwargs):
def make_structures(params: anp.ndarray) -> dict[str, td.Structure]:
"""Make a dictionary of the structures given the parameters."""

np.random.seed(0)

vector = np.random.random(N_PARAMS) - 0.5
vector /= np.linalg.norm(vector)
vector = vector / np.linalg.norm(vector)

# static components
box = td.Box(center=(0, 0, 0), size=(1, 1, 1))
Expand Down Expand Up @@ -411,6 +414,7 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = False) -> None:
if TEST_POLYSLAB_SPEED:
args = [("polyslab", "mode")]


# args = [("geo_group", "mode")]


Expand Down Expand Up @@ -612,33 +616,13 @@ def objective(*params):
ag.grad(objective)(params0)


def test_warning_no_adjoint_sources(log_capture, monkeypatch, use_emulated_run):
"""Make sure we get the right warning with no adjoint sources, and no error."""

monitor_key = "mode"
structure_key = "size_element"
monitor, postprocess = make_monitors()[monitor_key]

def make_sim(*args):
structure = make_structures(*args)[structure_key]
return SIM_BASE.updated_copy(structures=[structure], monitors=[monitor])

def objective(*args):
"""Objective function."""
sim = make_sim(*args)
data = run(sim, task_name="autograd_test", verbose=False)
value = postprocess(data, data[monitor_key])
return value

monkeypatch.setattr(td.SimulationData, "make_adjoint_sources", lambda *args, **kwargs: [])

with AssertLogLevel(log_capture, "WARNING", contains_str="No adjoint sources"):
ag.grad(objective)(params0)


def test_web_failure_handling(log_capture, monkeypatch, use_emulated_run, use_emulated_run_async):
"""Test what happens when autograd run pipeline fails."""

def fail(*args, **kwargs):
"""Just raise an exception."""
raise ValueError("test")

monitor_key = "mode"
structure_key = "size_element"
monitor, postprocess = make_monitors()[monitor_key]
Expand All @@ -650,14 +634,10 @@ def make_sim(*args):
def objective(*args):
"""Objective function."""
sim = make_sim(*args)
data = run(sim, task_name="autograd_test", verbose=False)
data = run(sim, task_name=None, verbose=False)
value = postprocess(data, data[monitor_key])
return value

def fail(*args, **kwargs):
"""Just raise an exception."""
raise ValueError("test")

""" if autograd run raises exception, raise a warning and continue with regular ."""

monkeypatch.setattr(td.web.api.autograd.autograd, "_run", fail)
Expand Down Expand Up @@ -1008,3 +988,230 @@ def f(x):
* no copy : 16 sec
* no to_static(): 13 sec
"""

FREQ1 = FREQ0 * 1.1

mnt_single = td.ModeMonitor(
size=(2, 2, 0),
center=(0, 0, LZ / 2 - WVL),
mode_spec=td.ModeSpec(num_modes=2),
freqs=[FREQ0],
name="single",
)

mnt_multi = td.ModeMonitor(
size=(2, 2, 0),
center=(0, 0, LZ / 2 - WVL),
mode_spec=td.ModeSpec(num_modes=2),
freqs=[FREQ0, FREQ1],
name="multi",
)


def make_objective(postprocess_fn: typing.Callable, structure_key: str) -> typing.Callable:
def objective(params):
structure_traced = make_structures(params)[structure_key]
sim = SIM_BASE.updated_copy(
structures=[structure_traced],
monitors=list(SIM_BASE.monitors) + [mnt_single, mnt_multi],
)
data = run(sim, task_name="multifreq_test")
return postprocess_fn(data)

return objective


def get_amps(sim_data: td.SimulationData, mnt_name: str) -> xr.DataArray:
return sim_data[mnt_name].amps


def power(amps: xr.DataArray) -> float:
"""Reduce a selected DataArray into just a float for objective function."""
return anp.sum(anp.abs(amps.values) ** 2)


def postprocess_0_src(sim_data: td.SimulationData) -> float:
"""Postprocess function that should return 0 adjoint sources."""
return 0.0


def compute_grad(postprocess_fn: typing.Callable, structure_key: str) -> typing.Callable:
objective = make_objective(postprocess_fn, structure_key=structure_key)
params = params0 + 1.0 # +1 is to avoid a warning in size_element with value 0
return ag.grad(objective)(params)


def check_0_src(log_capture, structure_key):
def postprocess(sim_data: td.SimulationData) -> float:
"""Postprocess function that should return 0 adjoint sources."""
return 0.0

compute_grad(postprocess, structure_key=structure_key)


# NOTE: not tested, just raises regular warning


def check_1_src_single(log_capture, structure_key):
def postprocess(sim_data: td.SimulationData) -> float:
"""Postprocess function that should return 1 adjoint sources."""
amps = get_amps(sim_data, "single").sel(mode_index=0, direction="+")
return power(amps)

with AssertLogLevel(
log_capture, log_level_expected="INFO", contains_str="One monitor with one adjoint source."
):
compute_grad(postprocess, structure_key=structure_key)


def check_2_src_single(log_capture, structure_key):
def postprocess_2_src_single(sim_data: td.SimulationData) -> float:
"""Postprocess function that should return 2 different adjoint sources."""
amps = get_amps(sim_data, "single").sel(mode_index=0)
return power(amps)

with AssertLogLevel(
log_capture, log_level_expected="INFO", contains_str="One monitor with 2 adjoint sources."
):
compute_grad(postprocess_2_src_single, structure_key=structure_key)


def check_1_src_multi(log_capture, structure_key):
def postprocess(sim_data: td.SimulationData) -> float:
"""Postprocess function that should return 1 adjoint sources."""
amps = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=FREQ0)
return power(amps)

with AssertLogLevel(
log_capture, log_level_expected="INFO", contains_str="One monitor with one adjoint source."
):
compute_grad(postprocess, structure_key=structure_key)


def check_2_src_multi(log_capture, structure_key):
def postprocess(sim_data: td.SimulationData) -> float:
"""Postprocess function that should return 2 different adjoint sources."""
amps = get_amps(sim_data, "multi").sel(mode_index=0, f=FREQ1)
return power(amps)

with AssertLogLevel(
log_capture, log_level_expected="INFO", contains_str="One monitor with 2 adjoint sources."
):
compute_grad(postprocess, structure_key=structure_key)


def check_2_src_both(log_capture, structure_key):
def postprocess(sim_data: td.SimulationData) -> float:
"""Postprocess function that should return 2 different adjoint sources."""
amps_single = get_amps(sim_data, "single").sel(mode_index=0, direction="+")
amps_multi = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=FREQ0)
return power(amps_single) + power(amps_multi)

with AssertLogLevel(
log_capture,
log_level_expected="INFO",
contains_str="Several adjoint sources from different monitors, all with same single frequency.",
):
compute_grad(postprocess, structure_key=structure_key)


def check_1_error_multisrc(log_capture, structure_key):
def postprocess(sim_data: td.SimulationData) -> float:
"""Postprocess function that should raise ValueError because diff sources, diff freqs."""
amps_single = get_amps(sim_data, "single").sel(mode_index=0, direction="+")
amps_multi = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=FREQ1)
return power(amps_single) + power(amps_multi)

with pytest.raises(ValueError):
compute_grad(postprocess, structure_key=structure_key)


def check_2_error_multisrc(log_capture, structure_key):
def postprocess(sim_data: td.SimulationData) -> float:
"""Postprocess function that should raise ValueError because diff sources, diff freqs."""
amps_single = get_amps(sim_data, "single").sel(mode_index=0, direction="+")
amps_multi = get_amps(sim_data, "multi").sel(mode_index=0, direction="+")
return power(amps_single) + power(amps_multi)

with pytest.raises(ValueError):
compute_grad(postprocess, structure_key=structure_key)


def check_1_src_broadband(log_capture, structure_key):
def postprocess(sim_data: td.SimulationData) -> float:
"""Postprocess function that should return 1 broadband adjoint sources with many freqs."""
amps = get_amps(sim_data, "multi").sel(mode_index=0, direction="+")
return power(amps)

with AssertLogLevel(
log_capture,
log_level_expected="INFO",
contains_str="Constructing broadband adjoint source and performing post-run normalization",
):
compute_grad(postprocess, structure_key=structure_key)


MULT_FREQ_TEST_CASES = dict(
src_1_freq_1=check_1_src_single,
src_2_freq_1=check_2_src_single,
src_1_freq_2=check_1_src_multi,
src_2_freq_1_mon_1=check_1_src_multi,
src_2_freq_1_mon_2=check_2_src_both,
src_2_freq_2_mon_1=check_1_error_multisrc,
src_2_freq_2_mon_2=check_2_error_multisrc,
src_1_freq_2_broadband=check_1_src_broadband,
)

checks = list(MULT_FREQ_TEST_CASES.items())


@pytest.mark.parametrize("label, check_fn", checks)
@pytest.mark.parametrize("structure_key", structure_keys_)
def test_multi_freq_edge_cases(log_capture, use_emulated_run, structure_key, label, check_fn):
# test multi-frequency adjoint handling
check_fn(structure_key=structure_key, log_capture=log_capture)


@pytest.mark.parametrize("structure_key", structure_keys_)
def test_multi_frequency_equivalence(use_emulated_run, structure_key):
"""Test an objective function through tidy3d autograd."""

def objective_indi(params, structure_key) -> float:
power_sum = 0.0

for f in mnt_multi.freqs:
structure_traced = make_structures(params)[structure_key]
sim = SIM_BASE.updated_copy(
structures=[structure_traced],
monitors=list(SIM_BASE.monitors) + [mnt_multi],
)

sim_data = run(sim, task_name="multifreq_test")
amps_i = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=f)
power_i = power(amps_i)
power_sum = power_sum + power_i

return power_sum

def objective_multi(params, structure_key) -> float:
structure_traced = make_structures(params)[structure_key]
sim = SIM_BASE.updated_copy(
structures=[structure_traced],
monitors=list(SIM_BASE.monitors) + [mnt_multi],
)
sim_data = run(sim, task_name="multifreq_test")
amps = get_amps(sim_data, "multi").sel(mode_index=0, direction="+")
return power(amps)

params0_ = params0 + 1.0

J_indi = objective_indi(params0_, structure_key)
J_multi = objective_multi(params0_, structure_key)

np.testing.assert_allclose(J_indi, J_multi)

grad_indi = ag.grad(objective_indi)(params0_, structure_key=structure_key)
grad_multi = ag.grad(objective_multi)(params0_, structure_key=structure_key)

assert not np.any(np.isclose(grad_indi, 0))
assert not np.any(np.isclose(grad_multi, 0))
8 changes: 5 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
""" utilities shared between all tests """
np.random.seed(4)

# function used to generate the data for emulated runs
DATA_GEN_FN = np.random.random
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we planning to use something else here? I think in general it might be a good idea to introduce a fixture that returns a generator, instead of relying on a global np.random.seed (considered legacy and might be deprecated at some point). I have something like that for the autograd plugin tests, but we might want to tweak and reuse this globally.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. and no plans (yours sounds good). I just left it here since it makes it easier for testing. Should I change it? or is it fine to leave like this for now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was just a general comment, lef‘s leave it for now. Would make sense to add this in one swoop together with the rest of the tests.


FREQS = np.array([1.90, 2.01, 2.2]) * 1e12
SIM_MONITORS = td.Simulation(
Expand Down Expand Up @@ -880,7 +882,7 @@ def make_data(
"""make a random DataArray out of supplied coordinates and data_type."""
data_shape = [len(coords[k]) for k in data_array_type._dims]
np.random.seed(1)
data = np.random.random(data_shape)
data = DATA_GEN_FN(data_shape)

data = (1 + 0.5j) * data if is_complex else data
data = gaussian_filter(data, sigma=1.0) # smooth out the data a little so it isnt random
Expand Down Expand Up @@ -939,7 +941,7 @@ def make_mode_solver_data(monitor: td.ModeSolverMonitor) -> td.ModeSolverData:
index_coords["mode_index"] = np.arange(monitor.mode_spec.num_modes)
index_data_shape = (len(index_coords["f"]), len(index_coords["mode_index"]))
index_data = ModeIndexDataArray(
(1 + 1j) * np.random.random(index_data_shape), coords=index_coords
(1 + 1j) * DATA_GEN_FN(index_data_shape), coords=index_coords
)
for field_name in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]:
coords = get_spatial_coords_dict(simulation, monitor, field_name)
Expand Down Expand Up @@ -977,7 +979,7 @@ def make_diff_data(monitor: td.DiffractionMonitor) -> td.DiffractionData:
orders_x = np.linspace(-1, 1, 3)
orders_y = np.linspace(-2, 2, 5)
coords = dict(orders_x=orders_x, orders_y=orders_y, f=f)
values = np.random.random((len(orders_x), len(orders_y), len(f)))
values = DATA_GEN_FN((len(orders_x), len(orders_y), len(f)))
data = td.DiffractionDataArray(values, coords=coords)
field_data = {field: data for field in ("Er", "Etheta", "Ephi", "Hr", "Htheta", "Hphi")}
return td.DiffractionData(monitor=monitor, sim_size=(1, 1), bloch_vecs=(0, 0), **field_data)
Expand Down
1 change: 1 addition & 0 deletions tidy3d/components/data/data_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(self, data, *args, **kwargs):
# initialize with untraced data
super().__init__(getval(data), *args, **kwargs)
# and put tracers in .attrs

if isbox(data):
self.attrs[AUTOGRAD_KEY] = data

Expand Down
Loading
Loading