Skip to content

Commit

Permalink
Latest review changes
Browse files Browse the repository at this point in the history
Generally focusing on docs and verbose outputs
Other small bug fixes included
  • Loading branch information
m-bone committed Sep 13, 2024
1 parent 55d6c95 commit 5d8e426
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 105 deletions.
24 changes: 14 additions & 10 deletions tests/test_plugins/test_design.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@

SWEEP_METHODS = dict(
grid=tdd.MethodGrid(),
monte_carlo=tdd.MethodMonteCarlo(num_points=5, rng_seed=1),
bay_opt=tdd.MethodBayOpt(initial_iter=5, n_iter=2, rng_seed=1),
monte_carlo=tdd.MethodMonteCarlo(num_points=5, seed=1),
bay_opt=tdd.MethodBayOpt(initial_iter=5, n_iter=2, seed=1),
gen_alg=tdd.MethodGenAlg(
solutions_per_pop=6,
n_generations=2,
n_parents_mating=4,
rng_seed=1,
seed=1,
mutation_prob=0,
keep_parents=0,
),
part_swarm=tdd.MethodParticleSwarm(n_particles=3, n_iter=2, rng_seed=1),
part_swarm=tdd.MethodParticleSwarm(n_particles=3, n_iter=2, seed=1),
)

# Task names that should be produced for the different methods
Expand Down Expand Up @@ -245,12 +245,16 @@ def scs_post_list(sim_list):

def scs_pre_dict(radius: float, num_spheres: int, tag: str):
sim = scs_pre(radius, num_spheres, tag)
return {"test1": sim, "test2": sim, "3": sim}
batch = web.Batch(simulations={"a": sim, "b": sim})
return {"test1": sim, "test2": sim, "batch1": batch}


def scs_post_dict(sim_dict):
sim_data = [scs_post(sim) for sim in sim_dict.values()]
return sum(sim_data)
sim_data = [sim_dict["test1"], sim_dict["test2"]]
batched_data = [sim for _, sim in sim_dict["batch1"].items()]
sim_data.extend(batched_data)
post_sim_data = [scs_post(sim) for sim in sim_data]
return sum(post_sim_data)


def scs_pre_list_const(radius: float, num_spheres: int, tag: str):
Expand Down Expand Up @@ -690,7 +694,7 @@ def test_genalg_early_stop():
n_parents_mating=2,
stop_criteria_type="reach",
stop_criteria_number=1,
rng_seed=1,
seed=1,
)
design_space_pass = init_design_space(gen_alg_pass)

Expand All @@ -706,7 +710,7 @@ def test_genalg_early_stop():
n_parents_mating=2,
stop_criteria_type=None,
stop_criteria_number=1,
rng_seed=1,
seed=1,
)

