Skip to content

Commit

Permalink
Share the ForwardDiff tag between models and fix a typo in precompila…
Browse files Browse the repository at this point in the history
…tion

Currently, every model gets its own ForwardDiff tag which means that every model
also have a unique type of their dual numbers. This causes every function called
with dual numbers to have to be recompiled for every model.

In this PR, we define a shared tag in ModelBasedEcon that all models use.
This means that we can push the precompile generation for many functions
from the model into ModelBaseEcon itself which changes the cost of them
from O(1) to O(n_models).

This PR also corrects a mismatch in the `precompile` call and the call to `ForwardDiff`.
In the precompile calls `MyTag` was used as the type to precompile for which means that the calls
to `GradientConfig` should have used `MyTag()` (so that the type of the tag was `MyTag`.) Now,
when `MyTag` was used to the `GradientConfig` call the type of it is actually `DataType` which
means that the types in the `precompile` call was different compared to the types actually
encountered at runtime.

Using the following benchmark script:

```julia
unique!(push!(LOAD_PATH, realpath("./models"))) # hide

@time using ModelBaseEcon
using Random # See JuliaLang/julia#48810

@time using FRBUS_VAR

m = FRBUS_VAR.model
nrows = 1 + m.maxlag + m.maxlead
ncols = length(m.allvars)
pt = zeros(nrows, ncols);
@time @eval eval_RJ(pt, m);

using BenchmarkTools
@Btime eval_RJ(pt, m);
```

this PR has the following changes:

- Loading ModelBaseEcon: 0.641551s -> 0.645943s
- Loading model 0.053s -> 0.032s
- First call `eval_RJ`: 5.50s -> 0.64s
- Benchmark `eval_RJ`:  597.966μs -> 573.923μs
  • Loading branch information
KristofferC committed Feb 27, 2023
1 parent 8631f7c commit 17184b9
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 35 deletions.
3 changes: 2 additions & 1 deletion src/ModelBaseEcon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
ModelBaseEcon
This package is part of the StateSpaceEcon ecosystem.
This package is part of the StateSpaceEcon ecosystem.
It provides the basic elements needed for model definition.
StateSpaceEcon works with model objects defined with ModelBaseEcon.
"""
Expand Down Expand Up @@ -44,6 +44,7 @@ include("metafuncs.jl")
include("model.jl")
include("export_model.jl")
include("linearize.jl")
include("precompile.jl")

"""
@using_example name
Expand Down
46 changes: 12 additions & 34 deletions src/evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
###########################################################
# Part 1: Helper functions

struct ModelBaseEconTag end

"""
precompilefuncs(resid, RJ, ::Val{N}, tag) where N
Expand All @@ -19,36 +20,16 @@ with the dual-number arithmetic required by ForwardDiff.
Internal function. Do not call directly
"""
function precompilefuncs(resid, RJ, ::Val{N}, tag) where {N}
function precompilefuncs(resid, RJ, ::Val{N}) where {N}
ccall(:jl_generating_output, Cint, ()) == 1 || return nothing

# tag = MyTag # ForwardDiff.Tag{resid,Float64}
dual = ForwardDiff.Dual{tag,Float64,N}
duals = Array{dual,1}
cfg = ForwardDiff.GradientConfig{tag,Float64,N,duals}
mdr = DiffResults.MutableDiffResult{1,Float64,Tuple{Array{Float64,1}}}
tagtype = ModelBaseEconTag
dual = ForwardDiff.Dual{tagtype,Float64,N}
duals = Vector{dual}

precompile(resid, (Array{Float64,1},)) || error("precompile")
precompile(resid, (Vector{Float64},)) || error("precompile")
precompile(resid, (duals,)) || error("precompile")
precompile(RJ, (Array{Float64,1},)) || error("precompile")

for pred in Symbol[:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger, :-, :+, :log, :exp]
pred (:iseven, :isodd) || precompile(getfield(Base, pred), (Float64,)) || error("precompile")
precompile(getfield(Base, pred), (dual,)) || error("precompile")
end

for pred in Symbol[:isequal, :isless, :<, :>, :(==), :(!=), :(<=), :(>=), :+, :-, :*, :/, :^]
precompile(getfield(Base, pred), (Float64, Float64)) || error("precompile")
precompile(getfield(Base, pred), (dual, Float64)) || error("precompile")
precompile(getfield(Base, pred), (Float64, dual)) || error("precompile")
precompile(getfield(Base, pred), (dual, dual)) || error("precompile")
end

# precompile(ForwardDiff.extract_gradient!, (Type{tag}, mdr, dual)) || error("precompile")
# precompile(ForwardDiff.vector_mode_gradient!, (mdr, typeof(resid), Array{Float64,1}, cfg)) || error("precompile")

# precompile(Tuple{typeof(ForwardDiff.extract_gradient!), Type{tag}, mdr, dual}) || error("precompile")
# precompile(Tuple{typeof(ForwardDiff.vector_mode_gradient!), mdr, resid, Array{Float64, 1}, cfg}) || error("precompile")
precompile(RJ, (Vector{Float64},)) || error("precompile")

