Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

various additions and improvements to data #83

Merged
merged 1 commit into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 129 additions & 77 deletions tidy3d/components/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,21 @@ def decode_bytes_array(array_of_bytes: Numpy) -> List[str]:
return list_of_str


""" Base Classes """
""" xarray subclasses """


class Tidy3dDataArray(xr.DataArray):
"""Subclass of xarray's DataArray that implements some custom functions."""

__slots__ = ()

@property
def abs(self):
"""Absolute value of complex-valued data."""
return abs(self)


""" Base data classes """


class Tidy3dData(Tidy3dBaseModel):
Expand All @@ -63,7 +77,7 @@ class Config: # pylint: disable=too-few-public-methods
json_encoders = { # how to write certain types to json files
np.ndarray: numpy_encoding, # use custom encoding defined in .types
np.int64: lambda x: int(x), # pylint: disable=unnecessary-lambda
xr.DataArray: lambda x: None, # dont write
Tidy3dDataArray: lambda x: None, # dont write
xr.Dataset: lambda x: None, # dont write
}

Expand Down Expand Up @@ -106,7 +120,7 @@ class MonitorData(Tidy3dData, ABC):
"""

@property
def data(self) -> xr.DataArray:
def data(self) -> Tidy3dDataArray:
# pylint:disable=line-too-long
"""Returns an xarray representation of the montitor data.

Expand All @@ -120,7 +134,7 @@ def data(self) -> xr.DataArray:

data_dict = self.dict()
coords = {dim: data_dict[dim] for dim in self._dims}
return xr.DataArray(self.values, coords=coords)
return Tidy3dDataArray(self.values, coords=coords)

def __eq__(self, other) -> bool:
"""Check equality against another MonitorData instance.
Expand Down Expand Up @@ -204,8 +218,8 @@ def data(self) -> xr.Dataset:
data_arrays = {name: arr.data for name, arr in self.data_dict.items()}

# make an xarray dataset
# return xr.Dataset(data_arrays) # datasets are annoying
return data_arrays
return xr.Dataset(data_arrays) # datasets are annoying
# return data_arrays

def __eq__(self, other):
"""Check for equality against other :class:`CollectionData` object."""
Expand Down Expand Up @@ -488,9 +502,11 @@ class SimulationData(Tidy3dBaseModel):
@property
def log(self):
"""Prints the server-side log."""
print(self.log_string if self.log_string else "no log stored")
if not self.log_string:
raise DataError("No log stored in SimulationData.")
return self.log_string

def __getitem__(self, monitor_name: str) -> Union[xr.DataArray, xr.Dataset]:
def __getitem__(self, monitor_name: str) -> Union[Tidy3dDataArray, xr.Dataset]:
"""Get the :class:`MonitorData` xarray representation by name (``sim_data[monitor_name]``).

Parameters
Expand All @@ -508,75 +524,111 @@ def __getitem__(self, monitor_name: str) -> Union[xr.DataArray, xr.Dataset]:
raise DataError(f"monitor {monitor_name} not found")
return monitor_data.data

# @add_ax_if_none
# def plot_field(
# self,
# field_monitor_name: str,
# field_name: str,
# x: float = None,
# y: float = None,
# z: float = None,
# freq: float = None,
# time: float = None,
# eps_alpha: pydantic.confloat(ge=0.0, le=1.0) = 0.5,
# ax: Ax = None,
# **kwargs,
# ) -> Ax:
# """Plot the field data for a monitor with simulation plot overlayed.

