Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error messages #204

Merged
merged 6 commits into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions src/algorithms/expectation_propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ messagePassingSchedule(variable::Variable) = messagePassingSchedule([variable])

function inferUpdateRule!(entry::ScheduleEntry,
rule_type::Type{T},
inferred_outbound_types::Dict{Interface, <:Type}
inferred_outbound_types::Dict
) where T<:ExpectationPropagationRule
# Collect inbound types
inbound_types = collectInboundTypes(entry, rule_type, inferred_outbound_types)
Expand All @@ -28,20 +28,14 @@ function inferUpdateRule!(entry::ScheduleEntry,
end

# Select and set applicable rule
if isempty(applicable_rules)
error("No applicable $(rule_type) update for $(typeof(entry.interface.node)) node with inbound types: $(join(inbound_types, ", "))")
elseif length(applicable_rules) > 1
error("Multiple applicable $(rule_type) updates for $(typeof(entry.interface.node)) node with inbound types: $(join(inbound_types, ", ")): $(join(applicable_rules, ", "))")
else
entry.message_update_rule = first(applicable_rules)
end
entry.message_update_rule = selectApplicableRule(rule_type, entry, inbound_types, applicable_rules)

return entry
end

function collectInboundTypes(entry::ScheduleEntry,
::Type{T},
inferred_outbound_types::Dict{Interface, <:Type}
inferred_outbound_types::Dict
) where T<:ExpectationPropagationRule
inbound_message_types = Type[]
for node_interface in entry.interface.node.interfaces
Expand Down
50 changes: 44 additions & 6 deletions src/algorithms/inference_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ end
"""
Create a message passing algorithm to infer marginals over a posterior distribution
"""
function messagePassingAlgorithm(target_variables::Vector{Variable}=Variable[], # Quantities of interest
pfz::PosteriorFactorization=currentPosteriorFactorization();
function messagePassingAlgorithm(target_variables::Vector{Variable}, # Quantities of interest
pfz::PosteriorFactorization;
ep_sites=Tuple[],
id=Symbol(""),
id=Symbol(""),
free_energy=false)

if isempty(pfz.posterior_factors) # If no factorization is defined
Expand Down Expand Up @@ -92,11 +92,20 @@ function messagePassingAlgorithm(target_variables::Vector{Variable}=Variable[],
return algo
end

messagePassingAlgorithm(target_variable::Variable,
pfz::PosteriorFactorization=currentPosteriorFactorization();
messagePassingAlgorithm(target_variables::Vector{Variable};
ep_sites=Tuple[],
id=Symbol(""),
free_energy=false) = messagePassingAlgorithm([target_variable], pfz; ep_sites=ep_sites, id=id, free_energy=free_energy)
free_energy=false) = messagePassingAlgorithm(target_variables, currentPosteriorFactorization(); ep_sites=ep_sites, id=id, free_energy=free_energy)

messagePassingAlgorithm(target_variable::Variable;
ep_sites=Tuple[],
id=Symbol(""),
free_energy=false) = messagePassingAlgorithm([target_variable], currentPosteriorFactorization(); ep_sites=ep_sites, id=id, free_energy=free_energy)

messagePassingAlgorithm(pfz::PosteriorFactorization=currentPosteriorFactorization();
ep_sites=Tuple[],
id=Symbol(""),
free_energy=false) = messagePassingAlgorithm(Variable[], pfz; ep_sites=ep_sites, id=id, free_energy=free_energy)

function interfaceToScheduleEntry(algo::InferenceAlgorithm)
mapping = Dict{Interface, ScheduleEntry}()
Expand All @@ -116,4 +125,33 @@ function targetToMarginalEntry(algo::InferenceAlgorithm)
end

return mapping
end

"""Select an applicable update rule for scheduling"""
function selectApplicableRule(rule_type::Type,
entry::ScheduleEntry, # Message update
inbound_types::Vector{<:Type},
applicable_rules::Vector{<:Type})

if isempty(applicable_rules)
error("No applicable $(rule_type) for interface :$(handle(entry.interface)) of $(entry.interface.node.id) with inbound types:\n[$(join(format.(inbound_types), ", "))]")
elseif length(applicable_rules) > 1
error("Multiple applicable $(rule_type) for interface :$(handle(entry.interface)) of $(entry.interface.node.id) with inbound types:\n[$(join(format.(inbound_types), ", "))]\n[$(join(applicable_rules, ", "))]")
else
return first(applicable_rules)
end
end

function selectApplicableRule(rule_type::Type,
cluster::Cluster, # Marginal update
inbound_types::Vector{<:Type},
applicable_rules::Vector{<:Type})

if isempty(applicable_rules)
error("No applicable $(rule_type) for $(cluster.node.id) with inbound types:\n[$(join(format.(inbound_types), ", "))]")
elseif length(applicable_rules) > 1
error("Multiple applicable $(rule_type) for $(cluster.node.id) with inbound types:\n[$(join(format.(inbound_types), ", "))]\n[$(join(applicable_rules, ", "))]")
else
return first(applicable_rules)
end
end
13 changes: 3 additions & 10 deletions src/algorithms/joint_marginals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,15 @@ Infer the rule that computes the joint marginal over `cluster`
function inferMarginalRule(cluster::Cluster, inbound_types::Vector{<:Type})
# Find applicable rule(s)
applicable_rules = Type[]
for rule in leaftypes(MarginalRule{typeof(cluster.node)})
rule_type = MarginalRule{typeof(cluster.node)}
for rule in leaftypes(rule_type)
if isApplicable(rule, inbound_types)
push!(applicable_rules, rule)
end
end

# Select and set applicable rule
if isempty(applicable_rules)
error("No applicable marginal update rule for $(typeof(cluster.node)) node with inbound types: $(join(inbound_types, ", "))")
elseif length(applicable_rules) > 1
error("Multiple applicable marginal update rules for $(typeof(cluster.node)) node with inbound types: $(join(inbound_types, ", ")): $(join(applicable_rules, ", "))")
else
marginal_update_rule = first(applicable_rules)
end

return marginal_update_rule
return selectApplicableRule(rule_type, cluster, inbound_types, applicable_rules)
end

"""
Expand Down
12 changes: 3 additions & 9 deletions src/algorithms/naive_variational_bayes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Infer the update rule that computes the message for `entry`, as dependent on the
"""
function inferUpdateRule!( entry::ScheduleEntry,
rule_type::Type{T},
inferred_outbound_types::Dict{Interface, Type}) where T<:NaiveVariationalRule
inferred_outbound_types::Dict) where T<:NaiveVariationalRule
# Collect inbound types
inbound_types = collectInboundTypes(entry, rule_type, inferred_outbound_types)

Expand All @@ -25,13 +25,7 @@ function inferUpdateRule!( entry::ScheduleEntry,
end

# Select and set applicable rule
if isempty(applicable_rules)
error("No applicable $(rule_type) update for $(typeof(entry.interface.node)) node with inbound types: $(join(inbound_types, ", "))")
elseif length(applicable_rules) > 1
error("Multiple applicable $(rule_type) updates for $(typeof(entry.interface.node)) node with inbound types: $(join(inbound_types, ", ")): $(join(applicable_rules, ", "))")
else
entry.message_update_rule = first(applicable_rules)
end
entry.message_update_rule = selectApplicableRule(rule_type, entry, inbound_types, applicable_rules)

return entry
end
Expand All @@ -42,7 +36,7 @@ Returns a vector with inbound types that correspond with required interfaces.
"""
function collectInboundTypes( entry::ScheduleEntry,
::Type{T},
inferred_outbound_types::Dict{Interface, Type}) where T<:NaiveVariationalRule
inferred_outbound_types::Dict) where T<:NaiveVariationalRule
inbound_types = Type[]
for node_interface in entry.interface.node.interfaces
if node_interface === entry.interface
Expand Down
15 changes: 14 additions & 1 deletion src/algorithms/posterior_factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,23 @@ function PosteriorFactorization(args::Vararg{Union{T, Set{T}, Vector{T}} where T
PosteriorFactor(arg, id=ids[i])
end
end

# Verify that all stochastic edges are covered by a posterior factor
uncovered_variables = uncoveredVariables(pfz)
isempty(uncovered_variables) || error("Edges for stochastic variables $([var.id for var in uncovered_variables]) must be covered by a posterior factor")

return pfz
end

function uncoveredVariables(pfz::PosteriorFactorization)
stochastic_edges = setdiff(Set(pfz.graph.edges), pfz.deterministic_edges)
covered_stochastic_edges = Set(Iterators.flatten([pf.internal_edges for (_, pf) in pfz.posterior_factors]))
uncovered_edges = setdiff(stochastic_edges, covered_stochastic_edges)
uncovered_variables = Set([edge.variable for edge in uncovered_edges])

return uncovered_variables
end

iterate(pfz::PosteriorFactorization) = iterate(pfz.posterior_factors)
iterate(pfz::PosteriorFactorization, state) = iterate(pfz.posterior_factors, state)
values(pfz::PosteriorFactorization) = values(pfz.posterior_factors)
Expand Down Expand Up @@ -216,7 +230,6 @@ Return the local stochastic regions around `node`
function localStochasticRegions(node::FactorNode, pfz::PosteriorFactorization)
regions = Region[]
for interface in node.interfaces
partner = ultimatePartner(interface)
if !(interface.edge in pfz.deterministic_edges) # If edge is stochastic
push!(regions, region(node, interface.edge))
end
Expand Down
16 changes: 5 additions & 11 deletions src/algorithms/structured_variational_bayes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ abstract type StructuredVariationalRule{factor_type} <: MessageUpdateRule end
Infer the update rule that computes the message for `entry`, as dependent on the inbound types
"""
function inferUpdateRule!(entry::ScheduleEntry,
rule_type::Type{T},
inferred_outbound_types::Dict{Interface, Type}
) where T<:StructuredVariationalRule
rule_type::Type{<:StructuredVariationalRule},
inferred_outbound_types::Dict)

# Collect inbound types
inbound_types = collectInboundTypes(entry, rule_type, inferred_outbound_types)

Expand All @@ -26,13 +26,7 @@ function inferUpdateRule!(entry::ScheduleEntry,
end

# Select and set applicable rule
if isempty(applicable_rules)
error("No applicable $(rule_type) update for $(typeof(entry.interface.node)) node with inbound types: $(join(inbound_types, ", "))")
elseif length(applicable_rules) > 1
error("Multiple applicable $(rule_type) updates for $(typeof(entry.interface.node)) node with inbound types: $(join(inbound_types, ", ")): $(join(applicable_rules, ", "))")
else
entry.message_update_rule = first(applicable_rules)
end
entry.message_update_rule = selectApplicableRule(rule_type, entry, inbound_types, applicable_rules)

return entry
end
Expand All @@ -43,7 +37,7 @@ Returns a vector with inbound types that correspond with required interfaces.
"""
function collectInboundTypes(entry::ScheduleEntry,
::Type{T},
inferred_outbound_types::Dict{Interface, Type}
inferred_outbound_types::Dict
) where T<:StructuredVariationalRule
inbound_types = Type[]
entry_posterior_factor = posteriorFactor(entry.interface.edge) # Collect posterior factor for outbound edge
Expand Down
25 changes: 9 additions & 16 deletions src/algorithms/sum_product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ message out of the specified `outbound_interface`.
"""
function internalSumProductSchedule(cnode::CompositeFactor,
outbound_interface::Interface,
inferred_outbound_types::Dict{Interface, <:Type})
inferred_outbound_types::Dict)

# Collect types of messages towards the CompositeFactor
msg_types = Dict{Interface, Type}()
msg_types = Dict{Union{Interface, Nothing}, Type}(nothing => Missing) # Initialize with fallback
for (idx, terminal) in enumerate(cnode.terminals)
(cnode.interfaces[idx] === outbound_interface) && continue # don't need incoming msg on outbound interface
msg_types[terminal.interfaces[1]] = inferred_outbound_types[cnode.interfaces[idx].partner]
Expand Down Expand Up @@ -51,7 +51,7 @@ end

function inferUpdateRule!(entry::ScheduleEntry,
rule_type::Type{T},
inferred_outbound_types::Dict{Interface, <:Type}
inferred_outbound_types::Dict
) where T<:SumProductRule
# Collect inbound types
inbound_types = collectInboundTypes(entry, rule_type, inferred_outbound_types)
Expand All @@ -65,27 +65,20 @@ function inferUpdateRule!(entry::ScheduleEntry,
end

# Select and set applicable rule
if isempty(applicable_rules)
if isa(entry.interface.node, CompositeFactor)
# No 'shortcut rule' available for CompositeFactor.
# Try to fall back to msg passing on the internal graph.
entry.internal_schedule = internalSumProductSchedule(entry.interface.node, entry.interface, inferred_outbound_types)
entry.message_update_rule = entry.internal_schedule[end].message_update_rule
else
error("No applicable $(rule_type) update for $(typeof(entry.interface.node)) node with inbound types: $(join(inbound_types, ", "))")
end
elseif length(applicable_rules) > 1
error("Multiple applicable $(rule_type) updates for $(typeof(entry.interface.node)) node with inbound types: $(join(inbound_types, ", ")): $(join(applicable_rules, ", "))")
if isempty(applicable_rules) && isa(entry.interface.node, CompositeFactor)
# No shortcut rule available for CompositeFactor, fall back to message passing on the internal graph.
entry.internal_schedule = internalSumProductSchedule(entry.interface.node, entry.interface, inferred_outbound_types)
entry.message_update_rule = entry.internal_schedule[end].message_update_rule
else
entry.message_update_rule = first(applicable_rules)
entry.message_update_rule = selectApplicableRule(rule_type, entry, inbound_types, applicable_rules)
end

return entry
end

function collectInboundTypes(entry::ScheduleEntry,
::Type{T},
inferred_outbound_types::Dict{Interface, <:Type}
inferred_outbound_types::Dict
) where T<:SumProductRule
inbound_message_types = Type[]
for node_interface in entry.interface.node.interfaces
Expand Down
4 changes: 4 additions & 0 deletions src/distribution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ end
const P = Distribution
const ProbabilityDistribution = Distribution # For backwards compatibility

"""Distribution formatting for concise printing"""
format(::Type{<:Distribution{V}}) where V<:VariateType = "Distribution{$V}"
format(::Type{<:Distribution{V, F}}) where {V<:VariateType, F<:FactorFunction} = "Distribution{$V, $F}"

"""Sample multiple realizations from a probability distribution"""
sample(dist::Distribution, n_samples::Int64) = [sample(dist) for i in 1:n_samples] # TODO: individual samples can be optimized

Expand Down
52 changes: 50 additions & 2 deletions src/engines/julia/update_rules/gaussian_moments.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ ruleSPGaussianMomentsMSNP,
ruleSPGaussianMomentsOutNGS,
ruleSPGaussianMomentsMGNS,
ruleVBGaussianMomentsM,
ruleVBGaussianMomentsOut
ruleVBGaussianMomentsOut,
ruleSVBGaussianMomentsOutVGD,
ruleSVBGaussianMomentsMGVD,
ruleMGaussianMomentsGGD,
ruleMGaussianMomentsGGN

ruleSPGaussianMomentsOutNPP(msg_out::Nothing,
msg_mean::Message{PointMass, V},
Expand Down Expand Up @@ -92,4 +96,48 @@ ruleVBGaussianMomentsM(dist_out::Distribution{V},
ruleVBGaussianMomentsOut(dist_out::Any,
dist_mean::Distribution{V},
dist_var::Distribution) where V<:VariateType =
Message(V, Gaussian{Moments}, m=unsafeMean(dist_mean), v=unsafeMean(dist_var))
Message(V, Gaussian{Moments}, m=unsafeMean(dist_mean), v=unsafeMean(dist_var))

ruleSVBGaussianMomentsOutVGD(dist_out::Any, # Only implemented for PointMass variance
msg_mean::Message{<:Gaussian, V},
dist_var::Distribution{<:VariateType, PointMass}) where V<:VariateType =
Message(V, Gaussian{Moments}, m=unsafeMean(msg_mean.dist), v=unsafeCov(msg_mean.dist) + unsafeMean(dist_var))

ruleSVBGaussianMomentsMGVD(msg_out::Message{F, V}, # Only implemented for PointMass variance
dist_mean::Any,
dist_var::Distribution{<:VariateType, PointMass}) where {F<:Gaussian, V<:VariateType} =
Message(V, Gaussian{Moments}, m=unsafeMean(msg_out.dist), v=unsafeCov(msg_out.dist) + unsafeMean(dist_var))

function ruleMGaussianMomentsGGD( # Only implemented for PointMass variance
msg_out::Message{<:Gaussian, V},
msg_mean::Message{<:Gaussian, V},
dist_var::Distribution{<:VariateType, PointMass}) where V<:VariateType

d_mean = convert(Distribution{V, Gaussian{Canonical}}, msg_mean.dist)
d_out = convert(Distribution{V, Gaussian{Canonical}}, msg_out.dist)

xi_y = d_out.params[:xi]
W_y = d_out.params[:w]
xi_m = d_mean.params[:xi]
W_m = d_mean.params[:w]
W_bar = cholinv(unsafeMean(dist_var))

return Distribution(Multivariate, Gaussian{Canonical}, xi=[xi_y; xi_m], w=[W_y+W_bar -W_bar; -W_bar W_m+W_bar])
end

function ruleMGaussianMomentsGGN(
msg_out::Message{<:Gaussian, V},
msg_mean::Message{<:Gaussian, V},
msg_var::Message{PointMass}) where V<:VariateType

d_mean = convert(Distribution{V, Gaussian{Canonical}}, msg_mean.dist)
d_out = convert(Distribution{V, Gaussian{Canonical}}, msg_out.dist)

xi_y = d_out.params[:xi]
W_y = d_out.params[:w]
xi_m = d_mean.params[:xi]
W_m = d_mean.params[:w]
W_bar = cholinv(msg_var.dist.params[:m])

return Distribution(Multivariate, Gaussian{Canonical}, xi=[xi_y; xi_m], w=[W_y+W_bar -W_bar; -W_bar W_m+W_bar])
end
12 changes: 4 additions & 8 deletions src/engines/julia/update_rules/gaussian_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,10 @@ function ruleSVBGaussianPrecisionW(
end
end

function ruleSVBGaussianPrecisionMGVD(msg_out::Message{F, V},
dist_mean::Any,
dist_prec::Distribution) where {F<:Gaussian, V<:VariateType}

d_out = convert(Distribution{V, Gaussian{Moments}}, msg_out.dist)

Message(V, Gaussian{Moments}, m=d_out.params[:m], v=d_out.params[:v] + cholinv(unsafeMean(dist_prec)))
end
ruleSVBGaussianPrecisionMGVD(msg_out::Message{F, V},
dist_mean::Any,
dist_prec::Distribution) where {F<:Gaussian, V<:VariateType} =
Message(V, Gaussian{Moments}, m=unsafeMean(msg_out.dist), v=unsafeCov(msg_out.dist) + cholinv(unsafeMean(dist_prec)))

function ruleMGaussianPrecisionGGD(
msg_out::Message{<:Gaussian, V},
Expand Down
Loading