-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #56 from TuringLang/ml/models
- Loading branch information
Showing
14 changed files
with
413 additions
and
113 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Correlated Gaussian | ||
|
||
This example will explore a highly-correlated Gaussian using [`Models.CorrelatedGaussian`](@ref). This model uses a conjuage Gaussian prior, see the docstring for the mathematical definition. | ||
|
||
## Setup | ||
|
||
For this example, you'll need to add the following packages | ||
```julia | ||
julia>]add Distributions MCMCChains Measurements NestedSamplers StatsBase StatsPlots | ||
``` | ||
|
||
```@setup correlated | ||
using AbstractMCMC | ||
using Random | ||
AbstractMCMC.setprogress!(false) | ||
Random.seed!(8452) | ||
``` | ||
|
||
## Define model | ||
|
||
```@example correlated | ||
using NestedSamplers | ||
# set up a 4-dimensional Gaussian | ||
D = 4 | ||
model, logz = Models.CorrelatedGaussian(D) | ||
nothing; # hide | ||
``` | ||
|
||
let's take a look at a couple of parameters to see what the likelihood surface looks like | ||
|
||
```@example correlated | ||
using StatsPlots | ||
θ1 = range(-1, 1, length=1000) | ||
θ2 = range(-1, 1, length=1000) | ||
logf = [model.loglike([t1, t2, 0, 0]) for t2 in θ2, t1 in θ1] | ||
heatmap( | ||
θ1, θ2, exp.(logf), | ||
aspect_ratio=1, | ||
xlims=extrema(θ1), | ||
ylims=extrema(θ2), | ||
xlabel="θ1", | ||
ylabel="θ2" | ||
) | ||
``` | ||
|
||
## Sample | ||
|
||
```@example correlated | ||
using MCMCChains | ||
using StatsBase | ||
# using single Ellipsoid for bounds | ||
# using Gibbs-style slicing for proposing new points | ||
sampler = Nested(D, 50 * (D + 1); | ||
bounds=Bounds.Ellipsoid, | ||
proposal=Proposals.Slice() | ||
) | ||
names = ["θ_$i" for i in 1:D] | ||
chain, state = sample(model, sampler; dlogz=0.01, param_names=names) | ||
# resample chain using statistical weights | ||
chain_resampled = sample(chain, Weights(vec(chain[:weights])), length(chain)); | ||
nothing # hide | ||
``` | ||
|
||
## Results | ||
|
||
```@example correlated | ||
chain_resampled | ||
``` | ||
|
||
```@example correlated | ||
corner(chain_resampled) | ||
``` | ||
|
||
```@example correlated | ||
using Measurements | ||
logz_est = state.logz ± state.logzerr | ||
diff = logz_est - logz | ||
print("logz: ", logz, "\nestimate: ", logz_est, "\ndiff: ", diff) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Gaussian Shells | ||
|
||
This example will explore the classic Gaussian shells model using [`Models.GaussianShells`](@ref). | ||
|
||
## Setup | ||
|
||
For this example, you'll need to add the following packages | ||
```julia | ||
julia>]add Distributions MCMCChains Measurements NestedSamplers StatsBase StatsPlots | ||
``` | ||
|
||
```@setup shells | ||
using AbstractMCMC | ||
using Random | ||
AbstractMCMC.setprogress!(false) | ||
Random.seed!(8452) | ||
``` | ||
|
||
## Define model | ||
|
||
```@example shells | ||
using NestedSamplers | ||
model, logz = Models.GaussianShells() | ||
nothing; # hide | ||
``` | ||
|
||
let's take a look at a couple of parameters to see what the likelihood surface looks like | ||
|
||
```@example shells | ||
using StatsPlots | ||
x = range(-6, 6, length=1000) | ||
y = range(-6, 6, length=1000) | ||
logf = [model.loglike([xi, yi]) for yi in y, xi in x] | ||
heatmap( | ||
x, y, exp.(logf), | ||
aspect_ratio=1, | ||
xlims=extrema(x), | ||
ylims=extrema(y), | ||
xlabel="x", | ||
ylabel="y", | ||
size=(400, 400) | ||
) | ||
``` | ||
|
||
## Sample | ||
|
||
```@example shells | ||
using MCMCChains | ||
using StatsBase | ||
# using multi-ellipsoid for bounds | ||
# using default rejection sampler for proposals | ||
sampler = Nested(2, 1000) | ||
chain, state = sample(model, sampler; dlogz=0.05, param_names=["x", "y"]) | ||
# resample chain using statistical weights | ||
chain_resampled = sample(chain, Weights(vec(chain[:weights])), length(chain)); | ||
nothing # hide | ||
``` | ||
|
||
## Results | ||
|
||
```@example shells | ||
chain_resampled | ||
``` | ||
|
||
```@example shells | ||
marginalkde(chain[:x], chain[:y]) | ||
``` | ||
|
||
```@example shells | ||
density(chain_resampled) | ||
vline!([-5.5, -1.5, 1.5, 5.5], c=:black, ls=:dash, sp=1) | ||
vline!([-2, 2], c=:black, ls=:dash, sp=2) | ||
``` | ||
|
||
```@example shells | ||
using Measurements | ||
logz_est = state.logz ± state.logzerr | ||
diff = logz_est - logz | ||
print("logz: ", logz, "\nestimate: ", logz_est, "\ndiff: ", diff) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
""" | ||
This module contains various statistical models in the form of [`NestedModel`](@ref)s. These models can be used for examples and for testing. | ||
* [`Models.GaussianShells`](@ref) | ||
* [`Models.CorrelatedGaussian`](@ref) | ||
""" | ||
module Models | ||
|
||
using ..NestedSamplers | ||
|
||
using Distributions | ||
using LinearAlgebra | ||
using StatsFuns | ||
|
||
include("shells.jl") | ||
include("correlated.jl") | ||
|
||
end # module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
|
||
@doc raw""" | ||
Models.CorrelatedGaussian(ndims) | ||
Creates a highly-correlated Gaussian with the given dimensionality. | ||
```math | ||
\mathbf\theta \sim \mathcal{N}\left(2\mathbf{1}, \mathbf{I}\right) | ||
``` | ||
```math | ||
\Sigma_{ij} = \begin{cases} 1 &\quad i=j \\ 0.95 &\quad i\neq j \end{cases} | ||
``` | ||
```math | ||
\mathcal{L}(\mathbf\theta) = \mathcal{N}\left(\mathbf\theta | \mathbf{0}, \mathbf\Sigma \right) | ||
``` | ||
the analytical evidence of the model is | ||
```math | ||
Z = \mathcal{N}\left(2\mathbf{1} | \mathbf{0}, \mathbf\Sigma + \mathbf{I} \right) | ||
``` | ||
## Examples | ||
```jldoctest | ||
julia> model, lnZ = Models.CorrelatedGaussian(10); | ||
julia> lnZ | ||
-12.482738597926607 | ||
``` | ||
""" | ||
function CorrelatedGaussian(ndims) | ||
priors = fill(Normal(2, 1), ndims) | ||
Σ = fill(0.95, ndims, ndims) | ||
Σ[diagind(Σ)] .= 1 | ||
cent_dist = MvNormal(Σ) | ||
loglike(X) = logpdf(cent_dist, X) | ||
|
||
model = NestedModel(loglike, priors) | ||
true_lnZ = logpdf(MvNormal(fill(2, ndims), Σ + I), zeros(ndims)) | ||
return model, true_lnZ | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
using StatsFuns | ||
|
||
""" | ||
Models.GaussianShells() | ||
2-D Gaussian shells centered at `[-3.5, 0]` and `[3.5, 0]` with a radius of 2 and a shell width of 0.1 | ||
# Examples | ||
```jldoctest | ||
julia> model, lnZ = Models.GaussianShells(); | ||
julia> lnZ | ||
-1.75 | ||
``` | ||
""" | ||
function GaussianShells() | ||
μ1 = [-3.5, 0] | ||
μ2 = [3.5, 0] | ||
|
||
prior(X) = 12 .* X .- 6 | ||
loglike(X) = logaddexp(logshell(X, μ1), logshell(X, μ2)) | ||
|
||
lnZ = -1.75 | ||
return NestedModel(loglike, prior), lnZ | ||
end | ||
|
||
function logshell(X, μ, radius=2, width=0.1) | ||
d = LinearAlgebra.norm(X - μ) | ||
norm = -log(sqrt(2 * π * width^2)) | ||
return norm - (d - radius)^2 / (2 * width^2) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.