Skip to content

Commit

Permalink
add progress bar to callback (#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg authored Mar 29, 2023
1 parent 9dad6a2 commit 373f192
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 7 deletions.
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Expand All @@ -26,6 +27,7 @@ NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Expand All @@ -35,8 +37,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractDifferentiation = "0.4, 0.5"
CUDA = "3.13, 4.0"
AbstractDifferentiation = "0.5"
CUDA = "4.0"
ComponentArrays = "0.13"
ComputationalResources = "0.3"
DataFrames = "1.2"
Expand All @@ -56,9 +58,10 @@ NNlibCUDA = "0.2"
Optimisers = "0.2"
Optimization = "3.5"
OptimizationOptimisers = "0.1"
ProgressMeter = "1.7"
SciMLBase = "1.27"
SciMLSensitivity = "7.0"
ScientificTypes = "3.0"
SparseDiffTools = "1.30, 2"
SparseDiffTools = "1.31, 2.0"
Zygote = "0.6"
julia = "1.6"
2 changes: 2 additions & 0 deletions src/ICNF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ using AbstractDifferentiation,
Optimisers,
Optimization,
OptimizationOptimisers,
ProgressMeter,
SciMLBase,
SciMLSensitivity,
ScientificTypes,
SparseDiffTools,
Zygote,
Dates,
LinearAlgebra,
Random,
Statistics,
Expand Down
7 changes: 5 additions & 2 deletions src/core.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# SciML interface

function callback_f(icnf::AbstractFlows)::Function
function callback_f(icnf::AbstractFlows, n::Integer)::Function
prgr = Progress(n; dt=eps(), desc="Training: ", showspeed=true)
itr_n = 1
function f(ps, l)
@info "Training" loss = l
ProgressMeter.next!(prgr; showvalues = [(:loss_value, l), (:iteration, itr_n), (:last_update, Dates.now())])
itr_n += one(itr_n)
false
end
f
Expand Down
2 changes: 1 addition & 1 deletion src/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
end
ncdata = ncycle(data, model.n_epochs)
_loss = loss_f(model.m, model.loss, st)
_callback = callback_f(model.m)
optfunc = OptimizationFunction(_loss, model.adtype)
optprob = OptimizationProblem(optfunc, ps)
tst = @timed for opt in model.optimizers
optprob_re = remake(optprob; u0 = ps)
if model.have_callback
_callback = callback_f(model.m, length(ncdata))
res = solve(optprob_re, opt, ncdata; callback = _callback)
else
res = solve(optprob_re, opt, ncdata)
Expand Down
2 changes: 1 addition & 1 deletion src/core_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,12 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
end
ncdata = ncycle(data, model.n_epochs)
_loss = loss_f(model.m, model.loss, st)
_callback = callback_f(model.m)
optfunc = OptimizationFunction(_loss, model.adtype)
optprob = OptimizationProblem(optfunc, ps)
tst = @timed for opt in model.optimizers
optprob_re = remake(optprob; u0 = ps)
if model.have_callback
_callback = callback_f(model.m, length(ncdata))
res = solve(optprob_re, opt, ncdata; callback = _callback)
else
res = solve(optprob_re, opt, ncdata)
Expand Down

0 comments on commit 373f192

Please sign in to comment.