-
Notifications
You must be signed in to change notification settings - Fork 248
/
Copy pathhorseshoe_regression.py
177 lines (131 loc) · 6.05 KB
/
horseshoe_regression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
"""
Example: Horseshoe Regression
=============================
We demonstrate how to use NUTS to do sparse regression using
the Horseshoe prior [1] for both continuous- and binary-valued
responses. For a more complex modeling and inference approach
that also supports quadratic interaction terms in a way that
is efficient in high dimensions see examples/sparse_regression.py.
References:
[1] "Handling Sparsity via the Horseshoe,"
Carlos M. Carvalho, Nicholas G. Polson, James G. Scott.
"""
import argparse
import os
import time
import numpy as np
from scipy.special import expit
import jax.numpy as jnp
import jax.random as random
import numpyro
from numpyro.diagnostics import summary
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
# regression model with continuous-valued outputs/responses
def model_normal_likelihood(X, Y):
D_X = X.shape[1]
# sample from horseshoe prior
lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(D_X)))
tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))
# note that in practice for a normal likelihood we would probably want to
# integrate out the coefficients (as is done for example in sparse_regression.py).
# however, this trick wouldn't be applicable to other likelihoods
# (e.g. bernoulli, see below) so we don't make use of it here.
unscaled_betas = numpyro.sample("unscaled_betas", dist.Normal(0.0, jnp.ones(D_X)))
scaled_betas = numpyro.deterministic("betas", tau * lambdas * unscaled_betas)
# compute mean function using linear coefficients
mean_function = jnp.dot(X, scaled_betas)
prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
sigma_obs = 1.0 / jnp.sqrt(prec_obs)
# observe data
numpyro.sample("Y", dist.Normal(mean_function, sigma_obs), obs=Y)
# regression model with binary-valued outputs/responses
def model_bernoulli_likelihood(X, Y):
D_X = X.shape[1]
# sample from horseshoe prior
lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(D_X)))
tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))
# note that this reparameterization (i.e. coordinate transformation) improves
# posterior geometry and makes NUTS sampling more efficient
unscaled_betas = numpyro.sample("unscaled_betas", dist.Normal(0.0, jnp.ones(D_X)))
scaled_betas = numpyro.deterministic("betas", tau * lambdas * unscaled_betas)
# compute mean function using linear coefficients
mean_function = jnp.dot(X, scaled_betas)
# observe data
numpyro.sample("Y", dist.Bernoulli(logits=mean_function), obs=Y)
# helper function for HMC inference
def run_inference(model, args, rng_key, X, Y):
start = time.time()
kernel = NUTS(model)
mcmc = MCMC(
kernel,
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(rng_key, X, Y)
mcmc.print_summary(exclude_deterministic=False)
samples = mcmc.get_samples()
summary_dict = summary(samples, group_by_chain=False)
print("\nMCMC elapsed time:", time.time() - start)
return summary_dict
# create artificial regression dataset with 3 non-zero regression coefficients
def get_data(N=50, D_X=3, sigma_obs=0.05, response="continuous"):
assert response in ["continuous", "binary"]
assert D_X >= 3
np.random.seed(0)
X = np.random.randn(N, D_X)
# the response only depends on X_0, X_1, and X_2
W = np.array([2.0, -1.0, 0.50])
Y = jnp.dot(X[:, :3], W)
Y -= jnp.mean(Y)
if response == "continuous":
Y += sigma_obs * np.random.randn(N)
elif response == "binary":
Y = np.random.binomial(1, expit(Y))
assert X.shape == (N, D_X)
assert Y.shape == (N,)
return X, Y
def main(args):
N, D_X = args.num_data, 32
print("[Experiment with continuous-valued responses]")
# first generate and analyze data with continuous-valued responses
X, Y = get_data(N=N, D_X=D_X, response="continuous")
# do inference
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
summary = run_inference(model_normal_likelihood, args, rng_key, X, Y)
# lambda should only be large for the first 3 dimensions, which
# correspond to relevant covariates (see get_data)
print("Posterior median over lambdas (leading 5 dimensions):")
print(summary["lambdas"]["median"][:5])
print("Posterior mean over betas (leading 5 dimensions):")
print(summary["betas"]["mean"][:5])
print("[Experiment with binary-valued responses]")
# next generate and analyze data with binary-valued responses
# (note we use more data for the case of binary-valued responses,
# since each response carries less information than a real number)
X, Y = get_data(N=4 * N, D_X=D_X, response="binary")
# do inference
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
summary = run_inference(model_bernoulli_likelihood, args, rng_key, X, Y)
# lambda should only be large for the first 3 dimensions, which
# correspond to relevant covariates (see get_data)
print("Posterior median over lambdas (leading 5 dimensions):")
print(summary["lambdas"]["median"][:5])
print("Posterior mean over betas (leading 5 dimensions):")
print(summary["betas"]["mean"][:5])
if __name__ == "__main__":
assert numpyro.__version__.startswith("0.17.0")
parser = argparse.ArgumentParser(description="Horseshoe regression example")
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
parser.add_argument("--num-data", nargs="?", default=100, type=int)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)