Skip to content

Commit

Permalink
Merge pull request #313 from ReactiveBayes/initialization_macro_bug_fix
Browse files Browse the repository at this point in the history
Initialization macro bug fix
  • Loading branch information
wouterwln authored Jun 4, 2024
2 parents 736e74b + c9ea043 commit 9dd4d41
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/inference/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ function infer(;
error("""`data` and `datastream` keyword arguments cannot be used together. """)
elseif isnothing(data) && isnothing(predictvars) && isnothing(datastream)
error("""One of the keyword arguments `data` or `predictvars` or `datastream` must be specified""")
elseif !isnothing(initmessages) && !isnothing(initmarginals)
elseif !isnothing(initmessages) || !isnothing(initmarginals)
error(
"""`initmessages` and `initmarginals` keyword arguments have been deprecated and removed. Use the `@initialization` macro and the `initialization` keyword instead."""
)
Expand Down
8 changes: 7 additions & 1 deletion src/model/plugins/initialization_plugin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,13 @@ function Base.show(io::IO, init::SubModelInit)
print(IOContext(io, (:indent => get(io, :indent, 0) + 2), (:head => false)), "Init for submodel ", getsubmodel(init), " = ", getinitobjects(init))
end

Base.push!(m::InitSpecification, o::InitObject) = push!(m.init_objects, o)
function Base.push!(m::InitSpecification, o::InitObject)
if getvardescriptor(o) getvardescriptor.(getinitobjects(m))
@warn "Variable $(getvardescriptor(getvardescriptor(o))) is initialized multiple times. The last initialization will be used."
filter!(x -> getvardescriptor(getvardescriptor(x)) getvardescriptor(getvardescriptor(o)), m.init_objects)
end
push!(m.init_objects, o)
end
Base.push!(m::InitSpecification, o::SubModelInit) = push!(m.submodel_init, o)

default_init(any) = EmptyInit
Expand Down
19 changes: 19 additions & 0 deletions test/inference/inference_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -763,3 +763,22 @@ end
model = beta_bernoulli(), data = (y = [1],), autoupdates = autoupdates, initialization = @initialization(q(t) = Beta(1, 1))
)
end

@testitem "`infer` should throw an error if `initmessages` or `initmarginals` keywords are used" begin
@model function beta_bernoulli(a, b, y)
t ~ Beta(a, b)
y ~ Bernoulli(t)
end

@test_throws "`initmessages` and `initmarginals` keyword arguments have been deprecated and removed. Use the `@initialization` macro and the `initialization` keyword instead." infer(
model = beta_bernoulli(), data = (y = 1,), initmessages = (t = Normal(0.0, 1.0)), initmarginals = (t = Normal(0.0, 1.0))
)

@test_throws "`initmessages` and `initmarginals` keyword arguments have been deprecated and removed. Use the `@initialization` macro and the `initialization` keyword instead." infer(
model = beta_bernoulli(), data = (y = 1,), initmarginals = (t = Normal(0.0, 1.0))
)

@test_throws "`initmessages` and `initmarginals` keyword arguments have been deprecated and removed. Use the `@initialization` macro and the `initialization` keyword instead." infer(
model = beta_bernoulli(), data = (y = 1,), initmessages = (t = Normal(0.0, 1.0))
)
end
14 changes: 14 additions & 0 deletions test/model/initialization_plugin_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -693,3 +693,17 @@ end
model = GraphPPL.create_model(GraphPPL.with_plugins(model_with_init(), GraphPPL.PluginsCollection(RxInfer.InitializationPlugin())))
@test GraphPPL.getextra(model[model[][:x]], RxInfer.InitMarExtraKey) == NormalMeanVariance(0, 1e12)
end

@testitem "throw warning if double init" begin
using RxInfer

@test_logs (:warn, "Variable u is initialized multiple times. The last initialization will be used.") @initialization begin
q(u) = NormalMeanVariance(0, 1)
q(u) = NormalMeanVariance(0, 1)
end

@test_nowarn @initialization begin
q(u) = NormalMeanVariance(0, 1)
μ(u) = NormalMeanVariance(0, 1)
end
end

0 comments on commit 9dd4d41

Please sign in to comment.