Skip to content

Commit

Permalink
unified validation check for dependency fields
Browse files Browse the repository at this point in the history
  • Loading branch information
dbochkov-flexcompute committed Jan 4, 2024
1 parent efb87ae commit d0a7616
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 44 deletions.
6 changes: 3 additions & 3 deletions tests/test_components/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,11 @@ def test_validate_components_none():
assert SIM._source_homogeneous_isotropic(val=None, values=SIM.dict()) is None


def test_sources_edge_case_validation():
def test_sources_edge_case_validation(log_capture):
values = SIM.dict()
values.pop("sources")
with pytest.raises(ValidationError):
SIM._warn_monitor_simulation_frequency_range(val="test", values=values)
SIM._warn_monitor_simulation_frequency_range(val="test", values=values)
assert_log_level(log_capture, "WARNING")


def test_validate_size_run_time(monkeypatch):
Expand Down
22 changes: 22 additions & 0 deletions tidy3d/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,28 @@ def _get_valid_extension(fname: str) -> str:
)


def check_previous_fields_validation(required_fields):
"""Decorate ``validator`` to check that other fields have passed validation."""

def actual_decorator(validator):
@wraps(validator)
def _validator(cls, val, values):
"""New validator function."""
for field in required_fields:
if field not in values:
log.warning(
f"Could not execute validator '{validator.__name__}' because field "
f"'{field}' failed validation."
)
return val

return validator(cls, val, values)

return _validator

return actual_decorator


