Skip to content

Commit

Permalink
move adjoint source post norm from source info to simulation
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex authored and momchil-flex committed Oct 4, 2024
1 parent ce64857 commit 545edf7
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 46 deletions.
4 changes: 0 additions & 4 deletions tests/test_components/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import tidy3d.web as web
import xarray as xr
from tidy3d.components.autograd.derivative_utils import DerivativeInfo
from tidy3d.components.data.sim_data import AdjointSourceInfo
from tidy3d.web import run, run_async
from tidy3d.web.api.autograd.utils import FieldMap

Expand Down Expand Up @@ -197,15 +196,12 @@ def emulated_run_bwd(simulation, task_name, **run_kwargs) -> td.SimulationData:
# get the original traced fields
sim_fields_keys = cache[task_id_fwd][AUX_KEY_SIM_FIELDS_KEYS]

adjoint_source_info = AdjointSourceInfo(sources=[], post_norm=1.0, normalize_sim=True)

# postprocess (compute adjoint gradients)
traced_fields_vjp = postprocess_adj(
sim_data_adj=sim_data_adj,
sim_data_orig=sim_data_orig,
sim_data_fwd=sim_data_fwd,
sim_fields_keys=sim_fields_keys,
adjoint_source_info=adjoint_source_info,
)

return traced_fields_vjp
Expand Down
5 changes: 3 additions & 2 deletions tidy3d/components/data/sim_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,7 @@ def split_original_fwd(self, num_mnts_original: int) -> Tuple[SimulationData, Si

def make_adjoint_sim(
self, data_vjp_paths: set[tuple], adjoint_monitors: list[Monitor]
) -> tuple[Simulation, AdjointSourceInfo]:
) -> Simulation:
"""Make the adjoint simulation from the original simulation and the VJP-containing data."""

sim_original = self.simulation
Expand All @@ -1043,6 +1043,7 @@ def make_adjoint_sim(
sources=adjoint_source_info.sources,
boundary_spec=bc_adj,
monitors=adjoint_monitors,
post_norm=adjoint_source_info.post_norm,
)

if not adjoint_source_info.normalize_sim:
Expand All @@ -1055,7 +1056,7 @@ def make_adjoint_sim(
grid_spec_adj = grid_spec_original.updated_copy(wavelength=wavelength_original)
sim_adj_update_dict["grid_spec"] = grid_spec_adj

return sim_original.updated_copy(**sim_adj_update_dict), adjoint_source_info
return sim_original.updated_copy(**sim_adj_update_dict)

def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, SourceType]:
"""Generate all of the non-zero sources for the adjoint simulation given the VJP data."""
Expand Down
8 changes: 8 additions & 0 deletions tidy3d/components/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
PMCBoundary,
StablePML,
)
from .data.data_array import FreqDataArray
from .data.dataset import CustomSpatialDataType, Dataset
from .geometry.base import Box, Geometry
from .geometry.mesh import TriangleMesh
Expand Down Expand Up @@ -210,6 +211,13 @@ class AbstractYeeGridSimulation(AbstractSimulation, ABC):
"``autograd`` gradient processing.",
)

post_norm: Union[float, FreqDataArray] = pydantic.Field(
1.0,
title="Post Normalization Values",
description="Factor to multiply the fields by after running, "
"given the adjoint source pipeline used. Note: this is used internally only.",
)

