-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
2nd round of autograd fixes #1923
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,6 @@ | |
from ...log import log | ||
from ..base import TYPE_TAG_STR, cached_property, skip_if_fields_missing | ||
from ..base_sim.data.monitor_data import AbstractMonitorData | ||
from ..geometry.base import Box | ||
from ..grid.grid import Coords, Grid | ||
from ..medium import Medium, MediumType | ||
from ..monitor import ( | ||
|
@@ -1069,46 +1068,33 @@ def to_adjoint_field_sources(self, fwidth: float) -> List[CustomCurrentSource]: | |
|
||
sources = [] | ||
|
||
# Define source geometry based on coordinates in the data | ||
data_mins = [] | ||
data_maxs = [] | ||
source_geo = self.monitor.geometry | ||
freqs = self.monitor.freqs | ||
|
||
def shift_value(coords) -> float: | ||
"""How much to shift the geometry by along a dimension (only if > 1D).""" | ||
return SHIFT_VALUE_ADJ_FLD_SRC if len(coords) > 1 else 0 | ||
|
||
for _, field_component in self.field_components.items(): | ||
coords = field_component.coords | ||
data_mins.append({key: min(val) + shift_value(val) for key, val in coords.items()}) | ||
data_maxs.append({key: max(val) + shift_value(val) for key, val in coords.items()}) | ||
|
||
rmin = [] | ||
rmax = [] | ||
for dim in "xyz": | ||
rmin.append(max(val[dim] for val in data_mins)) | ||
rmax.append(min(val[dim] for val in data_maxs)) | ||
|
||
source_geo = Box.from_bounds(rmin=rmin, rmax=rmax) | ||
|
||
# Define source dataset | ||
# Offset coordinates by source center since local coords are assumed in CustomCurrentSource | ||
|
||
for freq0 in tuple(self.field_components.values())[0].coords["f"]: | ||
for freq0 in freqs: | ||
src_field_components = {} | ||
for name, field_component in self.field_components.items(): | ||
# get the VJP values at frequency and apply adjoint phase | ||
field_component = field_component.sel(f=freq0) | ||
forward_amps = field_component.values | ||
values = -1j * forward_amps | ||
values = -1j * field_component.values | ||
|
||
# make source go backwards | ||
if "H" in name: | ||
values *= -1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this what we do in the adjoint plugin too? I am trying to think about whether this is the correct thing to do in all cases. It's probably possible to see it somehow from the adjoint formulation if written for the H field. I think this is correct. I'm just wondering a little because the intuition of this minus sign "making source go backwards" makes sense for a real propagating mode, but it gets a bit murkier for evanescent or lossy modes. The other option is to take np.conj(values). If the E field is real and the mode is lossless and propagating, then the H field is imaginary and the two operations are the same. In general the I do think the |
||
|
||
# make coords that are shifted relative to geometry (0,0,0) = geometry.center | ||
coords = dict(field_component.coords.copy()) | ||
for dim, key in enumerate("xyz"): | ||
coords[key] = np.array(coords[key]) - source_geo.center[dim] | ||
coords["f"] = np.array([freq0]) | ||
values = np.expand_dims(values, axis=-1) | ||
|
||
# ignore zero components | ||
if not np.all(values == 0): | ||
src_field_components[name] = ScalarFieldDataArray(values, coords=coords) | ||
|
||
# construct custom Current source | ||
dataset = FieldDataset(**src_field_components) | ||
|
||
custom_source = CustomCurrentSource( | ||
center=source_geo.center, | ||
size=source_geo.size, | ||
|
@@ -1763,7 +1749,7 @@ def make_adjoint_sources(self, dataset_names: list[str], fwidth: float) -> list[ | |
for name in dataset_names: | ||
if name == "amps": | ||
adjoint_sources += self.make_adjoint_sources_amps(fwidth=fwidth) | ||
else: | ||
elif not np.all(self.n_complex.values == 0.0): | ||
log.warning( | ||
f"Can't create adjoint source for 'ModeData.{type(self)}.{name}'. " | ||
f"for monitor '{self.monitor.name}'. " | ||
|
@@ -1948,6 +1934,10 @@ def make_adjoint_sources( | |
) -> List[Union[CustomCurrentSource, PointDipole]]: | ||
"""Converts a :class:`.FieldData` to a list of adjoint current or point sources.""" | ||
|
||
# avoids error in edge case where there are extraneous flux monitors not used in objective | ||
if np.all(self.flux.values == 0.0): | ||
return [] | ||
|
||
raise NotImplementedError( | ||
"Could not formulate adjoint source for 'FluxMonitor' output. To compute derivatives " | ||
"with respect to flux data, please use a 'FieldMonitor' and call '.flux' on the " | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be nice to have private debug run function, or maybe a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. which line? the adjoint field visualization? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah exactly, that and the global toggle to trigger that code |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
double
abs
? not that it matters i guess