# Parameters
# ----------
# field_monitor_name : ``str``
# Name of :class:`FieldMonitor` or :class:`FieldTimeData` to plot.
# field_name : ``str``
# Name of `field` in monitor to plot (eg. 'Ex').
# x : ``float``, optional
# Position of plane in x direction.
# y : ``float``, optional
# Position of plane in y direction.
# z : ``float``, optional
# Position of plane in z direction.
# freq: ``float``, optional
# if monitor is a :class:`FieldMonitor`, specifies the frequency (Hz) to plot the field.
# time: ``float``, optional
# if monitor is a :class:`FieldTimeMonitor`, specifies the time (sec) to plot the field.
# cbar: `bool``, optional
# if True (default), will include colorbar
# ax : ``matplotlib.axes._subplots.Axes``, optional
# matplotlib axes to plot on, if not specified, one is created.
# **patch_kwargs
# Optional keyword arguments passed to ``add_artist(patch, **patch_kwargs)``.

# Returns
# -------
# ``matplotlib.axes._subplots.Axes``
# The supplied or created matplotlib axes.

# TODO: fully test and finalize arguments.
# """

# if field_monitor_name not in self.monitor_data:
# raise DataError(f"field_monitor_name {field_monitor_name} not found in SimulationData.")

# monitor_data = self.monitor_data.get(field_monitor_name)

# if not isinstance(monitor_data, FieldData):
# raise DataError(f"field_monitor_name {field_monitor_name} not a FieldData instance.")

# if field_name not in monitor_data.data_dict:
# raise DataError(f"field_name {field_name} not found in {field_monitor_name}.")

# xr_data = monitor_data.data_dict.get(field_name)
# if isinstance(monitor_data, FieldData):
# field_data = xr_data.sel(f=freq)
# else:
# field_data = xr_data.sel(t=time)

# ax = field_data.sel(x=x, y=y, z=z).real.plot.pcolormesh(ax=ax)
# ax = self.simulation.plot_structures_eps(
# freq=freq, cbar=False, x=x, y=y, z=z, alpha=eps_alpha, ax=ax
# )
# return ax
@add_ax_if_none
def plot_field(
self,
field_monitor_name: str,
field_name: str,
x: float = None,
y: float = None,
z: float = None,
val: Literal["real", "imag", "abs"] = "real",
freq: float = None,
time: float = None,
cbar: bool = None,
eps_alpha: float = 0.2,
ax: Ax = None,
**kwargs,
) -> Ax:
"""Plot the field data for a monitor with simulation plot overlayed.

Parameters
----------
field_monitor_name : str
Name of :class:`FieldMonitor` or :class:`FieldTimeData` to plot.
field_name : str
Name of `field` in monitor to plot (eg. 'Ex').
x : float = None
Position of plane in x direction.
y : float = None
Position of plane in y direction.
z : float = None
Position of plane in z direction.
val : Literal['real', 'imag', 'abs'] = 'real'
What part of the field to plot (in )
freq: float = None
If monitor is a :class:`FieldMonitor`, specifies the frequency (Hz) to plot the field.
time: float = None
if monitor is a :class:`FieldTimeMonitor`, specifies the time (sec) to plot the field.
cbar: bool = True
if True (default), will include colorbar
eps_alpha : float = 0.2
Opacity of the structure permittivity.
Must be between 0 and 1 (inclusive).
ax : matplotlib.axes._subplots.Axes = None
matplotlib axes to plot on, if not specified, one is created.
**patch_kwargs
Optional keyword arguments passed to ``add_artist(patch, **patch_kwargs)``.

Returns
-------
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
"""

# get the monitor data
if field_monitor_name not in self.monitor_data:
raise DataError(f"Monitor named '{field_monitor_name}' not found.")
monitor_data = self.monitor_data.get(field_monitor_name)
if not isinstance(monitor_data, FieldData):
raise DataError(f"field_monitor_name '{field_monitor_name}' not a FieldData instance.")

# get the field data component
if field_name not in monitor_data.data_dict:
raise DataError(f"field_name {field_name} not found in {field_monitor_name}.")
xr_data = monitor_data.data_dict.get(field_name).data

# select the frequency or time value
if "f" in monitor_data.coords:
if freq is None:
raise DataError("'freq' must be supplied to plot a FieldMonitor.")
field_data = xr_data.interp(f=freq)
elif "t" in monitor_data.coords:
if time is None:
raise DataError("'time' must be supplied to plot a FieldMonitor.")
field_data = xr_data.interp(t=time)
else:
raise DataError("Field data has neither time nor frequency data, something went wrong.")