"""
Supply :class:`SubpixelSpec` to select subpixel averaging methods separately for dielectric, metal, and
PEC material interfaces. Alternatively, supply ``True`` to use default subpixel averaging methods,
Expand Down
46 changes: 6 additions & 40 deletions tidy3d/web/api/autograd/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
# server-side auxiliary files to upload/download
SIM_VJP_FILE = "output/autograd_sim_vjp.hdf5"
SIM_FIELDS_KEYS_FILE = "autograd_sim_fields_keys.hdf5"
ADJOINT_SOURCE_INFO_FILE = "autograd_adjoint_source_info_file.hdf5"

ISSUE_URL = (
"https://github.com/flexcompute/tidy3d/issues/new?"
Expand Down Expand Up @@ -540,21 +539,6 @@ def upload_sim_fields_keys(sim_fields_keys: list[tuple], task_id: str, verbose:
)


def upload_adjoint_source_info(
adjoint_source_info: AdjointSourceInfo, task_id: str, verbose: bool = False
) -> None:
"""Upload the adjoint source information for the adjoint run."""
data_file = tempfile.NamedTemporaryFile(suffix=".hdf5")
data_file.close()
adjoint_source_info.to_file(data_file.name)
upload_file(
task_id,
data_file.name,
ADJOINT_SOURCE_INFO_FILE,
verbose=verbose,
)


""" VJP maker for ADJ pass."""


Expand Down Expand Up @@ -591,7 +575,7 @@ def _run_bwd(
def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap:
"""dJ/d{sim.traced_fields()} as a function of Function of dJ/d{data.traced_fields()}"""

sim_adj, adjoint_source_info = setup_adj(
sim_adj = setup_adj(
data_fields_vjp=data_fields_vjp,
sim_data_orig=sim_data_orig,
sim_fields_keys=sim_fields_keys,
Expand All @@ -608,7 +592,6 @@ def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap:
sim_data_orig=sim_data_orig,
sim_data_fwd=sim_data_fwd,
sim_fields_keys=sim_fields_keys,
adjoint_source_info=adjoint_source_info,
)

else:
Expand All @@ -620,7 +603,6 @@ def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap:
vjp_traced_fields = _run_tidy3d_bwd(
sim_adj,
task_name=task_name_adj,
adjoint_source_info=adjoint_source_info,
**run_kwargs,
)

Expand Down Expand Up @@ -662,19 +644,17 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd
task_names_adj = {task_name + "_adjoint" for task_name in task_names}

sims_adj = {}
adjoint_source_info_dict = {}
for task_name, task_name_adj in zip(task_names, task_names_adj):
data_fields_vjp = data_fields_dict_vjp[task_name]
sim_data_orig = sim_data_orig_dict[task_name]
sim_fields_keys = sim_fields_keys_dict[task_name]

sim_adj, adjoint_source_info = setup_adj(
sim_adj = setup_adj(
data_fields_vjp=data_fields_vjp,
sim_data_orig=sim_data_orig,
sim_fields_keys=sim_fields_keys,
)
sims_adj[task_name_adj] = sim_adj
adjoint_source_info_dict[task_name_adj] = adjoint_source_info
# TODO: handle case where no adjoint sources?

if local_gradient:
Expand All @@ -687,14 +667,12 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd
sim_data_orig = sim_data_orig_dict[task_name]
sim_data_fwd = sim_data_fwd_dict[task_name]
sim_fields_keys = sim_fields_keys_dict[task_name]
adjoint_source_info = adjoint_source_info_dict[task_name_adj]

sim_fields_vjp = postprocess_adj(
sim_data_adj=sim_data_adj,
sim_data_orig=sim_data_orig,
sim_data_fwd=sim_data_fwd,
sim_fields_keys=sim_fields_keys,
adjoint_source_info=adjoint_source_info,
)
sim_fields_vjp_dict[task_name] = sim_fields_vjp

Expand All @@ -712,7 +690,6 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd
}
sim_fields_vjp_dict_adj_keys = _run_async_tidy3d_bwd(
simulations=sims_adj,
adjoint_source_info_dict=adjoint_source_info_dict,
**run_async_kwargs,
)

Expand Down Expand Up @@ -749,7 +726,7 @@ def setup_adj(
num_monitors:
]

sim_adj, adjoint_source_info = sim_data_vjp.make_adjoint_sim(
sim_adj = sim_data_vjp.make_adjoint_sim(
data_vjp_paths=data_vjp_paths, adjoint_monitors=adjoint_monitors
)

Expand Down Expand Up @@ -778,15 +755,14 @@ def setup_adj(

td.log.info(f"Adjoint simulation created with {len(sim_adj.sources)} sources.")

return sim_adj, adjoint_source_info
return sim_adj


def postprocess_adj(
sim_data_adj: td.SimulationData,
sim_data_orig: td.SimulationData,
sim_data_fwd: td.SimulationData,
sim_fields_keys: list[tuple],
adjoint_source_info: AdjointSourceInfo,
) -> AutogradFieldMap:
"""Postprocess some data from the adjoint simulation into the VJP for the original sim flds."""

Expand All @@ -809,7 +785,7 @@ def postprocess_adj(

fwd_flds_normed = {}
for key, val in E_adj.field_components.items():
fwd_flds_normed[key] = val * adjoint_source_info.post_norm
fwd_flds_normed[key] = val * sim_data_adj.simulation.post_norm

E_adj = E_adj.updated_copy(**fwd_flds_normed)

Expand Down Expand Up @@ -890,14 +866,10 @@ def _run_tidy3d(
return data, job.task_id


def _run_tidy3d_bwd(
simulation: td.Simulation, task_name: str, adjoint_source_info: AdjointSourceInfo, **run_kwargs
) -> AutogradFieldMap:
def _run_tidy3d_bwd(simulation: td.Simulation, task_name: str, **run_kwargs) -> AutogradFieldMap:
"""Run a simulation without any tracers using regular web.run()."""
job_init_kwargs = parse_run_kwargs(**run_kwargs)
job = Job(simulation=simulation, task_name=task_name, **job_init_kwargs)
verbose = run_kwargs.get("verbose", False)
upload_adjoint_source_info(adjoint_source_info, task_id=job.task_id, verbose=verbose)
td.log.info(f"running {job.simulation_type} simulation with '_run_tidy3d_bwd()'")
job.start()
job.monitor()
Expand Down Expand Up @@ -939,7 +911,6 @@ def _run_async_tidy3d(

def _run_async_tidy3d_bwd(
simulations: dict[str, td.Simulation],
adjoint_source_info_dict: dict[str, AdjointSourceInfo],
**run_kwargs,
) -> dict[str, AutogradFieldMap]:
"""Run a simulation without any tracers using regular web.run()."""
Expand All @@ -949,11 +920,6 @@ def _run_async_tidy3d_bwd(
batch = Batch(simulations=simulations, **batch_init_kwargs)
td.log.info(f"running {batch.simulation_type} simulation with '_run_tidy3d_bwd()'")

task_ids = {key: job.task_id for key, job in batch.jobs.items()}
for task_name, adjoint_source_info in adjoint_source_info_dict.items():
task_id = task_ids[task_name]
upload_adjoint_source_info(adjoint_source_info, task_id=task_id, verbose=batch.verbose)

batch.start()
batch.monitor()

Expand Down

0 comments on commit 545edf7

Please sign in to comment.