Skip to content

Commit

Permalink
made plots fast again by storing re-computed properties
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Mar 31, 2022
1 parent 395a317 commit 24d8631
Showing 1 changed file with 35 additions and 16 deletions.
51 changes: 35 additions & 16 deletions tidy3d/components/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,9 +671,14 @@ def plot_structures(
"""

medium_shapes = self._filter_structures_plane(self.structures, x=x, y=y, z=z)
medium_map = self.medium_map

for (medium, shape) in medium_shapes:
if medium != self.medium:
ax = self._plot_shape_structure(medium=medium, shape=shape, ax=ax)
mat_index = medium_map[medium]
ax = self._plot_shape_structure(
medium=medium, mat_index=mat_index, shape=shape, ax=ax
)

ax = self._set_plot_bounds(ax=ax, x=x, y=y, z=z)

Expand All @@ -684,20 +689,18 @@ def plot_structures(

return ax

def _plot_shape_structure(self, medium: Medium, shape: Shapely, ax: Ax) -> Ax:
def _plot_shape_structure(self, medium: Medium, mat_index: int, shape: Shapely, ax: Ax) -> Ax:
"""Plot a structure's cross section shape for a given medium."""
plot_params_struct = self._get_structure_plot_params(medium=medium)
plot_params_struct = self._get_structure_plot_params(medium=medium, mat_index=mat_index)
ax = self.plot_shape(shape=shape, plot_params=plot_params_struct, ax=ax)
return ax

def _get_structure_plot_params(self, medium: Medium) -> PlotParams:
def _get_structure_plot_params(self, mat_index: int, medium: Medium) -> PlotParams:
"""Constructs the plot parameters for a given medium in simulation.plot()."""

plot_params = plot_params_structure.copy(deep=True)
plot_params.linewidth = 0

mat_index = self.medium_map[medium]

if mat_index == 0 or medium == self.medium:
# background medium
plot_params.facecolor = "white"
Expand Down Expand Up @@ -772,7 +775,14 @@ def plot_structures_eps( # pylint: disable=too-many-arguments,too-many-locals
for (medium, shape) in medium_shapes:
if medium != self.medium:
ax = self._plot_shape_structure_eps(
freq=freq, alpha=alpha, medium=medium, reverse=reverse, shape=shape, ax=ax
freq=freq,
alpha=alpha,
medium=medium,
eps_min=eps_min,
eps_max=eps_max,
reverse=reverse,
shape=shape,
ax=ax,
)

if cbar:
Expand All @@ -789,15 +799,21 @@ def plot_structures_eps( # pylint: disable=too-many-arguments,too-many-locals
def eps_bounds(self, freq: float = None) -> Tuple[float, float]:
"""Compute range of (real) permittivity present in the simulation at frequency "freq"."""

medium_list = [self.medium] + [structure.medium for structure in self.structures]
medium_list = [self.medium] + self.mediums
medium_list = [medium for medium in medium_list if not isinstance(medium, PECMedium)]
eps_list = [medium.eps_model(freq).real for medium in medium_list]
eps_min = min(1, min(eps_list))
eps_max = max(1, max(eps_list))
return eps_min, eps_max

def _get_structure_eps_plot_params(
self, medium: Medium, freq: float, reverse: bool = False, alpha: float = None
self,
medium: Medium,
freq: float,
eps_min: float,
eps_max: float,
reverse: bool = False,
alpha: float = None,
) -> PlotParams:
"""Constructs the plot parameters for a given medium in simulation.plot_eps()."""

Expand All @@ -817,7 +833,6 @@ def _get_structure_eps_plot_params(
plot_params.linewidth = 1
else:
# regular medium
eps_min, eps_max = self.eps_bounds(freq=freq)
eps_medium = medium.eps_model(frequency=freq).real
delta_eps = eps_medium - eps_min
delta_eps_max = eps_max - eps_min + 1e-5
Expand All @@ -832,13 +847,15 @@ def _plot_shape_structure_eps(
freq: float,
medium: Medium,
shape: Shapely,
eps_min: float,
eps_max: float,
ax: Ax,
reverse: bool = False,
alpha: float = None,
) -> Ax:
"""Plot a structure's cross section shape for a given medium, grayscale for permittivity."""
plot_params = self._get_structure_eps_plot_params(
medium=medium, freq=freq, alpha=alpha, reverse=reverse
medium=medium, freq=freq, eps_min=eps_min, eps_max=eps_max, alpha=alpha, reverse=reverse
)
ax = self.plot_shape(shape=shape, plot_params=plot_params, ax=ax)
return ax
Expand Down Expand Up @@ -1055,7 +1072,9 @@ def _make_symmetry_box(self, sym_axis: Axis) -> Box:
return Box(size=size, center=center)

@add_ax_if_none
def plot_grid(self, x: float = None, y: float = None, z: float = None, ax: Ax = None) -> Ax:
def plot_grid( # pylint:disable=too-many-locals
self, x: float = None, y: float = None, z: float = None, ax: Ax = None
) -> Ax:
"""Plot the cell boundaries as lines on a plane defined by one nonzero x,y,z coordinate.
Parameters
Expand All @@ -1082,14 +1101,14 @@ def plot_grid(self, x: float = None, y: float = None, z: float = None, ax: Ax =
_, (xmin, ymin) = self.pop_axis(self.bounds_pml[0], axis=axis)
_, (xmax, ymax) = self.pop_axis(self.bounds_pml[1], axis=axis)
segs_x = [((bound, ymin), (bound, ymax)) for bound in boundaries_x]
line_segments_x = mpl.collections.LineCollection(segs_x, linewidths=0.2, colors='k')
line_segments_x = mpl.collections.LineCollection(segs_x, linewidths=0.2, colors="k")
segs_y = [((xmin, bound), (xmax, bound)) for bound in boundaries_y]
line_segments_y = mpl.collections.LineCollection(segs_y, linewidths=0.2, colors='k')
line_segments_y = mpl.collections.LineCollection(segs_y, linewidths=0.2, colors="k")

ax.add_collection(line_segments_x)
ax.add_collection(line_segments_y)
ax = self._set_plot_bounds(ax=ax, x=x, y=y, z=z)

return ax

def _set_plot_bounds(self, ax: Ax, x: float = None, y: float = None, z: float = None) -> Ax:
Expand Down

0 comments on commit 24d8631

Please sign in to comment.