Skip to content

Commit

Permalink
Merge pull request #90 from AUTODIAL/enh/optimizer-fitter
Browse files Browse the repository at this point in the history
  • Loading branch information
ma-sadeghi authored Feb 13, 2024
2 parents b847865 + 203fd11 commit 0628cc4
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 65 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ on: pull_request

jobs:
test:
runs-on: macos-14
runs-on: ubuntu-latest
defaults:
run:
shell: sh -l {0}
shell: bash

steps:
- uses: actions/checkout@v4
- name: "Set up Python"
uses: conda-incubator/setup-miniconda@v3
with:
python-version: '3.10'
installer-url: https://github.com/conda-forge/miniforge/releases/download/23.11.0-0/Mambaforge-23.11.0-0-MacOSX-arm64.sh
miniforge-version: latest
- name: "Install AutoEIS"
run: |
python -m pip install --upgrade pip
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ examples/**/*.png
examples/**/*.csv
examples/**/*.txt
results/
*.pkl

# Local development
benchmark_jaxfit.py
Expand Down
66 changes: 54 additions & 12 deletions autoeis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pandas as pd
import psutil
from impedance.validation import linKK
from jax import config
from mpire import WorkerPool
from numpyro.infer import MCMC, NUTS
from scipy.optimize import curve_fit
Expand All @@ -34,6 +35,8 @@
from autoeis import io, julia_helpers, metrics, parser, utils
from autoeis.models import circuit_regression, circuit_regression_wrapped # noqa: F401

# Enforce double precision, otherwise circuit fitter fails (who knows what else!)
config.update("jax_enable_x64", True)
# AutoEIS datasets are not small-enough that CPU is much faster than GPU
numpyro.set_platform("cpu")

Expand All @@ -57,7 +60,7 @@


