From 74c5eb6070e8bfbe5ca3b9885eb9639cc27f53af Mon Sep 17 00:00:00 2001 From: Weiliang Jin Date: Wed, 2 Feb 2022 15:09:56 -0800 Subject: [PATCH] Fitter improvement Change DispersionFitter to a pydantic.BaseModel object. Make eps_data, lossy, freqs, frequency_range into @property types More error handling in StableDispersiveFitter Allow to load dispersive data directly by providing URL to txt or csv file Allow to filter wavelength range for fitting Add tunable eps_inf as a variable for optimization --- CHANGELOG.md | 2 + tests/_test_fit_web.py | 37 +- tests/test_plugins.py | 23 +- tests/test_plugins_web.py | 46 ++- tidy3d/plugins/__init__.py | 2 +- tidy3d/plugins/dispersion/fit.py | 591 +++++++++++++++++++-------- tidy3d/plugins/dispersion/fit_web.py | 223 ++++++---- 7 files changed, 645 insertions(+), 279 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 41d1a5494..d7d192e14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - ``Selmeier.from_dispersion()`` method to quickly make a single-pole fit for lossless weakly dispersive materials. +- Stable dispersive material fits via webservice. +- Allow to load dispersive data directly by providing URL to txt or csv file - Validates simulation based on discretized size. ### Changed diff --git a/tests/_test_fit_web.py b/tests/_test_fit_web.py index 0bf7014a0..02b791b63 100644 --- a/tests/_test_fit_web.py +++ b/tests/_test_fit_web.py @@ -1,8 +1,6 @@ import numpy as np -import sys -sys.path.append("../tidy3d_client_revamp") -from tidy3d.plugins import StableDispersionFitter +from tidy3d.plugins import StableDispersionFitter, AdvancedFitterParam def test_dispersion_load_list(): @@ -10,7 +8,7 @@ def test_dispersion_load_list(): num_data = 10 n_data = np.random.random(num_data) wvls = np.linspace(1, 2, num_data) - fitter = StableDispersionFitter(wvls, n_data) + fitter = StableDispersionFitter(wvl_um=wvls, n_data=n_data) num_poles = 3 num_tries = 10 @@ -27,7 +25,7 @@ def test_dispersion_load_file(): fitter = StableDispersionFitter.from_file("tests/data/nk_data.csv", skiprows=1, delimiter=",") num_poles = 3 - num_tries = 30 + num_tries = 10 tolerance_rms = 1e-3 best_medium, best_rms = fitter.fit( num_tries=num_tries, num_poles=num_poles, tolerance_rms=tolerance_rms @@ -36,5 +34,30 @@ def test_dispersion_load_file(): print(best_medium.eps_model(1e12)) -if __name__ == "__main__": - test_dispersion_load_file() +def test_dispersion_load_url(): + + url_csv = "https://refractiveindex.info/data_csv.php?datafile=data/main/Ag/Johnson.yml" + fitter = StableDispersionFitter.from_url(url_csv) + + num_poles = 2 + num_tries = 10 + tolerance_rms = 1e-3 + best_medium, best_rms = fitter.fit( + num_tries=num_tries, + num_poles=num_poles, + tolerance_rms=tolerance_rms, + advanced_param=AdvancedFitterParam(constraint="hard", bound_eps_inf=10), + ) + print(best_rms) + print(best_medium.eps_inf) + + fitter.wvl_range = [1.0, 1.3] + print(len(fitter.freqs)) + best_medium, best_rms = fitter.fit( + num_tries=num_tries, + num_poles=num_poles, + tolerance_rms=tolerance_rms, + advanced_param=AdvancedFitterParam(constraint="hard", bound_eps_inf=10), + ) + print(best_rms) + print(best_medium.eps_inf) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 0b55c58ad..d6fb2ede2 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -5,8 +5,6 @@ import tidy3d as td from tidy3d.plugins import DispersionFitter -from tidy3d.plugins.dispersion.fit import _poles_to_coeffs, _coeffs_to_poles -from tidy3d.plugins.dispersion.fit import _pack_coeffs, _unpack_coeffs from tidy3d.plugins import ModeSolver from tidy3d.plugins import Near2Far @@ -84,11 +82,14 @@ def test_dispersion(): num_data = 10 n_data = np.random.random(num_data) wvls = np.linspace(1, 2, num_data) - fitter = DispersionFitter(wvls, n_data) - medium, rms = fitter.fit_single() + fitter = DispersionFitter(wvl_um=wvls, n_data=n_data) + medium, rms = fitter._fit_single() medium, rms = fitter.fit(num_tries=2) medium.to_file("tests/tmp/medium_fit.json") + k_data = np.random.random(num_data) + fitter = DispersionFitter(wvl_um=wvls, n_data=n_data, k_data=k_data) + def test_dispersion_load(): """loads dispersion model from nk data file""" @@ -101,3 +102,17 @@ def test_dispersion_plot(): fitter = DispersionFitter.from_file("tests/data/nk_data.csv", skiprows=1, delimiter=",") medium, rms = fitter.fit(num_tries=20) fitter.plot(medium) + + +def test_dispersion_set_wvg_range(): + """set wavelength range function""" + num_data = 50 + n_data = np.random.random(num_data) + wvls = np.linspace(1, 2, num_data) + fitter = DispersionFitter(wvl_um=wvls, n_data=n_data) + + wvl_min = np.random.random(1)[0] * 0.5 + 1 + wvl_max = wvl_min + 0.5 + fitter.wvl_range = [wvl_min, wvl_max] + assert len(fitter.freqs) < num_data + medium, rms = fitter.fit(num_tries=2) diff --git a/tests/test_plugins_web.py b/tests/test_plugins_web.py index 213b59aa2..008fb7952 100644 --- a/tests/test_plugins_web.py +++ b/tests/test_plugins_web.py @@ -1,6 +1,6 @@ import numpy as np -from tidy3d.plugins import StableDispersionFitter +from tidy3d.plugins import StableDispersionFitter, DispersionFitter def test_dispersion_load_list(): @@ -8,29 +8,33 @@ def test_dispersion_load_list(): num_data = 10 n_data = np.random.random(num_data) wvls = np.linspace(1, 2, num_data) - fitter = StableDispersionFitter(wvls, n_data) - - # num_poles = 3 - # num_tries = 50 - # tolerance_rms = 1e-3 - # local_run = True - # best_medium, best_rms = fitter.fit( - # num_tries=num_tries, num_poles=num_poles, tolerance_rms=tolerance_rms, local_run=True - # ) - # print(best_rms) - # print(best_medium.eps_model(1e12)) + fitter = StableDispersionFitter(wvl_um=wvls, n_data=n_data) def test_dispersion_load_file(): """loads dispersion model from nk data file""" fitter = StableDispersionFitter.from_file("tests/data/nk_data.csv", skiprows=1, delimiter=",") - # num_poles = 3 - # num_tries = 50 - # tolerance_rms = 1e-3 - # local_run = True - # best_medium, best_rms = fitter.fit( - # num_tries=num_tries, num_poles=num_poles, tolerance_rms=tolerance_rms, local_run=True - # ) - # print(best_rms) - # print(best_medium.eps_model(1e12)) + +def test_dispersion_load_url(): + """performs a fit on some random data""" + + # both n and k + url_csv = "https://refractiveindex.info/data_csv.php?datafile=data/main/Ag/Johnson.yml" + url_txt = "https://refractiveindex.info/data_txt.php?datafile=data/main/Ag/Johnson.yml" + fitter = DispersionFitter.from_url(url_csv, delimiter=",") + fitter = StableDispersionFitter.from_url(url_csv, delimiter=",") + fitter_txt = DispersionFitter.from_url(url_txt, delimiter="\t") + fitter_txt = StableDispersionFitter.from_url(url_txt, delimiter="\t") + fitter_txt.wvl_range = [0.3, 0.8] + assert len(fitter_txt.freqs) < len(fitter.freqs) + + # only k + url_csv = "https://refractiveindex.info/data_csv.php?datafile=data/main/N2/Peck-0C.yml" + url_txt = "https://refractiveindex.info/data_txt.php?datafile=data/main/N2/Peck-0C.yml" + fitter = DispersionFitter.from_url(url_csv, delimiter=",") + fitter_txt = DispersionFitter.from_url(url_txt, delimiter="\t") + + +if __name__ == "__main__": + test_dispersion_load_url() diff --git a/tidy3d/plugins/__init__.py b/tidy3d/plugins/__init__.py index 1b278045a..6c2bfb7a5 100644 --- a/tidy3d/plugins/__init__.py +++ b/tidy3d/plugins/__init__.py @@ -1,6 +1,6 @@ # import the specific classes / functions needed for the plugins from .dispersion.fit import DispersionFitter -from .dispersion.fit_web import StableDispersionFitter +from .dispersion.fit_web import StableDispersionFitter, AdvancedFitterParam from .mode.mode_solver import ModeSolver from .near2far.near2far import Near2Far diff --git a/tidy3d/plugins/dispersion/fit.py b/tidy3d/plugins/dispersion/fit.py index cf28c8d98..b0d64ceb3 100644 --- a/tidy3d/plugins/dispersion/fit.py +++ b/tidy3d/plugins/dispersion/fit.py @@ -1,188 +1,291 @@ """Fit PoleResidue Dispersion models to optical NK data """ -from typing import Tuple +from typing import Tuple, List +import csv +import codecs +import requests import nlopt import numpy as np from rich.progress import Progress +from pydantic import BaseModel, Field, validator from ...components import PoleResidue, AbstractMedium -from ...constants import C_0, HBAR +from ...constants import C_0, HBAR, MICROMETER from ...components.viz import add_ax_if_none -from ...components.types import Ax, Numpy -from ...log import log - - -def _unpack_complex(complex_num): - """Returns real and imaginary parts from complex number. - - Parameters - ---------- - complex_num : complex - Complex number. - - Returns - ------- - Tuple[float, float] - Real and imaginary parts of the complex number. - """ - return complex_num.real, complex_num.imag - - -def _pack_complex(real_part, imag_part): - """Returns complex number from real and imaginary parts. - - Parameters - ---------- - real_part : float - Real part of the complex number. - imag_part : float - Imaginary part of the complex number. - - Returns - ------- - complex - The complex number. - """ - return real_part + 1j * imag_part - - -def _unpack_coeffs(coeffs): - """Unpacks coefficient vector into complex pole parameters. - - Parameters - ---------- - coeffs : np.ndarray[real] - Array of real coefficients for the pole residue fit. - - Returns - ------- - Tuple[np.ndarray[complex], np.ndarray[complex]] - "a" and "c" poles for the PoleResidue model. - """ - assert len(coeffs) % 4 == 0, "len(coeffs) must be multiple of 4." - num_poles = len(coeffs) // 4 - indices = 4 * np.arange(num_poles) - - a_real = coeffs[indices + 0] - a_imag = coeffs[indices + 1] - c_real = coeffs[indices + 2] - c_imag = coeffs[indices + 3] - - poles_a = _pack_complex(a_real, a_imag) - poles_c = _pack_complex(c_real, c_imag) - return poles_a, poles_c - - -def _pack_coeffs(pole_a, pole_c): - """Packs complex a and c pole parameters into coefficient array. - - Parameters - ---------- - pole_a : np.ndarray[complex] - Array of complex "a" poles for the PoleResidue dispersive model. - pole_c : np.ndarray[complex] - Array of complex "c" poles for the PoleResidue dispersive model. - - Returns - ------- - np.ndarray[float] - Array of real coefficients for the pole residue fit. - """ - a_real, a_imag = _unpack_complex(pole_a) - c_real, c_imag = _unpack_complex(pole_c) - stacked_coeffs = np.stack((a_real, a_imag, c_real, c_imag), axis=1) - return stacked_coeffs.flatten() - - -def _coeffs_to_poles(coeffs): - """Converts model coefficients to poles. - - Parameters - ---------- - coeffs : np.ndarray[float] - Array of real coefficients for the pole residue fit. - - Returns - ------- - List[Tuple[complex, complex]] - List of complex poles (a, c) - """ - coeffs_scaled = coeffs / HBAR - poles_a, poles_c = _unpack_coeffs(coeffs_scaled) - poles = [(complex(a), complex(c)) for (a, c) in zip(poles_a, poles_c)] - # poles = [((a.real, a.imag), (c.real, c.imag)) for (a, c) in zip(poles_a, poles_c)] - return poles - - -def _poles_to_coeffs(poles): - """Converts poles to model coefficients. - - Parameters - ---------- - poles : List[Tuple[complex, complex]] - List of complex poles (a, c) - - Returns - ------- - np.ndarray[float] - Array of real coefficients for the pole residue fit. - """ - poles_a, poles_c = np.array([[a, c] for (a, c) in poles]).T - coeffs = _pack_coeffs(poles_a, poles_c) - return coeffs * HBAR - - -class DispersionFitter: +from ...components.types import Ax, Numpy, NumpyArray, ArrayLike +from ...log import log, ValidationError, WebError, SetupError + + +class DispersionFitter(BaseModel): """Tool for fitting refractive index data to get a dispersive ``Medium``.""" - def __init__(self, wvl_um: Numpy, n_data: Numpy, k_data: Numpy = None): - """Make a ``DispersionFitter`` with raw wavelength-nk data. + wvl_um: ArrayLike = Field( + ..., + title="Wavelength data", + description="Wavelength data in micrometers.", + unit=MICROMETER, + ) + + n_data: ArrayLike = Field( + ..., + title="Index of refraction data", + description="Real part of the complex index of refraction.", + ) + + k_data: ArrayLike = Field( + None, + title="Extinction coefficient data", + description="Imaginary part of the complex index of refraction.", + ) + + wvl_range: Tuple[float, float] = Field( + [None, None], + title="Wavelength range [wvl_min,wvl_max] for fitting", + description="Truncate the wavelength-nk data to wavelength range " + "[wvl_min,wvl_max] for fitting", + unit=MICROMETER, + ) + + @validator("wvl_um", always=True) + def _setup_wvl(cls, val): + """Convert wvl_um to a numpy array""" + if len(val) < 1: + raise ValidationError("The length of data cannot be empty.") + return np.array(val) + + @validator("n_data", always=True) + def _ndata_length_match_wvl(cls, val, values): + """Validate n_data""" + _val = np.array(val) + if _val.shape != values["wvl_um"].shape: + raise ValidationError("The length of n_data doesn't match wvl_um.") + return _val + + @validator("k_data", always=True) + def _kdata_setup_and_length_match(cls, val, values): + """ + validate the length of k_data, or setup k if it's None + """ + if val is None: + return np.zeros_like(values["wvl_um"]) + _val = np.array(val) + if _val.shape != values["wvl_um"].shape: + raise ValidationError("The length of k_data doesn't match wvl_um.") + return _val + + def _filter_wvl_range( + self, wvl_min: float = None, wvl_max: float = None + ) -> Tuple[NumpyArray, NumpyArray, NumpyArray]: + """ + Filter the wavelength-nk data to wavelength range [wvl_min,wvl_max] + for fitting. Parameters ---------- - wvl_um : Numpy - Wavelength data in micrometers. - n_data : Numpy - Real part of refractive index in micrometers. - k_data : Numpy, optional - Imaginary part of refractive index in micrometers. + wvl_min : float, optional + The beginning of wavelength range. Unit: micron + wvl_max : float, optional + The end of wavelength range. Unit: micron + + Returns + ------- + Tuple[NumpyArray,NumpyArray,NumpyArray] + Filtered wvl_um, n_data, k_data + """ - self._validate_data(wvl_um, n_data, k_data) - self.wvl_um = wvl_um - self.n_data = n_data - self.k_data = k_data - self.lossy = True + ind_select = np.ones(self.wvl_um.shape, dtype=bool) + if wvl_min is not None: + ind_select = np.logical_and(self.wvl_um >= wvl_min, ind_select) + + if wvl_max is not None: + ind_select = np.logical_and(self.wvl_um <= wvl_max, ind_select) + + if not np.any(ind_select): + raise SetupError("No data within [wvl_min,wvl_max]") + + return self.wvl_um[ind_select], self.n_data[ind_select], self.k_data[ind_select] - # handle lossless case + @property + def lossy(self) -> bool: + """Find out if the medium is lossy or lossless + based on the filtered input data. + + Returns + ------- + bool + True for lossy medium; False for lossless medium + """ + _, _, k_data = self._filter_wvl_range(wvl_min=self.wvl_range[0], wvl_max=self.wvl_range[1]) if k_data is None: - self.k_data = np.zeros_like(n_data) - self.lossy = False - self.eps_data = AbstractMedium.nk_to_eps_complex(n=self.n_data, k=self.k_data) - self.freqs = C_0 / wvl_um - self.frequency_range = (np.min(self.freqs), np.max(self.freqs)) + return False + if not np.any(k_data): + return False + return True + + @property + def eps_data(self) -> complex: + """Convert filtered input n(k) data into complex permittivity. + + Returns + ------- + complex + Complex-valued relative permittivty. + """ + _, n_data, k_data = self._filter_wvl_range( + wvl_min=self.wvl_range[0], wvl_max=self.wvl_range[1] + ) + return AbstractMedium.nk_to_eps_complex(n=n_data, k=k_data) + + @property + def freqs(self) -> NumpyArray: + """Convert filtered input wavelength data to frequency. + + Returns + ------- + NumpyArray + Frequency array converted from filtered input wavelength data + """ + + wvl_um, _, _ = self._filter_wvl_range(wvl_min=self.wvl_range[0], wvl_max=self.wvl_range[1]) + return C_0 / wvl_um + + @property + def frequency_range(self) -> Tuple[float, float]: + """Frequency range of filtered input data + + Returns + ------- + Tuple[float, float] + The minimal frequency and the maximal frequency + """ + + return (np.min(self.freqs), np.max(self.freqs)) + + @staticmethod + def _unpack_complex(complex_num): + """Returns real and imaginary parts from complex number. + + Parameters + ---------- + complex_num : complex + Complex number. + + Returns + ------- + Tuple[float, float] + Real and imaginary parts of the complex number. + """ + return complex_num.real, complex_num.imag + + @staticmethod + def _pack_complex(real_part, imag_part): + """Returns complex number from real and imaginary parts. + + Parameters + ---------- + real_part : float + Real part of the complex number. + imag_part : float + Imaginary part of the complex number. + + Returns + ------- + complex + The complex number. + """ + return real_part + 1j * imag_part + + @staticmethod + def _unpack_coeffs(coeffs): + """Unpacks coefficient vector into complex pole parameters. + + Parameters + ---------- + coeffs : np.ndarray[real] + Array of real coefficients for the pole residue fit. + + Returns + ------- + Tuple[np.ndarray[complex], np.ndarray[complex]] + "a" and "c" poles for the PoleResidue model. + """ + assert len(coeffs) % 4 == 0, "len(coeffs) must be multiple of 4." + num_poles = len(coeffs) // 4 + indices = 4 * np.arange(num_poles) + + a_real = coeffs[indices + 0] + a_imag = coeffs[indices + 1] + c_real = coeffs[indices + 2] + c_imag = coeffs[indices + 3] + + poles_a = DispersionFitter._pack_complex(a_real, a_imag) + poles_c = DispersionFitter._pack_complex(c_real, c_imag) + return poles_a, poles_c + + @staticmethod + def _pack_coeffs(pole_a, pole_c): + """Packs complex a and c pole parameters into coefficient array. + + Parameters + ---------- + pole_a : np.ndarray[complex] + Array of complex "a" poles for the PoleResidue dispersive model. + pole_c : np.ndarray[complex] + Array of complex "c" poles for the PoleResidue dispersive model. + + Returns + ------- + np.ndarray[float] + Array of real coefficients for the pole residue fit. + """ + a_real, a_imag = DispersionFitter._unpack_complex(pole_a) + c_real, c_imag = DispersionFitter._unpack_complex(pole_c) + stacked_coeffs = np.stack((a_real, a_imag, c_real, c_imag), axis=1) + return stacked_coeffs.flatten() @staticmethod - def _validate_data(wvl_um: Numpy, n_data: Numpy, k_data: Numpy = None): - """make sure raw data is correctly shaped. + def _coeffs_to_poles(coeffs): + """Converts model coefficients to poles. Parameters ---------- - wvl_um : Numpy - Wavelength data in micrometers. - n_data : Numpy - Real part of refractive index in micrometers. - k_data : Numpy, optional - Imaginary part of refractive index in micrometers. + coeffs : np.ndarray[float] + Array of real coefficients for the pole residue fit. + + Returns + ------- + List[Tuple[complex, complex]] + List of complex poles (a, c) """ - assert wvl_um.shape == n_data.shape - if k_data is not None: - assert wvl_um.shape == k_data.shape + coeffs_scaled = coeffs / HBAR + poles_a, poles_c = DispersionFitter._unpack_coeffs(coeffs_scaled) + poles = [(complex(a), complex(c)) for (a, c) in zip(poles_a, poles_c)] + # poles = [((a.real, a.imag), (c.real, c.imag)) for (a, c) in zip(poles_a, poles_c)] + return poles @staticmethod - def eV_to_Hz(f_eV: float): + def _poles_to_coeffs(poles): + """Converts poles to model coefficients. + + Parameters + ---------- + poles : List[Tuple[complex, complex]] + List of complex poles (a, c) + + Returns + ------- + np.ndarray[float] + Array of real coefficients for the pole residue fit. + """ + poles_a, poles_c = np.array([[a, c] for (a, c) in poles]).T + coeffs = DispersionFitter._pack_coeffs(poles_a, poles_c) + return coeffs * HBAR + + @staticmethod + def _eV_to_Hz(f_eV: float): # pylint:disable=invalid-name """convert frequency in unit of eV to Hz Parameters @@ -194,7 +297,7 @@ def eV_to_Hz(f_eV: float): return f_eV / HBAR / 2 / np.pi @staticmethod - def Hz_to_eV(f_Hz: float): + def _Hz_to_eV(f_Hz: float): # pylint:disable=invalid-name """convert frequency in unit of Hz to eV Parameters @@ -207,9 +310,9 @@ def Hz_to_eV(f_Hz: float): def fit( self, - num_poles: int = 3, - num_tries: int = 100, - tolerance_rms: float = 0.0, + num_poles: int = 1, + num_tries: int = 50, + tolerance_rms: float = 1e-2, ) -> Tuple[PoleResidue, float]: """Fits data a number of times and returns best results. @@ -224,7 +327,7 @@ def fit( Returns ------- - Tuple[``PoleResidue``, float] + Tuple[:class:``PoleResidue``, float] Best results of multiple fits: (dispersive medium, RMS error). """ @@ -240,7 +343,7 @@ def fit( while not progress.finished: - medium, rms_error = self.fit_single(num_poles=num_poles) + medium, rms_error = self._fit_single(num_poles=num_poles) # if improvement, set the best RMS and coeffs if rms_error < best_rms: @@ -277,10 +380,10 @@ def _make_medium(self, coeffs): ``PoleResidue`` Dispersive medium corresponding to this set of ``coeffs``. """ - poles_complex = _coeffs_to_poles(coeffs) + poles_complex = DispersionFitter._coeffs_to_poles(coeffs) return PoleResidue(poles=poles_complex, frequency_range=self.frequency_range) - def fit_single( + def _fit_single( self, num_poles: int = 3, ) -> Tuple[PoleResidue, float]: @@ -293,7 +396,7 @@ def fit_single( Returns ------- - Tuple[``PoleResidue``, float] + Tuple[:class:``PoleResidue``, float] Results of single fit: (dispersive medium, RMS error). """ @@ -316,9 +419,9 @@ def constraint(coeffs, _grad): float Value of constraint. """ - poles_a, poles_c = _unpack_coeffs(coeffs) - a_real, a_imag = _unpack_complex(poles_a) - c_real, c_imag = _unpack_complex(poles_c) + poles_a, poles_c = DispersionFitter._unpack_coeffs(coeffs) + a_real, a_imag = DispersionFitter._unpack_complex(poles_a) + c_real, c_imag = DispersionFitter._unpack_complex(poles_c) prstar = a_real * c_real + a_imag * c_imag res = 2 * prstar * a_real - c_real * (a_real * a_real + a_imag * a_imag) res[res >= 0] = 0 @@ -426,7 +529,7 @@ def plot( """ if wvl_um is None: - wvl_um = self.wvl_um + wvl_um = C_0 / self.freqs freqs = C_0 / wvl_um eps_model = medium.eps_model(freqs) @@ -448,9 +551,151 @@ def plot( return ax + @staticmethod + def _validate_url_load(data_load: List): + """Validate if the loaded data from URL is valid + The data list should be in this format: + [["wl", "n"], + [float, float], + . . + . . + . . + (if lossy) + ["wl", "k"], + [float, float], + . . + . . + . .]] + + Parameters + ---------- + data_load : List + Loaded data from URL + + Raises + ------ + ValidationError + Or other exceptions + """ + has_k = 0 + + if data_load[0][0] != "wl" or data_load[0][1] != "n": + raise ValidationError( + "Invalid URL. The file should begin with ['wl','n']. " + "Or make sure that you have supplied an appropriate delimiter." + ) + + for row in data_load[1:]: + if row[0] == "wl": + if row[1] == "k": + has_k += 1 + else: + raise ValidationError( + "Invalid URL. The file is not well formatted for ['wl', 'k'] data." + ) + else: + # make sure the rest is float type + try: + nk_tmp = [float(x) for x in row] # pylint:disable=unused-variable + except Exception as e: + raise ValidationError("Invalid URL. Float data cannot be recognized.") from e + + if has_k > 1: + raise ValidationError("Invalid URL. Too many k labels.") + + @classmethod + def from_url(cls, url_file: str, delimiter: str = ","): + """loads ``DispersionFitter`` from url linked to a csv/txt file that + contains wavelength (micron), n, and optionally k data. Preferred from + refractiveindex.info. + + Hint + ---- + The data file from url should be in this format (delimiter not displayed + here, and note that the strings such as "wl", "n" need to be included + in the file): + + * For lossless media:: + + wl n + [float] [float] + . . + . . + . . + + * For lossy media:: + + wl n + [float] [float] + . . + . . + . . + wl k + [float] [float] + . . + . . + . . + + Parameters + ---------- + url_file : str + Url link to the data file. + e.g. "https://refractiveindex.info/data_csv.php?datafile=data/main/Ag/Johnson.yml" + delimiter : str = "," + E.g. in refractiveindex.info, it'll be "," for csv file, and "\\\\t" for txt file. + + Returns + ------- + DispersionFitter + A ``DispersionFitter`` instance. + """ + + resp = requests.get(url_file) + + try: + resp.raise_for_status() + except Exception as e: # pylint:disable=broad-except + raise WebError("Connection to the website failed. Please provide a valid URL.") from e + + data_url = list( + csv.reader(codecs.iterdecode(resp.iter_lines(), "utf-8"), delimiter=delimiter) + ) + data_url = list(data_url) + + # first validate data + cls._validate_url_load(data_url) + + # parsing the data + n_lam = [] + k_lam = [] # the two variables contain [wvl_um, n(k)] + has_k = 0 # whether k is in the data + + for row in data_url[1:]: + if has_k == 1: + k_lam.append([float(x) for x in row]) + else: + if row[0] == "wl": + has_k += 1 + else: + n_lam.append([float(x) for x in row]) + + n_lam = np.array(n_lam) + k_lam = np.array(k_lam) + + # for data containing k + if has_k == 1: + # now let's make sure wvl_um in n_lam and k_lam match + if not np.allclose(n_lam[:, 0], k_lam[:, 0]): + raise ValidationError( + "Invalid URL. Both n and k should be provided at each wavelength." + ) + + return cls(wvl_um=n_lam[:, 0], n_data=n_lam[:, 1], k_data=k_lam[:, 1]) + return cls(wvl_um=n_lam[:, 0], n_data=n_lam[:, 1]) + @classmethod - def from_file(cls, fname, **loadtxt_kwargs): - """Loads ``DispersionFitter`` from file contining wavelength, n, k data. + def from_file(cls, fname: str, **loadtxt_kwargs): + """Loads ``DispersionFitter`` from file containing wavelength, n, k data. Parameters ---------- diff --git a/tidy3d/plugins/dispersion/fit_web.py b/tidy3d/plugins/dispersion/fit_web.py index 8f2d4d583..351025cce 100644 --- a/tidy3d/plugins/dispersion/fit_web.py +++ b/tidy3d/plugins/dispersion/fit_web.py @@ -1,18 +1,62 @@ """Fit PoleResidue Dispersion models to optical NK data based on web service """ from typing import Tuple, List -import numpy as np +from enum import Enum import requests from pydantic import BaseModel, PositiveInt, NonNegativeFloat, PositiveFloat, Field from ...components.types import Literal from ...components import PoleResidue -from ...constants import MICROMETER -from ...log import log +from ...constants import MICROMETER, HERTZ +from ...log import log, WebError, Tidy3dError from .fit import DispersionFitter +BOUND_MAX_FACTOR = 10 -class FitterData(BaseModel): + +class AdvancedFitterParam(BaseModel): + """Advanced fitter parameters""" + + bound_amp: NonNegativeFloat = Field( + None, + title="Upper bound of oscillator strength", + description="Upper bound of oscillator strength in the model " + "(The default 'None' will trigger automatic setup based on the " + "frequency range of interest).", + unis=HERTZ, + ) + bound_f: NonNegativeFloat = Field( + None, + title="Upper bound of pole frequency", + description="Upper bound of pole frequency in the model " + "(The default 'None' will trigger automatic setup based on the " + "frequency range of interest).", + units=HERTZ, + ) + bound_eps_inf: float = Field( + 1.0, + title="Upper bound of epsilon at infinity frequency", + description="Upper bound of epsilon at infinity frequency. It must be no less than 1.", + ge=1, + ) + constraint: Literal["hard", "soft"] = Field( + "hard", + title="Type of constraint for stability", + description="Stability constraint: 'hard' constraints are generally recommended since " + "they are faster to compute per iteration, and they often require fewer iterations to " + "converge since the search space is smaller. But sometimes the search space is " + "so restrictive that all good solutions are missed, then please try the 'soft' constraints " + "for larger search space. However, both constraints improve stability equally well.", + ) + nlopt_maxeval: PositiveInt = Field( + 5000, + title="Number of inner iterations", + description="Number of iterations in each inner optimization.", + ) + + +# FitterData will be used internally +class FitterData(AdvancedFitterParam): """Data class for request body of Fitter where dipsersion data is input through list""" wvl_um: List[float] = Field( @@ -57,26 +101,22 @@ class FitterData(BaseModel): description="Upper bound of pole frequency in the model.", units="eV", ) - constraint: Literal["hard", "soft"] = Field( - "hard", - title="Type of constraint for stability", - description="Stability constraint enfored on each pole (hard)," - " or the summed contribution (soft).", - ) - nlopt_maxeval: PositiveInt = Field( - 1000, - title="Number of inner iterations", - description="Number of iterations in each optimization", - ) URL_ENV = { - "local": "http://127.0.0.1:8000/dispersion/fit", - "dev": "https://tidy3d-service.dev-simulation.cloud/dispersion/fit", - "prod": "https://tidy3d-service.simulation.cloud/dispersion/fit", + "local": "http://127.0.0.1:8000", + "dev": "https://tidy3d-service.dev-simulation.cloud", + "prod": "https://tidy3d-service.simulation.cloud", } +class ExceptionCodes(Enum): + """HTTP exception codes to handle individually.""" + + GATEWAY_TIMEOUT = 504 + NOT_FOUND = 404 + + class StableDispersionFitter(DispersionFitter): """Stable fitter based on web service""" @@ -102,15 +142,49 @@ def _set_url(config_env: Literal["default", "dev", "prod", "local"] = "default") return URL_ENV[_env] - def fit( # pylint:disable=arguments-differ, too-many-arguments, too-many-locals + @staticmethod + def _setup_server(url_server: str): + """set up web server access + + Parameters + ---------- + url_server : str + URL for the server + """ + + from ...web.auth import ( # pylint:disable=import-outside-toplevel, unused-import + get_credentials, + ) + from ...web.httputils import ( # pylint:disable=import-outside-toplevel + get_headers, + ) + + # get_credentials() + access_token = get_headers() + headers = {"Authorization": access_token["Authorization"]} + + # test connection + resp = requests.get(url_server + "/health") + try: + resp.raise_for_status() + except Exception as e: + raise WebError("Connection to the server failed. Please try again.") from e + + # test authorization + resp = requests.get(url_server + "/health/access", headers=headers) + try: + resp.raise_for_status() + except Exception as e: + raise WebError("Authorization to the server failed. Please try again.") from e + + return headers + + def fit( # pylint:disable=arguments-differ self, num_poles: PositiveInt = 1, num_tries: PositiveInt = 50, - tolerance_rms: NonNegativeFloat = 0.0, - bound_amp: PositiveFloat = DispersionFitter.eV_to_Hz(100.0), - bound_f: PositiveFloat = DispersionFitter.eV_to_Hz(100.0), - constraint: Literal["hard", "soft"] = "hard", - nlopt_maxeval: PositiveInt = 1000, + tolerance_rms: NonNegativeFloat = 1e-2, + advanced_param: AdvancedFitterParam = AdvancedFitterParam(), ) -> Tuple[PoleResidue, float]: """Fits data a number of times and returns best results. @@ -122,71 +196,74 @@ def fit( # pylint:disable=arguments-differ, too-many-arguments, too-many-locals Number of optimizations to run with random initial guess. tolerance_rms : NonNegativeFloat, optional RMS error below which the fit is successful and the result is returned. - bound_amp : PositiveFloat, optional - Bound on the amplitude of poles, namely on Re[c], Im[c], Re[a] - bound_f : PositiveFloat, optional - Bound on the frequency of poles, namely on Im[a]. Unit: Hz - constraint : Literal["hard", "soft"], optional - 'hard' or 'soft' - nlopt_maxeval : PositiveInt, optional - Maxeval in each optimization - - Returned - ------------------ - Tuple[``PoleResidue``, float] + advanced_param : :class:`AdvancedFitterParam`, optional + Other advanced parameters. + + Returns + ------- + Tuple[:class:``PoleResidue``, float] Best results of multiple fits: (dispersive medium, RMS error). """ - from ...web.auth import ( - get_credentials, - ) # pylint:disable=import-outside-toplevel, unused-import - from ...web.httputils import get_headers # pylint:disable=import-outside-toplevel + # get url + url_server = self._set_url("default") + headers = self._setup_server(url_server) - _url = self._set_url("default") + # set up bound_f, bound_amp + if advanced_param.bound_f is None: + advanced_param.bound_f = self.frequency_range[1] * BOUND_MAX_FACTOR + if advanced_param.bound_amp is None: + advanced_param.bound_amp = self.frequency_range[1] * BOUND_MAX_FACTOR - # get_credentials() - access_token = get_headers() - headers = {"Authorization": access_token["Authorization"]} + wvl_um, n_data, k_data = self._filter_wvl_range( + wvl_min=self.wvl_range[0], wvl_max=self.wvl_range[1] + ) web_data = FitterData( - wvl_um=self.wvl_um.tolist(), - n_data=self.n_data.tolist(), - k_data=self.k_data.tolist(), + wvl_um=wvl_um.tolist(), + n_data=n_data.tolist(), + k_data=k_data.tolist(), num_poles=num_poles, num_tries=num_tries, tolerance_rms=tolerance_rms, - bound_amp=self.Hz_to_eV(bound_amp), - bound_f=self.Hz_to_eV(bound_f), - constraint=constraint, - nlopt_maxeval=nlopt_maxeval, + bound_amp=self._Hz_to_eV(advanced_param.bound_amp), + bound_f=self._Hz_to_eV(advanced_param.bound_f), + bound_eps_inf=advanced_param.bound_eps_inf, + constraint=advanced_param.constraint, + nlopt_maxeval=advanced_param.nlopt_maxeval, ) - r_post = requests.post( - _url, + resp = requests.post( + url_server + "/dispersion/fit", headers=headers, data=web_data.json(), ) try: - run_result = r_post.json() - best_medium = PoleResidue.parse_raw(run_result["message"]) - best_rms = float(run_result["rms"]) + resp.raise_for_status() + except Exception as e: + if resp.status_code == ExceptionCodes.GATEWAY_TIMEOUT.value: + raise Tidy3dError( + "Fitter failed due to timeout. Try to decrease " + "the number of tries, the number of inner iterations, " + "to relax RMS tolerance, or to use the 'hard' constraint." + ) from e - if best_rms < tolerance_rms: - log.info(f"\tfound optimal fit with RMS error = {best_rms:.2e}, returning") - else: - log.warning( - f"\twarning: did not find fit " - f"with RMS error under tolerance_rms of {tolerance_rms:.2e}" - ) - log.info(f"\treturning best fit with RMS error {best_rms:.2e}") - return best_medium, best_rms - - except Exception as e: # pylint:disable=broad-except - log.warning(e) - log.error( - "Fitter failed due to timeout. Try to decrease " - "the number of tries, the number of internal iterations, " - "or relax RMS tolerance." + raise WebError( + "Fitter failed. Try again, or tune the parameters, or contact us for more help." + ) from e + + run_result = resp.json() + best_medium = PoleResidue.parse_raw(run_result["message"]) + best_rms = float(run_result["rms"]) + + if best_rms < tolerance_rms: + log.info(f"\tfound optimal fit with RMS error = {best_rms:.2e}, returning") + else: + log.warning( + f"\twarning: did not find fit " + f"with RMS error under tolerance_rms of {tolerance_rms:.2e}" ) - return PoleResidue(), np.inf + log.info(f"\treturning best fit with RMS error {best_rms:.2e}") + + return best_medium, best_rms