Skip to content

Commit

Permalink
fix tspan (#68)
Browse files Browse the repository at this point in the history
* fix tspan

* add ODEFunction

* fix type unstable nn
  • Loading branch information
prbzrg authored Jun 21, 2022
1 parent 7735e04 commit ff81582
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 24 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Hossein Pourbozorg <prbzrg@gmail.com> and contributors"]
version = "0.1.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Expand Down Expand Up @@ -33,6 +34,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Adapt = "3.3"
CUDA = "3.7"
ComputationalResources = "0.3"
DataFrames = "1.2"
Expand Down
1 change: 1 addition & 0 deletions src/ICNF.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ICNF

using
Adapt,
CUDA,
ComputationalResources,
DataFrames,
Expand Down
13 changes: 9 additions & 4 deletions src/cond_ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ function CondFFJORD{T}(
acceleration::AbstractResource=default_acceleration,
) where {T <: AbstractFloat}
array_mover = make_mover(acceleration, T)
nn = fmap(x -> adapt(T, x), nn)
p, re = destructure(nn)
CondFFJORD{T}(
re, p |> array_mover, nvars, basedist, tspan,
Expand Down Expand Up @@ -87,7 +88,8 @@ function inference(icnf::CondFFJORD{T}, mode::TestMode, xs::AbstractMatrix{T}, y
ys = ys |> icnf.array_mover
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, ys)
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -101,7 +103,8 @@ function inference(icnf::CondFFJORD{T}, mode::TrainMode, xs::AbstractMatrix{T},
ys = ys |> icnf.array_mover
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, ys, size(xs); rng)
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -116,7 +119,8 @@ function generate(icnf::CondFFJORD{T}, mode::TestMode, ys::AbstractMatrix{T}, n:
new_xs = new_xs |> icnf.array_mover
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, ys)
prob = ODEProblem{false}(f_aug, vcat(new_xs, zrs), reverse(icnf.tspan), p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -129,7 +133,8 @@ function generate(icnf::CondFFJORD{T}, mode::TrainMode, ys::AbstractMatrix{T}, n
new_xs = new_xs |> icnf.array_mover
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, ys, size(new_xs))
prob = ODEProblem{false}(f_aug, vcat(new_xs, zrs), reverse(icnf.tspan), p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand Down
13 changes: 9 additions & 4 deletions src/cond_planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ function CondPlanar{T}(
acceleration::AbstractResource=default_acceleration,
) where {T <: AbstractFloat}
array_mover = make_mover(acceleration, T)
nn = fmap(x -> adapt(T, x), nn)
p, re = destructure(nn)
CondPlanar{T}(
re, p |> array_mover, nvars, basedist, tspan,
Expand Down Expand Up @@ -72,7 +73,8 @@ function inference(icnf::CondPlanar{T}, mode::TestMode, xs::AbstractMatrix{T}, y
ys = ys |> icnf.array_mover
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, ys, size(xs))
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -86,7 +88,8 @@ function inference(icnf::CondPlanar{T}, mode::TrainMode, xs::AbstractMatrix{T},
ys = ys |> icnf.array_mover
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, ys, size(xs))
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -101,7 +104,8 @@ function generate(icnf::CondPlanar{T}, mode::TestMode, ys::AbstractMatrix{T}, n:
new_xs = new_xs |> icnf.array_mover
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, ys, size(new_xs))
prob = ODEProblem{false}(f_aug, vcat(new_xs, zrs), reverse(icnf.tspan), p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -114,7 +118,8 @@ function generate(icnf::CondPlanar{T}, mode::TrainMode, ys::AbstractMatrix{T}, n
new_xs = new_xs |> icnf.array_mover
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, ys, size(new_xs))
prob = ODEProblem{false}(f_aug, vcat(new_xs, zrs), reverse(icnf.tspan), p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand Down
13 changes: 9 additions & 4 deletions src/cond_rnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ function CondRNODE{T}(
acceleration::AbstractResource=default_acceleration,
) where {T <: AbstractFloat}
array_mover = make_mover(acceleration, T)
nn = fmap(x -> adapt(T, x), nn)
p, re = destructure(nn)
CondRNODE{T}(
re, p |> array_mover, nvars, basedist, tspan,
Expand Down Expand Up @@ -89,7 +90,8 @@ function inference(icnf::CondRNODE{T}, mode::TestMode, xs::AbstractMatrix{T}, ys
ys = ys |> icnf.array_mover
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, ys)
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -103,7 +105,8 @@ function inference(icnf::CondRNODE{T}, mode::TrainMode, xs::AbstractMatrix{T}, y
ys = ys |> icnf.array_mover
zrs = zeros(T, 3, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, ys, size(xs); rng)
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 3, :]
Expand All @@ -120,7 +123,8 @@ function generate(icnf::CondRNODE{T}, mode::TestMode, ys::AbstractMatrix{T}, n::
new_xs = new_xs |> icnf.array_mover
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, ys)
prob = ODEProblem{false}(f_aug, vcat(new_xs, zrs), reverse(icnf.tspan), p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -133,7 +137,8 @@ function generate(icnf::CondRNODE{T}, mode::TrainMode, ys::AbstractMatrix{T}, n:
new_xs = new_xs |> icnf.array_mover
zrs = zeros(T, 3, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, ys, size(new_xs))
prob = ODEProblem{false}(f_aug, vcat(new_xs, zrs), reverse(icnf.tspan), p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 3, :]
Expand Down
13 changes: 9 additions & 4 deletions src/ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ function FFJORD{T}(
acceleration::AbstractResource=default_acceleration,
) where {T <: AbstractFloat}
array_mover = make_mover(acceleration, T)
nn = fmap(x -> adapt(T, x), nn)
p, re = destructure(nn)
FFJORD{T}(
re, p |> array_mover, nvars, basedist, tspan,
Expand Down Expand Up @@ -82,7 +83,8 @@ function inference(icnf::FFJORD{T}, mode::TestMode, xs::AbstractMatrix{T}, p::Ab
xs = xs |> icnf.array_mover
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode)
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -95,7 +97,8 @@ function inference(icnf::FFJORD{T}, mode::TrainMode, xs::AbstractMatrix{T}, p::A
xs = xs |> icnf.array_mover
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, size(xs); rng)
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -109,7 +112,8 @@ function generate(icnf::FFJORD{T}, mode::TestMode, n::Integer, p::AbstractVector
new_xs = new_xs |> icnf.array_mover
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode)
prob = ODEProblem{false}(f_aug, vcat(new_xs, zrs), reverse(icnf.tspan), p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -121,7 +125,8 @@ function generate(icnf::FFJORD{T}, mode::TrainMode, n::Integer, p::AbstractVecto
new_xs = new_xs |> icnf.array_mover
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, size(new_xs))
prob = ODEProblem{false}(f_aug, vcat(new_xs, zrs), reverse(icnf.tspan), p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand Down
13 changes: 9 additions & 4 deletions src/planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ function Planar{T}(
acceleration::AbstractResource=default_acceleration,
) where {T <: AbstractFloat}
array_mover = make_mover(acceleration, T)
nn = fmap(x -> adapt(T, x), nn)
p, re = destructure(nn)
Planar{T}(
re, p |> array_mover, nvars, basedist, tspan,
Expand Down Expand Up @@ -92,7 +93,8 @@ function inference(icnf::Planar{T}, mode::TestMode, xs::AbstractMatrix{T}, p::Ab
xs = xs |> icnf.array_mover
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, size(xs))
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -105,7 +107,8 @@ function inference(icnf::Planar{T}, mode::TrainMode, xs::AbstractMatrix{T}, p::A
xs = xs |> icnf.array_mover
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, size(xs))
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -119,7 +122,8 @@ function generate(icnf::Planar{T}, mode::TestMode, n::Integer, p::AbstractVector
new_xs = new_xs |> icnf.array_mover
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, size(new_xs))
prob = ODEProblem{false}(f_aug, vcat(new_xs, zrs), reverse(icnf.tspan), p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -131,7 +135,8 @@ function generate(icnf::Planar{T}, mode::TrainMode, n::Integer, p::AbstractVecto
new_xs = new_xs |> icnf.array_mover
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, size(new_xs))
prob = ODEProblem{false}(f_aug, vcat(new_xs, zrs), reverse(icnf.tspan), p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand Down
13 changes: 9 additions & 4 deletions src/rnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ function RNODE{T}(
acceleration::AbstractResource=default_acceleration,
) where {T <: AbstractFloat}
array_mover = make_mover(acceleration, T)
nn = fmap(x -> adapt(T, x), nn)
p, re = destructure(nn)
RNODE{T}(
re, p |> array_mover, nvars, basedist, tspan,
Expand Down Expand Up @@ -84,7 +85,8 @@ function inference(icnf::RNODE{T}, mode::TestMode, xs::AbstractMatrix{T}, p::Abs
xs = xs |> icnf.array_mover
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode)
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -97,7 +99,8 @@ function inference(icnf::RNODE{T}, mode::TrainMode, xs::AbstractMatrix{T}, p::Ab
xs = xs |> icnf.array_mover
zrs = zeros(T, 3, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, size(xs); rng)
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 3, :]
Expand All @@ -113,7 +116,8 @@ function generate(icnf::RNODE{T}, mode::TestMode, n::Integer, p::AbstractVector=
new_xs = new_xs |> icnf.array_mover
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode)
prob = ODEProblem{false}(f_aug, vcat(new_xs, zrs), reverse(icnf.tspan), p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -125,7 +129,8 @@ function generate(icnf::RNODE{T}, mode::TrainMode, n::Integer, p::AbstractVector
new_xs = new_xs |> icnf.array_mover
zrs = zeros(T, 3, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, size(new_xs))
prob = ODEProblem{false}(f_aug, vcat(new_xs, zrs), reverse(icnf.tspan), p)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 3, :]
Expand Down

0 comments on commit ff81582

Please sign in to comment.