From 0f689738e6d7966e8dbe2b5658c7e4e5e902adfd Mon Sep 17 00:00:00 2001 From: Yannick Augenstein Date: Wed, 23 Oct 2024 11:17:32 +0200 Subject: [PATCH] fix: remove adjoint temporary simulation files --- CHANGELOG.md | 2 ++ tidy3d/plugins/adjoint/web.py | 42 ++++++++++++++++++----------- tidy3d/web/api/autograd/autograd.py | 39 ++++++++++++++++++--------- 3 files changed, 55 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index df019607f8..9a266e983a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Minor gradient direction and normalization fixes for polyslab, field monitors, and diffraction monitors in autograd. +### Fixed +- Resolved an issue where temporary files for adjoint simulations were not being deleted properly. ## [2.7.5] - 2024-10-16 diff --git a/tidy3d/plugins/adjoint/web.py b/tidy3d/plugins/adjoint/web.py index acb33479e3..0357b6b7c2 100644 --- a/tidy3d/plugins/adjoint/web.py +++ b/tidy3d/plugins/adjoint/web.py @@ -1,5 +1,6 @@ """Adjoint-specific webapi.""" +import os import tempfile from functools import partial from typing import Dict, List, Tuple @@ -8,6 +9,7 @@ from jax import custom_vjp from jax.tree_util import register_pytree_node_class +import tidy3d as td from tidy3d.web.api.asynchronous import run_async as web_run_async from tidy3d.web.api.webapi import run as web_run from tidy3d.web.api.webapi import wait_for_connection @@ -219,26 +221,36 @@ def run_bwd( @wait_for_connection def upload_jax_info(jax_info: JaxInfo, task_id: str, verbose: bool) -> None: """Upload jax_info for a task with a given task_id.""" - - data_file = tempfile.NamedTemporaryFile(suffix=".json") - data_file.close() - jax_info.to_file(data_file.name) - upload_file( - task_id, - data_file.name, - JAX_INFO_FILE, - verbose=verbose, - ) + handle, fname = tempfile.mkstemp(suffix=".json") + os.close(handle) + try: + jax_info.to_file(fname) + upload_file( + task_id, + fname, + JAX_INFO_FILE, + verbose=verbose, + ) + except Exception as e: + td.log.error(f"Error occurred while uploading 'jax_info': {e}") + raise e + finally: + os.unlink(fname) @wait_for_connection def download_sim_vjp(task_id: str, verbose: bool) -> JaxSimulation: """Download the vjp loaded simulation from the server to return to jax.""" - - data_file = tempfile.NamedTemporaryFile(suffix=".hdf5") - data_file.close() - download_file(task_id, SIM_VJP_FILE, to_file=data_file.name, verbose=verbose) - return JaxSimulation.from_file(data_file.name) + handle, fname = tempfile.mkstemp(suffix=".hdf5") + os.close(handle) + try: + download_file(task_id, SIM_VJP_FILE, to_file=fname, verbose=verbose) + return JaxSimulation.from_file(fname) + except Exception as e: + td.log.error(f"Error occurred while downloading 'sim_vjp': {e}") + raise e + finally: + os.unlink(fname) AdjointSimulationType = Literal["tidy3d", "adjoint_fwd", "adjoint_bwd"] diff --git a/tidy3d/web/api/autograd/autograd.py b/tidy3d/web/api/autograd/autograd.py index a99247c8de..5abe68967e 100644 --- a/tidy3d/web/api/autograd/autograd.py +++ b/tidy3d/web/api/autograd/autograd.py @@ -1,5 +1,6 @@ # autograd wrapper for web functions +import os import tempfile import typing from collections import defaultdict @@ -529,15 +530,21 @@ def postprocess_fwd( def upload_sim_fields_keys(sim_fields_keys: list[tuple], task_id: str, verbose: bool = False): """Function to grab the VJP result for the simulation fields from the adjoint task ID.""" - data_file = tempfile.NamedTemporaryFile(suffix=".hdf5") - data_file.close() - TracerKeys(keys=sim_fields_keys).to_file(data_file.name) - upload_file( - task_id, - data_file.name, - SIM_FIELDS_KEYS_FILE, - verbose=verbose, - ) + handle, fname = tempfile.mkstemp(suffix=".hdf5") + os.close(handle) + try: + TracerKeys(keys=sim_fields_keys).to_file(fname) + upload_file( + task_id, + fname, + SIM_FIELDS_KEYS_FILE, + verbose=verbose, + ) + except Exception as e: + td.log.error(f"Error occurred while uploading simulation fields keys: {e}") + raise e + finally: + os.unlink(fname) """ VJP maker for ADJ pass.""" @@ -545,10 +552,16 @@ def upload_sim_fields_keys(sim_fields_keys: list[tuple], task_id: str, verbose: def get_vjp_traced_fields(task_id_adj: str, verbose: bool) -> AutogradFieldMap: """Function to grab the VJP result for the simulation fields from the adjoint task ID.""" - data_file = tempfile.NamedTemporaryFile(suffix=".hdf5") - data_file.close() - download_file(task_id_adj, SIM_VJP_FILE, to_file=data_file.name, verbose=verbose) - field_map = FieldMap.from_file(data_file.name) + handle, fname = tempfile.mkstemp(suffix=".hdf5") + os.close(handle) + try: + download_file(task_id_adj, SIM_VJP_FILE, to_file=fname, verbose=verbose) + field_map = FieldMap.from_file(fname) + except Exception as e: + td.log.error(f"Error occurred while getting VJP traced fields: {e}") + raise e + finally: + os.unlink(fname) return field_map.to_autograd_field_map