Skip to content

Commit

Permalink
Added fast dispersion fitter
Browse files Browse the repository at this point in the history
  • Loading branch information
caseyflex committed Jul 24, 2023
1 parent 7a8b753 commit 7d92369
Show file tree
Hide file tree
Showing 4 changed files with 887 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Configuration option `config.log_suppression` can be used to control the suppression of log messages.
- `abort()` for `Job` and `mode solver`, Job or mode solver whose status is not success or error(e.g. running, draft) can be aborted, if Job or mode solver is abort, it can't be submitted, a new one needs to be created and submitted.
- `web.abort()` and `Job.abort()` methods allowing to abort running tasks without deleting them. If a task is aborted, it cannot be restarted later, a new one needs to be created and submitted.
- `FastDispersionFitter` for fast fitting of material dispersion data.

### Changed

Expand Down
28 changes: 27 additions & 1 deletion tests/test_plugins/test_dispersion_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
import responses

import tidy3d as td
from tidy3d.plugins.dispersion import DispersionFitter
from tidy3d.plugins.dispersion import DispersionFitter, FastDispersionFitter
from tidy3d.plugins.dispersion import AdvancedFastFitterParam
from tidy3d.plugins.dispersion.web import run as run_fitter


advanced_param = AdvancedFastFitterParam(num_iters=1, passivity_num_iters=1)


@pytest.fixture
def random_data():
data_points = 11
Expand Down Expand Up @@ -63,6 +67,9 @@ def test_lossless_dispersion(random_data, mock_remote_api):
medium, rms = fitter.fit(num_tries=2)
medium, rms = run_fitter(fitter)

fitter = FastDispersionFitter(wvl_um=wvl_um.tolist(), n_data=tuple(n_data))
medium, rms = fitter.fit(advanced_param=advanced_param)


@responses.activate
def test_lossy_dispersion(random_data, mock_remote_api):
Expand All @@ -73,12 +80,18 @@ def test_lossy_dispersion(random_data, mock_remote_api):
medium, rms = fitter.fit(num_tries=2)
medium, rms = run_fitter(fitter)

fitter = FastDispersionFitter(wvl_um=wvl_um.tolist(), n_data=n_data, k_data=k_data)
medium, rms = fitter.fit(advanced_param=advanced_param)


def test_dispersion_load():
"""loads dispersion model from nk data file"""
fitter = DispersionFitter.from_file("tests/data/nk_data.csv", skiprows=1, delimiter=",")
medium, rms = fitter.fit(num_tries=20)

fitter = FastDispersionFitter.from_file("tests/data/nk_data.csv", skiprows=1, delimiter=",")
medium, rms = fitter.fit(advanced_param=advanced_param)


def test_dispersion_plot(random_data):
"""plots a medium fit from file"""
Expand All @@ -99,23 +112,36 @@ def test_dispersion_set_wvg_range(random_data):
"""set wavelength range function"""
wvl_um, n_data, k_data = random_data
fitter = DispersionFitter(wvl_um=wvl_um, n_data=n_data)
fastfitter = FastDispersionFitter(wvl_um=wvl_um, n_data=n_data)

wvl_range = [1.2, 1.8]
fitter = fitter.copy(update={"wvl_range": wvl_range})
assert len(fitter.freqs) == 7
medium, rms = fitter.fit(num_tries=2)
fastfitter = fastfitter.copy(update={"wvl_range": wvl_range})
assert len(fastfitter.freqs) == 7
medium, rms = fastfitter.fit(advanced_param=advanced_param)

wvl_range = [1.2, 2.8]
fitter = fitter.copy(update={"wvl_range": wvl_range, "k_data": k_data})
assert len(fitter.freqs) == 9
medium, rms = fitter.fit(num_tries=2)
fastfitter = fastfitter.copy(update={"wvl_range": wvl_range})
assert len(fastfitter.freqs) == 9
medium, rms = fastfitter.fit(advanced_param=advanced_param)

wvl_range = [0.2, 1.8]
fitter = fitter.copy(update={"wvl_range": wvl_range})
assert len(fitter.freqs) == 9
medium, rms = fitter.fit(num_tries=2)
fastfitter = fastfitter.copy(update={"wvl_range": wvl_range})
assert len(fastfitter.freqs) == 9
medium, rms = fastfitter.fit(advanced_param=advanced_param)

wvl_range = [0.2, 2.8]
fitter = fitter.copy(update={"wvl_range": wvl_range, "k_data": k_data})
assert len(fitter.freqs) == 11
medium, rms = fitter.fit(num_tries=2)
fastfitter = fastfitter.copy(update={"wvl_range": wvl_range})
assert len(fastfitter.freqs) == 11
medium, rms = fastfitter.fit(advanced_param=advanced_param)
1 change: 1 addition & 0 deletions tidy3d/plugins/dispersion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

from .fit import DispersionFitter
from .web import AdvancedFitterParam, StableDispersionFitter
from .fit_fast import FastDispersionFitter, AdvancedFastFitterParam
Loading

0 comments on commit 7d92369

Please sign in to comment.