Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify dispersion fitters with a web.run interface #909

Merged
merged 10 commits into from
May 31, 2023
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
- Add `Medium2D` to full simulation in tests.
- `DispersionFitter` and `StableDispersionFitter` unified in a single `DispersionFitter` interface.
- `StableDispersionFitter` deprecated, with stable fitter now being run instead through `plugins.dispersion.web.run(DispersionFitter)`.

### Fixed
- Plotting 2D materials in `SimulationData.plot_field` and other circumstances.
Expand Down
125 changes: 84 additions & 41 deletions tests/test_plugins/test_dispersion_fitter.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,77 @@
import pytest
import numpy as np
import matplotlib.pyplot as plt
import pydantic
import pytest
import responses

import tidy3d as td

from tidy3d.plugins.dispersion import DispersionFitter
from tidy3d.plugins.mode import ModeSolver
from tidy3d.plugins.mode.solver import compute_modes
from tidy3d import FieldData, ScalarFieldDataArray, FieldMonitor
from tidy3d.plugins.smatrix.smatrix import Port, ComponentModeler
from tidy3d.plugins.smatrix.smatrix import ComponentModeler
from ..utils import clear_tmp, run_emulated
from tidy3d.plugins.dispersion.web import run as run_fitter


@pytest.fixture
def random_data():
data_points = 11
wvl_um = np.linspace(1, 2, data_points)
n_data = np.random.random(data_points)
k_data = np.random.random(data_points)
return wvl_um, n_data, k_data


@pytest.fixture
def mock_remote_api(monkeypatch):
def mock_url(*args, **kwargs):
return "http://monkeypatched.com"

monkeypatch.setattr("tidy3d.plugins.dispersion.web.FitterData._set_url", mock_url)
responses.add(responses.GET, f"{mock_url()}/health", status=200)
responses.add(
responses.POST,
f"{mock_url()}/dispersion/fit",
json={"message": td.PoleResidue().json(), "rms": 1e-16},
status=200,
)

def _test_coeffs():

def test_coeffs():
"""make sure pack_coeffs and unpack_coeffs are reciprocal"""
num_poles = 10
coeffs = np.random.random(4 * num_poles)
a, c = _unpack_coeffs(coeffs)
coeffs_ = _pack_coeffs(a, c)
a_, c_ = _unpack_coeffs(coeffs_)
a, c = DispersionFitter._unpack_coeffs(coeffs)
coeffs_ = DispersionFitter._pack_coeffs(a, c)
a_, c_ = DispersionFitter._unpack_coeffs(coeffs_)
assert np.allclose(coeffs, coeffs_)
assert np.allclose(a, a_)
assert np.allclose(c, c_)


def _test_pole_coeffs():
def test_pole_coeffs():
"""make sure coeffs_to_poles and poles_to_coeffs are reciprocal"""
num_poles = 10
coeffs = np.random.random(4 * num_poles)
poles = _coeffs_to_poles(coeffs)
coeffs_ = _poles_to_coeffs(poles)
poles_ = _coeffs_to_poles(coeffs_)
poles = DispersionFitter._coeffs_to_poles(coeffs)
coeffs_ = DispersionFitter._poles_to_coeffs(poles)
poles_ = DispersionFitter._coeffs_to_poles(coeffs_)
assert np.allclose(coeffs, coeffs_)
assert np.allclose(poles, poles_)


@clear_tmp
def test_dispersion():
"""performs a fit on some random data"""
num_data = 10
n_data = np.random.random(num_data)
wvls = np.linspace(1, 2, num_data)
fitter = DispersionFitter(wvl_um=wvls, n_data=n_data)
@responses.activate
def test_lossless_dispersion(random_data, mock_remote_api):
"""perform fitting on random data"""
wvl_um, n_data, _ = random_data
fitter = DispersionFitter(wvl_um=wvl_um.tolist(), n_data=tuple(n_data))
medium, rms = fitter._fit_single()
medium, rms = fitter.fit(num_tries=2)
medium.to_file("tests/tmp/medium_fit.json")
medium, rms = run_fitter(fitter)


k_data = np.random.random(num_data)
fitter = DispersionFitter(wvl_um=wvls, n_data=n_data, k_data=k_data)
@responses.activate
def test_lossy_dispersion(random_data, mock_remote_api):
"""perform fitting on random lossy data"""
wvl_um, n_data, k_data = random_data
fitter = DispersionFitter(wvl_um=wvl_um, n_data=n_data, k_data=k_data)
medium, rms = fitter._fit_single()
medium, rms = fitter.fit(num_tries=2)
medium, rms = run_fitter(fitter)


def test_dispersion_load():
Expand All @@ -57,22 +80,42 @@ def test_dispersion_load():
medium, rms = fitter.fit(num_tries=20)


def test_dispersion_plot():
def test_dispersion_plot(random_data):
"""plots a medium fit from file"""
fitter = DispersionFitter.from_file("tests/data/nk_data.csv", skiprows=1, delimiter=",")
medium, rms = fitter.fit(num_tries=20)
wvl_um, n_data, k_data = random_data

fitter = DispersionFitter(wvl_um=wvl_um, n_data=n_data)
fitter.plot()
medium, rms = fitter.fit(num_tries=2)
fitter.plot(medium)

fitter = DispersionFitter(wvl_um=wvl_um, n_data=n_data, k_data=k_data)
fitter.plot()
medium, rms = fitter.fit(num_tries=2)
fitter.plot(medium)


def test_dispersion_set_wvg_range():
def test_dispersion_set_wvg_range(random_data):
"""set wavelength range function"""
num_data = 50
n_data = np.random.random(num_data)
wvls = np.linspace(1, 2, num_data)
fitter = DispersionFitter(wvl_um=wvls, n_data=n_data)

wvl_min = np.random.random(1)[0] * 0.5 + 1
wvl_max = wvl_min + 0.5
fitter = fitter.copy(update=dict(wvl_range=[wvl_min, wvl_max]))
assert len(fitter.freqs) < num_data
wvl_um, n_data, k_data = random_data
fitter = DispersionFitter(wvl_um=wvl_um, n_data=n_data)

wvl_range = [1.2, 1.8]
fitter = fitter.copy(update={"wvl_range": wvl_range})
assert len(fitter.freqs) == 7
medium, rms = fitter.fit(num_tries=2)

wvl_range = [1.2, 2.8]
fitter = fitter.copy(update={"wvl_range": wvl_range, "k_data": k_data})
assert len(fitter.freqs) == 9
medium, rms = fitter.fit(num_tries=2)

wvl_range = [0.2, 1.8]
fitter = fitter.copy(update={"wvl_range": wvl_range})
assert len(fitter.freqs) == 9
medium, rms = fitter.fit(num_tries=2)

wvl_range = [0.2, 2.8]
fitter = fitter.copy(update={"wvl_range": wvl_range, "k_data": k_data})
assert len(fitter.freqs) == 11
medium, rms = fitter.fit(num_tries=2)
2 changes: 1 addition & 1 deletion tidy3d/plugins/dispersion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" Imports from dispersion fitter plugin. """

from .fit import DispersionFitter
from .fit_web import StableDispersionFitter, AdvancedFitterParam
from .web import AdvancedFitterParam, StableDispersionFitter
Loading