# select the cross section data
axis, pos = self.simulation.parse_xyz_kwargs(x=x, y=y, z=z)
axis_label = "xyz"[axis]
sel_kwarg = {axis_label: pos}
try:
field_data = field_data.sel(**sel_kwarg)
except Exception as e:
raise DataError(f"Could not select data at {axis_label}={pos}.") from e

# select the field value
if val not in ("real", "imag", "abs"):
raise DataError(f"'val' must be one of ``{'real', 'imag', 'abs'}``, given {val}")
if val == "real":
field_data = field_data.real
elif val == "imag":
field_data = field_data.imag
elif val == "real":
field_data = abs(field_data)

# plot the field
xy_coords = list("xyz")
xy_coords.pop(axis)
field_data.plot(ax=ax, x=xy_coords[0], y=xy_coords[1])

# plot the simulation epsilon
ax = self.simulation.plot_structures_eps(
freq=freq, cbar=cbar, x=x, y=y, z=z, alpha=eps_alpha, ax=ax
)
return ax

def export(self, fname: str) -> None:
"""Export :class:`SimulationData` to single hdf5 file including monitor data.
Expand Down
10 changes: 5 additions & 5 deletions tidy3d/components/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def plot(
# pylint:disable=line-too-long

# find shapes that intersect self at plane
axis, position = self._parse_xyz_kwargs(x=x, y=y, z=z)
axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z)
shapes_intersect = self.intersections(x=x, y=y, z=z)

# for each intersection, plot the shape
Expand Down Expand Up @@ -328,7 +328,7 @@ def unpop_axis(ax_coord: Any, plane_coords: Tuple[Any, Any], axis: int) -> Tuple
return tuple(coords)

@staticmethod
def _parse_xyz_kwargs(**xyz) -> Tuple[Axis, float]:
def parse_xyz_kwargs(**xyz) -> Tuple[Axis, float]:
"""Turns x,y,z kwargs into index of the normal axis and position along that axis.

Parameters
Expand Down Expand Up @@ -380,7 +380,7 @@ def intersections(self, x: float = None, y: float = None, z: float = None):
For more details refer to
`Shapely's Documentaton <https://shapely.readthedocs.io/en/stable/project.html>`_.
"""
axis, position = self._parse_xyz_kwargs(x=x, y=y, z=z)
axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z)
if axis == self.axis:
z0, _ = self.pop_axis(self.center, axis=self.axis)
if (position < z0 - self.length / 2) or (position > z0 + self.length / 2):
Expand Down Expand Up @@ -529,7 +529,7 @@ def intersections(self, x: float = None, y: float = None, z: float = None):
For more details refer to
`Shapely's Documentaton <https://shapely.readthedocs.io/en/stable/project.html>`_.
"""
axis, position = self._parse_xyz_kwargs(x=x, y=y, z=z)
axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z)
z0, (x0, y0) = self.pop_axis(self.center, axis=axis)
Lz, (Lx, Ly) = self.pop_axis(self.size, axis=axis)
dz = np.abs(z0 - position)
Expand Down Expand Up @@ -647,7 +647,7 @@ def intersections(self, x: float = None, y: float = None, z: float = None):
For more details refer to
`Shapely's Documentaton <https://shapely.readthedocs.io/en/stable/project.html>`_.
"""
axis, position = self._parse_xyz_kwargs(x=x, y=y, z=z)
axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z)
z0, (x0, y0) = self.pop_axis(self.center, axis=axis)
intersect_dist = self._intersect_dist(position, z0)
if not intersect_dist:
Expand Down
9 changes: 5 additions & 4 deletions tidy3d/components/medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .viz import add_ax_if_none
from .validators import validate_name_str

from ..constants import C_0, inf, pec_val
from ..constants import C_0, pec_val
from ..log import log


Expand All @@ -23,15 +23,16 @@ class AbstractMedium(ABC, Tidy3dBaseModel):

Parameters
----------
frequeuncy_range : Tuple[float, float] = (-inf, inf)
frequeuncy_range : Tuple[float, float] = None
Range of validity for the medium in Hz.
If None, then all frequencies are valid.
If simulation or plotting functions use frequency out of this range, a warning is thrown.
name : str = None
Optional name for the medium.
"""

