diff --git a/tests/test_components/test_heat.py b/tests/test_components/test_heat.py index d1792e2332..9880524944 100644 --- a/tests/test_components/test_heat.py +++ b/tests/test_components/test_heat.py @@ -26,7 +26,6 @@ from tidy3d import TemperatureMonitor from tidy3d import TemperatureData -from tidy3d.exceptions import DataError from ..utils import STL_GEO, assert_log_level, log_capture @@ -349,7 +348,7 @@ def test_sim_data(): _ = heat_sim_data.plot_field("test", z=0) plt.close() - with pytest.raises(DataError): + with pytest.raises(KeyError): _ = heat_sim_data.plot_field("test3", x=0) with pytest.raises(pd.ValidationError): diff --git a/tests/test_components/test_scene.py b/tests/test_components/test_scene.py index 3c6c5939b6..3310c9ff6c 100644 --- a/tests/test_components/test_scene.py +++ b/tests/test_components/test_scene.py @@ -5,7 +5,7 @@ import numpy as np import tidy3d as td -from tidy3d.components.simulation import MAX_NUM_MEDIUMS +from tidy3d.components.scene import MAX_NUM_MEDIUMS, MAX_GEOMETRY_COUNT from ..utils import assert_log_level, log_capture, SIM_FULL SCENE = td.Scene() @@ -187,32 +187,6 @@ def test_names_unique(): ) -def test_allow_gain(): - """Test if simulation allows gain.""" - - medium = td.Medium(permittivity=2.0) - medium_gain = td.Medium(permittivity=2.0, allow_gain=True) - medium_ani = td.AnisotropicMedium(xx=medium, yy=medium, zz=medium) - medium_gain_ani = td.AnisotropicMedium(xx=medium, yy=medium_gain, zz=medium) - - # Test simulation medium - scene = td.Scene(medium=medium) - assert not scene.allow_gain - scene = scene.updated_copy(medium=medium_gain) - assert scene.allow_gain - - # Test structure with anisotropic gain medium - struct = td.Structure(geometry=td.Box(center=(0, 0, 0), size=(1, 1, 1)), medium=medium_ani) - struct_gain = struct.updated_copy(medium=medium_gain_ani) - scene = td.Scene( - medium=medium, - structures=[struct], - ) - assert not scene.allow_gain - scene = scene.updated_copy(structures=[struct_gain]) - assert scene.allow_gain - - def test_perturbed_mediums_copy(): # Non-dispersive @@ -272,3 +246,36 @@ def test_perturbed_mediums_copy(): assert isinstance(new_scene.medium, td.CustomMedium) assert isinstance(new_scene.structures[0].medium, td.CustomPoleResidue) + + +def test_max_geometry_validation(): + too_many = [td.Box(size=(1, 1, 1)) for _ in range(MAX_GEOMETRY_COUNT + 1)] + + fine = [ + td.Structure( + geometry=td.ClipOperation( + operation="union", + geometry_a=td.Box(size=(1, 1, 1)), + geometry_b=td.GeometryGroup(geometries=too_many), + ), + medium=td.Medium(permittivity=2.0), + ), + td.Structure( + geometry=td.GeometryGroup(geometries=too_many), + medium=td.Medium(permittivity=2.0), + ), + ] + _ = td.Scene(structures=fine) + + not_fine = [ + td.Structure( + geometry=td.ClipOperation( + operation="difference", + geometry_a=td.Box(size=(1, 1, 1)), + geometry_b=td.GeometryGroup(geometries=too_many), + ), + medium=td.Medium(permittivity=2.0), + ), + ] + with pytest.raises(pd.ValidationError, match=f" {MAX_GEOMETRY_COUNT + 2} "): + _ = td.Scene(structures=not_fine) diff --git a/tests/test_components/test_simulation.py b/tests/test_components/test_simulation.py index d591f6f81b..ad6e410b04 100644 --- a/tests/test_components/test_simulation.py +++ b/tests/test_components/test_simulation.py @@ -8,7 +8,7 @@ import tidy3d as td from tidy3d.exceptions import SetupError, ValidationError, Tidy3dKeyError from tidy3d.components import simulation -from tidy3d.components.simulation import MAX_NUM_MEDIUMS +from tidy3d.components.scene import MAX_NUM_MEDIUMS, MAX_GEOMETRY_COUNT from ..utils import assert_log_level, SIM_FULL, log_capture, run_emulated from tidy3d.constants import LARGE_NUMBER @@ -492,7 +492,6 @@ def test_validate_zero_dim_boundaries(log_capture): def test_validate_components_none(): assert SIM._structures_not_at_edges(val=None, values=SIM.dict()) is None - assert SIM._validate_num_mediums(val=None) is None assert SIM._warn_monitor_mediums_frequency_range(val=None, values=SIM.dict()) is None assert SIM._warn_monitor_simulation_frequency_range(val=None, values=SIM.dict()) is None assert SIM._warn_grid_size_too_small(val=None, values=SIM.dict()) is None @@ -537,7 +536,7 @@ def test_validate_mnt_size(monkeypatch, log_capture): def test_max_geometry_validation(): gs = td.GridSpec(wavelength=1.0) - too_many = [td.Box(size=(1, 1, 1)) for _ in range(simulation.MAX_GEOMETRY_COUNT + 1)] + too_many = [td.Box(size=(1, 1, 1)) for _ in range(MAX_GEOMETRY_COUNT + 1)] fine = [ td.Structure( @@ -565,7 +564,7 @@ def test_max_geometry_validation(): medium=td.Medium(permittivity=2.0), ), ] - with pytest.raises(pydantic.ValidationError, match=f" {simulation.MAX_GEOMETRY_COUNT + 2} "): + with pytest.raises(pydantic.ValidationError, match=f" {MAX_GEOMETRY_COUNT + 2} "): _ = td.Simulation(size=(1, 1, 1), run_time=1, grid_spec=gs, structures=not_fine) @@ -581,7 +580,6 @@ def test_plot_structure(): def test_plot_eps(): ax = SIM_FULL.plot_eps(x=0) - SIM_FULL._add_cbar(eps_min=1, eps_max=2, ax=ax) plt.close() @@ -724,26 +722,6 @@ def test_discretize_non_intersect(log_capture): assert_log_level(log_capture, "ERROR") -def test_filter_structures(): - s1 = td.Structure(geometry=td.Box(size=(1, 1, 1)), medium=SIM.medium) - s2 = td.Structure(geometry=td.Box(size=(1, 1, 1), center=(1, 1, 1)), medium=SIM.medium) - plane = td.Box(center=(0, 0, 1.5), size=(td.inf, td.inf, 0)) - SIM._filter_structures_plane(structures=[s1, s2], plane=plane) - - -def test_get_structure_plot_params(): - pp = SIM_FULL._get_structure_plot_params(mat_index=0, medium=SIM_FULL.medium) - assert pp.facecolor == "white" - pp = SIM_FULL._get_structure_plot_params(mat_index=1, medium=td.PEC) - assert pp.facecolor == "gold" - pp = SIM_FULL._get_structure_eps_plot_params( - medium=SIM_FULL.medium, freq=1, eps_min=1, eps_max=2 - ) - assert float(pp.facecolor) == 1.0 - pp = SIM_FULL._get_structure_eps_plot_params(medium=td.PEC, freq=1, eps_min=1, eps_max=2) - assert pp.facecolor == "gold" - - def test_warn_sim_background_medium_freq_range(log_capture): _ = SIM.copy( update=dict( diff --git a/tidy3d/components/base_sim/data/sim_data.py b/tidy3d/components/base_sim/data/sim_data.py index 3338d73781..be2172d955 100644 --- a/tidy3d/components/base_sim/data/sim_data.py +++ b/tidy3d/components/base_sim/data/sim_data.py @@ -41,8 +41,6 @@ class AbstractSimulationData(Tidy3dBaseModel, ABC): def __getitem__(self, monitor_name: str) -> AbstractMonitorData: """Get a :class:`.AbstractMonitorData` by name. Apply symmetry if applicable.""" - if monitor_name not in self.monitor_data: - raise DataError(f"'{self.type}' does not contain data for monitor '{monitor_name}'.") monitor_data = self.monitor_data[monitor_name] return monitor_data.symmetry_expanded_copy diff --git a/tidy3d/components/base_sim/monitor.py b/tidy3d/components/base_sim/monitor.py index 65cc1ddc03..ff13887fd3 100644 --- a/tidy3d/components/base_sim/monitor.py +++ b/tidy3d/components/base_sim/monitor.py @@ -6,14 +6,10 @@ import numpy as np from ..types import ArrayFloat1D, Numpy -from ..types import Direction, Axis, BoxSurface +from ..types import Axis from ..geometry.base import Box -from ..validators import assert_plane from ..base import cached_property from ..viz import PlotParams, plot_params_monitor -from ...constants import SECOND -from ...exceptions import SetupError -from ...log import log class AbstractMonitor(Box, ABC): @@ -94,140 +90,3 @@ def downsampled_num_cells(self, num_cells: Tuple[int, int, int]) -> Tuple[int, i """ arrs = [np.arange(ncells) for ncells in num_cells] return tuple((self.downsample(arr, axis=dim).size for dim, arr in enumerate(arrs))) - - -class AbstractTimeMonitor(AbstractMonitor, ABC): - """Abstract base class for transient monitors.""" - - start: pd.NonNegativeFloat = pd.Field( - 0.0, - title="Start time", - description="Time at which to start monitor recording.", - units=SECOND, - ) - - stop: pd.NonNegativeFloat = pd.Field( - None, - title="Stop time", - description="Time at which to stop monitor recording. " - "If not specified, record until end of simulation.", - units=SECOND, - ) - - interval: pd.PositiveInt = pd.Field( - 1, - title="Time interval", - description="Number of time step intervals between monitor recordings.", - ) - - @pd.validator("stop", always=True, allow_reuse=True) - def stop_greater_than_start(cls, val, values): - """Ensure sure stop is greater than or equal to start.""" - start = values.get("start") - if val and val < start: - raise SetupError("Monitor start time is greater than stop time.") - return val - - def time_inds(self, tmesh: ArrayFloat1D) -> Tuple[int, int]: - """Compute the starting and stopping index of the monitor in a given discrete time mesh.""" - - tmesh = np.array(tmesh) - tind_beg, tind_end = (0, 0) - - if tmesh.size == 0: - return (tind_beg, tind_end) - - # If monitor.stop is None, record until the end - t_stop = self.stop - if t_stop is None: - tind_end = int(tmesh.size) - t_stop = tmesh[-1] - else: - tend = np.nonzero(tmesh <= t_stop)[0] - if tend.size > 0: - tind_end = int(tend[-1] + 1) - - # Step to compare to in order to handle t_start = t_stop - dt = 1e-20 if np.array(tmesh).size < 2 else tmesh[1] - tmesh[0] - # If equal start and stopping time, record one time step - if np.abs(self.start - t_stop) < dt and self.start <= tmesh[-1]: - tind_beg = max(tind_end - 1, 0) - else: - tbeg = np.nonzero(tmesh[:tind_end] >= self.start)[0] - tind_beg = tbeg[0] if tbeg.size > 0 else tind_end - return (tind_beg, tind_end) - - def num_steps(self, tmesh: ArrayFloat1D) -> int: - """Compute number of time steps for a time monitor.""" - - tind_beg, tind_end = self.time_inds(tmesh) - return int((tind_end - tind_beg) / self.interval) - - -class AbstractPlanarMonitor(AbstractMonitor, ABC): - """:class:`AbstractMonitor` that has a planar geometry.""" - - _plane_validator = assert_plane() - - @cached_property - def normal_axis(self) -> Axis: - """Axis normal to the monitor's plane.""" - return self.size.index(0.0) - - -class AbstractSurfaceIntegrationMonitor(AbstractMonitor, ABC): - """Abstract class for monitors that perform surface integrals during the solver run.""" - - normal_dir: Direction = pd.Field( - None, - title="Normal vector orientation", - description="Direction of the surface monitor's normal vector w.r.t. " - "the positive x, y or z unit vectors. Must be one of ``'+'`` or ``'-'``. " - "Applies to surface monitors only, and defaults to ``'+'`` if not provided.", - ) - - exclude_surfaces: Tuple[BoxSurface, ...] = pd.Field( - None, - title="Excluded surfaces", - description="Surfaces to exclude in the integration, if a volume monitor.", - ) - - @property - def integration_surfaces(self): - """Surfaces of the monitor where fields will be recorded for subsequent integration.""" - if self.size.count(0.0) == 0: - return self.surfaces_with_exclusion(**self.dict()) - return [self] - - @pd.root_validator(skip_on_failure=True) - def normal_dir_exists_for_surface(cls, values): - """If the monitor is a surface, set default ``normal_dir`` if not provided. - If the monitor is a box, warn that ``normal_dir`` is relevant only for surfaces.""" - normal_dir = values.get("normal_dir") - name = values.get("name") - size = values.get("size") - if size.count(0.0) != 1: - if normal_dir is not None: - log.warning( - "The ``normal_dir`` field is relevant only for surface monitors " - f"and will be ignored for monitor {name}, which is a box." - ) - else: - if normal_dir is None: - values["normal_dir"] = "+" - return values - - @pd.root_validator(skip_on_failure=True) - def check_excluded_surfaces(cls, values): - """Error if ``exclude_surfaces`` is provided for a surface monitor.""" - exclude_surfaces = values.get("exclude_surfaces") - if exclude_surfaces is None: - return values - name = values.get("name") - size = values.get("size") - if size.count(0.0) > 0: - raise SetupError( - f"Can't specify ``exclude_surfaces`` for surface monitor {name}; " - "valid for box monitors only." - ) - return values diff --git a/tidy3d/components/base_sim/simulation.py b/tidy3d/components/base_sim/simulation.py index a08b55efb6..a8631cf823 100644 --- a/tidy3d/components/base_sim/simulation.py +++ b/tidy3d/components/base_sim/simulation.py @@ -7,7 +7,7 @@ import pydantic.v1 as pd -from .monitor import AbstractSurfaceIntegrationMonitor +from .monitor import AbstractMonitor from ..base import cached_property from ..validators import assert_unique_names, assert_objects_in_sim_bounds @@ -22,7 +22,7 @@ from ..viz import PlotParams, plot_params_symmetry from ...version import __version__ -from ...exceptions import Tidy3dKeyError, SetupError +from ...exceptions import Tidy3dKeyError from ...log import log @@ -91,30 +91,11 @@ class AbstractSimulation(Box, ABC): # make sure all names are unique _unique_monitor_names = assert_unique_names("monitors") _unique_structure_names = assert_unique_names("structures") + _unique_source_names = assert_unique_names("sources") _monitors_in_bounds = assert_objects_in_sim_bounds("monitors") _structures_in_bounds = assert_objects_in_sim_bounds("structures", error=False) - @pd.validator("monitors", always=True) - def _integration_surfaces_in_bounds(cls, val, values): - """Error if any of the integration surfaces are outside of the simulation domain.""" - - if val is None: - return val - - sim_center = values.get("center") - sim_size = values.get("size") - sim_box = Box(size=sim_size, center=sim_center) - - for mnt in (mnt for mnt in val if isinstance(mnt, AbstractSurfaceIntegrationMonitor)): - if not any(sim_box.intersects(surf) for surf in mnt.integration_surfaces): - raise SetupError( - f"All integration surfaces of monitor '{mnt.name}' are outside of the " - "simulation bounds." - ) - - return val - @pd.validator("structures", always=True) def _structures_not_at_edges(cls, val, values): """Warn if any structures lie at the simulation boundaries.""" @@ -135,14 +116,21 @@ def _structures_not_at_edges(cls, val, values): if isclose(sim_val, struct_val): consolidated_logger.warning( - f"Structure at structures[{istruct}] has bounds that extend exactly to " - "simulation edges. This can cause unexpected behavior. " + f"Structure at 'structures[{istruct}]' has bounds that extend exactly " + "to simulation edges. This can cause unexpected behavior. " "If intending to extend the structure to infinity along one dimension, " - "use td.inf as a size variable instead to make this explicit." + "use td.inf as a size variable instead to make this explicit.", + custom_loc=["structures", istruct], ) return val + """ Post-init validators """ + + def _post_init_validators(self) -> None: + """Call validators taking z`self` that get run after init.""" + _ = self.scene + """ Accounting """ @cached_property @@ -151,7 +139,7 @@ def scene(self) -> Scene: return Scene(medium=self.medium, structures=self.structures) - def get_monitor_by_name(self, name: str): # -> Monitor: + def get_monitor_by_name(self, name: str) -> AbstractMonitor: """Return monitor named 'name'.""" for monitor in self.monitors: if monitor.name == name: @@ -271,7 +259,7 @@ def plot_sources( matplotlib.axes._subplots.Axes The supplied or created matplotlib axes. """ - bounds = self.simulation_bounds + bounds = self.bounds for source in self.sources: ax = source.plot(x=x, y=y, z=z, alpha=alpha, ax=ax, sim_bounds=bounds) ax = Scene._set_plot_bounds( @@ -315,7 +303,7 @@ def plot_monitors( matplotlib.axes._subplots.Axes The supplied or created matplotlib axes. """ - bounds = self.simulation_bounds + bounds = self.bounds for monitor in self.monitors: ax = monitor.plot(x=x, y=y, z=z, alpha=alpha, ax=ax, sim_bounds=bounds) ax = Scene._set_plot_bounds( diff --git a/tidy3d/components/base_sim/source.py b/tidy3d/components/base_sim/source.py new file mode 100644 index 0000000000..d2089523b4 --- /dev/null +++ b/tidy3d/components/base_sim/source.py @@ -0,0 +1,21 @@ +"""Abstract base for classes that define simulation sources.""" +from __future__ import annotations +from abc import ABC, abstractmethod +import pydantic.v1 as pydantic + +from ..base import Tidy3dBaseModel + +from ..validators import validate_name_str +from ..viz import PlotParams + + +class AbstractSource(Tidy3dBaseModel, ABC): + """Abstract base class for all sources.""" + + name: str = pydantic.Field(None, title="Name", description="Optional name for the source.") + + @abstractmethod + def plot_params(self) -> PlotParams: + """Default parameters for plotting a Source object.""" + + _name_validator = validate_name_str() diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index eb0b0a95aa..71e9662738 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -35,11 +35,13 @@ from ...constants import ETA_0, C_0, MICROMETER from ...log import log +from ..base_sim.data.monitor_data import AbstractMonitorData + Coords1D = ArrayFloat1D -class MonitorData(Dataset, ABC): +class MonitorData(AbstractMonitorData, ABC): """Abstract base class of objects that store data pertaining to a single :class:`.monitor`.""" monitor: MonitorType = pd.Field( @@ -54,11 +56,6 @@ def symmetry_expanded(self) -> MonitorData: """Return self with symmetry applied.""" return self - @property - def symmetry_expanded_copy(self) -> MonitorData: - """Return copy of self with symmetry applied.""" - return self.copy() - def normalize(self, source_spectrum_fn: Callable[[float], complex]) -> Dataset: """Return copy of self after normalization is applied using source spectrum function.""" return self.copy() diff --git a/tidy3d/components/data/sim_data.py b/tidy3d/components/data/sim_data.py index 74f67eb730..f52fe2d990 100644 --- a/tidy3d/components/data/sim_data.py +++ b/tidy3d/components/data/sim_data.py @@ -1,26 +1,27 @@ """ Simulation Level Data """ from __future__ import annotations -from typing import Dict, Callable, Tuple +from typing import Callable, Tuple import xarray as xr import pydantic.v1 as pd import numpy as np from .monitor_data import MonitorDataTypes, MonitorDataType, AbstractFieldData, FieldTimeData -from ..base import Tidy3dBaseModel from ..simulation import Simulation from ..boundary import BlochBoundary from ..source import TFSF from ..types import Ax, Axis, annotate_type, FieldVal, PlotScale, ColormapType from ..viz import equal_aspect, add_ax_if_none -from ...exceptions import DataError, Tidy3dKeyError, ValidationError +from ...exceptions import DataError, Tidy3dKeyError from ...log import log +from ..base_sim.data.sim_data import AbstractSimulationData + DATA_TYPE_MAP = {data.__fields__["monitor"].type_: data for data in MonitorDataTypes} -class SimulationData(Tidy3dBaseModel): +class SimulationData(AbstractSimulationData): """Stores data from a collection of :class:`.Monitor` objects in a :class:`.Simulation`. Example @@ -75,45 +76,12 @@ class SimulationData(Tidy3dBaseModel): "associated with the monitors of the original :class:`.Simulation`.", ) - log: str = pd.Field( - None, - title="Solver Log", - description="A string containing the log information from the simulation run.", - ) - diverged: bool = pd.Field( False, title="Diverged", description="A boolean flag denoting whether the simulation run diverged.", ) - def __getitem__(self, monitor_name: str) -> MonitorDataType: - """Get a :class:`.MonitorData` by name. Apply symmetry if applicable.""" - monitor_data = self.monitor_data[monitor_name] - return monitor_data.symmetry_expanded_copy - - @property - def monitor_data(self) -> Dict[str, MonitorDataType]: - """Dictionary mapping monitor name to its associated :class:`.MonitorData`.""" - return {monitor_data.monitor.name: monitor_data for monitor_data in self.data} - - @pd.validator("data", always=True) - def data_monitors_match_sim(cls, val, values): - """Ensure each MonitorData in ``.data`` corresponds to a monitor in ``.simulation``.""" - sim = values.get("simulation") - if sim is None: - raise ValidationError("Simulation.simulation failed validation, can't validate data.") - for mnt_data in val: - try: - monitor_name = mnt_data.monitor.name - sim.get_monitor_by_name(monitor_name) - except Tidy3dKeyError as exc: - raise DataError( - f"Data with monitor name {monitor_name} supplied " - "but not found in the Simulation" - ) from exc - return val - @property def final_decay_value(self) -> float: """Returns value of the field decay at the final time step.""" @@ -315,44 +283,6 @@ def get_poynting_vector(self, field_monitor_name: str) -> xr.Dataset: return xr.Dataset(poynting_components) - @staticmethod - def _field_component_value(field_component: xr.DataArray, val: FieldVal) -> xr.DataArray: - """return the desired value of a field component. - - Parameter - ---------- - field_component : xarray.DataArray - Field component from which to calculate the value. - val : Literal['real', 'imag', 'abs', 'abs^2', 'phase'] - Which part of the field to return. - - Returns - ------- - xarray.DataArray - Value extracted from the field component. - """ - if val == "real": - field_value = field_component.real - field_value.name = f"Re{{{field_component.name}}}" - - elif val == "imag": - field_value = field_component.imag - field_value.name = f"Im{{{field_component.name}}}" - - elif val == "abs": - field_value = np.abs(field_component) - field_value.name = f"|{field_component.name}|" - - elif val == "abs^2": - field_value = np.abs(field_component) ** 2 - field_value.name = f"|{field_component.name}|²" - - elif val == "phase": - field_value = np.arctan2(field_component.imag, field_component.real) - field_value.name = f"∠{field_component.name}" - - return field_value - def _get_scalar_field(self, field_monitor_name: str, field_name: str, val: FieldVal): """return ``xarray.DataArray`` of the scalar field of a given monitor at Yee cell centers. diff --git a/tidy3d/components/heat/source.py b/tidy3d/components/heat/source.py index b72f6993af..c477353036 100644 --- a/tidy3d/components/heat/source.py +++ b/tidy3d/components/heat/source.py @@ -8,14 +8,15 @@ from .viz import plot_params_heat_source -from ..base import Tidy3dBaseModel, cached_property +from ..base import cached_property +from ..base_sim.source import AbstractSource from ..data.data_array import TimeDataArray from ..viz import PlotParams from ...constants import VOLUMETRIC_HEAT_RATE -class HeatSource(ABC, Tidy3dBaseModel): +class HeatSource(AbstractSource, ABC): """Abstract heat source.""" structures: Tuple[str, ...] = pd.Field( diff --git a/tidy3d/components/monitor.py b/tidy3d/components/monitor.py index cfab76b6f5..9a71786a14 100644 --- a/tidy3d/components/monitor.py +++ b/tidy3d/components/monitor.py @@ -1,22 +1,23 @@ """Objects that define how data is recorded from simulation.""" -from abc import ABC, abstractmethod +from abc import ABC from typing import Union, Tuple import pydantic.v1 as pydantic import numpy as np -from .types import Ax, EMField, ArrayFloat1D, FreqArray, FreqBound, Numpy +from .types import Ax, EMField, ArrayFloat1D, FreqArray, FreqBound from .types import Literal, Direction, Coordinate, Axis, ObsGridArray, BoxSurface -from .geometry.base import Box from .validators import assert_plane from .base import cached_property, Tidy3dBaseModel from .mode import ModeSpec from .apodization import ApodizationSpec -from .viz import PlotParams, plot_params_monitor, ARROW_COLOR_MONITOR, ARROW_ALPHA +from .viz import ARROW_COLOR_MONITOR, ARROW_ALPHA from ..constants import HERTZ, SECOND, MICROMETER, RADIAN, inf from ..exceptions import SetupError, ValidationError from ..log import log +from .base_sim.monitor import AbstractMonitor + BYTES_REAL = 4 BYTES_COMPLEX = 8 @@ -24,16 +25,9 @@ WARN_NUM_MODES = 100 -class Monitor(Box, ABC): +class Monitor(AbstractMonitor): """Abstract base class for monitors.""" - name: str = pydantic.Field( - ..., - title="Name", - description="Unique name for monitor.", - min_length=1, - ) - interval_space: Tuple[Literal[1], Literal[1], Literal[1]] = pydantic.Field( (1, 1, 1), title="Spatial interval", @@ -51,75 +45,6 @@ class Monitor(Box, ABC): "and is hard-coded for other monitors depending on their specific function.", ) - @cached_property - def plot_params(self) -> PlotParams: - """Default parameters for plotting a Monitor object.""" - return plot_params_monitor - - @cached_property - def geometry(self) -> Box: - """:class:`Box` representation of monitor. - - Returns - ------- - :class:`Box` - Representation of the monitor geometry as a :class:`Box`. - """ - return Box(center=self.center, size=self.size) - - @abstractmethod - def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: - """Size of monitor storage given the number of points after discretization. - - Parameters - ---------- - num_cells : int - Number of grid cells within the monitor after discretization by a :class:`Simulation`. - tmesh : Array - The discretized time mesh of a :class:`Simulation`. - - Returns - ------- - int - Number of bytes to be stored in monitor. - """ - - def downsample(self, arr: Numpy, axis: Axis) -> Numpy: - """Downsample a 1D array making sure to keep the first and last entries, based on the - spatial interval defined for the ``axis``. - - Parameters - ---------- - arr : Numpy - A 1D array of arbitrary type. - axis : Axis - Axis for which to select the interval_space defined for the monitor. - - Returns - ------- - Numpy - Downsampled array. - """ - - size = len(arr) - interval = self.interval_space[axis] - # There should always be at least 3 indices for "surface" monitors. Also, if the - # size along this dim is already smaller than the interval, then don't downsample. - if size < 4 or (size - 1) <= interval: - return arr - # make sure the last index is always included - inds = np.arange(0, size, interval) - if inds[-1] != size - 1: - inds = np.append(inds, size - 1) - return arr[inds] - - def downsampled_num_cells(self, num_cells: Tuple[int, int, int]) -> Tuple[int, int, int]: - """Given a tuple of the number of cells spanned by the monitor along each dimension, - return the number of cells one would have after downsampling based on ``interval_space``. - """ - arrs = [np.arange(ncells) for ncells in num_cells] - return tuple((self.downsample(arr, axis=dim).size for dim, arr in enumerate(arrs))) - class FreqMonitor(Monitor, ABC): """:class:`Monitor` that records data in the frequency-domain.""" diff --git a/tidy3d/components/scene.py b/tidy3d/components/scene.py index ee167233f9..4d3a74d4d4 100644 --- a/tidy3d/components/scene.py +++ b/tidy3d/components/scene.py @@ -11,14 +11,14 @@ from .base import cached_property, Tidy3dBaseModel from .validators import assert_unique_names -from .geometry.base import Box -from .geometry.mesh import TriangleMesh +from .geometry.base import Box, GeometryGroup, ClipOperation +from .geometry.utils import flatten_groups, traverse_geometries from .types import Ax, Shapely, TYPE_TAG_STR, Bound, Size, Coordinate, InterpMethod from .medium import Medium, MediumType, PECMedium from .medium import AbstractCustomMedium, Medium2D, MediumType3D -from .medium import AnisotropicMedium, AbstractPerturbationMedium +from .medium import AbstractPerturbationMedium +from .grid.grid import Grid from .structure import Structure -from .data.dataset import Dataset from .data.data_array import SpatialDataArray from .viz import add_ax_if_none, equal_aspect from .grid.grid import Coords @@ -34,6 +34,9 @@ # maximum number of mediums supported MAX_NUM_MEDIUMS = 65530 +# maximum geometry count in a single structure +MAX_GEOMETRY_COUNT = 100 + class Scene(Tidy3dBaseModel): """Contains generic information about the geometry and medium properties common to all types of @@ -88,6 +91,29 @@ def _validate_num_mediums(cls, val): return val + @pd.validator("structures", always=True) + def _validate_num_geometries(cls, val): + """Error if too many geometries in a single structure.""" + + if val is None: + return val + + for i, structure in enumerate(val): + for geometry in flatten_groups(structure.geometry): + count = sum( + 1 + for g in traverse_geometries(geometry) + if not isinstance(g, (GeometryGroup, ClipOperation)) + ) + if count > MAX_GEOMETRY_COUNT: + raise SetupError( + f"Structure at 'structures[{i}]' has {count} geometries that cannot be " + f"flattened. A maximum of {MAX_GEOMETRY_COUNT} is supported due to " + f"preprocessing performance." + ) + + return val + """ Accounting """ @cached_property @@ -357,13 +383,13 @@ def plot_structures( The supplied or created matplotlib axes. """ - medium_shapes = self._get_structures_plane(structures=self.structures, x=x, y=y, z=z) + medium_shapes = self._get_structures_2dbox( + structures=self.structures, x=x, y=y, z=z, hlim=hlim, vlim=vlim + ) medium_map = self.medium_map - for (medium, shape) in medium_shapes: mat_index = medium_map[medium] ax = self._plot_shape_structure(medium=medium, mat_index=mat_index, shape=shape, ax=ax) - ax = self._set_plot_bounds(bounds=self.bounds, ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim) # clean up the axis display @@ -448,11 +474,16 @@ def _set_plot_bounds( ax.set_ylim(vlim) return ax - @staticmethod - def _get_structures_plane( - structures: List[Structure], x: float = None, y: float = None, z: float = None + def _get_structures_2dbox( + self, + structures: List[Structure], + x: float = None, + y: float = None, + z: float = None, + hlim: Tuple[float, float] = None, + vlim: Tuple[float, float] = None, ) -> List[Tuple[Medium, Shapely]]: - """Compute list of shapes to plot on plane specified by {x,y,z}. + """Compute list of shapes to plot on 2d box specified by (x_min, x_max), (y_min, y_max). Parameters ---------- @@ -464,17 +495,42 @@ def _get_structures_plane( position of plane in y direction, only one of x, y, z must be specified to define plane. z : float = None position of plane in z direction, only one of x, y, z must be specified to define plane. + hlim : Tuple[float, float] = None + The x range if plotting on xy or xz planes, y range if plotting on yz plane. + vlim : Tuple[float, float] = None + The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns ------- List[Tuple[:class:`.AbstractMedium`, shapely.geometry.base.BaseGeometry]] List of shapes and mediums on the plane. """ + # if no hlim and/or vlim given, the bounds will then be the usual pml bounds + axis, _ = Box.parse_xyz_kwargs(x=x, y=y, z=z) + _, (hmin, vmin) = Box.pop_axis(self.bounds[0], axis=axis) + _, (hmax, vmax) = Box.pop_axis(self.bounds[1], axis=axis) + + if hlim is not None: + (hmin, hmax) = hlim + if vlim is not None: + (vmin, vmax) = vlim + + # get center and size with h, v + h_center = (hmin + hmax) / 2.0 + v_center = (vmin + vmax) / 2.0 + h_size = (hmax - hmin) or inf + v_size = (vmax - vmin) or inf + + axis, center_normal = Box.parse_xyz_kwargs(x=x, y=y, z=z) + center = Box.unpop_axis(center_normal, (h_center, v_center), axis=axis) + size = Box.unpop_axis(0.0, (h_size, v_size), axis=axis) + plane = Box(center=center, size=size) + medium_shapes = [] for structure in structures: - intersections = structure.geometry.intersections_plane(x=x, y=y, z=z) - if len(intersections) > 0: - for shape in intersections: + intersections = plane.intersections_with(structure.geometry) + for shape in intersections: + if not shape.is_empty: shape = Box.evaluate_inf_shape(shape) medium_shapes.append((structure.medium, shape)) return medium_shapes @@ -649,6 +705,7 @@ def plot_structures_eps( ax: Ax = None, hlim: Tuple[float, float] = None, vlim: Tuple[float, float] = None, + grid: Grid = None, ) -> Ax: """Plot each of scene's structures on a plane defined by one nonzero x,y,z coordinate. The permittivity is plotted in grayscale based on its value at the specified frequency. @@ -688,7 +745,6 @@ def plot_structures_eps( """ structures = self.structures - structures = [self.background_structure] + list(structures) # alpha is None just means plot without any transparency if alpha is None: @@ -704,7 +760,10 @@ def plot_structures_eps( plane = Box(center=center, size=size) medium_shapes = self._filter_structures_plane_medium(structures=structures, plane=plane) else: - medium_shapes = self._get_structures_plane(structures=structures, x=x, y=y, z=z) + structures = [self.background_structure] + list(structures) + medium_shapes = self._get_structures_2dbox( + structures=structures, x=x, y=y, z=z, hlim=hlim, vlim=vlim + ) eps_min, eps_max = eps_lim @@ -737,7 +796,7 @@ def plot_structures_eps( else: # For custom medium, apply pcolormesh clipped by the shape. self._pcolormesh_shape_custom_medium_structure_eps( - x, y, z, freq, alpha, medium, eps_min, eps_max, reverse, shape, ax + x, y, z, freq, alpha, medium, eps_min, eps_max, reverse, shape, ax, grid ) if cbar: @@ -779,11 +838,10 @@ def eps_bounds(self, freq: float = None) -> Tuple[float, float]: eps_list = [ medium.eps_model(freq).real for medium in medium_list - if not isinstance(medium, AbstractCustomMedium) + if not isinstance(medium, AbstractCustomMedium) and not isinstance(medium, Medium2D) ] - eps_list.append(1) - eps_min = min(eps_list) - eps_max = max(eps_list) + eps_min = min(eps_list, default=1) + eps_max = max(eps_list, default=1) # custom medium, the min and max in the supplied dataset over all components and # spatial locations. for mat in [medium for medium in medium_list if isinstance(medium, AbstractCustomMedium)]: @@ -811,6 +869,7 @@ def _pcolormesh_shape_custom_medium_structure_eps( reverse: bool, shape: Shapely, ax: Ax, + grid: Grid, ): """ Plot shape made of custom medium with ``pcolormesh``. @@ -818,8 +877,6 @@ def _pcolormesh_shape_custom_medium_structure_eps( coords = "xyz" normal_axis_ind, normal_position = Box.parse_xyz_kwargs(x=x, y=y, z=z) normal_axis, plane_axes = Box.pop_axis(coords, normal_axis_ind) - plane_axes_inds = [0, 1, 2] - plane_axes_inds.pop(normal_axis_ind) # make grid for eps interpolation # we will do this by combining shape bounds and points where custom eps is provided @@ -828,31 +885,54 @@ def _pcolormesh_shape_custom_medium_structure_eps( rmin.insert(normal_axis_ind, normal_position) rmax.insert(normal_axis_ind, normal_position) - # in case when different components of custom medium are defined on different grids - # we will combine all points along each dimension - eps_diag = medium.eps_dataarray_freq(frequency=freq) - if eps_diag[0].coords == eps_diag[1].coords and eps_diag[0].coords == eps_diag[2].coords: - coords_to_insert = [eps_diag[0].coords] + if grid is None: + plane_axes_inds = [0, 1, 2] + plane_axes_inds.pop(normal_axis_ind) + + # in case when different components of custom medium are defined on different grids + # we will combine all points along each dimension + eps_diag = medium.eps_dataarray_freq(frequency=freq) + if ( + eps_diag[0].coords == eps_diag[1].coords + and eps_diag[0].coords == eps_diag[2].coords + ): + coords_to_insert = [eps_diag[0].coords] + else: + coords_to_insert = [eps_diag[0].coords, eps_diag[1].coords, eps_diag[2].coords] + + # actual combining of points along each of plane dimensions + plane_coord = [] + for ind, comp in zip(plane_axes_inds, plane_axes): + # first start with an array made of shapes bounds + axis_coords = np.array([rmin[ind], rmax[ind]]) + # now add points in between them + for coords in coords_to_insert: + comp_axis_coords = coords[comp] + inds_inside_shape = np.where( + np.logical_and(comp_axis_coords > rmin[ind], comp_axis_coords < rmax[ind]) + )[0] + if len(inds_inside_shape) > 0: + axis_coords = np.concatenate( + (axis_coords, comp_axis_coords[inds_inside_shape]) + ) + # remove duplicates + axis_coords = np.unique(axis_coords) + + plane_coord.append(axis_coords) else: - coords_to_insert = [eps_diag[0].coords, eps_diag[1].coords, eps_diag[2].coords] - - # actual combining of points along each of plane dimensions - plane_coord = [] - for ind, comp in zip(plane_axes_inds, plane_axes): - # first start with an array made of shapes bounds - axis_coords = np.array([rmin[ind], rmax[ind]]) - # now add points in between them - for coords in coords_to_insert: - comp_axis_coords = coords[comp] - inds_inside_shape = np.where( - np.logical_and(comp_axis_coords > rmin[ind], comp_axis_coords < rmax[ind]) - )[0] - if len(inds_inside_shape) > 0: - axis_coords = np.concatenate((axis_coords, comp_axis_coords[inds_inside_shape])) - # remove duplicates - axis_coords = np.unique(axis_coords) - - plane_coord.append(axis_coords) + span_inds = grid.discretize_inds(Box.from_bounds(rmin=rmin, rmax=rmax), extend=True) + # filter negative or too large inds + n_grid = [len(grid_comp) for grid_comp in grid.boundaries.to_list] + span_inds = [ + (max(fmin, 0), min(fmax, n_grid[f_ind])) + for f_ind, (fmin, fmax) in enumerate(span_inds) + ] + + # assemble the coordinate in the 2d plane + plane_coord = [] + for plane_axis in range(2): + ind_axis = "xyz".index(plane_axes[plane_axis]) + plane_coord.append(grid.boundaries.to_list[ind_axis][slice(*span_inds[ind_axis])]) # prepare `Coords` for interpolation coord_dict = { @@ -1035,7 +1115,6 @@ def plot_structures_heat_conductivity( """ structures = self.structures - structures = [self.background_structure] + list(structures) # alpha is None just means plot without any transparency if alpha is None: @@ -1051,7 +1130,10 @@ def plot_structures_heat_conductivity( plane = Box(center=center, size=size) medium_shapes = self._filter_structures_plane_medium(structures=structures, plane=plane) else: - medium_shapes = self._get_structures_plane(structures=structures, x=x, y=y, z=z) + structures = [self.background_structure] + list(structures) + medium_shapes = self._get_structures_2dbox( + structures=structures, x=x, y=y, z=z, hlim=hlim, vlim=vlim + ) heat_cond_min, heat_cond_max = self.heat_conductivity_bounds() for (medium, shape) in medium_shapes: @@ -1154,31 +1236,6 @@ def _plot_shape_structure_heat_cond( """ Misc """ - @property - def custom_datasets(self) -> List[Dataset]: - """List of custom datasets for verification purposes. If the list is not empty, then - the scene needs to be exported to hdf5 to store the data. - """ - datasets_medium = [mat for mat in self.mediums if isinstance(mat, AbstractCustomMedium)] - datasets_geometry = [ - struct.geometry.mesh_dataset - for struct in self.structures - if isinstance(struct.geometry, TriangleMesh) - ] - return datasets_medium + datasets_geometry - - @cached_property - def allow_gain(self) -> bool: - """``True`` if any of the mediums in the scene allows gain.""" - - for medium in self.mediums: - if isinstance(medium, AnisotropicMedium): - if np.any([med.allow_gain for med in [medium.xx, medium.yy, medium.zz]]): - return True - elif medium.allow_gain: - return True - return False - def perturbed_mediums_copy( self, temperature: SpatialDataArray = None, diff --git a/tidy3d/components/simulation.py b/tidy3d/components/simulation.py index 2e9a8578b1..c65d50f7f6 100644 --- a/tidy3d/components/simulation.py +++ b/tidy3d/components/simulation.py @@ -2,24 +2,21 @@ from __future__ import annotations from typing import Dict, Tuple, List, Set, Union -from math import isclose import pydantic.v1 as pydantic import numpy as np import xarray as xr -import matplotlib.pyplot as plt import matplotlib as mpl -from mpl_toolkits.axes_grid1 import make_axes_locatable from .base import cached_property -from .validators import assert_unique_names, assert_objects_in_sim_bounds +from .validators import assert_objects_in_sim_bounds from .validators import validate_mode_objects_symmetry -from .geometry.base import Geometry, Box, GeometryGroup, ClipOperation +from .geometry.base import Geometry, Box from .geometry.primitives import Cylinder from .geometry.mesh import TriangleMesh from .geometry.polyslab import PolySlab from .geometry.utils import flatten_groups, traverse_geometries -from .types import Ax, Shapely, FreqBound, Axis, annotate_type, Symmetry, TYPE_TAG_STR, InterpMethod +from .types import Ax, FreqBound, Axis, annotate_type, InterpMethod from .grid.grid import Coords1D, Grid, Coords from .grid.grid_spec import GridSpec, UniformGrid, AutoGrid from .medium import Medium, MediumType, AbstractMedium, PECMedium @@ -39,16 +36,17 @@ from .viz import add_ax_if_none, equal_aspect from .scene import Scene -from .viz import MEDIUM_CMAP, STRUCTURE_EPS_CMAP, PlotParams, plot_params_symmetry, polygon_path -from .viz import plot_params_structure, plot_params_pml, plot_params_override_structures +from .viz import PlotParams +from .viz import plot_params_pml, plot_params_override_structures from .viz import plot_params_pec, plot_params_pmc, plot_params_bloch, plot_sim_3d -from ..version import __version__ -from ..constants import C_0, SECOND, inf, fp_eps -from ..exceptions import Tidy3dKeyError, SetupError, ValidationError, Tidy3dError, Tidy3dImportError +from ..constants import C_0, SECOND, fp_eps +from ..exceptions import SetupError, ValidationError, Tidy3dError, Tidy3dImportError from ..log import log from ..updater import Updater +from .base_sim.simulation import AbstractSimulation + try: gdstk_available = True import gdstk @@ -65,12 +63,6 @@ # minimum number of grid points allowed per central wavelength in a medium MIN_GRIDS_PER_WVL = 6.0 -# maximum number of mediums supported -MAX_NUM_MEDIUMS = 65530 - -# maximum geometry count in a single structure -MAX_GEOMETRY_COUNT = 100 - # maximum numbers of simulation parameters MAX_TIME_STEPS = 1e7 WARN_TIME_STEPS = 1e6 @@ -92,7 +84,7 @@ PML_HEIGHT_FOR_0_DIMS = 0.02 -class Simulation(Box): +class Simulation(AbstractSimulation): """Contains all information about Tidy3d simulation. Example @@ -151,33 +143,6 @@ class Simulation(Box): units=SECOND, ) - medium: MediumType3D = pydantic.Field( - Medium(), - title="Background Medium", - description="Background medium of simulation, defaults to vacuum if not specified.", - discriminator=TYPE_TAG_STR, - ) - - symmetry: Tuple[Symmetry, Symmetry, Symmetry] = pydantic.Field( - (0, 0, 0), - title="Symmetries", - description="Tuple of integers defining reflection symmetry across a plane " - "bisecting the simulation domain normal to the x-, y-, and z-axis " - "at the simulation center of each axis, respectively. " - "Each element can be ``0`` (no symmetry), ``1`` (even, i.e. 'PMC' symmetry) or " - "``-1`` (odd, i.e. 'PEC' symmetry). " - "Note that the vectorial nature of the fields must be taken into account to correctly " - "determine the symmetry value.", - ) - - structures: Tuple[Structure, ...] = pydantic.Field( - (), - title="Structures", - description="Tuple of structures present in simulation. " - "Note: Structures defined later in this list override the " - "simulation material properties in regions of spatial overlap.", - ) - sources: Tuple[annotate_type(SourceType), ...] = pydantic.Field( (), title="Sources", @@ -239,12 +204,6 @@ class Simulation(Box): le=1.0, ) - version: str = pydantic.Field( - __version__, - title="Version", - description="String specifying the front end version number.", - ) - """ Validating setup """ @pydantic.root_validator(pre=True) @@ -266,17 +225,10 @@ def _validate_auto_grid_wavelength(cls, val, values): _ = val.wavelength_from_sources(sources=values.get("sources")) return val - _structures_in_bounds = assert_objects_in_sim_bounds("structures", error=False) _sources_in_bounds = assert_objects_in_sim_bounds("sources") - _monitors_in_bounds = assert_objects_in_sim_bounds("monitors") _mode_sources_symmetries = validate_mode_objects_symmetry("sources") _mode_monitors_symmetries = validate_mode_objects_symmetry("monitors") - # make sure all names are unique - _unique_structure_names = assert_unique_names("structures") - _unique_source_names = assert_unique_names("sources") - _unique_monitor_names = assert_unique_names("monitors") - # _few_enough_mediums = validate_num_mediums() # _structures_not_at_edges = validate_structure_bounds_not_at_edges() # _gap_size_ok = validate_pml_gap_size() @@ -442,74 +394,6 @@ def boundaries_for_zero_dims(cls, val, values): ) return val - @pydantic.validator("structures", always=True) - def _validate_num_mediums(cls, val): - """Error if too many mediums present.""" - - if val is None: - return val - - mediums = {structure.medium for structure in val} - if len(mediums) > MAX_NUM_MEDIUMS: - raise SetupError( - f"Tidy3d only supports {MAX_NUM_MEDIUMS} distinct mediums." - f"{len(mediums)} were supplied." - ) - - return val - - @pydantic.validator("structures", always=True) - def _validate_num_geometries(cls, val): - """Error if too many geometries in a single structure.""" - - if val is None: - return val - - for i, structure in enumerate(val): - for geometry in flatten_groups(structure.geometry): - count = sum( - 1 - for g in traverse_geometries(geometry) - if not isinstance(g, (GeometryGroup, ClipOperation)) - ) - if count > MAX_GEOMETRY_COUNT: - raise SetupError( - f"Structure at 'structures[{i}]' has {count} geometries that cannot be " - f"flattened. A maximum of {MAX_GEOMETRY_COUNT} is supported due to " - f"preprocessing performance." - ) - - return val - - @pydantic.validator("structures", always=True) - def _structures_not_at_edges(cls, val, values): - """Warn if any structures lie at the simulation boundaries.""" - - if val is None: - return val - - sim_box = Box(size=values.get("size"), center=values.get("center")) - sim_bound_min, sim_bound_max = sim_box.bounds - sim_bounds = list(sim_bound_min) + list(sim_bound_max) - - with log as consolidated_logger: - for istruct, structure in enumerate(val): - struct_bound_min, struct_bound_max = structure.geometry.bounds - struct_bounds = list(struct_bound_min) + list(struct_bound_max) - - for sim_val, struct_val in zip(sim_bounds, struct_bounds): - - if isclose(sim_val, struct_val): - consolidated_logger.warning( - f"Structure at 'structures[{istruct}]' has bounds that extend exactly " - "to simulation edges. This can cause unexpected behavior. " - "If intending to extend the structure to infinity along one dimension, " - "use td.inf as a size variable instead to make this explicit.", - custom_loc=["structures", istruct], - ) - - return val - @pydantic.validator("structures", always=True) def _validate_2d_geometry_has_2d_medium(cls, val, values): """Warn if a geometry bounding box has zero size in a certain dimension.""" @@ -925,6 +809,7 @@ def _check_normalize_index(cls, val, values): def _post_init_validators(self) -> None: """Call validators taking z`self` that get run after init.""" + _ = self.scene self._validate_no_structures_pml() self._validate_tfsf_nonuniform_grid() @@ -1226,9 +1111,7 @@ def mediums(self) -> Set[MediumType]: List[:class:`.AbstractMedium`] Set of distinct mediums in the simulation. """ - medium_dict = {self.medium: None} - medium_dict.update({structure.medium: None for structure in self.structures}) - return list(medium_dict.keys()) + return self.scene.mediums @cached_property def medium_map(self) -> Dict[MediumType, pydantic.NonNegativeInt]: @@ -1242,20 +1125,12 @@ def medium_map(self) -> Dict[MediumType, pydantic.NonNegativeInt]: Mapping between distinct mediums to index in simulation. """ - return {medium: index for index, medium in enumerate(self.mediums)} - - def get_monitor_by_name(self, name: str) -> Monitor: - """Return monitor named 'name'.""" - for monitor in self.monitors: - if monitor.name == name: - return monitor - raise Tidy3dKeyError(f"No monitor named '{name}'") + return self.scene.medium_map @cached_property def background_structure(self) -> Structure: """Returns structure representing the background of the :class:`.Simulation`.""" - geometry = Box(size=(inf, inf, inf)) - return Structure(geometry=geometry, medium=self.medium) + return self.scene.background_structure @staticmethod def intersecting_media( @@ -1277,19 +1152,7 @@ def intersecting_media( List[:class:`.AbstractMedium`] Set of distinct mediums that intersect with the given planar object. """ - if test_object.size.count(0.0) == 1: - # get all merged structures on the test_object, which is already planar - structures_merged = Simulation._filter_structures_plane(structures, test_object) - mediums = {medium for medium, _ in structures_merged} - return mediums - - # if the test object is a volume, test each surface recursively - surfaces = test_object.surfaces_with_exclusion(**test_object.dict()) - mediums = set() - for surface in surfaces: - _mediums = Simulation.intersecting_media(surface, structures) - mediums.update(_mediums) - return mediums + return Scene.intersecting_media(test_object=test_object, structures=structures) @staticmethod def intersecting_structures( @@ -1311,26 +1174,7 @@ def intersecting_structures( Set of distinct structures that intersect with the given surface, or with the surfaces of the given volume. """ - if test_object.size.count(0.0) == 1: - # get all merged structures on the test_object, which is already planar - normal_axis_index = test_object.size.index(0.0) - dim = "xyz"[normal_axis_index] - pos = test_object.center[normal_axis_index] - xyz_kwargs = {dim: pos} - - structures_merged = [] - for structure in structures: - intersections = structure.geometry.intersections_plane(**xyz_kwargs) - if len(intersections) > 0: - structures_merged.append(structure) - return structures_merged - - # if the test object is a volume, test each surface recursively - surfaces = test_object.surfaces_with_exclusion(**test_object.dict()) - structures_merged = [] - for surface in surfaces: - structures_merged += Simulation.intersecting_structures(surface, structures) - return structures_merged + return Scene.intersecting_structures(test_object=test_object, structures=structures) def monitor_medium(self, monitor: MonitorType): """Return the medium in which the given monitor resides. @@ -1699,28 +1543,18 @@ def plot( matplotlib.axes._subplots.Axes The supplied or created matplotlib axes. """ - # if no hlim and/or vlim given, the bounds will then be the usual pml bounds - axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z) - _, (hmin, vmin) = self.pop_axis(self.bounds_pml[0], axis=axis) - _, (hmax, vmax) = self.pop_axis(self.bounds_pml[1], axis=axis) - - # account for unordered limits - if hlim is None: - hlim = (hmin, hmax) - if vlim is None: - vlim = (vmin, vmax) - - if hlim[0] > hlim[1]: - raise Tidy3dError("Error: 'hmin' > 'hmax'") - if vlim[0] > vlim[1]: - raise Tidy3dError("Error: 'vmin' > 'vmax'") + hlim, vlim = Scene._get_plot_lims( + bounds=self.simulation_bounds, x=x, y=y, z=z, hlim=hlim, vlim=vlim + ) ax = self.plot_structures(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim) ax = self.plot_sources(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim, alpha=source_alpha) ax = self.plot_monitors(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim, alpha=monitor_alpha) ax = self.plot_symmetries(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim) ax = self.plot_pml(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim) - ax = self._set_plot_bounds(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim) + ax = Scene._set_plot_bounds( + bounds=self.simulation_bounds, ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim + ) ax = self.plot_boundaries(ax=ax, x=x, y=y, z=z) return ax @@ -1772,30 +1606,29 @@ def plot_eps( matplotlib.axes._subplots.Axes The supplied or created matplotlib axes. """ - # if no hlim and/or vlim given, the bounds will then be the usual pml bounds - axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z) - _, (hmin, vmin) = self.pop_axis(self.bounds_pml[0], axis=axis) - _, (hmax, vmax) = self.pop_axis(self.bounds_pml[1], axis=axis) - # account for unordered limits - if hlim is None: - hlim = (hmin, hmax) - if vlim is None: - vlim = (vmin, vmax) - - if hlim[0] > hlim[1]: - raise Tidy3dError("Error: hmin > hmax") - if vlim[0] > vlim[1]: - raise Tidy3dError("Error: vmin > vmax") + hlim, vlim = Scene._get_plot_lims( + bounds=self.simulation_bounds, x=x, y=y, z=z, hlim=hlim, vlim=vlim + ) ax = self.plot_structures_eps( - freq=freq, cbar=True, alpha=alpha, ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim + freq=freq, + cbar=True, + alpha=alpha, + ax=ax, + x=x, + y=y, + z=z, + hlim=hlim, + vlim=vlim, ) ax = self.plot_sources(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim, alpha=source_alpha) ax = self.plot_monitors(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim, alpha=monitor_alpha) ax = self.plot_symmetries(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim) ax = self.plot_pml(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim) - ax = self._set_plot_bounds(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim) + ax = Scene._set_plot_bounds( + bounds=self.simulation_bounds, ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim + ) ax = self.plot_boundaries(ax=ax, x=x, y=y, z=z) return ax @@ -1810,7 +1643,7 @@ def plot_structures( hlim: Tuple[float, float] = None, vlim: Tuple[float, float] = None, ) -> Ax: - """Plot each of simulation's structures on a plane defined by one nonzero x,y,z coordinate. + """Plot each of scene's structures on a plane defined by one nonzero x,y,z coordinate. Parameters ---------- @@ -1833,59 +1666,11 @@ def plot_structures( The supplied or created matplotlib axes. """ - medium_shapes = self._get_structures_2dbox( - structures=self.structures, x=x, y=y, z=z, hlim=hlim, vlim=vlim + hlim_new, vlim_new = Scene._get_plot_lims( + bounds=self.simulation_bounds, x=x, y=y, z=z, hlim=hlim, vlim=vlim ) - medium_map = self.medium_map - for medium, shape in medium_shapes: - mat_index = medium_map[medium] - ax = self._plot_shape_structure(medium=medium, mat_index=mat_index, shape=shape, ax=ax) - ax = self._set_plot_bounds(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim) - # clean up the axis display - axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) - ax = self.add_ax_labels_lims(axis=axis, ax=ax) - ax.set_title(f"cross section at {'xyz'[axis]}={position:.2f}") - - return ax - - def _plot_shape_structure(self, medium: Medium, mat_index: int, shape: Shapely, ax: Ax) -> Ax: - """Plot a structure's cross section shape for a given medium.""" - plot_params_struct = self._get_structure_plot_params(medium=medium, mat_index=mat_index) - ax = self.plot_shape(shape=shape, plot_params=plot_params_struct, ax=ax) - return ax - - def _get_structure_plot_params(self, mat_index: int, medium: Medium) -> PlotParams: - """Constructs the plot parameters for a given medium in simulation.plot().""" - - plot_params = plot_params_structure.copy(update={"linewidth": 0}) - - if mat_index == 0 or medium == self.medium: - # background medium - plot_params = plot_params.copy(update={"facecolor": "white", "edgecolor": "white"}) - elif isinstance(medium, PECMedium): - # perfect electrical conductor - plot_params = plot_params.copy( - update={"facecolor": "gold", "edgecolor": "k", "linewidth": 1} - ) - elif isinstance(medium, Medium2D): - # 2d material - plot_params = plot_params.copy(update={"edgecolor": "k", "linewidth": 1}) - else: - # regular medium - facecolor = MEDIUM_CMAP[(mat_index - 1) % len(MEDIUM_CMAP)] - plot_params = plot_params.copy(update={"facecolor": facecolor}) - - return plot_params - - @staticmethod - def _add_cbar(eps_min: float, eps_max: float, ax: Ax = None) -> None: - """Add a colorbar to eps plot.""" - norm = mpl.colors.Normalize(vmin=eps_min, vmax=eps_max) - divider = make_axes_locatable(ax) - cax = divider.append_axes("right", size="5%", pad=0.15) - mappable = mpl.cm.ScalarMappable(norm=norm, cmap=STRUCTURE_EPS_CMAP) - plt.colorbar(mappable, cax=cax, label=r"$\epsilon_r$") + return self.scene.plot_structures(x=x, y=y, z=z, ax=ax, hlim=hlim_new, vlim=vlim_new) @equal_aspect @add_ax_if_none @@ -1937,292 +1722,27 @@ def plot_structures_eps( The supplied or created matplotlib axes. """ - structures = self.structures - - # alpha is None just means plot without any transparency - if alpha is None: - alpha = 1 - - if alpha <= 0: - return ax - - if alpha < 1 and not isinstance(self.medium, AbstractCustomMedium): - axis, position = Box.parse_xyz_kwargs(x=x, y=y, z=z) - center = Box.unpop_axis(position, (0, 0), axis=axis) - size = Box.unpop_axis(0, (inf, inf), axis=axis) - plane = Box(center=center, size=size) - medium_shapes = self._filter_structures_plane(structures=structures, plane=plane) - else: - structures = [self.background_structure] + list(structures) - medium_shapes = self._get_structures_2dbox( - structures=structures, x=x, y=y, z=z, hlim=hlim, vlim=vlim - ) - - eps_min, eps_max = self.eps_bounds(freq=freq) - for medium, shape in medium_shapes: - # if the background medium is custom medium, it needs to be rendered separately - if medium == self.medium and alpha < 1 and not isinstance(medium, AbstractCustomMedium): - continue - # no need to add patches for custom medium - if not isinstance(medium, AbstractCustomMedium): - ax = self._plot_shape_structure_eps( - freq=freq, - alpha=alpha, - medium=medium, - eps_min=eps_min, - eps_max=eps_max, - reverse=reverse, - shape=shape, - ax=ax, - ) - else: - # For custom medium, apply pcolormesh clipped by the shape. - self._pcolormesh_shape_custom_medium_structure_eps( - x, y, z, freq, alpha, medium, eps_min, eps_max, reverse, shape, ax - ) - - if cbar: - self._add_cbar(eps_min=eps_min, eps_max=eps_max, ax=ax) - ax = self._set_plot_bounds(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim) - - # clean up the axis display - axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) - ax = self.add_ax_labels_lims(axis=axis, ax=ax) - ax.set_title(f"cross section at {'xyz'[axis]}={position:.2f}") - - return ax - - def eps_bounds(self, freq: float = None) -> Tuple[float, float]: - """Compute range of (real) permittivity present in the simulation at frequency "freq".""" - - medium_list = [self.medium] + list(self.mediums) - medium_list = [medium for medium in medium_list if not isinstance(medium, PECMedium)] - # regular medium - eps_list = [ - medium.eps_model(freq).real - for medium in medium_list - if not isinstance(medium, AbstractCustomMedium) and not isinstance(medium, Medium2D) - ] - eps_min = min(eps_list, default=1) - eps_max = max(eps_list, default=1) - # custom medium, the min and max in the supplied dataset over all components and - # spatial locations. - for mat in [medium for medium in medium_list if isinstance(medium, AbstractCustomMedium)]: - eps_dataarray = mat.eps_dataarray_freq(freq) - eps_min = min( - eps_min, - min(np.min(eps_comp.real.values.ravel()) for eps_comp in eps_dataarray), - ) - eps_max = max( - eps_max, - max(np.max(eps_comp.real.values.ravel()) for eps_comp in eps_dataarray), - ) - return eps_min, eps_max - - def _pcolormesh_shape_custom_medium_structure_eps( - self, - x: float, - y: float, - z: float, - freq: float, - alpha: float, - medium: Medium, - eps_min: float, - eps_max: float, - reverse: bool, - shape: Shapely, - ax: Ax, - ): - """ - Plot shape made of custom medium with ``pcolormesh``. - """ - coords = "xyz" - normal_axis_ind, normal_position = self.parse_xyz_kwargs(x=x, y=y, z=z) - normal_axis, plane_axes = self.pop_axis(coords, normal_axis_ind) - - # First, obtain `span_inds` of grids for interpolating permittivity in the - # bounding box of the shape - shape_bounds = shape.bounds - rmin, rmax = [*shape_bounds[:2]], [*shape_bounds[2:]] - rmin.insert(normal_axis_ind, normal_position) - rmax.insert(normal_axis_ind, normal_position) - span_inds = self.grid.discretize_inds(Box.from_bounds(rmin=rmin, rmax=rmax), extend=True) - # filter negative or too large inds - n_grid = [len(grid_comp) for grid_comp in self.grid.boundaries.to_list] - span_inds = [ - (max(fmin, 0), min(fmax, n_grid[f_ind])) for f_ind, (fmin, fmax) in enumerate(span_inds) - ] - - # assemble the coordinate in the 2d plane - plane_coord = [] - for plane_axis in range(2): - ind_axis = "xyz".index(plane_axes[plane_axis]) - plane_coord.append(self.grid.boundaries.to_list[ind_axis][slice(*span_inds[ind_axis])]) - - # prepare `Coords` for interpolation - coord_dict = { - plane_axes[0]: plane_coord[0], - plane_axes[1]: plane_coord[1], - normal_axis: [normal_position], - } - coord_shape = Coords(**coord_dict) - # interpolate permittivity and take the average over components - eps_shape = np.mean(medium.eps_diagonal_on_grid(frequency=freq, coords=coord_shape), axis=0) - # remove the normal_axis and take real part - eps_shape = eps_shape.real.mean(axis=normal_axis_ind) - # reverse - if reverse: - eps_shape = eps_min + eps_max - eps_shape - - # pcolormesh - plane_xp, plane_yp = np.meshgrid(plane_coord[0], plane_coord[1], indexing="ij") - ax.pcolormesh( - plane_xp, - plane_yp, - eps_shape, - clip_path=(polygon_path(shape), ax.transData), - cmap=STRUCTURE_EPS_CMAP, - vmin=eps_min, - vmax=eps_max, - alpha=alpha, - clip_box=ax.bbox, + hlim, vlim = Scene._get_plot_lims( + bounds=self.simulation_bounds, x=x, y=y, z=z, hlim=hlim, vlim=vlim ) - def _get_structure_eps_plot_params( - self, - medium: Medium, - freq: float, - eps_min: float, - eps_max: float, - reverse: bool = False, - alpha: float = None, - ) -> PlotParams: - """Constructs the plot parameters for a given medium in simulation.plot_eps().""" - - plot_params = plot_params_structure.copy(update={"linewidth": 0}) - if alpha is not None: - plot_params = plot_params.copy(update={"alpha": alpha}) - - if isinstance(medium, PECMedium): - # perfect electrical conductor - plot_params = plot_params.copy( - update={"facecolor": "gold", "edgecolor": "k", "linewidth": 1} - ) - elif isinstance(medium, Medium2D): - # 2d material - plot_params = plot_params.copy(update={"edgecolor": "k", "linewidth": 1}) - else: - # regular medium - eps_medium = medium.eps_model(frequency=freq).real - delta_eps = eps_medium - eps_min - delta_eps_max = eps_max - eps_min + 1e-5 - eps_fraction = delta_eps / delta_eps_max - color = eps_fraction if reverse else 1 - eps_fraction - plot_params = plot_params.copy(update={"facecolor": str(color)}) - - return plot_params - - def _plot_shape_structure_eps( - self, - freq: float, - medium: Medium, - shape: Shapely, - eps_min: float, - eps_max: float, - ax: Ax, - reverse: bool = False, - alpha: float = None, - ) -> Ax: - """Plot a structure's cross section shape for a given medium, grayscale for permittivity.""" - plot_params = self._get_structure_eps_plot_params( - medium=medium, freq=freq, eps_min=eps_min, eps_max=eps_max, alpha=alpha, reverse=reverse + return self.scene.plot_structures_eps( + freq=freq, + cbar=cbar, + alpha=alpha, + ax=ax, + x=x, + y=y, + z=z, + hlim=hlim, + vlim=vlim, + grid=self.grid, + reverse=reverse, ) - ax = self.plot_shape(shape=shape, plot_params=plot_params, ax=ax) - return ax - - @equal_aspect - @add_ax_if_none - def plot_sources( - self, - x: float = None, - y: float = None, - z: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, - alpha: float = None, - ax: Ax = None, - ) -> Ax: - """Plot each of simulation's sources on a plane defined by one nonzero x,y,z coordinate. - Parameters - ---------- - x : float = None - position of plane in x direction, only one of x, y, z must be specified to define plane. - y : float = None - position of plane in y direction, only one of x, y, z must be specified to define plane. - z : float = None - position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None - The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None - The z range if plotting on xz or yz planes, y plane if plotting on xy plane. - alpha : float = None - Opacity of the sources, If ``None`` uses Tidy3d default. - ax : matplotlib.axes._subplots.Axes = None - Matplotlib axes to plot on, if not specified, one is created. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied or created matplotlib axes. - """ - bounds = self.bounds - for source in self.sources: - ax = source.plot(x=x, y=y, z=z, alpha=alpha, ax=ax, sim_bounds=bounds) - ax = self._set_plot_bounds(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim) - return ax - - @equal_aspect - @add_ax_if_none - def plot_monitors( - self, - x: float = None, - y: float = None, - z: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, - alpha: float = None, - ax: Ax = None, - ) -> Ax: - """Plot each of simulation's monitors on a plane defined by one nonzero x,y,z coordinate. - - Parameters - ---------- - x : float = None - position of plane in x direction, only one of x, y, z must be specified to define plane. - y : float = None - position of plane in y direction, only one of x, y, z must be specified to define plane. - z : float = None - position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None - The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None - The z range if plotting on xz or yz planes, y plane if plotting on xy plane. - alpha : float = None - Opacity of the sources, If ``None`` uses Tidy3d default. - ax : matplotlib.axes._subplots.Axes = None - Matplotlib axes to plot on, if not specified, one is created. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied or created matplotlib axes. - """ - bounds = self.bounds - for monitor in self.monitors: - ax = monitor.plot(x=x, y=y, z=z, alpha=alpha, ax=ax, sim_bounds=bounds) - ax = self._set_plot_bounds(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim) - return ax + def eps_bounds(self, freq: float = None) -> Tuple[float, float]: + """Compute range of (real) permittivity present in the simulation at frequency "freq".""" + return self.scene.eps_bounds(freq=freq) @cached_property def num_pml_layers(self) -> List[Tuple[float, float]]: @@ -2270,12 +1790,9 @@ def bounds_pml(self) -> Tuple[Tuple[float, float, float], Tuple[float, float, fl return (bounds_min, bounds_max) @cached_property - def simulation_geometry(self) -> Box: - """The entire simulation domain including PML layers. It is identical to - ``sim.geometry`` in the absence of PML. - """ - rmin, rmax = self.bounds_pml - return Box.from_bounds(rmin=rmin, rmax=rmax) + def simulation_bounds(self) -> Tuple[Tuple[float, float, float], Tuple[float, float, float]]: + """Simulation bounds including the PML regions.""" + return self.bounds_pml @equal_aspect @add_ax_if_none @@ -2315,7 +1832,9 @@ def plot_pml( pml_boxes = self._make_pml_boxes(normal_axis=normal_axis) for pml_box in pml_boxes: pml_box.plot(x=x, y=y, z=z, ax=ax, **plot_params_pml.to_kwargs()) - ax = self._set_plot_bounds(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim) + ax = Scene._set_plot_bounds( + bounds=self.simulation_bounds, ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim + ) return ax def _make_pml_boxes(self, normal_axis: Axis) -> List[Box]: @@ -2350,77 +1869,6 @@ def _make_pml_box(self, pml_axis: Axis, pml_height: float, sign: int) -> Box: return pml_box - @equal_aspect - @add_ax_if_none - def plot_symmetries( - self, - x: float = None, - y: float = None, - z: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, - ax: Ax = None, - ) -> Ax: - """Plot each of simulation's symmetries on a plane defined by one nonzero x,y,z coordinate. - - Parameters - ---------- - x : float = None - position of plane in x direction, only one of x, y, z must be specified to define plane. - y : float = None - position of plane in y direction, only one of x, y, z must be specified to define plane. - z : float = None - position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None - The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None - The z range if plotting on xz or yz planes, y plane if plotting on xy plane. - ax : matplotlib.axes._subplots.Axes = None - Matplotlib axes to plot on, if not specified, one is created. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied or created matplotlib axes. - """ - - normal_axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z) - - for sym_axis, sym_value in enumerate(self.symmetry): - if sym_value == 0 or sym_axis == normal_axis: - continue - sym_box = self._make_symmetry_box(sym_axis=sym_axis) - plot_params = self._make_symmetry_plot_params(sym_value=sym_value) - ax = sym_box.plot(x=x, y=y, z=z, ax=ax, **plot_params.to_kwargs()) - ax = self._set_plot_bounds(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim) - return ax - - def _make_symmetry_plot_params(self, sym_value: Symmetry) -> PlotParams: - """Make PlotParams for symmetry.""" - - plot_params = plot_params_symmetry.copy() - - if sym_value == 1: - plot_params = plot_params.copy( - update={"facecolor": "lightsteelblue", "edgecolor": "lightsteelblue", "hatch": "++"} - ) - elif sym_value == -1: - plot_params = plot_params.copy( - update={"facecolor": "goldenrod", "edgecolor": "goldenrod", "hatch": "--"} - ) - - return plot_params - - def _make_symmetry_box(self, sym_axis: Axis) -> Box: - """Construct a :class:`.Box` representing the symmetry to be plotted.""" - sym_box = self.simulation_geometry - size = list(sym_box.size) - size[sym_axis] /= 2 - center = list(sym_box.center) - center[sym_axis] -= size[sym_axis] / 2 - - return Box(size=size, center=center) - @add_ax_if_none def plot_grid( self, @@ -2492,7 +1940,9 @@ def plot_grid( ) ax.add_patch(rect) - ax = self._set_plot_bounds(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim) + ax = Scene._set_plot_bounds( + bounds=self.simulation_bounds, ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim + ) return ax @@ -2613,217 +2063,6 @@ def set_plot_params(boundary_edge, lim, side, thickness): return ax - def _set_plot_bounds( - self, - ax: Ax, - x: float = None, - y: float = None, - z: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, - ) -> Ax: - """Sets the xy limits of the simulation at a plane, useful after plotting. - - Parameters - ---------- - ax : matplotlib.axes._subplots.Axes - Matplotlib axes to set bounds on. - x : float = None - position of plane in x direction, only one of x, y, z must be specified to define plane. - y : float = None - position of plane in y direction, only one of x, y, z must be specified to define plane. - z : float = None - position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None - The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None - The z range if plotting on xz or yz planes, y plane if plotting on xy plane. - Returns - ------- - matplotlib.axes._subplots.Axes - The axes after setting the boundaries. - """ - - axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z) - _, (xmin, ymin) = self.pop_axis(self.bounds_pml[0], axis=axis) - _, (xmax, ymax) = self.pop_axis(self.bounds_pml[1], axis=axis) - - if hlim is not None: - (xmin, xmax) = hlim - if vlim is not None: - (ymin, ymax) = vlim - - if xmin != xmax: - ax.set_xlim(xmin, xmax) - if ymin != ymax: - ax.set_ylim(ymin, ymax) - - return ax - - @staticmethod - def _get_structures_plane( - structures: List[Structure], x: float = None, y: float = None, z: float = None - ) -> List[Tuple[Medium, Shapely]]: - """Compute list of shapes to plot on plane specified by {x,y,z}. - - Parameters - ---------- - structures : List[:class:`.Structure`] - list of structures to filter on the plane. - x : float = None - position of plane in x direction, only one of x, y, z must be specified to define plane. - y : float = None - position of plane in y direction, only one of x, y, z must be specified to define plane. - z : float = None - position of plane in z direction, only one of x, y, z must be specified to define plane. - - Returns - ------- - List[Tuple[:class:`.AbstractMedium`, shapely.geometry.base.BaseGeometry]] - List of shapes and mediums on the plane. - """ - medium_shapes = [] - for structure in structures: - intersections = structure.geometry.intersections_plane(x=x, y=y, z=z) - if len(intersections) > 0: - for shape in intersections: - shape = Geometry.evaluate_inf_shape(shape) - medium_shapes.append((structure.medium, shape)) - return medium_shapes - - def _get_structures_2dbox( - self, - structures: List[Structure], - x: float = None, - y: float = None, - z: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, - ) -> List[Tuple[Medium, Shapely]]: - """Compute list of shapes to plot on 2d box specified by (x_min, x_max), (y_min, y_max). - - Parameters - ---------- - structures : List[:class:`.Structure`] - list of structures to filter on the plane. - x : float = None - position of plane in x direction, only one of x, y, z must be specified to define plane. - y : float = None - position of plane in y direction, only one of x, y, z must be specified to define plane. - z : float = None - position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None - The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None - The z range if plotting on xz or yz planes, y plane if plotting on xy plane. - - Returns - ------- - List[Tuple[:class:`.AbstractMedium`, shapely.geometry.base.BaseGeometry]] - List of shapes and mediums on the plane. - """ - # if no hlim and/or vlim given, the bounds will then be the usual pml bounds - axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z) - _, (hmin, vmin) = self.pop_axis(self.bounds_pml[0], axis=axis) - _, (hmax, vmax) = self.pop_axis(self.bounds_pml[1], axis=axis) - - if hlim is not None: - (hmin, hmax) = hlim - if vlim is not None: - (vmin, vmax) = vlim - - # get center and size with h, v - h_center = (hmin + hmax) / 2.0 - v_center = (vmin + vmax) / 2.0 - h_size = (hmax - hmin) or inf - v_size = (vmax - vmin) or inf - - axis, center_normal = self.parse_xyz_kwargs(x=x, y=y, z=z) - center = self.unpop_axis(center_normal, (h_center, v_center), axis=axis) - size = self.unpop_axis(0.0, (h_size, v_size), axis=axis) - plane = Box(center=center, size=size) - - medium_shapes = [] - for structure in structures: - intersections = plane.intersections_with(structure.geometry) - for shape in intersections: - if not shape.is_empty: - shape = Box.evaluate_inf_shape(shape) - medium_shapes.append((structure.medium, shape)) - return medium_shapes - - @staticmethod - def _filter_structures_plane( - structures: List[Structure], plane: Box - ) -> List[Tuple[Medium, Shapely]]: - """Compute list of shapes to plot on plane specified by {x,y,z}. - Overlaps are removed or merged depending on medium. - - Parameters - ---------- - structures : List[:class:`.Structure`] - list of structures to filter on the plane. - x : float = None - position of plane in x direction, only one of x, y, z must be specified to define plane. - y : float = None - position of plane in y direction, only one of x, y, z must be specified to define plane. - z : float = None - position of plane in z direction, only one of x, y, z must be specified to define plane. - - Returns - ------- - List[Tuple[:class:`.AbstractMedium`, shapely.geometry.base.BaseGeometry]] - List of shapes and mediums on the plane after merging. - """ - - shapes = [] - for structure in structures: - - # get list of Shapely shapes that intersect at the plane - shapes_plane = plane.intersections_with(structure.geometry) - - # Append each of them and their medium information to the list of shapes - for shape in shapes_plane: - shapes.append((structure.medium, shape, shape.bounds)) - - background_shapes = [] - for medium, shape, bounds in shapes: - - minx, miny, maxx, maxy = bounds - - # loop through background_shapes (note: all background are non-intersecting or merged) - for index, (_medium, _shape, _bounds) in enumerate(background_shapes): - - _minx, _miny, _maxx, _maxy = _bounds - - # do a bounding box check to see if any intersection to do anything about - if minx > _maxx or _minx > maxx or miny > _maxy or _miny > maxy: - continue - - # look more closely to see if intersected. - if _shape.is_empty or not shape.intersects(_shape): - continue - - diff_shape = _shape - shape - - # different medium, remove intersection from background shape - if medium != _medium and len(diff_shape.bounds) > 0: - background_shapes[index] = (_medium, diff_shape, diff_shape.bounds) - - # same medium, add diff shape to this shape and mark background shape for removal - else: - shape = shape | diff_shape - background_shapes[index] = None - - # after doing this with all background shapes, add this shape to the background - background_shapes.append((medium, shape, shape.bounds)) - - # remove any existing background shapes that have been marked as 'None' - background_shapes = [b for b in background_shapes if b is not None] - - # filter out any remaining None or empty shapes (shapes with area completely removed) - return [(medium, shape) for (medium, shape, _) in background_shapes if shape] - @cached_property def frequency_range(self) -> FreqBound: """Range of frequencies spanning all sources' frequency dependence. @@ -3231,7 +2470,7 @@ def make_eps_data(coords: Coords): """returns epsilon data on grid of points defined by coords""" arrays = (np.array(coords.x), np.array(coords.y), np.array(coords.z)) eps_background = get_eps( - structure=self.background_structure, frequency=freq, coords=coords + structure=self.scene.background_structure, frequency=freq, coords=coords ) shape = tuple(len(array) for array in arrays) eps_array = eps_background * np.ones(shape, dtype=complex) @@ -3539,15 +2778,6 @@ def perturbed_mediums_copy( return Simulation.parse_obj(sim_dict) - @cached_property - def scene(self) -> Scene: - """Return a :class:.`Scene` instance based on the current simulation.""" - - return Scene( - structures=self.structures, - medium=self.medium, - ) - @classmethod def from_scene(cls, scene: Scene, **kwargs) -> Simulation: """Create a simulation from a :class:.`Scene` instance. Must provide additional parameters diff --git a/tidy3d/components/source.py b/tidy3d/components/source.py index e94eed9591..903e986f8c 100644 --- a/tidy3d/components/source.py +++ b/tidy3d/components/source.py @@ -10,10 +10,11 @@ import numpy as np from .base import Tidy3dBaseModel, cached_property +from .base_sim.source import AbstractSource from .types import Coordinate, Direction, Polarization, Ax, FreqBound from .types import ArrayFloat1D, Axis, PlotVal, ArrayComplex1D, TYPE_TAG_STR -from .validators import assert_plane, assert_volumetric, validate_name_str, get_value +from .validators import assert_plane, assert_volumetric, get_value from .validators import warn_if_dataset_none, assert_single_freq_in_range from .data.dataset import FieldDataset, TimeDataset from .data.data_array import TimeDataArray @@ -470,7 +471,7 @@ def amp_time(self, time: float) -> complex: """ Source objects """ -class Source(Box, ABC): +class Source(Box, AbstractSource, ABC): """Abstract base class for all sources.""" source_time: SourceTimeType = pydantic.Field( @@ -480,15 +481,11 @@ class Source(Box, ABC): discriminator=TYPE_TAG_STR, ) - name: str = pydantic.Field(None, title="Name", description="Optional name for the source.") - @cached_property def plot_params(self) -> PlotParams: """Default parameters for plotting a Source object.""" return plot_params_source - _name_validator = validate_name_str() - @cached_property def geometry(self) -> Box: """:class:`Box` representation of source."""