Skip to content

Commit

Permalink
solver classes unification
Browse files Browse the repository at this point in the history
  • Loading branch information
dbochkov-flexcompute committed Oct 19, 2023
1 parent cb9af20 commit 9511a2d
Show file tree
Hide file tree
Showing 14 changed files with 295 additions and 1,308 deletions.
3 changes: 1 addition & 2 deletions tests/test_components/test_heat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from tidy3d import TemperatureMonitor
from tidy3d import TemperatureData

from tidy3d.exceptions import DataError
from ..utils import STL_GEO, assert_log_level, log_capture


Expand Down Expand Up @@ -349,7 +348,7 @@ def test_sim_data():
_ = heat_sim_data.plot_field("test", z=0)
plt.close()

with pytest.raises(DataError):
with pytest.raises(KeyError):
_ = heat_sim_data.plot_field("test3", x=0)

with pytest.raises(pd.ValidationError):
Expand Down
61 changes: 34 additions & 27 deletions tests/test_components/test_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import tidy3d as td
from tidy3d.components.simulation import MAX_NUM_MEDIUMS
from tidy3d.components.scene import MAX_NUM_MEDIUMS, MAX_GEOMETRY_COUNT
from ..utils import assert_log_level, log_capture, SIM_FULL

SCENE = td.Scene()
Expand Down Expand Up @@ -187,32 +187,6 @@ def test_names_unique():
)


def test_allow_gain():
"""Test if simulation allows gain."""

medium = td.Medium(permittivity=2.0)
medium_gain = td.Medium(permittivity=2.0, allow_gain=True)
medium_ani = td.AnisotropicMedium(xx=medium, yy=medium, zz=medium)
medium_gain_ani = td.AnisotropicMedium(xx=medium, yy=medium_gain, zz=medium)

# Test simulation medium
scene = td.Scene(medium=medium)
assert not scene.allow_gain
scene = scene.updated_copy(medium=medium_gain)
assert scene.allow_gain

# Test structure with anisotropic gain medium
struct = td.Structure(geometry=td.Box(center=(0, 0, 0), size=(1, 1, 1)), medium=medium_ani)
struct_gain = struct.updated_copy(medium=medium_gain_ani)
scene = td.Scene(
medium=medium,
structures=[struct],
)
assert not scene.allow_gain
scene = scene.updated_copy(structures=[struct_gain])
assert scene.allow_gain


def test_perturbed_mediums_copy():

# Non-dispersive
Expand Down Expand Up @@ -272,3 +246,36 @@ def test_perturbed_mediums_copy():

assert isinstance(new_scene.medium, td.CustomMedium)
assert isinstance(new_scene.structures[0].medium, td.CustomPoleResidue)


def test_max_geometry_validation():
too_many = [td.Box(size=(1, 1, 1)) for _ in range(MAX_GEOMETRY_COUNT + 1)]

fine = [
td.Structure(
geometry=td.ClipOperation(
operation="union",
geometry_a=td.Box(size=(1, 1, 1)),
geometry_b=td.GeometryGroup(geometries=too_many),
),
medium=td.Medium(permittivity=2.0),
),
td.Structure(
geometry=td.GeometryGroup(geometries=too_many),
medium=td.Medium(permittivity=2.0),
),
]
_ = td.Scene(structures=fine)

not_fine = [
td.Structure(
geometry=td.ClipOperation(
operation="difference",
geometry_a=td.Box(size=(1, 1, 1)),
geometry_b=td.GeometryGroup(geometries=too_many),
),
medium=td.Medium(permittivity=2.0),
),
]
with pytest.raises(pd.ValidationError, match=f" {MAX_GEOMETRY_COUNT + 2} "):
_ = td.Scene(structures=not_fine)
28 changes: 3 additions & 25 deletions tests/test_components/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tidy3d as td
from tidy3d.exceptions import SetupError, ValidationError, Tidy3dKeyError
from tidy3d.components import simulation
from tidy3d.components.simulation import MAX_NUM_MEDIUMS
from tidy3d.components.scene import MAX_NUM_MEDIUMS, MAX_GEOMETRY_COUNT
from ..utils import assert_log_level, SIM_FULL, log_capture, run_emulated
from tidy3d.constants import LARGE_NUMBER

Expand Down Expand Up @@ -492,7 +492,6 @@ def test_validate_zero_dim_boundaries(log_capture):
def test_validate_components_none():

