Skip to content

Commit

Permalink
fix: remove adjoint temporary simulation files
Browse files Browse the repository at this point in the history
  • Loading branch information
yaugenst-flex committed Oct 24, 2024
1 parent 3d85cde commit 2687193
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 28 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- Minor gradient direction and normalization fixes for polyslab, field monitors, and diffraction monitors in autograd.
- Resolved an issue where temporary files for adjoint simulations were not being deleted properly.

## [2.7.5] - 2024-10-16

Expand Down
42 changes: 27 additions & 15 deletions tidy3d/plugins/adjoint/web.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Adjoint-specific webapi."""

import os
import tempfile
from functools import partial
from typing import Dict, List, Tuple
Expand All @@ -8,6 +9,7 @@
from jax import custom_vjp
from jax.tree_util import register_pytree_node_class

import tidy3d as td
from tidy3d.web.api.asynchronous import run_async as web_run_async
from tidy3d.web.api.webapi import run as web_run
from tidy3d.web.api.webapi import wait_for_connection
Expand Down Expand Up @@ -219,26 +221,36 @@ def run_bwd(
@wait_for_connection
def upload_jax_info(jax_info: JaxInfo, task_id: str, verbose: bool) -> None:
"""Upload jax_info for a task with a given task_id."""

data_file = tempfile.NamedTemporaryFile(suffix=".json")
data_file.close()
jax_info.to_file(data_file.name)
upload_file(
task_id,
data_file.name,
JAX_INFO_FILE,
verbose=verbose,
)
handle, fname = tempfile.mkstemp(suffix=".json")
os.close(handle)
try:
jax_info.to_file(fname)
upload_file(
task_id,
fname,
JAX_INFO_FILE,
verbose=verbose,
)
except Exception as e:
td.log.error(f"Error occurred while uploading 'jax_info': {e}")
raise e
finally:
os.unlink(fname)


@wait_for_connection
def download_sim_vjp(task_id: str, verbose: bool) -> JaxSimulation:
"""Download the vjp loaded simulation from the server to return to jax."""

data_file = tempfile.NamedTemporaryFile(suffix=".hdf5")
data_file.close()
download_file(task_id, SIM_VJP_FILE, to_file=data_file.name, verbose=verbose)
return JaxSimulation.from_file(data_file.name)
handle, fname = tempfile.mkstemp(suffix=".hdf5")
os.close(handle)
try:
download_file(task_id, SIM_VJP_FILE, to_file=fname, verbose=verbose)
return JaxSimulation.from_file(fname)
except Exception as e:
td.log.error(f"Error occurred while downloading 'sim_vjp': {e}")
raise e
finally:
os.unlink(fname)


AdjointSimulationType = Literal["tidy3d", "adjoint_fwd", "adjoint_bwd"]
Expand Down
39 changes: 26 additions & 13 deletions tidy3d/web/api/autograd/autograd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# autograd wrapper for web functions

import os
import tempfile
import typing
from collections import defaultdict
Expand Down Expand Up @@ -529,26 +530,38 @@ def postprocess_fwd(

def upload_sim_fields_keys(sim_fields_keys: list[tuple], task_id: str, verbose: bool = False):
"""Function to grab the VJP result for the simulation fields from the adjoint task ID."""
data_file = tempfile.NamedTemporaryFile(suffix=".hdf5")
data_file.close()
TracerKeys(keys=sim_fields_keys).to_file(data_file.name)
upload_file(
task_id,
data_file.name,
SIM_FIELDS_KEYS_FILE,
verbose=verbose,
)
handle, fname = tempfile.mkstemp(suffix=".hdf5")
os.close(handle)
try:
TracerKeys(keys=sim_fields_keys).to_file(fname)
upload_file(
task_id,
fname,
SIM_FIELDS_KEYS_FILE,
verbose=verbose,
)
except Exception as e:
td.log.error(f"Error occurred while uploading simulation fields keys: {e}")
raise e
finally:
os.unlink(fname)


""" VJP maker for ADJ pass."""


def get_vjp_traced_fields(task_id_adj: str, verbose: bool) -> AutogradFieldMap:
"""Function to grab the VJP result for the simulation fields from the adjoint task ID."""
data_file = tempfile.NamedTemporaryFile(suffix=".hdf5")
data_file.close()
download_file(task_id_adj, SIM_VJP_FILE, to_file=data_file.name, verbose=verbose)
field_map = FieldMap.from_file(data_file.name)
handle, fname = tempfile.mkstemp(suffix=".hdf5")
os.close(handle)
try:
download_file(task_id_adj, SIM_VJP_FILE, to_file=fname, verbose=verbose)
field_map = FieldMap.from_file(fname)
except Exception as e:
td.log.error(f"Error occurred while getting VJP traced fields: {e}")
raise e
finally:
os.unlink(fname)
return field_map.to_autograd_field_map


Expand Down

0 comments on commit 2687193

Please sign in to comment.