diff --git a/tidy3d/components/simulation.py b/tidy3d/components/simulation.py index c9c0d6f79..fc3e24c22 100644 --- a/tidy3d/components/simulation.py +++ b/tidy3d/components/simulation.py @@ -671,9 +671,14 @@ def plot_structures( """ medium_shapes = self._filter_structures_plane(self.structures, x=x, y=y, z=z) + medium_map = self.medium_map + for (medium, shape) in medium_shapes: if medium != self.medium: - ax = self._plot_shape_structure(medium=medium, shape=shape, ax=ax) + mat_index = medium_map[medium] + ax = self._plot_shape_structure( + medium=medium, mat_index=mat_index, shape=shape, ax=ax + ) ax = self._set_plot_bounds(ax=ax, x=x, y=y, z=z) @@ -684,20 +689,18 @@ def plot_structures( return ax - def _plot_shape_structure(self, medium: Medium, shape: Shapely, ax: Ax) -> Ax: + def _plot_shape_structure(self, medium: Medium, mat_index: int, shape: Shapely, ax: Ax) -> Ax: """Plot a structure's cross section shape for a given medium.""" - plot_params_struct = self._get_structure_plot_params(medium=medium) + plot_params_struct = self._get_structure_plot_params(medium=medium, mat_index=mat_index) ax = self.plot_shape(shape=shape, plot_params=plot_params_struct, ax=ax) return ax - def _get_structure_plot_params(self, medium: Medium) -> PlotParams: + def _get_structure_plot_params(self, mat_index: int, medium: Medium) -> PlotParams: """Constructs the plot parameters for a given medium in simulation.plot().""" plot_params = plot_params_structure.copy(deep=True) plot_params.linewidth = 0 - mat_index = self.medium_map[medium] - if mat_index == 0 or medium == self.medium: # background medium plot_params.facecolor = "white" @@ -772,7 +775,14 @@ def plot_structures_eps( # pylint: disable=too-many-arguments,too-many-locals for (medium, shape) in medium_shapes: if medium != self.medium: ax = self._plot_shape_structure_eps( - freq=freq, alpha=alpha, medium=medium, reverse=reverse, shape=shape, ax=ax + freq=freq, + alpha=alpha, + medium=medium, + eps_min=eps_min, + eps_max=eps_max, + reverse=reverse, + shape=shape, + ax=ax, ) if cbar: @@ -789,7 +799,7 @@ def plot_structures_eps( # pylint: disable=too-many-arguments,too-many-locals def eps_bounds(self, freq: float = None) -> Tuple[float, float]: """Compute range of (real) permittivity present in the simulation at frequency "freq".""" - medium_list = [self.medium] + [structure.medium for structure in self.structures] + medium_list = [self.medium] + self.mediums medium_list = [medium for medium in medium_list if not isinstance(medium, PECMedium)] eps_list = [medium.eps_model(freq).real for medium in medium_list] eps_min = min(1, min(eps_list)) @@ -797,7 +807,13 @@ def eps_bounds(self, freq: float = None) -> Tuple[float, float]: return eps_min, eps_max def _get_structure_eps_plot_params( - self, medium: Medium, freq: float, reverse: bool = False, alpha: float = None + self, + medium: Medium, + freq: float, + eps_min: float, + eps_max: float, + reverse: bool = False, + alpha: float = None, ) -> PlotParams: """Constructs the plot parameters for a given medium in simulation.plot_eps().""" @@ -817,7 +833,6 @@ def _get_structure_eps_plot_params( plot_params.linewidth = 1 else: # regular medium - eps_min, eps_max = self.eps_bounds(freq=freq) eps_medium = medium.eps_model(frequency=freq).real delta_eps = eps_medium - eps_min delta_eps_max = eps_max - eps_min + 1e-5 @@ -832,13 +847,15 @@ def _plot_shape_structure_eps( freq: float, medium: Medium, shape: Shapely, + eps_min: float, + eps_max: float, ax: Ax, reverse: bool = False, alpha: float = None, ) -> Ax: """Plot a structure's cross section shape for a given medium, grayscale for permittivity.""" plot_params = self._get_structure_eps_plot_params( - medium=medium, freq=freq, alpha=alpha, reverse=reverse + medium=medium, freq=freq, eps_min=eps_min, eps_max=eps_max, alpha=alpha, reverse=reverse ) ax = self.plot_shape(shape=shape, plot_params=plot_params, ax=ax) return ax @@ -1055,7 +1072,9 @@ def _make_symmetry_box(self, sym_axis: Axis) -> Box: return Box(size=size, center=center) @add_ax_if_none - def plot_grid(self, x: float = None, y: float = None, z: float = None, ax: Ax = None) -> Ax: + def plot_grid( # pylint:disable=too-many-locals + self, x: float = None, y: float = None, z: float = None, ax: Ax = None + ) -> Ax: """Plot the cell boundaries as lines on a plane defined by one nonzero x,y,z coordinate. Parameters @@ -1082,14 +1101,14 @@ def plot_grid(self, x: float = None, y: float = None, z: float = None, ax: Ax = _, (xmin, ymin) = self.pop_axis(self.bounds_pml[0], axis=axis) _, (xmax, ymax) = self.pop_axis(self.bounds_pml[1], axis=axis) segs_x = [((bound, ymin), (bound, ymax)) for bound in boundaries_x] - line_segments_x = mpl.collections.LineCollection(segs_x, linewidths=0.2, colors='k') + line_segments_x = mpl.collections.LineCollection(segs_x, linewidths=0.2, colors="k") segs_y = [((xmin, bound), (xmax, bound)) for bound in boundaries_y] - line_segments_y = mpl.collections.LineCollection(segs_y, linewidths=0.2, colors='k') - + line_segments_y = mpl.collections.LineCollection(segs_y, linewidths=0.2, colors="k") + ax.add_collection(line_segments_x) ax.add_collection(line_segments_y) ax = self._set_plot_bounds(ax=ax, x=x, y=y, z=z) - + return ax def _set_plot_bounds(self, ax: Ax, x: float = None, y: float = None, z: float = None) -> Ax: