diff --git a/CHANGELOG.md b/CHANGELOG.md index a649562e1..354bf38a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- Support for multiple frequencies in `output_monitors` in `adjoint` plugin. ### Changed @@ -22,9 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Internal refactor of Web API functionality. - `Geometry.from_gds` doesn't create unecessary groups of single elements. -- Properly handle `.freqs` in `output_monitors` of adjoint plugin. ### Fixed +- Properly handle `.freqs` in `output_monitors` of adjoint plugin. ## [2.4.2] - 2023-9-28 diff --git a/tests/test_plugins/test_adjoint.py b/tests/test_plugins/test_adjoint.py index 5ebfb57fb..9fd7d0137 100644 --- a/tests/test_plugins/test_adjoint.py +++ b/tests/test_plugins/test_adjoint.py @@ -21,7 +21,7 @@ from tidy3d.plugins.adjoint.components.medium import JaxMedium, JaxAnisotropicMedium from tidy3d.plugins.adjoint.components.medium import JaxCustomMedium, MAX_NUM_CELLS_CUSTOM_MEDIUM from tidy3d.plugins.adjoint.components.structure import JaxStructure -from tidy3d.plugins.adjoint.components.simulation import JaxSimulation, JaxInfo +from tidy3d.plugins.adjoint.components.simulation import JaxSimulation, JaxInfo, RUN_TIME_FACTOR from tidy3d.plugins.adjoint.components.simulation import MAX_NUM_INPUT_STRUCTURES from tidy3d.plugins.adjoint.components.data.sim_data import JaxSimulationData from tidy3d.plugins.adjoint.components.data.monitor_data import JaxModeData, JaxDiffractionData @@ -54,6 +54,12 @@ # name of the output monitor used in tests MNT_NAME = "mode" +src = td.PointDipole( + center=(0, 0, 0), + source_time=td.GaussianPulse(freq0=FREQ0, fwidth=FREQ0 / 10), + polarization="Ex", +) + # Emulated forward and backward run functions def run_emulated_fwd( simulation: td.Simulation, @@ -255,7 +261,7 @@ def make_sim( output_mnt1 = td.ModeMonitor( size=(10, 10, 0), mode_spec=td.ModeSpec(num_modes=3), - freqs=[FREQ0], + freqs=[FREQ0, FREQ0 * 1.1], name=MNT_NAME + "1", ) @@ -276,13 +282,13 @@ def make_sim( output_mnt4 = td.FieldMonitor( size=(0, 0, 0), - freqs=[FREQ0], + freqs=np.array([FREQ0, FREQ0 * 1.1]), name=MNT_NAME + "4", ) extraneous_field_monitor = td.FieldMonitor( size=(10, 10, 0), - freqs=[1e14, 2e14], + freqs=np.array([1e14, 2e14]), name="field", ) @@ -301,6 +307,7 @@ def make_sim( jax_struct_custom_anis, ), output_monitors=(output_mnt1, output_mnt2, output_mnt3, output_mnt4), + sources=[src], boundary_spec=td.BoundarySpec.pml(x=False, y=False, z=False), symmetry=(0, 1, -1), ) @@ -550,23 +557,8 @@ def _test_adjoint_setup_adj(use_emulated_run): assert len(sim_vjp.input_structures) == len(sim_orig.input_structures) -# @pytest.mark.parametrize("add_grad_monitors", (True, False)) -# def test_convert_tidy3d_to_jax(add_grad_monitors): -# """test conversion of JaxSimulation to Simulation and SimulationData to JaxSimulationData.""" -# jax_sim = make_sim(permittivity=EPS, size=SIZE, vertices=VERTICES, base_eps_val=BASE_EPS_VAL) -# if add_grad_monitors: -# jax_sim = jax_sim.add_grad_monitors() -# sim, jax_info = jax_sim.to_simulation() -# assert type(sim) == td.Simulation -# assert sim.type == "Simulation" -# sim_data = run_emulated(sim) -# jax_sim_data = JaxSimulationData.from_sim_data(sim_data, jax_info) -# jax_sim2 = jax_sim_data.simulation -# assert jax_sim_data.simulation == jax_sim - - def test_multiple_freqs(): - """Test that sim validation fails when output monitors have multiple frequencies.""" + """Test that sim validation doesnt fail when output monitors have multiple frequencies.""" output_mnt = td.ModeMonitor( size=(10, 10, 0), @@ -575,20 +567,19 @@ def test_multiple_freqs(): name=MNT_NAME, ) - with pytest.raises(pydantic.ValidationError): - _ = JaxSimulation( - size=(10, 10, 10), - run_time=1e-12, - grid_spec=td.GridSpec(wavelength=1.0), - monitors=(), - structures=(), - output_monitors=(output_mnt,), - input_structures=(), - ) + _ = JaxSimulation( + size=(10, 10, 10), + run_time=1e-12, + grid_spec=td.GridSpec(wavelength=1.0), + monitors=(), + structures=(), + output_monitors=(output_mnt,), + input_structures=(), + ) def test_different_freqs(): - """Test that sim validation fails when output monitors have different frequencies.""" + """Test that sim validation doesnt fail when output monitors have different frequencies.""" output_mnt1 = td.ModeMonitor( size=(10, 10, 0), @@ -602,16 +593,15 @@ def test_different_freqs(): freqs=[2e14], name=MNT_NAME + "2", ) - with pytest.raises(pydantic.ValidationError): - _ = JaxSimulation( - size=(10, 10, 10), - run_time=1e-12, - grid_spec=td.GridSpec(wavelength=1.0), - monitors=(), - structures=(), - output_monitors=(output_mnt1, output_mnt2), - input_structures=(), - ) + _ = JaxSimulation( + size=(10, 10, 10), + run_time=1e-12, + grid_spec=td.GridSpec(wavelength=1.0), + monitors=(), + structures=(), + output_monitors=(output_mnt1, output_mnt2), + input_structures=(), + ) def test_get_freq_adjoint(): @@ -628,9 +618,11 @@ def test_get_freq_adjoint(): ) with pytest.raises(AdjointError): - _ = sim.freq_adjoint + _ = sim.freqs_adjoint freq0 = 2e14 + freq1 = 3e14 + freq2 = 1e14 output_mnt1 = td.ModeMonitor( size=(10, 10, 0), mode_spec=td.ModeSpec(num_modes=3), @@ -640,7 +632,7 @@ def test_get_freq_adjoint(): output_mnt2 = td.ModeMonitor( size=(10, 10, 0), mode_spec=td.ModeSpec(num_modes=3), - freqs=[freq0], + freqs=[freq1, freq2, freq0], name=MNT_NAME + "2", ) sim = JaxSimulation( @@ -652,7 +644,11 @@ def test_get_freq_adjoint(): output_monitors=(output_mnt1, output_mnt2), input_structures=(), ) - assert sim.freq_adjoint == freq0 + + freqs = [freq0, freq1, freq2] + freqs.sort() + + assert sim.freqs_adjoint == freqs def test_get_fwidth_adjoint(): @@ -691,7 +687,7 @@ def make_sim(sources=(), fwidth_adjoint=None): src_times = [td.GaussianPulse(freq0=freq0, fwidth=fwidth) for fwidth in fwidths] srcs = [td.PointDipole(source_time=src_time, polarization="Ex") for src_time in src_times] sim = make_sim(sources=srcs, fwidth_adjoint=None) - assert np.isclose(sim._fwidth_adjoint, np.mean(fwidths)) + assert np.isclose(sim._fwidth_adjoint, np.max(fwidths)) # a few sources, with custom fwidth specified fwidth_custom = 3e13 @@ -1548,3 +1544,27 @@ def f(x): return jnp.sum(jnp.abs(jnp.array(sd["test"].amps.values))) jax.grad(f)(0.5) + + +fwidth_run_time_expected = [ + (FREQ0 / 10, 1e-11, 1e-11), # run time supplied explicitly, use that + (FREQ0 / 10, None, RUN_TIME_FACTOR / (FREQ0 / 10)), # no run_time, use fwidth supplied + (FREQ0 / 20, None, RUN_TIME_FACTOR / (FREQ0 / 20)), # no run_time, use fwidth supplied +] + + +@pytest.mark.parametrize("fwidth, run_time, run_time_expected", fwidth_run_time_expected) +def test_adjoint_run_time(use_emulated_run, tmp_path, fwidth, run_time, run_time_expected): + + sim = make_sim(permittivity=EPS, size=SIZE, vertices=VERTICES, base_eps_val=BASE_EPS_VAL) + + sim = sim.updated_copy(run_time_adjoint=run_time, fwidth_adjoint=fwidth) + + sim_data = run(sim, task_name="test", path=str(tmp_path / RUN_FILE)) + + run_time_adj = sim._run_time_adjoint + fwidth_adj = sim._fwidth_adjoint + + sim_adj = sim_data.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj) + + assert sim_adj.run_time == run_time_expected diff --git a/tidy3d/components/source.py b/tidy3d/components/source.py index e7868fe35..e94eed959 100644 --- a/tidy3d/components/source.py +++ b/tidy3d/components/source.py @@ -98,6 +98,10 @@ def spectrum( if not complex_fields: time_amps = np.real(time_amps) + # if all time amplitudes are zero, just return (complex-valued) zeros for spectrum + if np.allclose(time_amps, 0.0): + return (0.0 + 0.0j) * np.zeros_like(freqs) + # Cut to only relevant times relevant_time_inds = np.where(np.abs(time_amps) / np.amax(np.abs(time_amps)) > DFT_CUTOFF) # find first and last index where the filter is True diff --git a/tidy3d/plugins/adjoint/components/data/monitor_data.py b/tidy3d/plugins/adjoint/components/data/monitor_data.py index ab1c6bc48..eb620c9a4 100644 --- a/tidy3d/plugins/adjoint/components/data/monitor_data.py +++ b/tidy3d/plugins/adjoint/components/data/monitor_data.py @@ -232,90 +232,93 @@ def time_reversed_copy(self) -> FieldData: def to_adjoint_sources(self, fwidth: float) -> List[CustomFieldSource]: """Converts a :class:`.JaxFieldData` to a list of adjoint :class:`.CustomFieldSource.""" - # parse the frequency from the scalar field data - freqs = [scalar_fld.coords["f"] for _, scalar_fld in self.field_components.items()] - if any(len(fs) != 1 for fs in freqs): - raise AdjointError("FieldData must have only one frequency.") - freqs = [fs[0] for fs in freqs] - if len(set(freqs)) != 1: - raise AdjointError("FieldData must all contain the same frequency.") - freq0 = freqs[0] - - omega0 = 2 * np.pi * freq0 - scaling_factor = 1 / (MU_0 * omega0) - interpolate_source = True + sources = [] - # dipole case if np.allclose(np.array(self.monitor.size), np.zeros(3)): - dipoles = [] for polarization, field_component in self.field_components.items(): + if field_component is None: continue - forward_amp = complex(field_component.as_ndarray) - adj_phase = 3 * np.pi / 2 + np.angle(forward_amp) + for freq0 in field_component.coords["f"]: + + omega0 = 2 * np.pi * freq0 + scaling_factor = 1 / (MU_0 * omega0) + + forward_amp = complex(field_component.sel(f=freq0).values) + + adj_phase = 3 * np.pi / 2 + np.angle(forward_amp) + + adj_amp = scaling_factor * forward_amp + + src_adj = PointDipole( + center=self.monitor.center, + polarization=polarization, + source_time=GaussianPulse( + freq0=freq0, fwidth=fwidth, amplitude=abs(adj_amp), phase=adj_phase + ), + interpolate=interpolate_source, + ) + + sources.append(src_adj) + else: + + # Define source geometry based on coordinates in the data + data_mins = [] + data_maxs = [] - adj_amp = scaling_factor * forward_amp + def shift_value(coords) -> float: + """How much to shift the geometry by along a dimension (only if > 1D).""" + return 1e-5 if len(coords) > 1 else 0 - src_adj = PointDipole( - center=self.monitor.center, - polarization=polarization, + for _, field_component in self.field_components.items(): + coords = field_component.coords + data_mins.append({key: min(val) + shift_value(val) for key, val in coords.items()}) + data_maxs.append({key: max(val) + shift_value(val) for key, val in coords.items()}) + + rmin = [] + rmax = [] + for dim in "xyz": + rmin.append(max(val[dim] for val in data_mins)) + rmax.append(min(val[dim] for val in data_maxs)) + + source_geo = Box.from_bounds(rmin=rmin, rmax=rmax) + + # Define source dataset + # Offset coordinates by source center since local coords are assumed in CustomCurrentSource + + for freq0 in tuple(self.field_components.values())[0].coords["f"]: + + src_field_components = {} + for name, field_component in self.field_components.items(): + field_component = field_component.sel(f=freq0) + forward_amps = field_component.as_ndarray + values = -1j * forward_amps + coords = field_component.coords + for dim, key in enumerate("xyz"): + coords[key] = np.array(coords[key]) - source_geo.center[dim] + coords["f"] = np.array([freq0]) + values = np.expand_dims(values, axis=-1) + if not np.all(values == 0): + src_field_components[name] = ScalarFieldDataArray(values, coords=coords) + + dataset = FieldDataset(**src_field_components) + + custom_source = CustomCurrentSource( + center=source_geo.center, + size=source_geo.size, source_time=GaussianPulse( - freq0=freq0, fwidth=fwidth, amplitude=abs(adj_amp), phase=adj_phase + freq0=freq0, + fwidth=fwidth, ), + current_dataset=dataset, interpolate=interpolate_source, ) - dipoles.append(src_adj) - return dipoles - - # Define source geometry based on coordinates in the data - data_mins = [] - data_maxs = [] - - def shift_value(coords) -> float: - """How much to shift the geometry by along a dimension (only if > 1D).""" - return 1e-5 if len(coords) > 1 else 0 - - for _, field_component in self.field_components.items(): - coords = field_component.coords - data_mins.append({key: min(val) + shift_value(val) for key, val in coords.items()}) - data_maxs.append({key: max(val) + shift_value(val) for key, val in coords.items()}) - - rmin = [] - rmax = [] - for dim in "xyz": - rmin.append(max(val[dim] for val in data_mins)) - rmax.append(min(val[dim] for val in data_maxs)) - - source_geo = Box.from_bounds(rmin=rmin, rmax=rmax) - - # Define source dataset - # Offset coordinates by source center since local coords are assumed in CustomCurrentSource - src_field_components = {} - for name, field_component in self.field_components.items(): - forward_amps = field_component.as_ndarray - values = -1j * forward_amps - coords = field_component.coords - for dim, key in enumerate("xyz"): - coords[key] = np.array(coords[key]) - source_geo.center[dim] - if not np.all(values == 0): - src_field_components[name] = ScalarFieldDataArray(values, coords=coords) - - dataset = FieldDataset(**src_field_components) - custom_source = CustomCurrentSource( - center=source_geo.center, - size=source_geo.size, - source_time=GaussianPulse( - freq0=freq0, - fwidth=fwidth, - ), - current_dataset=dataset, - interpolate=interpolate_source, - ) + sources.append(custom_source) - return [custom_source] + return sources @register_pytree_node_class diff --git a/tidy3d/plugins/adjoint/components/data/sim_data.py b/tidy3d/plugins/adjoint/components/data/sim_data.py index 21cdbae2d..a4b28304a 100644 --- a/tidy3d/plugins/adjoint/components/data/sim_data.py +++ b/tidy3d/plugins/adjoint/components/data/sim_data.py @@ -4,6 +4,8 @@ from typing import Tuple, Dict, Union, List import pydantic.v1 as pd +import numpy as np +import xarray as xr from jax.tree_util import register_pytree_node_class @@ -157,7 +159,7 @@ def split_fwd_sim_data( return user_sim_data, adjoint_sim_data - def make_adjoint_simulation(self, fwidth: float) -> JaxSimulation: + def make_adjoint_simulation(self, fwidth: float, run_time: float) -> JaxSimulation: """Make an adjoint simulation out of the data provided (generally, the vjp sim data).""" sim_fwd = self.simulation @@ -171,11 +173,19 @@ def make_adjoint_simulation(self, fwidth: float) -> JaxSimulation: for adj_source in mnt_data_vjp.to_adjoint_sources(fwidth=fwidth): adj_srcs.append(adj_source) - update_dict = dict(boundary_spec=bc_adj, sources=adj_srcs, monitors=(), output_monitors=()) + update_dict = dict( + boundary_spec=bc_adj, + sources=adj_srcs, + monitors=(), + output_monitors=(), + run_time=run_time, + normalize_index=None, # normalize later, frequency-by-frequency + ) + update_dict.update( sim_fwd.get_grad_monitors( input_structures=sim_fwd.input_structures, - freq_adjoint=sim_fwd.freq_adjoint, + freqs_adjoint=sim_fwd.freqs_adjoint, include_eps_mnts=False, ) ) @@ -188,3 +198,28 @@ def make_adjoint_simulation(self, fwidth: float) -> JaxSimulation: update_dict.update(dict(grid_spec=grid_spec_adj)) return sim_fwd.updated_copy(**update_dict) + + def normalize_adjoint_fields(self) -> JaxSimulationData: + """Make copy of jax_sim_data with grad_data (fields) normalized by adjoint sources.""" + + grad_data_norm = [] + for field_data in self.grad_data: + field_components_norm = {} + for field_name, field_component in field_data.field_components.items(): + freqs = field_component.coords["f"] + norm_factor_f = np.zeros(len(freqs), dtype=complex) + for i, freq in enumerate(freqs): + freq = float(freq) + for source_index, source in enumerate(self.simulation.sources): + if source.source_time.freq0 == freq and source.source_time.amplitude > 0: + spectrum_fn = self.source_spectrum(source_index) + norm_factor_f[i] = complex(spectrum_fn([freq])[0]) + + norm_factor_f_darr = xr.DataArray(norm_factor_f, coords=dict(f=freqs)) + field_component_norm = field_component / norm_factor_f_darr + field_components_norm[field_name] = field_component_norm + + field_data_norm = field_data.updated_copy(**field_components_norm) + grad_data_norm.append(field_data_norm) + + return self.updated_copy(grad_data=grad_data_norm) diff --git a/tidy3d/plugins/adjoint/components/geometry.py b/tidy3d/plugins/adjoint/components/geometry.py index 0bff9607c..78a061657 100644 --- a/tidy3d/plugins/adjoint/components/geometry.py +++ b/tidy3d/plugins/adjoint/components/geometry.py @@ -2,13 +2,12 @@ from __future__ import annotations from abc import ABC -from typing import Tuple, Union, Dict +from typing import Tuple, Union, Dict, List from multiprocessing import Pool import pydantic.v1 as pd import numpy as np import xarray as xr -import jax.numpy as jnp from jax.tree_util import register_pytree_node_class import jax @@ -69,7 +68,7 @@ def bounding_box(self): return JaxBox.from_bounds(*self.bounds) def make_grad_monitors( - self, freq: float, name: str + self, freqs: List[float], name: str ) -> Tuple[FieldMonitor, PermittivityMonitor]: """Return gradient monitor associated with this object.""" size_enlarged = tuple(s + 2 * GRAD_MONITOR_EXPANSION for s in self.bound_size) @@ -77,7 +76,7 @@ def make_grad_monitors( size=size_enlarged, center=self.bound_center, fields=["Ex", "Ey", "Ez"], - freqs=[freq], + freqs=freqs, name=name + "_field", colocate=False, ) @@ -85,7 +84,7 @@ def make_grad_monitors( eps_mnt = PermittivityMonitor( size=size_enlarged, center=self.bound_center, - freqs=[freq], + freqs=freqs, name=name + "_eps", ) return field_mnt, eps_mnt @@ -234,7 +233,7 @@ def store_vjp( # select the permittivity data eps_field_name = f"eps_{field_cmp_dim}{field_cmp_dim}" - eps_data = grad_data_eps.field_components[eps_field_name].isel(f=0) + eps_data = grad_data_eps.field_components[eps_field_name] # get the permittivity values just inside and outside the edge @@ -265,7 +264,7 @@ def store_vjp( delta_eps_inv = 1.0 / eps1 - 1.0 / eps2 d_integrand = -(delta_eps_inv * d_normal).real d_integrand = d_integrand.interp(**area_coords, assume_sorted=True) - grad_contrib = d_area * jnp.sum(d_integrand.values) + grad_contrib = d_area * np.sum(d_integrand.values) # get gradient contribution for parallel components using parallel E fields else: @@ -278,14 +277,14 @@ def store_vjp( delta_eps = eps1 - eps2 e_integrand = +(delta_eps * e_parallel).real e_integrand = e_integrand.interp(**area_coords, assume_sorted=True) - grad_contrib = d_area * jnp.sum(e_integrand.values) + grad_contrib = d_area * np.sum(e_integrand.values) # add this field contribution to the dict storing the surface contributions vjp_surfs[dim_normal][min_max_index] += grad_contrib - # convert surface vjps to center, size vjps. Note, convert these to jax types w/ jnp.sum() - vjp_center = tuple(jnp.sum(vjp_surfs[dim][1] - vjp_surfs[dim][0]) for dim in "xyz") - vjp_size = tuple(jnp.sum(0.5 * (vjp_surfs[dim][1] + vjp_surfs[dim][0])) for dim in "xyz") + # convert surface vjps to center, size vjps. Note, convert these to jax types w/ np.sum() + vjp_center = tuple(np.sum(vjp_surfs[dim][1] - vjp_surfs[dim][0]) for dim in "xyz") + vjp_size = tuple(np.sum(0.5 * (vjp_surfs[dim][1] + vjp_surfs[dim][0])) for dim in "xyz") return self.copy(update=dict(center=vjp_center, size=vjp_size)) @@ -448,7 +447,6 @@ def compute_integrand(s: np.array, z: np.array) -> np.array: def evaluate(scalar_field: ScalarFieldDataArray) -> float: """Evaluate a scalar field at a coordinate along the edge.""" - scalar_field = scalar_field.isel(f=0) # if only 1 z coordinate, just isel the data. if len(z) == 1: @@ -506,7 +504,7 @@ def evaluate(scalar_field: ScalarFieldDataArray) -> float: dz = 1.0 # integrate by summing over axis edge (z) and parameterization point (s) - integrand = compute_integrand(s=s_vals, z=z_vals) + integrand = compute_integrand(s=s_vals, z=z_vals).sum(dim="f") integral_result = np.sum(integrand.fillna(0).values) # project to the normal direction diff --git a/tidy3d/plugins/adjoint/components/medium.py b/tidy3d/plugins/adjoint/components/medium.py index 81659a80d..3512270bd 100644 --- a/tidy3d/plugins/adjoint/components/medium.py +++ b/tidy3d/plugins/adjoint/components/medium.py @@ -191,10 +191,16 @@ def store_vjp( inside_fn=inside_fn, ) - vjp_eps_complex = np.sum(d_eps_map.values) + vjp_eps_complex = d_eps_map.sum(dim=("x", "y", "z")) - freq = d_eps_map.coords["f"][0] - vjp_eps, vjp_sigma = self.eps_complex_to_eps_sigma(vjp_eps_complex, freq) + vjp_eps = 0.0 + vjp_sigma = 0.0 + + for freq in d_eps_map.coords["f"]: + vjp_eps_complex_f = vjp_eps_complex.sel(f=freq) + _vjp_eps, _vjp_sigma = self.eps_complex_to_eps_sigma(vjp_eps_complex_f, freq) + vjp_eps += _vjp_eps + vjp_sigma += _vjp_sigma return self.copy( update=dict( @@ -274,9 +280,19 @@ def store_vjp( inside_fn=inside_fn, ) - vjp_eps_complex_ii = np.sum(e_mult_dim.values) + vjp_eps_complex_ii = e_mult_dim.sum(dim=("x", "y", "z")) freq = e_mult_dim.coords["f"][0] - vjp_eps_ii, vjp_sigma_ii = self.eps_complex_to_eps_sigma(vjp_eps_complex_ii, freq) + + vjp_eps_ii = 0.0 + vjp_sigma_ii = 0.0 + + for freq in e_mult_dim.coords["f"]: + vjp_eps_complex_ii_f = vjp_eps_complex_ii.sel(f=freq) + _vjp_eps_ii, _vjp_sigma_ii = self.eps_complex_to_eps_sigma( + vjp_eps_complex_ii_f, freq + ) + vjp_eps_ii += _vjp_eps_ii + vjp_sigma_ii += _vjp_sigma_ii vjp_fields[component_name] = JaxMedium( permittivity=vjp_eps_ii, @@ -511,14 +527,18 @@ def store_vjp( # grab the correpsonding dotted fields at these interp_coords and sum over len-1 pixels field_name = "E" + dim - e_dotted = self.e_mult_volume( - field=field_name, - grad_data_fwd=grad_data_fwd, - grad_data_adj=grad_data_adj, - vol_coords=interp_coords, - d_vol=d_vols, - inside_fn=inside_fn, - ).sum(sum_axes) + e_dotted = ( + self.e_mult_volume( + field=field_name, + grad_data_fwd=grad_data_fwd, + grad_data_adj=grad_data_adj, + vol_coords=interp_coords, + d_vol=d_vols, + inside_fn=inside_fn, + ) + .sum(sum_axes) + .sum(dim="f") + ) # reshape values to the expected vjp shape to be more safe vjp_shape = tuple(len(coord) for _, coord in coords.items()) diff --git a/tidy3d/plugins/adjoint/components/simulation.py b/tidy3d/plugins/adjoint/components/simulation.py index 030d70939..08213bc40 100644 --- a/tidy3d/plugins/adjoint/components/simulation.py +++ b/tidy3d/plugins/adjoint/components/simulation.py @@ -17,7 +17,7 @@ from ....components.data.monitor_data import FieldData, PermittivityData from ....components.structure import Structure from ....components.types import Ax, annotate_type -from ....constants import HERTZ +from ....constants import HERTZ, SECOND from ....exceptions import AdjointError from .base import JaxObject @@ -25,9 +25,15 @@ from .geometry import JaxPolySlab, JaxGeometryGroup -# bandwidth of adjoint source in units of freq0 if no sources and no `fwidth_adjoint` specified +# bandwidth of adjoint source in units of freq0 if no `fwidth_adjoint`, and one output freq FWIDTH_FACTOR = 1.0 / 10 +# bandwidth of adjoint sources in units of the minimum difference between output frequencies +FWIDTH_FACTOR_MULTIFREQ = 0.1 + +# the adjoint run time is RUN_TIME_FACTOR / fwidth +RUN_TIME_FACTOR = 100 + # how many processors to use for server and client side adjoint NUM_PROC_LOCAL = 1 @@ -69,6 +75,13 @@ class JaxInfo(Tidy3dBaseModel): units=HERTZ, ) + run_time_adjoint: float = pd.Field( + None, + title="Adjoint Run Time", + description="Custom run time of the original JaxSimulation.", + units=SECOND, + ) + @register_pytree_node_class class JaxSimulation(Simulation, JaxObject): @@ -105,33 +118,18 @@ class JaxSimulation(Simulation, JaxObject): fwidth_adjoint: pd.PositiveFloat = pd.Field( None, title="Adjoint Frequency Width", - description="Custom frequency width to use for 'source_time' of adjoint sources. " - "If not supplied or 'None', uses the average fwidth of the original simulation's sources.", + description="Custom frequency width to use for ``source_time`` of adjoint sources. " + "If not supplied or ``None``, uses the average fwidth of the original simulation's sources.", units=HERTZ, ) - @pd.validator("output_monitors", always=True) - def _output_monitors_single_freq(cls, val): - """Assert all output monitors have just one frequency.""" - for mnt in val: - if len(mnt.freqs) != 1: - raise AdjointError( - "All output monitors must have single frequency for adjoint feature. " - f"Monitor '{mnt.name}' had {len(mnt.freqs)} frequencies." - ) - return val - - @pd.validator("output_monitors", always=True) - def _output_monitors_same_freq(cls, val): - """Assert all output monitors have the same frequency.""" - freqs = [mnt.freqs[0] for mnt in val] - if len(set(freqs)) > 1: - raise AdjointError( - "All output monitors must have the same frequency, " - f"given frequencies of {[f'{f:.2e}' for f in freqs]} (Hz) " - f"for monitors named '{[mnt.name for mnt in val]}', respectively." - ) - return val + run_time_adjoint: pd.PositiveFloat = pd.Field( + None, + title="Adjoint Run Time", + description="Custom ``run_time`` to use for adjoint simulation. " + "If not supplied or ``None``, uses a factor times the adjoint source ``fwidth``.", + units=SECOND, + ) @pd.validator("output_monitors", always=True) def _output_monitors_colocate_false(cls, val): @@ -227,18 +225,38 @@ def _warn_if_colocate(cls, val): return val @staticmethod - def get_freq_adjoint(output_monitors: List[Monitor]) -> float: - """Return the single adjoint frequency stripped from the output monitors.""" + def get_freqs_adjoint(output_monitors: List[Monitor]) -> List[float]: + """Return sorted list of unique frequencies stripped from a collection of monitors.""" if len(output_monitors) == 0: raise AdjointError("Can't get adjoint frequency as no output monitors present.") - return output_monitors[0].freqs[0] + output_freqs = [] + for mnt in output_monitors: + for freq in mnt.freqs: + output_freqs.append(freq) + + return np.unique(output_freqs).tolist() + + @cached_property + def freqs_adjoint(self) -> List[float]: + """Return sorted list of frequencies stripped from the output monitors.""" + return self.get_freqs_adjoint(output_monitors=self.output_monitors) + + @cached_property + def _is_multi_freq(self) -> bool: + """Does this simulation have a multi-frequency output?""" + return len(self.freqs_adjoint) > 1 @cached_property - def freq_adjoint(self) -> float: - """Return the single adjoint frequency stripped from the output monitors.""" - return self.get_freq_adjoint(output_monitors=self.output_monitors) + def _min_delta_freq(self) -> float: + """Minimum spacing between output_frequencies (Hz).""" + + if not self._is_multi_freq: + return None + + delta_freqs = np.abs(np.diff(np.sort(np.array(self.freqs_adjoint)))) + return np.min(delta_freqs) @cached_property def _fwidth_adjoint(self) -> float: @@ -248,19 +266,51 @@ def _fwidth_adjoint(self) -> float: if self.fwidth_adjoint is not None: return self.fwidth_adjoint - # otherwise, grab from sources - num_sources = len(self.sources) + freqs_adjoint = self.freqs_adjoint + + # multiple output frequency case + if self._is_multi_freq: + return FWIDTH_FACTOR_MULTIFREQ * self._min_delta_freq - # if no sources, just use a constant factor times the adjoint frequency + # otherwise, grab from sources and output monitors + num_sources = len(self.sources) # should be 0 for adjoint already but worth checking + + # if no sources, just use a constant factor times the mean adjoint frequency if num_sources == 0: - return FWIDTH_FACTOR * self.freq_adjoint + return FWIDTH_FACTOR * np.mean(freqs_adjoint) - # if more than one forward source, use their average + # if more than one forward source, use their maximum if num_sources > 1: - log.warning(f"{num_sources} sources, using their average 'fwidth' for adjoint source.") + log.warning(f"{num_sources} sources, using their maximum 'fwidth' for adjoint source.") fwidths = [src.source_time.fwidth for src in self.sources] - return np.mean(fwidths) + return np.max(fwidths) + + @cached_property + def _run_time_adjoint(self: float) -> float: + """Return the run time of the adjoint simulation as a function of its fwidth.""" + + if self.run_time_adjoint is not None: + return self.run_time_adjoint + + run_time_adjoint = RUN_TIME_FACTOR / self._fwidth_adjoint + + if self._is_multi_freq: + + log.warning( + f"{len(self.freqs_adjoint)} unique frequencies detected in the output monitors " + f"with a minimum spacing of {self._min_delta_freq:.3e} (Hz). " + f"Setting the 'fwidth' of the adjoint sources to {FWIDTH_FACTOR_MULTIFREQ} times " + f"this value = {self._fwidth_adjoint:.3e} (Hz) to avoid spectral overlap. " + "To account for this, the corresponding 'run_time' in the adjoint simulation is " + f"will be set to {run_time_adjoint:3e} " + f"compared to {self.run_time:3e} in the forward simulation. " + "If the adjoint 'run_time' is large due to small frequency spacing, " + "it could be better to instead run one simulation per frequency, " + "which can be done in parallel using 'tidy3d.plugins.adjoint.web.run_async'." + ) + + return run_time_adjoint def to_simulation(self) -> Tuple[Simulation, JaxInfo]: """Convert :class:`.JaxSimulation` instance to :class:`.Simulation` with an info dict.""" @@ -275,8 +325,9 @@ def to_simulation(self) -> Tuple[Simulation, JaxInfo]: "grad_eps_monitors", "input_structures", "fwidth_adjoint", + "run_time_adjoint", } - ) # .copy() + ) sim = Simulation.parse_obj(sim_dict) # put all structures and monitors in one list @@ -288,7 +339,7 @@ def to_simulation(self) -> Tuple[Simulation, JaxInfo]: + list(self.grad_eps_monitors) ) - sim = sim.copy(update=dict(structures=all_structures, monitors=all_monitors)) + sim = sim.updated_copy(structures=all_structures, monitors=all_monitors) # information about the state of the original JaxSimulation to stash for reconstruction jax_info = JaxInfo( @@ -297,6 +348,7 @@ def to_simulation(self) -> Tuple[Simulation, JaxInfo]: num_grad_monitors=len(self.grad_monitors), num_grad_eps_monitors=len(self.grad_eps_monitors), fwidth_adjoint=self.fwidth_adjoint, + run_time_adjoint=self.run_time_adjoint, ) return sim, jax_info @@ -522,7 +574,12 @@ def from_simulation(cls, simulation: Simulation, jax_info: JaxInfo) -> JaxSimula # update the dictionary with these and the adjoint fwidth sim_dict.update(**structures) sim_dict.update(**monitors) - sim_dict.update(dict(fwidth_adjoint=jax_info.fwidth_adjoint)) + sim_dict.update( + dict( + fwidth_adjoint=jax_info.fwidth_adjoint, + run_time_adjoint=jax_info.run_time_adjoint, + ) + ) # load JaxSimulation from the dictionary return cls.parse_obj(sim_dict) @@ -539,7 +596,7 @@ def make_sim_fwd(cls, simulation: Simulation, jax_info: JaxInfo) -> Tuple[Simula input_structures = structure_dict["input_structures"] grad_mnt_dict = cls.get_grad_monitors( input_structures=input_structures, - freq_adjoint=cls.get_freq_adjoint(output_monitors=output_monitors), + freqs_adjoint=cls.get_freqs_adjoint(output_monitors=output_monitors), ) grad_mnts = grad_mnt_dict["grad_monitors"] @@ -569,14 +626,14 @@ def to_simulation_fwd(self) -> Tuple[Simulation, JaxInfo, JaxInfo]: @staticmethod def get_grad_monitors( - input_structures: List[Structure], freq_adjoint: float, include_eps_mnts: bool = True + input_structures: List[Structure], freqs_adjoint: List[float], include_eps_mnts: bool = True ) -> dict: """Return dictionary of gradient monitors for simulation.""" grad_mnts = [] grad_eps_mnts = [] for index, structure in enumerate(input_structures): grad_mnt, grad_eps_mnt = structure.make_grad_monitors( - freq=freq_adjoint, name=f"grad_mnt_{index}" + freqs=freqs_adjoint, name=f"grad_mnt_{index}" ) grad_mnts.append(grad_mnt) if include_eps_mnts: @@ -593,8 +650,8 @@ def _store_vjp_structure( ) -> JaxStructure: """Store the vjp for a single structure.""" - freq = float(eps_data.eps_xx.coords["f"]) - eps_out = self.medium.eps_model(frequency=freq) + freq_max = float(max(eps_data.eps_xx.coords["f"])) + eps_out = self.medium.eps_model(frequency=freq_max) return structure.store_vjp( grad_data_fwd=fld_fwd, grad_data_adj=fld_adj, diff --git a/tidy3d/plugins/adjoint/components/structure.py b/tidy3d/plugins/adjoint/components/structure.py index 1438a6896..065cfcb09 100644 --- a/tidy3d/plugins/adjoint/components/structure.py +++ b/tidy3d/plugins/adjoint/components/structure.py @@ -1,6 +1,8 @@ """Defines a jax-compatible structure and its conversion to a gradient monitor.""" from __future__ import annotations +from typing import List + import pydantic.v1 as pd import numpy as np from jax.tree_util import register_pytree_node_class @@ -74,10 +76,10 @@ def store_vjp( ) -> JaxStructure: """Returns the gradient of the structure parameters given forward and adjoint field data.""" - # compute wavelength in material (to use for determining integration points) - freq = float(grad_data_eps.eps_xx.f) - wvl_free_space = C_0 / freq - eps_in = self.medium.eps_model(frequency=freq) + # compute minimum wavelength in material (to use for determining integration points) + freq_max = max(grad_data_eps.eps_xx.f) + wvl_free_space = C_0 / freq_max + eps_in = self.medium.eps_model(frequency=freq_max) ref_ind = np.sqrt(np.max(np.real(eps_in))) wvl_mat = wvl_free_space / ref_ind @@ -102,6 +104,6 @@ def store_vjp( return self.copy(update=dict(geometry=geo_vjp, medium=medium_vjp)) - def make_grad_monitors(self, freq: float, name: str) -> FieldMonitor: + def make_grad_monitors(self, freqs: List[float], name: str) -> FieldMonitor: """Return gradient monitor associated with this object.""" - return self.geometry.make_grad_monitors(freq=freq, name=name) + return self.geometry.make_grad_monitors(freqs=freqs, name=name) diff --git a/tidy3d/plugins/adjoint/web.py b/tidy3d/plugins/adjoint/web.py index 23c97e6f4..87dd5401b 100644 --- a/tidy3d/plugins/adjoint/web.py +++ b/tidy3d/plugins/adjoint/web.py @@ -165,7 +165,8 @@ def run_bwd( fwd_task_id = res[0].fwd_task_id fwidth_adj = sim_data_vjp.simulation._fwidth_adjoint - jax_sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj) + run_time_adj = sim_data_vjp.simulation._run_time_adjoint + jax_sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj) sim_adj, jax_info_adj = jax_sim_adj.to_simulation() sim_vjp = webapi_run_adjoint_bwd( @@ -484,7 +485,8 @@ def run_async_bwd( for sim_data_vjp, fwd_task_id in zip(batch_data_vjp, fwd_task_ids): parent_tasks_adj.append([str(fwd_task_id)]) fwidth_adj = sim_data_vjp.simulation._fwidth_adjoint - jax_sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj) + run_time_adj = sim_data_vjp.simulation._run_time_adjoint + jax_sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj) sim_adj, jax_info_adj = jax_sim_adj.to_simulation() sims_adj.append(sim_adj) jax_infos_adj.append(jax_info_adj) @@ -647,7 +649,7 @@ def run_local_fwd( # add the gradient monitors and run the forward simulation grad_mnts = simulation.get_grad_monitors( - input_structures=simulation.input_structures, freq_adjoint=simulation.freq_adjoint + input_structures=simulation.input_structures, freqs_adjoint=simulation.freqs_adjoint ) sim_fwd = simulation.updated_copy(**grad_mnts) sim_data_fwd = run( @@ -682,7 +684,8 @@ def run_local_bwd( # make and run adjoint simulation fwidth_adj = sim_data_fwd.simulation._fwidth_adjoint - sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj) + run_time_adj = sim_data_fwd.simulation._run_time_adjoint + sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj) sim_data_adj = run( simulation=sim_adj, task_name=_task_name_adj(task_name), @@ -691,6 +694,9 @@ def run_local_bwd( callback_url=callback_url, verbose=verbose, ) + + sim_data_adj = sim_data_adj.normalize_adjoint_fields() + grad_data_adj = sim_data_adj.grad_data_symmetry # get gradient and insert into the resulting simulation structure medium @@ -804,7 +810,7 @@ def run_async_local_fwd( for simulation in simulations: grad_mnts = simulation.get_grad_monitors( - input_structures=simulation.input_structures, freq_adjoint=simulation.freq_adjoint + input_structures=simulation.input_structures, freqs_adjoint=simulation.freqs_adjoint ) sim_fwd = simulation.updated_copy(**grad_mnts) sims_fwd.append(sim_fwd) @@ -857,8 +863,9 @@ def run_async_local_bwd( sims_adj = [] for i, sim_data_fwd in enumerate(batch_data_fwd): fwidth_adj = sim_data_fwd.simulation._fwidth_adjoint + run_time_adj = sim_data_fwd.simulation._run_time_adjoint sim_data_vjp = batch_data_vjp[i] - sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj) + sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj) sims_adj.append(sim_adj) batch_data_adj = run_async_local( @@ -874,6 +881,8 @@ def run_async_local_bwd( sims_vjp = [] for i, (sim_data_fwd, sim_data_adj) in enumerate(zip(batch_data_fwd, batch_data_adj)): + sim_data_adj = sim_data_adj.normalize_adjoint_fields() + grad_data_fwd = sim_data_fwd.grad_data_symmetry grad_data_adj = sim_data_adj.grad_data_symmetry grad_data_eps_fwd = sim_data_fwd.grad_eps_data_symmetry