Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply phase to FieldData and SimulationData.plot_field #1271

Merged
merged 1 commit into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `ComponentModeler.plot_sim_eps()` method to plot the simulation permittivity and ports.
- Support for 2D PEC materials.
- Ability to downsample recorded near fields to speed up server-side far field projections.
- `FieldData.apply_phase(phase)` to multiply field data by a phase.
- Optional `phase` argument to `SimulationData.plot_field` that applies a phase to complex-valued fields.

### Changed
- Indent for the json string of Tidy3D models has been changed to `None` when used internally; kept as `indent=4` for writing to `json` and `yaml` files.
Expand Down
17 changes: 17 additions & 0 deletions tests/test_data/test_monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,3 +529,20 @@ def test_outer_dot():
_ = field_data.outer_dot(mode_data)
_ = mode_data.outer_dot(field_data)
_ = field_data.outer_dot(field_data)


@pytest.mark.parametrize("phase_shift", np.linspace(0, 2 * np.pi, 10))
def test_field_data_phase(phase_shift):
def get_combined_phase(data):
field_sum = 0.0
for fld_cmp in data.field_components.values():
field_sum += np.sum(fld_cmp.values)
return np.angle(field_sum)

fld_data1 = make_field_data()
fld_data2 = fld_data1.apply_phase(phase_shift)

phase1 = get_combined_phase(fld_data1)
phase2 = get_combined_phase(fld_data2)

assert np.allclose(phase2, np.angle(np.exp(1j * (phase1 + phase_shift))))
21 changes: 14 additions & 7 deletions tests/test_data/test_sim_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,38 +119,45 @@ def test_centers():
_ = sim_data.at_centers(mon.name)


def test_plot():
@pytest.mark.parametrize("phase", [0, 1.0])
def test_plot(phase):
sim_data = make_sim_data()

# plot regular field data
for field_cmp in sim_data.simulation.get_monitor_by_name("field").fields:
field_data = sim_data["field"].field_components[field_cmp]
for axis_name in "xyz":
xyz_kwargs = {axis_name: field_data.coords[axis_name][0]}
_ = sim_data.plot_field("field", field_cmp, val="imag", f=1e14, **xyz_kwargs)
_ = sim_data.plot_field(
"field", field_cmp, val="imag", f=1e14, phase=phase, **xyz_kwargs
)
plt.close()
for axis_name in "xyz":
xyz_kwargs = {axis_name: 0}
_ = sim_data.plot_field("field", "int", f=1e14, **xyz_kwargs)
_ = sim_data.plot_field("field", "int", f=1e14, phase=phase, **xyz_kwargs)
plt.close()

# plot field time data
for field_cmp in sim_data.simulation.get_monitor_by_name("field_time").fields:
field_data = sim_data["field_time"].field_components[field_cmp]
for axis_name in "xyz":
xyz_kwargs = {axis_name: field_data.coords[axis_name][0]}
_ = sim_data.plot_field("field_time", field_cmp, val="real", t=0.0, **xyz_kwargs)
_ = sim_data.plot_field(
"field_time", field_cmp, val="real", phase=phase, t=0.0, **xyz_kwargs
)
plt.close()
for axis_name in "xyz":
xyz_kwargs = {axis_name: 0}
_ = sim_data.plot_field("field_time", "int", t=0.0, **xyz_kwargs)
_ = sim_data.plot_field("field_time", "int", t=0.0, phase=phase, **xyz_kwargs)
plt.close()

# plot mode field data
for field_cmp in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
_ = sim_data.plot_field("mode_solver", field_cmp, val="real", f=1e14, mode_index=1)
_ = sim_data.plot_field(
"mode_solver", field_cmp, val="real", f=1e14, mode_index=1, phase=phase
)
plt.close()
_ = sim_data.plot_field("mode_solver", "int", f=1e14, mode_index=1)
_ = sim_data.plot_field("mode_solver", "int", f=1e14, mode_index=1, phase=phase)
plt.close()


Expand Down
19 changes: 19 additions & 0 deletions tidy3d/components/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ class AbstractFieldDataset(Dataset, ABC):
def field_components(self) -> Dict[str, DataArray]:
"""Maps the field components to thier associated data."""

def apply_phase(self, phase: float) -> AbstractFieldDataset:
"""Create a copy where all elements are phase-shifted by a value (in radians)."""
if phase == 0.0:
return self
phasor = np.exp(1j * phase)
field_components_shifted = {}
for fld_name, fld_cmp in self.field_components.items():
fld_cmp_shifted = phasor * fld_cmp
field_components_shifted[fld_name] = fld_cmp_shifted
return self.updated_copy(**field_components_shifted)

