diff --git a/src/kfactory/layout.py b/src/kfactory/layout.py index b63c3564..692273c9 100644 --- a/src/kfactory/layout.py +++ b/src/kfactory/layout.py @@ -67,7 +67,7 @@ get_cell_name, ) from .settings import Info, KCellSettings, KCellSettingsUnits -from .typings import K, KCellParams, MetaData, T +from .typings import K, KCellParams, KCellSpec, MetaData, T from .utilities import load_layout_options, save_layout_options kcl: KCLayout @@ -1638,6 +1638,57 @@ def get_dcross_section( def __repr__(self) -> str: return f"{self.__class__.__name__}({self.name}, n={len(self.kcells)})" + @overload + def get_component( + self, spec: KCellSpec, *, output_type: type[K], **cell_kwargs: Any + ) -> K: ... + + @overload + def get_component( + self, + spec: int, + ) -> KCell: ... + + @overload + def get_component( + self, + spec: str, + **cell_kwargs: Any, + ) -> ProtoTKCell[Any]: ... + + @overload + def get_component( + self, + spec: Callable[..., K], + **cell_kwargs: Any, + ) -> K: ... + @overload + def get_component(self, spec: K) -> K: ... + + def get_component( + self, + spec: KCellSpec, + *, + output_type: type[K] | None = None, + **cell_kwargs: Any, + ) -> ProtoTKCell[Any]: + if output_type: + return output_type(base=self.get_component(spec, **cell_kwargs).base) + if callable(spec): + return spec(**cell_kwargs) + if isinstance(spec, dict): + settings = spec.get("settings", {}).copy() + settings.update(cell_kwargs) + return self.factories[spec["component"]](**settings) + elif isinstance(spec, str): + return self.factories[spec](**cell_kwargs) + if cell_kwargs: + raise ValueError( + "Cell kwargs are not allowed for retrieving static cells by integer " + "or the cell itself." + ) + return self.kcells[spec] if isinstance(spec, int) else spec + KCLayout.model_rebuild() TVCell.model_rebuild() diff --git a/src/kfactory/typings.py b/src/kfactory/typings.py index 88449c71..5772ffdb 100644 --- a/src/kfactory/typings.py +++ b/src/kfactory/typings.py @@ -1,6 +1,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Annotated, Any, ParamSpec, TypeAlias, TypeVar +from collections.abc import Callable +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + NotRequired, + ParamSpec, + TypeAlias, + TypedDict, + TypeVar, +) import klayout.db as kdb import klayout.lay as lay @@ -12,7 +22,6 @@ from .layout import Constants from .port import ProtoPort - T = TypeVar("T") K = TypeVar("K", bound="ProtoTKCell[Any]") KC = TypeVar("KC", bound="ProtoTKCell[Any]", covariant=True) @@ -26,6 +35,12 @@ TBaseCell = TypeVar("TBaseCell", bound="BaseKCell", covariant=True) KCellParams = ParamSpec("KCellParams") + +class KCellSpecDict(TypedDict): + component: str + settings: NotRequired[dict[str, Any]] + + AnyTrans = TypeVar( "AnyTrans", bound=kdb.Trans | kdb.DTrans | kdb.ICplxTrans | kdb.DCplxTrans ) @@ -106,3 +121,6 @@ """Database unit or micrometer.""" Angle: TypeAlias = int """Integer in the range of `[0,1,2,3]` which are increments in 90°.""" +KCellSpec: TypeAlias = ( + "int | str | KCellSpecDict | ProtoTKCell[Any] | Callable[..., ProtoTKCell[Any]]" +) diff --git a/tests/test_layout.py b/tests/test_layout.py index 524b7c3a..f45b741a 100644 --- a/tests/test_layout.py +++ b/tests/test_layout.py @@ -270,5 +270,77 @@ def test_kclayout_assign(kcl: kf.KCLayout, layers: Layers) -> None: assert len(list(kcl2.layout.each_cell())) == 1 +def test_get_component(layers: Layers) -> None: + # normal functions + kf.kcl.get_component( + "straight", width=1000, length=10_000, layer=kf.kdb.LayerInfo(1, 0) + ) + kf.kcl.get_component( + kf.cells.straight.straight, width=1, length=10, layer=layers.WG + ) + kf.kcl.get_component( + kf.cells.straight.straight(width=1, length=10, layer=layers.WG) + ) + kf.kcl.get_component( + kf.cells.straight.straight(width=1, length=10, layer=layers.WG).cell_index() + ) + + # output_type functions + kf.kcl.get_component( + "straight", + width=1000, + length=10_000, + layer=kf.kdb.LayerInfo(1, 0), + output_type=kf.DKCell, + ) + kf.kcl.get_component( + { + "component": "straight", + "settings": dict(width=1000, length=10_000, layer=kf.kdb.LayerInfo(1, 0)), + }, + output_type=kf.DKCell, + ) + kf.kcl.get_component( + kf.cells.straight.straight, + width=1, + length=10, + layer=layers.WG, + output_type=kf.DKCell, + ) + kf.kcl.get_component( + kf.cells.straight.straight(width=1, length=10, layer=layers.WG), + output_type=kf.DKCell, + ) + kf.kcl.get_component( + kf.cells.straight.straight(width=1, length=10, layer=layers.WG).cell_index(), + output_type=kf.DKCell, + ) + + # raises errors + with pytest.raises(ValueError): + kf.kcl.get_component( + kf.cells.straight.straight(width=1, length=10, layer=layers.WG), + output_type=kf.DKCell, + width=1, + length=10, + layer=layers.WG, + ) + with pytest.raises(ValueError): + kf.kcl.get_component( + kf.cells.straight.straight( + width=1, length=10, layer=layers.WG + ).cell_index(), + output_type=kf.DKCell, + width=1, + length=10, + layer=layers.WG, + ) + with pytest.raises(TypeError): + kf.kcl.get_component( + {"component": "straight"}, + output_type=kf.DKCell, + ) + + if __name__ == "__main__": pytest.main(["-s", __file__])