From ed0cd77cbb3131dd53a9e5f498e38d87c82dbf5b Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sun, 28 Nov 2021 03:49:31 +0330 Subject: [PATCH] add DistributionsAD --- Project.toml | 2 ++ src/ICNF.jl | 2 ++ src/ffjord.jl | 12 ++++++------ src/rnode.jl | 12 ++++++------ test/ffjord.jl | 20 ++++++++++---------- test/rnode.jl | 20 ++++++++++---------- 6 files changed, 36 insertions(+), 32 deletions(-) diff --git a/Project.toml b/Project.toml index 5d1188ab..3613aeaa 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" @@ -17,6 +18,7 @@ MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/ICNF.jl b/src/ICNF.jl index fc10260d..af8ce31c 100644 --- a/src/ICNF.jl +++ b/src/ICNF.jl @@ -6,12 +6,14 @@ module ICNF DataFrames, DiffEqSensitivity, Distributions, + DistributionsAD, Flux, MLJBase, MLJFlux, MLJModelInterface, OrdinaryDiffEq, Parameters, + SciMLBase, ScientificTypes, Zygote, LinearAlgebra, diff --git a/src/ffjord.jl b/src/ffjord.jl index 6b5d00c7..364e2219 100644 --- a/src/ffjord.jl +++ b/src/ffjord.jl @@ -13,11 +13,11 @@ Implementations of FFJORD from basedist::Distribution = MvNormal(zeros(T, nvars), Diagonal(ones(T, nvars))) tspan::Tuple{T, T} = convert.(T, (0, 1)) - solver_test::OrdinaryDiffEqAlgorithm = default_solver_test - solver_train::OrdinaryDiffEqAlgorithm = default_solver_train + solver_test::SciMLBase.AbstractODEAlgorithm = default_solver_test + solver_train::SciMLBase.AbstractODEAlgorithm = default_solver_train sensealg_test::SciMLBase.AbstractSensitivityAlgorithm = default_sensealg - sensealg_train::SciMLBase.AbstractSensitivityAlgorithm = sensealg_test + sensealg_train::SciMLBase.AbstractSensitivityAlgorithm = default_sensealg acceleration::AbstractResource = default_acceleration @@ -32,11 +32,11 @@ function FFJORD{T}( basedist::Distribution=MvNormal(zeros(T, nvars), Diagonal(ones(T, nvars))), tspan::Tuple{T, T}=convert.(T, (0, 1)), - solver_test::OrdinaryDiffEqAlgorithm=default_solver_test, - solver_train::OrdinaryDiffEqAlgorithm=default_solver_train, + solver_test::SciMLBase.AbstractODEAlgorithm=default_solver_test, + solver_train::SciMLBase.AbstractODEAlgorithm=default_solver_train, sensealg_test::SciMLBase.AbstractSensitivityAlgorithm=default_sensealg, - sensealg_train::SciMLBase.AbstractSensitivityAlgorithm=sensealg_test, + sensealg_train::SciMLBase.AbstractSensitivityAlgorithm=default_sensealg, acceleration::AbstractResource=default_acceleration, ) where {T <: AbstractFloat} diff --git a/src/rnode.jl b/src/rnode.jl index c9fd4ef0..ebec1042 100644 --- a/src/rnode.jl +++ b/src/rnode.jl @@ -13,11 +13,11 @@ Implementations of RNODE from basedist::Distribution = MvNormal(zeros(T, nvars), Diagonal(ones(T, nvars))) tspan::Tuple{T, T} = convert.(T, (0, 1)) - solver_test::OrdinaryDiffEqAlgorithm = default_solver_test - solver_train::OrdinaryDiffEqAlgorithm = default_solver_train + solver_test::SciMLBase.AbstractODEAlgorithm = default_solver_test + solver_train::SciMLBase.AbstractODEAlgorithm = default_solver_train sensealg_test::SciMLBase.AbstractSensitivityAlgorithm = default_sensealg - sensealg_train::SciMLBase.AbstractSensitivityAlgorithm = sensealg_test + sensealg_train::SciMLBase.AbstractSensitivityAlgorithm = default_sensealg acceleration::AbstractResource = default_acceleration @@ -32,11 +32,11 @@ function RNODE{T}( basedist::Distribution=MvNormal(zeros(T, nvars), Diagonal(ones(T, nvars))), tspan::Tuple{T, T}=convert.(T, (0, 1)), - solver_test::OrdinaryDiffEqAlgorithm=default_solver_test, - solver_train::OrdinaryDiffEqAlgorithm=default_solver_train, + solver_test::SciMLBase.AbstractODEAlgorithm=default_solver_test, + solver_train::SciMLBase.AbstractODEAlgorithm=default_solver_train, sensealg_test::SciMLBase.AbstractSensitivityAlgorithm=default_sensealg, - sensealg_train::SciMLBase.AbstractSensitivityAlgorithm=sensealg_test, + sensealg_train::SciMLBase.AbstractSensitivityAlgorithm=default_sensealg, acceleration::AbstractResource=default_acceleration, ) where {T <: AbstractFloat} diff --git a/test/ffjord.jl b/test/ffjord.jl index 91954f47..43230917 100644 --- a/test/ffjord.jl +++ b/test/ffjord.jl @@ -3,28 +3,28 @@ cr in [CPU1(), CUDALibs()], tp in [Float64, Float32, Float16], nvars in 1:3 - ffjord = FFJORD{tp}(Dense(nvars, nvars), nvars; acceleration=cr) - ufd = copy(ffjord.p) + icnf = FFJORD{tp}(Dense(nvars, nvars), nvars; acceleration=cr) + ufd = copy(icnf.p) n = 8 r = rand(tp, nvars, n) - @test !isnothing(inference(ffjord, TestMode(), r)) - @test !isnothing(inference(ffjord, TrainMode(), r)) + @test !isnothing(inference(icnf, TestMode(), r)) + @test !isnothing(inference(icnf, TrainMode(), r)) - @test !isnothing(generate(ffjord, TestMode(), n)) - @test !isnothing(generate(ffjord, TrainMode(), n)) + @test !isnothing(generate(icnf, TestMode(), n)) + @test !isnothing(generate(icnf, TrainMode(), n)) - @test !isnothing(ffjord(r)) - @test !isnothing(loss_f(ffjord)(r)) + @test !isnothing(icnf(r)) + @test !isnothing(loss_f(icnf)(r)) - d = ICNFDistribution(; m=ffjord) + d = ICNFDistribution(; m=icnf) @test !isnothing(logpdf(d, r)) @test !isnothing(pdf(d, r)) @test !isnothing(rand(d, n)) df = DataFrame(r', :auto) - model = ICNFModel(; m=ffjord, n_epochs=8) + model = ICNFModel(; m=icnf, n_epochs=8) mach = machine(model, df) fit!(mach) fd = MLJBase.fitted_params(mach).learned_parameters diff --git a/test/rnode.jl b/test/rnode.jl index 32e8d14e..e199d58e 100644 --- a/test/rnode.jl +++ b/test/rnode.jl @@ -3,28 +3,28 @@ cr in [CPU1(), CUDALibs()], tp in [Float64, Float32, Float16], nvars in 1:3 - rnode = RNODE{tp}(Dense(nvars, nvars), nvars; acceleration=cr) - ufd = copy(rnode.p) + icnf = RNODE{tp}(Dense(nvars, nvars), nvars; acceleration=cr) + ufd = copy(icnf.p) n = 8 r = rand(tp, nvars, n) - @test !isnothing(inference(rnode, TestMode(), r)) - @test !isnothing(inference(rnode, TrainMode(), r)) + @test !isnothing(inference(icnf, TestMode(), r)) + @test !isnothing(inference(icnf, TrainMode(), r)) - @test !isnothing(generate(rnode, TestMode(), n)) - @test !isnothing(generate(rnode, TrainMode(), n)) + @test !isnothing(generate(icnf, TestMode(), n)) + @test !isnothing(generate(icnf, TrainMode(), n)) - @test !isnothing(rnode(r)) - @test !isnothing(loss_f(rnode)(r)) + @test !isnothing(icnf(r)) + @test !isnothing(loss_f(icnf)(r)) - d = ICNFDistribution(; m=rnode) + d = ICNFDistribution(; m=icnf) @test !isnothing(logpdf(d, r)) @test !isnothing(pdf(d, r)) @test !isnothing(rand(d, n)) df = DataFrame(r', :auto) - model = ICNFModel(; m=rnode, n_epochs=8) + model = ICNFModel(; m=icnf, n_epochs=8) mach = machine(model, df) fit!(mach) fd = MLJBase.fitted_params(mach).learned_parameters