name: str = None
frequency_range: Tuple[FreqBound, FreqBound] = (-inf, inf)
frequency_range: Tuple[FreqBound, FreqBound] = None

_name_validator = validate_name_str()

Expand Down Expand Up @@ -88,7 +89,7 @@ def _eps_model(self, frequency: float) -> complex:
"""New eps_model function."""

# if frequency is none, don't check, return original function
if frequency is None:
if frequency is None or self.frequency_range is None:
return eps_model(self, frequency)

fmin, fmax = self.frequency_range
Expand Down
4 changes: 2 additions & 2 deletions tidy3d/components/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def plot_grid(self, x: float = None, y: float = None, z: float = None, ax: Ax =
The supplied or created matplotlib axes.
"""
cell_boundaries = self.grid.boundaries
axis, _ = self._parse_xyz_kwargs(x=x, y=y, z=z)
axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z)
_, (axis_x, axis_y) = self.pop_axis([0, 1, 2], axis=axis)
boundaries_x = cell_boundaries.dict()["xyz"[axis_x]]
boundaries_y = cell_boundaries.dict()["xyz"[axis_y]]
Expand Down Expand Up @@ -668,7 +668,7 @@ def _set_plot_bounds(self, ax: Ax, x: float = None, y: float = None, z: float =
The axes after setting the boundaries.
"""

axis, _ = self._parse_xyz_kwargs(x=x, y=y, z=z)
axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z)
_, ((xmin, ymin), (xmax, ymax)) = self._pop_bounds(axis=axis)
_, (pml_thick_x, pml_thick_y) = self.pop_axis(self.pml_thicknesses, axis=axis)

Expand Down
8 changes: 4 additions & 4 deletions tidy3d/components/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ class SourceParams(PatchParamSwitcher):

def get_plot_params(self) -> PatchParams:
"""Returns :class:`PatchParams` based on user-supplied args."""
return PatchParams(alpha=0.7, facecolor="blueviolet", edgecolor="blueviolet")
return PatchParams(alpha=0.4, facecolor="blueviolet", edgecolor="blueviolet")


class MonitorParams(PatchParamSwitcher):
"""Patch plotting parameters for :class:`Monitor`."""

def get_plot_params(self) -> PatchParams:
"""Returns :class:`PatchParams` based on user-supplied args."""
return PatchParams(alpha=0.7, facecolor="crimson", edgecolor="crimson")
return PatchParams(alpha=0.4, facecolor="crimson", edgecolor="crimson")


class StructMediumParams(PatchParamSwitcher):
Expand Down Expand Up @@ -140,9 +140,9 @@ class SymParams(PatchParamSwitcher):
def get_plot_params(self) -> PatchParams:
"""Returns :class:`PatchParams` based on user-supplied args."""
if self.sym_value == 1:
return PatchParams(alpha=0.5, facecolor="lightsteelblue", edgecolor="lightsteelblue")
return PatchParams(alpha=0.3, facecolor="lightsteelblue", edgecolor="lightsteelblue")
if self.sym_value == -1:
return PatchParams(alpha=0.5, facecolor="lightgreen", edgecolor="lightgreen")
return PatchParams(alpha=0.3, facecolor="lightgreen", edgecolor="lightgreen")
return PatchParams()


Expand Down
2 changes: 1 addition & 1 deletion tidy3d/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
HBAR = 6.582119569e-16

# infinity (very large)
inf = 1e20
inf = 1e10

# floating point precisions
dp_eps = np.finfo(np.float64).eps
Expand Down
Loading