Skip to content

Commit

Permalink
Add deterministic tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xjing76 committed Oct 13, 2021
1 parent fd288a9 commit 2478e41
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
26 changes: 14 additions & 12 deletions pymc3_hmm/step_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from aesara.graph.op import get_test_value as test_value
from aesara.graph.opt import OpRemove, pre_greedy_local_optimizer
from aesara.graph.optdb import Query
from aesara.scalar.basic import Mul
from aesara.scalar.basic import Dot
from aesara.sparse.basic import StructuredDot
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.subtensor import AdvancedIncSubtensor1
Expand All @@ -29,7 +29,7 @@
from theano.tensor.elemwise import DimShuffle, Elemwise
from theano.tensor.subtensor import AdvancedIncSubtensor1
from theano.tensor.var import TensorConstant
from theano.scalar.basic import Mul
from theano.tensor.basic import Dot
from theano.sparse.basic import StructuredDot

import pymc3 as pm
Expand Down Expand Up @@ -549,18 +549,16 @@ def __init__(self, vars, values=None, model=None):
if hasattr(i, "distribution") and isinstance(i.distribution, pm.Normal):
mu = i.distribution.mu
elif isinstance(i, pm.model.DeterministicWrapper):
mu = i
mu = i.owner.inputs[0]
else:
continue
dense_dot = isinstance(mu.owner.op, Elemwise) and isinstance(
mu.owner.op.scalar_op, Mul
)
dense_dot = isinstance(mu.owner.op, Dot)
sparse_dot = isinstance(mu.owner.op, StructuredDot)
if not (
mu.owner
and (dense_dot or sparse_dot)
and beta in mu.owner.inputs[1].owner.inputs
):

dense_inputs = dense_dot and beta in mu.owner.inputs
sparse_inputs = sparse_dot and beta in mu.owner.inputs[1].owner.inputs

if not (mu.owner and dense_inputs or sparse_inputs):
continue
if i in model.observed_RVs:

Expand All @@ -570,7 +568,10 @@ def y_fn(x):
else:
y_fn = model.fn(i)

X_fn = model.fn(mu.owner.inputs[0])
if dense_inputs:
X_fn = model.fn(mu.owner.inputs[1])
else:
X_fn = model.fn(mu.owner.inputs[0])

self.vars = [beta]

Expand All @@ -585,6 +586,7 @@ def y_fn(x):
self.X_fn = X_fn

def step(self, point):
# breakpoint()
y = self.y_fn(point)
X = self.X_fn(point)

Expand Down
14 changes: 13 additions & 1 deletion tests/test_step_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def test_Hsstep():
M = X.shape[1]
with pm.Model():
beta = HorseShoe("beta", tau=1, shape=M)
pm.Normal("y", mu=X * beta, sigma=1, observed=y)
pm.Normal("y", mu=beta.dot(X), sigma=1, observed=y)
hsstep = HSStep([beta])
trace = pm.sample(
draws=20, tune=0, step=hsstep, chains=1, return_inferencedata=True
Expand All @@ -412,6 +412,18 @@ def test_Hsstep():

assert beta_samples.shape == (20, M)

with pm.Model():
beta = HorseShoe("beta", tau=1, shape=M)
mu = pm.Deterministic("mu", beta.dot(X))
pm.Normal("y", mu=mu, sigma=1, observed=y)
hsstep = HSStep([beta])
trace = pm.sample(
draws=20, tune=0, step=hsstep, chains=1, return_inferencedata=True
)

beta_samples = trace.posterior["beta"][0].values
assert beta_samples.shape == (20, M)

# test case for sparse matrix
X = sp.sparse.csr_matrix(X)

Expand Down

0 comments on commit 2478e41

Please sign in to comment.