Skip to content

Commit

Permalink
run 1 layer hgf in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed Sep 18, 2024
1 parent d882504 commit 0466876
Showing 1 changed file with 58 additions and 7 deletions.
65 changes: 58 additions & 7 deletions test/inference/inference_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ end

n = 5 # Number of test cases

distribution = NormalMeanVariance(10.0, 1.0)
distribution = NormalMeanVariance(0.0, 1.0)
dataset = rand(distribution, n)

@model function gcv(y, x, z, κ, ω)
Expand Down Expand Up @@ -262,7 +262,58 @@ end
return @call_rule GCV(, Marginalisation) (q_y = q_y, q_x = q_x, q_z = q_z, q_ω = q_ω, meta = meta)
end

@model function hgf(y)
@average_energy typeof(gcv) (q_y::Any, q_x::Any, q_z::Any, q_κ::Any, q_ω::Any, meta::Union{<:GCVMetadata, Nothing}) =
begin
y_mean, y_var = mean_var(q_y)
x_mean, x_var = mean_var(q_x)
z_mean, z_var = mean_var(q_z)
κ_mean, κ_var = mean_var(q_κ)
ω_mean, ω_var = mean_var(q_ω)

ksi = (κ_mean^2) * z_var + (z_mean^2) * κ_var + κ_var * z_var
psi = (y_mean - x_mean)^2 + y_var + x_var
A = exp(-ω_mean + ω_var / 2)
B = exp(-κ_mean * z_mean + ksi / 2)

(log(2π) + (z_mean * κ_mean + ω_mean) + (psi * A * B)) / 2
end

@model function hgf_1(y)
ω ~ NormalMeanVariance(0, 1)
κ ~ NormalMeanVariance(1, 1)
x_0 ~ NormalMeanVariance(0, 1)
z[1] ~ NormalMeanVariance(0, 1)
x[1] ~ gcv(x = x_0, z = z[1], κ = κ, ω = ω)
y[1] ~ NormalMeanVariance(x[1], 1)

for i in 2:length(y)
z[i] ~ NormalMeanPrecision(z[i - 1], 1)
x[i] ~ gcv(x = x[i - 1], z = z[i], κ = κ, ω = ω)
y[i] ~ NormalMeanVariance(x[i], 1)
end
end

@initialization function hgf_1_initialization()
q(ω) = NormalMeanVariance(0, 1)
q(κ) = NormalMeanVariance(1, 1)
q(z) = NormalMeanVariance(0, 1)
q(x) = NormalMeanVariance(0, 1)
end

result_1 = infer(
model = hgf_1(),
data = (y = dataset,),
initialization = hgf_1_initialization(),
constraints = MeanField(),
allow_node_contraction = true,
free_energy = true
)

@test all(!isnan, mean.(result_1.posteriors[:x]))
@test all(!isnan, var.(result_1.posteriors[:x]))
@test all(<=(0), diff(result_1.free_energy))

@model function hgf_2(y)

# Specify priors
ω_1 ~ NormalMeanVariance(0, 1)
Expand All @@ -283,7 +334,7 @@ end
end
end

initialization = @initialization begin
@initialization function hgf_2_initialization()
q(ω_1) = vague(NormalMeanVariance)
q(κ_1) = vague(NormalMeanVariance)
q(ω_2) = vague(NormalMeanVariance)
Expand All @@ -293,15 +344,15 @@ end
q(x_3) = vague(NormalMeanVariance)
end

result = infer(
model = hgf(),
result_2 = infer(
model = hgf_2(),
data = (y = dataset,),
initialization = initialization,
initialization = hgf_2_initialization(),
constraints = MeanField(),
allow_node_contraction = true
)

@test result.posteriors[:x_1] isa Vector{<:NormalMeanVariance}
@test result_2.posteriors[:x_1] isa Vector{<:NormalMeanVariance}
end

@testitem "Test warn argument in `infer()`" begin
Expand Down

0 comments on commit 0466876

Please sign in to comment.