Skip to content

Commit

Permalink
broadband adjoint
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Oct 11, 2023
1 parent a9706bc commit 6b66abf
Show file tree
Hide file tree
Showing 10 changed files with 352 additions and 203 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

### Added
- Support for multiple frequencies in `output_monitors` in `adjoint` plugin.

### Changed

Expand All @@ -22,9 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- Internal refactor of Web API functionality.
- `Geometry.from_gds` doesn't create unecessary groups of single elements.
- Properly handle `.freqs` in `output_monitors` of adjoint plugin.

### Fixed
- Properly handle `.freqs` in `output_monitors` of adjoint plugin.

## [2.4.2] - 2023-9-28

Expand Down
110 changes: 65 additions & 45 deletions tests/test_plugins/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tidy3d.plugins.adjoint.components.medium import JaxMedium, JaxAnisotropicMedium
from tidy3d.plugins.adjoint.components.medium import JaxCustomMedium, MAX_NUM_CELLS_CUSTOM_MEDIUM
from tidy3d.plugins.adjoint.components.structure import JaxStructure
from tidy3d.plugins.adjoint.components.simulation import JaxSimulation, JaxInfo
from tidy3d.plugins.adjoint.components.simulation import JaxSimulation, JaxInfo, RUN_TIME_FACTOR
from tidy3d.plugins.adjoint.components.simulation import MAX_NUM_INPUT_STRUCTURES
from tidy3d.plugins.adjoint.components.data.sim_data import JaxSimulationData
from tidy3d.plugins.adjoint.components.data.monitor_data import JaxModeData, JaxDiffractionData
Expand Down Expand Up @@ -54,6 +54,12 @@
# name of the output monitor used in tests
MNT_NAME = "mode"

src = td.PointDipole(
center=(0, 0, 0),
source_time=td.GaussianPulse(freq0=FREQ0, fwidth=FREQ0 / 10),
polarization="Ex",
)

