forked from bankofcanada/ModelBaseEcon.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Share the ForwardDiff tag between models and fix a typo in precompila…
…tion Currently, every model gets its own ForwardDiff tag which means that every model also have a unique type of their dual numbers. This causes every function called with dual numbers to have to be recompiled for every model. In this PR, we define a shared tag in ModelBasedEcon that all models use. This means that we can push the precompile generation for many functions from the model into ModelBaseEcon itself which changes the cost of them from O(1) to O(n_models). This PR also corrects a mismatch in the `precompile` call and the call to `ForwardDiff`. In the precompile calls `MyTag` was used as the type to precompile for which means that the calls to `GradientConfig` should have used `MyTag()` (so that the type of the tag was `MyTag`.) Now, when `MyTag` was used to the `GradientConfig` call the type of it is actually `DataType` which means that the types in the `precompile` call was different compared to the types actually encountered at runtime. Using the following benchmark script: ```julia unique!(push!(LOAD_PATH, realpath("./models"))) # hide @time using ModelBaseEcon using Random # See JuliaLang/julia#48810 @time using FRBUS_VAR m = FRBUS_VAR.model nrows = 1 + m.maxlag + m.maxlead ncols = length(m.allvars) pt = zeros(nrows, ncols); @time @eval eval_RJ(pt, m); using BenchmarkTools @Btime eval_RJ(pt, m); ``` this PR has the following changes: - Loading ModelBaseEcon: 0.641551s -> 0.645943s - Loading model 0.053s -> 0.032s - First call `eval_RJ`: 5.50s -> 0.64s - Benchmark `eval_RJ`: 597.966μs -> 573.923μs
- Loading branch information
KristofferC
committed
Feb 27, 2023
1 parent
8631f7c
commit 17184b9
Showing
3 changed files
with
55 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
|
||
""" | ||
precompilefuncs(N::Int) | ||
Pre-compiles functions used by models for a `ForwardDiff.Dual` numbers | ||
with chunk size `N`. | ||
!!! warning | ||
Internal function. Do not call directly | ||
""" | ||
function precompile_funcs(N::Int) | ||
ccall(:jl_generating_output, Cint, ()) == 1 || return nothing | ||
|
||
tag = ModelBaseEconTag | ||
dual = ForwardDiff.Dual{tag,Float64,N} | ||
duals = Vector{dual} | ||
cfg = ForwardDiff.GradientConfig{tag,Float64,N,duals} | ||
mdr = DiffResults.MutableDiffResult{1,Float64,Tuple{Vector{Float64}}} | ||
|
||
for pred in Symbol[:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger, :-, :+, :log, :exp] | ||
pred ∈ (:iseven, :isodd) || precompile(getfield(Base, pred), (Float64,)) || error("precompile") | ||
precompile(getfield(Base, pred), (dual,)) || error("precompile") | ||
end | ||
|
||
for pred in Symbol[:isequal, :isless, :<, :>, :(==), :(!=), :(<=), :(>=), :+, :-, :*, :/, :^] | ||
precompile(getfield(Base, pred), (Float64, Float64)) || error("precompile") | ||
precompile(getfield(Base, pred), (dual, Float64)) || error("precompile") | ||
precompile(getfield(Base, pred), (Float64, dual)) || error("precompile") | ||
precompile(getfield(Base, pred), (dual, dual)) || error("precompile") | ||
end | ||
|
||
precompile(ForwardDiff.extract_gradient!, (Type{tag}, mdr, dual)) || error("precompile") | ||
precompile(ForwardDiff.vector_mode_gradient!, (mdr, FunctionWrapper, Vector{Float64}, cfg)) || error("precompile") | ||
|
||
return nothing | ||
end | ||
|
||
for i in 1:MAX_CHUNK_SIZE | ||
precompile_funcs(i) | ||
end |