@property
@abstractmethod
def grid_locations(self) -> Dict[str, str]:
Expand Down Expand Up @@ -265,6 +276,14 @@ class FieldTimeDataset(ElectromagneticFieldDataset):
description="Spatial distribution of the z-component of the magnetic field.",
)

def apply_phase(self, phase: float) -> AbstractFieldDataset:
"""Create a copy where all elements are phase-shifted by a value (in radians)."""

if phase != 0.0:
raise ValueError("Can't apply phase to time-domain field data, which is real-valued.")

return self


class ModeSolverDataset(ElectromagneticFieldDataset):
"""Dataset storing scalar components of E and H fields as a function of freq. and mode_index.
Expand Down
31 changes: 27 additions & 4 deletions tidy3d/components/data/sim_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
""" Simulation Level Data """
from __future__ import annotations
from typing import Callable, Tuple
from typing import Callable, Tuple, Union

import pathlib
import xarray as xr
Expand Down Expand Up @@ -290,7 +290,9 @@ def get_poynting_vector(self, field_monitor_name: str) -> xr.Dataset:

return xr.Dataset(poynting_components)

def _get_scalar_field(self, field_monitor_name: str, field_name: str, val: FieldVal):
def _get_scalar_field(
self, field_monitor_name: str, field_name: str, val: FieldVal, phase: float = 0.0
):
"""return ``xarray.DataArray`` of the scalar field of a given monitor at Yee cell centers.

Parameters
Expand All @@ -301,6 +303,8 @@ def _get_scalar_field(self, field_monitor_name: str, field_name: str, val: Field
Name of the derived field component: one of `('E', 'H', 'S', 'Sx', 'Sy', 'Sz')`.
val : Literal['real', 'imag', 'abs', 'abs^2', 'phase'] = 'real'
Which part of the field to plot.
phase : float = 0.0
Optional phase to apply to result

Returns
-------
Expand All @@ -320,6 +324,8 @@ def _get_scalar_field(self, field_monitor_name: str, field_name: str, val: Field
else:
dataset = self.at_boundaries(field_monitor_name)

dataset = self.apply_phase(data=dataset, phase=phase)

if field_name in ("E", "H", "S"):
# Gather vector components
required_components = [field_name + c for c in "xyz"]
Expand Down Expand Up @@ -435,13 +441,27 @@ def mnt_data_from_file(cls, fname: str, mnt_name: str, **parse_obj_kwargs) -> Mo

raise ValueError(f"No monitor with name '{mnt_name}' found in data file.")

@staticmethod
def apply_phase(data: Union[xr.DataArray, xr.Dataset], phase: float = 0.0) -> xr.DataArray:
"""Apply a phase to xarray data."""
if phase != 0.0:
if np.any(np.iscomplex(data.values)):
data *= np.exp(1j * phase)
else:
log.warning(
f"Non-zero phase of {phase} specified but the data being plotted is "
"real-valued. The phase will be ignored in the plot."
)
return data

def plot_field(
self,
field_monitor_name: str,
field_name: str,
val: FieldVal = "real",
scale: PlotScale = "lin",
eps_alpha: float = 0.2,
phase: float = 0.0,
robust: bool = True,
vmin: float = None,
vmax: float = None,
Expand All @@ -466,6 +486,9 @@ def plot_field(
eps_alpha : float = 0.2
Opacity of the structure permittivity.
Must be between 0 and 1 (inclusive).
phase : float = 0.0
Optional phase (radians) to apply to the fields.
Only has an effect on frequency-domain fields.
robust : bool = True
If True and vmin or vmax are absent, uses the 2nd and 98th percentiles of the data
to compute the color limits. This helps in visualizing the field patterns especially
Expand All @@ -492,7 +515,6 @@ def plot_field(
"""

# get the DataArray corresponding to the monitor_name and field_name

# deprecated intensity
if field_name == "int":
log.warning(
Expand All @@ -504,14 +526,15 @@ def plot_field(

if field_name in ("E", "H") or field_name[0] == "S":
# Derived fields
field_data = self._get_scalar_field(field_monitor_name, field_name, val)
field_data = self._get_scalar_field(field_monitor_name, field_name, val, phase=phase)
else:
# Direct field component (e.g. Ex)
field_monitor_data = self.load_field_monitor(field_monitor_name)
if field_name not in field_monitor_data.field_components:
raise DataError(f"field_name '{field_name}' not found in data.")
field_component = field_monitor_data.field_components[field_name]
field_component.name = field_name
field_component = self.apply_phase(data=field_component, phase=phase)
field_data = self._field_component_value(field_component, val)

if scale == "dB":
Expand Down