Skip to content

Commit

Permalink
Merge pull request #24 from prbzrg/add-d-ad
Browse files Browse the repository at this point in the history
add DistributionsAD
  • Loading branch information
prbzrg authored Nov 28, 2021
2 parents 3cf0b39 + ed0cd77 commit 2456f74
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 32 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Expand All @@ -17,6 +18,7 @@ MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand Down
2 changes: 2 additions & 0 deletions src/ICNF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ module ICNF
DataFrames,
DiffEqSensitivity,
Distributions,
DistributionsAD,
Flux,
MLJBase,
MLJFlux,
MLJModelInterface,
OrdinaryDiffEq,
Parameters,
SciMLBase,
ScientificTypes,
Zygote,
LinearAlgebra,
Expand Down
12 changes: 6 additions & 6 deletions src/ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ Implementations of FFJORD from
basedist::Distribution = MvNormal(zeros(T, nvars), Diagonal(ones(T, nvars)))
tspan::Tuple{T, T} = convert.(T, (0, 1))

solver_test::OrdinaryDiffEqAlgorithm = default_solver_test
solver_train::OrdinaryDiffEqAlgorithm = default_solver_train
solver_test::SciMLBase.AbstractODEAlgorithm = default_solver_test
solver_train::SciMLBase.AbstractODEAlgorithm = default_solver_train

sensealg_test::SciMLBase.AbstractSensitivityAlgorithm = default_sensealg
sensealg_train::SciMLBase.AbstractSensitivityAlgorithm = sensealg_test
sensealg_train::SciMLBase.AbstractSensitivityAlgorithm = default_sensealg

acceleration::AbstractResource = default_acceleration

Expand All @@ -32,11 +32,11 @@ function FFJORD{T}(
basedist::Distribution=MvNormal(zeros(T, nvars), Diagonal(ones(T, nvars))),
tspan::Tuple{T, T}=convert.(T, (0, 1)),

solver_test::OrdinaryDiffEqAlgorithm=default_solver_test,
solver_train::OrdinaryDiffEqAlgorithm=default_solver_train,
solver_test::SciMLBase.AbstractODEAlgorithm=default_solver_test,
solver_train::SciMLBase.AbstractODEAlgorithm=default_solver_train,

sensealg_test::SciMLBase.AbstractSensitivityAlgorithm=default_sensealg,
sensealg_train::SciMLBase.AbstractSensitivityAlgorithm=sensealg_test,
sensealg_train::SciMLBase.AbstractSensitivityAlgorithm=default_sensealg,

acceleration::AbstractResource=default_acceleration,
) where {T <: AbstractFloat}
Expand Down
12 changes: 6 additions & 6 deletions src/rnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ Implementations of RNODE from
basedist::Distribution = MvNormal(zeros(T, nvars), Diagonal(ones(T, nvars)))
tspan::Tuple{T, T} = convert.(T, (0, 1))

solver_test::OrdinaryDiffEqAlgorithm = default_solver_test
solver_train::OrdinaryDiffEqAlgorithm = default_solver_train
solver_test::SciMLBase.AbstractODEAlgorithm = default_solver_test
solver_train::SciMLBase.AbstractODEAlgorithm = default_solver_train

sensealg_test::SciMLBase.AbstractSensitivityAlgorithm = default_sensealg
sensealg_train::SciMLBase.AbstractSensitivityAlgorithm = sensealg_test
sensealg_train::SciMLBase.AbstractSensitivityAlgorithm = default_sensealg

acceleration::AbstractResource = default_acceleration

