Skip to content

Commit

Permalink
add rng to inferences
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed May 9, 2022
1 parent a8a7a31 commit 03573b2
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 32 deletions.
11 changes: 6 additions & 5 deletions src/cond_ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ function augmented_f(icnf::CondFFJORD{T}, mode::TestMode, ys::Union{AbstractMatr
f_aug
end

function augmented_f(icnf::CondFFJORD{T}, mode::TrainMode, ys::Union{AbstractMatrix{T}, CuArray}, sz::Tuple{T2, T2})::Function where {T <: AbstractFloat, T2 <: Integer}
function augmented_f(icnf::CondFFJORD{T}, mode::TrainMode, ys::Union{AbstractMatrix{T}, CuArray}, sz::Tuple{T2, T2}; rng::Union{AbstractRNG, Nothing}=nothing)::Function where {T <: AbstractFloat, T2 <: Integer}
move = MLJFlux.Mover(icnf.acceleration)
ϵ = randn(T, sz) |> move
ϵ = isnothing(rng) ? randn(T, sz) : randn(rng, T, sz)
ϵ = ϵ |> move

function f_aug(u, p, t)
m = Chain(
Expand All @@ -90,7 +91,7 @@ function augmented_f(icnf::CondFFJORD{T}, mode::TrainMode, ys::Union{AbstractMat
f_aug
end

function inference(icnf::CondFFJORD{T}, mode::TestMode, xs::AbstractMatrix{T}, ys::AbstractMatrix{T}, p::AbstractVector=icnf.p)::AbstractVector where {T <: AbstractFloat}
function inference(icnf::CondFFJORD{T}, mode::TestMode, xs::AbstractMatrix{T}, ys::AbstractMatrix{T}, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractVector where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)
xs = xs |> move
ys = ys |> move
Expand All @@ -105,12 +106,12 @@ function inference(icnf::CondFFJORD{T}, mode::TestMode, xs::AbstractMatrix{T}, y
logp̂x
end

function inference(icnf::CondFFJORD{T}, mode::TrainMode, xs::AbstractMatrix{T}, ys::AbstractMatrix{T}, p::AbstractVector=icnf.p)::AbstractVector where {T <: AbstractFloat}
function inference(icnf::CondFFJORD{T}, mode::TrainMode, xs::AbstractMatrix{T}, ys::AbstractMatrix{T}, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractVector where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)
xs = xs |> move
ys = ys |> move
zrs = zeros(T, 1, size(xs, 2)) |> move
f_aug = augmented_f(icnf, mode, ys, size(xs))
f_aug = augmented_f(icnf, mode, ys, size(xs); rng)
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
Expand Down
4 changes: 2 additions & 2 deletions src/cond_planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ function augmented_f(icnf::CondPlanar{T}, ys::Union{AbstractMatrix{T}, CuArray},
f_aug
end

function inference(icnf::CondPlanar{T}, mode::TestMode, xs::AbstractMatrix{T}, ys::AbstractMatrix{T}, p::AbstractVector=icnf.p)::AbstractVector where {T <: AbstractFloat}
function inference(icnf::CondPlanar{T}, mode::TestMode, xs::AbstractMatrix{T}, ys::AbstractMatrix{T}, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractVector where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)
xs = xs |> move
ys = ys |> move
Expand All @@ -90,7 +90,7 @@ function inference(icnf::CondPlanar{T}, mode::TestMode, xs::AbstractMatrix{T}, y
logp̂x
end

function inference(icnf::CondPlanar{T}, mode::TrainMode, xs::AbstractMatrix{T}, ys::AbstractMatrix{T}, p::AbstractVector=icnf.p)::AbstractVector where {T <: AbstractFloat}
function inference(icnf::CondPlanar{T}, mode::TrainMode, xs::AbstractMatrix{T}, ys::AbstractMatrix{T}, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractVector where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)
xs = xs |> move
ys = ys |> move
Expand Down
11 changes: 6 additions & 5 deletions src/cond_rnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ function augmented_f(icnf::CondRNODE{T}, mode::TestMode, ys::Union{AbstractMatri
f_aug
end

function augmented_f(icnf::CondRNODE{T}, mode::TrainMode, ys::Union{AbstractMatrix{T}, CuArray}, sz::Tuple{T2, T2})::Function where {T <: AbstractFloat, T2 <: Integer}
function augmented_f(icnf::CondRNODE{T}, mode::TrainMode, ys::Union{AbstractMatrix{T}, CuArray}, sz::Tuple{T2, T2}; rng::Union{AbstractRNG, Nothing}=nothing)::Function where {T <: AbstractFloat, T2 <: Integer}
move = MLJFlux.Mover(icnf.acceleration)
ϵ = randn(T, sz) |> move
ϵ = isnothing(rng) ? randn(T, sz) : randn(rng, T, sz)
ϵ = ϵ |> move

function f_aug(u, p, t)
m = Chain(
Expand All @@ -92,7 +93,7 @@ function augmented_f(icnf::CondRNODE{T}, mode::TrainMode, ys::Union{AbstractMatr
f_aug
end

function inference(icnf::CondRNODE{T}, mode::TestMode, xs::AbstractMatrix{T}, ys::AbstractMatrix{T}, p::AbstractVector=icnf.p)::AbstractVector where {T <: AbstractFloat}
function inference(icnf::CondRNODE{T}, mode::TestMode, xs::AbstractMatrix{T}, ys::AbstractMatrix{T}, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractVector where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)
xs = xs |> move
ys = ys |> move
Expand All @@ -107,12 +108,12 @@ function inference(icnf::CondRNODE{T}, mode::TestMode, xs::AbstractMatrix{T}, ys
logp̂x
end

function inference(icnf::CondRNODE{T}, mode::TrainMode, xs::AbstractMatrix{T}, ys::AbstractMatrix{T}, p::AbstractVector=icnf.p)::Tuple where {T <: AbstractFloat}
function inference(icnf::CondRNODE{T}, mode::TrainMode, xs::AbstractMatrix{T}, ys::AbstractMatrix{T}, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::Tuple where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)
xs = xs |> move
ys = ys |> move
zrs = zeros(T, 3, size(xs, 2)) |> move
f_aug = augmented_f(icnf, mode, ys, size(xs))
f_aug = augmented_f(icnf, mode, ys, size(xs); rng)
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
Expand Down
8 changes: 4 additions & 4 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ default_optimizer = Dict(

abstract type AbstractICNF{T} <: InfinitesimalContinuousNormalizingFlows where {T <: AbstractFloat} end

function inference(icnf::AbstractICNF{T}, mode::TestMode, xs::AbstractMatrix{T}, p::AbstractVector=icnf.p)::AbstractVector where {T <: AbstractFloat} end
function inference(icnf::AbstractICNF{T}, mode::TrainMode, xs::AbstractMatrix{T}, p::AbstractVector=icnf.p)::AbstractVector where {T <: AbstractFloat} end
function inference(icnf::AbstractICNF{T}, mode::TestMode, xs::AbstractMatrix{T}, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractVector where {T <: AbstractFloat} end
function inference(icnf::AbstractICNF{T}, mode::TrainMode, xs::AbstractMatrix{T}, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractVector where {T <: AbstractFloat} end

function generate(icnf::AbstractICNF{T}, mode::TestMode, n::Integer, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractMatrix{T} where {T <: AbstractFloat} end
function generate(icnf::AbstractICNF{T}, mode::TrainMode, n::Integer, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractMatrix{T} where {T <: AbstractFloat} end
Expand Down Expand Up @@ -118,8 +118,8 @@ end

abstract type AbstractCondICNF{T} <: InfinitesimalContinuousNormalizingFlows where {T <: AbstractFloat} end

function inference(icnf::AbstractCondICNF{T}, mode::TestMode, xs::AbstractMatrix{T}, ys::AbstractMatrix{T}, p::AbstractVector=icnf.p)::AbstractVector where {T <: AbstractFloat} end
function inference(icnf::AbstractCondICNF{T}, mode::TrainMode, xs::AbstractMatrix{T}, ys::AbstractMatrix{T}, p::AbstractVector=icnf.p)::AbstractVector where {T <: AbstractFloat} end
function inference(icnf::AbstractCondICNF{T}, mode::TestMode, xs::AbstractMatrix{T}, ys::AbstractMatrix{T}, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractVector where {T <: AbstractFloat} end
function inference(icnf::AbstractCondICNF{T}, mode::TrainMode, xs::AbstractMatrix{T}, ys::AbstractMatrix{T}, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractVector where {T <: AbstractFloat} end

function generate(icnf::AbstractCondICNF{T}, mode::TestMode, ys::AbstractMatrix{T}, n::Integer, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractMatrix{T} where {T <: AbstractFloat} end
function generate(icnf::AbstractCondICNF{T}, mode::TrainMode, ys::AbstractMatrix{T}, n::Integer, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractMatrix{T} where {T <: AbstractFloat} end
Expand Down
11 changes: 6 additions & 5 deletions src/ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ function augmented_f(icnf::FFJORD{T}, mode::TestMode)::Function where {T <: Abst
f_aug
end

function augmented_f(icnf::FFJORD{T}, mode::TrainMode, sz::Tuple{T2, T2})::Function where {T <: AbstractFloat, T2 <: Integer}
function augmented_f(icnf::FFJORD{T}, mode::TrainMode, sz::Tuple{T2, T2}; rng::Union{AbstractRNG, Nothing}=nothing)::Function where {T <: AbstractFloat, T2 <: Integer}
move = MLJFlux.Mover(icnf.acceleration)
ϵ = randn(T, sz) |> move
ϵ = isnothing(rng) ? randn(T, sz) : randn(rng, T, sz)
ϵ = ϵ |> move

function f_aug(u, p, t)
m = icnf.re(p)
Expand All @@ -86,7 +87,7 @@ function augmented_f(icnf::FFJORD{T}, mode::TrainMode, sz::Tuple{T2, T2})::Funct
f_aug
end

function inference(icnf::FFJORD{T}, mode::TestMode, xs::AbstractMatrix{T}, p::AbstractVector=icnf.p)::AbstractVector where {T <: AbstractFloat}
function inference(icnf::FFJORD{T}, mode::TestMode, xs::AbstractMatrix{T}, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractVector where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)
xs = xs |> move
zrs = zeros(T, 1, size(xs, 2)) |> move
Expand All @@ -100,11 +101,11 @@ function inference(icnf::FFJORD{T}, mode::TestMode, xs::AbstractMatrix{T}, p::Ab
logp̂x
end

function inference(icnf::FFJORD{T}, mode::TrainMode, xs::AbstractMatrix{T}, p::AbstractVector=icnf.p)::AbstractVector where {T <: AbstractFloat}
function inference(icnf::FFJORD{T}, mode::TrainMode, xs::AbstractMatrix{T}, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractVector where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)
xs = xs |> move
zrs = zeros(T, 1, size(xs, 2)) |> move
f_aug = augmented_f(icnf, mode, size(xs))
f_aug = augmented_f(icnf, mode, size(xs); rng)
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
Expand Down
12 changes: 6 additions & 6 deletions src/planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ struct PlanarNN
h::Function
end

function PlanarNN(nvars::Integer, h::Function=tanh; cond=false)
u = randn(nvars)
w = randn(cond ? nvars*2 : nvars)
b = randn(1)
function PlanarNN(nvars::Integer, h::Function=tanh; cond=false, rng::Union{AbstractRNG, Nothing}=nothing)
u = isnothing(rng) ? randn(nvars) : randn(rng, nvars)
w = isnothing(rng) ? randn(cond ? nvars*2 : nvars) : randn(rng, cond ? nvars*2 : nvars)
b = isnothing(rng) ? randn(1) : randn(rng, 1)
PlanarNN(u, w, b, h)
end

Expand Down Expand Up @@ -96,7 +96,7 @@ function augmented_f(icnf::Planar{T}, sz::Tuple{T2, T2})::Function where {T <: A
f_aug
end

function inference(icnf::Planar{T}, mode::TestMode, xs::AbstractMatrix{T}, p::AbstractVector=icnf.p)::AbstractVector where {T <: AbstractFloat}
function inference(icnf::Planar{T}, mode::TestMode, xs::AbstractMatrix{T}, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractVector where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)
xs = xs |> move
zrs = zeros(T, 1, size(xs, 2)) |> move
Expand All @@ -110,7 +110,7 @@ function inference(icnf::Planar{T}, mode::TestMode, xs::AbstractMatrix{T}, p::Ab
logp̂x
end

function inference(icnf::Planar{T}, mode::TrainMode, xs::AbstractMatrix{T}, p::AbstractVector=icnf.p)::AbstractVector where {T <: AbstractFloat}
function inference(icnf::Planar{T}, mode::TrainMode, xs::AbstractMatrix{T}, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractVector where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)
xs = xs |> move
zrs = zeros(T, 1, size(xs, 2)) |> move
Expand Down
11 changes: 6 additions & 5 deletions src/rnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ function augmented_f(icnf::RNODE{T}, mode::TestMode)::Function where {T <: Abstr
f_aug
end

function augmented_f(icnf::RNODE{T}, mode::TrainMode, sz::Tuple{T2, T2})::Function where {T <: AbstractFloat, T2 <: Integer}
function augmented_f(icnf::RNODE{T}, mode::TrainMode, sz::Tuple{T2, T2}; rng::Union{AbstractRNG, Nothing}=nothing)::Function where {T <: AbstractFloat, T2 <: Integer}
move = MLJFlux.Mover(icnf.acceleration)
ϵ = randn(T, sz) |> move
ϵ = isnothing(rng) ? randn(T, sz) : randn(rng, T, sz)
ϵ = ϵ |> move

function f_aug(u, p, t)
m = icnf.re(p)
Expand All @@ -88,7 +89,7 @@ function augmented_f(icnf::RNODE{T}, mode::TrainMode, sz::Tuple{T2, T2})::Functi
f_aug
end

function inference(icnf::RNODE{T}, mode::TestMode, xs::AbstractMatrix{T}, p::AbstractVector=icnf.p)::AbstractVector where {T <: AbstractFloat}
function inference(icnf::RNODE{T}, mode::TestMode, xs::AbstractMatrix{T}, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractVector where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)
xs = xs |> move
zrs = zeros(T, 1, size(xs, 2)) |> move
Expand All @@ -102,11 +103,11 @@ function inference(icnf::RNODE{T}, mode::TestMode, xs::AbstractMatrix{T}, p::Abs
logp̂x
end

function inference(icnf::RNODE{T}, mode::TrainMode, xs::AbstractMatrix{T}, p::AbstractVector=icnf.p)::Tuple where {T <: AbstractFloat}
function inference(icnf::RNODE{T}, mode::TrainMode, xs::AbstractMatrix{T}, p::AbstractVector=icnf.p; rng::Union{AbstractRNG, Nothing}=nothing)::Tuple where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)
xs = xs |> move
zrs = zeros(T, 3, size(xs, 2)) |> move
f_aug = augmented_f(icnf, mode, size(xs))
f_aug = augmented_f(icnf, mode, size(xs); rng)
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
Expand Down

0 comments on commit 03573b2

Please sign in to comment.