# TODO: Breaks when data is noisy -> use curve_fit to extrapolate R0
def compute_ohmic_resistance(Z: np.ndarray[complex], freq: np.ndarray[float]) -> float:
def compute_ohmic_resistance(freq: np.ndarray[float], Z: np.ndarray[complex]) -> float:
"""Extracts the ohmic resistance from impedance data.
Parameters
Expand Down Expand Up @@ -223,7 +226,7 @@ def preprocess_impedance_data(

# Find the ohmic resistance
try:
ohmic_resistance = compute_ohmic_resistance(Z_mask, freq_mask)
ohmic_resistance = compute_ohmic_resistance(freq_mask, Z_mask)
ohmic_resistance_found = True
except ValueError:
log.error("Ohmic resistance not found. Check data or increase KK threshold.")
Expand Down Expand Up @@ -489,15 +492,47 @@ def merge_identical_circuits(circuits: "pd.DataFrame") -> "pd.DataFrame":

def perform_bayesian_inference(
circuits: Union[pd.DataFrame, list[str], str],
Z: np.ndarray[complex],
freq: np.ndarray[float],
Z: np.ndarray[complex],
p0: Union[np.ndarray, dict, list[dict], list[np.ndarray]] = None,
num_warmup=2500,
num_samples=1000,
num_chains=1,
seed: Union[int, jax.Array] = None,
progress_bar: bool = True,
) -> list[Union[numpyro.infer.mcmc.MCMC, None]]:
refine_p0: bool = False,
) -> list[tuple[Union[numpyro.infer.mcmc.MCMC, None], int]]:
"""Performs Bayesian inference on the circuits based on impedance data.
Parameters
----------
circuits : pd.DataFrame or list[str]
Dataframe containing circuits or list of circuit strings.
Z : np.ndarray[complex]
Complex impedance data.
freq: np.ndarray[float]
Frequency data.
p0 : Union[np.ndarray[float], dict[str, float]], optional
Initial guess for the circuit parameters (default is None).
num_warmup : int, optional
Number of warmup samples for the MCMC (default is 2500).
num_samples : int, optional
Number of samples for the MCMC (default is 1000).
num_chains : int, optional
Number of MCMC chains (default is 1).
seed : int, optional
Random seed for reproducibility (default is None).
progress_bar : bool, optional
If True, a progress bar will be displayed (default is True).
refine_p0 : bool, optional
If True, the initial guess for the circuit parameters will be refined
using the circuit fitter (default is False).
Returns
-------
list[tuple[numpyro.infer.mcmc.MCMC, int]]
List of MCMC objects and exit codes (0 if successful, -1 if failed).
"""
# Ensure inputs are lists
if isinstance(circuits, str):
circuits = [circuits]
Expand All @@ -521,14 +556,20 @@ def perform_bayesian_inference(
num_params = len(parser.get_parameter_labels(circuit))
assert len(p0_) == num_params, f"Invalid p0 length: {p0_}"

if refine_p0:
for i, (circuit, p0_) in enumerate(zip(circuits, p0)):
p0[i] = utils.fit_circuit_parameters(circuit, freq, Z, p0=p0_)
# If circuit fitter didn't converge, use the initial guess
p0[i] = p0_ if p0[i] is None else p0[i]

# Short-circuit if no circuits are provided
if len(circuits) == 0:
log.warning("'circuits' dataframe is empty!")
return None

bi_kwargs = {
"Z": Z,
"freq": freq,
"Z": Z,
"num_warmup": num_warmup,
"num_samples": num_samples,
"num_chains": num_chains,
Expand All @@ -545,8 +586,8 @@ def perform_bayesian_inference(

def _perform_bayesian_inference(
circuit: str,
Z: np.ndarray[complex],
freq: np.ndarray[float],
Z: np.ndarray[complex],
p0: Union[np.ndarray[float], dict[str, float]] = None,
num_warmup=2500,
num_samples=1000,
Expand Down Expand Up @@ -586,9 +627,10 @@ def _perform_bayesian_inference(
key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key)

# TODO: Remove this, circuit fitting must be done in the public API
# Deal with initial values for the circuit parameters
if p0 is None:
p0 = utils.fit_circuit_parameters(circuit, Z, freq)
p0 = utils.fit_circuit_parameters(circuit, freq, Z)
assert isinstance(p0, dict), "p0 must be a dictionary"

circuit_fn = utils.generate_circuit_fn(circuit, jit=True)
Expand All @@ -607,8 +649,8 @@ def _perform_bayesian_inference(
}
mcmc = MCMC(nuts_kernel, **kwargs_mcmc)
kwargs_inference = {
"Z": Z,
"freq": freq,
"Z": Z,
"priors": priors,
"circuit_fn": circuit_fn,
}
Expand All @@ -624,8 +666,8 @@ def _perform_bayesian_inference(

def _perform_bayesian_inference_batch(
circuits: list[str],
Z: np.ndarray[complex],
freq: np.ndarray[float],
Z: np.ndarray[complex],
p0: list[dict[str, float]] = None,
num_warmup=2500,
num_samples=1000,
Expand Down Expand Up @@ -662,8 +704,8 @@ def _perform_bayesian_inference_batch(
# Multiprocessing requires all inputs to be iterables of the same length
bi_kwargs = {
"circuits": circuits,
"Z": [Z] * N,
"freq": [freq] * N,
"Z": [Z] * N,
"p0": p0 if isinstance(p0, list) else [p0] * N,
"num_warmup": [num_warmup] * N,
"num_samples": [num_samples] * N,
Expand Down Expand Up @@ -726,8 +768,8 @@ def filter_implausible_circuits(circuits: pd.DataFrame) -> pd.DataFrame:


def perform_full_analysis(
Z: np.ndarray[complex],
freq: np.ndarray[float],
Z: np.ndarray[complex],
iters: int = 100,
parallel: bool = True,
linKK_threshold: float = 5e-2,
Expand Down Expand Up @@ -774,7 +816,7 @@ def perform_full_analysis(

# Perform Bayesian inference on the filtered ECMs
kwargs_mcmc = {"num_warmup": num_warmup, "num_samples": num_samples}
mcmcs = perform_bayesian_inference(circuits, Z, freq, **kwargs_mcmc)
mcmcs = perform_bayesian_inference(circuits, freq, Z, **kwargs_mcmc)

# Add the results to the circuits dataframe as a new column
chains, status = zip(*mcmcs)
Expand Down
8 changes: 4 additions & 4 deletions autoeis/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def circuit_regression(
):
"""NumpyRo model for Bayesian inference of circuit component values."""
# Sample each element of X separately
X = jnp.array([numpyro.sample(k, v) for k, v in priors.items()])
p = jnp.array([numpyro.sample(k, v) for k, v in priors.items()])
# Predict Z using the model
circuit_fn = utils.generate_circuit_fn(circuit)
circuit_fn = jax.jit(circuit_fn)
Z_pred = circuit_fn(X, freq)
Z_pred = circuit_fn(freq, p)
# Define observation model for real and imaginary parts of Z
sigma_real = numpyro.sample("sigma_real", dist.Exponential(rate=1.0))
numpyro.sample("obs_real", dist.Normal(Z_pred.real, sigma_real), obs=Z.real)
Expand All @@ -47,9 +47,9 @@ def circuit_regression_wrapped(
):
"""NumpyRo model for Bayesian inference of circuit component values."""
# Sample each element of X separately
X = jnp.array([numpyro.sample(k, v) for k, v in priors.items()])
p = jnp.array([numpyro.sample(k, v) for k, v in priors.items()])
# Predict Z using the model
Z_pred = circuit_fn(X, freq)
Z_pred = circuit_fn(freq, p)
# Define observation model for real and imaginary parts of Z
sigma_real = numpyro.sample("sigma_real", dist.Exponential(rate=1.0))
numpyro.sample("obs_real", dist.Normal(Z_pred.real, sigma_real), obs=Z.real)
Expand Down
Loading

0 comments on commit 0628cc4

Please sign in to comment.