diff --git a/tidy3d/components/data/sim_data.py b/tidy3d/components/data/sim_data.py index 0038206217..03c1c206d7 100644 --- a/tidy3d/components/data/sim_data.py +++ b/tidy3d/components/data/sim_data.py @@ -15,7 +15,7 @@ from ...exceptions import DataError, FileError, Tidy3dKeyError from ...log import log -from ..base import JSON_TAG +from ..base import JSON_TAG, Tidy3dBaseModel from ..base_sim.data.sim_data import AbstractSimulationData from ..file_util import replace_values from ..monitor import Monitor @@ -37,7 +37,31 @@ DATA_TYPE_NAME_MAP = {val.__fields__["monitor"].type_.__name__: val for val in MonitorDataTypes} # residuals below this are considered good fits for broadband adjoint source creation -RESIDUAL_CUTOFF_ADJOINT = 0.5 +RESIDUAL_CUTOFF_ADJOINT = 1e-6 + + +class AdjointSourceInfo(Tidy3dBaseModel): + """Stores information about the adjoint sources to pass to autograd pipeline.""" + + sources: tuple[Source, ...] = pd.Field( + ..., + title="Adjoint Sources", + description="Set of processed sources to include in the adjoint simulation.", + ) + + post_norm: Union[float, xr.DataArray] = pd.Field( + ..., + title="Post Normalization Values", + description="Factor to multiply the adjoint fields by after running " + "given the adjoint source pipeline used.", + ) + + normalize_sim: bool = pd.Field( + ..., + title="Normalize Adjoint Simulation", + description="Whether the adjoint simulation needs to be normalized " + "given the adjoint source pipeline used.", + ) class AbstractYeeGridSimulationData(AbstractSimulationData, ABC): @@ -956,7 +980,7 @@ def source_spectrum_fn(freqs): def make_adjoint_sim( self, data_vjp_paths: set[tuple], adjoint_monitors: list[Monitor] - ) -> tuple[Simulation, float]: + ) -> tuple[Simulation, AdjointSourceInfo]: """Make the adjoint simulation from the original simulation and the VJP-containing data.""" sim_original = self.simulation @@ -967,19 +991,19 @@ def make_adjoint_sim( for src_list in sources_adj_dict.values(): adj_srcs += list(src_list) - sources_adj, post_norm, norm_source = self.process_adjoint_sources(adj_srcs=adj_srcs) + adjoint_source_info = self.process_adjoint_sources(adj_srcs=adj_srcs) # grab boundary conditions with flipped Bloch vectors (for adjoint) bc_adj = sim_original.boundary_spec.flipped_bloch_vecs # fields to update the 'fwd' simulation with to make it 'adj' sim_adj_update_dict = dict( - sources=sources_adj, + sources=adjoint_source_info.sources, boundary_spec=bc_adj, monitors=adjoint_monitors, ) - if not norm_source: + if not adjoint_source_info.normalize_sim: sim_adj_update_dict["normalize_index"] = None # set the ADJ grid spec wavelength to the original wavelength (for same meshing) @@ -989,7 +1013,7 @@ def make_adjoint_sim( grid_spec_adj = grid_spec_original.updated_copy(wavelength=wavelength_original) sim_adj_update_dict["grid_spec"] = grid_spec_adj - return sim_original.updated_copy(**sim_adj_update_dict), post_norm + return sim_original.updated_copy(**sim_adj_update_dict), adjoint_source_info def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, Source]: """Generate all of the non-zero sources for the adjoint simulation given the VJP data.""" @@ -1010,29 +1034,14 @@ def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, Source]: return sources_adj_all - @staticmethod - def get_amp(src_time: GaussianPulse) -> complex: - """grab the complex amplitude from a ``SourceTime``.""" - mag = src_time.amplitude - phase = np.exp(1j * src_time.phase) - return mag * phase - - @staticmethod - def set_amp(src_time: GaussianPulse, amp: complex) -> GaussianPulse: - """set the complex amplitude of a ``SourceTime``.""" - amplitude = abs(amp) - phase = np.angle(amp) - return src_time.updated_copy(amplitude=amplitude, phase=phase) - @property def fwidth_adj(self) -> float: # fwidth of forward pass, try as default for adjoint normalize_index_fwd = self.simulation.normalize_index or 0 return self.simulation.sources[normalize_index_fwd].source_time.fwidth - def process_adjoint_sources( - self, adj_srcs: list[Source] - ) -> tuple[list[Source], Union[float, xr.DataArray], bool]: + def process_adjoint_sources(self, adj_srcs: list[Source]) -> AdjointSourceInfo: + # tuple[list[Source], Union[float, xr.DataArray], bool]: """Compute list of final sources along with a post run normalization for adj fields.""" # dictionary mapping unique spatial dependence of each Source to list of time-dependencies @@ -1049,12 +1058,14 @@ def process_adjoint_sources( # next, figure out which treatment / normalization to apply if num_unique_freqs == 1: log.info("adjoint source creation: one unique frequency, no normalization") - return adj_srcs, 1.0, True + return AdjointSourceInfo(sources=adj_srcs, post_norm=1.0, normalize_sim=True) + # return adj_srcs, 1.0, True if num_ports == 1 and len(adj_srcs) == num_unique_freqs: log.info("adjoint source creation: one spatial port detected") adj_srcs, post_norm = self.process_adjoint_sources_broadband(adj_srcs) - return adj_srcs, post_norm, True + return AdjointSourceInfo(sources=adj_srcs, post_norm=post_norm, normalize_sim=True) + # return adj_srcs, post_norm, True # if several spatial ports and several frequencies, try to fit log.info("adjoint source creation: trying multifrequency fit.") @@ -1063,7 +1074,8 @@ def process_adjoint_sources( spatial_to_src_times=spatial_to_src_times, json_to_sources=json_to_sources, ) - return adj_srcs, post_norm, False + return AdjointSourceInfo(sources=adj_srcs, post_norm=post_norm, normalize_sim=False) + # return adj_srcs, post_norm, False """ SIMPLE APPROACH """ @@ -1084,7 +1096,7 @@ def process_adjoint_sources_broadband( return [src_broadband], post_norm_amps - def _make_broadband_source(self, adj_srcs: list[Source], num_fwidth: float = 0.5) -> Source: + def _make_broadband_source(self, adj_srcs: list[Source]) -> Source: """Make a broadband source for a set of adjoint sources.""" source_index = self.simulation.normalize_index or 0 @@ -1139,15 +1151,16 @@ def process_adjoint_sources_fit( # compute amplitudes of each adjoint source, and the norm adj_src_amps = [] for src in new_adj_srcs: - amp = self.get_amp(src.source_time) + amp = src.source_time.amp_complex adj_src_amps.append(amp) norm_amps = np.linalg.norm(adj_src_amps) # normalize all of the adjoint sources by this and return the normalization term used adj_srcs_norm = [] for src in new_adj_srcs: - amp = self.get_amp(src.source_time) - src_time_norm = self.set_amp(src_time=src.source_time, amp=amp / norm_amps) + src_time = src.source_time + amp = src_time.amp_complex + src_time_norm = src_time.from_amp_complex(amp=amp / norm_amps) src_nrm = src.updated_copy(source_time=src_time_norm) adj_srcs_norm.append(src_nrm) @@ -1180,7 +1193,7 @@ def get_coupling_matrix(fwidth: float) -> np.ndarray: ] ).T - amps_adj = np.array([self.get_amp(src_time) for src_time in source_times]) + amps_adj = np.array([src_time.amp_complex for src_time in source_times]) # compute the corrected set of amps to inject at each freq to take coupling into account def get_amps_corrected(fwidth: float) -> tuple[np.ndarray, float]: @@ -1207,7 +1220,7 @@ def get_amps_corrected(fwidth: float) -> tuple[np.ndarray, float]: # construct the new adjoint sources with the corrected amplitudes src_times_corrected = [ - self.set_amp(src_time=src_time, amp=amp).updated_copy(fwidth=self.fwidth_adj) + src_time.from_amp_complex(amp=amp, fwidth=self.fwidth_adj) for src_time, amp in zip(source_times, amps_corrected) ] srcs_corrected = [] diff --git a/tidy3d/components/medium.py b/tidy3d/components/medium.py index 4b5fd6bb4e..27b684f986 100644 --- a/tidy3d/components/medium.py +++ b/tidy3d/components/medium.py @@ -2587,7 +2587,7 @@ def _derivative_field_cmp( E_der_dim_interp = E_der_dim.interp(**coords_interp).fillna(0.0).sum(dims_sum).real E_der_dim_interp = E_der_dim_interp.sum("f") - vjp_array = np.array(E_der_dim_interp.values).astype(float) + vjp_array = np.array(E_der_dim_interp.values, dtype=float) vjp_array = vjp_array.reshape(eps_data.shape) diff --git a/tidy3d/components/source.py b/tidy3d/components/source.py index 6716731e13..fbabdd7d98 100644 --- a/tidy3d/components/source.py +++ b/tidy3d/components/source.py @@ -201,6 +201,20 @@ def end_time(self) -> float | None: return self.offset * self.twidth + END_TIME_FACTOR_GAUSSIAN * self.twidth + @property + def amp_complex(self) -> complex: + """grab the complex amplitude from a ``SourceTime``.""" + mag = self.amplitude + phase = np.exp(1j * self.phase) + return mag * phase + + @classmethod + def from_amp_complex(cls, amp: complex, **kwargs) -> GaussianPulse: + """set the complex amplitude of a ``SourceTime``.""" + amplitude = abs(amp) + phase = np.angle(amp) + return cls(amplitude=amplitude, phase=phase, **kwargs) + class ContinuousWave(Pulse): """Source time dependence that ramps up to continuous oscillation diff --git a/tidy3d/web/api/autograd/autograd.py b/tidy3d/web/api/autograd/autograd.py index 35f805797d..23c537338c 100644 --- a/tidy3d/web/api/autograd/autograd.py +++ b/tidy3d/web/api/autograd/autograd.py @@ -11,6 +11,7 @@ import tidy3d as td from tidy3d.components.autograd import AutogradFieldMap, get_static from tidy3d.components.autograd.derivative_utils import DerivativeInfo +from tidy3d.components.data.sim_data import AdjointSourceInfo from ..asynchronous import DEFAULT_DATA_DIR from ..asynchronous import run_async as run_async_webapi @@ -486,7 +487,7 @@ def _run_bwd( def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap: """dJ/d{sim.traced_fields()} as a function of Function of dJ/d{data.traced_fields()}""" - sim_adj, post_norm = setup_adj( + sim_adj, adjoint_source_info = setup_adj( data_fields_vjp=data_fields_vjp, sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, @@ -502,7 +503,7 @@ def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap: sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_original=sim_fields_original, - post_norm=post_norm, + adjoint_source_info=adjoint_source_info, ) return vjp @@ -535,21 +536,21 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd task_names_adj = {task_name + "_adjoint" for task_name in task_names} sims_adj = {} - post_norm_dict = {} + adjoint_source_info_dict = {} for task_name, task_name_adj in zip(task_names, task_names_adj): data_fields_vjp = data_fields_dict_vjp[task_name] sim_data_orig = sim_data_orig_dict[task_name] sim_data_fwd = sim_data_fwd_dict[task_name] sim_fields_original = sim_fields_original_dict[task_name] - sim_adj, post_norm = setup_adj( + sim_adj, adjoint_source_info = setup_adj( data_fields_vjp=data_fields_vjp, sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_original=sim_fields_original, ) sims_adj[task_name_adj] = sim_adj - post_norm_dict[task_name_adj] = post_norm + adjoint_source_info_dict[task_name_adj] = adjoint_source_info # TODO: handle case where no adjoint sources? @@ -559,7 +560,7 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd sim_fields_vjp_dict = {} for task_name, task_name_adj in zip(task_names, task_names_adj): sim_data_adj = batch_data_adj[task_name_adj] - post_norm = post_norm_dict[task_name_adj] + adjoint_source_info = adjoint_source_info_dict[task_name_adj] sim_data_orig = sim_data_orig_dict[task_name] sim_data_fwd = sim_data_fwd_dict[task_name] sim_fields_original = sim_fields_original_dict[task_name] @@ -568,7 +569,7 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_original=sim_fields_original, - post_norm=post_norm, + adjoint_source_info=adjoint_source_info, ) sim_fields_vjp_dict[task_name] = sim_fields_vjp @@ -582,7 +583,7 @@ def setup_adj( sim_data_orig: td.SimulationData, sim_data_fwd: td.SimulationData, sim_fields_original: AutogradFieldMap, -) -> tuple[td.Simulation, float]: +) -> tuple[td.Simulation, AdjointSourceInfo]: """Construct an adjoint simulation from a set of data_fields for the VJP.""" td.log.info("Running custom vjp (adjoint) pipeline.") @@ -597,13 +598,13 @@ def setup_adj( # make adjoint simulation from that SimulationData data_vjp_paths = set(data_fields_vjp.keys()) - sim_adj, post_norm = sim_data_vjp.make_adjoint_sim( + sim_adj, adjoint_source_info = sim_data_vjp.make_adjoint_sim( data_vjp_paths=data_vjp_paths, adjoint_monitors=sim_data_fwd.simulation.monitors ) td.log.info(f"Adjoint simulation created with {len(sim_adj.sources)} sources.") - return sim_adj, post_norm + return sim_adj, adjoint_source_info def postprocess_adj( @@ -611,7 +612,7 @@ def postprocess_adj( sim_data_orig: td.SimulationData, sim_data_fwd: td.SimulationData, sim_fields_original: AutogradFieldMap, - post_norm: float, + adjoint_source_info: AdjointSourceInfo, ) -> AutogradFieldMap: """Postprocess some data from the adjoint simulation into the VJP for the original sim flds.""" @@ -634,7 +635,7 @@ def postprocess_adj( fwd_flds_normed = {} for key, val in fld_adj.field_components.items(): - fwd_flds_normed[key] = val * post_norm + fwd_flds_normed[key] = val * adjoint_source_info.post_norm fld_adj = fld_adj.updated_copy(**fwd_flds_normed)