Skip to content

Commit

Permalink
Escaping web dependence in invdes plugin until needed
Browse files Browse the repository at this point in the history
  • Loading branch information
momchil-flex committed Oct 8, 2024
1 parent 9fd7cf2 commit b18ebe8
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions tidy3d/plugins/invdes/design.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pydantic.v1 as pd

import tidy3d as td
import tidy3d.web as web
from tidy3d.components.autograd import get_static
from tidy3d.exceptions import ValidationError
from tidy3d.plugins.expressions.metrics import Metric
Expand Down Expand Up @@ -71,7 +70,7 @@ def objective_fn(params: anp.ndarray, aux_data: dict = None) -> float:
post_process_val = post_process_fn(data)
elif isinstance(data, td.SimulationData):
post_process_val = self.metric.evaluate(data)
elif isinstance(data, web.BatchData):
elif getattr(data, "type", None) == "BatchData":
raise NotImplementedError("Metrics currently do not support 'BatchData'")
else:
raise ValueError(f"Invalid data type: {type(data)}")
Expand Down Expand Up @@ -100,6 +99,21 @@ def initial_simulation(self) -> td.Simulation:
initial_params = self.design_region.initial_parameters
return self.to_simulation(initial_params)

def run(self, simulation, **kwargs) -> td.SimulationData:
"""Run a single tidy3d simulation."""
from tidy3d.web import run

kwargs.setdefault("verbose", self.verbose)
kwargs.setdefault("task_name", self.task_name)
return run(simulation, **kwargs)

def run_async(self, simulations, **kwargs) -> web.BatchData: # noqa: F821
"""Run a batch of tidy3d simulations."""
from tidy3d.web import run_async

kwargs.setdefault("verbose", self.verbose)
return run_async(simulations, **kwargs)


class InverseDesign(AbstractInverseDesign):
"""Container for an inverse design problem."""
Expand Down Expand Up @@ -221,8 +235,7 @@ def to_simulation(self, params: anp.ndarray) -> td.Simulation:
def to_simulation_data(self, params: anp.ndarray, **kwargs) -> td.SimulationData:
"""Convert the ``InverseDesign`` to a ``td.Simulation`` and run it."""
simulation = self.to_simulation(params=params)
kwargs.setdefault("task_name", self.task_name)
return web.run(simulation, verbose=self.verbose, **kwargs)
return self.run(simulation, **kwargs)


class InverseDesignMulti(AbstractInverseDesign):
Expand Down Expand Up @@ -292,11 +305,10 @@ def to_simulation(self, params: anp.ndarray) -> dict[str, td.Simulation]:
simulation_list = [design.to_simulation(params) for design in self.designs]
return dict(zip(self.task_names, simulation_list))

def to_simulation_data(self, params: anp.ndarray, **kwargs) -> web.BatchData:
def to_simulation_data(self, params: anp.ndarray, **kwargs) -> web.BatchData: # noqa: F821
"""Convert the ``InverseDesignMulti`` to a set of ``td.Simulation``s and run async."""
simulations = self.to_simulation(params)
kwargs.setdefault("verbose", self.verbose)
return web.run_async(simulations, **kwargs)
return self.run_async(simulations, **kwargs)


InverseDesignType = typing.Union[InverseDesign, InverseDesignMulti]

0 comments on commit b18ebe8

Please sign in to comment.