Skip to content

Commit

Permalink
added plot_length_units to simulation and scene classes
Browse files Browse the repository at this point in the history
  • Loading branch information
dmarek-flex committed Jul 11, 2024
1 parent 93b4d68 commit 8dcaf4c
Show file tree
Hide file tree
Showing 15 changed files with 256 additions and 49 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Introduce RF material library. Users can now export `rf_material_library` from `tidy3d.plugins.microwave`.
- Users can specify the background medium for a structure in automatic differentiation by supplying `Structure.autograd_background_permittivity`.
- `DirectivityMonitor` to compute antenna directivity.
- Added `plot_length_units` to `Simulation` and `Scene` to allow for specifying units, which improves axis labels and scaling when plotting.

### Changed

Expand Down
8 changes: 5 additions & 3 deletions scripts/make_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,16 @@ def main(args):
sim_string = re.sub(pattern, "(", sim_string)

# write sim_string to a temporary file
with tempfile.NamedTemporaryFile(delete=False, mode="w+", suffix=".py") as temp_file:
with tempfile.NamedTemporaryFile(
delete=False, mode="w+", suffix=".py", encoding="utf-8"
) as temp_file:
temp_file.write(sim_string)
temp_file_path = temp_file.name
try:
# run ruff to format the temporary file
subprocess.run(["ruff", "format", temp_file_path], check=True)
# read the formatted content back
with open(temp_file_path) as temp_file:
with open(temp_file_path, encoding="utf-8") as temp_file:
sim_string = temp_file.read()
except subprocess.CalledProcessError:
raise RuntimeError(
Expand All @@ -87,7 +89,7 @@ def main(args):
# remove the temporary file
os.remove(temp_file_path)

with open(out_file, "w+") as f:
with open(out_file, "w+", encoding="utf-8") as f:
f.write(sim_string)


Expand Down
5 changes: 5 additions & 0 deletions tests/test_components/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def test_plot(component):
plt.close()


def test_plot_with_units():
_ = BOX.plot(z=0, ax=AX, plot_length_units="nm")
plt.close()


def test_base_inside():
assert td.Geometry.inside(GEO, x=0, y=0, z=0)
assert np.all(td.Geometry.inside(GEO, np.array([0, 0]), np.array([0, 0]), np.array([0, 0])))
Expand Down
5 changes: 5 additions & 0 deletions tests/test_components/test_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def test_structure_alpha():
plt.close()


def test_plot_with_units():
scene_with_units = SCENE_FULL.updated_copy(plot_length_units="nm")
scene_with_units.plot(x=-0.5)


def test_filter_structures():
s1 = td.Structure(geometry=td.Box(size=(1, 1, 1)), medium=SCENE.medium)
s2 = td.Structure(geometry=td.Box(size=(1, 1, 1), center=(1, 1, 1)), medium=SCENE.medium)
Expand Down
5 changes: 5 additions & 0 deletions tests/test_components/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,11 @@ def test_plot():
plt.close()


def test_plot_with_units():
sim_with_units = SIM_FULL.updated_copy(plot_length_units="nm")
sim_with_units.plot(x=-0.5)


def test_plot_1d_sim():
mesh1d = td.UniformGrid(dl=2e-4)
grid_spec = td.GridSpec(grid_x=mesh1d, grid_y=mesh1d, grid_z=mesh1d)
Expand Down
23 changes: 22 additions & 1 deletion tests/test_components/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import matplotlib.pyplot as plt
import pytest
import tidy3d as td
from tidy3d.components.viz import Polygon
from tidy3d.components.viz import Polygon, set_default_labels_and_title
from tidy3d.constants import inf
from tidy3d.exceptions import Tidy3dKeyError


def test_make_polygon_dict():
Expand Down Expand Up @@ -79,3 +80,23 @@ def test_2d_boundary_plot():

# should have a non-infinite size as x is specified
assert pml_box.size[0] != inf


def test_set_default_labels_title():
"""
Ensure labels are correctly added to axes, and test that plot_units are validated.
"""
box = td.Box(center=(0, 0, 0), size=(0.01, 0.01, 0.01))
ax = box.plot(z=0)
axis_labels = box._get_plot_labels(2)

ax = set_default_labels_and_title(axis_labels=axis_labels, axis=2, position=0, ax=ax)

ax = set_default_labels_and_title(
axis_labels=axis_labels, axis=2, position=0, ax=ax, plot_length_units="nm"
)

with pytest.raises(Tidy3dKeyError):
ax = set_default_labels_and_title(
axis_labels=axis_labels, axis=2, position=0, ax=ax, plot_length_units="inches"
)
41 changes: 37 additions & 4 deletions tidy3d/components/base_sim/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Tuple
from typing import Optional, Tuple

import autograd.numpy as anp
import pydantic.v1 as pd
Expand All @@ -16,9 +16,14 @@
from ..medium import Medium, MediumType3D
from ..scene import Scene
from ..structure import Structure
from ..types import TYPE_TAG_STR, Ax, Axis, Bound, Symmetry
from ..types import TYPE_TAG_STR, Ax, Axis, Bound, LengthUnit, Symmetry
from ..validators import assert_objects_in_sim_bounds, assert_unique_names
from ..viz import PlotParams, add_ax_if_none, equal_aspect, plot_params_symmetry
from ..viz import (
PlotParams,
add_ax_if_none,
equal_aspect,
plot_params_symmetry,
)
from .monitor import AbstractMonitor


Expand Down Expand Up @@ -102,6 +107,14 @@ class AbstractSimulation(Box, ABC):
description="String specifying the front end version number.",
)