Expand All @@ -32,11 +32,11 @@ function RNODE{T}(
basedist::Distribution=MvNormal(zeros(T, nvars), Diagonal(ones(T, nvars))),
tspan::Tuple{T, T}=convert.(T, (0, 1)),

solver_test::OrdinaryDiffEqAlgorithm=default_solver_test,
solver_train::OrdinaryDiffEqAlgorithm=default_solver_train,
solver_test::SciMLBase.AbstractODEAlgorithm=default_solver_test,
solver_train::SciMLBase.AbstractODEAlgorithm=default_solver_train,

sensealg_test::SciMLBase.AbstractSensitivityAlgorithm=default_sensealg,
sensealg_train::SciMLBase.AbstractSensitivityAlgorithm=sensealg_test,
sensealg_train::SciMLBase.AbstractSensitivityAlgorithm=default_sensealg,

acceleration::AbstractResource=default_acceleration,
) where {T <: AbstractFloat}
Expand Down
20 changes: 10 additions & 10 deletions test/ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,28 @@
cr in [CPU1(), CUDALibs()],
tp in [Float64, Float32, Float16],
nvars in 1:3
ffjord = FFJORD{tp}(Dense(nvars, nvars), nvars; acceleration=cr)
ufd = copy(ffjord.p)
icnf = FFJORD{tp}(Dense(nvars, nvars), nvars; acceleration=cr)
ufd = copy(icnf.p)
n = 8
r = rand(tp, nvars, n)

@test !isnothing(inference(ffjord, TestMode(), r))
@test !isnothing(inference(ffjord, TrainMode(), r))
@test !isnothing(inference(icnf, TestMode(), r))
@test !isnothing(inference(icnf, TrainMode(), r))

@test !isnothing(generate(ffjord, TestMode(), n))
@test !isnothing(generate(ffjord, TrainMode(), n))
@test !isnothing(generate(icnf, TestMode(), n))
@test !isnothing(generate(icnf, TrainMode(), n))

@test !isnothing(ffjord(r))
@test !isnothing(loss_f(ffjord)(r))
@test !isnothing(icnf(r))
@test !isnothing(loss_f(icnf)(r))

d = ICNFDistribution(; m=ffjord)
d = ICNFDistribution(; m=icnf)

@test !isnothing(logpdf(d, r))
@test !isnothing(pdf(d, r))
@test !isnothing(rand(d, n))

df = DataFrame(r', :auto)
model = ICNFModel(; m=ffjord, n_epochs=8)
model = ICNFModel(; m=icnf, n_epochs=8)
mach = machine(model, df)
fit!(mach)
fd = MLJBase.fitted_params(mach).learned_parameters
Expand Down
20 changes: 10 additions & 10 deletions test/rnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,28 @@
cr in [CPU1(), CUDALibs()],
tp in [Float64, Float32, Float16],
nvars in 1:3
rnode = RNODE{tp}(Dense(nvars, nvars), nvars; acceleration=cr)
ufd = copy(rnode.p)
icnf = RNODE{tp}(Dense(nvars, nvars), nvars; acceleration=cr)
ufd = copy(icnf.p)
n = 8
r = rand(tp, nvars, n)

@test !isnothing(inference(rnode, TestMode(), r))
@test !isnothing(inference(rnode, TrainMode(), r))
@test !isnothing(inference(icnf, TestMode(), r))
@test !isnothing(inference(icnf, TrainMode(), r))

@test !isnothing(generate(rnode, TestMode(), n))
@test !isnothing(generate(rnode, TrainMode(), n))
@test !isnothing(generate(icnf, TestMode(), n))
@test !isnothing(generate(icnf, TrainMode(), n))

@test !isnothing(rnode(r))
@test !isnothing(loss_f(rnode)(r))
@test !isnothing(icnf(r))
@test !isnothing(loss_f(icnf)(r))

d = ICNFDistribution(; m=rnode)
d = ICNFDistribution(; m=icnf)

@test !isnothing(logpdf(d, r))
@test !isnothing(pdf(d, r))
@test !isnothing(rand(d, n))

df = DataFrame(r', :auto)
model = ICNFModel(; m=rnode, n_epochs=8)
model = ICNFModel(; m=icnf, n_epochs=8)
mach = machine(model, df)
fit!(mach)
fd = MLJBase.fitted_params(mach).learned_parameters
Expand Down

0 comments on commit 2456f74

Please sign in to comment.