Skip to content

Commit

Permalink
Fitter improvement
Browse files Browse the repository at this point in the history
Change DispersionFitter to a pydantic.BaseModel object.
Make eps_data, lossy, freqs, frequency_range into @Property types
More error handling in StableDispersiveFitter
Allow to load dispersive data directly by providing URL to txt or csv file
Allow to filter wavelength range for fitting
Add tunable eps_inf as a variable for optimization
  • Loading branch information
weiliangjin2021 committed Feb 15, 2022
1 parent 91807c0 commit 74c5eb6
Show file tree
Hide file tree
Showing 7 changed files with 645 additions and 279 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- ``Selmeier.from_dispersion()`` method to quickly make a single-pole fit for lossless weakly dispersive materials.
- Stable dispersive material fits via webservice.
- Allow to load dispersive data directly by providing URL to txt or csv file
- Validates simulation based on discretized size.

### Changed
Expand Down
37 changes: 30 additions & 7 deletions tests/_test_fit_web.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import numpy as np
import sys

sys.path.append("../tidy3d_client_revamp")
from tidy3d.plugins import StableDispersionFitter
from tidy3d.plugins import StableDispersionFitter, AdvancedFitterParam


def test_dispersion_load_list():
"""performs a fit on some random data"""
num_data = 10
n_data = np.random.random(num_data)
wvls = np.linspace(1, 2, num_data)
fitter = StableDispersionFitter(wvls, n_data)
fitter = StableDispersionFitter(wvl_um=wvls, n_data=n_data)

num_poles = 3
num_tries = 10
Expand All @@ -27,7 +25,7 @@ def test_dispersion_load_file():
fitter = StableDispersionFitter.from_file("tests/data/nk_data.csv", skiprows=1, delimiter=",")

