diff --git a/.github/workflows/conda.yml b/.github/workflows/conda.yml index 76afc7bc..741dfd5f 100644 --- a/.github/workflows/conda.yml +++ b/.github/workflows/conda.yml @@ -4,10 +4,10 @@ on: pull_request jobs: test: - runs-on: macos-14 + runs-on: ubuntu-latest defaults: run: - shell: sh -l {0} + shell: bash steps: - uses: actions/checkout@v4 @@ -15,7 +15,7 @@ jobs: uses: conda-incubator/setup-miniconda@v3 with: python-version: '3.10' - installer-url: https://github.com/conda-forge/miniforge/releases/download/23.11.0-0/Mambaforge-23.11.0-0-MacOSX-arm64.sh + miniforge-version: latest - name: "Install AutoEIS" run: | python -m pip install --upgrade pip diff --git a/.gitignore b/.gitignore index a63b470b..8361e0de 100644 --- a/.gitignore +++ b/.gitignore @@ -166,6 +166,7 @@ examples/**/*.png examples/**/*.csv examples/**/*.txt results/ +*.pkl # Local development benchmark_jaxfit.py diff --git a/autoeis/core.py b/autoeis/core.py index 6fc66ee9..59a19c04 100644 --- a/autoeis/core.py +++ b/autoeis/core.py @@ -25,6 +25,7 @@ import pandas as pd import psutil from impedance.validation import linKK +from jax import config from mpire import WorkerPool from numpyro.infer import MCMC, NUTS from scipy.optimize import curve_fit @@ -34,6 +35,8 @@ from autoeis import io, julia_helpers, metrics, parser, utils from autoeis.models import circuit_regression, circuit_regression_wrapped # noqa: F401 +# Enforce double precision, otherwise circuit fitter fails (who knows what else!) +config.update("jax_enable_x64", True) # AutoEIS datasets are not small-enough that CPU is much faster than GPU numpyro.set_platform("cpu") @@ -57,7 +60,7 @@ # TODO: Breaks when data is noisy -> use curve_fit to extrapolate R0 -def compute_ohmic_resistance(Z: np.ndarray[complex], freq: np.ndarray[float]) -> float: +def compute_ohmic_resistance(freq: np.ndarray[float], Z: np.ndarray[complex]) -> float: """Extracts the ohmic resistance from impedance data. Parameters @@ -223,7 +226,7 @@ def preprocess_impedance_data( # Find the ohmic resistance try: - ohmic_resistance = compute_ohmic_resistance(Z_mask, freq_mask) + ohmic_resistance = compute_ohmic_resistance(freq_mask, Z_mask) ohmic_resistance_found = True except ValueError: log.error("Ohmic resistance not found. Check data or increase KK threshold.") @@ -489,15 +492,47 @@ def merge_identical_circuits(circuits: "pd.DataFrame") -> "pd.DataFrame": def perform_bayesian_inference( circuits: Union[pd.DataFrame, list[str], str], - Z: np.ndarray[complex], freq: np.ndarray[float], + Z: np.ndarray[complex], p0: Union[np.ndarray, dict, list[dict], list[np.ndarray]] = None, num_warmup=2500, num_samples=1000, num_chains=1, seed: Union[int, jax.Array] = None, progress_bar: bool = True, -) -> list[Union[numpyro.infer.mcmc.MCMC, None]]: + refine_p0: bool = False, +) -> list[tuple[Union[numpyro.infer.mcmc.MCMC, None], int]]: + """Performs Bayesian inference on the circuits based on impedance data. + + Parameters + ---------- + circuits : pd.DataFrame or list[str] + Dataframe containing circuits or list of circuit strings. + Z : np.ndarray[complex] + Complex impedance data. + freq: np.ndarray[float] + Frequency data. + p0 : Union[np.ndarray[float], dict[str, float]], optional + Initial guess for the circuit parameters (default is None). + num_warmup : int, optional + Number of warmup samples for the MCMC (default is 2500). + num_samples : int, optional + Number of samples for the MCMC (default is 1000). + num_chains : int, optional + Number of MCMC chains (default is 1). + seed : int, optional + Random seed for reproducibility (default is None). + progress_bar : bool, optional + If True, a progress bar will be displayed (default is True). + refine_p0 : bool, optional + If True, the initial guess for the circuit parameters will be refined + using the circuit fitter (default is False). + + Returns + ------- + list[tuple[numpyro.infer.mcmc.MCMC, int]] + List of MCMC objects and exit codes (0 if successful, -1 if failed). + """ # Ensure inputs are lists if isinstance(circuits, str): circuits = [circuits] @@ -521,14 +556,20 @@ def perform_bayesian_inference( num_params = len(parser.get_parameter_labels(circuit)) assert len(p0_) == num_params, f"Invalid p0 length: {p0_}" + if refine_p0: + for i, (circuit, p0_) in enumerate(zip(circuits, p0)): + p0[i] = utils.fit_circuit_parameters(circuit, freq, Z, p0=p0_) + # If circuit fitter didn't converge, use the initial guess + p0[i] = p0_ if p0[i] is None else p0[i] + # Short-circuit if no circuits are provided if len(circuits) == 0: log.warning("'circuits' dataframe is empty!") return None bi_kwargs = { - "Z": Z, "freq": freq, + "Z": Z, "num_warmup": num_warmup, "num_samples": num_samples, "num_chains": num_chains, @@ -545,8 +586,8 @@ def perform_bayesian_inference( def _perform_bayesian_inference( circuit: str, - Z: np.ndarray[complex], freq: np.ndarray[float], + Z: np.ndarray[complex], p0: Union[np.ndarray[float], dict[str, float]] = None, num_warmup=2500, num_samples=1000, @@ -586,9 +627,10 @@ def _perform_bayesian_inference( key = jax.random.PRNGKey(seed) key, subkey = jax.random.split(key) + # TODO: Remove this, circuit fitting must be done in the public API # Deal with initial values for the circuit parameters if p0 is None: - p0 = utils.fit_circuit_parameters(circuit, Z, freq) + p0 = utils.fit_circuit_parameters(circuit, freq, Z) assert isinstance(p0, dict), "p0 must be a dictionary" circuit_fn = utils.generate_circuit_fn(circuit, jit=True) @@ -607,8 +649,8 @@ def _perform_bayesian_inference( } mcmc = MCMC(nuts_kernel, **kwargs_mcmc) kwargs_inference = { - "Z": Z, "freq": freq, + "Z": Z, "priors": priors, "circuit_fn": circuit_fn, } @@ -624,8 +666,8 @@ def _perform_bayesian_inference( def _perform_bayesian_inference_batch( circuits: list[str], - Z: np.ndarray[complex], freq: np.ndarray[float], + Z: np.ndarray[complex], p0: list[dict[str, float]] = None, num_warmup=2500, num_samples=1000, @@ -662,8 +704,8 @@ def _perform_bayesian_inference_batch( # Multiprocessing requires all inputs to be iterables of the same length bi_kwargs = { "circuits": circuits, - "Z": [Z] * N, "freq": [freq] * N, + "Z": [Z] * N, "p0": p0 if isinstance(p0, list) else [p0] * N, "num_warmup": [num_warmup] * N, "num_samples": [num_samples] * N, @@ -726,8 +768,8 @@ def filter_implausible_circuits(circuits: pd.DataFrame) -> pd.DataFrame: def perform_full_analysis( - Z: np.ndarray[complex], freq: np.ndarray[float], + Z: np.ndarray[complex], iters: int = 100, parallel: bool = True, linKK_threshold: float = 5e-2, @@ -774,7 +816,7 @@ def perform_full_analysis( # Perform Bayesian inference on the filtered ECMs kwargs_mcmc = {"num_warmup": num_warmup, "num_samples": num_samples} - mcmcs = perform_bayesian_inference(circuits, Z, freq, **kwargs_mcmc) + mcmcs = perform_bayesian_inference(circuits, freq, Z, **kwargs_mcmc) # Add the results to the circuits dataframe as a new column chains, status = zip(*mcmcs) diff --git a/autoeis/models.py b/autoeis/models.py index 84c82816..ef30c6ef 100644 --- a/autoeis/models.py +++ b/autoeis/models.py @@ -27,11 +27,11 @@ def circuit_regression( ): """NumpyRo model for Bayesian inference of circuit component values.""" # Sample each element of X separately - X = jnp.array([numpyro.sample(k, v) for k, v in priors.items()]) + p = jnp.array([numpyro.sample(k, v) for k, v in priors.items()]) # Predict Z using the model circuit_fn = utils.generate_circuit_fn(circuit) circuit_fn = jax.jit(circuit_fn) - Z_pred = circuit_fn(X, freq) + Z_pred = circuit_fn(freq, p) # Define observation model for real and imaginary parts of Z sigma_real = numpyro.sample("sigma_real", dist.Exponential(rate=1.0)) numpyro.sample("obs_real", dist.Normal(Z_pred.real, sigma_real), obs=Z.real) @@ -47,9 +47,9 @@ def circuit_regression_wrapped( ): """NumpyRo model for Bayesian inference of circuit component values.""" # Sample each element of X separately - X = jnp.array([numpyro.sample(k, v) for k, v in priors.items()]) + p = jnp.array([numpyro.sample(k, v) for k, v in priors.items()]) # Predict Z using the model - Z_pred = circuit_fn(X, freq) + Z_pred = circuit_fn(freq, p) # Define observation model for real and imaginary parts of Z sigma_real = numpyro.sample("sigma_real", dist.Exponential(rate=1.0)) numpyro.sample("obs_real", dist.Normal(Z_pred.real, sigma_real), obs=Z.real) diff --git a/autoeis/utils.py b/autoeis/utils.py index 9178f3a5..90d76a5f 100644 --- a/autoeis/utils.py +++ b/autoeis/utils.py @@ -23,8 +23,8 @@ import sys from collections.abc import Iterable from contextlib import contextmanager -from functools import partial, wraps -from typing import Callable, Union +from functools import wraps +from typing import Union import jax # NOQA: F401 import jax.numpy as jnp @@ -32,9 +32,11 @@ import numpyro.distributions as dist import pandas as pd from impedance.models.circuits import CustomCircuit +from impedance.models.circuits.fitting import set_default_bounds from numpy import pi # NOQA: F401 from rich.logging import RichHandler from scipy import stats +from scipy.optimize import curve_fit # from tensorflow_probability import distributions as tfdist # NOQA: F401 import __main__ @@ -58,6 +60,8 @@ def get_logger(name: str) -> logging.Logger: return logger +log = get_logger(__name__) + # <<< Logging utils @@ -156,7 +160,7 @@ def wrapper(*args, **kwargs): try: result = func(*args, **kwargs) except TimeoutException: - print("Didn't converge in time!") + log.warning(f"{func.__name__} didn't converge in time!") result = None finally: signal.alarm(0) @@ -203,10 +207,25 @@ def is_notebook(): # >>> Circuit utils -def fit_circuit_parameters( +def parse_initial_guess( + p0: Union[np.ndarray, dict[str, float], list[float]], + circuit: str, +) -> np.ndarray: + """Parses the initial guess for circuit parameters.""" + num_params = parser.count_parameters(circuit) + if p0 is None: + return np.random.rand(num_params) + elif isinstance(p0, dict): + return np.fromiter(p0.values(), dtype=float) + elif isinstance(p0, (list, np.ndarray)): + return np.array(p0) + raise ValueError(f"Invalid initial guess: {p0}") + + +def fit_circuit_parameters_legacy( circuit: str, - Z: np.ndarray[complex], freq: np.ndarray[float], + Z: np.ndarray[complex], p0: Union[np.ndarray[float], dict[str, float]] = None, iters: int = 1, maxfev: int = 1000, @@ -215,10 +234,7 @@ def fit_circuit_parameters( # NOTE: Each circuit eval ~ 1 ms, so 1000 evals ~ 1 s # Deal with initial guess num_params = parser.count_parameters(circuit) - if p0 is None: - p0 = np.random.rand(num_params) - elif isinstance(p0, dict): - p0 = list(p0.values()) + p0 = parse_initial_guess(p0, circuit) assert len(p0) == num_params, "Wrong number of parameters in initial guess." # Fit circuit parameters @@ -246,24 +262,86 @@ def fit_circuit_parameters( return dict(zip(labels, p0)) +def fit_circuit_parameters( + circuit: str, + freq: np.ndarray[float], + Z: np.ndarray[complex], + p0: Union[np.ndarray[float], dict[str, float]] = None, + iters: int = 1, + maxfev: int = 1000, + ftol: float = 1e-13, +) -> dict[str, float]: + """Fits a circuit to impedance data and returns the parameters.""" + # Define objective function + Zc = np.hstack([Z.real, Z.imag]) + fn = generate_circuit_fn(circuit, jit=True, concat=True) + # Format obj function as f(freq, *p) not f(freq, p) for curve_fit + obj_fn = lambda freq, *p: fn(freq, p) # noqa: E731 + + # >>> Alternatively, use impedance.py to create the objective function + # from impedance.models.circuits.fitting import wrapCircuit + # circuit_impy = parser.convert_to_impedance_format(circuit) + # obj_fn = wrapCircuit(circuit_impy, constants={}) + # <<< + + # Sanitize initial guess + num_params = parser.count_parameters(circuit) + p0 = parse_initial_guess(p0, circuit) + assert len(p0) == num_params, "Wrong number of parameters in initial guess." + + # Assemble kwargs for curve_fit + circuit_impy = parser.convert_to_impedance_format(circuit) + bounds = set_default_bounds(circuit_impy) + kwargs = {"p0": p0, "bounds": bounds, "maxfev": maxfev, "ftol": ftol} + + # Fit circuit parameters by brute force + err_min = np.inf + for _ in range(iters): + try: + popt, pcov = curve_fit(obj_fn, freq, Zc, **kwargs) + except RuntimeError: + continue + err = np.mean((obj_fn(freq, *popt) - Zc) ** 2) + if err < err_min: + err_min = err + p0 = popt + kwargs["p0"] = np.random.rand(num_params) + + if err_min == np.inf: + raise RuntimeError("Failed to fit the circuit parameters.") + + variables = parser.get_parameter_labels(circuit) + return dict(zip(variables, p0)) + + # FIXME: Timeout logic doesn't work on Windows -> module 'signal' has no attribute 'SIGALRM'. if os.name != "nt": - fit_circuit_parameters = timeout(300)(fit_circuit_parameters) + fit_circuit_parameters = timeout(15)(fit_circuit_parameters) -def eval_circuit(circuit: str, x: np.ndarray[float], f: np.ndarray[float]) -> np.ndarray: +def eval_circuit( + circuit: str, f: Union[np.ndarray, float], p: np.ndarray +) -> np.ndarray[complex]: """Converts a circuit string to a function of (params, freq) and evaluates it.""" Z_expr = parser.generate_mathematical_expr(circuit) return eval(Z_expr) -def generate_circuit_fn(circuit: str, jit=False): - T = Callable[[np.ndarray, np.ndarray], np.ndarray] - fn: T = partial(eval_circuit, circuit) - return jax.jit(fn) if jit else fn +def generate_circuit_fn(circuit: str, jit=False, concat=False): + def Z_complex(freq: np.ndarray, p: Union[np.ndarray, float]) -> np.ndarray[complex]: + return eval_circuit(circuit, freq, p) + + def Z_concat(freq: np.ndarray, p: Union[np.ndarray, float]) -> np.ndarray: + Z = Z_complex(freq, p) + return jnp.hstack([Z.real, Z.imag]) + + fn = Z_concat if concat else Z_complex + fn = jax.jit(fn) if jit else fn + + return fn -def generate_circuit_fn_impedance_backend(circuit: str) -> callable: +def generate_circuit_fn_impedance_backend(circuit: str): """Converts a circuit string to a function using impedance.py.""" num_params = parser.count_parameters(circuit) # Convert circuit string to impedance.py format @@ -272,8 +350,8 @@ def generate_circuit_fn_impedance_backend(circuit: str) -> callable: p0 = np.full(num_params, np.nan) circuit = CustomCircuit(circuit, initial_guess=p0) - def func(params, freq): - circuit.parameters_ = params + def func(freq: Union[np.ndarray, float], p: np.ndarray) -> np.ndarray: + circuit.parameters_ = p return circuit.predict(freq) return func @@ -329,8 +407,8 @@ def x0(circuit: str) -> np.ndarray[float]: return np.array(x0) freq = np.logspace(-3, 3, 10) - Z1 = generate_circuit_fn(circuit1)(x0(circuit1), freq) - Z2 = generate_circuit_fn(circuit2)(x0(circuit2), freq) + Z1 = generate_circuit_fn(circuit1)(freq, x0(circuit1)) + Z2 = generate_circuit_fn(circuit2)(freq, x0(circuit2)) return np.allclose(Z1, Z2) diff --git a/examples/autoeis_demo.ipynb b/examples/autoeis_demo.ipynb index ab22f12d..50af5725 100644 --- a/examples/autoeis_demo.ipynb +++ b/examples/autoeis_demo.ipynb @@ -526,7 +526,7 @@ } ], "source": [ - "mcmc_results = ae.core.perform_bayesian_inference(circuits, Z, freq)\n", + "mcmc_results = ae.core.perform_bayesian_inference(circuits, freq, Z)\n", "mcmcs, status = zip(*mcmc_results)" ] }, @@ -765,7 +765,7 @@ " percentiles = [10, 50, 90]\n", " params_list = [[np.percentile(samples[v], p) for v in variables] for p in percentiles]\n", " circuit_fn = ae.utils.generate_circuit_fn(circuit)\n", - " Zsim_list = [circuit_fn(params, freq) for params in params_list]\n", + " Zsim_list = [circuit_fn(freq, params) for params in params_list]\n", " # Plot Nyquist plot\n", " fig, ax = plt.subplots(figsize=(5.5, 4))\n", " for p, Zsim in zip(percentiles, Zsim_list):\n", diff --git a/tests/test_core.py b/tests/test_core.py index 32a49369..bbf83b45 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -6,14 +6,25 @@ from autoeis import core, io, utils +def test_bayesian_inference_batch(): + Z, freq = io.load_test_dataset() + # Only test first three circuits to save time in CI + circuits = io.load_test_circuits(filtered=True).iloc[:3] + mcmc_results = core.perform_bayesian_inference(circuits, freq, Z, refine_p0=True) + assert len(mcmc_results) == len(circuits) + for mcmc, exist_code in mcmc_results: + assert exist_code in [-1, 0] + assert isinstance(mcmc, numpyro.infer.mcmc.MCMC) + + def test_compute_ohmic_resistance(): circuit_string = "R1-[P2,P3-R4]" circuit_fn = utils.generate_circuit_fn_impedance_backend(circuit_string) R1 = 250 parameters = np.array([R1, 1e-3, 0.1, 5e-5, 0.8, 10]) freq = np.logspace(-3, 3, 1000) - Z = circuit_fn(parameters, freq) - R = core.compute_ohmic_resistance(Z, freq) + Z = circuit_fn(freq, parameters) + R = core.compute_ohmic_resistance(freq, Z) np.testing.assert_allclose(R, R1, rtol=0.15) @@ -23,8 +34,8 @@ def test_compute_ohmic_resistance_missing_high_freq(): R1 = 250 parameters = np.array([R1, 1e-3, 0.1, 5e-5, 0.8, 10]) freq = np.logspace(-3, 0, 1000) - Z = circuit_fn(parameters, freq) - R = core.compute_ohmic_resistance(Z, freq) + Z = circuit_fn(freq, parameters) + R = core.compute_ohmic_resistance(freq, Z) # When high frequency measurements are missing, Re(Z) @ max(freq) is good approximation Zreal_at_high_freq = Z.real[np.argmax(freq)] np.testing.assert_allclose(R, Zreal_at_high_freq) @@ -81,27 +92,16 @@ def test_bayesian_inference_single(): "num_samples": 1000, "progress_bar": False, } - mcmc, exist_code = core._perform_bayesian_inference( - circuit, Z, freq, p0, **kwargs_mcmc - ) - assert exist_code in [-1, 0] + mcmcs = core.perform_bayesian_inference(circuit, freq, Z, p0, **kwargs_mcmc) + mcmc, exit_code = mcmcs[0] + assert exit_code in [-1, 0] assert isinstance(mcmc, numpyro.infer.mcmc.MCMC) -def test_bayesian_inference_batch(): - Z, freq = io.load_test_dataset() - circuits = io.load_test_circuits(filtered=True) - mcmc_results = core.perform_bayesian_inference(circuits, Z, freq) - assert len(mcmc_results) == len(circuits) - for mcmc, exist_code in mcmc_results: - assert exist_code in [-1, 0] - assert isinstance(mcmc, numpyro.infer.mcmc.MCMC) - - @pytest.mark.skip(reason="This test is too slow!") def test_perform_full_analysis(): Z, freq = io.load_test_dataset() - results = core.perform_full_analysis(Z, freq) + results = core.perform_full_analysis(freq, Z) required_columns = [ "circuitstring", "Parameters", diff --git a/tests/test_utils.py b/tests/test_utils.py index 6f89d7c3..6f278725 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -11,15 +11,15 @@ # Simulated EIS data circuit_string = "R1-[P2,R3]" -p0_dict = {"R1": 250, "P2w": 1e-3, "P2n": 0.5, "R3": 10} +p0_dict = {"R1": 250, "P2w": 1e-3, "P2n": 0.5, "R3": 10.0} p0_vals = list(p0_dict.values()) circuit_fn_gt = utils.generate_circuit_fn_impedance_backend(circuit_string) freq = np.logspace(-3, 3, 1000) -Z = circuit_fn_gt(p0_vals, freq) +Z = circuit_fn_gt(freq, p0_vals) def test_fit_circuit_parameters_without_x0(): - p_dict = utils.fit_circuit_parameters(circuit_string, Z, freq, iters=5) + p_dict = utils.fit_circuit_parameters(circuit_string, freq, Z, iters=10) p_fit = list(p_dict.values()) assert np.allclose(p_fit, p0_vals, rtol=0.01) @@ -27,7 +27,7 @@ def test_fit_circuit_parameters_without_x0(): def test_fit_circuit_parameters_with_x0(): # Add some noise to the initial guess to test robustness p0 = p0_vals + np.random.rand(len(p0_vals)) * p0_vals * 0.5 - p_dict = utils.fit_circuit_parameters(circuit_string, Z, freq, p0) + p_dict = utils.fit_circuit_parameters(circuit_string, freq, Z, p0) p_fit = list(p_dict.values()) assert np.allclose(p_fit, p0_vals, rtol=0.01) @@ -38,7 +38,7 @@ def test_generate_circuit_fn(): freq = np.array([1, 10, 100]) p = np.random.rand(num_params) circuit_fn = utils.generate_circuit_fn(circuit) - Z_py = circuit_fn(p, freq) + Z_py = circuit_fn(freq, p) Main = julia_helpers.init_julia() ec = julia_helpers.import_backend(Main) Z_jl = np.array([ec.get_target_impedance(circuit, p, f) for f in freq])