Skip to content

Commit

Permalink
Implement Kernel Embeddings (#36)
Browse files Browse the repository at this point in the history
* Change function names to be more clear about what they do

* Fixed tests

* Patch bump
  • Loading branch information
theogf authored Dec 7, 2021
1 parent a505e45 commit 326570b
Show file tree
Hide file tree
Showing 13 changed files with 99 additions and 48 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BayesianQuadrature"
uuid = "609f5bd8-aef1-42b2-b90e-083e3346dba9"
authors = ["Theo Galy-Fajou <theo.galyfajou@gmail.com> and contributors"]
version = "0.2.0"
version = "0.2.1"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand Down
4 changes: 3 additions & 1 deletion src/BayesianQuadrature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ export PriorSampling
export BayesModel
export prior, integrand, logintegrand, logprior, logjoint
export BQ # Short version for calling BayesianQuadrature
export KernelEmbedding
export kernel_mean, kernel_variance

const BQ = BayesianQuadrature

Expand Down Expand Up @@ -42,7 +44,7 @@ abstract type AbstractBQModel{Tp,Ti} <: AbstractMCMC.AbstractModel end

include("bayesquads/abstractbq.jl")
include("samplers/abstractbqsampler.jl")
include("kernelmeans/kernels.jl")
include("kernelembeddings/kernelembedding.jl")
include("interface.jl")
include("models.jl")
include("utils.jl")
Expand Down
6 changes: 1 addition & 5 deletions src/bayesquads/abstractbq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ end

function get_kernel_params(k::TransformedKernel; kwargs...)
check_transform(k.transform)
return get_kernel_params(k.kernel; kwargs..., l=param(k.transform))
return get_kernel_params(k.kernel; kwargs..., l=transform_param(k.transform))
end

function check_transform(transform)
Expand All @@ -44,7 +44,3 @@ function kernel(b::AbstractBQ)
return b.σ * (b.kernel ScaleTransform(inv.(b.l)))
end
end

Λ(l::Real) = abs2(l) * I
Λ(l::AbstractVector) = Diagonal(abs2.(l))
Λ(l::LowerTriangular) = l * l'
5 changes: 3 additions & 2 deletions src/bayesquads/bayesquad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ function quadrature(
isempty(samples) && error("The collection of samples is empty")
y = integrand(model).(samples)
K = kernelpdmat(kernel(bquad), samples)
z = calc_z(samples, p_0(model), bquad)
C = calc_C(p_0(model), bquad)
ke = KernelEmbedding(bquad, p_0(model))
z = kernel_mean(ke, samples)
C = kernel_variance(ke)
var = evaluate_var(z, K, C)
if var < 0
if var > -1e-5
Expand Down
7 changes: 4 additions & 3 deletions src/bayesquads/logbayesquad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,23 @@ function quadrature(
f = exp.(logf) # Compute integrand on samples

x_c = sample_candidates(bquad, samples, bquad.n_candidates) # Sample candidates around the samples
ke = KernelEmbedding(bquad, p_0(model))

gp = create_gp(bquad, samples)
f_c_0 = mean.(predict(gp, f, x_c)) # Predict integrand on x_c
logf_c_0 = mean.(predict(gp, logf, x_c)) # Predict log-integrand on x_c
Δ_c = exp.(logf_c_0) - f_c_0 # Compute difference of predictions

z = calc_z(samples, p_0(model), bquad) # Compute mean for the basic BQ
z = kernel_mean(ke, samples) # Compute mean for the basic BQ
K = kernelpdmat(kernel(bquad), samples) # and the kernel matrix

z_c = calc_z(x_c, p_0(model), bquad) # Compute mean for the ΔlogBQ
z_c = kernel_mean(ke, x_c) # Compute mean for the ΔlogBQ
K_c = kernelpdmat(kernel(bquad), x_c) # and the kernel matrix for the candidates

m_evidence = evaluate_mean(z, K, f) # Compute m(Z|samples)
m_correction = evaluate_mean(z_c, K_c, Δ_c) #

C = calc_C(p_0(model), bquad) # Compute the C component for the variance
C = kernel_variance(ke) # Compute the kernel variance

var_evidence = evaluate_var(z, K, C)
var_correction = evaluate_var(z_c, K_c, C)
Expand Down
62 changes: 62 additions & 0 deletions src/kernelembeddings/kernelembedding.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
abstract type AbstractKernelEmbedding end

measure(ke::AbstractKernelEmbedding) = ke.measure

struct KernelEmbedding{Tk<:Kernel,Tm,Tσ<:Real,Tl} <: AbstractKernelEmbedding
kernel::Tk # Kernel function
σ::Tσ # Kernel variance
l::Tl # Kernel lengthscale
measure::Tm # Measure
end

function KernelEmbedding(bquad::AbstractBQ, prior)
return KernelEmbedding(bquad.kernel, bquad.σ, bquad.l, prior)
end

scale(ke::KernelEmbedding) = ke.σ

@doc raw"""
kernel_mean(ke::KernelEmbedding{Kernel,Measure}, samples::AbstractVector)
Compute the kernel mean of the kernel embedding `ke` for each one of
the `samples` $$x_i$$:
```math
z_i = \int k(x, x_i)d\mu(x)
```
"""
kernel_mean

function kernel_mean(ke::KernelEmbedding{<:SqExponentialKernel,<:AbstractMvNormal}, samples::AbstractVector)
z = samples .- Ref(mean(measure(ke)))
B = Λ(ke.l)
return scale(ke) / sqrt(det(inv(B) * cov(measure(ke)) + I)) *
exp.(- PDMats.invquad.(Ref(PDMat(B + cov(measure(ke)))), z) / 2)
end

@doc raw"""
kernel_variance(ke::KernelEmbedding{Kernel,Measure})
Compute the kernel variance of the given kernel embedding:
```math
C = \int\int k(x,x')d\mu(x)d\mu(x')
```
"""
kernel_variance


function kernel_variance(ke::KernelEmbedding{<:SqExponentialKernel,<:AbstractMvNormal})
B = Λ(ke.l)
return scale(ke) / sqrt(det(2 * inv(B) * cov(measure(ke)) + I))
end


# Turn the lengthscale into a Diagonal matrix of noise
Λ(l::Real) = abs2(l) * I
Λ(l::AbstractVector) = Diagonal(abs2.(l))
Λ(l::LowerTriangular) = Cholesky(l, :L, 1)
Λ(l::AbstractMatrix) = l * l'

# Turns a transform into a lengthscale
transform_param(t::ScaleTransform) = inv(first(t.s))
transform_param(t::ARDTransform) = inv.(t.v)
transform_param(t::LinearTransform) = inv(t.A)
5 changes: 0 additions & 5 deletions src/kernelmeans/kernels.jl

This file was deleted.

18 changes: 0 additions & 18 deletions src/kernelmeans/sekernel.jl

This file was deleted.

6 changes: 0 additions & 6 deletions test/bayesquads/bayesquad.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
@testset "bayesquad" begin
s = 2.0
l = [1.0, 2.0]
L = LowerTriangular(rand(2, 2))
k = SqExponentialKernel()
@test BQ.Λ(s) s^2 * I
@test BQ.Λ(l) Diagonal(abs2.(l))
@test BQ.Λ(L) L * L'

σ = 4.0
@test BQ.scale(BayesQuad* k)) == σ
Expand Down
22 changes: 22 additions & 0 deletions test/kernelembeddings/kernelembedding.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
@testset "kernelembeddings" begin
rng = MersenneTwister(42)
N = 3
measure = MvNormal(ones(3), ones(3))
k = SqExponentialKernel()
l = 2.0
σ = 0.5
ke = KernelEmbedding(k, σ, l, measure)
kernel = σ * with_lengthscale(k, l)
@test BQ.scale(ke) == σ

sample = [rand(rng, 3)]
@test kernel_mean(ke, sample) [mean(kernel.(sample, eachcol(rand(rng, measure, 10000))))] atol=1e-2
@test kernel_variance(ke) mean(kernel.(eachcol(rand(rng, measure, 10000)), eachcol(rand(rng, measure, 10000)))) atol=1e-2

s = 2.0
l = [1.0, 2.0]
L = LowerTriangular(rand(2, 2))
@test BQ.Λ(s) s^2 * I
@test BQ.Λ(l) Diagonal(abs2.(l))
@test Matrix(BQ.Λ(L)) L * L'
end
2 changes: 0 additions & 2 deletions test/kernelmeans/kernels.jl

This file was deleted.

1 change: 0 additions & 1 deletion test/kernelmeans/sekernel.jl

This file was deleted.

7 changes: 3 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ include("testing_tools.jl")
include(joinpath("samplers", "priorsampling.jl"))
end

@info "Testing kernel means"
@testset "Kernel Means" begin
include(joinpath("kernelmeans", "kernels.jl"))
include(joinpath("kernelmeans", "sekernel.jl"))
@info "Testing kernel embeddings"
@testset "Kernel Embeddings" begin
include(joinpath("kernelembeddings", "kernelembedding.jl"))
end
end

2 comments on commit 326570b

@theogf
Copy link
Owner Author

@theogf theogf commented on 326570b Dec 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/50102

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.1 -m "<description of version>" 326570b3faaa4d0cef8a1d33367492ff1d529605
git push origin v0.2.1

Please sign in to comment.