design_space_fail = init_design_space(gen_alg_fail)
Expand All @@ -729,7 +733,7 @@ def test_genalg_run_count():
mutation_type="scramble",
crossover_prob=0.6,
crossover_type="uniform",
rng_seed=1,
seed=1,
save_solution=True,
)
design_space = init_design_space(gen_alg)
Expand Down
122 changes: 84 additions & 38 deletions tidy3d/plugins/design/design.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,27 @@
class DesignSpace(Tidy3dBaseModel):
"""Manages all exploration of a parameter space within specified parameters using a supplied search method.
The ``DesignSpace`` forms the basis of the ``Design`` plugin, and receives a ``Method`` and ``Parameter`` list that
define the scope of the design space and how it should be searched. ``DesignSpace.run()`` can then be called with
a function(s) to generate different solutions from parameters suggested by the ``Method``. The ``Method`` can either
sample the design space systematically or randomly, or can optimize for a given problem through an iterative search
and evaluate approach.
Schematic outline of how to use the ``Design`` plugin to explore a design space.
.. image:: ../../_static/img/design.png
:width: 50%
:align: left
The `Design <https://www.flexcompute.com/tidy3d/examples/notebooks/Design/>'_ notebook contains an overview of the
``Design`` plugin and is the best place to learn how to get started.
Detailed examples using the ``Design`` plugin can be found in the following notebooks:
`All-Dielectric Structural Colors <https://www.flexcompute.com/tidy3d/examples/notebooks/AllDielectricStructuralColor/>'_
`Bayesian Optimization of Y-Junction <https://www.flexcompute.com/tidy3d/examples/notebooks/BayesianOptimizationYJunction/>'_
`Genetic Algorithm Reflector <https://www.flexcompute.com/tidy3d/examples/notebooks/GeneticAlgorithmReflector/>'_
`Particle Swarm Optimizer PBS <https://www.flexcompute.com/tidy3d/examples/notebooks/ParticleSwarmOptimizedPBS/>'_
`Particle Swarm Optimizer Bullseye Cavity <https://www.flexcompute.com/tidy3d/examples/notebooks/BullseyeCavityPSO/>'_
Example
-------
>>> import tidy3d.plugins.design as tdd
Expand Down Expand Up @@ -55,17 +76,17 @@ class DesignSpace(Tidy3dBaseModel):
task_name: str = pd.Field(
"",
title="Task Name",
description="Task name assigned to tasks along with a simulation counter in the form of {task_name}_{counter}. \
If the pre function outputs a dictionary the key will be included in the task name as {task_name}_{dict_key}_{counter}. \
Only used when pre-post functions are supplied.",
description="Task name assigned to tasks along with a simulation counter in the form of {task_name}_{counter}. "
"If the pre function outputs a dictionary the key will be included in the task name as {task_name}_{dict_key}_{counter}. "
"Only used when pre-post functions are supplied.",
)

name: str = pd.Field(None, title="Name", description="Optional name for the design space.")

path_dir: str = pd.Field(
".",
title="Path Directory",
description="Directory where simulation data files will be locally saved to. Only used when pre-post functions are supplied.",
description="Directory where simulation data files will be locally saved to. Only used when pre and post functions are supplied.",
)

folder_name: str = pd.Field(
Expand Down Expand Up @@ -132,7 +153,20 @@ def run(self, fn: Callable, fn_post: Callable = None, verbose: bool = True) -> R
If used as a pre function, the output of ``fn`` must be a float, ``Simulation``, ``Batch``, list, or dict. Supplied ``Batch`` objects are
run without modification and are run in series. A list or dict of ``Simulation`` objects is flattened into a single ``Batch`` to enable
parallel computation on the cloud. The original structure is then restored for output; all `Simulation`` objects are replaced by ``SimulationData`` objects.
parallel computation on the cloud. The original structure is then restored for output; all ``Simulation`` objects are replaced by ``SimulationData`` objects.
Example pre return formats and associated post inputs can be seen in the table below.
| fn_pre return | fn_post call |
|-------------------------------------------|---------------------------------------------------|
| 1.0 | fn_post(1.0) |
| [1,2,3] | fn_post(1,2,3) |
| {'a': 2, 'b': 'hi'} | fn_post(a=2, b='hi') |
| Simulation | fn_post(SimulationData) |
| Batch | fn_post(BatchData) |
| [Simulation, Simulation] | fn_post(SimulationData, SimulationData) |
| [Simulation, 1.0] | fn_post(SimulationData, 1.0) |
| [Simulation, Batch] | fn_post(SimulationData, BatchData) |
| {'a': Simulation, 'b': Batch, 'c': 2.0} | fn_post(a=SimulationData, b=BatchData, c=2.0) |
The output of ``fn_post`` (or ``fn`` if only one function is supplied) must be a float
or a container where the first element is a ``float`` and second element is a ``list`` / ``dict`` e,g. [float {"aux_1": str}].
Expand Down Expand Up @@ -175,7 +209,7 @@ def run(self, fn: Callable, fn_post: Callable = None, verbose: bool = True) -> R

else:
fn_args, fn_values, aux_values, opt_output, sim_names, sim_paths = self.run_pre_post(
fn_pre=fn, fn_post=fn_post, console=console, verbose=verbose
fn_pre=fn, fn_post=fn_post, console=console
)

if len(sim_names) == 0:
Expand All @@ -199,12 +233,12 @@ def run_single(self, fn: Callable, console: Console) -> Tuple(list[dict], list,
evaluate_fn = self._get_evaluate_fn_single(fn=fn)
return self.method._run(run_fn=evaluate_fn, parameters=self.parameters, console=console)

def run_pre_post(
self, fn_pre: Callable, fn_post: Callable, console: Console, verbose: bool
) -> Tuple(list[dict], list[dict], list[Any]):
def run_pre_post(self, fn_pre: Callable, fn_post: Callable, console: Console) -> Tuple(
list[dict], list[dict], list[Any]
):
"""Run a function with Tidy3D implicitly called in between."""
handler = self._get_evaluate_fn_pre_post(
fn_pre=fn_pre, fn_post=fn_post, fn_mid=self._fn_mid, verbose=verbose
fn_pre=fn_pre, fn_post=fn_post, fn_mid=self._fn_mid, console=console
)
fn_args, fn_values, aux_values, opt_output = self.method._run(
run_fn=handler.fn_combined, parameters=self.parameters, console=console
Expand All @@ -223,16 +257,16 @@ def evaluate(args_list: list) -> list[Any]:
return evaluate

def _get_evaluate_fn_pre_post(
self, fn_pre: Callable, fn_post: Callable, fn_mid: Callable, verbose: bool
self, fn_pre: Callable, fn_post: Callable, fn_mid: Callable, console: Console
):
"""Get function that tries to use batch processing on a set of arguments."""

class Pre_Post_Handler:
def __init__(self, verbose):
def __init__(self, console):
self.sim_counter = 0
self.sim_names = []
self.sim_paths = []
self.verbose = verbose
self.console = console

def fn_combined(self, args_list: list[dict[str, Any]]) -> list[Any]:
"""Compute fn_pre and fn_post functions and capture other outputs."""
Expand All @@ -247,20 +281,20 @@ def fn_combined(self, args_list: list[dict[str, Any]]) -> list[Any]:
)

data, task_names, task_paths, sim_counter = fn_mid(
sim_dict, self.sim_counter, self.verbose
sim_dict, self.sim_counter, self.console
)
self.sim_names.extend(task_names)
self.sim_paths.extend(task_paths)
self.sim_counter = sim_counter
post_out = [fn_post(val) for val in data.values()]
return post_out

handler = Pre_Post_Handler(verbose)
handler = Pre_Post_Handler(console)

return handler

def _fn_mid(
self, pre_out: dict[int, Any], sim_counter: int, verbose: bool
self, pre_out: dict[int, Any], sim_counter: int, console: Console
) -> Union[dict[int, Any], BatchData]:
"""A function of the output of ``fn_pre`` that gives the input to ``fn_post``."""

Expand Down Expand Up @@ -324,20 +358,29 @@ def _find_and_map(
translate_sims[sim_name] = sim_key
sim_counter += 1

# Log the simulations and batches for the user
if console is not None:
# Writen like this to include batches on the same line if present
run_statement = f"{len(named_sims)} Simulations"
if len(batches) > 0:
run_statement = run_statement + f" and {len(batches)} user Batches"

console.log(f"Running {run_statement}")

# Running simulations and batches
sims_out = Batch(
simulations=named_sims,
folder_name=self.folder_name,
simulation_type="tidy3d_design",
verbose=verbose,
verbose=False, # Using a custom output instead of Batch.monitor updates
).run(path_dir=self.path_dir)

batch_results = {}
for batch_key, batch in batches.items():
batch_out = batch.run(path_dir=self.path_dir)
batch_results[batch_key] = batch_out

def _return_to_dict(return_dict, key, return_obj):
def _return_to_dict(return_dict: dict, key: str, return_obj: Any) -> None:
"""Recursively insert items into a dict by keys split with underscore. Only works for dict or dict of dict inputs."""
split_key = key.split("_", 1)
if len(split_key) > 1:
Expand All @@ -356,7 +399,7 @@ def _return_to_dict(return_dict, key, return_obj):
for batch_name, batch in batch_results.items():
_return_to_dict(pre_out, batch_name, batch)

def _remove_or_replace(search_dict, attr_name):
def _remove_or_replace(search_dict: dict, attr_name: str) -> dict:
"""Recursively search through a dict replacing Sims and Batches or ignoring other items thus removing them"""
new_dict = {}
for key, value in search_dict.items():
Expand Down Expand Up @@ -398,7 +441,7 @@ def run_batch(
fn_post: Callable[
Union[SimulationData, List[SimulationData], Dict[str, SimulationData]], Any
],
path_dir: str = None,
path_dir: str = ".",
**batch_kwargs,
) -> Result:
"""
Expand Down Expand Up @@ -495,15 +538,15 @@ def _estimate_sim_cost(sim):
raise ValueError("Unrecognized output from pre-function, unable to estimate cost.")

