Skip to content

Commit

Permalink
Return InferenceData by default
Browse files Browse the repository at this point in the history
Also removes some unnecessary XFAIL marks.

Closes #4372, #4740

Co-authored-by: Oriol Abril <oriol.abril.pla@gmail.com>
  • Loading branch information
michaelosthege and OriolAbril committed Jun 7, 2021
1 parent 660b95b commit 0923d25
Show file tree
Hide file tree
Showing 24 changed files with 225 additions and 231 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- ArviZ `plots` and `stats` *wrappers* were removed. The functions are now just available by their original names (see [#4549](https://github.com/pymc-devs/pymc3/pull/4471) and `3.11.2` release notes).
- The GLM submodule has been removed, please use [Bambi](https://bambinos.github.io/bambi/) instead.
- The `Distribution` keyword argument `testval` has been deprecated in favor of `initval`.
- `pm.sample` now returns results as `InferenceData` instead of `MultiTrace` by default (see [#4744](https://github.com/pymc-devs/pymc3/pull/4744)).
- ...

### New Features
Expand Down
16 changes: 8 additions & 8 deletions benchmarks/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def track_glm_hierarchical_ess(self, init):
init=init, chains=self.chains, progressbar=False, random_seed=123
)
t0 = time.time()
trace = pm.sample(
idata = pm.sample(
draws=self.draws,
step=step,
cores=4,
Expand All @@ -192,7 +192,7 @@ def track_glm_hierarchical_ess(self, init):
compute_convergence_checks=False,
)
tot = time.time() - t0
ess = float(az.ess(trace, var_names=["mu_a"])["mu_a"].values)
ess = float(az.ess(idata, var_names=["mu_a"])["mu_a"].values)
return ess / tot

def track_marginal_mixture_model_ess(self, init):
Expand All @@ -203,7 +203,7 @@ def track_marginal_mixture_model_ess(self, init):
)
start = [{k: v for k, v in start.items()} for _ in range(self.chains)]
t0 = time.time()
trace = pm.sample(
idata = pm.sample(
draws=self.draws,
step=step,
cores=4,
Expand All @@ -214,7 +214,7 @@ def track_marginal_mixture_model_ess(self, init):
compute_convergence_checks=False,
)
tot = time.time() - t0
ess = az.ess(trace, var_names=["mu"])["mu"].values.min() # worst case
ess = az.ess(idata, var_names=["mu"])["mu"].values.min() # worst case
return ess / tot


Expand All @@ -235,7 +235,7 @@ def track_glm_hierarchical_ess(self, step):
if step is not None:
step = step()
t0 = time.time()
trace = pm.sample(
idata = pm.sample(
draws=self.draws,
step=step,
cores=4,
Expand All @@ -245,7 +245,7 @@ def track_glm_hierarchical_ess(self, step):
compute_convergence_checks=False,
)
tot = time.time() - t0
ess = float(az.ess(trace, var_names=["mu_a"])["mu_a"].values)
ess = float(az.ess(idata, var_names=["mu_a"])["mu_a"].values)
return ess / tot


Expand Down Expand Up @@ -302,9 +302,9 @@ def freefall(y, t, p):
Y = pm.Normal("Y", mu=ode_solution, sd=sigma, observed=y)

t0 = time.time()
trace = pm.sample(500, tune=1000, chains=2, cores=2, random_seed=0)
idata = pm.sample(500, tune=1000, chains=2, cores=2, random_seed=0)
tot = time.time() - t0
ess = az.ess(trace)
ess = az.ess(idata)
return np.mean([ess.sigma, ess.gamma]) / tot


Expand Down
8 changes: 4 additions & 4 deletions docs/source/Advanced_usage_of_Aesara_in_PyMC3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ be time consuming if the number of datasets is large)::
pm.Normal('y', mu=mu, sigma=1, observed=data)

# Generate one trace for each dataset
traces = []
idatas = []
for data_vals in observed_data:
# Switch out the observed dataset
data.set_value(data_vals)
with model:
traces.append(pm.sample())
idatas.append(pm.sample())

We can also sometimes use shared variables to work around limitations
in the current PyMC3 api. A common task in Machine Learning is to predict
Expand All @@ -63,7 +63,7 @@ variable for our observations::
pm.Bernoulli('obs', p=logistic, observed=y)

# fit the model
trace = pm.sample()
idata = pm.sample()

# Switch out the observations and use `sample_posterior_predictive` to predict
x_shared.set_value([-1, 0, 1.])
Expand Down Expand Up @@ -220,4 +220,4 @@ We can now define our model using this new `Op`::
mu = pm.Deterministic('mu', at_mu_from_theta(theta))
pm.Normal('y', mu=mu, sigma=0.1, observed=[0.2, 0.21, 0.3])

trace = pm.sample()
idata = pm.sample()
2 changes: 1 addition & 1 deletion docs/source/Gaussian_Processes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ other implementations. The first block fits the GP prior. We denote

f = gp.marginal_likelihood("f", X, y, noise)

trace = pm.sample(1000)
idata = pm.sample(1000)


To construct the conditional distribution of :code:`gp1` or :code:`gp2`, we
Expand Down
4 changes: 2 additions & 2 deletions docs/source/about.rst
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,9 @@ Save this file, then from a python shell (or another file in the same directory)
with bioassay_model:

# Draw samples
trace = pm.sample(1000, tune=2000, cores=2)
idata = pm.sample(1000, tune=2000, cores=2)
# Plot two parameters
az.plot_forest(trace, var_names=['alpha', 'beta'], r_hat=True)
az.plot_forest(idata, var_names=['alpha', 'beta'], r_hat=True)

This example will generate 1000 posterior samples on each of two cores using the NUTS algorithm, preceded by 2000 tuning samples (these are good default numbers for most models).

Expand Down
4 changes: 2 additions & 2 deletions pymc3/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,12 +498,12 @@ class Data:
... pm.Normal('y', mu=mu, sigma=1, observed=data)
>>> # Generate one trace for each dataset
>>> traces = []
>>> idatas = []
>>> for data_vals in observed_data:
... with model:
... # Switch out the observed dataset
... model.set_data('data', data_vals)
... traces.append(pm.sample())
... idatas.append(pm.sample())
To set the value of the data container variable, check out
:func:`pymc3.model.set_data()`.
Expand Down
14 changes: 8 additions & 6 deletions pymc3/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1691,14 +1691,15 @@ class OrderedLogistic(Categorical):
cutpoints = pm.Normal("cutpoints", mu=[-1,1], sigma=10, shape=2,
transform=pm.distributions.transforms.ordered)
y_ = pm.OrderedLogistic("y", cutpoints=cutpoints, eta=x, observed=y)
tr = pm.sample(1000)
idata = pm.sample(1000)
# Plot the results
plt.hist(cluster1, 30, alpha=0.5);
plt.hist(cluster2, 30, alpha=0.5);
plt.hist(cluster3, 30, alpha=0.5);
plt.hist(tr["cutpoints"][:,0], 80, alpha=0.2, color='k');
plt.hist(tr["cutpoints"][:,1], 80, alpha=0.2, color='k');
posterior = idata.posterior.stack(sample=("chain", "draw"))
plt.hist(posterior["cutpoints"][0], 80, alpha=0.2, color='k');
plt.hist(posterior["cutpoints"][1], 80, alpha=0.2, color='k');
"""

Expand Down Expand Up @@ -1782,14 +1783,15 @@ class OrderedProbit(Categorical):
cutpoints = pm.Normal("cutpoints", mu=[-1,1], sigma=10, shape=2,
transform=pm.distributions.transforms.ordered)
y_ = pm.OrderedProbit("y", cutpoints=cutpoints, eta=x, observed=y)
tr = pm.sample(1000)
idata = pm.sample(1000)
# Plot the results
plt.hist(cluster1, 30, alpha=0.5);
plt.hist(cluster2, 30, alpha=0.5);
plt.hist(cluster3, 30, alpha=0.5);
plt.hist(tr["cutpoints"][:,0], 80, alpha=0.2, color='k');
plt.hist(tr["cutpoints"][:,1], 80, alpha=0.2, color='k');
posterior = idata.posterior.stack(sample=("chain", "draw"))
plt.hist(posterior["cutpoints"][0], 80, alpha=0.2, color='k');
plt.hist(posterior["cutpoints"][1], 80, alpha=0.2, color='k');
"""

Expand Down
2 changes: 1 addition & 1 deletion pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def __init__(
normal_dist.logp,
observed=np.random.randn(100),
)
trace = pm.sample(100)
idata = pm.sample(100)
.. code-block:: python
Expand Down
4 changes: 2 additions & 2 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,15 +1696,15 @@ def set_data(new_data, model=None):
... y = pm.Data('y', [1., 2., 3.])
... beta = pm.Normal('beta', 0, 1)
... obs = pm.Normal('obs', x * beta, 1, observed=y)
... trace = pm.sample(1000, tune=1000)
... idata = pm.sample(1000, tune=1000)
Set the value of `x` to predict on new data.
.. code:: ipython
>>> with model:
... pm.set_data({'x': [5., 6., 9.]})
... y_test = pm.sample_posterior_predictive(trace)
... y_test = pm.sample_posterior_predictive(idata)
>>> y_test['obs'].mean(axis=0)
array([4.6088569 , 5.54128318, 8.32953844])
"""
Expand Down
22 changes: 6 additions & 16 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import aesara.gradient as tg
import numpy as np
import packaging
import xarray

from aesara.compile.mode import Mode
Expand Down Expand Up @@ -355,7 +354,7 @@ def sample(
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter
that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full``
init methods.
return_inferencedata : bool, default=False
return_inferencedata : bool, default=True
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
Defaults to `False`, but we'll switch to `True` in an upcoming release.
idata_kwargs : dict, optional
Expand Down Expand Up @@ -430,9 +429,9 @@ def sample(
In [2]: with pm.Model() as model: # context management
...: p = pm.Beta("p", alpha=alpha, beta=beta)
...: y = pm.Binomial("y", n=n, p=p, observed=h)
...: trace = pm.sample()
...: idata = pm.sample()
In [3]: az.summary(trace, kind="stats")
In [3]: az.summary(idata, kind="stats")
Out[3]:
mean sd hdi_3% hdi_97%
Expand Down Expand Up @@ -471,6 +470,9 @@ def sample(
if not isinstance(random_seed, abc.Iterable):
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")

if return_inferencedata is None:
return_inferencedata = True

if not discard_tuned_samples and not return_inferencedata:
warnings.warn(
"Tuning samples will be included in the returned `MultiTrace` object, which can lead to"
Expand All @@ -480,18 +482,6 @@ def sample(
stacklevel=2,
)

if return_inferencedata is None:
v = packaging.version.parse(pm.__version__)
if v.release[0] > 3 or v.release[1] >= 10: # type: ignore
warnings.warn(
"In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. "
"You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.",
FutureWarning,
stacklevel=2,
)
# set the default
return_inferencedata = False

if start is not None:
for start_vals in start:
_check_start_shape(model, start_vals)
Expand Down
4 changes: 2 additions & 2 deletions pymc3/step_methods/mlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,11 @@ class MLDA(ArrayStepShared):
... y = pm.Normal("y", mu=x, sigma=1, observed=datum)
... step_method = pm.MLDA(coarse_models=[coarse_model],
... subsampling_rates=5)
... trace = pm.sample(500, chains=2,
... idata = pm.sample(500, chains=2,
... tune=100, step=step_method,
... random_seed=123)
...
... az.summary(trace, kind="stats")
... az.summary(idata, kind="stats")
mean sd hdi_3% hdi_97%
x 0.99 0.987 -0.734 2.992
Expand Down
21 changes: 8 additions & 13 deletions pymc3/tests/test_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ def test_sample(self):
pm.Normal("obs", b * x_shared, np.sqrt(1e-2), observed=y)

prior_trace0 = pm.sample_prior_predictive(1000)
trace = pm.sample(1000, init=None, tune=1000, chains=1)
pp_trace0 = pm.sample_posterior_predictive(trace, 1000)
idata = pm.sample(1000, init=None, tune=1000, chains=1)
pp_trace0 = pm.sample_posterior_predictive(idata, 1000)

x_shared.set_value(x_pred)
prior_trace1 = pm.sample_prior_predictive(1000)
pp_trace1 = pm.sample_posterior_predictive(trace, samples=1000)
pp_trace1 = pm.sample_posterior_predictive(idata, samples=1000)

assert prior_trace0["b"].shape == (1000,)
assert prior_trace0["obs"].shape == (1000, 100)
Expand Down Expand Up @@ -101,23 +101,21 @@ def test_sample_after_set_data(self):
init=None,
tune=1000,
chains=1,
return_inferencedata=False,
compute_convergence_checks=False,
)
# Predict on new data.
new_x = [5.0, 6.0, 9.0]
new_y = [5.0, 6.0, 9.0]
with model:
pm.set_data(new_data={"x": new_x, "y": new_y})
new_trace = pm.sample(
new_idata = pm.sample(
1000,
init=None,
tune=1000,
chains=1,
return_inferencedata=False,
compute_convergence_checks=False,
)
pp_trace = pm.sample_posterior_predictive(new_trace, 1000)
pp_trace = pm.sample_posterior_predictive(new_idata, 1000)

assert pp_trace["obs"].shape == (1000, 3)
np.testing.assert_allclose(new_y, pp_trace["obs"].mean(axis=0), atol=1e-1)
Expand All @@ -134,12 +132,11 @@ def test_shared_data_as_index(self):
pm.Normal("obs", alpha[index], np.sqrt(1e-2), observed=y)

prior_trace = pm.sample_prior_predictive(1000, var_names=["alpha"])
trace = pm.sample(
idata = pm.sample(
1000,
init=None,
tune=1000,
chains=1,
return_inferencedata=False,
compute_convergence_checks=False,
)

Expand All @@ -148,10 +145,10 @@ def test_shared_data_as_index(self):
new_y = [5.0, 6.0, 9.0]
with model:
pm.set_data(new_data={"index": new_index, "y": new_y})
pp_trace = pm.sample_posterior_predictive(trace, 1000, var_names=["alpha", "obs"])
pp_trace = pm.sample_posterior_predictive(idata, 1000, var_names=["alpha", "obs"])

assert prior_trace["alpha"].shape == (1000, 3)
assert trace["alpha"].shape == (1000, 3)
assert idata.posterior["alpha"].shape == (1, 1000, 3)
assert pp_trace["alpha"].shape == (1000, 3)
assert pp_trace["obs"].shape == (1000, 3)

Expand Down Expand Up @@ -233,7 +230,6 @@ def test_set_data_to_non_data_container_variables(self):
init=None,
tune=1000,
chains=1,
return_inferencedata=False,
compute_convergence_checks=False,
)
with pytest.raises(TypeError) as error:
Expand All @@ -253,7 +249,6 @@ def test_model_to_graphviz_for_model_with_data_container(self):
init=None,
tune=1000,
chains=1,
return_inferencedata=False,
compute_convergence_checks=False,
)

Expand Down
Loading

0 comments on commit 0923d25

Please sign in to comment.