num_poles = 3
num_tries = 30
num_tries = 10
tolerance_rms = 1e-3
best_medium, best_rms = fitter.fit(
num_tries=num_tries, num_poles=num_poles, tolerance_rms=tolerance_rms
Expand All @@ -36,5 +34,30 @@ def test_dispersion_load_file():
print(best_medium.eps_model(1e12))


if __name__ == "__main__":
test_dispersion_load_file()
def test_dispersion_load_url():

url_csv = "https://refractiveindex.info/data_csv.php?datafile=data/main/Ag/Johnson.yml"
fitter = StableDispersionFitter.from_url(url_csv)

num_poles = 2
num_tries = 10
tolerance_rms = 1e-3
best_medium, best_rms = fitter.fit(
num_tries=num_tries,
num_poles=num_poles,
tolerance_rms=tolerance_rms,
advanced_param=AdvancedFitterParam(constraint="hard", bound_eps_inf=10),
)
print(best_rms)
print(best_medium.eps_inf)

fitter.wvl_range = [1.0, 1.3]
print(len(fitter.freqs))
best_medium, best_rms = fitter.fit(
num_tries=num_tries,
num_poles=num_poles,
tolerance_rms=tolerance_rms,
advanced_param=AdvancedFitterParam(constraint="hard", bound_eps_inf=10),
)
print(best_rms)
print(best_medium.eps_inf)
23 changes: 19 additions & 4 deletions tests/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import tidy3d as td

from tidy3d.plugins import DispersionFitter
from tidy3d.plugins.dispersion.fit import _poles_to_coeffs, _coeffs_to_poles
from tidy3d.plugins.dispersion.fit import _pack_coeffs, _unpack_coeffs

from tidy3d.plugins import ModeSolver
from tidy3d.plugins import Near2Far
Expand Down Expand Up @@ -84,11 +82,14 @@ def test_dispersion():
num_data = 10
n_data = np.random.random(num_data)
wvls = np.linspace(1, 2, num_data)
fitter = DispersionFitter(wvls, n_data)
medium, rms = fitter.fit_single()
fitter = DispersionFitter(wvl_um=wvls, n_data=n_data)
medium, rms = fitter._fit_single()
medium, rms = fitter.fit(num_tries=2)
medium.to_file("tests/tmp/medium_fit.json")

k_data = np.random.random(num_data)
fitter = DispersionFitter(wvl_um=wvls, n_data=n_data, k_data=k_data)


def test_dispersion_load():
"""loads dispersion model from nk data file"""
Expand All @@ -101,3 +102,17 @@ def test_dispersion_plot():
fitter = DispersionFitter.from_file("tests/data/nk_data.csv", skiprows=1, delimiter=",")
medium, rms = fitter.fit(num_tries=20)
fitter.plot(medium)


def test_dispersion_set_wvg_range():
"""set wavelength range function"""
num_data = 50
n_data = np.random.random(num_data)
wvls = np.linspace(1, 2, num_data)
fitter = DispersionFitter(wvl_um=wvls, n_data=n_data)

wvl_min = np.random.random(1)[0] * 0.5 + 1
wvl_max = wvl_min + 0.5
fitter.wvl_range = [wvl_min, wvl_max]
assert len(fitter.freqs) < num_data
medium, rms = fitter.fit(num_tries=2)
46 changes: 25 additions & 21 deletions tests/test_plugins_web.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,40 @@
import numpy as np

from tidy3d.plugins import StableDispersionFitter
from tidy3d.plugins import StableDispersionFitter, DispersionFitter


def test_dispersion_load_list():
"""performs a fit on some random data"""
num_data = 10
n_data = np.random.random(num_data)
wvls = np.linspace(1, 2, num_data)
fitter = StableDispersionFitter(wvls, n_data)

# num_poles = 3
# num_tries = 50
# tolerance_rms = 1e-3
# local_run = True
# best_medium, best_rms = fitter.fit(
# num_tries=num_tries, num_poles=num_poles, tolerance_rms=tolerance_rms, local_run=True
# )
# print(best_rms)
# print(best_medium.eps_model(1e12))
fitter = StableDispersionFitter(wvl_um=wvls, n_data=n_data)


def test_dispersion_load_file():
"""loads dispersion model from nk data file"""
fitter = StableDispersionFitter.from_file("tests/data/nk_data.csv", skiprows=1, delimiter=",")

# num_poles = 3
# num_tries = 50
# tolerance_rms = 1e-3
# local_run = True
# best_medium, best_rms = fitter.fit(
# num_tries=num_tries, num_poles=num_poles, tolerance_rms=tolerance_rms, local_run=True
# )
# print(best_rms)
# print(best_medium.eps_model(1e12))

def test_dispersion_load_url():
"""performs a fit on some random data"""

# both n and k
url_csv = "https://refractiveindex.info/data_csv.php?datafile=data/main/Ag/Johnson.yml"
url_txt = "https://refractiveindex.info/data_txt.php?datafile=data/main/Ag/Johnson.yml"
fitter = DispersionFitter.from_url(url_csv, delimiter=",")
fitter = StableDispersionFitter.from_url(url_csv, delimiter=",")
fitter_txt = DispersionFitter.from_url(url_txt, delimiter="\t")
fitter_txt = StableDispersionFitter.from_url(url_txt, delimiter="\t")
fitter_txt.wvl_range = [0.3, 0.8]
assert len(fitter_txt.freqs) < len(fitter.freqs)

# only k
url_csv = "https://refractiveindex.info/data_csv.php?datafile=data/main/N2/Peck-0C.yml"
url_txt = "https://refractiveindex.info/data_txt.php?datafile=data/main/N2/Peck-0C.yml"
fitter = DispersionFitter.from_url(url_csv, delimiter=",")
fitter_txt = DispersionFitter.from_url(url_txt, delimiter="\t")


if __name__ == "__main__":
test_dispersion_load_url()
2 changes: 1 addition & 1 deletion tidy3d/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# import the specific classes / functions needed for the plugins

from .dispersion.fit import DispersionFitter
from .dispersion.fit_web import StableDispersionFitter
from .dispersion.fit_web import StableDispersionFitter, AdvancedFitterParam
from .mode.mode_solver import ModeSolver
from .near2far.near2far import Near2Far
Loading

0 comments on commit 74c5eb6

Please sign in to comment.