diff --git a/src/model/graphppl.jl b/src/model/graphppl.jl index 47e4364df..5e6bb569e 100644 --- a/src/model/graphppl.jl +++ b/src/model/graphppl.jl @@ -225,14 +225,14 @@ function GraphPPL.default_parametrization(backend::ReactiveMPGraphPPLBackend, no end function GraphPPL.instantiate(::Type{ReactiveMPGraphPPLBackend}) - return ReactiveMPGraphPPLBackend(Static.False) + return ReactiveMPGraphPPLBackend(Static.False()) end function GraphPPL.instantiate(::Type{ReactiveMPGraphPPLBackend{Static.True}}) - return ReactiveMPGraphPPLBackend(Static.True) + return ReactiveMPGraphPPLBackend(Static.True()) end function GraphPPL.instantiate(::Type{ReactiveMPGraphPPLBackend{Static.False}}) - return ReactiveMPGraphPPLBackend(Static.False) + return ReactiveMPGraphPPLBackend(Static.False()) end # Node specific aliases diff --git a/test/inference/inference_tests.jl b/test/inference/inference_tests.jl index 5781c2c93..e63d04278 100644 --- a/test/inference/inference_tests.jl +++ b/test/inference/inference_tests.jl @@ -227,7 +227,7 @@ end import RxInfer: ReactiveMPGraphPPLBackend import Static - n = 5 # Number of test cases + n = 6 # Number of test cases distribution = NormalMeanVariance(0.0, 1.0) dataset = rand(distribution, n) @@ -332,13 +332,13 @@ end q(ω_2) = vague(NormalMeanVariance) q(κ_2) = vague(NormalMeanVariance) q(x_1) = vague(NormalMeanVariance) - q(x_2) = vague(NormalMeanVariance) + q(x_2[1:2:n]) = vague(NormalMeanVariance) q(x_3) = vague(NormalMeanVariance) end result_2 = infer(model = hgf_2(), data = (y = dataset,), initialization = hgf_2_initialization(), constraints = MeanField(), allow_node_contraction = true) - @test result_2.posteriors[:x_1] isa Vector{<:NormalMeanVariance} + @test result_2.posteriors[:x_1] isa Vector{<:NormalDistributionsFamily} end @testitem "Test warn argument in `infer()`" begin diff --git a/test/model/graphppl_tests.jl b/test/model/graphppl_tests.jl index 761d307e2..ed7889842 100644 --- a/test/model/graphppl_tests.jl +++ b/test/model/graphppl_tests.jl @@ -92,12 +92,13 @@ end import RxInfer: ReactiveMPGraphPPLBackend import ReactiveMP: @node import GraphPPL + using Static struct CustomStochasticNode end @node CustomStochasticNode Stochastic [out, (x, aliases = [xx]), (y, aliases = [yy]), z] - backend = ReactiveMPGraphPPLBackend() + backend = ReactiveMPGraphPPLBackend(Static.False()) @test GraphPPL.NodeBehaviour(backend, CustomStochasticNode) === GraphPPL.Stochastic() @test GraphPPL.NodeType(backend, CustomStochasticNode) === GraphPPL.Atomic() @@ -135,7 +136,7 @@ end @node CustomStochasticNode Stochastic [out, (x, aliases = [xx]), (y, aliases = [yy]), z] - backend = ReactiveMPGraphPPLBackend(true) + backend = ReactiveMPGraphPPLBackend(Static.True()) @test GraphPPL.NodeBehaviour(backend, CustomStochasticNode) === GraphPPL.Stochastic() @test GraphPPL.NodeType(backend, CustomStochasticNode) === GraphPPL.Atomic() diff --git a/test/model/initialization_plugin_tests.jl b/test/model/initialization_plugin_tests.jl index 22ff47dd9..90e930be4 100644 --- a/test/model/initialization_plugin_tests.jl +++ b/test/model/initialization_plugin_tests.jl @@ -95,15 +95,15 @@ end @model function gcv(κ, ω, z, x, y) log_σ := κ * z + ω - y ~ Normal(x, exp(log_σ)) + y ~ NormalMeanVariance(x, exp(log_σ)) end @model function gcv_collection() - κ ~ Normal(0, 1) - ω ~ Normal(0, 1) - z ~ Normal(0, 1) + κ ~ NormalMeanVariance(0, 1) + ω ~ NormalMeanVariance(0, 1) + z ~ NormalMeanVariance(0, 1) for i in 1:10 - x[i] ~ Normal(0, 1) + x[i] ~ NormalMeanVariance(0, 1) y[i] ~ gcv(κ = κ, ω = ω, z = z, x = x[i]) end end