Skip to content

Commit

Permalink
add FieldData.apply_phase and phase kwarg to SimulationData.plot_field
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Nov 27, 2023
1 parent 5cc45f7 commit 4f8ace2
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 11 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `ComponentModeler.plot_sim_eps()` method to plot the simulation permittivity and ports.
- Support for 2D PEC materials.
- Ability to downsample recorded near fields to speed up server-side far field projections.
- `FieldData.apply_phase(phase)` to multiply field data by a phase.
- Optional `phase` argument to `SimulationData.plot_field` that applies a phase to complex-valued fields.

### Changed
- Indent for the json string of Tidy3D models has been changed to `None` when used internally; kept as `indent=4` for writing to `json` and `yaml` files.
Expand Down
17 changes: 17 additions & 0 deletions tests/test_data/test_monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,3 +529,20 @@ def test_outer_dot():
_ = field_data.outer_dot(mode_data)
_ = mode_data.outer_dot(field_data)
_ = field_data.outer_dot(field_data)


@pytest.mark.parametrize("phase_shift", np.linspace(0, 2 * np.pi, 10))
def test_field_data_phase(phase_shift):
def get_combined_phase(data):
field_sum = 0.0
for fld_cmp in data.field_components.values():
field_sum += np.sum(fld_cmp.values)
return np.angle(field_sum)

fld_data1 = make_field_data()
fld_data2 = fld_data1.apply_phase(phase_shift)

phase1 = get_combined_phase(fld_data1)
phase2 = get_combined_phase(fld_data2)

assert np.allclose(phase2, np.angle(np.exp(1j * (phase1 + phase_shift))))
21 changes: 14 additions & 7 deletions tests/test_data/test_sim_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,38 +119,45 @@ def test_centers():
_ = sim_data.at_centers(mon.name)


def test_plot():
@pytest.mark.parametrize("phase", [0, 1.0])
def test_plot(phase):
sim_data = make_sim_data()

# plot regular field data
for field_cmp in sim_data.simulation.get_monitor_by_name("field").fields:
field_data = sim_data["field"].field_components[field_cmp]
for axis_name in "xyz":
xyz_kwargs = {axis_name: field_data.coords[axis_name][0]}
_ = sim_data.plot_field("field", field_cmp, val="imag", f=1e14, **xyz_kwargs)
_ = sim_data.plot_field(
"field", field_cmp, val="imag", f=1e14, phase=phase, **xyz_kwargs
)
plt.close()
for axis_name in "xyz":
xyz_kwargs = {axis_name: 0}
_ = sim_data.plot_field("field", "int", f=1e14, **xyz_kwargs)
_ = sim_data.plot_field("field", "int", f=1e14, phase=phase, **xyz_kwargs)
plt.close()

# plot field time data
for field_cmp in sim_data.simulation.get_monitor_by_name("field_time").fields:
field_data = sim_data["field_time"].field_components[field_cmp]
for axis_name in "xyz":
xyz_kwargs = {axis_name: field_data.coords[axis_name][0]}
_ = sim_data.plot_field("field_time", field_cmp, val="real", t=0.0, **xyz_kwargs)
_ = sim_data.plot_field(
"field_time", field_cmp, val="real", phase=phase, t=0.0, **xyz_kwargs
)
plt.close()
for axis_name in "xyz":
xyz_kwargs = {axis_name: 0}
_ = sim_data.plot_field("field_time", "int", t=0.0, **xyz_kwargs)
_ = sim_data.plot_field("field_time", "int", t=0.0, phase=phase, **xyz_kwargs)
plt.close()

# plot mode field data
for field_cmp in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
_ = sim_data.plot_field("mode_solver", field_cmp, val="real", f=1e14, mode_index=1)
_ = sim_data.plot_field(
"mode_solver", field_cmp, val="real", f=1e14, mode_index=1, phase=phase
)
plt.close()
_ = sim_data.plot_field("mode_solver", "int", f=1e14, mode_index=1)
_ = sim_data.plot_field("mode_solver", "int", f=1e14, mode_index=1, phase=phase)
plt.close()


Expand Down
19 changes: 19 additions & 0 deletions tidy3d/components/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ class AbstractFieldDataset(Dataset, ABC):
def field_components(self) -> Dict[str, DataArray]:
"""Maps the field components to thier associated data."""

def apply_phase(self, phase: float) -> AbstractFieldDataset:
"""Create a copy where all elements are phase-shifted by a value (in radians)."""
if phase == 0.0:
return self
phasor = np.exp(1j * phase)
field_components_shifted = {}
for fld_name, fld_cmp in self.field_components.items():
fld_cmp_shifted = phasor * fld_cmp
field_components_shifted[fld_name] = fld_cmp_shifted
return self.updated_copy(**field_components_shifted)