# Emulated forward and backward run functions
def run_emulated_fwd(
simulation: td.Simulation,
Expand Down Expand Up @@ -255,7 +261,7 @@ def make_sim(
output_mnt1 = td.ModeMonitor(
size=(10, 10, 0),
mode_spec=td.ModeSpec(num_modes=3),
freqs=[FREQ0],
freqs=[FREQ0, FREQ0 * 1.1],
name=MNT_NAME + "1",
)

Expand All @@ -276,13 +282,13 @@ def make_sim(

output_mnt4 = td.FieldMonitor(
size=(0, 0, 0),
freqs=[FREQ0],
freqs=np.array([FREQ0, FREQ0 * 1.1]),
name=MNT_NAME + "4",
)

extraneous_field_monitor = td.FieldMonitor(
size=(10, 10, 0),
freqs=[1e14, 2e14],
freqs=np.array([1e14, 2e14]),
name="field",
)

Expand All @@ -301,6 +307,7 @@ def make_sim(
jax_struct_custom_anis,
),
output_monitors=(output_mnt1, output_mnt2, output_mnt3, output_mnt4),
sources=[src],
boundary_spec=td.BoundarySpec.pml(x=False, y=False, z=False),
symmetry=(0, 1, -1),
)
Expand Down Expand Up @@ -550,23 +557,8 @@ def _test_adjoint_setup_adj(use_emulated_run):
assert len(sim_vjp.input_structures) == len(sim_orig.input_structures)


# @pytest.mark.parametrize("add_grad_monitors", (True, False))
# def test_convert_tidy3d_to_jax(add_grad_monitors):
# """test conversion of JaxSimulation to Simulation and SimulationData to JaxSimulationData."""
# jax_sim = make_sim(permittivity=EPS, size=SIZE, vertices=VERTICES, base_eps_val=BASE_EPS_VAL)
# if add_grad_monitors:
# jax_sim = jax_sim.add_grad_monitors()
# sim, jax_info = jax_sim.to_simulation()
# assert type(sim) == td.Simulation
# assert sim.type == "Simulation"
# sim_data = run_emulated(sim)
# jax_sim_data = JaxSimulationData.from_sim_data(sim_data, jax_info)
# jax_sim2 = jax_sim_data.simulation
# assert jax_sim_data.simulation == jax_sim


def test_multiple_freqs():
"""Test that sim validation fails when output monitors have multiple frequencies."""
"""Test that sim validation doesnt fail when output monitors have multiple frequencies."""

output_mnt = td.ModeMonitor(
size=(10, 10, 0),
Expand All @@ -575,20 +567,19 @@ def test_multiple_freqs():
name=MNT_NAME,
)

with pytest.raises(pydantic.ValidationError):
_ = JaxSimulation(
size=(10, 10, 10),
run_time=1e-12,
grid_spec=td.GridSpec(wavelength=1.0),
monitors=(),
structures=(),
output_monitors=(output_mnt,),
input_structures=(),
)
_ = JaxSimulation(
size=(10, 10, 10),
run_time=1e-12,
grid_spec=td.GridSpec(wavelength=1.0),
monitors=(),
structures=(),
output_monitors=(output_mnt,),
input_structures=(),
)


def test_different_freqs():
"""Test that sim validation fails when output monitors have different frequencies."""
"""Test that sim validation doesnt fail when output monitors have different frequencies."""

output_mnt1 = td.ModeMonitor(
size=(10, 10, 0),
Expand All @@ -602,16 +593,15 @@ def test_different_freqs():
freqs=[2e14],
name=MNT_NAME + "2",
)
with pytest.raises(pydantic.ValidationError):
_ = JaxSimulation(
size=(10, 10, 10),
run_time=1e-12,
grid_spec=td.GridSpec(wavelength=1.0),
monitors=(),
structures=(),
output_monitors=(output_mnt1, output_mnt2),
input_structures=(),
)
_ = JaxSimulation(
size=(10, 10, 10),
run_time=1e-12,
grid_spec=td.GridSpec(wavelength=1.0),
monitors=(),
structures=(),
output_monitors=(output_mnt1, output_mnt2),
input_structures=(),
)


def test_get_freq_adjoint():
Expand All @@ -628,9 +618,11 @@ def test_get_freq_adjoint():
)

with pytest.raises(AdjointError):
_ = sim.freq_adjoint
_ = sim.freqs_adjoint

freq0 = 2e14
freq1 = 3e14
freq2 = 1e14
output_mnt1 = td.ModeMonitor(
size=(10, 10, 0),
mode_spec=td.ModeSpec(num_modes=3),
Expand All @@ -640,7 +632,7 @@ def test_get_freq_adjoint():
output_mnt2 = td.ModeMonitor(
size=(10, 10, 0),
mode_spec=td.ModeSpec(num_modes=3),
freqs=[freq0],
freqs=[freq1, freq2, freq0],
name=MNT_NAME + "2",
)
sim = JaxSimulation(
Expand All @@ -652,7 +644,11 @@ def test_get_freq_adjoint():
output_monitors=(output_mnt1, output_mnt2),
input_structures=(),
)
assert sim.freq_adjoint == freq0

freqs = [freq0, freq1, freq2]
freqs.sort()

assert sim.freqs_adjoint == freqs


def test_get_fwidth_adjoint():
Expand Down Expand Up @@ -691,7 +687,7 @@ def make_sim(sources=(), fwidth_adjoint=None):
src_times = [td.GaussianPulse(freq0=freq0, fwidth=fwidth) for fwidth in fwidths]
srcs = [td.PointDipole(source_time=src_time, polarization="Ex") for src_time in src_times]
sim = make_sim(sources=srcs, fwidth_adjoint=None)
assert np.isclose(sim._fwidth_adjoint, np.mean(fwidths))
assert np.isclose(sim._fwidth_adjoint, np.max(fwidths))

# a few sources, with custom fwidth specified
fwidth_custom = 3e13
Expand Down Expand Up @@ -1548,3 +1544,27 @@ def f(x):
return jnp.sum(jnp.abs(jnp.array(sd["test"].amps.values)))

jax.grad(f)(0.5)


fwidth_run_time_expected = [
(FREQ0 / 10, 1e-11, 1e-11), # run time supplied explicitly, use that
(FREQ0 / 10, None, RUN_TIME_FACTOR / (FREQ0 / 10)), # no run_time, use fwidth supplied
(FREQ0 / 20, None, RUN_TIME_FACTOR / (FREQ0 / 20)), # no run_time, use fwidth supplied
]


@pytest.mark.parametrize("fwidth, run_time, run_time_expected", fwidth_run_time_expected)
def test_adjoint_run_time(use_emulated_run, tmp_path, fwidth, run_time, run_time_expected):

sim = make_sim(permittivity=EPS, size=SIZE, vertices=VERTICES, base_eps_val=BASE_EPS_VAL)

sim = sim.updated_copy(run_time_adjoint=run_time, fwidth_adjoint=fwidth)

sim_data = run(sim, task_name="test", path=str(tmp_path / RUN_FILE))

run_time_adj = sim._run_time_adjoint
fwidth_adj = sim._fwidth_adjoint

sim_adj = sim_data.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj)

assert sim_adj.run_time == run_time_expected
4 changes: 4 additions & 0 deletions tidy3d/components/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def spectrum(
if not complex_fields:
time_amps = np.real(time_amps)

# if all time amplitudes are zero, just return (complex-valued) zeros for spectrum
if np.allclose(time_amps, 0.0):
return (0.0 + 0.0j) * np.zeros_like(freqs)

# Cut to only relevant times
relevant_time_inds = np.where(np.abs(time_amps) / np.amax(np.abs(time_amps)) > DFT_CUTOFF)
# find first and last index where the filter is True
Expand Down
Loading

0 comments on commit 6b66abf

Please sign in to comment.