assert SIM._structures_not_at_edges(val=None, values=SIM.dict()) is None
assert SIM._validate_num_mediums(val=None) is None
assert SIM._warn_monitor_mediums_frequency_range(val=None, values=SIM.dict()) is None
assert SIM._warn_monitor_simulation_frequency_range(val=None, values=SIM.dict()) is None
assert SIM._warn_grid_size_too_small(val=None, values=SIM.dict()) is None
Expand Down Expand Up @@ -537,7 +536,7 @@ def test_validate_mnt_size(monkeypatch, log_capture):

def test_max_geometry_validation():
gs = td.GridSpec(wavelength=1.0)
too_many = [td.Box(size=(1, 1, 1)) for _ in range(simulation.MAX_GEOMETRY_COUNT + 1)]
too_many = [td.Box(size=(1, 1, 1)) for _ in range(MAX_GEOMETRY_COUNT + 1)]

fine = [
td.Structure(
Expand Down Expand Up @@ -565,7 +564,7 @@ def test_max_geometry_validation():
medium=td.Medium(permittivity=2.0),
),
]
with pytest.raises(pydantic.ValidationError, match=f" {simulation.MAX_GEOMETRY_COUNT + 2} "):
with pytest.raises(pydantic.ValidationError, match=f" {MAX_GEOMETRY_COUNT + 2} "):
_ = td.Simulation(size=(1, 1, 1), run_time=1, grid_spec=gs, structures=not_fine)


Expand All @@ -581,7 +580,6 @@ def test_plot_structure():

def test_plot_eps():
ax = SIM_FULL.plot_eps(x=0)
SIM_FULL._add_cbar(eps_min=1, eps_max=2, ax=ax)
plt.close()


Expand Down Expand Up @@ -724,26 +722,6 @@ def test_discretize_non_intersect(log_capture):
assert_log_level(log_capture, "ERROR")


def test_filter_structures():
s1 = td.Structure(geometry=td.Box(size=(1, 1, 1)), medium=SIM.medium)
s2 = td.Structure(geometry=td.Box(size=(1, 1, 1), center=(1, 1, 1)), medium=SIM.medium)
plane = td.Box(center=(0, 0, 1.5), size=(td.inf, td.inf, 0))
SIM._filter_structures_plane(structures=[s1, s2], plane=plane)


def test_get_structure_plot_params():
pp = SIM_FULL._get_structure_plot_params(mat_index=0, medium=SIM_FULL.medium)
assert pp.facecolor == "white"
pp = SIM_FULL._get_structure_plot_params(mat_index=1, medium=td.PEC)
assert pp.facecolor == "gold"
pp = SIM_FULL._get_structure_eps_plot_params(
medium=SIM_FULL.medium, freq=1, eps_min=1, eps_max=2
)
assert float(pp.facecolor) == 1.0
pp = SIM_FULL._get_structure_eps_plot_params(medium=td.PEC, freq=1, eps_min=1, eps_max=2)
assert pp.facecolor == "gold"