return nothing
end
Expand Down Expand Up @@ -81,9 +62,7 @@ function funcsyms(mod::Module)
return fn1, fn2
end

# Can be changed to MAX_CHUNK_SIZE::Bool = 4 when support for Julia 1.7
# is dropped.
const MAX_CHUNK_SIZE = Ref(4)
const MAX_CHUNK_SIZE = 4

# Used to avoid specialzing the ForwardDiff functions on
# every equation.
Expand Down Expand Up @@ -117,7 +96,7 @@ function makefuncs(expr, tssyms, sssyms, psyms, mod)
fn1, fn2 = funcsyms(mod)
x = gensym("x")
nargs = length(tssyms) + length(sssyms)
chunk = min(nargs, MAX_CHUNK_SIZE[])
chunk = min(nargs, MAX_CHUNK_SIZE)
return quote
function (ee::EquationEvaluator{$(QuoteNode(fn1))})($x::Vector{<:Real})
($(tssyms...), $(sssyms...),) = $x
Expand All @@ -127,7 +106,7 @@ function makefuncs(expr, tssyms, sssyms, psyms, mod)
const $fn1 = EquationEvaluator{$(QuoteNode(fn1))}(UInt(0),
$(@__MODULE__).LittleDict(Symbol[$(QuoteNode.(psyms)...)], fill(nothing, $(length(psyms)))))
const $fn2 = EquationGradient($FunctionWrapper($fn1), $nargs, Val($chunk))
$(@__MODULE__).precompilefuncs($fn1, $fn2, Val($chunk), MyTag)
$(@__MODULE__).precompilefuncs($fn1, $fn2, Val($chunk))
($fn1, $fn2)
end
end
Expand All @@ -151,9 +130,8 @@ together with a `DiffResult` and a `GradientConfig` used by `ForwardDiff`. Its
call is defined here and computes the residual and the gradient.
"""
function initfuncs(mod::Module)
if :MyTag names(mod; all=true)
if :EquationEvaluator names(mod; all=true)
mod.eval(quote
struct MyTag end
struct EquationEvaluator{FN} <: Function
rev::Ref{UInt}
params::$(@__MODULE__).LittleDict{Symbol,Any}
Expand All @@ -165,7 +143,7 @@ function initfuncs(mod::Module)
end
EquationGradient(fn1::Function, nargs::Int, ::Val{N}) where {N} = EquationGradient(fn1,
$(@__MODULE__).DiffResults.DiffResult(zero(Float64), zeros(Float64, nargs)),
$(@__MODULE__).ForwardDiff.GradientConfig(fn1, zeros(Float64, nargs), $(@__MODULE__).ForwardDiff.Chunk{N}(), MyTag))
$(@__MODULE__).ForwardDiff.GradientConfig(fn1, zeros(Float64, nargs), $(@__MODULE__).ForwardDiff.Chunk{N}(), $ModelBaseEconTag()))
function (s::EquationGradient)(x::Vector{Float64})
$(@__MODULE__).ForwardDiff.gradient!(s.dr, s.fn1, x, s.cfg)
return s.dr.value, s.dr.derivs[1]
Expand Down
41 changes: 41 additions & 0 deletions src/precompile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@

"""
precompilefuncs(N::Int)
Pre-compiles functions used by models for a `ForwardDiff.Dual` numbers
with chunk size `N`.
!!! warning
Internal function. Do not call directly
"""
function precompile_funcs(N::Int)
ccall(:jl_generating_output, Cint, ()) == 1 || return nothing

tag = ModelBaseEconTag
dual = ForwardDiff.Dual{tag,Float64,N}
duals = Vector{dual}
cfg = ForwardDiff.GradientConfig{tag,Float64,N,duals}
mdr = DiffResults.MutableDiffResult{1,Float64,Tuple{Vector{Float64}}}

for pred in Symbol[:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger, :-, :+, :log, :exp]
pred (:iseven, :isodd) || precompile(getfield(Base, pred), (Float64,)) || error("precompile")
precompile(getfield(Base, pred), (dual,)) || error("precompile")
end

for pred in Symbol[:isequal, :isless, :<, :>, :(==), :(!=), :(<=), :(>=), :+, :-, :*, :/, :^]
precompile(getfield(Base, pred), (Float64, Float64)) || error("precompile")
precompile(getfield(Base, pred), (dual, Float64)) || error("precompile")
precompile(getfield(Base, pred), (Float64, dual)) || error("precompile")
precompile(getfield(Base, pred), (dual, dual)) || error("precompile")
end

precompile(ForwardDiff.extract_gradient!, (Type{tag}, mdr, dual)) || error("precompile")
precompile(ForwardDiff.vector_mode_gradient!, (mdr, FunctionWrapper, Vector{Float64}, cfg)) || error("precompile")

return nothing
end

for i in 1:MAX_CHUNK_SIZE
precompile_funcs(i)
end

0 comments on commit 17184b9

Please sign in to comment.