Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactored plotting and plotly plugin and simdata viewer app. #281

Merged
merged 8 commits into from
Mar 31, 2022
6 changes: 4 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ def prepend_tmp(path):
geometry=Box(size=(1, 1, 1), center=(0, 0, 0)),
medium=Medium(permittivity=1.0, conductivity=3.0),
),
Structure(geometry=Sphere(radius=1.4, center=(1.0, 0.0, 1.0)), medium=Medium()),
Structure(
geometry=Sphere(radius=1.4, center=(1.0, 0.0, 1.0)), medium=Medium(permittivity=6.0)
),
Structure(
geometry=Cylinder(radius=1.4, length=2.0, center=(1.0, 0.0, -1.0), axis=1),
medium=Medium(),
medium=Medium(permittivity=5.0),
),
Structure(
geometry=PolySlab(
Expand Down
12 changes: 2 additions & 10 deletions tidy3d/components/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,6 @@ def plot_field_array(
eps_alpha: float = 0.2,
robust: bool = True,
ax: Ax = None,
**patch_kwargs,
) -> Ax:
"""Plot the field data for a monitor with simulation plot overlayed.

Expand All @@ -886,8 +885,6 @@ def plot_field_array(
This helps in visualizing the field patterns especially in the presence of a source.
ax : matplotlib.axes._subplots.Axes = None
matplotlib axes to plot on, if not specified, one is created.
**patch_kwargs
Optional keyword arguments passed to ``add_artist(patch, **patch_kwargs)``.

Returns
-------
Expand Down Expand Up @@ -919,10 +916,10 @@ def plot_field_array(

if val == "abs":
cmap = "magma"
eps_reverse = False
eps_reverse = True
else:
cmap = "RdBu"
eps_reverse = True
eps_reverse = False

# plot the field
xy_coord_labels = list("xyz")
Expand All @@ -938,7 +935,6 @@ def plot_field_array(
reverse=eps_reverse,
ax=ax,
**{axis_label: position},
**patch_kwargs,
)

# set the limits based on the xarray coordinates min and max
Expand Down Expand Up @@ -1067,7 +1063,6 @@ def plot_field(
eps_alpha: float = 0.2,
robust: bool = True,
ax: Ax = None,
**patch_kwargs,
) -> Ax:
"""Plot the field data for a monitor with simulation plot overlayed.

Expand Down Expand Up @@ -1103,8 +1098,6 @@ def plot_field(
This helps in visualizing the field patterns especially in the presence of a source.
ax : matplotlib.axes._subplots.Axes = None
matplotlib axes to plot on, if not specified, one is created.
**patch_kwargs
Optional keyword arguments passed to ``add_artist(patch, **patch_kwargs)``.

Returns
-------
Expand Down Expand Up @@ -1167,7 +1160,6 @@ def plot_field(
eps_alpha=eps_alpha,
robust=robust,
ax=ax,
**patch_kwargs,
)

def normalize(self, normalize_index: int = 0):
Expand Down
114 changes: 94 additions & 20 deletions tidy3d/components/geometry.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
# pylint:disable=too-many-lines
# pylint:disable=too-many-lines, too-many-arguments
"""Defines spatial extent of objects."""

from abc import ABC, abstractmethod
from typing import List, Tuple, Union, Any
from typing import List, Tuple, Union, Any, Callable

import pydantic
import numpy as np

from shapely.geometry import Point, Polygon, box
from shapely.geometry import Point, Polygon, box, MultiPolygon
from descartes import PolygonPatch

from .base import Tidy3dBaseModel
from .types import Bound, Size, Coordinate, Axis, Coordinate2D, tidynumpy, Array
from .types import Vertices, Ax, Shapely
from .viz import add_ax_if_none, equal_aspect
from .viz import PLOT_BUFFER, ARROW_LENGTH_FACTOR, ARROW_WIDTH_FACTOR, MAX_ARROW_WIDTH_FACTOR
from .viz import PlotParams, plot_params_geometry
from ..log import Tidy3dKeyError, SetupError, ValidationError
from ..constants import MICROMETER, LARGE_NUMBER, RADIAN

Expand All @@ -26,6 +27,11 @@
class Geometry(Tidy3dBaseModel, ABC):
"""Abstract base class, defines where something exists in space."""

@property
def plot_params(self):
"""Default parameters for plotting a Geometry object."""
return plot_params_geometry

center: Coordinate = pydantic.Field(
(0.0, 0.0, 0.0),
title="Center",
Expand Down Expand Up @@ -195,7 +201,6 @@ def _pop_bounds(self, axis: Axis) -> Tuple[Coordinate2D, Tuple[Coordinate2D, Coo
def plot(
self, x: float = None, y: float = None, z: float = None, ax: Ax = None, **patch_kwargs
) -> Ax:
# pylint:disable=line-too-long
"""Plot geometry cross section at single (x,y,z) coordinate.

Parameters
Expand All @@ -218,24 +223,93 @@ def plot(
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
"""
# pylint:disable=line-too-long

# find shapes that intersect self at plane
axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z)
shapes_intersect = self.intersections(x=x, y=y, z=z)

plot_params = self.plot_params.include_kwargs(**patch_kwargs)

# for each intersection, plot the shape
for shape in shapes_intersect:
_shape = self.evaluate_inf_shape(shape)
patch = PolygonPatch(_shape, **patch_kwargs)
ax.add_artist(patch)
ax = self.plot_shape(shape, plot_params=plot_params, ax=ax)

# clean up the axis display
ax = self.add_ax_labels_lims(axis=axis, ax=ax)
ax.set_aspect("equal")
ax.set_title(f"cross section at {'xyz'[axis]}={position:.2f}")
return ax

def plot_shape(self, shape: Shapely, plot_params: PlotParams, ax: Ax) -> Ax:
"""Defines how a shape is plotted on a matplotlib axes."""
_shape = self.evaluate_inf_shape(shape)
patch = PolygonPatch(_shape, **plot_params.to_kwargs())
ax.add_artist(patch)
return ax

@classmethod
def strip_coords(
cls, shape: Shapely
) -> Tuple[List[float], List[float], Tuple[List[float], List[float]]]:
"""Get the exterior and list of interior xy coords for a shape.

Parameters
----------
shape: shapely.geometry.base.BaseGeometry
The shape that you want to strip coordinates from.

Returns
-------
Tuple[List[float], List[float], Tuple[List[float], List[float]]]
List of exterior xy coordinates
and a list of lists of the interior xy coordinates of the "holes" in the shape.
"""

if isinstance(shape, Polygon):
ext_coords = shape.exterior.coords[:]
list_int_coords = [interior.coords[:] for interior in shape.interiors]
elif isinstance(shape, MultiPolygon):
all_ext_coords = []
list_all_int_coords = []
for _shape in shape.geoms:
all_ext_coords.append(_shape.exterior.coords[:])
all_int_coords = [_interior.coords[:] for _interior in _shape.interiors]
list_all_int_coords.append(all_int_coords)
ext_coords = np.concatenate(all_ext_coords, axis=0)
list_int_coords = [
np.concatenate(all_int_coords, axis=0) for all_int_coords in list_all_int_coords
]
return ext_coords, list_int_coords

@classmethod
def map_to_coords(cls, func: Callable[[float], float], shape: Shapely) -> Shapely:
"""Maps a function to each coordinate in shape.

Parameters
----------
func : Callable[[float], float]
Takes old coordinate and returns new coordinate.
shape: shapely.geometry.base.BaseGeometry
The shape to map this function to.

Returns
-------
shapely.geometry.base.BaseGeometry
A new copy of the input shape with the mapping applied to the coordinates.
"""

if not isinstance(shape, (Polygon, MultiPolygon)):
return shape

def apply_func(coords):
return [(func(coord_x), func(coord_y)) for (coord_x, coord_y) in coords]

ext_coords, list_int_coords = cls.strip_coords(shape)
new_ext_coords = apply_func(ext_coords)
list_new_int_coords = [apply_func(int_coords) for int_coords in list_int_coords]

return Polygon(new_ext_coords, holes=list_new_int_coords)

def _get_plot_labels(self, axis: Axis) -> Tuple[str, str]:
"""Returns planar coordinate x and y axis labels for cross section plots.

Expand Down Expand Up @@ -293,7 +367,7 @@ def add_ax_labels_lims(self, axis: Axis, ax: Ax, buffer: float = PLOT_BUFFER) ->
(xmin, xmax), (ymin, ymax) = self._get_plot_limits(axis=axis, buffer=buffer)

# note: axes limits dont like inf values, so we need to evaluate them first if present
xmin, xmax, ymin, ymax = self._evaluate_infs(xmin, xmax, ymin, ymax)
xmin, xmax, ymin, ymax = (self._evaluate_inf(v) for v in (xmin, xmax, ymin, ymax))

ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
Expand All @@ -302,23 +376,18 @@ def add_ax_labels_lims(self, axis: Axis, ax: Ax, buffer: float = PLOT_BUFFER) ->
return ax

@staticmethod
def _evaluate_infs(*values):
def _evaluate_inf(v):
"""Processes values and evaluates any infs into large (signed) numbers."""
return map(lambda v: v if not np.isinf(v) else np.sign(v) * LARGE_NUMBER, values)
return v if not np.isinf(v) else np.sign(v) * LARGE_NUMBER / 2.0

@classmethod
def evaluate_inf_shape(cls, shape: "shapely.Geometry") -> "shapely.Geometry":
def evaluate_inf_shape(cls, shape: Shapely) -> Shapely:
"""Returns a copy of shape with inf vertices replaced by large numbers if polygon."""

if not isinstance(shape, Polygon):
return shape

coords = shape.exterior.coords[:]
new_coords = []
for (coord_x, coord_y) in coords:
new_coord = tuple(cls._evaluate_infs(coord_x, coord_y))
new_coords.append(new_coord)
return Polygon(new_coords)
return cls.map_to_coords(cls._evaluate_inf, shape)

@staticmethod
def pop_axis(coord: Tuple[Any, Any, Any], axis: int) -> Tuple[Any, Tuple[Any, Any]]:
Expand Down Expand Up @@ -792,8 +861,13 @@ def _plot_arrow( # pylint:disable=too-many-arguments, too-many-locals
arrow_axis = [component == 0 for component in direction]
arrow_length, arrow_width = self._arrow_dims(ax, length_factor, width_factor)

# only add arrow if the plotting plane is perpendicular to the source
if arrow_axis.count(0.0) > 1 or arrow_axis.index(0.0) != plot_axis:
# conditions to check to determine whether to plot arrow
arrow_intersecting_plane = len(self.intersections(x=x, y=y, z=z)) > 0
arrow_perp_to_screen = arrow_axis.index(0.0) != plot_axis
arrow_not_cartesian_axis = arrow_axis.count(0.0) > 1

# plot if arrow in plotting plane and some non-zero component can be displayed.
if arrow_intersecting_plane and (arrow_not_cartesian_axis or arrow_perp_to_screen):
_, (x0, y0) = self.pop_axis(self.center, axis=plot_axis)
_, (dx, dy) = self.pop_axis(direction, axis=plot_axis)

Expand Down
17 changes: 6 additions & 11 deletions tidy3d/components/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .geometry import Box
from .validators import assert_plane
from .mode import ModeSpec
from .viz import add_ax_if_none, equal_aspect, MonitorParams, ARROW_COLOR_MONITOR, ARROW_ALPHA
from .viz import PlotParams, plot_params_monitor, ARROW_COLOR_MONITOR, ARROW_ALPHA
from ..log import SetupError
from ..constants import HERTZ, SECOND

Expand All @@ -28,18 +28,13 @@ class Monitor(Box, ABC):
min_length=1,
)

@equal_aspect
@add_ax_if_none
def plot( # pylint:disable=duplicate-code
self, x: float = None, y: float = None, z: float = None, ax: Ax = None, **kwargs
) -> Ax:

kwargs = MonitorParams().update_params(**kwargs)
ax = self.geometry.plot(x=x, y=y, z=z, ax=ax, **kwargs)
return ax
@property
def plot_params(self) -> PlotParams:
"""Default parameters for plotting a Monitor object."""
return plot_params_monitor

@property
def geometry(self):
def geometry(self) -> Box:
""":class:`Box` representation of monitor.

Returns
Expand Down
Loading