Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wouterwln committed Sep 18, 2024
1 parent a631308 commit 98de244
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 13 deletions.
6 changes: 3 additions & 3 deletions src/model/graphppl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,14 @@ function GraphPPL.default_parametrization(backend::ReactiveMPGraphPPLBackend, no
end

function GraphPPL.instantiate(::Type{ReactiveMPGraphPPLBackend})
return ReactiveMPGraphPPLBackend(Static.False)
return ReactiveMPGraphPPLBackend(Static.False())

Check warning on line 228 in src/model/graphppl.jl

View check run for this annotation

Codecov / codecov/patch

src/model/graphppl.jl#L228

Added line #L228 was not covered by tests
end

function GraphPPL.instantiate(::Type{ReactiveMPGraphPPLBackend{Static.True}})
return ReactiveMPGraphPPLBackend(Static.True)
return ReactiveMPGraphPPLBackend(Static.True())

Check warning on line 232 in src/model/graphppl.jl

View check run for this annotation

Codecov / codecov/patch

src/model/graphppl.jl#L231-L232

Added lines #L231 - L232 were not covered by tests
end
function GraphPPL.instantiate(::Type{ReactiveMPGraphPPLBackend{Static.False}})
return ReactiveMPGraphPPLBackend(Static.False)
return ReactiveMPGraphPPLBackend(Static.False())
end

# Node specific aliases
Expand Down
6 changes: 3 additions & 3 deletions test/inference/inference_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ end
import RxInfer: ReactiveMPGraphPPLBackend
import Static

n = 5 # Number of test cases
n = 6 # Number of test cases

distribution = NormalMeanVariance(0.0, 1.0)
dataset = rand(distribution, n)
Expand Down Expand Up @@ -332,13 +332,13 @@ end
q(ω_2) = vague(NormalMeanVariance)
q(κ_2) = vague(NormalMeanVariance)
q(x_1) = vague(NormalMeanVariance)
q(x_2) = vague(NormalMeanVariance)
q(x_2[1:2:n]) = vague(NormalMeanVariance)
q(x_3) = vague(NormalMeanVariance)
end

result_2 = infer(model = hgf_2(), data = (y = dataset,), initialization = hgf_2_initialization(), constraints = MeanField(), allow_node_contraction = true)

@test result_2.posteriors[:x_1] isa Vector{<:NormalMeanVariance}
@test result_2.posteriors[:x_1] isa Vector{<:NormalDistributionsFamily}
end

@testitem "Test warn argument in `infer()`" begin
Expand Down
5 changes: 3 additions & 2 deletions test/model/graphppl_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,13 @@ end
import RxInfer: ReactiveMPGraphPPLBackend
import ReactiveMP: @node
import GraphPPL
using Static

struct CustomStochasticNode end

@node CustomStochasticNode Stochastic [out, (x, aliases = [xx]), (y, aliases = [yy]), z]

backend = ReactiveMPGraphPPLBackend()
backend = ReactiveMPGraphPPLBackend(Static.False())

@test GraphPPL.NodeBehaviour(backend, CustomStochasticNode) === GraphPPL.Stochastic()
@test GraphPPL.NodeType(backend, CustomStochasticNode) === GraphPPL.Atomic()
Expand Down Expand Up @@ -135,7 +136,7 @@ end

@node CustomStochasticNode Stochastic [out, (x, aliases = [xx]), (y, aliases = [yy]), z]

backend = ReactiveMPGraphPPLBackend(true)
backend = ReactiveMPGraphPPLBackend(Static.True())

@test GraphPPL.NodeBehaviour(backend, CustomStochasticNode) === GraphPPL.Stochastic()
@test GraphPPL.NodeType(backend, CustomStochasticNode) === GraphPPL.Atomic()
Expand Down
10 changes: 5 additions & 5 deletions test/model/initialization_plugin_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ end

@model function gcv(κ, ω, z, x, y)
log_σ := κ * z + ω
y ~ Normal(x, exp(log_σ))
y ~ NormalMeanVariance(x, exp(log_σ))
end

@model function gcv_collection()
κ ~ Normal(0, 1)
ω ~ Normal(0, 1)
z ~ Normal(0, 1)
κ ~ NormalMeanVariance(0, 1)
ω ~ NormalMeanVariance(0, 1)
z ~ NormalMeanVariance(0, 1)
for i in 1:10
x[i] ~ Normal(0, 1)
x[i] ~ NormalMeanVariance(0, 1)
y[i] ~ gcv= κ, ω = ω, z = z, x = x[i])
end
end
Expand Down

0 comments on commit 98de244

Please sign in to comment.