Skip to content

Commit

Permalink
general maintenance (#2315)
Browse files Browse the repository at this point in the history
* general maintenance

* update changelog and check linters

* keep everything bokeh2

* fix squeeze behaviour

* black

* add unconstrained groups to list of recognized groups

* attempt fixing benchmarks
  • Loading branch information
OriolAbril authored Feb 22, 2024
1 parent 7c1637f commit 2631d13
Show file tree
Hide file tree
Showing 28 changed files with 83 additions and 89 deletions.
6 changes: 3 additions & 3 deletions .azure-pipelines/azure-pipelines-base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ jobs:
timeoutInMinutes: 360
strategy:
matrix:
Python_39:
python.version: "3.9"
Python_312:
python.version: "3.12"
PyPIGithub: false
name: "Python 3.9"
name: "Python 3.12"
Python_311:
python.version: "3.11"
PyPIGithub: false
Expand Down
2 changes: 1 addition & 1 deletion .azure-pipelines/azure-pipelines-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
python -m pip install wheel
python -m pip install --no-cache-dir -r requirements.txt
python -m pip install --no-cache-dir -r requirements-optional.txt
python -m pip install asv
python -m pip install asv!=0.6.2
displayName: 'Install requirements'
- script: |
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New features

### Maintenance and fixes
- Fix deprecations introduced in latest pandas and xarray versions, and prepare for numpy 2.0 ones ([2315](https://github.com/arviz-devs/arviz/pull/2315)))

- Refactor ECDF code ([2311](https://github.com/arviz-devs/arviz/pull/2311))

Expand Down
4 changes: 3 additions & 1 deletion arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
"observed_data",
"constant_data",
"predictions_constant_data",
"unconstrained_posterior",
"unconstrained_prior",
]

WARMUP_TAG = "warmup_"
Expand Down Expand Up @@ -1492,7 +1494,7 @@ def add_groups(self, group_dict=None, coords=None, dims=None, **kwargs):
import numpy as np
rng = np.random.default_rng(73)
ary = rng.normal(size=(post.dims["chain"], post.dims["draw"], obs.dims["match"]))
ary = rng.normal(size=(post.sizes["chain"], post.sizes["draw"], obs.sizes["match"]))
idata.add_groups(
log_likelihood={"home_points": ary},
dims={"home_points": ["match"]},
Expand Down
9 changes: 7 additions & 2 deletions arviz/plots/backends/bokeh/bpvplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,13 @@ def plot_bpv(
ax_i.line(0, 0, legend_label=f"bpv={p_value:.2f}", alpha=0)

if plot_mean:
ax_i.circle(
obs_vals.mean(), 0, fill_color=color, line_color="black", size=markersize
ax_i.scatter(
obs_vals.mean(),
0,
fill_color=color,
line_color="black",
size=markersize,
marker="circle",
)

_title = Title()
Expand Down
11 changes: 7 additions & 4 deletions arviz/plots/backends/bokeh/compareplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,14 @@ def plot_compare(
err_ys.append((y, y))

# plot them
dif_tri = ax.triangle(
dif_tri = ax.scatter(
comp_df[information_criterion].iloc[1:],
yticks_pos[1::2],
line_color=plot_kwargs.get("color_dse", "grey"),
fill_color=plot_kwargs.get("color_dse", "grey"),
line_width=2,
size=6,
marker="triangle",
)
dif_line = ax.multi_line(err_xs, err_ys, line_color=plot_kwargs.get("color_dse", "grey"))

Expand All @@ -85,13 +86,14 @@ def plot_compare(
ax.yaxis.ticker = yticks_pos[::2]
ax.yaxis.major_label_overrides = dict(zip(yticks_pos[::2], yticks_labels))

elpd_circ = ax.circle(
elpd_circ = ax.scatter(
comp_df[information_criterion],
yticks_pos[::2],
line_color=plot_kwargs.get("color_ic", "black"),
fill_color=None,
line_width=2,
size=6,
marker="circle",
)
elpd_label = [elpd_circ]

Expand All @@ -110,7 +112,7 @@ def plot_compare(

labels.append(("ELPD", elpd_label))

scale = comp_df["scale"][0]
scale = comp_df["scale"].iloc[0]

if insample_dev:
p_ic = comp_df[f"p_{information_criterion.split('_')[1]}"]
Expand All @@ -120,13 +122,14 @@ def plot_compare(
correction = -p_ic
elif scale == "deviance":
correction = -(2 * p_ic)
insample_circ = ax.circle(
insample_circ = ax.scatter(
comp_df[information_criterion] + correction,
yticks_pos[::2],
line_color=plot_kwargs.get("color_insample_dev", "black"),
fill_color=plot_kwargs.get("color_insample_dev", "black"),
line_width=2,
size=6,
marker="circle",
)
labels.append(("In-sample ELPD", [insample_circ]))

Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/backends/bokeh/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,15 +640,15 @@ def iterator(self):
grouped_data = [[(0, datum)] for datum in self.data]
skip_dims = self.combine_dims.union({"chain"})
else:
grouped_data = [datum.groupby("chain") for datum in self.data]
grouped_data = [datum.groupby("chain", squeeze=False) for datum in self.data]
skip_dims = self.combine_dims

label_dict = OrderedDict()
selection_list = []
for name, grouped_datum in zip(self.model_names, grouped_data):
for _, sub_data in grouped_datum:
datum_iter = xarray_var_iter(
sub_data,
sub_data.squeeze(),
var_names=[self.var_name],
skip_dims=skip_dims,
reverse_selections=True,
Expand Down
2 changes: 1 addition & 1 deletion arviz/plots/backends/matplotlib/compareplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def plot_compare(
else:
ax.set_yticks(yticks_pos[::2])

scale = comp_df["scale"][0]
scale = comp_df["scale"].iloc[0]

if insample_dev:
p_ic = comp_df[f"p_{information_criterion.split('_')[1]}"]
Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/backends/matplotlib/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,15 +536,15 @@ def iterator(self):
grouped_data = [[(0, datum)] for datum in self.data]
skip_dims = self.combine_dims.union({"chain"})
else:
grouped_data = [datum.groupby("chain") for datum in self.data]
grouped_data = [datum.groupby("chain", squeeze=False) for datum in self.data]
skip_dims = self.combine_dims

label_dict = OrderedDict()
selection_list = []
for name, grouped_datum in zip(self.model_names, grouped_data):
for _, sub_data in grouped_datum:
datum_iter = xarray_var_iter(
sub_data,
sub_data.squeeze(),
var_names=[self.var_name],
skip_dims=skip_dims,
reverse_selections=True,
Expand Down
2 changes: 1 addition & 1 deletion arviz/plots/backends/matplotlib/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def plot_trace(
Line2D(
[], [], label=chain_id, **dealiase_sel_kwargs(legend_kwargs, chain_prop, chain_id)
)
for chain_id in range(data.dims["chain"])
for chain_id in range(data.sizes["chain"])
]
if combined:
handles.insert(
Expand Down
13 changes: 7 additions & 6 deletions arviz/plots/bfplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def plot_bf(
algorithm presented in [1]_.
Parameters
-----------
----------
idata : InferenceData
Any object that can be converted to an :class:`arviz.InferenceData` object
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
Expand All @@ -52,16 +52,16 @@ def plot_bf(
Tuple of valid Matplotlib colors. First element for the prior, second for the posterior.
figsize : (float, float), optional
Figure size. If `None` it will be defined automatically.
textsize: float, optional
textsize : float, optional
Text size scaling factor for labels, titles and lines. If `None` it will be auto
scaled based on `figsize`.
plot_kwargs : dicts, optional
plot_kwargs : dict, optional
Additional keywords passed to :func:`matplotlib.pyplot.plot`.
hist_kwargs : dicts, optional
hist_kwargs : dict, optional
Additional keywords passed to :func:`arviz.plot_dist`. Only works for discrete variables.
ax : axes, optional
:class:`matplotlib.axes.Axes` or :class:`bokeh.plotting.Figure`.
backend :{"matplotlib", "bokeh"}, default "matplotlib"
backend : {"matplotlib", "bokeh"}, default "matplotlib"
Select plotting backend.
backend_kwargs : dict, optional
These are kwargs specific to the backend being used, passed to
Expand All @@ -78,7 +78,7 @@ def plot_bf(
References
----------
.. [1] Heck, D., 2019. A caveat on the Savage-Dickey density ratio:
The case of computing Bayes factors for regression parameters.
The case of computing Bayes factors for regression parameters.
Examples
--------
Expand All @@ -92,6 +92,7 @@ def plot_bf(
>>> idata = az.from_dict(posterior={"a":np.random.normal(1, 0.5, 5000)},
... prior={"a":np.random.normal(0, 1, 5000)})
>>> az.plot_bf(idata, var_name="a", ref_val=0)
"""
posterior = extract(idata, var_names=var_name).values

Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/bpvplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,11 @@ def plot_bpv(

if flatten_pp is None:
if flatten is None:
flatten_pp = list(predictive_dataset.dims.keys())
flatten_pp = list(predictive_dataset.dims)
else:
flatten_pp = flatten
if flatten is None:
flatten = list(observed.dims.keys())
flatten = list(observed.dims)

if coords is None:
coords = {}
Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/compareplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ def plot_compare(
References
----------
.. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out
cross-validation and WAIC https://arxiv.org/abs/1507.04544
cross-validation and WAIC https://arxiv.org/abs/1507.04544
.. [2] McElreath R. (2022). Statistical Rethinking A Bayesian Course with Examples in
R and Stan, Second edition, CRC Press.
R and Stan, Second edition, CRC Press.
Examples
--------
Expand Down
2 changes: 1 addition & 1 deletion arviz/plots/elpdplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def plot_elpd(
References
----------
.. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out
cross-validation and WAIC https://arxiv.org/abs/1507.04544
cross-validation and WAIC https://arxiv.org/abs/1507.04544
Examples
--------
Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/essplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def plot_ess(

data = get_coords(convert_to_dataset(idata, group="posterior"), coords)
var_names = _var_names(var_names, data, filter_vars)
n_draws = data.dims["draw"]
n_samples = n_draws * data.dims["chain"]
n_draws = data.sizes["draw"]
n_samples = n_draws * data.sizes["chain"]

ess_tail_dataset = None
mean_ess = None
Expand Down
2 changes: 1 addition & 1 deletion arviz/plots/pairplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def plot_pair(
)

if gridsize == "auto":
gridsize = int(dataset.dims["draw"] ** 0.35)
gridsize = int(dataset.sizes["draw"] ** 0.35)

numvars = len(flat_var_names)

Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,11 @@ def plot_ppc(

if flatten_pp is None:
if flatten is None:
flatten_pp = list(predictive_dataset.dims.keys())
flatten_pp = list(predictive_dataset.dims)
else:
flatten_pp = flatten
if flatten is None:
flatten = list(observed_data.dims.keys())
flatten = list(observed_data.dims)

if coords is None:
coords = {}
Expand Down
4 changes: 2 additions & 2 deletions arviz/stats/density_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def _fixed_point(t, N, k_sq, a_sq):
Z. I. Botev, J. F. Grotowski, and D. P. Kroese.
Ann. Statist. 38 (2010), no. 5, 2916--2957.
"""
k_sq = np.asfarray(k_sq, dtype=np.float64)
a_sq = np.asfarray(a_sq, dtype=np.float64)
k_sq = np.asarray(k_sq, dtype=np.float64)
a_sq = np.asarray(a_sq, dtype=np.float64)

l = 7
f = np.sum(np.power(k_sq, l) * a_sq * np.exp(-k_sq * np.pi**2 * t))
Expand Down
4 changes: 2 additions & 2 deletions arviz/stats/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,10 +457,10 @@ def ks_summary(pareto_tail_indices):
"""
_numba_flag = Numba.numba_flag
if _numba_flag:
bins = np.asarray([-np.Inf, 0.5, 0.7, 1, np.Inf])
bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf])
kcounts, *_ = _histogram(pareto_tail_indices, bins)
else:
kcounts, *_ = _histogram(pareto_tail_indices, bins=[-np.Inf, 0.5, 0.7, 1, np.Inf])
kcounts, *_ = _histogram(pareto_tail_indices, bins=[-np.inf, 0.5, 0.7, 1, np.inf])
kprop = kcounts / len(pareto_tail_indices) * 100
df_k = pd.DataFrame(
dict(_=["(good)", "(ok)", "(bad)", "(very bad)"], Count=kcounts, Pct=kprop)
Expand Down
Loading

0 comments on commit 2631d13

Please sign in to comment.