Skip to content

Commit

Permalink
update defaults (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg authored Aug 14, 2022
1 parent 65c5758 commit 4ef6577
Show file tree
Hide file tree
Showing 10 changed files with 20 additions and 13 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ version = "0.1.0"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Expand All @@ -36,12 +36,12 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Adapt = "3.3"
CUDA = "3.7"
DataFrames = "1.2"
DifferentialEquations = "7.0"
Distributions = "0.25"
DistributionsAD = "0.6"
FillArrays = "0.13"
Flux = "0.13"
IterTools = "1.3"
LineSearches = "7.1"
MLJBase = "0.20"
MLJModelInterface = "1.3"
MLUtils = "0.2"
Expand Down
2 changes: 1 addition & 1 deletion src/ICNF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ module ICNF
Adapt,
CUDA,
DataFrames,
DifferentialEquations,
Distributions,
DistributionsAD,
FillArrays,
Flux,
IterTools,
LineSearches,
MLJBase,
MLJModelInterface,
MLUtils,
Expand Down
2 changes: 1 addition & 1 deletion src/cond_ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function CondFFJORD{T, AT}(
nvars::Integer,
;
basedist::Distribution=MvNormal(Zeros{T}(nvars), one(T)*I),
tspan::Tuple{T, T}=convert(Tuple{T, T}, (0, 1)),
tspan::Tuple{T, T}=convert(Tuple{T, T}, default_tspan),

solvealg_test::SciMLBase.AbstractODEAlgorithm=default_solvealg,
solvealg_train::SciMLBase.AbstractODEAlgorithm=default_solvealg,
Expand Down
2 changes: 1 addition & 1 deletion src/cond_planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function CondPlanar{T, AT}(
nvars::Integer,
;
basedist::Distribution=MvNormal(Zeros{T}(nvars), one(T)*I),
tspan::Tuple{T, T}=convert(Tuple{T, T}, (0, 1)),
tspan::Tuple{T, T}=convert(Tuple{T, T}, default_tspan),

solvealg_test::SciMLBase.AbstractODEAlgorithm=default_solvealg,
solvealg_train::SciMLBase.AbstractODEAlgorithm=default_solvealg,
Expand Down
2 changes: 1 addition & 1 deletion src/cond_rnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function CondRNODE{T, AT}(
nvars::Integer,
;
basedist::Distribution=MvNormal(Zeros{T}(nvars), one(T)*I),
tspan::Tuple{T, T}=convert(Tuple{T, T}, (0, 1)),
tspan::Tuple{T, T}=convert(Tuple{T, T}, default_tspan),

solvealg_test::SciMLBase.AbstractODEAlgorithm=default_solvealg,
solvealg_train::SciMLBase.AbstractODEAlgorithm=default_solvealg,
Expand Down
13 changes: 10 additions & 3 deletions src/defaults.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
default_tspan = (0, 1)
default_solvealg = Tsit5(
;
thread=OrdinaryDiffEq.True(),
)
default_sensealg = InterpolatingAdjoint(
;
autodiff=true,
checkpointing=false,
noisemixing=false,
chunk_size=0,
autojacvec=ZygoteVJP(),
)
default_optimizer = Dict(
FluxOptApp => Flux.AMSGrad(),
OptimOptApp => BFGS(),
SciMLOptApp => Optimisers.AMSGrad(),
FluxOptApp => Flux.AMSGrad(0.001, (0.9, 0.999), eps()),
OptimOptApp => BFGS(
alphaguess=InitialHagerZhang(),
linesearch=HagerZhang(),
manifold=Flat(),
),
SciMLOptApp => Optimisers.AMSGrad(0.001, (0.9, 0.999), eps()),
)
2 changes: 1 addition & 1 deletion src/ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function FFJORD{T, AT}(
nvars::Integer,
;
basedist::Distribution=MvNormal(Zeros{T}(nvars), one(T)*I),
tspan::Tuple{T, T}=convert(Tuple{T, T}, (0, 1)),
tspan::Tuple{T, T}=convert(Tuple{T, T}, default_tspan),

solvealg_test::SciMLBase.AbstractODEAlgorithm=default_solvealg,
solvealg_train::SciMLBase.AbstractODEAlgorithm=default_solvealg,
Expand Down
2 changes: 1 addition & 1 deletion src/planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function Planar{T, AT}(
nvars::Integer,
;
basedist::Distribution=MvNormal(Zeros{T}(nvars), one(T)*I),
tspan::Tuple{T, T}=convert(Tuple{T, T}, (0, 1)),
tspan::Tuple{T, T}=convert(Tuple{T, T}, default_tspan),

solvealg_test::SciMLBase.AbstractODEAlgorithm=default_solvealg,
solvealg_train::SciMLBase.AbstractODEAlgorithm=default_solvealg,
Expand Down
2 changes: 1 addition & 1 deletion src/rnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function RNODE{T, AT}(
nvars::Integer,
;
basedist::Distribution=MvNormal(Zeros{T}(nvars), one(T)*I),
tspan::Tuple{T, T}=convert(Tuple{T, T}, (0, 1)),
tspan::Tuple{T, T}=convert(Tuple{T, T}, default_tspan),

solvealg_test::SciMLBase.AbstractODEAlgorithm=default_solvealg,
solvealg_train::SciMLBase.AbstractODEAlgorithm=default_solvealg,
Expand Down
2 changes: 1 addition & 1 deletion test/smoke_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Optimization.AutoTracker(),
Optimization.AutoFiniteDiff(),
]
go_mds = Any[ICNF.default_optimizer[FluxOptApp], ICNF.default_optimizer[OptimOptApp]]
go_mds = Any[ICNF.default_optimizer[FluxOptApp], ICNF.default_optimizer[OptimOptApp], ICNF.default_optimizer[SciMLOptApp]]
nvars_ = (1:2)
n_epochs = 2
batch_size = 8
Expand Down

0 comments on commit 4ef6577

Please sign in to comment.