diff --git a/Project.toml b/Project.toml index aee97ae05..b638ea65a 100644 --- a/Project.toml +++ b/Project.toml @@ -34,7 +34,7 @@ Boltz = "1" ChainRulesCore = "1" ComponentArrays = "0.15.17" ConcreteStructs = "0.2" -DataInterpolations = "5, 6" +DataInterpolations = "6.4" DelayDiffEq = "5.47.3" DiffEqCallbacks = "3.6.2" Distances = "0.10.11" @@ -54,9 +54,9 @@ LuxLib = "1.2" NNlib = "0.9.22" OneHotArrays = "0.2.5" Optimisers = "0.3" -Optimization = "3.25.0" -OptimizationOptimJL = "0.3.0" -OptimizationOptimisers = "0.2.1" +Optimization = "4" +OptimizationOptimJL = "0.4" +OptimizationOptimisers = "0.3" OrdinaryDiffEq = "6.76.0" Printf = "1.10" Random = "1.10" diff --git a/docs/Project.toml b/docs/Project.toml index 09aa6590e..ee0dbb7d9 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -57,10 +57,10 @@ MLUtils = "0.4" NNlib = "0.9" OneHotArrays = "0.2" Optimisers = "0.3" -Optimization = "3.9" -OptimizationOptimJL = "0.2, 0.3" -OptimizationOptimisers = "0.2" -OptimizationPolyalgorithms = "0.2" +Optimization = "4" +OptimizationOptimJL = "0.4" +OptimizationOptimisers = "0.3" +OptimizationPolyalgorithms = "0.3" OrdinaryDiffEq = "6.31" Plots = "1.36" Printf = "1" diff --git a/docs/src/examples/augmented_neural_ode.md b/docs/src/examples/augmented_neural_ode.md index a0f9a6a0f..6096f597b 100644 --- a/docs/src/examples/augmented_neural_ode.md +++ b/docs/src/examples/augmented_neural_ode.md @@ -69,13 +69,13 @@ function plot_contour(model, ps, st, npoints = 300) return contour(x, y, sol; fill = true, linewidth = 0.0) end -loss_node(model, x, y, ps, st) = mean((first(model(x, ps, st)) .- y) .^ 2) +loss_node(model, data, ps, st) = mean((first(model(data[1], ps, st)) .- data[2]) .^ 2) dataloader = concentric_sphere( 2, (0.0f0, 2.0f0), (3.0f0, 4.0f0), 2000, 2000; batch_size = 256) iter = 0 -cb = function (ps, l) +cb = function (state, l) global iter iter += 1 if iter % 10 == 0 @@ -87,15 +87,15 @@ end model, ps, st = construct_model(1, 2, 64, 0) opt = OptimizationOptimisers.Adam(0.005) -loss_node(model, dataloader.data[1], dataloader.data[2], ps, st) +loss_node(model, (dataloader.data[1], dataloader.data[2]), ps, st) println("Training Neural ODE") optfunc = OptimizationFunction( - (x, p, data, target) -> loss_node(model, data, target, x, st), + (x, data) -> loss_node(model, data, x, st), Optimization.AutoZygote()) -optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev) -res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb) +optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader) +res = solve(optprob, opt; callback = cb, epochs = 100) plt_node = plot_contour(model, res.u, st) @@ -106,10 +106,10 @@ println() println("Training Augmented Neural ODE") optfunc = OptimizationFunction( - (x, p, data, target) -> loss_node(model, data, target, x, st), + (x, data) -> loss_node(model, data, x, st), Optimization.AutoZygote()) -optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev) -res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb) +optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader) +res = solve(optprob, opt; callback = cb, epochs = 100) plot_contour(model, res.u, st) ``` @@ -229,7 +229,7 @@ We use the L2 distance between the model prediction `model(x)` and the actual pr optimization objective. ```@example augneuralode -loss_node(model, x, y, ps, st) = mean((first(model(x, ps, st)) .- y) .^ 2) +loss_node(model, data, ps, st) = mean((first(model(data[1], ps, st)) .- data[2]) .^ 2) ``` #### Dataset @@ -248,7 +248,7 @@ Additionally, we define a callback function which displays the total loss at spe ```@example augneuralode iter = 0 -cb = function (ps, l) +cb = function (state, l) global iter iter += 1 if iter % 10 == 0 @@ -276,10 +276,10 @@ for `20` epochs. model, ps, st = construct_model(1, 2, 64, 0) optfunc = OptimizationFunction( - (x, p, data, target) -> loss_node(model, data, target, x, st), + (x, data) -> loss_node(model, data, x, st), Optimization.AutoZygote()) -optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev) -res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb) +optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader) +res = solve(optprob, opt; callback = cb, epochs = 100) plot_contour(model, res.u, st) ``` @@ -297,10 +297,10 @@ a function which can be expressed by the neural ode. For more details and proofs model, ps, st = construct_model(1, 2, 64, 1) optfunc = OptimizationFunction( - (x, p, data, target) -> loss_node(model, data, target, x, st), + (x, data) -> loss_node(model, data, x, st), Optimization.AutoZygote()) -optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev) -res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb) +optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader) +res = solve(optprob, opt; callback = cb, epochs = 100) plot_contour(model, res.u, st) ``` diff --git a/docs/src/examples/hamiltonian_nn.md b/docs/src/examples/hamiltonian_nn.md index 9c1716bad..87e2750f2 100644 --- a/docs/src/examples/hamiltonian_nn.md +++ b/docs/src/examples/hamiltonian_nn.md @@ -14,7 +14,7 @@ Before getting to the explanation, here's some code to start with. We will follo ```@example hamiltonian_cp using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random, - ComponentArrays, Optimization, OptimizationOptimisers, IterTools + ComponentArrays, Optimization, OptimizationOptimisers, MLUtils t = range(0.0f0, 1.0f0; length = 1024) π_32 = Float32(π) @@ -23,37 +23,33 @@ p_t = reshape(cos.(2π_32 * t), 1, :) dqdt = 2π_32 .* p_t dpdt = -2π_32 .* q_t -data = vcat(q_t, p_t) -target = vcat(dqdt, dpdt) +data = cat(q_t, p_t; dims = 1) +target = cat(dqdt, dpdt; dims = 1) B = 256 -NEPOCHS = 100 -dataloader = ncycle( - ((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))), - selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2))))) - for i in 1:(size(data, 2) ÷ B)), - NEPOCHS) - -hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64, 1)); autodiff = AutoZygote()) +NEPOCHS = 500 +dataloader = DataLoader((data, target); batchsize = B) + +hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (1028, 1)); autodiff = AutoZygote()) ps, st = Lux.setup(Xoshiro(0), hnn) ps_c = ps |> ComponentArray opt = OptimizationOptimisers.Adam(0.01f0) -function loss_function(ps, data, target) +function loss_function(ps, databatch) + data, target = databatch pred, st_ = hnn(data, ps, st) - return mean(abs2, pred .- target), pred + return mean(abs2, pred .- target) end -function callback(ps, loss, pred) +function callback(state, loss) println("[Hamiltonian NN] Loss: ", loss) return false end -opt_func = OptimizationFunction((ps, _, data, target) -> loss_function(ps, data, target), - Optimization.AutoForwardDiff()) -opt_prob = OptimizationProblem(opt_func, ps_c) +opt_func = OptimizationFunction(loss_function, Optimization.AutoForwardDiff()) +opt_prob = OptimizationProblem(opt_func, ps_c, dataloader) -res = Optimization.solve(opt_prob, opt, dataloader; callback) +res = Optimization.solve(opt_prob, opt; callback, epochs = NEPOCHS) ps_trained = res.u @@ -75,7 +71,7 @@ The HNN predicts the gradients ``(\dot q, \dot p)`` given ``(q, p)``. Hence, we ```@example hamiltonian using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random, - ComponentArrays, Optimization, OptimizationOptimisers, IterTools + ComponentArrays, Optimization, OptimizationOptimisers, MLUtils t = range(0.0f0, 1.0f0; length = 1024) π_32 = Float32(π) @@ -87,12 +83,8 @@ dpdt = -2π_32 .* q_t data = cat(q_t, p_t; dims = 1) target = cat(dqdt, dpdt; dims = 1) B = 256 -NEPOCHS = 100 -dataloader = ncycle( - ((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))), - selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2))))) - for i in 1:(size(data, 2) ÷ B)), - NEPOCHS) +NEPOCHS = 500 +dataloader = DataLoader((data, target); batchsize = B) ``` ### Training the HamiltonianNN @@ -100,27 +92,28 @@ dataloader = ncycle( We parameterize the with a small MultiLayered Perceptron. HNNs are trained by optimizing the gradients of the Neural Network. Zygote currently doesn't support nesting itself, so we will be using ForwardDiff in the training loop to compute the gradients of the HNN Layer for Optimization. ```@example hamiltonian -hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64, 1)); autodiff = AutoZygote()) +hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (1028, 1)); autodiff = AutoZygote()) ps, st = Lux.setup(Xoshiro(0), hnn) ps_c = ps |> ComponentArray +hnn_stateful = StatefulLuxLayer{true}(hnn, ps_c, st) -opt = OptimizationOptimisers.Adam(0.01f0) +opt = OptimizationOptimisers.Adam(0.005f0) -function loss_function(ps, data, target) - pred, st_ = hnn(data, ps, st) - return mean(abs2, pred .- target), pred +function loss_function(ps, databatch) + (data, target) = databatch + pred = hnn_stateful(data, ps) + return mean(abs2, pred .- target) end -function callback(ps, loss, pred) +function callback(state, loss) println("[Hamiltonian NN] Loss: ", loss) return false end -opt_func = OptimizationFunction( - (ps, _, data, target) -> loss_function(ps, data, target), Optimization.AutoZygote()) -opt_prob = OptimizationProblem(opt_func, ps_c) +opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps_c, dataloader) -res = solve(opt_prob, opt, dataloader; callback) +res = Optimization.solve(opt_prob, opt; callback, epochs = NEPOCHS) ps_trained = res.u ``` diff --git a/docs/src/examples/mnist_conv_neural_ode.md b/docs/src/examples/mnist_conv_neural_ode.md index f67262780..80f68bfe6 100644 --- a/docs/src/examples/mnist_conv_neural_ode.md +++ b/docs/src/examples/mnist_conv_neural_ode.md @@ -89,30 +89,30 @@ end # burn in accuracy accuracy(m, ((img, lab),), ps, st) -function loss_function(ps, x, y) +function loss_function(ps, data) + (x, y) = data pred, _ = m(x, ps, st) - return logitcrossentropy(pred, y), pred + return logitcrossentropy(pred, y) end # burn in loss -loss_function(ps, img, lab) +loss_function(ps, (img, lab)) opt = OptimizationOptimisers.Adam(0.005) iter = 0 -opt_func = OptimizationFunction( - (ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote()) -opt_prob = OptimizationProblem(opt_func, ps); +opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps, dataloader); -function callback(ps, l, pred) +function callback(state, l) global iter += 1 iter % 10 == 0 && - @info "[MNIST Conv GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))" + @info "[MNIST Conv GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))" return false end # Train the NN-ODE and monitor the loss and weights. -res = Optimization.solve(opt_prob, opt, dataloader; maxiters = 5, callback) +res = Optimization.solve(opt_prob, opt; epochs = 5, callback) acc = accuracy(m, dataloader, res.u, st) acc # hide ``` diff --git a/docs/src/examples/mnist_neural_ode.md b/docs/src/examples/mnist_neural_ode.md index 349dbcff2..158cc7a67 100644 --- a/docs/src/examples/mnist_neural_ode.md +++ b/docs/src/examples/mnist_neural_ode.md @@ -81,29 +81,29 @@ end accuracy(m, ((x_train1, y_train1),), ps, st) # burn in accuracy -function loss_function(ps, x, y) +function loss_function(ps, data) + (x, y) = data pred, st_ = m(x, ps, st) - return logitcrossentropy(pred, y), pred + return logitcrossentropy(pred, y) end -loss_function(ps, x_train1, y_train1) # burn in loss +loss_function(ps, (x_train1, y_train1)) # burn in loss opt = OptimizationOptimisers.Adam(0.05) iter = 0 -opt_func = OptimizationFunction( - (ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote()) -opt_prob = OptimizationProblem(opt_func, ps) +opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps, dataloader) -function callback(ps, l, pred) +function callback(state, l) global iter += 1 iter % 10 == 0 && - @info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))" + @info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))" return false end # Train the NN-ODE and monitor the loss and weights. -res = Optimization.solve(opt_prob, opt, dataloader; callback, maxiters = 5) +res = Optimization.solve(opt_prob, opt; callback, epochs = 5) accuracy(m, dataloader, res.u, st) ``` @@ -285,12 +285,13 @@ final output of our model. `logitcrossentropy` takes in the prediction from our model `model(x)` and compares it to actual output `y`: ```@example mnist -function loss_function(ps, x, y) +function loss_function(ps, data) + (x, y) = data pred, st_ = m(x, ps, st) - return logitcrossentropy(pred, y), pred + return logitcrossentropy(pred, y) end -loss_function(ps, x_train1, y_train1) # burn in loss +loss_function(ps, (x_train1, y_train1)) # burn in loss ``` #### Optimizer @@ -309,14 +310,13 @@ This callback function is used to print both the training and testing accuracy a ```@example mnist iter = 0 -opt_func = OptimizationFunction( - (ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote()) -opt_prob = OptimizationProblem(opt_func, ps) +opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps, dataloader) -function callback(ps, l, pred) +function callback(state, l) global iter += 1 iter % 10 == 0 && - @info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))" + @info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))" return false end ``` @@ -329,6 +329,6 @@ for Neural ODE is given by `nn_ode.p`: ```@example mnist # Train the NN-ODE and monitor the loss and weights. -res = Optimization.solve(opt_prob, opt, dataloader; callback, maxiters = 5) +res = Optimization.solve(opt_prob, opt; callback, epochs = 5) accuracy(m, dataloader, res.u, st) ``` diff --git a/docs/src/examples/multiple_shooting.md b/docs/src/examples/multiple_shooting.md index 2b20219f5..6d08f5eed 100644 --- a/docs/src/examples/multiple_shooting.md +++ b/docs/src/examples/multiple_shooting.md @@ -48,9 +48,32 @@ ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps)) nn = Chain(x -> x .^ 3, Dense(2, 16, tanh), Dense(16, 2)) p_init, st = Lux.setup(rng, nn) +ps = ComponentArray(p_init) +pd, pax = getdata(ps), getaxes(ps) + neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps) prob_node = ODEProblem((u, p, t) -> nn(u, p, st)[1], u0, tspan, ComponentArray(p_init)) +# Define parameters for Multiple Shooting +group_size = 3 +continuity_term = 200 + +function loss_function(data, pred) + return sum(abs2, data - pred) +end + +l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, + Tsit5(), group_size; continuity_term) + +function loss_multiple_shooting(p) + ps = ComponentArray(p, pax) + + loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, + Tsit5(), group_size; continuity_term) + global preds = currpred + return loss +end + function plot_multiple_shoot(plt, preds, group_size) step = group_size - 1 ranges = group_ranges(datasize, group_size) @@ -62,15 +85,17 @@ end anim = Plots.Animation() iter = 0 -callback = function (p, l, preds; doplot = true) +function callback(state, l; doplot = true, prob_node = prob_node) display(l) global iter iter += 1 if doplot && iter % 1 == 0 # plot the original data plt = scatter(tsteps, ode_data[1, :]; label = "Data") - # plot the different predictions for individual shoot + l1, preds = multiple_shoot( + ComponentArray(state.u, pax), ode_data, tsteps, prob_node, loss_function, + Tsit5(), group_size; continuity_term) plot_multiple_shoot(plt, preds, group_size) frame(anim) @@ -79,27 +104,10 @@ callback = function (p, l, preds; doplot = true) return false end -# Define parameters for Multiple Shooting -group_size = 3 -continuity_term = 200 - -function loss_function(data, pred) - return sum(abs2, data - pred) -end - -ps = ComponentArray(p_init) -pd, pax = getdata(ps), getaxes(ps) - -function loss_multiple_shooting(p) - ps = ComponentArray(p, pax) - return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, - Tsit5(), group_size; continuity_term) -end - adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_multiple_shooting(x), adtype) optprob = Optimization.OptimizationProblem(optf, pd) -res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback) +res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback, maxiters = 300) gif(anim, "multiple_shooting.gif"; fps = 15) ``` @@ -119,14 +127,16 @@ pd, pax = getdata(ps), getaxes(ps) function loss_single_shooting(p) ps = ComponentArray(p, pax) - return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, + loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term) + global preds = currpred + return loss end adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_single_shooting(x), adtype) optprob = Optimization.OptimizationProblem(optf, pd) -res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback) +res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback, maxiters = 300) gif(anim, "single_shooting.gif"; fps = 15) ``` diff --git a/docs/src/examples/neural_gde.md b/docs/src/examples/neural_gde.md index e50c70245..ce0e2c730 100644 --- a/docs/src/examples/neural_gde.md +++ b/docs/src/examples/neural_gde.md @@ -4,7 +4,7 @@ This tutorial has not been ran or updated in awhile. -This tutorial has been adapted from [here](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/master/examples/neural_ode_cora.jl). +This tutorial has been adapted from [here](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/master/GraphNeuralNetworks/examples/neural_ode_cora.jl). In this tutorial, we will use Graph Differential Equations (GDEs) to perform classification on the [CORA Dataset](https://paperswithcode.com/dataset/cora). We shall be using the Graph Neural Networks primitives from the package [GraphNeuralNetworks](https://github.com/CarloLucibello/GraphNeuralNetworks.jl). diff --git a/docs/src/examples/neural_ode.md b/docs/src/examples/neural_ode.md index c1dda9248..699eea94f 100644 --- a/docs/src/examples/neural_ode.md +++ b/docs/src/examples/neural_ode.md @@ -40,15 +40,16 @@ end function loss_neuralode(p) pred = predict_neuralode(p) loss = sum(abs2, ode_data .- pred) - return loss, pred + return loss end # Do not plot by default for the documentation # Users should change doplot=true to see the plots callbacks -callback = function (p, l, pred; doplot = false) +function callback(state, l; doplot = false) println(l) # plot current prediction against data if doplot + pred = predict_neuralode(state.u) plt = scatter(tsteps, ode_data[1, :]; label = "data") scatter!(plt, tsteps, pred[1, :]; label = "prediction") display(plot(plt)) @@ -57,7 +58,7 @@ callback = function (p, l, pred; doplot = false) end pinit = ComponentArray(p) -callback(pinit, loss_neuralode(pinit)...; doplot = true) +callback((; u = pinit), loss_neuralode(pinit); doplot = true) # use Optimization.jl to solve the problem adtype = Optimization.AutoZygote() @@ -73,7 +74,7 @@ optprob2 = remake(optprob; u0 = result_neuralode.u) result_neuralode2 = Optimization.solve( optprob2, Optim.BFGS(; initial_stepnorm = 0.01); callback, allow_f_increases = false) -callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot = true) +callback((; u = result_neuralode2.u), loss_neuralode(result_neuralode2.u); doplot = true) ``` ![Neural ODE](https://user-images.githubusercontent.com/1814174/88589293-e8207f80-d026-11ea-86e2-8a3feb8252ca.gif) @@ -134,7 +135,7 @@ end function loss_neuralode(p) pred = predict_neuralode(p) loss = sum(abs2, ode_data .- pred) - return loss, pred + return loss end ``` @@ -143,10 +144,11 @@ it would show every step and overflow the documentation, but for your use case s ```@example neuralode # Callback function to observe training -callback = function (p, l, pred; doplot = false) +callback = function (state, l; doplot = false) println(l) # plot current prediction against data if doplot + pred = predict_neuralode(state.u) plt = scatter(tsteps, ode_data[1, :]; label = "data") scatter!(plt, tsteps, pred[1, :]; label = "prediction") display(plot(plt)) @@ -155,7 +157,7 @@ callback = function (p, l, pred; doplot = false) end pinit = ComponentArray(p) -callback(pinit, loss_neuralode(pinit)...) +callback((; u = pinit), loss_neuralode(pinit)) ``` We then train the neural network to learn the ODE. @@ -198,8 +200,8 @@ result_neuralode2 = Optimization.solve(optprob2, Optim.BFGS(; initial_stepnorm = And then we use the callback with `doplot=true` to see the final plot: ```@example neuralode -callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot = true) +callback((; u = result_neuralode2.u), loss_neuralode(result_neuralode2.u); doplot = true) plt = scatter(tsteps, ode_data[1, :]; label = "data") # hide -scatter!(plt, tsteps, loss_neuralode(result_neuralode2.u)[2][1, :]; label = "prediction") # hide +scatter!(plt, tsteps, predict_neuralode(result_neuralode2.u)[1, :]; label = "prediction") # hide plt # hide ``` diff --git a/docs/src/examples/neural_ode_weather_forecast.md b/docs/src/examples/neural_ode_weather_forecast.md index 71553a504..8f69acfe5 100644 --- a/docs/src/examples/neural_ode_weather_forecast.md +++ b/docs/src/examples/neural_ode_weather_forecast.md @@ -122,8 +122,8 @@ function train_one_round(node, p, state, y, opt, maxiters, rng, y0 = y[:, 1]; kw end function train(t, y, obs_grid, maxiters, lr, rng, p = nothing, state = nothing; kwargs...) - log_results(ps, losses) = (p, loss) -> begin - push!(ps, copy(p.u)) + log_results(ps, losses) = (state, loss) -> begin + push!(ps, copy(state.u)) push!(losses, loss) false end diff --git a/docs/src/examples/neural_sde.md b/docs/src/examples/neural_sde.md index 88eddd091..555b1453f 100644 --- a/docs/src/examples/neural_sde.md +++ b/docs/src/examples/neural_sde.md @@ -86,8 +86,8 @@ Let's see what that looks like: # Get the prediction using the correct initial condition prediction0 = neuralsde(u0, ps, st)[1] -drift_model = StatefulLuxLayer{true}(drift_dudt, nothing, st.drift) -diffusion_model = StatefulLuxLayer{true}(diffusion_dudt, nothing, st.diffusion) +drift_model = StatefulLuxLayer{true}(drift_dudt, ps.drift, st.drift) +diffusion_model = StatefulLuxLayer{true}(diffusion_dudt, ps.diffusion, st.diffusion) drift_(u, p, t) = drift_model(u, p.drift) diffusion_(u, p, t) = diffusion_model(u, p.diffusion) @@ -110,7 +110,7 @@ mean and variance from `n` runs at each time point and uses the distance from the data values: ```@example nsde -neuralsde_model = StatefulLuxLayer{true}(neuralsde, nothing, st) +neuralsde_model = StatefulLuxLayer{true}(neuralsde, ps, st) function predict_neuralsde(p, u = u0) return Array(neuralsde_model(u, p)) @@ -119,21 +119,28 @@ end function loss_neuralsde(p; n = 100) u = repeat(reshape(u0, :, 1), 1, n) samples = predict_neuralsde(p, u) - means = mean(samples; dims = 2) - vars = var(samples; dims = 2, mean = means)[:, 1, :] - means = means[:, 1, :] - loss = sum(abs2, sde_data - means) + sum(abs2, sde_data_vars - vars) - return loss, means, vars + currmeans = mean(samples; dims = 2) + currvars = var(samples; dims = 2, mean = currmeans)[:, 1, :] + currmeans = currmeans[:, 1, :] + loss = sum(abs2, sde_data - currmeans) + sum(abs2, sde_data_vars - currvars) + global means = currmeans + global vars = currvars + return loss end ``` ```@example nsde list_plots = [] iter = 0 +u = repeat(reshape(u0, :, 1), 1, 100) +samples = predict_neuralsde(ps, u) +means = mean(samples; dims = 2) +vars = var(samples; dims = 2, mean = means)[:, 1, :] +means = means[:, 1, :] # Callback function to observe training -callback = function (p, loss, means, vars; doplot = false) - global list_plots, iter +callback = function (state, loss; doplot = false) + global list_plots, iter, means, vars if iter == 0 list_plots = [] @@ -174,15 +181,21 @@ We resume the training with a larger `n`. (WARNING - this step is a couple of orders of magnitude longer than the previous one). ```@example nsde +opt = OptimizationOptimisers.Adam(0.001) optf2 = Optimization.OptimizationFunction((x, p) -> loss_neuralsde(x; n = 100), adtype) optprob2 = Optimization.OptimizationProblem(optf2, result1.u) -result2 = Optimization.solve(optprob2, opt; callback, maxiters = 20) +result2 = Optimization.solve(optprob2, opt; callback, maxiters = 100) ``` And now we plot the solution to an ensemble of the trained neural SDE: ```@example nsde -_, means, vars = loss_neuralsde(result2.u; n = 1000) +n = 1000 +u = repeat(reshape(u0, :, 1), 1, n) +samples = predict_neuralsde(result2.u) +currmeans = mean(samples; dims = 2) +currvars = var(samples; dims = 2, mean = currmeans)[:, 1, :] +currmeans = currmeans[:, 1, :] plt2 = Plots.scatter(tsteps, sde_data'; yerror = sde_data_vars', label = "data", title = "Neural SDE: After Training", xlabel = "Time") diff --git a/docs/src/examples/normalizing_flows.md b/docs/src/examples/normalizing_flows.md index 6d4898ae3..aa88af6cd 100644 --- a/docs/src/examples/normalizing_flows.md +++ b/docs/src/examples/normalizing_flows.md @@ -28,7 +28,7 @@ function loss(θ) return -mean(logpx) end -function cb(p, l) +function cb(state, l) @info "FFJORD Training" loss=l return false end @@ -95,7 +95,7 @@ function loss(θ) return -mean(logpx) end -function cb(p, l) +function cb(state, l) @info "FFJORD Training" loss=loss(p) return false end diff --git a/docs/src/examples/physical_constraints.md b/docs/src/examples/physical_constraints.md index 723344619..430ec4e1e 100644 --- a/docs/src/examples/physical_constraints.md +++ b/docs/src/examples/physical_constraints.md @@ -50,7 +50,7 @@ end function loss_stiff_ndae(p) pred = predict_stiff_ndae(p) loss = sum(abs2, Array(sol_stiff) .- pred) - return loss, pred + return loss end # callback = function (state, l, pred) #callback function to observe training @@ -172,7 +172,7 @@ from these predictions. In this case, we use **least squares** as our loss. function loss_stiff_ndae(p) pred = predict_stiff_ndae(p) loss = sum(abs2, sol_stiff .- pred) - return loss, pred + return loss end l1 = first(loss_stiff_ndae(ComponentArray(pinit))) @@ -191,7 +191,7 @@ The optimizer is `BFGS`(see below). The callback function displays the loss during training. ```@example dae2 -callback = function (state, l, pred) #callback function to observe training +callback = function (state, l) #callback function to observe training display(l) return false end diff --git a/test/neural_ode_mm_tests.jl b/test/neural_ode_mm_tests.jl index 964ea42e9..0a1acc83a 100644 --- a/test/neural_ode_mm_tests.jl +++ b/test/neural_ode_mm_tests.jl @@ -32,10 +32,10 @@ function loss(p) pred = first(ndae(u₀, p, st)) loss = sum(abs2, Array(sol) .- pred) - return loss, pred + return loss end - cb = function (p, l, pred) + cb = function (state, l) @info "[NeuralODEMM] Loss: $l" return false end