From 0bc503d1a0713e409120e09eb485f092bbb86879 Mon Sep 17 00:00:00 2001 From: Thijs van de Laar Date: Mon, 9 May 2022 15:03:13 +0200 Subject: [PATCH 1/6] error on incomplete posterior factorization coverage --- src/algorithms/inference_algorithm.jl | 21 +++++++++++++------ src/algorithms/posterior_factorization.jl | 15 ++++++++++++- src/factor_graph.jl | 17 ++++++++------- src/message_passing.jl | 2 +- .../test_posterior_factorization.jl | 10 +++++++++ 5 files changed, 49 insertions(+), 16 deletions(-) diff --git a/src/algorithms/inference_algorithm.jl b/src/algorithms/inference_algorithm.jl index bbe972f3..dd96a880 100644 --- a/src/algorithms/inference_algorithm.jl +++ b/src/algorithms/inference_algorithm.jl @@ -49,10 +49,10 @@ end """ Create a message passing algorithm to infer marginals over a posterior distribution """ -function messagePassingAlgorithm(target_variables::Vector{Variable}=Variable[], # Quantities of interest - pfz::PosteriorFactorization=currentPosteriorFactorization(); +function messagePassingAlgorithm(target_variables::Vector{Variable}, # Quantities of interest + pfz::PosteriorFactorization; ep_sites=Tuple[], - id=Symbol(""), + id=Symbol(""), free_energy=false) if isempty(pfz.posterior_factors) # If no factorization is defined @@ -92,11 +92,20 @@ function messagePassingAlgorithm(target_variables::Vector{Variable}=Variable[], return algo end -messagePassingAlgorithm(target_variable::Variable, - pfz::PosteriorFactorization=currentPosteriorFactorization(); +messagePassingAlgorithm(target_variables::Vector{Variable}; ep_sites=Tuple[], id=Symbol(""), - free_energy=false) = messagePassingAlgorithm([target_variable], pfz; ep_sites=ep_sites, id=id, free_energy=free_energy) + free_energy=false) = messagePassingAlgorithm(target_variables, currentPosteriorFactorization(); ep_sites=ep_sites, id=id, free_energy=free_energy) + +messagePassingAlgorithm(target_variable::Variable; + ep_sites=Tuple[], + id=Symbol(""), + free_energy=false) = messagePassingAlgorithm([target_variable], currentPosteriorFactorization(); ep_sites=ep_sites, id=id, free_energy=free_energy) + +messagePassingAlgorithm(pfz::PosteriorFactorization=currentPosteriorFactorization(); + ep_sites=Tuple[], + id=Symbol(""), + free_energy=false) = messagePassingAlgorithm(Variable[], pfz; ep_sites=ep_sites, id=id, free_energy=free_energy) function interfaceToScheduleEntry(algo::InferenceAlgorithm) mapping = Dict{Interface, ScheduleEntry}() diff --git a/src/algorithms/posterior_factorization.jl b/src/algorithms/posterior_factorization.jl index 16eb140f..7f44b442 100644 --- a/src/algorithms/posterior_factorization.jl +++ b/src/algorithms/posterior_factorization.jl @@ -85,9 +85,23 @@ function PosteriorFactorization(args::Vararg{Union{T, Set{T}, Vector{T}} where T PosteriorFactor(arg, id=ids[i]) end end + + # Verify that all stochastic edges are covered by a posterior factor + uncovered_variables = uncoveredVariables(pfz) + isempty(uncovered_variables) || error("Edges for stochastic variables $([var.id for var in uncovered_variables]) must be covered by a posterior factor") + return pfz end +function uncoveredVariables(pfz::PosteriorFactorization) + stochastic_edges = setdiff(Set(pfz.graph.edges), pfz.deterministic_edges) + covered_stochastic_edges = Set(Iterators.flatten([pf.internal_edges for (_, pf) in pfz.posterior_factors])) + uncovered_edges = setdiff(stochastic_edges, covered_stochastic_edges) + uncovered_variables = Set([edge.variable for edge in uncovered_edges]) + + return uncovered_variables +end + iterate(pfz::PosteriorFactorization) = iterate(pfz.posterior_factors) iterate(pfz::PosteriorFactorization, state) = iterate(pfz.posterior_factors, state) values(pfz::PosteriorFactorization) = values(pfz.posterior_factors) @@ -216,7 +230,6 @@ Return the local stochastic regions around `node` function localStochasticRegions(node::FactorNode, pfz::PosteriorFactorization) regions = Region[] for interface in node.interfaces - partner = ultimatePartner(interface) if !(interface.edge in pfz.deterministic_edges) # If edge is stochastic push!(regions, region(node, interface.edge)) end diff --git a/src/factor_graph.jl b/src/factor_graph.jl index 02ddff1c..1d333945 100644 --- a/src/factor_graph.jl +++ b/src/factor_graph.jl @@ -43,6 +43,7 @@ function FactorGraph() Dict{Type, Int}(), Dict{Clamp, Tuple{Symbol, Int}}())) end + """ Automatically generate a unique id based on the current counter value for the element type. """ @@ -108,8 +109,8 @@ function nodes(edgeset::Set{Edge}) # Return all nodes connected to edgeset connected_nodes = Set{FactorNode}() for edge in edgeset - (edge.a == nothing) || push!(connected_nodes, edge.a.node) - (edge.b == nothing) || push!(connected_nodes, edge.b.node) + (edge.a === nothing) || push!(connected_nodes, edge.a.node) + (edge.b === nothing) || push!(connected_nodes, edge.b.node) end return connected_nodes @@ -160,7 +161,7 @@ interface.partner. In case of a Terminal node, it finds the first non-Terminal partner on a higher level factor graph. """ function ultimatePartner(interface::Interface) - if (interface.partner != nothing) && isa(interface.partner.node, Terminal) + if (interface.partner !== nothing) && isa(interface.partner.node, Terminal) return ultimatePartner(interface.partner.node.outer_interface) else return interface.partner @@ -193,8 +194,8 @@ function deterministicEdgeSet(root_edge::Edge) # Collect all interfaces in edge set interfaces = Interface[] for edge in edge_set - (edge.a == nothing) || push!(interfaces, edge.a) - (edge.b == nothing) || push!(interfaces, edge.b) + (edge.a === nothing) || push!(interfaces, edge.a) + (edge.b === nothing) || push!(interfaces, edge.b) end # Identify a schedule that propagates through the entire edge set @@ -210,8 +211,8 @@ function deterministicEdgeSet(root_edge::Edge) # Determine deterministic edges deterministic_edges = Set{Edge}() for edge in edge_set - deterministic_a = (edge.a != nothing) && is_deterministic[edge.a] - deterministic_b = (edge.b != nothing) && is_deterministic[edge.b] + deterministic_a = (edge.a !== nothing) && is_deterministic[edge.a] + deterministic_b = (edge.b !== nothing) && is_deterministic[edge.b] # An edge is deterministic if any of its interfaces are deterministic if deterministic_a || deterministic_b @@ -242,7 +243,7 @@ function isDeterministic(interface::Interface, is_deterministic::Dict{Interface, for iface in node.interfaces if iface != interface partner = ultimatePartner(iface) - if partner == nothing # Dangling edge + if partner === nothing # Dangling edge push!(inbounds_deterministic, false) # Unconstrained inbound is considered stochastic (uninformative) else push!(inbounds_deterministic, is_deterministic[partner]) diff --git a/src/message_passing.jl b/src/message_passing.jl index 8d90f5a4..07805b63 100644 --- a/src/message_passing.jl +++ b/src/message_passing.jl @@ -274,7 +274,7 @@ function interfaceToScheduleEntry(schedule::Schedule) for entry in schedule interface = entry.interface mapping[interface] = entry - while (interface.partner != nothing) && isa(interface.partner.node, Terminal) + while (interface.partner !== nothing) && isa(interface.partner.node, Terminal) interface = interface.partner.node.outer_interface mapping[interface] = entry end diff --git a/test/algorithms/test_posterior_factorization.jl b/test/algorithms/test_posterior_factorization.jl index 96027dae..d53561c5 100644 --- a/test/algorithms/test_posterior_factorization.jl +++ b/test/algorithms/test_posterior_factorization.jl @@ -81,4 +81,14 @@ end @test deterministicEdges(fg) == Set{Edge}([e_1, e_2, e_3]) end +@testset "Graph with incomplete posterior factorization coverage" begin + fg = FactorGraph() + + @RV a ~ GaussianMeanVariance(0.0, 1.0) + @RV b ~ GaussianMeanVariance(a, 1.0) + @RV c ~ GaussianMeanVariance(b, 1.0) + + @test_throws Exception PosteriorFactorization(a, b) +end + end # module \ No newline at end of file From 22ce1adaec1735dc9a2ec572eee7cb78fd55ebc9 Mon Sep 17 00:00:00 2001 From: Thijs van de Laar Date: Mon, 9 May 2022 15:32:20 +0200 Subject: [PATCH 2/6] add product special cases --- src/factor_nodes/categorical.jl | 11 +++++++++++ src/factor_nodes/dirichlet.jl | 9 +++++++++ test/factor_nodes/test_categorical.jl | 1 + test/factor_nodes/test_dirichlet.jl | 1 + 4 files changed, 22 insertions(+) diff --git a/src/factor_nodes/categorical.jl b/src/factor_nodes/categorical.jl index 4701698e..9ce8ae64 100644 --- a/src/factor_nodes/categorical.jl +++ b/src/factor_nodes/categorical.jl @@ -110,6 +110,17 @@ function prod!( x::Distribution{Univariate, Categorical}, return z end +@symmetrical function prod!(x::Distribution{Univariate, Categorical}, + y::Distribution{Univariate, Bernoulli}, + z::Distribution{Univariate, Categorical}=Distribution(Univariate, Categorical, p=ones(size(x.params[:p]))./length(x.params[:p]))) + + z.params[:p][:] = clamp.(x.params[:p] .* [y.params[:p], 1.0-y.params[:p]], tiny, Inf) # Soften vanishing probabilities + norm = sum(z.params[:p]) + z.params[:p] = z.params[:p]./norm + + return z +end + @symmetrical function prod!(x::Distribution{Univariate, Categorical}, y::Distribution{Multivariate, PointMass}, z::Distribution{Multivariate, PointMass}=Distribution(Multivariate, PointMass, m=[0.0])) diff --git a/src/factor_nodes/dirichlet.jl b/src/factor_nodes/dirichlet.jl index 05d67f96..a58a694a 100644 --- a/src/factor_nodes/dirichlet.jl +++ b/src/factor_nodes/dirichlet.jl @@ -84,6 +84,15 @@ function prod!( x::Distribution{V, Dirichlet}, return z end +@symmetrical function prod!(x::Distribution{Multivariate, Dirichlet}, + y::Distribution{Univariate, Beta}, + z::Distribution{Multivariate, Dirichlet}=Distribution(Multivariate, Dirichlet, a=ones(2))) + + z.params[:a] = x.params[:a] + [y.params[:a], y.params[:b]] .- 1.0 + + return z +end + @symmetrical function prod!(x::Distribution{Multivariate, Dirichlet}, y::Distribution{Multivariate, PointMass}, z::Distribution{Multivariate, PointMass}=Distribution(Multivariate, PointMass, m=[NaN])) diff --git a/test/factor_nodes/test_categorical.jl b/test/factor_nodes/test_categorical.jl index 5e780ae7..2670b7e2 100644 --- a/test/factor_nodes/test_categorical.jl +++ b/test/factor_nodes/test_categorical.jl @@ -44,6 +44,7 @@ end @testset "prod!" begin @test Distribution(Categorical, p=[0.2, 0.8])*Distribution(Categorical, p=[0.8, 0.2]) == Distribution(Categorical, p=[0.5, 0.5]) + @test Distribution(Categorical, p=[0.2, 0.8])*Distribution(Bernoulli, p=0.8) == Distribution(Categorical, p=[0.5, 0.5]) @test Distribution(Categorical, p=[0.25, 0.5, 0.25]) * Distribution(Categorical, p=[1/3, 1/3, 1/3]) == Distribution(Categorical, p=[0.25, 0.5, 0.25]) @test Distribution(Categorical, p=[0.0, 0.5, 0.5]) * Distribution(Categorical, p=[1.0, 0.0, 0.0]) == Distribution(Categorical, p=ones(3)/3) end diff --git a/test/factor_nodes/test_dirichlet.jl b/test/factor_nodes/test_dirichlet.jl index 8d665c9f..6c53d6d2 100644 --- a/test/factor_nodes/test_dirichlet.jl +++ b/test/factor_nodes/test_dirichlet.jl @@ -39,6 +39,7 @@ end @testset "prod!" begin # Multivariate @test Distribution(Multivariate, Dirichlet, a=[2.0, 2.0]) * Distribution(Multivariate, Dirichlet, a=[2.0, 3.0]) == Distribution(Multivariate, Dirichlet, a=[3.0, 4.0]) + @test Distribution(Multivariate, Dirichlet, a=[2.0, 2.0]) * Distribution(Univariate, Beta, a=2.0, b=3.0) == Distribution(Multivariate, Dirichlet, a=[3.0, 4.0]) @test Distribution(Multivariate, Dirichlet, a=[1.0, 2.0, 3.0]) * Distribution(Multivariate, PointMass, m=[0.1, 0.8, 0.1]) == Distribution(Multivariate, PointMass, m=[0.1, 0.8, 0.1]) @test Distribution(Multivariate, PointMass, m=[0.1, 0.8, 0.1]) * Distribution(Multivariate, Dirichlet, a=[1.0, 2.0, 3.0]) == Distribution(Multivariate, PointMass, m=[0.1, 0.8, 0.1]) From eb50b0b373f0a2126c20f12de02758edde683db4 Mon Sep 17 00:00:00 2001 From: Thijs van de Laar Date: Mon, 9 May 2022 16:14:58 +0200 Subject: [PATCH 3/6] resolve key nothing error for missing inbound message --- src/algorithms/expectation_propagation.jl | 4 ++-- src/algorithms/naive_variational_bayes.jl | 4 ++-- src/algorithms/posterior_factor.jl | 2 +- src/algorithms/structured_variational_bayes.jl | 4 ++-- src/algorithms/sum_product.jl | 8 ++++---- src/interface.jl | 2 +- src/message_passing.jl | 2 +- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/algorithms/expectation_propagation.jl b/src/algorithms/expectation_propagation.jl index fd078faf..54751ae8 100644 --- a/src/algorithms/expectation_propagation.jl +++ b/src/algorithms/expectation_propagation.jl @@ -11,7 +11,7 @@ messagePassingSchedule(variable::Variable) = messagePassingSchedule([variable]) function inferUpdateRule!(entry::ScheduleEntry, rule_type::Type{T}, - inferred_outbound_types::Dict{Interface, <:Type} + inferred_outbound_types::Dict ) where T<:ExpectationPropagationRule # Collect inbound types inbound_types = collectInboundTypes(entry, rule_type, inferred_outbound_types) @@ -41,7 +41,7 @@ end function collectInboundTypes(entry::ScheduleEntry, ::Type{T}, - inferred_outbound_types::Dict{Interface, <:Type} + inferred_outbound_types::Dict ) where T<:ExpectationPropagationRule inbound_message_types = Type[] for node_interface in entry.interface.node.interfaces diff --git a/src/algorithms/naive_variational_bayes.jl b/src/algorithms/naive_variational_bayes.jl index 4a9c89a7..2b039250 100644 --- a/src/algorithms/naive_variational_bayes.jl +++ b/src/algorithms/naive_variational_bayes.jl @@ -12,7 +12,7 @@ Infer the update rule that computes the message for `entry`, as dependent on the """ function inferUpdateRule!( entry::ScheduleEntry, rule_type::Type{T}, - inferred_outbound_types::Dict{Interface, Type}) where T<:NaiveVariationalRule + inferred_outbound_types::Dict) where T<:NaiveVariationalRule # Collect inbound types inbound_types = collectInboundTypes(entry, rule_type, inferred_outbound_types) @@ -42,7 +42,7 @@ Returns a vector with inbound types that correspond with required interfaces. """ function collectInboundTypes( entry::ScheduleEntry, ::Type{T}, - inferred_outbound_types::Dict{Interface, Type}) where T<:NaiveVariationalRule + inferred_outbound_types::Dict) where T<:NaiveVariationalRule inbound_types = Type[] for node_interface in entry.interface.node.interfaces if node_interface === entry.interface diff --git a/src/algorithms/posterior_factor.jl b/src/algorithms/posterior_factor.jl index 659c018f..5d75428a 100644 --- a/src/algorithms/posterior_factor.jl +++ b/src/algorithms/posterior_factor.jl @@ -71,7 +71,7 @@ function messagePassingSchedule(pf::PosteriorFactor) end end - breaker_types = breakerTypes(collect(pf.breaker_interfaces)) + breaker_types = breakerTypes(collect(pf.breaker_interfaces)) # CONTINUE inferUpdateRules!(schedule, inferred_outbound_types=breaker_types) return schedule diff --git a/src/algorithms/structured_variational_bayes.jl b/src/algorithms/structured_variational_bayes.jl index 8da7d2f5..97dc8128 100644 --- a/src/algorithms/structured_variational_bayes.jl +++ b/src/algorithms/structured_variational_bayes.jl @@ -12,7 +12,7 @@ Infer the update rule that computes the message for `entry`, as dependent on the """ function inferUpdateRule!(entry::ScheduleEntry, rule_type::Type{T}, - inferred_outbound_types::Dict{Interface, Type} + inferred_outbound_types::Dict ) where T<:StructuredVariationalRule # Collect inbound types inbound_types = collectInboundTypes(entry, rule_type, inferred_outbound_types) @@ -43,7 +43,7 @@ Returns a vector with inbound types that correspond with required interfaces. """ function collectInboundTypes(entry::ScheduleEntry, ::Type{T}, - inferred_outbound_types::Dict{Interface, Type} + inferred_outbound_types::Dict ) where T<:StructuredVariationalRule inbound_types = Type[] entry_posterior_factor = posteriorFactor(entry.interface.edge) # Collect posterior factor for outbound edge diff --git a/src/algorithms/sum_product.jl b/src/algorithms/sum_product.jl index 530f7abd..259cf164 100644 --- a/src/algorithms/sum_product.jl +++ b/src/algorithms/sum_product.jl @@ -14,10 +14,10 @@ message out of the specified `outbound_interface`. """ function internalSumProductSchedule(cnode::CompositeFactor, outbound_interface::Interface, - inferred_outbound_types::Dict{Interface, <:Type}) + inferred_outbound_types::Dict) # Collect types of messages towards the CompositeFactor - msg_types = Dict{Interface, Type}() + msg_types = Dict{Union{Interface, Nothing}, Type}(nothing => Nothing) # Initialize with fallback for (idx, terminal) in enumerate(cnode.terminals) (cnode.interfaces[idx] === outbound_interface) && continue # don't need incoming msg on outbound interface msg_types[terminal.interfaces[1]] = inferred_outbound_types[cnode.interfaces[idx].partner] @@ -51,7 +51,7 @@ end function inferUpdateRule!(entry::ScheduleEntry, rule_type::Type{T}, - inferred_outbound_types::Dict{Interface, <:Type} + inferred_outbound_types::Dict ) where T<:SumProductRule # Collect inbound types inbound_types = collectInboundTypes(entry, rule_type, inferred_outbound_types) @@ -85,7 +85,7 @@ end function collectInboundTypes(entry::ScheduleEntry, ::Type{T}, - inferred_outbound_types::Dict{Interface, <:Type} + inferred_outbound_types::Dict ) where T<:SumProductRule inbound_message_types = Type[] for node_interface in entry.interface.node.interfaces diff --git a/src/interface.jl b/src/interface.jl index 17a54f3b..da873f49 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -65,7 +65,7 @@ end Constructs breaker types dictionary for breaker sites """ function breakerTypes(breaker_sites::Vector{Interface}) - breaker_types = Dict{Interface, Type}() # Initialize Interface to Message dictionary + breaker_types = Dict{Union{Interface, Nothing}, Type}(nothing => Nothing) # Initialize with fallback for site in breaker_sites (breaker_type, _) = breakerParameters(site) breaker_types[site] = breaker_type diff --git a/src/message_passing.jl b/src/message_passing.jl index 07805b63..635f8346 100644 --- a/src/message_passing.jl +++ b/src/message_passing.jl @@ -198,7 +198,7 @@ end """ inferUpdateRules!(schedule) infers specific message update rules for all schedule entries. """ -function inferUpdateRules!(schedule::Schedule; inferred_outbound_types=Dict{Interface, Type}()) +function inferUpdateRules!(schedule::Schedule; inferred_outbound_types=Dict{Union{Nothing, Interface}, Type}(nothing => Nothing)) for entry in schedule (entry.message_update_rule == Nothing) && error("No msg update rule type specified for $(entry)") if !isconcretetype(entry.message_update_rule) From 5589a8c1308bf43d22c9140e7b7d64e60383779f Mon Sep 17 00:00:00 2001 From: Thijs van de Laar Date: Tue, 10 May 2022 10:43:39 +0200 Subject: [PATCH 4/6] update error messages regarding applicable rules --- src/algorithms/expectation_propagation.jl | 8 +---- src/algorithms/inference_algorithm.jl | 29 +++++++++++++++++++ src/algorithms/joint_marginals.jl | 10 +------ src/algorithms/naive_variational_bayes.jl | 8 +---- .../structured_variational_bayes.jl | 14 +++------ src/algorithms/sum_product.jl | 19 ++++-------- src/distribution.jl | 4 +++ src/interface.jl | 2 +- src/message_passing.jl | 8 +++-- 9 files changed, 53 insertions(+), 49 deletions(-) diff --git a/src/algorithms/expectation_propagation.jl b/src/algorithms/expectation_propagation.jl index 54751ae8..a7c6b197 100644 --- a/src/algorithms/expectation_propagation.jl +++ b/src/algorithms/expectation_propagation.jl @@ -28,13 +28,7 @@ function inferUpdateRule!(entry::ScheduleEntry, end # Select and set applicable rule - if isempty(applicable_rules) - error("No applicable $(rule_type) update for $(typeof(entry.interface.node)) node with inbound types: $(join(inbound_types, ", "))") - elseif length(applicable_rules) > 1 - error("Multiple applicable $(rule_type) updates for $(typeof(entry.interface.node)) node with inbound types: $(join(inbound_types, ", ")): $(join(applicable_rules, ", "))") - else - entry.message_update_rule = first(applicable_rules) - end + entry.message_update_rule = selectApplicableRule(rule_type, entry, inbound_types, applicable_rules) return entry end diff --git a/src/algorithms/inference_algorithm.jl b/src/algorithms/inference_algorithm.jl index dd96a880..60aee9f8 100644 --- a/src/algorithms/inference_algorithm.jl +++ b/src/algorithms/inference_algorithm.jl @@ -125,4 +125,33 @@ function targetToMarginalEntry(algo::InferenceAlgorithm) end return mapping +end + +"""Select an applicable update rule for scheduling""" +function selectApplicableRule(rule_type::Type, + entry::ScheduleEntry, # Message update + inbound_types::Vector{<:Type}, + applicable_rules::Vector{<:Type}) + + if isempty(applicable_rules) + error("No applicable $(rule_type) for interface :$(handle(entry.interface)) of $(entry.interface.node.id) with inbound types:\n[$(join(format.(inbound_types), ", "))]") + elseif length(applicable_rules) > 1 + error("Multiple applicable $(rule_type) for interface :$(handle(entry.interface)) of $(entry.interface.node.id) with inbound types:\n[$(join(format.(inbound_types), ", "))]\n[$(join(applicable_rules, ", "))]") + else + return first(applicable_rules) + end +end + +function selectApplicableRule(rule_type::Type, + cluster::Cluster, # Marginal update + inbound_types::Vector{<:Type}, + applicable_rules::Vector{<:Type}) + + if isempty(applicable_rules) + error("No applicable $(rule_type) for $(cluster.node.id) with inbound types:\n[$(join(format.(inbound_types), ", "))]") + elseif length(applicable_rules) > 1 + error("Multiple applicable $(rule_type) for $(cluster.node.id) with inbound types:\n[$(join(format.(inbound_types), ", "))]\n[$(join(applicable_rules, ", "))]") + else + return first(applicable_rules) + end end \ No newline at end of file diff --git a/src/algorithms/joint_marginals.jl b/src/algorithms/joint_marginals.jl index bf47a58a..cd618639 100644 --- a/src/algorithms/joint_marginals.jl +++ b/src/algorithms/joint_marginals.jl @@ -42,15 +42,7 @@ function inferMarginalRule(cluster::Cluster, inbound_types::Vector{<:Type}) end # Select and set applicable rule - if isempty(applicable_rules) - error("No applicable marginal update rule for $(typeof(cluster.node)) node with inbound types: $(join(inbound_types, ", "))") - elseif length(applicable_rules) > 1 - error("Multiple applicable marginal update rules for $(typeof(cluster.node)) node with inbound types: $(join(inbound_types, ", ")): $(join(applicable_rules, ", "))") - else - marginal_update_rule = first(applicable_rules) - end - - return marginal_update_rule + return selectApplicableRule(rule_type, cluster, inbound_types, applicable_rules) end """ diff --git a/src/algorithms/naive_variational_bayes.jl b/src/algorithms/naive_variational_bayes.jl index 2b039250..5f1311c8 100644 --- a/src/algorithms/naive_variational_bayes.jl +++ b/src/algorithms/naive_variational_bayes.jl @@ -25,13 +25,7 @@ function inferUpdateRule!( entry::ScheduleEntry, end # Select and set applicable rule - if isempty(applicable_rules) - error("No applicable $(rule_type) update for $(typeof(entry.interface.node)) node with inbound types: $(join(inbound_types, ", "))") - elseif length(applicable_rules) > 1 - error("Multiple applicable $(rule_type) updates for $(typeof(entry.interface.node)) node with inbound types: $(join(inbound_types, ", ")): $(join(applicable_rules, ", "))") - else - entry.message_update_rule = first(applicable_rules) - end + entry.message_update_rule = selectApplicableRule(rule_type, entry, inbound_types, applicable_rules) return entry end diff --git a/src/algorithms/structured_variational_bayes.jl b/src/algorithms/structured_variational_bayes.jl index 97dc8128..52aa0a83 100644 --- a/src/algorithms/structured_variational_bayes.jl +++ b/src/algorithms/structured_variational_bayes.jl @@ -11,9 +11,9 @@ abstract type StructuredVariationalRule{factor_type} <: MessageUpdateRule end Infer the update rule that computes the message for `entry`, as dependent on the inbound types """ function inferUpdateRule!(entry::ScheduleEntry, - rule_type::Type{T}, - inferred_outbound_types::Dict - ) where T<:StructuredVariationalRule + rule_type::Type{<:StructuredVariationalRule}, + inferred_outbound_types::Dict) + # Collect inbound types inbound_types = collectInboundTypes(entry, rule_type, inferred_outbound_types) @@ -26,13 +26,7 @@ function inferUpdateRule!(entry::ScheduleEntry, end # Select and set applicable rule - if isempty(applicable_rules) - error("No applicable $(rule_type) update for $(typeof(entry.interface.node)) node with inbound types: $(join(inbound_types, ", "))") - elseif length(applicable_rules) > 1 - error("Multiple applicable $(rule_type) updates for $(typeof(entry.interface.node)) node with inbound types: $(join(inbound_types, ", ")): $(join(applicable_rules, ", "))") - else - entry.message_update_rule = first(applicable_rules) - end + entry.message_update_rule = selectApplicableRule(rule_type, entry, inbound_types, applicable_rules) return entry end diff --git a/src/algorithms/sum_product.jl b/src/algorithms/sum_product.jl index 259cf164..3a9b4967 100644 --- a/src/algorithms/sum_product.jl +++ b/src/algorithms/sum_product.jl @@ -17,7 +17,7 @@ function internalSumProductSchedule(cnode::CompositeFactor, inferred_outbound_types::Dict) # Collect types of messages towards the CompositeFactor - msg_types = Dict{Union{Interface, Nothing}, Type}(nothing => Nothing) # Initialize with fallback + msg_types = Dict{Union{Interface, Nothing}, Type}(nothing => Missing) # Initialize with fallback for (idx, terminal) in enumerate(cnode.terminals) (cnode.interfaces[idx] === outbound_interface) && continue # don't need incoming msg on outbound interface msg_types[terminal.interfaces[1]] = inferred_outbound_types[cnode.interfaces[idx].partner] @@ -65,19 +65,12 @@ function inferUpdateRule!(entry::ScheduleEntry, end # Select and set applicable rule - if isempty(applicable_rules) - if isa(entry.interface.node, CompositeFactor) - # No 'shortcut rule' available for CompositeFactor. - # Try to fall back to msg passing on the internal graph. - entry.internal_schedule = internalSumProductSchedule(entry.interface.node, entry.interface, inferred_outbound_types) - entry.message_update_rule = entry.internal_schedule[end].message_update_rule - else - error("No applicable $(rule_type) update for $(typeof(entry.interface.node)) node with inbound types: $(join(inbound_types, ", "))") - end - elseif length(applicable_rules) > 1 - error("Multiple applicable $(rule_type) updates for $(typeof(entry.interface.node)) node with inbound types: $(join(inbound_types, ", ")): $(join(applicable_rules, ", "))") + if isempty(applicable_rules) && isa(entry.interface.node, CompositeFactor) + # No shortcut rule available for CompositeFactor, fall back to message passing on the internal graph. + entry.internal_schedule = internalSumProductSchedule(entry.interface.node, entry.interface, inferred_outbound_types) + entry.message_update_rule = entry.internal_schedule[end].message_update_rule else - entry.message_update_rule = first(applicable_rules) + entry.message_update_rule = selectApplicableRule(rule_type, entry, inbound_types, applicable_rules) end return entry diff --git a/src/distribution.jl b/src/distribution.jl index 05110c98..c957644a 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -37,6 +37,10 @@ end const P = Distribution const ProbabilityDistribution = Distribution # For backwards compatibility +"""Distribution formatting for concise printing""" +format(::Type{<:Distribution{V}}) where V<:VariateType = "Distribution{$V}" +format(::Type{<:Distribution{V, F}}) where {V<:VariateType, F<:FactorFunction} = "Distribution{$V, $F}" + """Sample multiple realizations from a probability distribution""" sample(dist::Distribution, n_samples::Int64) = [sample(dist) for i in 1:n_samples] # TODO: individual samples can be optimized diff --git a/src/interface.jl b/src/interface.jl index da873f49..21e9174e 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -65,7 +65,7 @@ end Constructs breaker types dictionary for breaker sites """ function breakerTypes(breaker_sites::Vector{Interface}) - breaker_types = Dict{Union{Interface, Nothing}, Type}(nothing => Nothing) # Initialize with fallback + breaker_types = Dict{Union{Interface, Nothing}, Type}(nothing => Missing) # Initialize with fallback for site in breaker_sites (breaker_type, _) = breakerParameters(site) breaker_types[site] = breaker_type diff --git a/src/message_passing.jl b/src/message_passing.jl index 635f8346..179264ed 100644 --- a/src/message_passing.jl +++ b/src/message_passing.jl @@ -27,7 +27,11 @@ end """Shorthand notation for Message definition""" const M = Message -family(msg_type::Type{Message{F}}) where F<:FactorFunction = F +family(::Type{Message{F}}) where F<:FactorFunction = F + +"""Message formatting for concise printing""" +format(::Type{<:Message{F}}) where F<:FactorFunction = "Message{$F}" +format(::Type{<:Message{F, V}}) where {F<:FactorFunction, V<:VariateType} = "Message{$F, $V}" function show(io::IO, msg::Message) if isdefined(msg, :scaling_factor) @@ -198,7 +202,7 @@ end """ inferUpdateRules!(schedule) infers specific message update rules for all schedule entries. """ -function inferUpdateRules!(schedule::Schedule; inferred_outbound_types=Dict{Union{Nothing, Interface}, Type}(nothing => Nothing)) +function inferUpdateRules!(schedule::Schedule; inferred_outbound_types=Dict{Union{Nothing, Interface}, Type}(nothing => Missing)) for entry in schedule (entry.message_update_rule == Nothing) && error("No msg update rule type specified for $(entry)") if !isconcretetype(entry.message_update_rule) From c7916e9f2a64b3a64bd6dfcb670a5740059555cb Mon Sep 17 00:00:00 2001 From: Thijs van de Laar Date: Tue, 10 May 2022 11:37:29 +0200 Subject: [PATCH 5/6] implement structured updates for Gaussian{Moments} --- src/algorithms/joint_marginals.jl | 3 +- .../julia/update_rules/gaussian_moments.jl | 52 ++++++++++++++++++- .../julia/update_rules/gaussian_precision.jl | 12 ++--- src/update_rules/gaussian_moments.jl | 20 +++++++ test/factor_nodes/test_gaussian_moments.jl | 38 +++++++++++++- 5 files changed, 113 insertions(+), 12 deletions(-) diff --git a/src/algorithms/joint_marginals.jl b/src/algorithms/joint_marginals.jl index cd618639..8b18f797 100644 --- a/src/algorithms/joint_marginals.jl +++ b/src/algorithms/joint_marginals.jl @@ -35,7 +35,8 @@ Infer the rule that computes the joint marginal over `cluster` function inferMarginalRule(cluster::Cluster, inbound_types::Vector{<:Type}) # Find applicable rule(s) applicable_rules = Type[] - for rule in leaftypes(MarginalRule{typeof(cluster.node)}) + rule_type = MarginalRule{typeof(cluster.node)} + for rule in leaftypes(rule_type) if isApplicable(rule, inbound_types) push!(applicable_rules, rule) end diff --git a/src/engines/julia/update_rules/gaussian_moments.jl b/src/engines/julia/update_rules/gaussian_moments.jl index 644d2c1c..9d9ede51 100644 --- a/src/engines/julia/update_rules/gaussian_moments.jl +++ b/src/engines/julia/update_rules/gaussian_moments.jl @@ -10,7 +10,11 @@ ruleSPGaussianMomentsMSNP, ruleSPGaussianMomentsOutNGS, ruleSPGaussianMomentsMGNS, ruleVBGaussianMomentsM, -ruleVBGaussianMomentsOut +ruleVBGaussianMomentsOut, +ruleSVBGaussianMomentsOutVGD, +ruleSVBGaussianMomentsMGVD, +ruleMGaussianMomentsGGD, +ruleMGaussianMomentsGGN ruleSPGaussianMomentsOutNPP(msg_out::Nothing, msg_mean::Message{PointMass, V}, @@ -92,4 +96,48 @@ ruleVBGaussianMomentsM(dist_out::Distribution{V}, ruleVBGaussianMomentsOut(dist_out::Any, dist_mean::Distribution{V}, dist_var::Distribution) where V<:VariateType = - Message(V, Gaussian{Moments}, m=unsafeMean(dist_mean), v=unsafeMean(dist_var)) \ No newline at end of file + Message(V, Gaussian{Moments}, m=unsafeMean(dist_mean), v=unsafeMean(dist_var)) + +ruleSVBGaussianMomentsOutVGD(dist_out::Any, # Only implemented for PointMass variance + msg_mean::Message{<:Gaussian, V}, + dist_var::Distribution{<:VariateType, PointMass}) where V<:VariateType = + Message(V, Gaussian{Moments}, m=unsafeMean(msg_mean.dist), v=unsafeCov(msg_mean.dist) + unsafeMean(dist_var)) + +ruleSVBGaussianMomentsMGVD(msg_out::Message{F, V}, # Only implemented for PointMass variance + dist_mean::Any, + dist_var::Distribution{<:VariateType, PointMass}) where {F<:Gaussian, V<:VariateType} = + Message(V, Gaussian{Moments}, m=unsafeMean(msg_out.dist), v=unsafeCov(msg_out.dist) + unsafeMean(dist_var)) + +function ruleMGaussianMomentsGGD( # Only implemented for PointMass variance + msg_out::Message{<:Gaussian, V}, + msg_mean::Message{<:Gaussian, V}, + dist_var::Distribution{<:VariateType, PointMass}) where V<:VariateType + + d_mean = convert(Distribution{V, Gaussian{Canonical}}, msg_mean.dist) + d_out = convert(Distribution{V, Gaussian{Canonical}}, msg_out.dist) + + xi_y = d_out.params[:xi] + W_y = d_out.params[:w] + xi_m = d_mean.params[:xi] + W_m = d_mean.params[:w] + W_bar = cholinv(unsafeMean(dist_var)) + + return Distribution(Multivariate, Gaussian{Canonical}, xi=[xi_y; xi_m], w=[W_y+W_bar -W_bar; -W_bar W_m+W_bar]) +end + +function ruleMGaussianMomentsGGN( + msg_out::Message{<:Gaussian, V}, + msg_mean::Message{<:Gaussian, V}, + msg_var::Message{PointMass}) where V<:VariateType + + d_mean = convert(Distribution{V, Gaussian{Canonical}}, msg_mean.dist) + d_out = convert(Distribution{V, Gaussian{Canonical}}, msg_out.dist) + + xi_y = d_out.params[:xi] + W_y = d_out.params[:w] + xi_m = d_mean.params[:xi] + W_m = d_mean.params[:w] + W_bar = cholinv(msg_var.dist.params[:m]) + + return Distribution(Multivariate, Gaussian{Canonical}, xi=[xi_y; xi_m], w=[W_y+W_bar -W_bar; -W_bar W_m+W_bar]) +end \ No newline at end of file diff --git a/src/engines/julia/update_rules/gaussian_precision.jl b/src/engines/julia/update_rules/gaussian_precision.jl index 7ff58bcd..d4e6a4e5 100644 --- a/src/engines/julia/update_rules/gaussian_precision.jl +++ b/src/engines/julia/update_rules/gaussian_precision.jl @@ -82,14 +82,10 @@ function ruleSVBGaussianPrecisionW( end end -function ruleSVBGaussianPrecisionMGVD(msg_out::Message{F, V}, - dist_mean::Any, - dist_prec::Distribution) where {F<:Gaussian, V<:VariateType} - - d_out = convert(Distribution{V, Gaussian{Moments}}, msg_out.dist) - - Message(V, Gaussian{Moments}, m=d_out.params[:m], v=d_out.params[:v] + cholinv(unsafeMean(dist_prec))) -end +ruleSVBGaussianPrecisionMGVD(msg_out::Message{F, V}, + dist_mean::Any, + dist_prec::Distribution) where {F<:Gaussian, V<:VariateType} = + Message(V, Gaussian{Moments}, m=unsafeMean(msg_out.dist), v=unsafeCov(msg_out.dist) + cholinv(unsafeMean(dist_prec))) function ruleMGaussianPrecisionGGD( msg_out::Message{<:Gaussian, V}, diff --git a/src/update_rules/gaussian_moments.jl b/src/update_rules/gaussian_moments.jl index e1615ead..a2f8795b 100644 --- a/src/update_rules/gaussian_moments.jl +++ b/src/update_rules/gaussian_moments.jl @@ -57,3 +57,23 @@ :outbound_type => Message{Gaussian{Moments}}, :inbound_types => (Nothing, Distribution, Distribution), :name => VBGaussianMomentsOut) + +@structuredVariationalRule(:node_type => Gaussian{Moments}, + :outbound_type => Message{Gaussian{Moments}}, + :inbound_types => (Nothing, Message{Gaussian}, Distribution), + :name => SVBGaussianMomentsOutVGD) + +@structuredVariationalRule(:node_type => Gaussian{Moments}, + :outbound_type => Message{Gaussian{Moments}}, + :inbound_types => (Message{Gaussian}, Nothing, Distribution), + :name => SVBGaussianMomentsMGVD) + +@marginalRule(:node_type => Gaussian{Moments}, + :inbound_types => (Message{Gaussian}, Message{Gaussian}, Distribution), + :name => MGaussianMomentsGGD) + +@marginalRule(:node_type => Gaussian{Moments}, + :inbound_types => (Message{Gaussian}, Message{Gaussian}, Nothing), # Variance is marginalized out + :name => MGaussianMomentsGGN) + + diff --git a/test/factor_nodes/test_gaussian_moments.jl b/test/factor_nodes/test_gaussian_moments.jl index b4040d0e..2e78c0e1 100644 --- a/test/factor_nodes/test_gaussian_moments.jl +++ b/test/factor_nodes/test_gaussian_moments.jl @@ -3,7 +3,7 @@ module GaussianMomentsTest using Test using ForneyLab using ForneyLab: outboundType, isApplicable, isProper, unsafeMean, unsafeMode, unsafeVar, unsafeCov, unsafeMeanCov, unsafePrecision, unsafeMeanPrecision, unsafeWeightedMean, unsafeWeightedMeanPrecision -using ForneyLab: SPGaussianMomentsOutNGS, SPGaussianMomentsOutNPP,SPGaussianMomentsMSNP, SPGaussianMomentsMPNP, SPGaussianMomentsOutNGP, SPGaussianMomentsMGNP, SPGaussianMomentsVGGN, SPGaussianMomentsVPGN, SPGaussianMomentsOutNSP, VBGaussianMomentsM, VBGaussianMomentsOut, bootstrap +using ForneyLab: SPGaussianMomentsOutNGS, SPGaussianMomentsOutNPP, SPGaussianMomentsMSNP, SPGaussianMomentsMPNP, SPGaussianMomentsOutNGP, SPGaussianMomentsMGNP, SPGaussianMomentsVGGN, SPGaussianMomentsVPGN, SPGaussianMomentsOutNSP, VBGaussianMomentsM, VBGaussianMomentsOut, SVBGaussianMomentsMGVD, SVBGaussianMomentsOutVGD, MGaussianMomentsGGD, MGaussianMomentsGGN, bootstrap @testset "default Gaussian{Moments} node definition" begin fg = FactorGraph() @@ -205,6 +205,42 @@ end @test ruleVBGaussianMomentsOut(nothing, Distribution(Multivariate, Gaussian{Moments}, m=[1.0], v=mat(2.0)), Distribution(MatrixVariate, PointMass, m=mat(3.0))) == Message(Multivariate, Gaussian{Moments}, m=[1.0], v=mat(3.0)) end +@testset "SVBGaussianMomentsMGVD" begin + @test SVBGaussianMomentsMGVD <: StructuredVariationalRule{Gaussian{Moments}} + @test outboundType(SVBGaussianMomentsMGVD) == Message{Gaussian{Moments}} + @test isApplicable(SVBGaussianMomentsMGVD, [Message{Gaussian}, Nothing, Distribution]) + + @test ruleSVBGaussianMomentsMGVD(Message(Univariate, Gaussian{Moments}, m=3.0, v=4.0), nothing, Distribution(Univariate, PointMass, m=2.0)) == Message(Univariate, Gaussian{Moments}, m=3.0, v=6.0) + @test ruleSVBGaussianMomentsMGVD(Message(Multivariate, Gaussian{Moments}, m=[3.0], v=mat(4.0)), nothing, Distribution(MatrixVariate, PointMass, m=mat(2.0))) == Message(Multivariate, Gaussian{Moments}, m=[3.0], v=mat(6.0)) +end + +@testset "SVBGaussianMomentsOutVGD" begin + @test SVBGaussianMomentsOutVGD <: StructuredVariationalRule{Gaussian{Moments}} + @test outboundType(SVBGaussianMomentsOutVGD) == Message{Gaussian{Moments}} + @test isApplicable(SVBGaussianMomentsOutVGD, [Nothing, Message{Gaussian}, Distribution]) + + @test ruleSVBGaussianMomentsOutVGD(nothing, Message(Univariate, Gaussian{Moments}, m=3.0, v=4.0), Distribution(Univariate, PointMass, m=2.0)) == Message(Univariate, Gaussian{Moments}, m=3.0, v=6.0) + @test ruleSVBGaussianMomentsOutVGD(nothing, Message(Multivariate, Gaussian{Moments}, m=[3.0], v=mat(4.0)), Distribution(MatrixVariate, PointMass, m=mat(2.0))) == Message(Multivariate, Gaussian{Moments}, m=[3.0], v=mat(6.0)) +end + +@testset "MGaussianMomentsGGD" begin + @test MGaussianMomentsGGD <: MarginalRule{Gaussian{Moments}} + @test isApplicable(MGaussianMomentsGGD, [Message{Gaussian}, Message{Gaussian}, Distribution]) + @test !isApplicable(MGaussianMomentsGGD, [Message{Gaussian}, Message{Gaussian}, Nothing]) + + @test ruleMGaussianMomentsGGD(Message(Univariate, Gaussian{Precision}, m=1.0, w=2.0), Message(Univariate, Gaussian{Precision}, m=3.0, w=4.0), Distribution(Univariate, PointMass, m=2.0)) == Distribution(Multivariate, Gaussian{Moments}, m=[1.3636363636363638, 2.8181818181818175], v=[0.4090909090909091 0.04545454545454545; 0.04545454545454545 0.22727272727272724]) + @test ruleMGaussianMomentsGGD(Message(Multivariate, Gaussian{Precision}, m=[1.0], w=mat(2.0)), Message(Multivariate, Gaussian{Precision}, m=[3.0], w=mat(4.0)), Distribution(MatrixVariate, PointMass, m=mat(2.0))) == Distribution(Multivariate, Gaussian{Moments}, m=[1.3636363636363638, 2.8181818181818175], v=[0.4090909090909091 0.04545454545454545; 0.04545454545454545 0.22727272727272724]) +end + +@testset "MGaussianMomentsGGN" begin + @test MGaussianMomentsGGN <: MarginalRule{Gaussian{Moments}} + @test isApplicable(MGaussianMomentsGGN, [Message{Gaussian}, Message{Gaussian}, Nothing]) + @test !isApplicable(MGaussianMomentsGGN, [Message{Gaussian}, Message{Gaussian}, Distribution]) + + @test ruleMGaussianMomentsGGN(Message(Univariate, Gaussian{Precision}, m=1.0, w=2.0), Message(Univariate, Gaussian{Precision}, m=3.0, w=4.0), Message(Univariate, PointMass, m=2.0)) == Distribution(Multivariate, Gaussian{Moments}, m=[1.3636363636363638, 2.8181818181818175], v=[0.4090909090909091 0.04545454545454545; 0.04545454545454545 0.22727272727272724]) + @test ruleMGaussianMomentsGGN(Message(Multivariate, Gaussian{Precision}, m=[1.0], w=mat(2.0)), Message(Multivariate, Gaussian{Precision}, m=[3.0], w=mat(4.0)), Message(MatrixVariate, PointMass, m=mat(2.0))) == Distribution(Multivariate, Gaussian{Moments}, m=[1.3636363636363638, 2.8181818181818175], v=[0.4090909090909091 0.04545454545454545; 0.04545454545454545 0.22727272727272724]) +end + @testset "averageEnergy and differentialEntropy" begin @test differentialEntropy(Distribution(Univariate, Gaussian{Moments}, m=0.0, v=2.0)) == averageEnergy(Gaussian{Moments}, Distribution(Univariate, Gaussian{Moments}, m=0.0, v=2.0), Distribution(Univariate, PointMass, m=0.0), Distribution(Univariate, PointMass, m=2.0)) @test differentialEntropy(Distribution(Univariate, Gaussian{Moments}, m=0.0, v=2.0)) == differentialEntropy(Distribution(Multivariate, Gaussian{Moments}, m=[0.0], v=mat(2.0))) From c1a905a4966983e9fdfb43f8b36106ebe9f97d50 Mon Sep 17 00:00:00 2001 From: Thijs van de Laar Date: Tue, 10 May 2022 11:51:31 +0200 Subject: [PATCH 6/6] finalize --- src/algorithms/posterior_factor.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/posterior_factor.jl b/src/algorithms/posterior_factor.jl index 5d75428a..659c018f 100644 --- a/src/algorithms/posterior_factor.jl +++ b/src/algorithms/posterior_factor.jl @@ -71,7 +71,7 @@ function messagePassingSchedule(pf::PosteriorFactor) end end - breaker_types = breakerTypes(collect(pf.breaker_interfaces)) # CONTINUE + breaker_types = breakerTypes(collect(pf.breaker_interfaces)) inferUpdateRules!(schedule, inferred_outbound_types=breaker_types) return schedule