def test_warn_sim_background_medium_freq_range(log_capture):
_ = SIM.copy(
update=dict(
Expand Down
2 changes: 0 additions & 2 deletions tidy3d/components/base_sim/data/sim_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ class AbstractSimulationData(Tidy3dBaseModel, ABC):

def __getitem__(self, monitor_name: str) -> AbstractMonitorData:
"""Get a :class:`.AbstractMonitorData` by name. Apply symmetry if applicable."""
if monitor_name not in self.monitor_data:
raise DataError(f"'{self.type}' does not contain data for monitor '{monitor_name}'.")
monitor_data = self.monitor_data[monitor_name]
return monitor_data.symmetry_expanded_copy

Expand Down
143 changes: 1 addition & 142 deletions tidy3d/components/base_sim/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,10 @@
import numpy as np

from ..types import ArrayFloat1D, Numpy
from ..types import Direction, Axis, BoxSurface
from ..types import Axis
from ..geometry.base import Box
from ..validators import assert_plane
from ..base import cached_property
from ..viz import PlotParams, plot_params_monitor
from ...constants import SECOND
from ...exceptions import SetupError
from ...log import log


class AbstractMonitor(Box, ABC):
Expand Down Expand Up @@ -94,140 +90,3 @@ def downsampled_num_cells(self, num_cells: Tuple[int, int, int]) -> Tuple[int, i
"""
arrs = [np.arange(ncells) for ncells in num_cells]
return tuple((self.downsample(arr, axis=dim).size for dim, arr in enumerate(arrs)))


class AbstractTimeMonitor(AbstractMonitor, ABC):
"""Abstract base class for transient monitors."""

start: pd.NonNegativeFloat = pd.Field(
0.0,
title="Start time",
description="Time at which to start monitor recording.",
units=SECOND,
)

stop: pd.NonNegativeFloat = pd.Field(
None,
title="Stop time",
description="Time at which to stop monitor recording. "
"If not specified, record until end of simulation.",
units=SECOND,
)

interval: pd.PositiveInt = pd.Field(
1,
title="Time interval",
description="Number of time step intervals between monitor recordings.",
)

@pd.validator("stop", always=True, allow_reuse=True)
def stop_greater_than_start(cls, val, values):
"""Ensure sure stop is greater than or equal to start."""
start = values.get("start")
if val and val < start:
raise SetupError("Monitor start time is greater than stop time.")
return val

def time_inds(self, tmesh: ArrayFloat1D) -> Tuple[int, int]:
"""Compute the starting and stopping index of the monitor in a given discrete time mesh."""

tmesh = np.array(tmesh)
tind_beg, tind_end = (0, 0)

if tmesh.size == 0:
return (tind_beg, tind_end)

# If monitor.stop is None, record until the end
t_stop = self.stop
if t_stop is None:
tind_end = int(tmesh.size)
t_stop = tmesh[-1]
else:
tend = np.nonzero(tmesh <= t_stop)[0]
if tend.size > 0:
tind_end = int(tend[-1] + 1)

# Step to compare to in order to handle t_start = t_stop
dt = 1e-20 if np.array(tmesh).size < 2 else tmesh[1] - tmesh[0]
# If equal start and stopping time, record one time step
if np.abs(self.start - t_stop) < dt and self.start <= tmesh[-1]:
tind_beg = max(tind_end - 1, 0)
else:
tbeg = np.nonzero(tmesh[:tind_end] >= self.start)[0]
tind_beg = tbeg[0] if tbeg.size > 0 else tind_end
return (tind_beg, tind_end)

def num_steps(self, tmesh: ArrayFloat1D) -> int:
"""Compute number of time steps for a time monitor."""

tind_beg, tind_end = self.time_inds(tmesh)
return int((tind_end - tind_beg) / self.interval)


class AbstractPlanarMonitor(AbstractMonitor, ABC):
""":class:`AbstractMonitor` that has a planar geometry."""

_plane_validator = assert_plane()

@cached_property
def normal_axis(self) -> Axis:
"""Axis normal to the monitor's plane."""
return self.size.index(0.0)


class AbstractSurfaceIntegrationMonitor(AbstractMonitor, ABC):
"""Abstract class for monitors that perform surface integrals during the solver run."""

normal_dir: Direction = pd.Field(
None,
title="Normal vector orientation",
description="Direction of the surface monitor's normal vector w.r.t. "
"the positive x, y or z unit vectors. Must be one of ``'+'`` or ``'-'``. "
"Applies to surface monitors only, and defaults to ``'+'`` if not provided.",
)

exclude_surfaces: Tuple[BoxSurface, ...] = pd.Field(
None,
title="Excluded surfaces",
description="Surfaces to exclude in the integration, if a volume monitor.",
)

@property
def integration_surfaces(self):
"""Surfaces of the monitor where fields will be recorded for subsequent integration."""
if self.size.count(0.0) == 0:
return self.surfaces_with_exclusion(**self.dict())
return [self]

@pd.root_validator(skip_on_failure=True)
def normal_dir_exists_for_surface(cls, values):
"""If the monitor is a surface, set default ``normal_dir`` if not provided.
If the monitor is a box, warn that ``normal_dir`` is relevant only for surfaces."""
normal_dir = values.get("normal_dir")
name = values.get("name")
size = values.get("size")
if size.count(0.0) != 1:
if normal_dir is not None:
log.warning(
"The ``normal_dir`` field is relevant only for surface monitors "
f"and will be ignored for monitor {name}, which is a box."
)
else:
if normal_dir is None:
values["normal_dir"] = "+"
return values

@pd.root_validator(skip_on_failure=True)
def check_excluded_surfaces(cls, values):
"""Error if ``exclude_surfaces`` is provided for a surface monitor."""
exclude_surfaces = values.get("exclude_surfaces")
if exclude_surfaces is None:
return values
name = values.get("name")
size = values.get("size")
if size.count(0.0) > 0:
raise SetupError(
f"Can't specify ``exclude_surfaces`` for surface monitor {name}; "
"valid for box monitors only."
)
return values
Loading

0 comments on commit 9511a2d

Please sign in to comment.