diff --git a/Manifest.toml b/Manifest.toml index 7210fa33bf..8cd785e521 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -256,6 +256,13 @@ git-tree-sha1 = "d45739abcfc03b51f6a42712894a593f74c80a23" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" version = "1.3.3" +[[ParameterSchedulers]] +git-tree-sha1 = "ab80539e1061e586a49300813039f5c11d3ba8e8" +repo-rev = "darsnack/rm-optim" +repo-url = "https://github.com/darsnack/ParameterSchedulers.jl.git" +uuid = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" +version = "0.2.0" + [[Pkg]] deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/Project.toml b/Project.toml index b51f143252..39c3c3afa9 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/Flux.jl b/src/Flux.jl index 5e6776d601..179a68d80c 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -24,6 +24,8 @@ include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs using .Optimise: skip +using .Optimise: Schedule +export Schedule export Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, OADAM, ADAMW, RADAM, AdaBelief, InvDecay, ExpDecay, diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index e2485a05d0..d552a5b388 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -11,4 +11,9 @@ export train!, update!, include("optimisers.jl") include("train.jl") +module Schedule +using ParameterSchedulers +using ParameterSchedulers : next! +end + end diff --git a/test/optimise.jl b/test/optimise.jl index 04cbf6f6c0..27b7897a8f 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -44,6 +44,19 @@ end end end +@testset "Scheduler" begin + schedule = Schedule.Exp(λ = 0.1, γ = 0.5) + opt = Descent() + scheduler = Schedule.Scheduler(schedule, opt) + m = Chain(Dense(10, 5), Dense(5, 2, tanh)) + ps = params(m) + for t in 1:10 + gs = gradient(() -> sum(m(rand(10))), ps) + Optimise.update!(scheduler, ps, gs) + @test opt.eta ≈ schedule[t] + end +end + @testset "Training Loop" begin i = 0 l = 1