diff --git a/Project.toml b/Project.toml index 21732a49..733aff9f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NestedSamplers" uuid = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e" authors = ["Miles Lucas "] -version = "0.6.2" +version = "0.6.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/docs/src/examples/correlated.md b/docs/src/examples/correlated.md index 9f851166..f2ab0a4b 100644 --- a/docs/src/examples/correlated.md +++ b/docs/src/examples/correlated.md @@ -52,7 +52,7 @@ using MCMCChains using StatsBase # using single Ellipsoid for bounds # using Gibbs-style slicing for proposing new points -sampler = Nested(D, 50 * (D + 1); +sampler = Nested(D, 50D; bounds=Bounds.Ellipsoid, proposal=Proposals.Slice() ) diff --git a/src/sample.jl b/src/sample.jl index 3ee6a603..5a64d8d1 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -10,17 +10,17 @@ StatsBase.sample(rng::AbstractRNG, model::AbstractModel, sampler::Nested; kwargs StatsBase.sample(model::AbstractModel, sampler::Nested; kwargs...) = StatsBase.sample(GLOBAL_RNG, model, sampler; kwargs...) -function nested_isdone(rng, model, smapler, samples, state, i; progress=true, maxiter=Inf, maxcall=Inf, dlogz=0.5, maxlogl=Inf, kwargs...) +function nested_isdone(rng, model, sampler, samples, state, i; progress=true, maxiter=Inf, maxcall=Inf, dlogz=0.5, maxlogl=Inf, kwargs...) # 1) iterations exceeds maxiter - done_sampling = state.it > maxiter + done_sampling = state.it ≥ maxiter # 2) number of loglike calls has been exceeded - done_sampling |= state.ncall > maxcall + done_sampling |= state.ncall ≥ maxcall # 3) remaining fractional log-evidence below threshold - logz_remain = maximum(state.logl) + state.logvol + logz_remain = maximum(state.logl) - state.it / sampler.nactive delta_logz = logaddexp(state.logz, logz_remain) - state.logz - done_sampling |= delta_logz < dlogz + done_sampling |= delta_logz ≤ dlogz # 4) last dead point loglikelihood exceeds threshold - done_sampling |= state.logl_dead > maxlogl + done_sampling |= state.logl_dead ≥ maxlogl # 5) number of effective samples # TODO diff --git a/src/staticsampler.jl b/src/staticsampler.jl index 8153ca97..89f69dfc 100644 --- a/src/staticsampler.jl +++ b/src/staticsampler.jl @@ -1,6 +1,6 @@ # Sampler and model implementations -struct Nested{B,P <: AbstractProposal} <: AbstractSampler +struct Nested{B, P <: AbstractProposal} <: AbstractSampler ndims::Int nactive::Int bounds::B @@ -9,6 +9,7 @@ struct Nested{B,P <: AbstractProposal} <: AbstractSampler min_ncall::Int min_eff::Float64 proposal::P + dlv::Float64 end """ @@ -65,6 +66,8 @@ function Nested(ndims, end end + dlv = log(nactive + 1) - log(nactive) + update_interval_frac = get(kwargs, :update_interval, default_update_interval(proposal, ndims)) update_interval = round(Int, update_interval_frac * nactive) return Nested(ndims, @@ -74,7 +77,8 @@ function Nested(ndims, update_interval, min_ncall, min_eff, - proposal) + proposal, + dlv) end default_update_interval(p::Proposals.Uniform, ndims) = 1.5 diff --git a/src/step.jl b/src/step.jl index 702c8b3a..630eeed1 100644 --- a/src/step.jl +++ b/src/step.jl @@ -6,8 +6,8 @@ function step(rng, model, sampler::Nested; kwargs...) # Find least likely point logl_dead, idx_dead = findmin(logl) - u_dead = @view us[:, idx_dead] - v_dead = @view vs[:, idx_dead] + u_dead = us[:, idx_dead] + v_dead = vs[:, idx_dead] # update weight using trapezoidal rule logvol = log1mexp(-1 / sampler.nactive) @@ -63,8 +63,8 @@ function step(rng, model, sampler, state; kwargs...) ## Replace least-likely active point # Find least likely point logl_dead, idx_dead = findmin(state.logl) - u_dead = @view state.us[:, idx_dead] - v_dead = @view state.vs[:, idx_dead] + u_dead = state.us[:, idx_dead] + v_dead = state.vs[:, idx_dead] # sample a new live point using bounds and proposal if has_bounds @@ -117,13 +117,10 @@ function bundle_samples(samples, if add_live samples, state = add_live_points(samples, model, sampler, state) end - vals = mapreduce(t -> hcat(t.v..., t.logwt), vcat, samples) - # update weights based on evidence - @. vals[:, end, 1] = exp(vals[:, end, 1] - state.logz) - wsum = sum(vals[:, end, 1]) - @. vals[:, end, 1] /= wsum + vals = mapreduce(t -> hcat(t.v..., exp(t.logwt - state.logz)), vcat, samples) if check_wsum + wsum = sum(vals[:, end, 1]) err = !iszero(state.logzerr) ? 3 * state.logzerr : 1e-3 isapprox(wsum, 1, atol=err) || @warn "Weights sum to $wsum instead of 1; possible bug" end @@ -150,14 +147,14 @@ function bundle_samples(samples, samples, state = add_live_points(samples, model, sampler, state) end - wsum = sum(s -> exp(s.logwt - state.logz), samples) + vals = mapreduce(t -> hcat(t.v..., exp(t.logwt - state.logz)), vcat, samples) if check_wsum + wsum = sum(vals[:, end]) err = !iszero(state.logzerr) ? 3 * state.logzerr : 1e-3 isapprox(wsum, 1, atol=err) || @warn "Weights sum to $wsum instead of 1; possible bug" end - vals = mapreduce(t -> hcat(t.v..., t.logwt / wsum), vcat, samples) return vals, state end @@ -201,11 +198,12 @@ function add_live_points(samples, model, sampler, state) prev_h = state.h local logl, logz, h, logzerr + N = length(samples) @inbounds for (i, idx) in enumerate(eachindex(state.logl)) # get new point - u = @view state.us[:, idx] - v = @view state.vs[:, idx] + u = state.us[:, idx] + v = state.vs[:, idx] logl = state.logl[idx] # update sampler @@ -219,7 +217,7 @@ function add_live_points(samples, model, sampler, state) prev_h = h sample = (u = u, v = v, logwt = logwt, logl = logl) - save!!(samples, sample, length(samples) + i, model, sampler) + save!!(samples, sample, N + i, model, sampler) end state = (it = state.it + sampler.nactive, us = state.us, vs = state.vs, logl = logl, diff --git a/test/models.jl b/test/models.jl index 8ae9d93c..1c012912 100644 --- a/test/models.jl +++ b/test/models.jl @@ -1,15 +1,11 @@ const test_bounds = [Bounds.Ellipsoid, Bounds.MultiEllipsoid] -const test_props = [Proposals.Uniform(), Proposals.RWalk(ratio=0.9, walks=75), Proposals.RStagger(ratio=0.9, walks=75, scale=0.5), Proposals.Slice(slices=10), Proposals.RSlice()] +const test_props = [Proposals.Uniform(), Proposals.RWalk(ratio=0.9, walks=50), Proposals.RStagger(ratio=0.9, walks=75), Proposals.Slice(slices=10), Proposals.RSlice()] @testset "$(nameof(bound)), $(nameof(typeof(proposal)))" for bound in test_bounds, proposal in test_props - @testset "Correlated Gaussian Conjugate Prior - ndims=$D" for D in [2, 4, 8] - if D == 8 && (proposal isa Proposals.RWalk || proposal isa Proposals.RStagger) - # TODO evidence estimates are terrible for D=8 - continue - end + @testset "Correlated Gaussian Conjugate Prior - ndims=$D" for D in [2, 4] model, logz = Models.CorrelatedGaussian(D) - + # match JAXNS paper setup, generally sampler = Nested(D, 50D; bounds=bound, proposal=proposal) chain, state = sample(rng, model, sampler; dlogz=0.01) @@ -17,15 +13,16 @@ const test_props = [Proposals.Uniform(), Proposals.RWalk(ratio=0.9, walks=75), P # test posteriors vals = Array(chain_res) means = mean(vals, dims=1) - tols = 3 .* std(vals, mean=means, dims=1) # 3-sigma + tols = 2std(vals, mean=means, dims=1) # 2-sigma μ = fill(2.0, D) Σ = fill(0.95, D, D) Σ[diagind(Σ)] .= 1 expected = Σ * ((Σ + I) \ μ) @test all(@.(abs(means - expected) < tols)) - # logz (5-sigma) - @test state.logz ≈ logz atol = 5state.logzerr + # logz + tol = 5state.logzerr + @test state.logz ≈ logz atol = tol end @testset "Gaussian Shells" begin @@ -46,12 +43,12 @@ const test_props = [Proposals.Uniform(), Proposals.RWalk(ratio=0.9, walks=75), P inv_σ = diagm(0 => fill(1 / σ^2, 2)) function logl(x) - dx1 = x .- μ1 - dx2 = x .- μ2 - f1 = -dx1' * (inv_σ * dx1) / 2 - f2 = -dx2' * (inv_σ * dx2) / 2 - return logaddexp(f1, f2) - end + dx1 = x .- μ1 + dx2 = x .- μ2 + f1 = -dx1' * (inv_σ * dx1) / 2 + f2 = -dx2' * (inv_σ * dx2) / 2 + return logaddexp(f1, f2) + end prior(X) = 10 .* X .- 5 model = NestedModel(logl, prior) @@ -60,22 +57,22 @@ const test_props = [Proposals.Uniform(), Proposals.RWalk(ratio=0.9, walks=75), P spl = Nested(2, 1000, bounds=bound, proposal=proposal) - chain, state = sample(rng, model, spl; dlogz=0.2) + chain, state = sample(rng, model, spl; dlogz=0.01) chain_res = sample(chain, Weights(vec(chain[:weights])), length(chain)) diff = state.logz - analytic_logz atol = 5state.logzerr if diff > atol - @warn "logz estimate is poor" bound proposal error = diff tolerance = atol - end + @warn "logz estimate is poor" bound proposal error = diff tolerance = atol + end - @test state.logz ≈ analytic_logz atol = atol # within 3σ + @test state.logz ≈ analytic_logz atol = atol # within 1σ xmodes = sort!(findpeaks(chain_res[:, 1, 1])[1:2]) - @test xmodes[1] ≈ -1 atol = 2σ - @test xmodes[2] ≈ 1 atol = 2σ + @test xmodes[1] ≈ -1 atol = σ + @test xmodes[2] ≈ 1 atol = σ ymodes = sort!(findpeaks(chain_res[:, 2, 1])[1:2]) - @test ymodes[1] ≈ -1 atol = 2σ - @test ymodes[2] ≈ 1 atol = 2σ + @test ymodes[1] ≈ -1 atol = σ + @test ymodes[2] ≈ 1 atol = σ end end diff --git a/test/sampler.jl b/test/sampler.jl index 62de3662..44301c22 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -18,6 +18,7 @@ spl = Nested(3, 100) @test spl.update_interval == 150 @test spl.enlarge == 1.25 @test spl.min_ncall == 200 +@test spl.dlv ≈ log(101/100) spl = Nested(10, 1000) diff --git a/test/sampling.jl b/test/sampling.jl index d6303eab..6e7ddc80 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -32,14 +32,14 @@ end chains, state = sample(rng, model, spl; add_live=false, dlogz=1.0) logz_remain = maximum(state.logl) + state.logvol delta_logz = logaddexp(state.logz, logz_remain) - state.logz - @test delta_logz < 1.0 + @test delta_logz ≤ 1.0 chains, state = sample(rng, model, spl; add_live=false, maxiter=3) - @test state.it < 3 + @test state.it == 3 chains, state = sample(rng, model, spl; add_live=false, maxcall=10) - @test state.ncall < 10 + @test state.ncall == 10 chains, state = sample(rng, model, spl; add_live=false, maxlogl=0.2) - @test state.logl[1] > 0.2 + @test state.logl[1] ≥ 0.2 end