diff --git a/tidy3d/components/apodization.py b/tidy3d/components/apodization.py index 24ee7c3c4a..47f713285d 100644 --- a/tidy3d/components/apodization.py +++ b/tidy3d/components/apodization.py @@ -3,7 +3,7 @@ import pydantic.v1 as pd import numpy as np -from .base import Tidy3dBaseModel +from .base import Tidy3dBaseModel, skip_if_fields_missing from ..constants import SECOND from ..exceptions import SetupError from .types import ArrayFloat1D, Ax @@ -40,6 +40,7 @@ class ApodizationSpec(Tidy3dBaseModel): ) @pd.validator("end", always=True, allow_reuse=True) + @skip_if_fields_missing(["start"]) def end_greater_than_start(cls, val, values): """Ensure end is greater than or equal to start.""" start = values.get("start") @@ -48,6 +49,7 @@ def end_greater_than_start(cls, val, values): return val @pd.validator("width", always=True, allow_reuse=True) + @skip_if_fields_missing(["start", "end"]) def width_provided(cls, val, values): """Check that width is provided if either start or end apodization is requested.""" start = values.get("start") diff --git a/tidy3d/components/base.py b/tidy3d/components/base.py index 7cbca23efb..f57a6428f9 100644 --- a/tidy3d/components/base.py +++ b/tidy3d/components/base.py @@ -86,14 +86,14 @@ def _get_valid_extension(fname: str) -> str: ) -def check_previous_fields_validation(required_fields): +def skip_if_fields_missing(fields: List[str]): """Decorate ``validator`` to check that other fields have passed validation.""" def actual_decorator(validator): @wraps(validator) def _validator(cls, val, values): """New validator function.""" - for field in required_fields: + for field in fields: if field not in values: log.warning( f"Could not execute validator '{validator.__name__}' because field " diff --git a/tidy3d/components/base_sim/data/sim_data.py b/tidy3d/components/base_sim/data/sim_data.py index fbf008755f..9770831050 100644 --- a/tidy3d/components/base_sim/data/sim_data.py +++ b/tidy3d/components/base_sim/data/sim_data.py @@ -12,7 +12,7 @@ from ..simulation import AbstractSimulation from ...data.dataset import UnstructuredGridDatasetType from ...base import Tidy3dBaseModel -from ...base import check_previous_fields_validation +from ...base import skip_if_fields_missing from ...types import FieldVal from ....exceptions import DataError, Tidy3dKeyError, ValidationError @@ -52,7 +52,7 @@ def monitor_data(self) -> Dict[str, AbstractMonitorData]: return {monitor_data.monitor.name: monitor_data for monitor_data in self.data} @pd.validator("data", always=True) - @check_previous_fields_validation(["simulation"]) + @skip_if_fields_missing(["simulation"]) def data_monitors_match_sim(cls, val, values): """Ensure each :class:`AbstractMonitorData` in ``.data`` corresponds to a monitor in ``.simulation``. @@ -71,7 +71,7 @@ def data_monitors_match_sim(cls, val, values): return val @pd.validator("data", always=True) - @check_previous_fields_validation(["simulation"]) + @skip_if_fields_missing(["simulation"]) def validate_no_ambiguity(cls, val, values): """Ensure all :class:`AbstractMonitorData` entries in ``.data`` correspond to different monitors in ``.simulation``. diff --git a/tidy3d/components/base_sim/simulation.py b/tidy3d/components/base_sim/simulation.py index 7988f14b33..cc82734c03 100644 --- a/tidy3d/components/base_sim/simulation.py +++ b/tidy3d/components/base_sim/simulation.py @@ -9,7 +9,7 @@ from .monitor import AbstractMonitor -from ..base import cached_property +from ..base import cached_property, skip_if_fields_missing from ..validators import assert_unique_names, assert_objects_in_sim_bounds from ..geometry.base import Box from ..types import Ax, Bound, Axis, Symmetry, TYPE_TAG_STR @@ -97,6 +97,7 @@ class AbstractSimulation(Box, ABC): _structures_in_bounds = assert_objects_in_sim_bounds("structures", error=False) @pd.validator("structures", always=True) + @skip_if_fields_missing(["size", "center"]) def _structures_not_at_edges(cls, val, values): """Warn if any structures lie at the simulation boundaries.""" diff --git a/tidy3d/components/data/dataset.py b/tidy3d/components/data/dataset.py index 092a1b7a23..9e5849b1c7 100644 --- a/tidy3d/components/data/dataset.py +++ b/tidy3d/components/data/dataset.py @@ -20,7 +20,7 @@ from ..viz import equal_aspect, add_ax_if_none, plot_params_grid from ..base import Tidy3dBaseModel, cached_property -from ..base import check_previous_fields_validation +from ..base import skip_if_fields_missing from ..types import Axis, Bound, ArrayLike, Ax, Coordinate, Literal from ..types import vtk, requires_vtk from ...exceptions import DataError, ValidationError, Tidy3dNotImplementedError @@ -525,7 +525,7 @@ def match_cells_to_vtk_type(cls, val): return CellDataArray(val.data.astype(vtk["id_type"], copy=False), coords=val.coords) @pd.validator("values", always=True) - @check_previous_fields_validation(["points"]) + @skip_if_fields_missing(["points"]) def number_of_values_matches_points(cls, val, values): """Check that the number of data values matches the number of grid points.""" num_values = len(val) @@ -565,7 +565,7 @@ def cells_right_type(cls, val): return val @pd.validator("cells", always=True) - @check_previous_fields_validation(["points"]) + @skip_if_fields_missing(["points"]) def check_cell_vertex_range(cls, val, values): """Check that cell connections use only defined points.""" all_point_indices_used = val.data.ravel() diff --git a/tidy3d/components/field_projection.py b/tidy3d/components/field_projection.py index 0d7eb03352..1f4819f308 100644 --- a/tidy3d/components/field_projection.py +++ b/tidy3d/components/field_projection.py @@ -19,7 +19,7 @@ from .monitor import FieldProjectionCartesianMonitor, FieldProjectionKSpaceMonitor from .types import Direction, Coordinate, ArrayComplex4D from .medium import MediumType -from .base import Tidy3dBaseModel, cached_property +from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing from ..exceptions import SetupError from ..constants import C_0, MICROMETER, ETA_0, EPSILON_0, MU_0 from ..log import get_logging_console @@ -72,6 +72,7 @@ class FieldProjector(Tidy3dBaseModel): ) @pydantic.validator("origin", always=True) + @skip_if_fields_missing(["surfaces"]) def set_origin(cls, val, values): """Sets .origin as the average of centers of all surface monitors if not provided.""" if val is None: diff --git a/tidy3d/components/geometry/polyslab.py b/tidy3d/components/geometry/polyslab.py index 24c582f3c5..0b6ae89824 100644 --- a/tidy3d/components/geometry/polyslab.py +++ b/tidy3d/components/geometry/polyslab.py @@ -11,7 +11,7 @@ from matplotlib import path from ..base import cached_property -from ..base import check_previous_fields_validation +from ..base import skip_if_fields_missing from ..types import Axis, Bound, PlanePosition, ArrayFloat2D, Coordinate from ..types import MatrixReal4x4, Shapely, trimesh from ...log import log @@ -155,7 +155,7 @@ def no_complex_self_intersecting_polygon_at_reference_plane(cls, val, values): return val @pydantic.validator("vertices", always=True) - @check_previous_fields_validation(["sidewall_angle"]) + @skip_if_fields_missing(["sidewall_angle"]) def no_self_intersecting_polygon_during_extrusion(cls, val, values): """In this simple polyslab, we don't support self-intersecting polygons yet, meaning that any normal cross section of the PolySlab cannot be self-intersecting. This part checks diff --git a/tidy3d/components/heat/data/monitor_data.py b/tidy3d/components/heat/data/monitor_data.py index 8f3e4e080b..1120f3be98 100644 --- a/tidy3d/components/heat/data/monitor_data.py +++ b/tidy3d/components/heat/data/monitor_data.py @@ -7,6 +7,7 @@ import pydantic.v1 as pd from ..monitor import TemperatureMonitor, HeatMonitorType +from ...base import skip_if_fields_missing from ...base_sim.data.monitor_data import AbstractMonitorData from ...data.data_array import SpatialDataArray from ...data.dataset import TriangularGridDataset, TetrahedralGridDataset @@ -74,6 +75,7 @@ class TemperatureData(HeatMonitorData): ) @pd.validator("temperature", always=True) + @skip_if_fields_missing(["monitor"]) def warn_no_data(cls, val, values): """Warn if no data provided.""" diff --git a/tidy3d/components/heat/grid.py b/tidy3d/components/heat/grid.py index 0572274dde..e182db7d6d 100644 --- a/tidy3d/components/heat/grid.py +++ b/tidy3d/components/heat/grid.py @@ -4,7 +4,7 @@ from typing import Union, Tuple import pydantic.v1 as pd -from ..base import Tidy3dBaseModel +from ..base import Tidy3dBaseModel, skip_if_fields_missing from ...constants import MICROMETER from ...exceptions import ValidationError @@ -107,6 +107,7 @@ class DistanceUnstructuredGrid(Tidy3dBaseModel): ) @pd.validator("distance_bulk", always=True) + @skip_if_fields_missing(["distance_interface"]) def names_exist_bcs(cls, val, values): """Error if distance_bulk is less than distance_interface""" distance_interface = values.get("distance_interface") diff --git a/tidy3d/components/heat/simulation.py b/tidy3d/components/heat/simulation.py index b3608614f7..50431b2aa6 100644 --- a/tidy3d/components/heat/simulation.py +++ b/tidy3d/components/heat/simulation.py @@ -15,7 +15,7 @@ from .viz import plot_params_heat_bc, plot_params_heat_source, HEAT_SOURCE_CMAP from ..base_sim.simulation import AbstractSimulation -from ..base import cached_property +from ..base import cached_property, skip_if_fields_missing from ..types import Ax, Shapely, TYPE_TAG_STR, ScalarSymmetry, Bound from ..viz import add_ax_if_none, equal_aspect, PlotParams from ..structure import Structure @@ -139,6 +139,7 @@ def check_zero_dim_domain(cls, val, values): return val @pd.validator("boundary_spec", always=True) + @skip_if_fields_missing(["structures", "medium"]) def names_exist_bcs(cls, val, values): """Error if boundary conditions point to non-existing structures/media.""" @@ -175,6 +176,7 @@ def names_exist_bcs(cls, val, values): return val @pd.validator("grid_spec", always=True) + @skip_if_fields_missing(["structures"]) def names_exist_grid_spec(cls, val, values): """Warn if UniformUnstructuredGrid points at a non-existing structure.""" @@ -191,6 +193,7 @@ def names_exist_grid_spec(cls, val, values): return val @pd.validator("sources", always=True) + @skip_if_fields_missing(["structures"]) def names_exist_sources(cls, val, values): """Error if a heat source point to non-existing structures.""" structures = values.get("structures") diff --git a/tidy3d/components/medium.py b/tidy3d/components/medium.py index a2c0a94509..6a53719d17 100644 --- a/tidy3d/components/medium.py +++ b/tidy3d/components/medium.py @@ -11,7 +11,7 @@ import xarray as xr from .base import Tidy3dBaseModel, cached_property -from .base import check_previous_fields_validation +from .base import skip_if_fields_missing from .grid.grid import Coords, Grid from .types import PoleAndResidue, Ax, FreqBound, TYPE_TAG_STR from .types import InterpMethod, Bound, ArrayComplex3D, ArrayFloat1D @@ -544,6 +544,7 @@ def _validate_nonlinear_spec(self): ) @pd.validator("modulation_spec", always=True) + @skip_if_fields_missing(["nonlinear_spec"]) def _validate_modulation_spec(cls, val, values): """Check compatibility with modulation_spec.""" nonlinear_spec = values.get("nonlinear_spec") @@ -1025,6 +1026,7 @@ class Medium(AbstractMedium): ) @pd.validator("conductivity", always=True) + @skip_if_fields_missing(["allow_gain"]) def _passivity_validation(cls, val, values): """Assert passive medium if `allow_gain` is False.""" if not values.get("allow_gain") and val < 0: @@ -1036,6 +1038,7 @@ def _passivity_validation(cls, val, values): return val @pd.validator("permittivity", always=True) + @skip_if_fields_missing(["modulation_spec"]) def _permittivity_modulation_validation(cls, val, values): """Assert modulated permittivity cannot be <= 0.""" modulation = values.get("modulation_spec") @@ -1050,6 +1053,7 @@ def _permittivity_modulation_validation(cls, val, values): return val @pd.validator("conductivity", always=True) + @skip_if_fields_missing(["modulation_spec", "allow_gain"]) def _passivity_modulation_validation(cls, val, values): """Assert passive medium if `allow_gain` is False.""" modulation = values.get("modulation_spec") @@ -1156,7 +1160,7 @@ def _eps_inf_greater_no_less_than_one(cls, val): return val @pd.validator("conductivity", always=True) - @check_previous_fields_validation(["permittivity"]) + @skip_if_fields_missing(["permittivity"]) def _conductivity_real_and_correct_shape(cls, val, values): """Assert conductivity is real and of right shape.""" @@ -1171,6 +1175,7 @@ def _conductivity_real_and_correct_shape(cls, val, values): return val @pd.validator("conductivity", always=True) + @skip_if_fields_missing(["allow_gain"]) def _passivity_validation(cls, val, values): """Assert passive medium if `allow_gain` is False.""" if val is None: @@ -1357,6 +1362,7 @@ def _eps_dataset_single_frequency(cls, val): return val @pd.validator("eps_dataset", always=True) + @skip_if_fields_missing(["modulation_spec", "allow_gain"]) def _eps_dataset_eps_inf_greater_no_less_than_one_sigma_positive(cls, val, values): """Assert any eps_inf must be >=1""" if val is None: @@ -1405,6 +1411,7 @@ def _eps_dataset_eps_inf_greater_no_less_than_one_sigma_positive(cls, val, value return val @pd.validator("permittivity", always=True) + @skip_if_fields_missing(["modulation_spec"]) def _eps_inf_greater_no_less_than_one(cls, val, values): """Assert any eps_inf must be >=1""" if val is None: @@ -1428,7 +1435,7 @@ def _eps_inf_greater_no_less_than_one(cls, val, values): return val @pd.validator("conductivity", always=True) - @check_previous_fields_validation(["permittivity"]) + @skip_if_fields_missing(["permittivity", "allow_gain"]) def _conductivity_non_negative_correct_shape(cls, val, values): """Assert conductivity>=0""" @@ -1452,6 +1459,7 @@ def _conductivity_non_negative_correct_shape(cls, val, values): return val @pd.validator("conductivity", always=True) + @skip_if_fields_missing(["eps_dataset", "modulation_spec", "allow_gain"]) def _passivity_modulation_validation(cls, val, values): """Assert passive medium at any time during modulation if `allow_gain` is False.""" @@ -1816,6 +1824,7 @@ def _permittivity_modulation_validation(): """Assert modulated permittivity cannot be <= 0 at any time.""" @pd.validator("eps_inf", allow_reuse=True, always=True) + @skip_if_fields_missing(["modulation_spec"]) def _validate_permittivity_modulation(cls, val, values): """Assert modulated permittivity cannot be <= 0.""" modulation = values.get("modulation_spec") @@ -1836,6 +1845,7 @@ def _conductivity_modulation_validation(): """Assert passive medium at any time if not ``allow_gain``.""" @pd.validator("modulation_spec", allow_reuse=True, always=True) + @skip_if_fields_missing(["allow_gain"]) def _validate_conductivity_modulation(cls, val, values): """With conductivity modulation, the medium can exhibit gain during the cycle. So `allow_gain` must be True when the conductivity is modulated. @@ -2384,7 +2394,7 @@ def _eps_inf_positive(cls, val): return val @pd.validator("poles", always=True) - @check_previous_fields_validation(["eps_inf"]) + @skip_if_fields_missing(["eps_inf"]) def _poles_correct_shape(cls, val, values): """poles must have the same shape.""" @@ -2506,6 +2516,7 @@ class Sellmeier(DispersiveMedium): ) @pd.validator("coeffs", always=True) + @skip_if_fields_missing(["allow_gain"]) def _passivity_validation(cls, val, values): """Assert passive medium if `allow_gain` is False.""" if values.get("allow_gain"): @@ -2648,6 +2659,7 @@ def _correct_shape_and_sign(cls, val): return val @pd.validator("coeffs", always=True) + @skip_if_fields_missing(["allow_gain"]) def _passivity_validation(cls, val, values): """Assert passive medium if `allow_gain` is False.""" if values.get("allow_gain"): @@ -2775,6 +2787,7 @@ def _coeffs_unequal_f_delta(cls, val): return val @pd.validator("coeffs", always=True) + @skip_if_fields_missing(["allow_gain"]) def _passivity_validation(cls, val, values): """Assert passive medium if `allow_gain` is False.""" if values.get("allow_gain"): @@ -2896,7 +2909,7 @@ def _coeffs_unequal_f_delta(cls, val): return val @pd.validator("coeffs", always=True) - @check_previous_fields_validation(["eps_inf"]) + @skip_if_fields_missing(["eps_inf"]) def _coeffs_correct_shape(cls, val, values): """coeffs must have consistent shape.""" expected_coords = values["eps_inf"].coords @@ -2928,6 +2941,7 @@ def _coeffs_delta_all_smaller_or_larger_than_fi(cls, val): return val @pd.validator("coeffs", always=True) + @skip_if_fields_missing(["allow_gain"]) def _passivity_validation(cls, val, values): """Assert passive medium if `allow_gain` is False.""" allow_gain = values.get("allow_gain") @@ -3083,7 +3097,7 @@ def _eps_inf_positive(cls, val): return val @pd.validator("coeffs", always=True) - @check_previous_fields_validation(["eps_inf"]) + @skip_if_fields_missing(["eps_inf"]) def _coeffs_correct_shape_and_sign(cls, val, values): """coeffs must have consistent shape and sign.""" expected_coords = values["eps_inf"].coords @@ -3150,6 +3164,7 @@ class Debye(DispersiveMedium): ) @pd.validator("coeffs", always=True) + @skip_if_fields_missing(["allow_gain"]) def _passivity_validation(cls, val, values): """Assert passive medium if `allow_gain` is False.""" if values.get("allow_gain"): @@ -3244,7 +3259,7 @@ def _eps_inf_positive(cls, val): return val @pd.validator("coeffs", always=True) - @check_previous_fields_validation(["eps_inf"]) + @skip_if_fields_missing(["eps_inf"]) def _coeffs_correct_shape(cls, val, values): """coeffs must have consistent shape.""" expected_coords = values["eps_inf"].coords @@ -3259,6 +3274,7 @@ def _coeffs_correct_shape(cls, val, values): return val @pd.validator("coeffs", always=True) + @skip_if_fields_missing(["allow_gain"]) def _passivity_validation(cls, val, values): """Assert passive medium if `allow_gain` is False.""" allow_gain = values.get("allow_gain") @@ -3521,6 +3537,7 @@ def permittivity_spd_and_ge_one(cls, val): return val @pd.validator("conductivity", always=True) + @skip_if_fields_missing(["permittivity"]) def conductivity_commutes(cls, val, values): """Check that the symmetric part of conductivity tensor commutes with permittivity tensor (that is, simultaneously diagonalizable). @@ -3538,6 +3555,7 @@ def conductivity_commutes(cls, val, values): return val @pd.validator("conductivity", always=True) + @skip_if_fields_missing(["allow_gain"]) def _passivity_validation(cls, val, values): """Assert passive medium if `allow_gain` is False.""" if values.get("allow_gain"): diff --git a/tidy3d/components/mode.py b/tidy3d/components/mode.py index 8afabe50d9..d526d079fa 100644 --- a/tidy3d/components/mode.py +++ b/tidy3d/components/mode.py @@ -7,7 +7,7 @@ import numpy as np from ..constants import MICROMETER, RADIAN, GLANCING_CUTOFF, fp_eps -from .base import Tidy3dBaseModel +from .base import Tidy3dBaseModel, skip_if_fields_missing from .types import Axis2D, Literal, TrackFreq from ..log import log from ..exceptions import SetupError, ValidationError @@ -115,6 +115,7 @@ class ModeSpec(Tidy3dBaseModel): ) @pd.validator("bend_axis", always=True) + @skip_if_fields_missing(["bend_radius"]) def bend_axis_given(cls, val, values): """Check that ``bend_axis`` is provided if ``bend_radius`` is not ``None``""" if val is None and values.get("bend_radius") is not None: diff --git a/tidy3d/components/monitor.py b/tidy3d/components/monitor.py index b33f9dca14..516e3c1c6d 100644 --- a/tidy3d/components/monitor.py +++ b/tidy3d/components/monitor.py @@ -8,7 +8,7 @@ from .types import Ax, EMField, ArrayFloat1D, FreqArray, FreqBound, Bound, Size from .types import Literal, Direction, Coordinate, Axis, ObsGridArray, BoxSurface from .validators import assert_plane, validate_freqs_not_empty, validate_freqs_min -from .base import cached_property, Tidy3dBaseModel +from .base import cached_property, Tidy3dBaseModel, skip_if_fields_missing from .mode import ModeSpec from .apodization import ApodizationSpec from .medium import MediumType @@ -137,6 +137,7 @@ class TimeMonitor(Monitor, ABC): ) @pydantic.validator("interval", always=True) + @skip_if_fields_missing(["start", "stop"]) def _warn_interval_default(cls, val, values): """If all defaults used for time sampler, warn and set ``interval=1`` internally.""" @@ -163,6 +164,7 @@ def _warn_interval_default(cls, val, values): return val @pydantic.validator("stop", always=True, allow_reuse=True) + @skip_if_fields_missing(["start"]) def stop_greater_than_start(cls, val, values): """Ensure sure stop is greater than or equal to start.""" start = values.get("start") @@ -700,6 +702,7 @@ class AbstractFieldProjectionMonitor(SurfaceIntegrationMonitor, FreqMonitor): ) @pydantic.validator("window_size", always=True) + @skip_if_fields_missing(["size", "name"]) def window_size_for_surface(cls, val, values): """Ensures that windowing is applied for surface monitors only.""" size = values.get("size") @@ -714,6 +717,7 @@ def window_size_for_surface(cls, val, values): return val @pydantic.validator("window_size", always=True) + @skip_if_fields_missing(["name"]) def window_size_leq_one(cls, val, values): """Ensures that each component of the window size is less than or equal to 1.""" name = values.get("name") diff --git a/tidy3d/components/simulation.py b/tidy3d/components/simulation.py index ce0e32686c..b818825c50 100644 --- a/tidy3d/components/simulation.py +++ b/tidy3d/components/simulation.py @@ -9,7 +9,7 @@ import matplotlib as mpl from .base import cached_property -from .base import check_previous_fields_validation +from .base import skip_if_fields_missing from .validators import assert_objects_in_sim_bounds from .validators import validate_mode_objects_symmetry from .geometry.base import Geometry, Box @@ -225,6 +225,7 @@ def _update_simulation(cls, values): return updater.update_to_current() @pydantic.validator("grid_spec", always=True) + @skip_if_fields_missing(["sources"]) def _validate_auto_grid_wavelength(cls, val, values): """Check that wavelength can be defined if there is auto grid spec.""" if val.wavelength is None and val.auto_grid_used: @@ -243,6 +244,7 @@ def _validate_auto_grid_wavelength(cls, val, values): # _plane_waves_in_homo = validate_plane_wave_intersections() @pydantic.validator("boundary_spec", always=True) + @skip_if_fields_missing(["symmetry"]) def bloch_with_symmetry(cls, val, values): """Error if a Bloch boundary is applied with symmetry""" boundaries = val.to_list @@ -256,6 +258,7 @@ def bloch_with_symmetry(cls, val, values): return val @pydantic.validator("boundary_spec", always=True) + @skip_if_fields_missing(["medium", "size", "structures", "sources"]) def plane_wave_boundaries(cls, val, values): """Error if there are plane wave sources incompatible with boundary conditions.""" boundaries = val.to_list @@ -297,6 +300,7 @@ def plane_wave_boundaries(cls, val, values): return val @pydantic.validator("boundary_spec", always=True) + @skip_if_fields_missing(["medium", "center", "size", "structures", "sources"]) def tfsf_boundaries(cls, val, values): """Error if the boundary conditions are compatible with TFSF sources, if any.""" boundaries = val.to_list @@ -376,6 +380,7 @@ def tfsf_boundaries(cls, val, values): return val @pydantic.validator("sources", always=True) + @skip_if_fields_missing(["symmetry"]) def tfsf_with_symmetry(cls, val, values): """Error if a TFSF source is applied with symmetry""" symmetry = values.get("symmetry") @@ -385,6 +390,7 @@ def tfsf_with_symmetry(cls, val, values): return val @pydantic.validator("boundary_spec", always=True) + @skip_if_fields_missing(["size", "symmetry"]) def boundaries_for_zero_dims(cls, val, values): """Error if absorbing boundaries, unmatching pec/pmc, or symmetry is used along a zero dimension.""" boundaries = val.to_list @@ -462,6 +468,7 @@ def _validate_2d_geometry_has_2d_medium(cls, val, values): return val @pydantic.validator("boundary_spec", always=True) + @skip_if_fields_missing(["sources", "center", "size", "structures"]) def _structures_not_close_pml(cls, val, values): """Warn if any structures lie at the simulation boundaries.""" @@ -521,6 +528,7 @@ def warn(istruct, side): return val @pydantic.validator("monitors", always=True) + @skip_if_fields_missing(["medium", "structures"]) def _warn_monitor_mediums_frequency_range(cls, val, values): """Warn user if any DFT monitors have frequencies outside of medium frequency range.""" @@ -572,7 +580,7 @@ def _warn_monitor_mediums_frequency_range(cls, val, values): return val @pydantic.validator("monitors", always=True) - @check_previous_fields_validation(["sources"]) + @skip_if_fields_missing(["sources"]) def _warn_monitor_simulation_frequency_range(cls, val, values): """Warn if any DFT monitors have frequencies outside of the simulation frequency range.""" @@ -603,6 +611,7 @@ def _warn_monitor_simulation_frequency_range(cls, val, values): return val @pydantic.validator("monitors", always=True) + @skip_if_fields_missing(["boundary_spec"]) def diffraction_monitor_boundaries(cls, val, values): """If any :class:`.DiffractionMonitor` exists, ensure boundary conditions in the transverse directions are periodic or Bloch.""" @@ -627,6 +636,7 @@ def diffraction_monitor_boundaries(cls, val, values): return val @pydantic.validator("monitors", always=True) + @skip_if_fields_missing(["medium", "center", "size", "structures"]) def _projection_monitors_homogeneous(cls, val, values): """Error if any field projection monitor is not in a homogeneous region.""" @@ -721,6 +731,7 @@ def _projection_direction(cls, val, values): return val @pydantic.validator("monitors", always=True) + @skip_if_fields_missing(["size"]) def proj_distance_for_approx(cls, val, values): """Warn if projection distance for projection monitors is not large compared to monitor or, simulation size, yet far_field_approx is True.""" @@ -749,6 +760,7 @@ def proj_distance_for_approx(cls, val, values): return val @pydantic.validator("monitors", always=True) + @skip_if_fields_missing(["center", "size"]) def _integration_surfaces_in_bounds(cls, val, values): """Error if any of the integration surfaces are outside of the simulation domain.""" @@ -769,6 +781,7 @@ def _integration_surfaces_in_bounds(cls, val, values): return val @pydantic.validator("monitors", always=True) + @skip_if_fields_missing(["size"]) def _projection_monitors_distance(cls, val, values): """Warn if the projection distance is large for exact projections.""" @@ -801,6 +814,7 @@ def _projection_monitors_distance(cls, val, values): return val @pydantic.validator("monitors", always=True) + @skip_if_fields_missing(["medium", "structures"]) def diffraction_monitor_medium(cls, val, values): """If any :class:`.DiffractionMonitor` exists, ensure is does not lie in a lossy medium.""" monitors = val @@ -816,6 +830,7 @@ def diffraction_monitor_medium(cls, val, values): return val @pydantic.validator("grid_spec", always=True) + @skip_if_fields_missing(["medium", "sources", "structures"]) def _warn_grid_size_too_small(cls, val, values): """Warn user if any grid size is too large compared to minimum wavelength in material.""" @@ -877,6 +892,7 @@ def _warn_grid_size_too_small(cls, val, values): return val @pydantic.validator("sources", always=True) + @skip_if_fields_missing(["medium", "center", "size", "structures"]) def _source_homogeneous_isotropic(cls, val, values): """Error if a plane wave or gaussian beam source is not in a homogeneous and isotropic region. @@ -919,6 +935,7 @@ def _source_homogeneous_isotropic(cls, val, values): return val @pydantic.validator("normalize_index", always=True) + @skip_if_fields_missing(["sources"]) def _check_normalize_index(cls, val, values): """Check validity of normalize index in context of simulation.sources.""" diff --git a/tidy3d/components/source.py b/tidy3d/components/source.py index 406e9dde82..695fd6de9a 100644 --- a/tidy3d/components/source.py +++ b/tidy3d/components/source.py @@ -9,12 +9,12 @@ import pydantic.v1 as pydantic import numpy as np -from .base import cached_property +from .base import cached_property, skip_if_fields_missing from .base_sim.source import AbstractSource from .time import AbstractTimeDependence from .types import Coordinate, Direction, Polarization, Ax, FreqBound from .types import ArrayFloat1D, Axis, PlotVal, ArrayComplex1D, TYPE_TAG_STR -from .validators import assert_plane, assert_volumetric, get_value +from .validators import assert_plane, assert_volumetric from .validators import warn_if_dataset_none, assert_single_freq_in_range, _assert_min_freq from .data.dataset import FieldDataset, TimeDataset from .data.data_array import TimeDataArray @@ -679,11 +679,12 @@ class CustomFieldSource(FieldSource, PlanarSource): _field_dataset_single_freq = assert_single_freq_in_range("field_dataset") @pydantic.validator("field_dataset", always=True) + @skip_if_fields_missing(["size"]) def _tangential_component_defined(cls, val: FieldDataset, values: dict) -> FieldDataset: """Assert that at least one tangential field component is provided.""" if val is None: return val - size = get_value(key="size", values=values) + size = values.get("size") normal_axis = size.index(0.0) _, (cmp1, cmp2) = cls.pop_axis("xyz", axis=normal_axis) for field in "EH": diff --git a/tidy3d/components/structure.py b/tidy3d/components/structure.py index 6815db0544..7fb44bca95 100644 --- a/tidy3d/components/structure.py +++ b/tidy3d/components/structure.py @@ -3,7 +3,7 @@ import pydantic.v1 as pydantic import numpy as np -from .base import Tidy3dBaseModel +from .base import Tidy3dBaseModel, skip_if_fields_missing from .validators import validate_name_str from .geometry.utils import GeometryType, validate_no_transformed_polyslabs from .medium import MediumType, AbstractCustomMedium, Medium2D @@ -116,6 +116,7 @@ def eps_diagonal(self, frequency: float, coords: Coords) -> Tuple[complex, compl return self.medium.eps_diagonal(frequency=frequency) @pydantic.validator("medium", always=True) + @skip_if_fields_missing(["geometry"]) def _check_2d_geometry(cls, val, values): """Medium2D is only consistent with certain geometry types""" geom = values.get("geometry") diff --git a/tidy3d/components/time_modulation.py b/tidy3d/components/time_modulation.py index 3b2aa29fc0..cdfa69c7da 100644 --- a/tidy3d/components/time_modulation.py +++ b/tidy3d/components/time_modulation.py @@ -8,7 +8,7 @@ import pydantic.v1 as pd import numpy as np -from .base import Tidy3dBaseModel, cached_property +from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing from .types import InterpMethod from .time import AbstractTimeDependence from .data.data_array import SpatialDataArray @@ -231,6 +231,7 @@ class ModulationSpec(Tidy3dBaseModel): ) @pd.validator("conductivity", always=True) + @skip_if_fields_missing(["permittivity"]) def _same_modulation_frequency(cls, val, values): """Assert same time-modulation applied to permittivity and conductivity.""" permittivity = values.get("permittivity") diff --git a/tidy3d/components/validators.py b/tidy3d/components/validators.py index eb7066f3f1..9479265c19 100644 --- a/tidy3d/components/validators.py +++ b/tidy3d/components/validators.py @@ -1,5 +1,4 @@ """ Defines various validation functions that get used to ensure inputs are legit """ -from typing import Any import pydantic.v1 as pydantic import numpy as np @@ -7,7 +6,7 @@ from .geometry.base import Box from ..exceptions import ValidationError, SetupError from .data.dataset import Dataset, FieldDataset -from .base import DATA_ARRAY_MAP +from .base import DATA_ARRAY_MAP, skip_if_fields_missing from .types import Tuple from ..log import log @@ -47,14 +46,6 @@ MIN_FREQUENCY = 1e5 -def get_value(key: str, values: dict) -> Any: - """Grab value from values dictionary. If not present, raise an error before continuing.""" - val = values.get(key) - if val is None: - raise ValidationError(f"value {key} not defined, must be present to validate.") - return val - - def assert_plane(): """makes sure a field's `size` attribute has exactly 1 zero""" @@ -118,6 +109,7 @@ def validate_mode_objects_symmetry(field_name: str): obj_type = "ModeSource" if field_name == "sources" else "ModeMonitor" @pydantic.validator(field_name, allow_reuse=True, always=True) + @skip_if_fields_missing(["center", "symmetry"]) def check_symmetry(cls, val, values): """check for intersection of each structure with simulation bounds.""" sim_center = values.get("center") @@ -161,6 +153,7 @@ def assert_objects_in_sim_bounds(field_name: str, error: bool = True): """Makes sure all objects in field are at least partially inside of simulation bounds.""" @pydantic.validator(field_name, allow_reuse=True, always=True) + @skip_if_fields_missing(["center", "size"]) def objects_in_sim_bounds(cls, val, values): """check for intersection of each structure with simulation bounds.""" sim_center = values.get("center") @@ -203,6 +196,7 @@ def required_if_symmetry_present(field_name: str): """Make a field required (not None) if any non-zero symmetry eigenvalue is present.""" @pydantic.validator(field_name, allow_reuse=True, always=True) + @skip_if_fields_missing(["symmetry"]) def _make_required(cls, val, values): """Ensure val is not None if the symmetry is non-zero along any dimension.""" symmetry = values.get("symmetry") @@ -232,11 +226,12 @@ def assert_single_freq_in_range(field_name: str): """Assert only one frequency supplied in source and it's in source time range.""" @pydantic.validator(field_name, always=True, allow_reuse=True) + @skip_if_fields_missing(["source_time"]) def _single_frequency_in_range(cls, val: FieldDataset, values: dict) -> FieldDataset: """Assert only one frequency supplied and it's in source time range.""" if val is None: return val - source_time = get_value(key="source_time", values=values) + source_time = values.get("source_time") fmin, fmax = source_time.frequency_range() for name, scalar_field in val.field_components.items(): freqs = scalar_field.f diff --git a/tidy3d/plugins/adjoint/components/data/data_array.py b/tidy3d/plugins/adjoint/components/data/data_array.py index 3bac553252..280b2e61f5 100644 --- a/tidy3d/plugins/adjoint/components/data/data_array.py +++ b/tidy3d/plugins/adjoint/components/data/data_array.py @@ -11,7 +11,7 @@ from jax.tree_util import register_pytree_node_class import xarray as xr -from .....components.base import Tidy3dBaseModel, cached_property +from .....components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing from .....exceptions import DataError, Tidy3dKeyError, AdjointError @@ -48,6 +48,7 @@ def _convert_values_to_np(cls, val): return val @pd.validator("coords", always=True) + @skip_if_fields_missing(["values"]) def _coords_match_values(cls, val, values): """Make sure the coordinate dimensions and shapes match the values data.""" diff --git a/tidy3d/plugins/adjoint/components/simulation.py b/tidy3d/plugins/adjoint/components/simulation.py index 2b8f422cf3..bdcedbf3de 100644 --- a/tidy3d/plugins/adjoint/components/simulation.py +++ b/tidy3d/plugins/adjoint/components/simulation.py @@ -10,7 +10,7 @@ from jax.tree_util import register_pytree_node_class from ....log import log -from ....components.base import cached_property, Tidy3dBaseModel +from ....components.base import cached_property, Tidy3dBaseModel, skip_if_fields_missing from ....components.monitor import FieldMonitor, PermittivityMonitor from ....components.monitor import ModeMonitor, DiffractionMonitor, Monitor from ....components.simulation import Simulation @@ -189,6 +189,7 @@ def _restrict_input_structures(cls, val): return val @pd.validator("input_structures", always=True) + @skip_if_fields_missing(["structures"]) def _warn_overlap(cls, val, values): """Print appropriate warning if structures intersect in ways that cause gradient error.""" diff --git a/tidy3d/plugins/dispersion/fit.py b/tidy3d/plugins/dispersion/fit.py index cb25417418..2a0b79e73d 100644 --- a/tidy3d/plugins/dispersion/fit.py +++ b/tidy3d/plugins/dispersion/fit.py @@ -13,7 +13,7 @@ from ...log import log, get_logging_console from ...components.base import Tidy3dBaseModel, cached_property -from ...components.base import check_previous_fields_validation +from ...components.base import skip_if_fields_missing from ...components.medium import PoleResidue, AbstractMedium from ...components.viz import add_ax_if_none from ...components.types import Ax, ArrayFloat1D @@ -61,7 +61,7 @@ def _setup_wvl(cls, val): return val @validator("n_data", always=True) - @check_previous_fields_validation(["wvl_um"]) + @skip_if_fields_missing(["wvl_um"]) def _ndata_length_match_wvl(cls, val, values): """Validate n_data""" @@ -70,7 +70,7 @@ def _ndata_length_match_wvl(cls, val, values): return val @validator("k_data", always=True) - @check_previous_fields_validation(["wvl_um"]) + @skip_if_fields_missing(["wvl_um"]) def _kdata_setup_and_length_match(cls, val, values): """Validate the length of k_data, or setup k if it's None.""" diff --git a/tidy3d/plugins/mode/mode_solver.py b/tidy3d/plugins/mode/mode_solver.py index fc4c003d8d..5a5f116350 100644 --- a/tidy3d/plugins/mode/mode_solver.py +++ b/tidy3d/plugins/mode/mode_solver.py @@ -11,7 +11,7 @@ import xarray as xr from ...log import log -from ...components.base import Tidy3dBaseModel, cached_property, check_previous_fields_validation +from ...components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing from ...components.geometry.base import Box from ...components.simulation import Simulation from ...components.grid.grid import Grid @@ -100,7 +100,7 @@ def is_plane(cls, val): _freqs_lower_bound = validate_freqs_min() @pydantic.validator("plane", always=True) - @check_previous_fields_validation(["simulation"]) + @skip_if_fields_missing(["simulation"]) def plane_in_sim_bounds(cls, val, values): """Check that the plane is at least partially inside the simulation bounds.""" sim_center = values.get("simulation").center