Skip to content

Commit

Permalink
shift port not source, use local grid for shift dist, reorg sims crea…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
tylerflex authored and momchil-flex committed Apr 22, 2022
1 parent a6dbde9 commit 40f89c0
Showing 1 changed file with 42 additions and 22 deletions.
64 changes: 42 additions & 22 deletions tidy3d/plugins/smatrix/smatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pydantic as pd
import numpy as np

from ...constants import HERTZ, C_0, MICROMETER
from ...constants import HERTZ, C_0
from ...components.simulation import Simulation
from ...components.geometry import Box
from ...components.mode import ModeSpec
Expand All @@ -28,7 +28,7 @@ class Port(Box):

direction: Direction = pd.Field(
...,
title="SHIFT_VALUEDirection",
title="Direction",
description="'+' or '-', defining which direction is considered 'input'.",
)
mode_spec: ModeSpec = pd.Field(
Expand Down Expand Up @@ -93,15 +93,6 @@ class ComponentModeler(Tidy3dBaseModel):
description="Batch Data of task used to compute S matrix. Set internally.",
)

shift_value: pd.PositiveFloat = pd.Field(
0.1,
title="Shift Value",
description="Distance between the source and monitor of a given port. "
"Should be greater than one grid cell for best results, while being "
"small enough that the waveguide cross section does not change within the distance.",
units=MICROMETER,
)

@pd.validator("simulation", always=True)
def _sim_has_no_sources(cls, val):
"""Make sure simulation has no sources as they interfere with tool."""
Expand Down Expand Up @@ -146,11 +137,41 @@ def plot_sim(self, x: float = None, y: float = None, z: float = None, ax: Ax = N
return sim_plot.plot(x=x, y=y, z=z, ax=ax)

def _shift_value_signed(self, port: Port) -> float:
"""How far (signed) to shift the monitor from the source."""
return self.shift_value if port.direction == "+" else -1 * self.shift_value
"""How far (signed) to shift the source from the monitor."""

# get the grid boundaries and sizes along port normal from the simulation
normal_axis = port.size.index(0.0)
grid = self.simulation.grid
grid_boundaries = grid.boundaries.to_list[normal_axis]
grid_centers = grid.centers.to_list[normal_axis]

# get the index of the grid cell where the port lies
port_position = port.center[normal_axis]
port_index = np.argwhere(port_position > grid_boundaries)[-1]

# shift the port to the left
if port.direction == "+":
shifted_index = port_index - 2
if shifted_index < 0:
raise SetupError(
f"Port {port.name} normal is too close to boundary "
f"on -{'xyz'[normal_axis]} side."
)

# shift the port to the right
else:
shifted_index = port_index + 2
if shifted_index >= len(grid_centers):
raise SetupError(
f"Port {port.name} normal is too close to boundary "
f"on +{'xyz'[normal_axis]} side."
)

new_pos = grid_centers[shifted_index]
return new_pos - port_position

def _shift_port(self, port: Port) -> Port:
"""Generate a new port shifted by one grid cell in normal direction."""
"""Generate a new port shifted by the shift amount in normal direction."""

shift_value = self._shift_value_signed(port)
center_shifted = list(port.center)
Expand All @@ -165,26 +186,25 @@ def _task_name(self, port_source: Port, mode_index: int) -> str:

def _make_sims(self) -> Dict[str, Simulation]:
"""Generate all the :class:`Simulation` objects for the S matrix calculation."""

mode_monitors = [self._to_monitor(port) for port in self.ports]
sim_dict = {}
for port_source in self.ports:
port_source = self._shift_port(port_source)
for mode_source in self._to_sources(port_source):
sim_copy = self.simulation.copy(deep=True)
sim_copy.sources = [mode_source]
for port_monitor in self.ports:
if port_source == port_monitor:
port_monitor = self._shift_port(port_source)
mode_monitor = self._to_monitor(port_monitor)
sim_copy.monitors.append(mode_monitor)
task_name = self._task_name(port_source, mode_source.mode_index)
sim_dict[task_name] = sim_copy
sim_copy.monitors += mode_monitors
task_name = self._task_name(port_source, mode_source.mode_index)
sim_dict[task_name] = sim_copy
return sim_dict

def _run_sims(
self, sim_dict: Dict[str, Simulation], folder_name: str, path_dir: str
) -> "BatchData":
"""Run :class:`Simulations` for each port and return the batch after saving."""
batch = Batch(simulations=sim_dict, folder_name=folder_name)

batch = Batch(simulations=sim_dict, folder_name=folder_name)
batch.upload()
batch.start()
batch.monitor()
Expand Down

0 comments on commit 40f89c0

Please sign in to comment.