Skip to content

Commit

Permalink
Fix a bug in subexpression handling (#1979) (#1984)
Browse files Browse the repository at this point in the history
Subexpressions in some cases were not correctly linked to the
constraints that they appeared in via a transitive dependency (chain of
subexpressions). This could cause incorrect derivative computations for models
with this structure.

Thanks to @sanderclaeys for providing the test case.
  • Loading branch information
mlubin authored Jun 8, 2019
1 parent 90cdbf8 commit ed04a7a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 30 deletions.
73 changes: 43 additions & 30 deletions src/_Derivatives/subexpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,56 +15,69 @@ function list_subexpressions(nd::Vector{NodeData})
return sort(collect(indices))
end

# order the subexpressions which main_expressions depend on
# such that we can run forward mode in this order
function order_subexpressions(main_expressions::Vector{Vector{NodeData}},subexpressions::Vector{Vector{NodeData}})
nsub = length(subexpressions)
computed = falses(nsub)
# Order the subexpressions which main_expressions depend on such that we can
# run forward mode in this order.
function order_subexpressions(main_expressions::Vector{Vector{NodeData}},
subexpressions::Vector{Vector{NodeData}})
num_sub = length(subexpressions)
computed = falses(num_sub)
dependencies = Dict{Int,Vector{Int}}()
to_visit = collect(nsub+1:nsub+length(main_expressions))
depended_on_by = [Set{Int}() for i in 1:nsub]
to_visit = collect(num_sub + 1 : num_sub + length(main_expressions))
# For each subexpression k, the indices of the main expressions that depend
# on k, possibly transitively.
depended_on_by = [Set{Int}() for i in 1:num_sub]

while !isempty(to_visit)
idx = pop!(to_visit)
if idx > nsub
li = list_subexpressions(main_expressions[idx-nsub])
if idx > num_sub
subexpr = list_subexpressions(main_expressions[idx - num_sub])
else
computed[idx] && continue
li = list_subexpressions(subexpressions[idx])
subexpr = list_subexpressions(subexpressions[idx])
computed[idx] = true
end
dependencies[idx] = li
for k in li
if idx > nsub
push!(depended_on_by[k], idx-nsub)
else
union!(depended_on_by[k], depended_on_by[idx])
dependencies[idx] = subexpr
for k in subexpr
if idx > num_sub
push!(depended_on_by[k], idx - num_sub)
end
push!(to_visit,k)
push!(to_visit, k)
end
end

# now order dependencies
# Now order dependencies.
I = Int[]
J = Int[]
for (idx,li) in dependencies
for k in li
push!(I,idx)
push!(J,k)
for (idx, subexpr) in dependencies
for k in subexpr
push!(I, idx)
push!(J, k)
end
end
N = nsub+length(main_expressions)
sp = sparse(I,J,ones(length(I)),N,N)
N = num_sub + length(main_expressions)
sp = sparse(I, J, ones(length(I)), N, N)
cmap = Vector{Int}(undef, N)
order = reverse(Coloring.reverse_topological_sort_by_dfs(sp.rowval,sp.colptr,N,cmap)[1])
# remove the subexpressions which never appear anywhere
# and the indices of the main expressions
order_filtered = collect(filter(idx -> (idx <= nsub && computed[idx]), order))
# also generate an individual order for each main expression
order = reverse(Coloring.reverse_topological_sort_by_dfs(sp.rowval,
sp.colptr, N,
cmap)[1])
# Remove the subexpressions which never appear anywhere and the indices of
# the main expressions.
condition(idx) = idx <= num_sub && computed[idx]
order_filtered = collect(filter(condition, order))

# Propagate transitive dependencies.
for o in Iterators.reverse(order_filtered)
@assert !isempty(depended_on_by[o])
for k in list_subexpressions(subexpressions[o])
union!(depended_on_by[k], depended_on_by[o])
end
end

# Generate an individual order for each main expression.
individual_order = [Int[] for i in 1:length(main_expressions)]
for o in order_filtered
for i in depended_on_by[o]
push!(individual_order[i],o)
push!(individual_order[i], o)
end
end

Expand Down
10 changes: 10 additions & 0 deletions test/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ li,li_individual = order_subexpressions(Vector{NodeData}[nd_outer2], Vector{Node
@test li == [1,2]
@test li_individual[1] == [1,2]

@testset "order_subexpressions with nested subexpressions" begin
expr_order, expr_order_individual = order_subexpressions(
[[NodeData(SUBEXPRESSION, 1, -1)], [NodeData(SUBEXPRESSION, 1, -1)]],
[[NodeData(SUBEXPRESSION, 2, -1)], NodeData[]])
@test expr_order == [2, 1]
@test expr_order_individual[1] == [2, 1]
@test expr_order_individual[2] == [2, 1]
end


adj_outer = adjmat(nd_outer)
outer_storage = zeros(1)
outer_storage_partials = zeros(1)
Expand Down

0 comments on commit ed04a7a

Please sign in to comment.