Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix adjoint source creation #2282

Merged
merged 1 commit into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed `ComponentModeler.to_file` when its batch is empty.
- In web api, mode solver is patched with remote data so that certain methods like `plot_field` show remote data.
- Added `viz_spec` property to `AbstractStructure` to fix error when plotting structures that have no `medium`.
- Fixed invalid adjoint source creation by disallowing sources made from non-traced fields.

## [2.8.0rc2] - 2025-01-28

Expand Down
13 changes: 12 additions & 1 deletion tidy3d/web/api/autograd/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,8 +861,19 @@ def setup_adj(
# immediately filter out any data_vjps with all 0's in the data
data_fields_vjp = {key: get_static(value) for key, value in data_fields_vjp.items()}

# start with the full simulation data structure and either zero out the fields
# that have no tracer data for them or insert the tracer data
full_sim_data_dict = sim_data_orig.strip_traced_fields(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this be an alternative approach? (only new code is "clearing" the data of untraced fields)

# get all of the adjoint source data
data_fields_vjp = {key: get_static(value) for key, value in data_fields_vjp.items()}

# remove all of the untraced data
traced_data = [data for data in sim_data_orig.data if data.monitor.name in data_fields_vjp]
sim_data_vjp = sim_data_orig.updated_copy(data=traced_data, deep=False)

# insert the VJP (as before)
sim_data_vjp = sim_data_vjp.insert_traced_fields(field_mapping=data_fields_vjp)

I suppose a downside of this is that it introduces an additional copy, but maybe because it is shallow it is not so bad? Upside is that it's bit simpler to understand.

Just a thought, I think what you have also works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, I like the upside of easier to understand.

The issue I'm running into with this approach is we actually are needing to clear certain fields in each monitor data according to the data_fields_vjp keys instead of just the whole data associated with a monitor. The monitor name is not a key into the data_fields_vjp because that key contains information on a specific field in the monitor data (i.e. - an example of one of the keys in data_fields_vjp looks like ('data', 0, 'Ex'))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right, I've already forgotten a bit how this data is stored. I guess it could look rather like

traced_indices = {index for _, index, * in data_fields_vjp.keys()}
traced_data = [sim_data_orig.data[i] for i in traced_indices]
sim_data_vjp = sim_data_orig.updated_copy(data=traced_data, deep=False)

or does that not work either?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

side note: might need to make sure traced_indices is sorted before iterating through it to avoid potential issues.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

traced_indices = {index for _, index, * in data_fields_vjp.keys()}
Is the * here intended to also be a _?

This isn't working when I test it either because while it strips the index correctly to pull out a specific monitor data from the simulation data, it still doesn't zero out fields inside that monitor data that are not present in the vjp data. Those non zero'd fields then get turned into adjoint sources.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, it should be a *_, basically to unpack everything after the index

it still doesn't zero out fields inside that monitor data that are not present in the vjp data

oh, interesting. Yea maybe in that case what you propose is ultimately the correct approach. thanks for going through that with me. I think it makes sense now. I've started somewhat forgetting how all of this works after spending more time away from it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, that makes sense on the syntax - I hadn't used that unpack before!

And definitely, no problem! I'm definitely still getting up to speed with a lot of how all this stuff is passed around and connects to the autograd library. It's quite the feat how seamless this stuff is for the user! I'm open to other approaches to this as well, but keeping with how the current adjoint source creation works inside monitor data, this was the best way I could come up with for now that seemed to get the job done. Yannick and I discussed this morning a potential different approach this morning in our 1:1 of using the vjp data directly to create the adjoint sources instead of filling in the monitor data as an intermediate.

include_untraced_data_arrays=True, starting_path=("data",)
)
for path in full_sim_data_dict.keys():
if path in data_fields_vjp:
full_sim_data_dict[path] = data_fields_vjp[path]
else:
full_sim_data_dict[path] *= 0

# insert the raw VJP data into the .data of the original SimulationData
sim_data_vjp = sim_data_orig.insert_traced_fields(field_mapping=data_fields_vjp)
sim_data_vjp = sim_data_orig.insert_traced_fields(field_mapping=full_sim_data_dict)

# make adjoint simulation from that SimulationData
data_vjp_paths = set(data_fields_vjp.keys())
Expand Down