Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add _validation_data property to Metric classes #2012

Merged
merged 3 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `tidy3d.plugins.pytorch` to wrap autograd functions for interoperability with PyTorch via the `to_torch` wrapper.

### Changed
- Renamed `Metric.freqs` --> `Metric.f` and made frequency argument optional, in which case all frequencies from the relevant monitor will be extracted.
- Renamed `Metric.freqs` --> `Metric.f` and made frequency argument optional, in which case all frequencies from the relevant monitor will be extracted. Metrics can still be initialized with both `f` or `freqs`.

### Fixed
- Some validation fixes for design region.
- Bug in adjoint source creation that included empty sources for extraneous `FieldMonitor` objects, triggering unnecessary errors.
- Correct sign in objective function history depending on `Optimizer.maximize`.

## [2.7.4] - 2024-09-25

Expand Down
10 changes: 9 additions & 1 deletion tests/test_plugins/test_invdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
import tidy3d as td
import tidy3d.plugins.invdes as tdi
from tidy3d.plugins.expressions import ModePower
from tidy3d.plugins.expressions import ModeAmp, ModePower
from tidy3d.plugins.invdes.initialization import (
CustomInitializationSpec,
RandomInitializationSpec,
Expand Down Expand Up @@ -606,3 +606,11 @@ def test_validate_invdes_metric():
invdes = invdes.updated_copy(simulation=simulation.updated_copy(monitors=[monitor]))
with pytest.raises(ValueError, match="single frequency"):
invdes.updated_copy(metric=metric)

metric = ModeAmp(monitor_name=MNT_NAME2, mode_index=0) + ModePower(
monitor_name=MNT_NAME2, mode_index=0
)
monitor = mnt2.updated_copy(freqs=[FREQ0])
invdes = invdes.updated_copy(simulation=simulation.updated_copy(monitors=[monitor]))
with pytest.raises(ValueError, match="must return a real"):
invdes.updated_copy(metric=metric)
3 changes: 2 additions & 1 deletion tidy3d/plugins/expressions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .base import Expression
from .functions import Cos, Exp, Log, Log10, Sin, Sqrt, Tan
from .metrics import ModeAmp, ModePower
from .metrics import ModeAmp, ModePower, generate_validation_data
from .variables import Constant, Variable

__all__ = [
Expand All @@ -9,6 +9,7 @@
"Variable",
"ModeAmp",
"ModePower",
"generate_validation_data",
"Sin",
"Cos",
"Tan",
Expand Down
44 changes: 43 additions & 1 deletion tidy3d/plugins/expressions/metrics.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,49 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, Union

import autograd.numpy as np
import pydantic.v1 as pd
import xarray as xr

from tidy3d.components.monitor import ModeMonitor
from tidy3d.components.types import Direction, FreqArray

from .base import Expression
from .types import NumberType
from .variables import Variable


class Metric(Variable):
def generate_validation_data(expr: Expression) -> dict[str, xr.Dataset]:
"""Generate combined dummy simulation data for all metrics in the expression.

Parameters
----------
expr : Expression
The expression containing metrics.

Returns
-------
dict[str, xr.Dataset]
The combined validation data.
"""
metrics = set(expr.filter(target_type=Metric))
combined_data = {k: v for metric in metrics for k, v in metric._validation_data.items()}
return combined_data


class Metric(Variable, ABC):
"""
Base class for all metrics.

To subclass Metric, you must implement an evaluate() method that takes a SimulationData
object and returns a scalar value.
"""

@property
@abstractmethod
def _validation_data(self) -> Any:
"""Return dummy data for this metric."""

def __repr__(self) -> str:
return f'{self.type}("{self.monitor_name}")'

Expand All @@ -43,6 +69,7 @@ class ModeAmp(Metric):
None,
title="Frequency Array",
description="The frequency array. If None, all frequencies in the monitor will be used.",
alias="freqs",
)
direction: Direction = pd.Field(
"+",
Expand All @@ -63,6 +90,21 @@ def from_mode_monitor(
monitor_name=monitor.name, f=monitor.freqs, mode_index=mode_index, direction=direction
)

@property
def _validation_data(self) -> Any:
"""Return dummy data for this metric (complex array of mode amplitudes)."""
f = list(self.f) if self.f is not None else [1.0]
amps_data = np.random.rand(len(f)) + 1j * np.random.rand(len(f))
amps = xr.DataArray(
amps_data.reshape(1, 1, -1),
coords={
"direction": [self.direction],
"mode_index": [self.mode_index],
"f": f,
},
)
return {self.monitor_name: xr.Dataset({"amps": amps})}

def evaluate(self, *args: Any, **kwargs: Any) -> NumberType:
data = super().evaluate(*args, **kwargs)
amps = data[self.monitor_name].amps.sel(
Expand Down
20 changes: 19 additions & 1 deletion tidy3d/plugins/invdes/design.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tidy3d as td
from tidy3d.components.autograd import get_static
from tidy3d.exceptions import ValidationError
from tidy3d.plugins.expressions.metrics import Metric
from tidy3d.plugins.expressions.metrics import Metric, generate_validation_data
from tidy3d.plugins.expressions.types import ExpressionType

from .base import InvdesBaseModel
Expand Down Expand Up @@ -151,6 +151,7 @@ def _validate_metric(values: dict) -> dict:
InverseDesign._validate_metric_monitor_name(metric, simulation)
InverseDesign._validate_metric_mode_index(metric, simulation)
InverseDesign._validate_metric_f(metric, simulation)
InverseDesign._validate_metric_data(metric_expr, simulation)
return values

@staticmethod
Expand Down Expand Up @@ -192,6 +193,23 @@ def _validate_metric_f(metric: Metric, simulation: td.Simulation) -> None:
f"Monitor '{metric.monitor_name}' must contain only a single frequency when metric.f is None."
)

@staticmethod
def _validate_metric_data(expr: ExpressionType, simulation: td.Simulation) -> None:
"""Validate that expression can be evaluated and returns a real scalar."""
data = generate_validation_data(expr)
try:
result = expr(data)
except Exception as e:
raise ValidationError(f"Failed to evaluate the metric expression: {str(e)}") from e
if len(np.ravel(result)) > 1:
raise ValidationError(
f"The expression must return a scalar value or an array of length 1 (got {result})."
)
if not np.all(np.isreal(result)):
raise ValidationError(
f"The expression must return a real (not complex) value (got {result})."
)

def is_output_monitor(self, monitor: td.Monitor) -> bool:
"""Whether a monitor is added to the ``JaxSimulation`` as an ``output_monitor``."""

Expand Down
2 changes: 1 addition & 1 deletion tidy3d/plugins/invdes/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def continue_run(
params = anp.clip(params, a_min=0.0, a_max=1.0)

# save the history of scalar values
history["objective_fn_val"].append(val)
history["objective_fn_val"].append(aux_data["objective_fn_val"])
history["penalty"].append(penalty)
history["post_process_val"].append(post_process_val)

Expand Down
Loading