Skip to content

Commit

Permalink
Fitter improvement
Browse files Browse the repository at this point in the history
More error handling in StableDispersiveFitter
Allow to load dispersive data directly by providing URL to txt or csv file
Add tunable eps_inf as a variable for optimization
  • Loading branch information
weiliangjin2021 committed Feb 6, 2022
1 parent d519c16 commit b6e470b
Show file tree
Hide file tree
Showing 7 changed files with 373 additions and 103 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- ``Selmeier.from_dispersion()`` method to quickly make a single-pole fit for lossless weakly dispersive materials.
- Stable dispersive material fits via webservice.
- Allow to load dispersive data directly by providing URL to txt or csv file
- Validates simulation based on discretized size.

### Changed
Expand Down
26 changes: 22 additions & 4 deletions tests/_test_fit_web.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import numpy as np
import sys

sys.path.append("../tidy3d_client_revamp")
from tidy3d.plugins import StableDispersionFitter


Expand All @@ -27,7 +25,7 @@ def test_dispersion_load_file():
fitter = StableDispersionFitter.from_file("tests/data/nk_data.csv", skiprows=1, delimiter=",")

num_poles = 3
num_tries = 30
num_tries = 10
tolerance_rms = 1e-3
best_medium, best_rms = fitter.fit(
num_tries=num_tries, num_poles=num_poles, tolerance_rms=tolerance_rms
Expand All @@ -36,5 +34,25 @@ def test_dispersion_load_file():
print(best_medium.eps_model(1e12))


def test_dispersion_load_url():

url_csv = "https://refractiveindex.info/data_csv.php?datafile=data/main/Ag/Johnson.yml"
fitter = StableDispersionFitter.from_url(url_csv)
# fitter = DispersionFitter.from_url(url_csv)
print(fitter.frequency_range[1]/1e12)

num_poles = 2
num_tries = 10
tolerance_rms = 1e-3
best_medium, best_rms = fitter.fit(
num_tries=num_tries,
num_poles=num_poles,
tolerance_rms=tolerance_rms,
)
print(best_rms)
print(best_medium.eps_model(1e12))
print(best_medium.eps_inf)


if __name__ == "__main__":
test_dispersion_load_file()
test_dispersion_load_url()
2 changes: 1 addition & 1 deletion tests/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_dispersion():
n_data = np.random.random(num_data)
wvls = np.linspace(1, 2, num_data)
fitter = DispersionFitter(wvls, n_data)
medium, rms = fitter.fit_single()
medium, rms = fitter._fit_single()
medium, rms = fitter.fit(num_tries=2)
medium.to_file("tests/tmp/medium_fit.json")

Expand Down
38 changes: 17 additions & 21 deletions tests/test_plugins_web.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from tidy3d.plugins import StableDispersionFitter
from tidy3d.plugins import StableDispersionFitter, DispersionFitter


def test_dispersion_load_list():
Expand All @@ -10,27 +10,23 @@ def test_dispersion_load_list():
wvls = np.linspace(1, 2, num_data)
fitter = StableDispersionFitter(wvls, n_data)

# num_poles = 3
# num_tries = 50
# tolerance_rms = 1e-3
# local_run = True
# best_medium, best_rms = fitter.fit(
# num_tries=num_tries, num_poles=num_poles, tolerance_rms=tolerance_rms, local_run=True
# )
# print(best_rms)
# print(best_medium.eps_model(1e12))


def test_dispersion_load_file():
"""loads dispersion model from nk data file"""
fitter = StableDispersionFitter.from_file("tests/data/nk_data.csv", skiprows=1, delimiter=",")

# num_poles = 3
# num_tries = 50
# tolerance_rms = 1e-3
# local_run = True
# best_medium, best_rms = fitter.fit(
# num_tries=num_tries, num_poles=num_poles, tolerance_rms=tolerance_rms, local_run=True
# )
# print(best_rms)
# print(best_medium.eps_model(1e12))
def test_dispersion_load_url():
"""performs a fit on some random data"""

# both n and k
url_csv = "https://refractiveindex.info/data_csv.php?datafile=data/main/Ag/Johnson.yml"
url_txt = "https://refractiveindex.info/data_txt.php?datafile=data/main/Ag/Johnson.yml"
fitter = DispersionFitter.from_url(url_csv, delimiter=",")
fitter = StableDispersionFitter.from_url(url_csv, delimiter=",")
fitter_txt = DispersionFitter.from_url(url_txt, delimiter="\t")
fitter_txt = StableDispersionFitter.from_url(url_txt, delimiter="\t")

# only k
url_csv = "https://refractiveindex.info/data_csv.php?datafile=data/main/N2/Peck-0C.yml"
url_txt = "https://refractiveindex.info/data_txt.php?datafile=data/main/N2/Peck-0C.yml"
fitter = DispersionFitter.from_url(url_csv, delimiter=",")
fitter_txt = DispersionFitter.from_url(url_txt, delimiter="\t")
2 changes: 1 addition & 1 deletion tidy3d/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# import the specific classes / functions needed for the plugins

from .dispersion.fit import DispersionFitter
from .dispersion.fit_web import StableDispersionFitter
from .dispersion.fit_web import StableDispersionFitter, AdvancedFitterParam
from .mode.mode_solver import ModeSolver
from .near2far.near2far import Near2Far
195 changes: 187 additions & 8 deletions tidy3d/plugins/dispersion/fit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Fit PoleResidue Dispersion models to optical NK data
"""

from typing import Tuple
from typing import Tuple, List
import csv
import codecs
import requests

import nlopt
import numpy as np
Expand All @@ -11,7 +14,7 @@
from ...constants import C_0, HBAR
from ...components.viz import add_ax_if_none
from ...components.types import Ax, Numpy
from ...log import log
from ...log import log, ValidationError, WebError


def _unpack_complex(complex_num):
Expand Down Expand Up @@ -160,10 +163,54 @@ def __init__(self, wvl_um: Numpy, n_data: Numpy, k_data: Numpy = None):
if k_data is None:
self.k_data = np.zeros_like(n_data)
self.lossy = False

self._set_rest_of_init()

def _set_rest_of_init(self):
"""Set up the rest of __init__ once wvl_um, n_data, k_data are set up"""

self.eps_data = AbstractMedium.nk_to_eps_complex(n=self.n_data, k=self.k_data)
self.freqs = C_0 / wvl_um
self.freqs = C_0 / self.wvl_um
self.frequency_range = (np.min(self.freqs), np.max(self.freqs))

# def _set_wvl_range(self, wvl_min: float, wvl_max: float):
# """Truncate the wavelength-nk data to wavelength range [wvl_min,wvl_max] for
# fitting. Data outside this range will be ignored.

# Parameters
# ----------
# wvl_min : float
# The beginning of wavelength range. Unit: micron
# wvl_max : float
# The end of wavelength range. Unit: micron
# """

# if wvl_min < np.min(self.wvl_um):
# log.warning(
# "wvl_min is smaller than the minimal wavelength provided "
# "in the data. Please increase wvl_min for accuracy."
# )
# if wvl_max > np.max(self.wvl_um):
# log.warning(
# "wvl_max is larger than the maximal wavelength provided "
# "in the data. Please decrease wvl_max for accuracy."
# )

# ind_min = self.wvl_um >= wvl_min
# ind_max = self.wvl_um <= wvl_max
# ind_final = np.logical_and(ind_min, ind_max)

# # update
# self.wvl_um = self.wvl_um[ind_final]
# self.n_data = self.n_data[ind_final]
# self.k_data = self.k_data[ind_final]

# if not np.all(ind_final):
# log.error("No data within [wvl_min,wvl_max]")
# return

# self.__init__(self.wvl_um, self.n_data, self.k_data)

@staticmethod
def _validate_data(wvl_um: Numpy, n_data: Numpy, k_data: Numpy = None):
"""make sure raw data is correctly shaped.
Expand All @@ -182,7 +229,7 @@ def _validate_data(wvl_um: Numpy, n_data: Numpy, k_data: Numpy = None):
assert wvl_um.shape == k_data.shape

@staticmethod
def eV_to_Hz(f_eV: float):
def _eV_to_Hz(f_eV: float): # pylint:disable=invalid-name
"""convert frequency in unit of eV to Hz
Parameters
Expand All @@ -194,7 +241,7 @@ def eV_to_Hz(f_eV: float):
return f_eV / HBAR / 2 / np.pi

@staticmethod
def Hz_to_eV(f_Hz: float):
def _Hz_to_eV(f_Hz: float): # pylint:disable=invalid-name
"""convert frequency in unit of Hz to eV
Parameters
Expand Down Expand Up @@ -240,7 +287,7 @@ def fit(

while not progress.finished:

medium, rms_error = self.fit_single(num_poles=num_poles)
medium, rms_error = self._fit_single(num_poles=num_poles)

# if improvement, set the best RMS and coeffs
if rms_error < best_rms:
Expand Down Expand Up @@ -280,7 +327,7 @@ def _make_medium(self, coeffs):
poles_complex = _coeffs_to_poles(coeffs)
return PoleResidue(poles=poles_complex, frequency_range=self.frequency_range)

def fit_single(
def _fit_single(
self,
num_poles: int = 3,
) -> Tuple[PoleResidue, float]:
Expand Down Expand Up @@ -448,9 +495,141 @@ def plot(

return ax

@staticmethod
def _validate_url_load(data_load: List):
"""Validate if the loaded data from URL is valid
The data list should be in this format:
[["wl", "n"],
[float, float],
. .
. .
. .
(if lossy)
["wl", "k"],
[float, float],
. .
. .
. .]]
Parameters
----------
data_load : List
Loaded data from URL
Raises
------
ValidationError
Or other exceptions
"""
has_k = 0

if data_load[0][0] != "wl" or data_load[0][1] != "n":
raise ValidationError(
"Invalid URL. The file should begin with ['wl','n']. "
"Or make sure that you have supplied an appropriate delimiter."
)

for row in data_load[1:]:
if row[0] == "wl":
if row[1] == "k":
has_k += 1
else:
raise ValidationError(
"Invalid URL. The file is not well formatted for ['wl', 'k'] data."
)
else:
# make sure the rest is float type
try:
nk_tmp = [float(x) for x in row] # pylint:disable=unused-variable
except Exception as e:
log.error(e)
raise ValidationError( # pylint:disable=raise-missing-from
"Invalid URL. Float data cannot be recognized."
)

if has_k > 1:
raise ValidationError("Invalid URL. Too many k labels.")

@classmethod
def from_url(cls, url_file: str, delimiter: str = ","):
"""loads ``DispersionFitter`` from url pointing to a csv/txt file
that contains wavelength (micron), n, and optionally k data.
Preferred from refractiveindex.info.
Example
-------
The data from url should be in this format:
wl n
float float
. .
. .
. .
(if lossy)
wl k
float float
. .
. .
. .
Parameters
----------
url_file : str
Url link to the csv file.
e.g. "https://refractiveindex.info/data_csv.php?datafile=data/main/Ag/Johnson.yml"
delimiter : str = ","
For refractiveindex.info, delimiter="," for csv, and "\t" for txt
"""

resp = requests.get(url_file)

try:
resp.raise_for_status()
except Exception as e: # pylint:disable=broad-except
log.error(e)
raise WebError( # pylint:disable=raise-missing-from
"Connection to the website failed. Please provide a valid URL."
)

data_url = list(
csv.reader(codecs.iterdecode(resp.iter_lines(), "utf-8"), delimiter=delimiter)
)
data_url = list(data_url)

# first validate data
cls._validate_url_load(data_url)

# parsing the data
n_lam = []
k_lam = [] # the two variables contain [wvl_um, n(k)]
has_k = 0 # whether k is in the data

for row in data_url[1:]:
if has_k == 1:
k_lam.append([float(x) for x in row])
else:
if row[0] == "wl":
has_k += 1
else:
n_lam.append([float(x) for x in row])

n_lam = np.array(n_lam)
k_lam = np.array(k_lam)

# for data containing k
if has_k == 1:
# now let's make sure wvl_um in n_lam and k_lam match
if not np.allclose(n_lam[:, 0], k_lam[:, 0]):
raise ValidationError(
"Invalid URL. Both n and k should be provided at each wavelength."
)

return cls(wvl_um=n_lam[:, 0], n_data=n_lam[:, 1], k_data=k_lam[:, 1])
return cls(wvl_um=n_lam[:, 0], n_data=n_lam[:, 1])

@classmethod
def from_file(cls, fname, **loadtxt_kwargs):
"""Loads ``DispersionFitter`` from file contining wavelength, n, k data.
"""Loads ``DispersionFitter`` from file containing wavelength, n, k data.
Parameters
----------
Expand Down
Loading

0 comments on commit b6e470b

Please sign in to comment.