@property
@abstractmethod
def grid_locations(self) -> Dict[str, str]:
Expand Down Expand Up @@ -265,6 +276,14 @@ class FieldTimeDataset(ElectromagneticFieldDataset):
description="Spatial distribution of the z-component of the magnetic field.",
)

def apply_phase(self, phase: float) -> AbstractFieldDataset:
"""Create a copy where all elements are phase-shifted by a value (in radians)."""

if phase != 0.0:
raise ValueError("Can't apply phase to time-domain field data, which is real-valued.")

return self


class ModeSolverDataset(ElectromagneticFieldDataset):
"""Dataset storing scalar components of E and H fields as a function of freq. and mode_index.
Expand Down
31 changes: 27 additions & 4 deletions tidy3d/components/data/sim_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
""" Simulation Level Data """
from __future__ import annotations
from typing import Callable, Tuple
from typing import Callable, Tuple, Union

import pathlib
import xarray as xr
Expand Down Expand Up @@ -290,7 +290,9 @@ def get_poynting_vector(self, field_monitor_name: str) -> xr.Dataset:

return xr.Dataset(poynting_components)

def _get_scalar_field(self, field_monitor_name: str, field_name: str, val: FieldVal):
def _get_scalar_field(
self, field_monitor_name: str, field_name: str, val: FieldVal, phase: float = 0.0
):
"""return ``xarray.DataArray`` of the scalar field of a given monitor at Yee cell centers.
Parameters
Expand All @@ -301,6 +303,8 @@ def _get_scalar_field(self, field_monitor_name: str, field_name: str, val: Field
Name of the derived field component: one of `('E', 'H', 'S', 'Sx', 'Sy', 'Sz')`.
val : Literal['real', 'imag', 'abs', 'abs^2', 'phase'] = 'real'
Which part of the field to plot.
phase : float = 0.0
Optional phase to apply to result
Returns
-------
Expand All @@ -320,6 +324,8 @@ def _get_scalar_field(self, field_monitor_name: str, field_name: str, val: Field
else:
dataset = self.at_boundaries(field_monitor_name)

dataset = self.apply_phase(data=dataset, phase=phase)

if field_name in ("E", "H", "S"):
# Gather vector components
required_components = [field_name + c for c in "xyz"]
Expand Down Expand Up @@ -435,13 +441,27 @@ def mnt_data_from_file(cls, fname: str, mnt_name: str, **parse_obj_kwargs) -> Mo

raise ValueError(f"No monitor with name '{mnt_name}' found in data file.")

@staticmethod
def apply_phase(data: Union[xr.DataArray, xr.Dataset], phase: float = 0.0) -> xr.DataArray:
"""Apply a phase to xarray data."""
if phase != 0.0:
if np.any(np.iscomplex(data.values)):
data *= np.exp(1j * phase)
else:
log.warning(
f"Non-zero phase of {phase} specified but the data being plotted is "
"real-valued. The phase will be ignored in the plot."
)
return data

def plot_field(
self,
field_monitor_name: str,
field_name: str,
val: FieldVal = "real",
scale: PlotScale = "lin",
eps_alpha: float = 0.2,
phase: float = 0.0,
robust: bool = True,
vmin: float = None,
vmax: float = None,
Expand All @@ -466,6 +486,9 @@ def plot_field(
eps_alpha : float = 0.2
Opacity of the structure permittivity.
Must be between 0 and 1 (inclusive).
phase : float = 0.0
Optional phase (radians) to apply to the fields.
Only has an effect on frequency-domain fields.
robust : bool = True
If True and vmin or vmax are absent, uses the 2nd and 98th percentiles of the data
to compute the color limits. This helps in visualizing the field patterns especially
Expand All @@ -492,7 +515,6 @@ def plot_field(
"""

# get the DataArray corresponding to the monitor_name and field_name

# deprecated intensity
if field_name == "int":
log.warning(
Expand All @@ -504,14 +526,15 @@ def plot_field(

if field_name in ("E", "H") or field_name[0] == "S":
# Derived fields
field_data = self._get_scalar_field(field_monitor_name, field_name, val)
field_data = self._get_scalar_field(field_monitor_name, field_name, val, phase=phase)
else:
# Direct field component (e.g. Ex)
field_monitor_data = self.load_field_monitor(field_monitor_name)
if field_name not in field_monitor_data.field_components:
raise DataError(f"field_name '{field_name}' not found in data.")
field_component = field_monitor_data.field_components[field_name]
field_component.name = field_name
field_component = self.apply_phase(data=field_component, phase=phase)
field_data = self._field_component_value(field_component, val)

if scale == "dB":
Expand Down

0 comments on commit 4f8ace2

Please sign in to comment.