Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more sax models #364

Merged
merged 4 commits into from
Mar 24, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 99 additions & 14 deletions gplugins/sax/models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from __future__ import annotations

from functools import cache
import inspect
from collections.abc import Callable
from functools import cache, partial

import jax
import jax.numpy as jnp
import sax
from numpy.typing import NDArray
from sax import SDict
from sax.utils import reciprocal

nm = 1e-3

FloatArray = NDArray[jnp.floating]
Float = float | FloatArray

################
# PassThrus
################
Expand Down Expand Up @@ -295,7 +301,59 @@ def coupler_single_wavelength(*, coupling: float = 0.5) -> SDict:
)


def mmi1x2() -> SDict:
################
# MMIs
################


def _mmi_amp(
wl: Float = 1.55, wl0: Float = 1.55, fwhm: Float = 0.2, loss_dB: Float = 0.3
):
max_power = 10 ** (-abs(loss_dB) / 10)
f = 1 / wl
f0 = 1 / wl0
f1 = 1 / (wl0 + fwhm / 2)
f2 = 1 / (wl0 - fwhm / 2)
_fwhm = f2 - f1

sigma = _fwhm / (2 * jnp.sqrt(2 * jnp.log(2)))
power = jnp.exp(-((f - f0) ** 2) / (2 * sigma**2))
power = max_power * power / power.max() / 2
return jnp.sqrt(power)


def mmi1x2(
wl: Float = 1.55, wl0: Float = 1.55, fwhm: Float = 0.2, loss_dB: Float = 0.3
) -> sax.SDict:
Comment on lines +327 to +329
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code_refinement): Consider validating input parameters for mmi1x2.

Given the physical nature of the parameters, it might be beneficial to validate their ranges to ensure they are within physically plausible limits.

thru = _mmi_amp(wl=wl, wl0=wl0, fwhm=fwhm, loss_dB=loss_dB)
return sax.reciprocal(
{
("o1", "o2"): thru,
("o1", "o3"): thru,
}
)


def mmi2x2(
wl: Float = 1.55,
wl0: Float = 1.55,
fwhm: Float = 0.2,
loss_dB: Float = 0.3,
shift: Float = 0.005,
) -> sax.SDict:
thru = _mmi_amp(wl=wl, wl0=wl0, fwhm=fwhm, loss_dB=loss_dB)
cross = 1j * _mmi_amp(wl=wl, wl0=wl0 + shift, fwhm=fwhm, loss_dB=loss_dB)
return sax.reciprocal(
{
("o1", "o3"): thru,
("o1", "o4"): cross,
("o2", "o3"): cross,
("o2", "o4"): thru,
}
)


def mmi1x2_ideal() -> SDict:
"""Returns an ideal 1x2 splitter."""
return reciprocal(
{
Expand All @@ -305,7 +363,7 @@ def mmi1x2() -> SDict:
)


def mmi2x2(*, coupling: float = 0.5) -> SDict:
def mmi2x2_ideal(*, coupling: float = 0.5) -> SDict:
"""Returns an ideal 2x2 splitter.

Args:
Expand All @@ -323,21 +381,48 @@ def mmi2x2(*, coupling: float = 0.5) -> SDict:
)


models = dict(
straight=straight,
bend_euler=bend,
mmi1x2=mmi1x2,
mmi2x2=mmi2x2,
attenuator=attenuator,
taper=straight,
phase_shifter=phase_shifter,
grating_coupler=grating_coupler,
coupler=coupler,
)
################
# Crossings
################


@jax.jit
def crossing(wl: Float = 1.5) -> sax.SDict:
one = jnp.ones_like(jnp.asarray(wl))
return sax.reciprocal(
{
("o1", "o3"): one,
("o2", "o4"): one,
}
)


################
# Models Dict
################
def get_models() -> dict[str, Callable[..., sax.SDict]]:
models = {}
for name, func in list(globals().items()):
if not callable(func):
continue
_func = func
while isinstance(_func, partial):
_func = _func.func
try:
sig = inspect.signature(_func)
except ValueError:
continue
if str(sig.return_annotation) in {"sax.SDict", "SDict"} and not name.startswith(
"_"
):
models[name] = func
return models


if __name__ == "__main__":
import gplugins.sax as gs

models = get_models()

gs.plot_model(grating_coupler)
# gs.plot_model(coupler)
Loading