diff --git a/src/inference/inference.jl b/src/inference/inference.jl index d7521fef8..9901a30dc 100644 --- a/src/inference/inference.jl +++ b/src/inference/inference.jl @@ -291,7 +291,7 @@ function infer(; error("""`data` and `datastream` keyword arguments cannot be used together. """) elseif isnothing(data) && isnothing(predictvars) && isnothing(datastream) error("""One of the keyword arguments `data` or `predictvars` or `datastream` must be specified""") - elseif !isnothing(initmessages) && !isnothing(initmarginals) + elseif !isnothing(initmessages) || !isnothing(initmarginals) error( """`initmessages` and `initmarginals` keyword arguments have been deprecated and removed. Use the `@initialization` macro and the `initialization` keyword instead.""" ) diff --git a/src/model/plugins/initialization_plugin.jl b/src/model/plugins/initialization_plugin.jl index 333261c07..31bfa9272 100644 --- a/src/model/plugins/initialization_plugin.jl +++ b/src/model/plugins/initialization_plugin.jl @@ -104,7 +104,13 @@ function Base.show(io::IO, init::SubModelInit) print(IOContext(io, (:indent => get(io, :indent, 0) + 2), (:head => false)), "Init for submodel ", getsubmodel(init), " = ", getinitobjects(init)) end -Base.push!(m::InitSpecification, o::InitObject) = push!(m.init_objects, o) +function Base.push!(m::InitSpecification, o::InitObject) + if getvardescriptor(o) ∈ getvardescriptor.(getinitobjects(m)) + @warn "Variable $(getvardescriptor(getvardescriptor(o))) is initialized multiple times. The last initialization will be used." + filter!(x -> getvardescriptor(getvardescriptor(x)) ≠ getvardescriptor(getvardescriptor(o)), m.init_objects) + end + push!(m.init_objects, o) +end Base.push!(m::InitSpecification, o::SubModelInit) = push!(m.submodel_init, o) default_init(any) = EmptyInit diff --git a/test/inference/inference_tests.jl b/test/inference/inference_tests.jl index c2d741acf..4c417bbae 100644 --- a/test/inference/inference_tests.jl +++ b/test/inference/inference_tests.jl @@ -763,3 +763,22 @@ end model = beta_bernoulli(), data = (y = [1],), autoupdates = autoupdates, initialization = @initialization(q(t) = Beta(1, 1)) ) end + +@testitem "`infer` should throw an error if `initmessages` or `initmarginals` keywords are used" begin + @model function beta_bernoulli(a, b, y) + t ~ Beta(a, b) + y ~ Bernoulli(t) + end + + @test_throws "`initmessages` and `initmarginals` keyword arguments have been deprecated and removed. Use the `@initialization` macro and the `initialization` keyword instead." infer( + model = beta_bernoulli(), data = (y = 1,), initmessages = (t = Normal(0.0, 1.0)), initmarginals = (t = Normal(0.0, 1.0)) + ) + + @test_throws "`initmessages` and `initmarginals` keyword arguments have been deprecated and removed. Use the `@initialization` macro and the `initialization` keyword instead." infer( + model = beta_bernoulli(), data = (y = 1,), initmarginals = (t = Normal(0.0, 1.0)) + ) + + @test_throws "`initmessages` and `initmarginals` keyword arguments have been deprecated and removed. Use the `@initialization` macro and the `initialization` keyword instead." infer( + model = beta_bernoulli(), data = (y = 1,), initmessages = (t = Normal(0.0, 1.0)) + ) +end diff --git a/test/model/initialization_plugin_tests.jl b/test/model/initialization_plugin_tests.jl index c752fdb29..22ff47dd9 100644 --- a/test/model/initialization_plugin_tests.jl +++ b/test/model/initialization_plugin_tests.jl @@ -693,3 +693,17 @@ end model = GraphPPL.create_model(GraphPPL.with_plugins(model_with_init(), GraphPPL.PluginsCollection(RxInfer.InitializationPlugin()))) @test GraphPPL.getextra(model[model[][:x]], RxInfer.InitMarExtraKey) == NormalMeanVariance(0, 1e12) end + +@testitem "throw warning if double init" begin + using RxInfer + + @test_logs (:warn, "Variable u is initialized multiple times. The last initialization will be used.") @initialization begin + q(u) = NormalMeanVariance(0, 1) + q(u) = NormalMeanVariance(0, 1) + end + + @test_nowarn @initialization begin + q(u) = NormalMeanVariance(0, 1) + μ(u) = NormalMeanVariance(0, 1) + end +end