diff --git a/CHANGELOG.md b/CHANGELOG.md index c7800b170..f95ae0fa8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Broadband adjoint support in autograd for adjoint sources with the same spatial dependence. + ## [2.7.1] ### Added @@ -13,6 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `ModeSolver` methods to plot the mode plane simulation components, including `.plot()`, `.plot_eps()`, `.plot_structures_eps()`, `.plot_grid()`, and `.plot_pml()`. - Support for differentiation with respect to monitor attributes that require interpolation, such as flux and intensity. - Support for automatic differentiation with respect to `.eps_inf` and `.poles` contained in dispersive mediums `td.PoleResidue` and `td.CustomPoleResidue`. +- Support for automatic differentiation with respect to `.eps_inf` and `.poles` contained in dispersive mediums `td.PoleResidue` and `td.CustomPoleResidue`. + +### Changed ### Fixed - Bug where boundary layers would be plotted too small in 2D simulations. diff --git a/tests/test_components/test_autograd.py b/tests/test_components/test_autograd.py index d811d7363..064eeb959 100644 --- a/tests/test_components/test_autograd.py +++ b/tests/test_components/test_autograd.py @@ -12,9 +12,9 @@ import numpy as np import pytest import tidy3d as td +import xarray as xr from tidy3d.components.autograd.derivative_utils import DerivativeInfo -from tidy3d.web import run_async -from tidy3d.web.api.autograd.autograd import run +from tidy3d.web import run, run_async from ..utils import SIM_FULL, AssertLogLevel, run_emulated @@ -53,6 +53,7 @@ WVL = 1.0 FREQ0 = td.C_0 / WVL +FREQS = [0.9 * FREQ0, FREQ0, FREQ0 * 1.1] # sim sizes LZ = 7 * WVL @@ -154,8 +155,10 @@ def run_async_emulated(simulations, **kwargs): def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: """Make a dictionary of the structures given the parameters.""" + np.random.seed(0) + vector = np.random.random(N_PARAMS) - 0.5 - vector /= np.linalg.norm(vector) + vector = vector / np.linalg.norm(vector) # static components box = td.Box(center=(0, 0, 0), size=(1, 1, 1)) @@ -411,6 +414,7 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = False) -> None: if TEST_POLYSLAB_SPEED: args = [("polyslab", "mode")] + # args = [("geo_group", "mode")] @@ -612,33 +616,13 @@ def objective(*params): ag.grad(objective)(params0) -def test_warning_no_adjoint_sources(log_capture, monkeypatch, use_emulated_run): - """Make sure we get the right warning with no adjoint sources, and no error.""" - - monitor_key = "mode" - structure_key = "size_element" - monitor, postprocess = make_monitors()[monitor_key] - - def make_sim(*args): - structure = make_structures(*args)[structure_key] - return SIM_BASE.updated_copy(structures=[structure], monitors=[monitor]) - - def objective(*args): - """Objective function.""" - sim = make_sim(*args) - data = run(sim, task_name="autograd_test", verbose=False) - value = postprocess(data, data[monitor_key]) - return value - - monkeypatch.setattr(td.SimulationData, "make_adjoint_sources", lambda *args, **kwargs: []) - - with AssertLogLevel(log_capture, "WARNING", contains_str="No adjoint sources"): - ag.grad(objective)(params0) - - def test_web_failure_handling(log_capture, monkeypatch, use_emulated_run, use_emulated_run_async): """Test what happens when autograd run pipeline fails.""" + def fail(*args, **kwargs): + """Just raise an exception.""" + raise ValueError("test") + monitor_key = "mode" structure_key = "size_element" monitor, postprocess = make_monitors()[monitor_key] @@ -650,14 +634,10 @@ def make_sim(*args): def objective(*args): """Objective function.""" sim = make_sim(*args) - data = run(sim, task_name="autograd_test", verbose=False) + data = run(sim, task_name=None, verbose=False) value = postprocess(data, data[monitor_key]) return value - def fail(*args, **kwargs): - """Just raise an exception.""" - raise ValueError("test") - """ if autograd run raises exception, raise a warning and continue with regular .""" monkeypatch.setattr(td.web.api.autograd.autograd, "_run", fail) @@ -1008,3 +988,230 @@ def f(x): * no copy : 16 sec * no to_static(): 13 sec """ + +FREQ1 = FREQ0 * 1.1 + +mnt_single = td.ModeMonitor( + size=(2, 2, 0), + center=(0, 0, LZ / 2 - WVL), + mode_spec=td.ModeSpec(num_modes=2), + freqs=[FREQ0], + name="single", +) + +mnt_multi = td.ModeMonitor( + size=(2, 2, 0), + center=(0, 0, LZ / 2 - WVL), + mode_spec=td.ModeSpec(num_modes=2), + freqs=[FREQ0, FREQ1], + name="multi", +) + + +def make_objective(postprocess_fn: typing.Callable, structure_key: str) -> typing.Callable: + def objective(params): + structure_traced = make_structures(params)[structure_key] + sim = SIM_BASE.updated_copy( + structures=[structure_traced], + monitors=list(SIM_BASE.monitors) + [mnt_single, mnt_multi], + ) + data = run(sim, task_name="multifreq_test") + return postprocess_fn(data) + + return objective + + +def get_amps(sim_data: td.SimulationData, mnt_name: str) -> xr.DataArray: + return sim_data[mnt_name].amps + + +def power(amps: xr.DataArray) -> float: + """Reduce a selected DataArray into just a float for objective function.""" + return anp.sum(anp.abs(amps.values) ** 2) + + +def postprocess_0_src(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 0 adjoint sources.""" + return 0.0 + + +def compute_grad(postprocess_fn: typing.Callable, structure_key: str) -> typing.Callable: + objective = make_objective(postprocess_fn, structure_key=structure_key) + params = params0 + 1.0 # +1 is to avoid a warning in size_element with value 0 + return ag.grad(objective)(params) + + +def check_0_src(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 0 adjoint sources.""" + return 0.0 + + compute_grad(postprocess, structure_key=structure_key) + + +# NOTE: not tested, just raises regular warning + + +def check_1_src_single(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 1 adjoint sources.""" + amps = get_amps(sim_data, "single").sel(mode_index=0, direction="+") + return power(amps) + + with AssertLogLevel( + log_capture, log_level_expected="INFO", contains_str="One monitor with one adjoint source." + ): + compute_grad(postprocess, structure_key=structure_key) + + +def check_2_src_single(log_capture, structure_key): + def postprocess_2_src_single(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 2 different adjoint sources.""" + amps = get_amps(sim_data, "single").sel(mode_index=0) + return power(amps) + + with AssertLogLevel( + log_capture, log_level_expected="INFO", contains_str="One monitor with 2 adjoint sources." + ): + compute_grad(postprocess_2_src_single, structure_key=structure_key) + + +def check_1_src_multi(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 1 adjoint sources.""" + amps = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=FREQ0) + return power(amps) + + with AssertLogLevel( + log_capture, log_level_expected="INFO", contains_str="One monitor with one adjoint source." + ): + compute_grad(postprocess, structure_key=structure_key) + + +def check_2_src_multi(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 2 different adjoint sources.""" + amps = get_amps(sim_data, "multi").sel(mode_index=0, f=FREQ1) + return power(amps) + + with AssertLogLevel( + log_capture, log_level_expected="INFO", contains_str="One monitor with 2 adjoint sources." + ): + compute_grad(postprocess, structure_key=structure_key) + + +def check_2_src_both(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 2 different adjoint sources.""" + amps_single = get_amps(sim_data, "single").sel(mode_index=0, direction="+") + amps_multi = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=FREQ0) + return power(amps_single) + power(amps_multi) + + with AssertLogLevel( + log_capture, + log_level_expected="INFO", + contains_str="Several adjoint sources from different monitors, all with same single frequency.", + ): + compute_grad(postprocess, structure_key=structure_key) + + +def check_1_error_multisrc(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should raise ValueError because diff sources, diff freqs.""" + amps_single = get_amps(sim_data, "single").sel(mode_index=0, direction="+") + amps_multi = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=FREQ1) + return power(amps_single) + power(amps_multi) + + with pytest.raises(ValueError): + compute_grad(postprocess, structure_key=structure_key) + + +def check_2_error_multisrc(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should raise ValueError because diff sources, diff freqs.""" + amps_single = get_amps(sim_data, "single").sel(mode_index=0, direction="+") + amps_multi = get_amps(sim_data, "multi").sel(mode_index=0, direction="+") + return power(amps_single) + power(amps_multi) + + with pytest.raises(ValueError): + compute_grad(postprocess, structure_key=structure_key) + + +def check_1_src_broadband(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 1 broadband adjoint sources with many freqs.""" + amps = get_amps(sim_data, "multi").sel(mode_index=0, direction="+") + return power(amps) + + with AssertLogLevel( + log_capture, + log_level_expected="INFO", + contains_str="Constructing broadband adjoint source and performing post-run normalization", + ): + compute_grad(postprocess, structure_key=structure_key) + + +MULT_FREQ_TEST_CASES = dict( + src_1_freq_1=check_1_src_single, + src_2_freq_1=check_2_src_single, + src_1_freq_2=check_1_src_multi, + src_2_freq_1_mon_1=check_1_src_multi, + src_2_freq_1_mon_2=check_2_src_both, + src_2_freq_2_mon_1=check_1_error_multisrc, + src_2_freq_2_mon_2=check_2_error_multisrc, + src_1_freq_2_broadband=check_1_src_broadband, +) + +checks = list(MULT_FREQ_TEST_CASES.items()) + + +@pytest.mark.parametrize("label, check_fn", checks) +@pytest.mark.parametrize("structure_key", structure_keys_) +def test_multi_freq_edge_cases(log_capture, use_emulated_run, structure_key, label, check_fn): + # test multi-frequency adjoint handling + check_fn(structure_key=structure_key, log_capture=log_capture) + + +@pytest.mark.parametrize("structure_key", structure_keys_) +def test_multi_frequency_equivalence(use_emulated_run, structure_key): + """Test an objective function through tidy3d autograd.""" + + def objective_indi(params, structure_key) -> float: + power_sum = 0.0 + + for f in mnt_multi.freqs: + structure_traced = make_structures(params)[structure_key] + sim = SIM_BASE.updated_copy( + structures=[structure_traced], + monitors=list(SIM_BASE.monitors) + [mnt_multi], + ) + + sim_data = run(sim, task_name="multifreq_test") + amps_i = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=f) + power_i = power(amps_i) + power_sum = power_sum + power_i + + return power_sum + + def objective_multi(params, structure_key) -> float: + structure_traced = make_structures(params)[structure_key] + sim = SIM_BASE.updated_copy( + structures=[structure_traced], + monitors=list(SIM_BASE.monitors) + [mnt_multi], + ) + sim_data = run(sim, task_name="multifreq_test") + amps = get_amps(sim_data, "multi").sel(mode_index=0, direction="+") + return power(amps) + + params0_ = params0 + 1.0 + + J_indi = objective_indi(params0_, structure_key) + J_multi = objective_multi(params0_, structure_key) + + np.testing.assert_allclose(J_indi, J_multi) + + grad_indi = ag.grad(objective_indi)(params0_, structure_key=structure_key) + grad_multi = ag.grad(objective_multi)(params0_, structure_key=structure_key) + + assert not np.any(np.isclose(grad_indi, 0)) + assert not np.any(np.isclose(grad_multi, 0)) diff --git a/tests/utils.py b/tests/utils.py index 141a6882b..94e40a7e0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,6 +16,8 @@ """ utilities shared between all tests """ np.random.seed(4) +# function used to generate the data for emulated runs +DATA_GEN_FN = np.random.random FREQS = np.array([1.90, 2.01, 2.2]) * 1e12 SIM_MONITORS = td.Simulation( @@ -880,7 +882,7 @@ def make_data( """make a random DataArray out of supplied coordinates and data_type.""" data_shape = [len(coords[k]) for k in data_array_type._dims] np.random.seed(1) - data = np.random.random(data_shape) + data = DATA_GEN_FN(data_shape) data = (1 + 0.5j) * data if is_complex else data data = gaussian_filter(data, sigma=1.0) # smooth out the data a little so it isnt random @@ -939,7 +941,7 @@ def make_mode_solver_data(monitor: td.ModeSolverMonitor) -> td.ModeSolverData: index_coords["mode_index"] = np.arange(monitor.mode_spec.num_modes) index_data_shape = (len(index_coords["f"]), len(index_coords["mode_index"])) index_data = ModeIndexDataArray( - (1 + 1j) * np.random.random(index_data_shape), coords=index_coords + (1 + 1j) * DATA_GEN_FN(index_data_shape), coords=index_coords ) for field_name in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]: coords = get_spatial_coords_dict(simulation, monitor, field_name) @@ -977,7 +979,7 @@ def make_diff_data(monitor: td.DiffractionMonitor) -> td.DiffractionData: orders_x = np.linspace(-1, 1, 3) orders_y = np.linspace(-2, 2, 5) coords = dict(orders_x=orders_x, orders_y=orders_y, f=f) - values = np.random.random((len(orders_x), len(orders_y), len(f))) + values = DATA_GEN_FN((len(orders_x), len(orders_y), len(f))) data = td.DiffractionDataArray(values, coords=coords) field_data = {field: data for field in ("Er", "Etheta", "Ephi", "Hr", "Htheta", "Hphi")} return td.DiffractionData(monitor=monitor, sim_size=(1, 1), bloch_vecs=(0, 0), **field_data) diff --git a/tidy3d/components/data/data_array.py b/tidy3d/components/data/data_array.py index 4d7eea82a..7f78e0f52 100644 --- a/tidy3d/components/data/data_array.py +++ b/tidy3d/components/data/data_array.py @@ -76,6 +76,7 @@ def __init__(self, data, *args, **kwargs): # initialize with untraced data super().__init__(getval(data), *args, **kwargs) # and put tracers in .attrs + if isbox(data): self.attrs[AUTOGRAD_KEY] = data diff --git a/tidy3d/components/data/sim_data.py b/tidy3d/components/data/sim_data.py index f0107b4dd..dd1aac610 100644 --- a/tidy3d/components/data/sim_data.py +++ b/tidy3d/components/data/sim_data.py @@ -20,7 +20,7 @@ from ..file_util import replace_values from ..monitor import Monitor from ..simulation import Simulation -from ..source import Source +from ..source import ModeSource, Source from ..structure import Structure from ..types import Ax, Axis, ColormapType, FieldVal, PlotScale, annotate_type from ..viz import add_ax_if_none, equal_aspect @@ -953,13 +953,17 @@ def source_spectrum_fn(freqs): def make_adjoint_sim( self, data_vjp_paths: set[tuple], adjoint_monitors: list[Monitor] - ) -> Simulation: + ) -> tuple[Simulation, xr.DataArray]: """Make the adjoint simulation from the original simulation and the VJP-containing data.""" sim_original = self.simulation # generate the adjoint sources - sources_adj = self.make_adjoint_sources(data_vjp_paths=data_vjp_paths) + sources_adj_dict = self.make_adjoint_sources(data_vjp_paths=data_vjp_paths) + + sources_adj, post_norm_amps = self.process_adjoint_sources( + sources_adj_dict=sources_adj_dict + ) # grab boundary conditions with flipped Bloch vectors (for adjoint) bc_adj = sim_original.boundary_spec.flipped_bloch_vecs @@ -978,9 +982,153 @@ def make_adjoint_sim( grid_spec_adj = grid_spec_original.updated_copy(wavelength=wavelength_original) sim_adj_update_dict["grid_spec"] = grid_spec_adj - return sim_original.updated_copy(**sim_adj_update_dict) + return sim_original.updated_copy(**sim_adj_update_dict), post_norm_amps + + def process_adjoint_sources( + self, sources_adj_dict: dict[str, list[Source]] + ) -> tuple[list[Source], xr.DataArray]: + """Process mapping of monitor name to adjoint sources to adj srcs and post-norm amps.""" + + # number of monitors contributing to adjoint sources + num_adj_src_monitors = len(sources_adj_dict) + + # no adjoint sources, gradient is 0 + if num_adj_src_monitors == 0: + log.warning( + "Something unexpected happened. There are No adjoint sources, " + " yet the adjoint pipeline was triggered. No gradient will be computed " + "with respect to simulation data output. If you received this warning, you may have" + " multiplied your objective function by 0 and might want to investigate further." + ) + return [], None + + # adjoint sources from just one monitor, try to do broadband + if num_adj_src_monitors == 1: + # grab the sources for that monitor + adj_srcs = list(sources_adj_dict.values())[0] + + # just one source, just handle it as normal + if len(adj_srcs) == 1: + log.info("One monitor with one adjoint source.") + return self._process_adjoint_sources_same_freq(sources_adj_dict) + + # multiple sources, check if they are same except for source_time + unique_sources = {src.json(exclude={"source_time"}) for src in adj_srcs} + if len(unique_sources) > 1: + log.info( + f"One monitor with {len(unique_sources)} adjoint sources. " + "But, because of different source characteristics, " + f"the monitor still requires f{len(unique_sources)} different adjoint sources." + " No special handling." + ) + return self._process_adjoint_sources_same_freq(sources_adj_dict) + + # many sources differ only by source can be handled with broadband adjoint + return self._process_adjoint_sources_broadband(sources_adj_dict) + + # typical case: several monitors with adjoint sources at same freq + return self._process_adjoint_sources_same_freq(sources_adj_dict) + + def _process_adjoint_sources_same_freq( + self, sources_adj_dict: dict[str, list[Source]] + ) -> tuple[list[Source], xr.DataArray]: + """Process adjoint sources for the case of one adjoint source at several freqs.""" + + # map of monitor name to set of unique frequencies in the corresponding data + sources_unique_freqs = { + key: {src.source_time.freq0 for src in val} for key, val in sources_adj_dict.items() + } + + def error_msg_pre(mnt_name: str) -> str: + """First part of error message for specific monitor.""" + return f"Can't compute adjoint source for data from monitor '{mnt_name}'. " + + error_msg_post = ( + "Can only compute adjoint source for several monitor output if each" + "data has one frequency only." + ) + + # perform validation of the adjoint sources + all_freqs = set() + for mnt_name, freqs in sources_unique_freqs.items(): + # first, make sure each monitor data only has one frequency + if len(freqs) > 1: + raise ValueError( + error_msg_pre(mnt_name) + + f"Monitor has {len(freqs)} frequencies. " + + error_msg_post + ) - def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> list[Source]: + # then, more generally, ensure that all monitor data have the same frequency + freq = tuple(freqs)[0] + all_freqs.add(freq) + if len(all_freqs) > 1: + raise ValueError( + error_msg_pre(mnt_name) + + error_msg_post + + " Detected different frequencies in different monitors." + ) + + log.info("Several adjoint sources from different monitors, all with same single frequency.") + + # passed validation, put all sources into a single list and set None for post-normalize amps + adj_srcs = sum(sources_adj_dict.values(), []) + + return adj_srcs, None + + def _process_adjoint_sources_broadband( + self, sources_adj_dict: dict[str, list[Source]] + ) -> tuple[list[Source], xr.DataArray]: + """Process adjoint sources for the case of several sources at the same freq.""" + + adj_srcs = list(sources_adj_dict.values())[0] + + src_broadband = self._make_broadband_source(adj_srcs=adj_srcs) + post_norm_amps = self._make_post_norm_amps(adj_srcs=adj_srcs) + + log.info( + "Several adjoint sources, from one monitor. " + "Only difference between them is the source time. " + "Constructing broadband adjoint source and performing post-run normalization " + f"of fields with {len(post_norm_amps)} frequencies." + ) + + return [src_broadband], post_norm_amps + + def _make_broadband_source(self, adj_srcs: list[Source], num_fwidth: float = 0.5) -> Source: + """Make a broadband source for a set of adjoint sources.""" + + source_index = self.simulation.normalize_index or 0 + src_time_base = self.simulation.sources[source_index].source_time.copy() + src_broadband = adj_srcs[0].updated_copy(source_time=src_time_base) + + # TODO: make this a broadband mode source, if applicable + if isinstance(src_broadband, ModeSource): + # src_un_normalized = src_un_normalized.updated_copy(...) + log.info( + "Making multi-frequency adjoint 'ModeSource' into a broadband " + f"mode source with {len(adj_srcs)} frequencies." + ) + + return src_broadband + + @staticmethod + def _make_post_norm_amps(adj_srcs: list[Source]) -> xr.DataArray: + """Make a ``DataArray`` containing the complex amplitudes to multiply with adjoint field.""" + + freqs = [] + amps_complex = [] + for src in adj_srcs: + src_time = src.source_time + freqs.append(src_time.freq0) + amp_complex = src_time.amplitude * np.exp(1j * src_time.phase) + amps_complex.append(amp_complex) + + coords = dict(f=freqs) + amps_complex = np.array(amps_complex) + return xr.DataArray(amps_complex, coords=coords) + + def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, Source]: """Generate all of the non-zero sources for the adjoint simulation given the VJP data.""" # TODO: determine if we can do multi-frequency sources @@ -990,12 +1138,12 @@ def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> list[Source]: for _, index, dataset_name in data_vjp_paths: adj_src_map[index].append(dataset_name) - # gather a list of adjoint sources for every monitor data in the VJP that needs one - sources_adj_all = [] + # gather a dict of adjoint sources for every monitor data in the VJP that needs one + sources_adj_all = defaultdict(list) for data_index, dataset_names in adj_src_map.items(): mnt_data = self.data[data_index] sources_adj = mnt_data.make_adjoint_sources(dataset_names=dataset_names) - sources_adj_all += sources_adj + sources_adj_all[mnt_data.monitor.name] = sources_adj return sources_adj_all diff --git a/tidy3d/components/medium.py b/tidy3d/components/medium.py index 79258b0a9..4b5fd6bb4 100644 --- a/tidy3d/components/medium.py +++ b/tidy3d/components/medium.py @@ -1150,7 +1150,7 @@ def derivative_eps_complex_volume( ) vjp_value += vjp_value_fld - return vjp_value + return vjp_value.sum("f") class AbstractCustomMedium(AbstractMedium, ABC): @@ -1395,7 +1395,7 @@ def _derivative_field_cmp( # TODO: probably this could be more robust. eg if the DataArray has weird edge cases E_der_dim = E_der_map[f"E{dim}"] - E_der_dim_interp = E_der_dim.interp(**coords_interp).fillna(0.0).sum(dims_sum) + E_der_dim_interp = E_der_dim.interp(**coords_interp).fillna(0.0).sum(dims_sum).sum("f") vjp_array = np.array(E_der_dim_interp.values).astype(complex) vjp_array = vjp_array.reshape(eps_data.shape) @@ -2551,8 +2551,11 @@ def _derivative_field_cmp( eps_data: PermittivityDataset, dim: str, ) -> np.ndarray: - coords_interp = {key: val for key, val in eps_data.coords.items() if len(val) > 1} - dims_sum = {dim for dim in eps_data.coords.keys() if dim not in coords_interp} + """Compute derivative with respect to the ``dim`` components within the custom medium.""" + + coords_interp = {key: eps_data.coords[key] for key in "xyz"} + coords_interp = {key: val for key, val in coords_interp.items() if len(val) > 1} + dims_sum = [dim for dim in "xyz" if dim not in coords_interp] # compute sizes along each of the interpolation dimensions sizes_list = [] @@ -2581,8 +2584,11 @@ def _derivative_field_cmp( # TODO: probably this could be more robust. eg if the DataArray has weird edge cases E_der_dim = E_der_map[f"E{dim}"] - E_der_dim_interp = E_der_dim.interp(**coords_interp).fillna(0.0).sum(dims_sum) - vjp_array = np.array(E_der_dim_interp.values).astype(complex) + E_der_dim_interp = E_der_dim.interp(**coords_interp).fillna(0.0).sum(dims_sum).real + E_der_dim_interp = E_der_dim_interp.sum("f") + + vjp_array = np.array(E_der_dim_interp.values).astype(float) + vjp_array = vjp_array.reshape(eps_data.shape) # multiply by volume elements (if possible, being defensive here..) diff --git a/tidy3d/components/simulation.py b/tidy3d/components/simulation.py index 7a35e8177..c2b1f6940 100644 --- a/tidy3d/components/simulation.py +++ b/tidy3d/components/simulation.py @@ -3254,14 +3254,6 @@ def freqs_adjoint(self) -> list[float]: if isinstance(mnt, FreqMonitor): freqs.update(mnt.freqs) freqs = sorted(freqs) - - if len(freqs) > 1: - raise ValueError( - "Only the same, single frequency is supported in all monitors " - "when using autograd differentiation. " - f"Found {len(freqs)} distinct frequencies in the monitors." - ) - return freqs """ Accounting """ diff --git a/tidy3d/web/api/autograd/autograd.py b/tidy3d/web/api/autograd/autograd.py index 40a8bd079..cd4d7729b 100644 --- a/tidy3d/web/api/autograd/autograd.py +++ b/tidy3d/web/api/autograd/autograd.py @@ -2,8 +2,10 @@ import traceback import typing +from collections import defaultdict import numpy as np +import xarray as xr from autograd.builtins import dict as dict_ag from autograd.extend import defvjp, primitive @@ -485,28 +487,13 @@ def _run_bwd( def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap: """dJ/d{sim.traced_fields()} as a function of Function of dJ/d{data.traced_fields()}""" - sim_adj = setup_adj( + sim_adj, post_norm_amps = setup_adj( data_fields_vjp=data_fields_vjp, sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_original=sim_fields_original, ) - # no adjoint sources, no gradient for you :( - if not len(sim_adj.sources): - td.log.warning( - "No adjoint sources generated. " - "There is likely zero output in the data, or you have no traceable monitors. " - "As a result, the 'SimulationData' returned has no contribution to the gradient. " - "Skipping the adjoint simulation. " - "If this is unexpected, please double check the post-processing function to ensure " - "there is a path from the 'SimulationData' to the objective function return value." - ) - - # TODO: add a test for this - # construct a VJP of all zeros for all tracers in the original simulation - return {path: 0 * value for path, value in sim_fields_original.items()} - # run adjoint simulation task_name_adj = str(task_name) + "_adjoint" sim_data_adj = _run_tidy3d(sim_adj, task_name=task_name_adj, **run_kwargs) @@ -516,6 +503,7 @@ def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap: sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_original=sim_fields_original, + post_norm_amps=post_norm_amps, ) return vjp @@ -548,19 +536,21 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd task_names_adj = {task_name + "_adjoint" for task_name in task_names} sims_adj = {} + post_norm_amps_dict = {} for task_name, task_name_adj in zip(task_names, task_names_adj): data_fields_vjp = data_fields_dict_vjp[task_name] sim_data_orig = sim_data_orig_dict[task_name] sim_data_fwd = sim_data_fwd_dict[task_name] sim_fields_original = sim_fields_original_dict[task_name] - sim_adj = setup_adj( + sim_adj, post_norm_amps = setup_adj( data_fields_vjp=data_fields_vjp, sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_original=sim_fields_original, ) sims_adj[task_name_adj] = sim_adj + post_norm_amps_dict[task_name_adj] = post_norm_amps # TODO: handle case where no adjoint sources? @@ -570,15 +560,16 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd sim_fields_vjp_dict = {} for task_name, task_name_adj in zip(task_names, task_names_adj): sim_data_adj = batch_data_adj[task_name_adj] + post_norm_amps = post_norm_amps_dict[task_name_adj] sim_data_orig = sim_data_orig_dict[task_name] sim_data_fwd = sim_data_fwd_dict[task_name] sim_fields_original = sim_fields_original_dict[task_name] - sim_fields_vjp = postprocess_adj( sim_data_adj=sim_data_adj, sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_original=sim_fields_original, + post_norm_amps=post_norm_amps, ) sim_fields_vjp_dict[task_name] = sim_fields_vjp @@ -592,7 +583,7 @@ def setup_adj( sim_data_orig: td.SimulationData, sim_data_fwd: td.SimulationData, sim_fields_original: AutogradFieldMap, -) -> td.Simulation: +) -> tuple[td.Simulation, xr.DataArray]: """Construct an adjoint simulation from a set of data_fields for the VJP.""" td.log.info("Running custom vjp (adjoint) pipeline.") @@ -607,13 +598,13 @@ def setup_adj( # make adjoint simulation from that SimulationData data_vjp_paths = set(data_fields_vjp.keys()) - sim_adj = sim_data_vjp.make_adjoint_sim( + sim_adj, post_norm_amps = sim_data_vjp.make_adjoint_sim( data_vjp_paths=data_vjp_paths, adjoint_monitors=sim_data_fwd.simulation.monitors ) td.log.info(f"Adjoint simulation created with {len(sim_adj.sources)} sources.") - return sim_adj + return sim_adj, post_norm_amps def postprocess_adj( @@ -621,17 +612,15 @@ def postprocess_adj( sim_data_orig: td.SimulationData, sim_data_fwd: td.SimulationData, sim_fields_original: AutogradFieldMap, + post_norm_amps: xr.DataArray, ) -> AutogradFieldMap: """Postprocess some data from the adjoint simulation into the VJP for the original sim flds.""" # map of index into 'structures' to the list of paths we need vjps for - sim_vjp_map = {} + sim_vjp_map = defaultdict(list) for _, structure_index, *structure_path in sim_fields_original.keys(): structure_path = tuple(structure_path) - if structure_index in sim_vjp_map: - sim_vjp_map[structure_index].append(structure_path) - else: - sim_vjp_map[structure_index] = [structure_path] + sim_vjp_map[structure_index].append(structure_path) # store the derivative values given the forward and adjoint data sim_fields_vjp = {} @@ -642,6 +631,13 @@ def postprocess_adj( fld_adj = sim_data_adj.get_adjoint_data(structure_index, data_type="fld") eps_adj = sim_data_adj.get_adjoint_data(structure_index, data_type="eps") + # post normalize the adjoint fields if a single, broadband source + if post_norm_amps is not None: + fwd_flds_normed = { + key: val * post_norm_amps for key, val in fld_adj.field_components.items() + } + fld_adj = fld_adj.updated_copy(**fwd_flds_normed) + # maps of the E_fwd * E_adj and D_fwd * D_adj, each as as td.FieldData & 'Ex', 'Ey', 'Ez' der_maps = get_derivative_maps( fld_fwd=fld_fwd, eps_fwd=eps_fwd, fld_adj=fld_adj, eps_adj=eps_adj @@ -656,7 +652,7 @@ def postprocess_adj( frequencies = {src.source_time.freq0 for src in sim_data_adj.simulation.sources} frequencies = list(frequencies) assert len(frequencies) == 1, "Multiple adjoint freqs found" - freq_adj = frequencies[0] + freq_adj = frequencies[0] or None eps_in = np.mean(structure.medium.eps_model(freq_adj)) eps_out = np.mean(sim_data_orig.simulation.medium.eps_model(freq_adj))