# Calculate maximum number of runs for different methods
run_count = self.method.get_run_count(self.parameters)
run_count = self.method._get_run_count(self.parameters)

# For if tidy3d server cannot determine the estimate
if per_run_estimate is None:
return None
else:
return round(per_run_estimate * run_count, 3)

def summarize(self, fn_pre: Callable = None) -> dict[str, Any]:
def summarize(self, fn_pre: Callable = None, verbose: bool = True) -> dict[str, Any]:
"""Summarize the setup of the DesignSpace
Prints a summary of the DesignSpace including the method and associated args, the parameters,
Expand All @@ -516,6 +559,8 @@ def summarize(self, fn_pre: Callable = None) -> dict[str, Any]:
Function accepting arguments that correspond to the ``name`` fields
of the ``DesignSpace.parameters``. Allows for estimated cost to be included
in the summary.
verbose: bool = True
Toggle if the summary should be output to log. If False, the dict is returned silently.
Returns
-------
Expand All @@ -541,7 +586,7 @@ def summarize(self, fn_pre: Callable = None) -> dict[str, Any]:
else:
param_values.append(f"{param.name}: {param.type} {param.span}\n")

run_count = self.method.get_run_count(self.parameters)
run_count = self.method._get_run_count(self.parameters)

# Compile data into a dict for return
summary_dict = {
Expand All @@ -553,22 +598,23 @@ def summarize(self, fn_pre: Callable = None) -> dict[str, Any]:
"max_run_count": run_count,
}

console.log(
"\nSummary of DesignSpace\n\n"
f"Method: {summary_dict['method']}\n"
f"Method Args\n{summary_dict['method_args']}\n"
f"No. of Parameters: {summary_dict['param_count']}\n"
f"Parameters: {summary_dict['param_names']}\n"
f"{summary_dict['param_vals']}\n"
f"Maximum Run Count: {summary_dict['max_run_count']}\n"
)
if verbose:
console.log(
"\nSummary of DesignSpace\n\n"
f"Method: {summary_dict['method']}\n"
f"Method Args\n{summary_dict['method_args']}\n"
f"No. of Parameters: {summary_dict['param_count']}\n"
f"Parameters: {summary_dict['param_names']}\n"
f"{summary_dict['param_vals']}\n"
f"Maximum Run Count: {summary_dict['max_run_count']}\n"
)

if fn_pre is not None:
cost_estimate = self.estimate_cost(fn_pre)
summary_dict["cost_estimate"] = cost_estimate
console.log(f"Estimated Maximum Cost: {cost_estimate} FlexCredits")
if fn_pre is not None:
cost_estimate = self.estimate_cost(fn_pre)
summary_dict["cost_estimate"] = cost_estimate
console.log(f"Estimated Maximum Cost: {cost_estimate} FlexCredits")

# NOTE: Could then add more details regarding the output of fn_pre - confirm batching?
# NOTE: Could then add more details regarding the output of fn_pre - confirm batching?

# Include additional notes/warnings
notes = []
Expand All @@ -591,7 +637,7 @@ def summarize(self, fn_pre: Callable = None) -> dict[str, Any]:
"Discrete 'int' values are automatically rounded if optimizers generate 'float' predictions.\n"
)

if len(notes) > 0:
if len(notes) > 0 and verbose:
console.log(
"Notes:",
)
Expand Down
Loading

0 comments on commit 5d8e426

Please sign in to comment.