Skip to content

Commit

Permalink
add FieldData.apply_phase and phase kwarg to SimulationData.plot_field
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Nov 22, 2023
1 parent be8d5e7 commit 5acdb50
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `SimulationData.mnt_data_from_file()` method to load only a single monitor data object from a simulation data `.hdf5` file.
- `_hash_self` to base model, uses `hashlib` to hash a Tidy3D component the same way every session.
- `ComponentModeler.plot_sim_eps()` method to plot the simulation permittivity and ports.
- `FieldData.apply_phase(phase)` to multiply field data by a phase.
- Optional `phase` argument to `SimulationData.plot_field` that applies a phase to complex-valued fields.

### Changed
- Indent for the json string of Tidy3D models has been changed to `None` when used internally; kept as `indent=4` for writing to `json` and `yaml` files.
Expand Down
17 changes: 17 additions & 0 deletions tests/test_data/test_monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,3 +529,20 @@ def test_outer_dot():
_ = field_data.outer_dot(mode_data)
_ = mode_data.outer_dot(field_data)
_ = field_data.outer_dot(field_data)


@pytest.mark.parametrize("phase_shift", np.linspace(0, 2 * np.pi, 10))
def test_field_data_phase(phase_shift):
def get_combined_phase(data):
field_sum = 0.0
for fld_cmp in data.field_components.values():
field_sum += np.sum(fld_cmp.values)
return np.angle(field_sum)

fld_data1 = make_field_data()
fld_data2 = fld_data1.apply_phase(phase_shift)

phase1 = get_combined_phase(fld_data1)
phase2 = get_combined_phase(fld_data2)

assert np.allclose(phase2, np.angle(np.exp(1j * (phase1 + phase_shift))))
21 changes: 14 additions & 7 deletions tests/test_data/test_sim_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,38 +119,45 @@ def test_centers():
_ = sim_data.at_centers(mon.name)


def test_plot():
@pytest.mark.parametrize("phase", [0, 1.0])
def test_plot(phase):
sim_data = make_sim_data()

# plot regular field data
for field_cmp in sim_data.simulation.get_monitor_by_name("field").fields:
field_data = sim_data["field"].field_components[field_cmp]
for axis_name in "xyz":
xyz_kwargs = {axis_name: field_data.coords[axis_name][0]}
_ = sim_data.plot_field("field", field_cmp, val="imag", f=1e14, **xyz_kwargs)
_ = sim_data.plot_field(
"field", field_cmp, val="imag", f=1e14, phase=phase, **xyz_kwargs
)
plt.close()
for axis_name in "xyz":
xyz_kwargs = {axis_name: 0}
_ = sim_data.plot_field("field", "int", f=1e14, **xyz_kwargs)
_ = sim_data.plot_field("field", "int", f=1e14, phase=phase, **xyz_kwargs)
plt.close()

# plot field time data
for field_cmp in sim_data.simulation.get_monitor_by_name("field_time").fields:
field_data = sim_data["field_time"].field_components[field_cmp]
for axis_name in "xyz":
xyz_kwargs = {axis_name: field_data.coords[axis_name][0]}
_ = sim_data.plot_field("field_time", field_cmp, val="real", t=0.0, **xyz_kwargs)
_ = sim_data.plot_field(
"field_time", field_cmp, val="real", phase=phase, t=0.0, **xyz_kwargs
)
plt.close()
for axis_name in "xyz":
xyz_kwargs = {axis_name: 0}
_ = sim_data.plot_field("field_time", "int", t=0.0, **xyz_kwargs)
_ = sim_data.plot_field("field_time", "int", t=0.0, phase=phase, **xyz_kwargs)
plt.close()

# plot mode field data
for field_cmp in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
_ = sim_data.plot_field("mode_solver", field_cmp, val="real", f=1e14, mode_index=1)
_ = sim_data.plot_field(
"mode_solver", field_cmp, val="real", f=1e14, mode_index=1, phase=phase
)
plt.close()
_ = sim_data.plot_field("mode_solver", "int", f=1e14, mode_index=1)
_ = sim_data.plot_field("mode_solver", "int", f=1e14, mode_index=1, phase=phase)
plt.close()


Expand Down
17 changes: 17 additions & 0 deletions tidy3d/components/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ class AbstractFieldDataset(Dataset, ABC):
def field_components(self) -> Dict[str, DataArray]:
"""Maps the field components to thier associated data."""

def apply_phase(self, phase: float) -> AbstractFieldDataset:
"""Create a copy of this dataset where all elements are phase shifted by value (radians)."""
phasor = np.exp(1j * phase)
field_components_shifted = {}
for fld_name, fld_cmp in self.field_components.items():
fld_cmp_shifted = phasor * fld_cmp
field_components_shifted[fld_name] = fld_cmp_shifted
return self.updated_copy(**field_components_shifted)

@property
@abstractmethod
def grid_locations(self) -> Dict[str, str]:
Expand Down Expand Up @@ -265,6 +274,14 @@ class FieldTimeDataset(ElectromagneticFieldDataset):
description="Spatial distribution of the z-component of the magnetic field.",
)

def apply_phase(self, phase: float) -> AbstractFieldDataset:
"""Create a copy of this dataset where all elements are phase shifted by value (radians)."""

if phase != 0.0:
raise ValueError("Can't apply phase to time-domain field data, which is real-valued.")

return self


class ModeSolverDataset(ElectromagneticFieldDataset):
"""Dataset storing scalar components of E and H fields as a function of freq. and mode_index.
Expand Down
13 changes: 13 additions & 0 deletions tidy3d/components/data/sim_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ def plot_field(
val: FieldVal = "real",
scale: PlotScale = "lin",
eps_alpha: float = 0.2,
phase: float = 0.0,
robust: bool = True,
vmin: float = None,
vmax: float = None,
Expand All @@ -466,6 +467,9 @@ def plot_field(
eps_alpha : float = 0.2
Opacity of the structure permittivity.
Must be between 0 and 1 (inclusive).
phase : float = 0.0
Optional phase (radians) to apply to the fields.
Only has an effect on frequency-domain fields.
robust : bool = True
If True and vmin or vmax are absent, uses the 2nd and 98th percentiles of the data
to compute the color limits. This helps in visualizing the field patterns especially
Expand Down Expand Up @@ -514,6 +518,15 @@ def plot_field(
field_component.name = field_name
field_data = self._field_component_value(field_component, val)

if phase != 0.0:
if np.any(np.iscomplex(field_data)):
field_data *= np.exp(1j * phase)
else:
log.warning(
f"Non-zero phase of {phase} specified but the data being plotted is "
"real-valued. The phase will be ignored in the plot."
)

if scale == "dB":
if val == "phase":
log.warning("Ploting phase component in log scale masks the phase sign.")
Expand Down

0 comments on commit 5acdb50

Please sign in to comment.