Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more sax models #364

Merged
merged 4 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions gplugins/klayout/dataprep/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self, gdspath, cell_name: str | None = None) -> None:
self.layout = lib.cell_by_name(cell_name) if cell_name else lib.top_cell()
self.lib = lib
self.regions = {}
self.cell = lib[lib.top_cell().cell_index()]

def __getitem__(self, layer: tuple[int, int]) -> Region:
_assert_is_layer(layer)
Expand Down Expand Up @@ -142,13 +143,9 @@ def write_gds(
else:
c.write(gdspath)

def plot(self, **kwargs):
def plot(self) -> kf.KCell:
"""Plot regions."""
gdspath = GDSDIR_TEMP / "out.gds"
self.write_gds(gdspath=gdspath, **kwargs)
gf.clear_cache()
c = gf.import_gds(gdspath)
return c.plot()
return self.cell

def get_kcell(
self, keep_original: bool = True, cellname: str = "Unnamed"
Expand Down
135 changes: 121 additions & 14 deletions gplugins/sax/models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
from __future__ import annotations

from functools import cache
import inspect
from collections.abc import Callable, Iterable
from functools import cache, partial
from inspect import getmembers

import jax
import jax.numpy as jnp
import sax
from numpy.typing import NDArray
from sax import SDict
from sax.utils import reciprocal

nm = 1e-3

FloatArray = NDArray[jnp.floating]
Float = float | FloatArray

################
# PassThrus
################
Expand Down Expand Up @@ -170,6 +177,7 @@ def grating_coupler(
https://github.com/flaport/photontorch/blob/master/photontorch/components/gratingcouplers.py

Args:
wl: wavelength.
wl0: center wavelength.
loss: in dB.
reflection: from waveguide side.
Expand Down Expand Up @@ -295,7 +303,68 @@ def coupler_single_wavelength(*, coupling: float = 0.5) -> SDict:
)


def mmi1x2() -> SDict:
################
# MMIs
################


def _mmi_amp(
wl: Float = 1.55, wl0: Float = 1.55, fwhm: Float = 0.2, loss_dB: Float = 0.3
):
max_power = 10 ** (-abs(loss_dB) / 10)
f = 1 / wl
f0 = 1 / wl0
f1 = 1 / (wl0 + fwhm / 2)
f2 = 1 / (wl0 - fwhm / 2)
_fwhm = f2 - f1

sigma = _fwhm / (2 * jnp.sqrt(2 * jnp.log(2)))
power = jnp.exp(-((f - f0) ** 2) / (2 * sigma**2))
power = max_power * power / power.max() / 2
return jnp.sqrt(power)


def mmi1x2(
wl: Float = 1.55, wl0: Float = 1.55, fwhm: Float = 0.2, loss_dB: Float = 0.3
) -> sax.SDict:
Comment on lines +327 to +329
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code_refinement): Consider validating input parameters for mmi1x2.

Given the physical nature of the parameters, it might be beneficial to validate their ranges to ensure they are within physically plausible limits.

thru = _mmi_amp(wl=wl, wl0=wl0, fwhm=fwhm, loss_dB=loss_dB)
return sax.reciprocal(
{
("o1", "o2"): thru,
("o1", "o3"): thru,
}
)


def mmi2x2(
wl: Float = 1.55,
wl0: Float = 1.55,
fwhm: Float = 0.2,
loss_dB: Float = 0.3,
shift: Float = 0.005,
) -> sax.SDict:
"""Returns 2x2 MMI model.

Args:
wl: wavelength.
wl0: center wavelength.
fwhm: full width half maximum.
loss_dB: loss in dB.
shift: wavelength shift.
"""
thru = _mmi_amp(wl=wl, wl0=wl0, fwhm=fwhm, loss_dB=loss_dB)
cross = 1j * _mmi_amp(wl=wl, wl0=wl0 + shift, fwhm=fwhm, loss_dB=loss_dB)
return sax.reciprocal(
{
("o1", "o3"): thru,
("o1", "o4"): cross,
("o2", "o3"): cross,
("o2", "o4"): thru,
}
)


def mmi1x2_ideal() -> SDict:
"""Returns an ideal 1x2 splitter."""
return reciprocal(
{
Expand All @@ -305,7 +374,7 @@ def mmi1x2() -> SDict:
)


def mmi2x2(*, coupling: float = 0.5) -> SDict:
def mmi2x2_ideal(*, coupling: float = 0.5) -> SDict:
"""Returns an ideal 2x2 splitter.

Args:
Expand All @@ -323,21 +392,59 @@ def mmi2x2(*, coupling: float = 0.5) -> SDict:
)


models = dict(
straight=straight,
bend_euler=bend,
mmi1x2=mmi1x2,
mmi2x2=mmi2x2,
attenuator=attenuator,
taper=straight,
phase_shifter=phase_shifter,
grating_coupler=grating_coupler,
coupler=coupler,
)
################
# Crossings
################


@jax.jit
def crossing(wl: Float = 1.5) -> sax.SDict:
one = jnp.ones_like(jnp.asarray(wl))
return sax.reciprocal(
{
("o1", "o3"): one,
("o2", "o4"): one,
}
)


################
# Models Dict
################
def get_models(modules) -> dict[str, Callable[..., sax.SDict]]:
"""Returns all models in a module or list of modules."""
models = {}
modules = modules if isinstance(modules, Iterable) else [modules]

for module in modules:
for t in getmembers(module):
name = t[0]
func = t[1]
if not callable(func):
continue
_func = func
while isinstance(_func, partial):
_func = _func.func
try:
sig = inspect.signature(_func)
except ValueError:
continue
if str(sig.return_annotation) in {
"sax.SDict",
"SDict",
} and not name.startswith("_"):
models[name] = func
return models


if __name__ == "__main__":
import sys

import gplugins.sax as gs

models = get_models(sys.modules[__name__])
for i in models.keys():
print(i)

gs.plot_model(grating_coupler)
# gs.plot_model(coupler)
21 changes: 16 additions & 5 deletions notebooks/klayout_dataprep.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
"import gdsfactory as gf\n",
"from gdsfactory.generic_tech.layer_map import LAYER\n",
"\n",
"import gplugins.klayout.dataprep as dp\n",
"\n",
"gf.CONF.display_type = \"klayout\""
"import gplugins.klayout.dataprep as dp"
]
},
{
Expand Down Expand Up @@ -74,6 +72,7 @@
"d[LAYER.N] = d[\n",
" LAYER.WG\n",
"].copy() # make sure you add the copy to create a copy of the layer\n",
"d.show()\n",
"d.plot()"
]
},
Expand All @@ -93,6 +92,7 @@
"outputs": [],
"source": [
"d[LAYER.N].clear()\n",
"d.show()\n",
"d.plot()"
]
},
Expand All @@ -114,6 +114,7 @@
"outputs": [],
"source": [
"d[LAYER.SLAB90] = d[LAYER.WG] + 2 # size layer by 4 um\n",
"d.show()\n",
"d.plot()"
]
},
Expand All @@ -134,7 +135,6 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"d[LAYER.SLAB90] += 2 # size layer by 4 um\n",
"d[LAYER.SLAB90] -= 2 # size layer by 2 um\n",
"d.plot()"
Expand Down Expand Up @@ -235,7 +235,6 @@
"\n",
"gdspath = \"mzi_fill.gds\"\n",
"c.write(gdspath)\n",
"c = gf.import_gds(gdspath)\n",
"c.plot()"
]
}
Expand All @@ -249,6 +248,18 @@
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
Expand Down
74 changes: 12 additions & 62 deletions notebooks/klayout_drc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"layer = LAYER.WG\n",
"\n",
"\n",
Expand Down Expand Up @@ -168,67 +167,6 @@
"c.show() # show in klayout\n",
"c.plot()"
]
},
{
"cell_type": "markdown",
"id": "6",
"metadata": {},
"source": [
"# Klayout connectivity checks\n",
"\n",
"You can you can to check for component overlap and unconnected pins using klayout DRC.\n",
"\n",
"\n",
"The easiest way is to write all the pins on the same layer and define the allowed pin widths.\n",
"This will check for disconnected pins or ports with width mismatch."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"\n",
"import gplugins.klayout.drc.write_connectivity as wc\n",
"\n",
"nm = 1e-3\n",
"\n",
"rules = [\n",
" wc.write_connectivity_checks(pin_widths=[0.5, 0.9, 0.45], pin_layer=LAYER.PORT)\n",
"]\n",
"script = wc.write_drc_deck_macro(rules=rules, layers=None)"
]
},
{
"cell_type": "markdown",
"id": "8",
"metadata": {},
"source": [
"You can also define the connectivity checks per section"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"metadata": {},
"outputs": [],
"source": [
"connectivity_checks = [\n",
" wc.ConnectivyCheck(cross_section=\"xs_sc\", pin_length=1 * nm, pin_layer=(1, 10)),\n",
" wc.ConnectivyCheck(\n",
" cross_section=\"xs_sc_auto_widen\", pin_length=1 * nm, pin_layer=(1, 10)\n",
" ),\n",
"]\n",
"rules = [\n",
" wc.write_connectivity_checks_per_section(connectivity_checks=connectivity_checks),\n",
"]\n",
"script = wc.write_drc_deck_macro(rules=rules, layers=None)"
]
}
],
"metadata": {
Expand All @@ -240,6 +178,18 @@
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
Expand Down
Loading