Skip to content

Commit

Permalink
reverse tspan (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg authored Jun 27, 2022
1 parent ff81582 commit 41ac7a7
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 24 deletions.
8 changes: 4 additions & 4 deletions src/cond_ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ function inference(icnf::CondFFJORD{T}, mode::TestMode, xs::AbstractMatrix{T}, y
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, ys)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
prob = ODEProblem{false}(func, vcat(xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -104,7 +104,7 @@ function inference(icnf::CondFFJORD{T}, mode::TrainMode, xs::AbstractMatrix{T},
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, ys, size(xs); rng)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
prob = ODEProblem{false}(func, vcat(xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -120,7 +120,7 @@ function generate(icnf::CondFFJORD{T}, mode::TestMode, ys::AbstractMatrix{T}, n:
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, ys)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -134,7 +134,7 @@ function generate(icnf::CondFFJORD{T}, mode::TrainMode, ys::AbstractMatrix{T}, n
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, ys, size(new_xs))
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand Down
8 changes: 4 additions & 4 deletions src/cond_planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ function inference(icnf::CondPlanar{T}, mode::TestMode, xs::AbstractMatrix{T}, y
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, ys, size(xs))
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
prob = ODEProblem{false}(func, vcat(xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -89,7 +89,7 @@ function inference(icnf::CondPlanar{T}, mode::TrainMode, xs::AbstractMatrix{T},
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, ys, size(xs))
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
prob = ODEProblem{false}(func, vcat(xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -105,7 +105,7 @@ function generate(icnf::CondPlanar{T}, mode::TestMode, ys::AbstractMatrix{T}, n:
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, ys, size(new_xs))
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -119,7 +119,7 @@ function generate(icnf::CondPlanar{T}, mode::TrainMode, ys::AbstractMatrix{T}, n
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, ys, size(new_xs))
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand Down
8 changes: 4 additions & 4 deletions src/cond_rnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ function inference(icnf::CondRNODE{T}, mode::TestMode, xs::AbstractMatrix{T}, ys
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, ys)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
prob = ODEProblem{false}(func, vcat(xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -106,7 +106,7 @@ function inference(icnf::CondRNODE{T}, mode::TrainMode, xs::AbstractMatrix{T}, y
zrs = zeros(T, 3, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, ys, size(xs); rng)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
prob = ODEProblem{false}(func, vcat(xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 3, :]
Expand All @@ -124,7 +124,7 @@ function generate(icnf::CondRNODE{T}, mode::TestMode, ys::AbstractMatrix{T}, n::
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, ys)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -138,7 +138,7 @@ function generate(icnf::CondRNODE{T}, mode::TrainMode, ys::AbstractMatrix{T}, n:
zrs = zeros(T, 3, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, ys, size(new_xs))
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 3, :]
Expand Down
8 changes: 4 additions & 4 deletions src/ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ function inference(icnf::FFJORD{T}, mode::TestMode, xs::AbstractMatrix{T}, p::Ab
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
prob = ODEProblem{false}(func, vcat(xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -98,7 +98,7 @@ function inference(icnf::FFJORD{T}, mode::TrainMode, xs::AbstractMatrix{T}, p::A
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, size(xs); rng)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
prob = ODEProblem{false}(func, vcat(xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -113,7 +113,7 @@ function generate(icnf::FFJORD{T}, mode::TestMode, n::Integer, p::AbstractVector
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -126,7 +126,7 @@ function generate(icnf::FFJORD{T}, mode::TrainMode, n::Integer, p::AbstractVecto
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, size(new_xs))
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand Down
8 changes: 4 additions & 4 deletions src/planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function inference(icnf::Planar{T}, mode::TestMode, xs::AbstractMatrix{T}, p::Ab
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, size(xs))
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
prob = ODEProblem{false}(func, vcat(xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -108,7 +108,7 @@ function inference(icnf::Planar{T}, mode::TrainMode, xs::AbstractMatrix{T}, p::A
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, size(xs))
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
prob = ODEProblem{false}(func, vcat(xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -123,7 +123,7 @@ function generate(icnf::Planar{T}, mode::TestMode, n::Integer, p::AbstractVector
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, size(new_xs))
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -136,7 +136,7 @@ function generate(icnf::Planar{T}, mode::TrainMode, n::Integer, p::AbstractVecto
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, size(new_xs))
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand Down
8 changes: 4 additions & 4 deletions src/rnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ function inference(icnf::RNODE{T}, mode::TestMode, xs::AbstractMatrix{T}, p::Abs
zrs = zeros(T, 1, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
prob = ODEProblem{false}(func, vcat(xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -100,7 +100,7 @@ function inference(icnf::RNODE{T}, mode::TrainMode, xs::AbstractMatrix{T}, p::Ab
zrs = zeros(T, 3, size(xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, size(xs); rng)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(xs, zrs), reverse(icnf.tspan), p)
prob = ODEProblem{false}(func, vcat(xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 3, :]
Expand All @@ -117,7 +117,7 @@ function generate(icnf::RNODE{T}, mode::TestMode, n::Integer, p::AbstractVector=
zrs = zeros(T, 1, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode)
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Expand All @@ -130,7 +130,7 @@ function generate(icnf::RNODE{T}, mode::TrainMode, n::Integer, p::AbstractVector
zrs = zeros(T, 3, size(new_xs, 2)) |> icnf.array_mover
f_aug = augmented_f(icnf, mode, size(new_xs))
func = ODEFunction{false, true}(f_aug)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), icnf.tspan, p)
prob = ODEProblem{false}(func, vcat(new_xs, zrs), reverse(icnf.tspan), p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 3, :]
Expand Down

0 comments on commit 41ac7a7

Please sign in to comment.