class Tidy3dBaseModel(pydantic.BaseModel):
"""Base pydantic model that all Tidy3d components inherit from.
Defines configuration for handling data structures
Expand Down
10 changes: 4 additions & 6 deletions tidy3d/components/base_sim/data/sim_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..simulation import AbstractSimulation
from ...data.dataset import UnstructuredGridDatasetType
from ...base import Tidy3dBaseModel
from ...base import check_previous_fields_validation
from ...types import FieldVal
from ....exceptions import DataError, Tidy3dKeyError, ValidationError

Expand Down Expand Up @@ -51,13 +52,13 @@ def monitor_data(self) -> Dict[str, AbstractMonitorData]:
return {monitor_data.monitor.name: monitor_data for monitor_data in self.data}

@pd.validator("data", always=True)
@check_previous_fields_validation(["simulation"])
def data_monitors_match_sim(cls, val, values):
"""Ensure each :class:`AbstractMonitorData` in ``.data`` corresponds to a monitor in
``.simulation``.
"""
sim = values.get("simulation")
if sim is None:
raise ValidationError("'.simulation' failed validation, can't validate data.")

for mnt_data in val:
try:
monitor_name = mnt_data.monitor.name
Expand All @@ -70,14 +71,11 @@ def data_monitors_match_sim(cls, val, values):
return val

@pd.validator("data", always=True)
@check_previous_fields_validation(["simulation"])
def validate_no_ambiguity(cls, val, values):
"""Ensure all :class:`AbstractMonitorData` entries in ``.data`` correspond to different
monitors in ``.simulation``.
"""
sim = values.get("simulation")
if sim is None:
raise ValidationError("'.simulation' failed validation, can't validate data.")

names = [mnt_data.monitor.name for mnt_data in val]

if len(set(names)) != len(names):
Expand Down
7 changes: 3 additions & 4 deletions tidy3d/components/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from ..viz import equal_aspect, add_ax_if_none, plot_params_grid
from ..base import Tidy3dBaseModel, cached_property
from ..base import check_previous_fields_validation
from ..types import Axis, Bound, ArrayLike, Ax, Coordinate, Literal
from ..types import vtk, requires_vtk
from ...exceptions import DataError, ValidationError, Tidy3dNotImplementedError
Expand Down Expand Up @@ -524,13 +525,12 @@ def match_cells_to_vtk_type(cls, val):
return CellDataArray(val.data.astype(vtk["id_type"], copy=False), coords=val.coords)

@pd.validator("values", always=True)
@check_previous_fields_validation(["points"])
def number_of_values_matches_points(cls, val, values):
"""Check that the number of data values matches the number of grid points."""
num_values = len(val)

points = values.get("points")
if points is None:
raise ValidationError("Cannot validate '.values' because '.points' failed validation.")
num_points = len(points)

if num_points != num_values:
Expand Down Expand Up @@ -565,15 +565,14 @@ def cells_right_type(cls, val):
return val

@pd.validator("cells", always=True)
@check_previous_fields_validation(["points"])
def check_cell_vertex_range(cls, val, values):
"""Check that cell connections use only defined points."""
all_point_indices_used = val.data.ravel()
min_index_used = np.min(all_point_indices_used)
max_index_used = np.max(all_point_indices_used)

points = values.get("points")
if points is None:
raise ValidationError("Cannot validate '.values' because '.points' failed validation.")
num_points = len(points)

if max_index_used != num_points - 1 or min_index_used != 0:
Expand Down
4 changes: 2 additions & 2 deletions tidy3d/components/geometry/polyslab.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from matplotlib import path

from ..base import cached_property
from ..base import check_previous_fields_validation
from ..types import Axis, Bound, PlanePosition, ArrayFloat2D, Coordinate
from ..types import MatrixReal4x4, Shapely, trimesh
from ...log import log
Expand Down Expand Up @@ -154,6 +155,7 @@ def no_complex_self_intersecting_polygon_at_reference_plane(cls, val, values):
return val

@pydantic.validator("vertices", always=True)
@check_previous_fields_validation(["sidewall_angle"])
def no_self_intersecting_polygon_during_extrusion(cls, val, values):
"""In this simple polyslab, we don't support self-intersecting polygons yet, meaning that
any normal cross section of the PolySlab cannot be self-intersecting. This part checks
Expand All @@ -168,8 +170,6 @@ def no_self_intersecting_polygon_during_extrusion(cls, val, values):
To detect this, we sample _N_SAMPLE_POLYGON_INTERSECT cross sections to see if any creation
of polygons/holes, and changes in vertices number.
"""
if "sidewall_angle" not in values:
raise ValidationError("'sidewall_angle' failed validation.")

# no need to valiate anything here
if isclose(values["sidewall_angle"], 0):
Expand Down
24 changes: 7 additions & 17 deletions tidy3d/components/medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import xarray as xr

from .base import Tidy3dBaseModel, cached_property
from .base import check_previous_fields_validation
from .grid.grid import Coords, Grid
from .types import PoleAndResidue, Ax, FreqBound, TYPE_TAG_STR
from .types import InterpMethod, Bound, ArrayComplex3D, ArrayFloat1D
Expand Down Expand Up @@ -1155,15 +1156,13 @@ def _eps_inf_greater_no_less_than_one(cls, val):
return val

@pd.validator("conductivity", always=True)
@check_previous_fields_validation(["permittivity"])
def _conductivity_real_and_correct_shape(cls, val, values):
"""Assert conductivity is real and of right shape."""

if val is None:
return val

if values.get("permittivity") is None:
raise ValidationError("'permittivity' failed validation.")

if not CustomIsotropicMedium._validate_isreal_dataarray(val):
raise SetupError("'conductivity' must be real.")

Expand Down Expand Up @@ -1429,15 +1428,13 @@ def _eps_inf_greater_no_less_than_one(cls, val, values):
return val

@pd.validator("conductivity", always=True)
@check_previous_fields_validation(["permittivity"])
def _conductivity_non_negative_correct_shape(cls, val, values):
"""Assert conductivity>=0"""

if val is None:
return val

if values.get("permittivity") is None:
raise ValidationError("'permittivity' failed validation.")

if not CustomMedium._validate_isreal_dataarray(val):
raise SetupError("'conductivity' must be real.")

Expand Down Expand Up @@ -2387,10 +2384,9 @@ def _eps_inf_positive(cls, val):
return val

@pd.validator("poles", always=True)
@check_previous_fields_validation(["eps_inf"])
def _poles_correct_shape(cls, val, values):
"""poles must have the same shape."""
if values.get("eps_inf") is None:
raise ValidationError("'eps_inf' failed validation.")

expected_coords = values["eps_inf"].coords
for coeffs in val:
Expand Down Expand Up @@ -2900,11 +2896,9 @@ def _coeffs_unequal_f_delta(cls, val):
return val

@pd.validator("coeffs", always=True)
@check_previous_fields_validation(["eps_inf"])
def _coeffs_correct_shape(cls, val, values):
"""coeffs must have consistent shape."""
if values.get("eps_inf") is None:
raise ValidationError("'eps_inf' failed validation.")

expected_coords = values["eps_inf"].coords
for de, f, delta in val:
if (
Expand Down Expand Up @@ -3089,11 +3083,9 @@ def _eps_inf_positive(cls, val):
return val

@pd.validator("coeffs", always=True)
@check_previous_fields_validation(["eps_inf"])
def _coeffs_correct_shape_and_sign(cls, val, values):
"""coeffs must have consistent shape and sign."""
if values.get("eps_inf") is None:
raise ValidationError("'eps_inf' failed validation.")

expected_coords = values["eps_inf"].coords
for f, delta in val:
if f.coords != expected_coords or delta.coords != expected_coords:
Expand Down Expand Up @@ -3252,11 +3244,9 @@ def _eps_inf_positive(cls, val):
return val

@pd.validator("coeffs", always=True)
@check_previous_fields_validation(["eps_inf"])
def _coeffs_correct_shape(cls, val, values):
"""coeffs must have consistent shape."""
if values.get("eps_inf") is None:
raise ValidationError("'eps_inf' failed validation.")

expected_coords = values["eps_inf"].coords
for de, tau in val:
if de.coords != expected_coords or tau.coords != expected_coords:
Expand Down
9 changes: 2 additions & 7 deletions tidy3d/components/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import matplotlib as mpl

from .base import cached_property
from .base import check_previous_fields_validation
from .validators import assert_objects_in_sim_bounds
from .validators import validate_mode_objects_symmetry
from .geometry.base import Geometry, Box
Expand Down Expand Up @@ -571,19 +572,13 @@ def _warn_monitor_mediums_frequency_range(cls, val, values):
return val

@pydantic.validator("monitors", always=True)
@check_previous_fields_validation(["sources"])
def _warn_monitor_simulation_frequency_range(cls, val, values):
"""Warn if any DFT monitors have frequencies outside of the simulation frequency range."""

if val is None:
return val

# Get simulation frequency range
if "sources" not in values:
raise ValidationError(
"could not validate `_warn_monitor_simulation_frequency_range` "
"as `sources` failed validation"
)

source_ranges = [source.source_time.frequency_range() for source in values["sources"]]
if not source_ranges:
log.info("No sources in simulation.")
Expand Down
7 changes: 3 additions & 4 deletions tidy3d/plugins/dispersion/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from ...log import log, get_logging_console
from ...components.base import Tidy3dBaseModel, cached_property
from ...components.base import check_previous_fields_validation
from ...components.medium import PoleResidue, AbstractMedium
from ...components.viz import add_ax_if_none
from ...components.types import Ax, ArrayFloat1D
Expand Down Expand Up @@ -60,20 +61,18 @@ def _setup_wvl(cls, val):
return val

@validator("n_data", always=True)
@check_previous_fields_validation(["wvl_um"])
def _ndata_length_match_wvl(cls, val, values):
"""Validate n_data"""
if "wvl_um" not in values:
raise ValidationError("'wvl_um' failed validation.")

if val.shape != values["wvl_um"].shape:
raise ValidationError("The length of 'n_data' doesn't match 'wvl_um'.")
return val

@validator("k_data", always=True)
@check_previous_fields_validation(["wvl_um"])
def _kdata_setup_and_length_match(cls, val, values):
"""Validate the length of k_data, or setup k if it's None."""
if "wvl_um" not in values:
raise ValidationError("'wvl_um' failed validation.")

if val is None:
return np.zeros_like(values["wvl_um"])
Expand Down
3 changes: 2 additions & 1 deletion tidy3d/plugins/mode/mode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import xarray as xr

from ...log import log
from ...components.base import Tidy3dBaseModel, cached_property
from ...components.base import Tidy3dBaseModel, cached_property, check_previous_fields_validation
from ...components.geometry.base import Box
from ...components.simulation import Simulation
from ...components.grid.grid import Grid
Expand Down Expand Up @@ -100,6 +100,7 @@ def is_plane(cls, val):
_freqs_lower_bound = validate_freqs_min()

@pydantic.validator("plane", always=True)
@check_previous_fields_validation(["simulation"])
def plane_in_sim_bounds(cls, val, values):
"""Check that the plane is at least partially inside the simulation bounds."""
sim_center = values.get("simulation").center
Expand Down

0 comments on commit d0a7616

Please sign in to comment.