plot_length_units: Optional[LengthUnit] = pd.Field(
"μm",
title="Plot Units",
description="When set to a supported ``LengthUnit``, "
"plots will be produced with proper scaling of axes and "
"include the desired unit specifier in labels.",
)

""" Validating setup """

# make sure all names are unique
Expand Down Expand Up @@ -157,7 +170,9 @@ def validate_pre_upload(self) -> None:
def scene(self) -> Scene:
"""Scene instance associated with the simulation."""

return Scene(medium=self.medium, structures=self.structures)
return Scene(
medium=self.medium, structures=self.structures, plot_length_units=self.plot_length_units
)

def get_monitor_by_name(self, name: str) -> AbstractMonitor:
"""Return monitor named 'name'."""
Expand Down Expand Up @@ -241,6 +256,12 @@ def plot(
)
ax = self.plot_boundaries(ax=ax, x=x, y=y, z=z)
ax = self.plot_symmetries(ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim)

# Add the default axis labels, tick labels, and title
ax = Box.add_ax_labels_and_title(
ax=ax, x=x, y=y, z=z, plot_length_units=self.plot_length_units
)

return ax

@equal_aspect
Expand Down Expand Up @@ -285,6 +306,10 @@ def plot_sources(
ax = Scene._set_plot_bounds(
bounds=self.simulation_bounds, ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim
)
# Add the default axis labels, tick labels, and title
ax = Box.add_ax_labels_and_title(
ax=ax, x=x, y=y, z=z, plot_length_units=self.plot_length_units
)
return ax

@equal_aspect
Expand Down Expand Up @@ -329,6 +354,10 @@ def plot_monitors(
ax = Scene._set_plot_bounds(
bounds=self.simulation_bounds, ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim
)
# Add the default axis labels, tick labels, and title
ax = Box.add_ax_labels_and_title(
ax=ax, x=x, y=y, z=z, plot_length_units=self.plot_length_units
)
return ax

@equal_aspect
Expand Down Expand Up @@ -376,6 +405,10 @@ def plot_symmetries(
ax = Scene._set_plot_bounds(
bounds=self.simulation_bounds, ax=ax, x=x, y=y, z=z, hlim=hlim, vlim=vlim
)
# Add the default axis labels, tick labels, and title
ax = Box.add_ax_labels_and_title(
ax=ax, x=x, y=y, z=z, plot_length_units=self.plot_length_units
)
return ax

def _make_symmetry_plot_params(self, sym_value: Symmetry) -> PlotParams:
Expand Down
70 changes: 60 additions & 10 deletions tidy3d/components/geometry/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
ClipOperationType,
Coordinate,
Coordinate2D,
LengthUnit,
MatrixReal4x4,
PlanePosition,
Shapely,
Expand All @@ -51,6 +52,7 @@
equal_aspect,
plot_params_geometry,
polygon_patch,
set_default_labels_and_title,
)

POLY_GRID_SIZE = 1e-12
Expand Down Expand Up @@ -438,7 +440,13 @@ def _update_from_bounds(self, bounds: Tuple[float, float], axis: Axis) -> Geomet
@equal_aspect
@add_ax_if_none
def plot(
self, x: float = None, y: float = None, z: float = None, ax: Ax = None, **patch_kwargs
self,
x: float = None,
y: float = None,
z: float = None,
ax: Ax = None,
plot_length_units: LengthUnit = None,
**patch_kwargs,
) -> Ax:
"""Plot geometry cross section at single (x,y,z) coordinate.
Expand All @@ -452,6 +460,8 @@ def plot(
Position of plane in z direction, only one of x,y,z can be specified to define plane.
ax : matplotlib.axes._subplots.Axes = None
Matplotlib axes to plot on, if not specified, one is created.
plot_length_units : LengthUnit = None
Specify units to use for axis labels, tick labels, and the title.
**patch_kwargs
Optional keyword arguments passed to the matplotlib patch plotting of structure.
For details on accepted values, refer to
Expand All @@ -474,9 +484,10 @@ def plot(
ax = self.plot_shape(shape, plot_params=plot_params, ax=ax)

# clean up the axis display
ax = self.add_ax_labels_lims(axis=axis, ax=ax)
ax = self.add_ax_lims(axis=axis, ax=ax)
ax.set_aspect("equal")
ax.set_title(f"cross section at {'xyz'[axis]}={position:.2f}")
# Add the default axis labels, tick labels, and title
ax = Box.add_ax_labels_and_title(ax=ax, x=x, y=y, z=z, plot_length_units=plot_length_units)
return ax

def plot_shape(self, shape: Shapely, plot_params: PlotParams, ax: Ax) -> Ax:
Expand Down Expand Up @@ -523,7 +534,8 @@ def _do_not_intersect(bounds_a, bounds_b, shape_a, shape_b):

return False

def _get_plot_labels(self, axis: Axis) -> Tuple[str, str]:
@staticmethod
def _get_plot_labels(axis: Axis) -> Tuple[str, str]:
"""Returns planar coordinate x and y axis labels for cross section plots.
Parameters
Expand All @@ -536,7 +548,7 @@ def _get_plot_labels(self, axis: Axis) -> Tuple[str, str]:
str, str
Labels of plot, packaged as ``(xlabel, ylabel)``.
"""
_, (xlabel, ylabel) = self.pop_axis("xyz", axis=axis)
_, (xlabel, ylabel) = Geometry.pop_axis("xyz", axis=axis)
return xlabel, ylabel

def _get_plot_limits(
Expand All @@ -559,8 +571,8 @@ def _get_plot_limits(
_, ((xmin, ymin), (xmax, ymax)) = self._pop_bounds(axis=axis)
return (xmin - buffer, xmax + buffer), (ymin - buffer, ymax + buffer)

def add_ax_labels_lims(self, axis: Axis, ax: Ax, buffer: float = PLOT_BUFFER) -> Ax:
"""Sets the x,y labels based on ``axis`` and the extends based on ``self.bounds``.
def add_ax_lims(self, axis: Axis, ax: Ax, buffer: float = PLOT_BUFFER) -> Ax:
"""Sets the x,y limits based on ``self.bounds``.
Parameters
----------
Expand All @@ -576,16 +588,54 @@ def add_ax_labels_lims(self, axis: Axis, ax: Ax, buffer: float = PLOT_BUFFER) ->
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
"""
xlabel, ylabel = self._get_plot_labels(axis=axis)
(xmin, xmax), (ymin, ymax) = self._get_plot_limits(axis=axis, buffer=buffer)

# note: axes limits dont like inf values, so we need to evaluate them first if present
xmin, xmax, ymin, ymax = self._evaluate_inf((xmin, xmax, ymin, ymax))

ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
return ax

@staticmethod
def add_ax_labels_and_title(
ax: Ax,
x: float = None,
y: float = None,
z: float = None,
plot_length_units: LengthUnit = None,
) -> Ax:
"""Sets the axis labels, tick labels, and title based on ``axis``
and an optional ``plot_length_units`` argument.
Parameters
----------
ax : matplotlib.axes._subplots.Axes
Matplotlib axes to add labels and limits on.
x : float = None
Position of plane in x direction, only one of x,y,z can be specified to define plane.
y : float = None
Position of plane in y direction, only one of x,y,z can be specified to define plane.
z : float = None
Position of plane in z direction, only one of x,y,z can be specified to define plane.
plot_length_units : LengthUnit = None
When set to a supported ``LengthUnit``, plots will be produced with annotated axes
and title with the proper units.
Returns
-------
matplotlib.axes._subplots.Axes
The supplied matplotlib axes.
"""
axis, position = Box.parse_xyz_kwargs(x=x, y=y, z=z)
axis_labels = Box._get_plot_labels(axis)
ax = set_default_labels_and_title(
axis_labels=axis_labels,
axis=axis,
position=position,
ax=ax,
plot_length_units=plot_length_units,
)
return ax

@staticmethod
Expand Down
18 changes: 10 additions & 8 deletions tidy3d/components/heat_charge/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,11 +657,12 @@ def plot_boundaries(
ax = self._plot_boundary_condition(shape=shape, boundary_spec=bc_spec, ax=ax)

# clean up the axis display
axis, position = Box.parse_xyz_kwargs(x=x, y=y, z=z)
ax = self.add_ax_labels_lims(axis=axis, ax=ax)
ax.set_title(f"cross section at {'xyz'[axis]}={position:.2f}")

ax = self.add_ax_lims(axis=axis, ax=ax)
ax = Scene._set_plot_bounds(bounds=self.simulation_bounds, ax=ax, x=x, y=y, z=z)
# Add the default axis labels, tick labels, and title
ax = Box.add_ax_labels_and_title(
ax=ax, x=x, y=y, z=z, plot_length_units=self.plot_length_units
)

return ax

Expand Down Expand Up @@ -1064,11 +1065,12 @@ def plot_sources(
)

# clean up the axis display
axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z)
ax = self.add_ax_labels_lims(axis=axis, ax=ax)
ax.set_title(f"cross section at {'xyz'[axis]}={position:.2f}")

ax = self.add_ax_lims(axis=axis, ax=ax)
ax = Scene._set_plot_bounds(bounds=self.simulation_bounds, ax=ax, x=x, y=y, z=z)
# Add the default axis labels, tick labels, and title
ax = Box.add_ax_labels_and_title(
ax=ax, x=x, y=y, z=z, plot_length_units=self.plot_length_units
)
return ax

def _add_source_cbar(self, ax: Ax, property: str = "heat_conductivity"):
Expand Down
Loading

0 comments on commit 8dcaf4c

Please sign in to comment.