Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add FFJORD #6

Merged
merged 1 commit into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ICNF.code-workspace
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"folders": [
{
"path": "."
}
],
"settings": {}
}
18 changes: 18 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,23 @@ uuid = "9bd0f7d2-bd29-441d-bcde-0d11364d2762"
authors = ["Hossein Pourbozorg <prbzrg@gmail.com> and contributors"]
version = "0.1.0"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
julia = "1.6"
23 changes: 22 additions & 1 deletion src/ICNF.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,26 @@
module ICNF

# Write your package code here.
using
CUDA,
ComputationalResources,
DataFrames,
DiffEqSensitivity,
Distributions,
Flux,
MLJBase,
MLJFlux,
MLJModelInterface,
OrdinaryDiffEq,
Parameters,
ScientificTypes,
Zygote,
LinearAlgebra,
Random,
Statistics

include("core.jl")
include("ffjord.jl")
include("metrics.jl")
include("utils.jl")

end
122 changes: 122 additions & 0 deletions src/core.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
export
inference, generate, TestMode, TrainMode,
loss, cb,
ICNFModel, ICNFDistribution

abstract type Flows end
abstract type NormalizingFlows <: Flows end
abstract type ContinuousNormalizingFlows <: NormalizingFlows end
abstract type InfinitesimalContinuousNormalizingFlows <: ContinuousNormalizingFlows end
abstract type AbstractICNF <: InfinitesimalContinuousNormalizingFlows where {T <: AbstractFloat} end

abstract type Mode end
struct TestMode <: Mode end
struct TrainMode <: Mode end

function inference(icnf::AbstractICNF, mode::TestMode, xs::AbstractMatrix{T})::AbstractVector{T} where {T <: AbstractFloat} end
function inference(icnf::AbstractICNF, mode::TrainMode, xs::AbstractMatrix{T})::AbstractVector{T} where {T <: AbstractFloat} end

function generate(icnf::AbstractICNF, mode::TestMode, n::Integer; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractMatrix{T} where {T <: AbstractFloat} end
function generate(icnf::AbstractICNF, mode::TrainMode, n::Integer; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractMatrix{T} where {T <: AbstractFloat} end

default_acceleration = CPU1()
default_solver_test = Feagin14()
default_solver_train = Tsit5(; thread=OrdinaryDiffEq.True())
default_sensealg = InterpolatingAdjoint(
;
autodiff=true,
chunk_size=0,
autojacvec=ZygoteVJP(),
)

# Flux interface

function loss(icnf::AbstractICNF; agg::Function=mean)::Function where {T <: AbstractFloat} end

function (m::AbstractICNF)(x::AbstractMatrix{T})::AbstractVector{T} where {T <: AbstractFloat}
inference(m, TestMode(), x)
end

function cb(icnf::AbstractICNF, data::AbstractVector{T2}; agg::Function=mean)::Function where {T <: AbstractFloat, T2 <: AbstractMatrix{T}}
l = loss(icnf; agg)
xs = first(data)
function f()::Nothing
@info "loss = $(l(xs))"
end
f
end

# MLJ interface

abstract type MLJICNF <: MLJModelInterface.Unsupervised end

@with_kw mutable struct ICNFModel{T2} <: MLJICNF where {T <: AbstractFloat, T2 <: AbstractICNF}
m::T2 = FFJORD{Float64}(Dense(1, 1), 1)

optimizer::Flux.Optimise.AbstractOptimiser = AMSGrad()
n_epochs::Integer = 128

batch_size::Integer = 32

cb_timeout::Integer = 16
end

function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
x = collect(MLJModelInterface.matrix(X)')

data = broadcast(nx -> hcat(nx...), Base.Iterators.partition(eachcol(x), model.batch_size))

Flux.Optimise.@epochs model.n_epochs Flux.Optimise.train!(loss(model.m), Flux.params(model.m), data, model.optimizer; cb=Flux.throttle(cb(model.m, data), model.cb_timeout))

fitresult = nothing
cache = nothing
report = nothing
fitresult, cache, report
end

function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew)
xnew = collect(MLJModelInterface.matrix(Xnew)')

logp̂x = inference(model.m, TestMode(), xnew)

DataFrame(px=exp.(logp̂x))
end

function MLJModelInterface.fitted_params(model::ICNFModel, fitresult)
(
learned_parameters=model.m.p,
)
end

MLJBase.metadata_pkg(
ICNFModel,
package_name="ICNF",
package_uuid="9bd0f7d2-bd29-441d-bcde-0d11364d2762",
package_url="https://github.com/impICNF/ICNF.jl",
is_pure_julia=true,
package_license="MIT",
is_wrapper=false,
)

MLJBase.metadata_model(
ICNFModel,
input_scitype=Table{AbstractVector{ScientificTypes.Continuous}},
target_scitype=Table{AbstractVector{ScientificTypes.Continuous}},
output_scitype=Table{AbstractVector{ScientificTypes.Continuous}},
supports_weights=false,
docstring="ICNFModel",
load_path="ICNF.ICNFModel",
)

# Distributions interface

@with_kw struct ICNFDistribution{T2} <: ContinuousMultivariateDistribution where {T <: AbstractFloat, T2 <: AbstractICNF}
m::T2
end

Base.length(d::ICNFDistribution) = d.m.nvars
Base.eltype(d::ICNFDistribution) = eltype(d.m.p)
Distributions._logpdf(d::ICNFDistribution, x::AbstractVector) = first(Distributions._logpdf(d, hcat(x)))
Distributions._logpdf(d::ICNFDistribution, A::AbstractMatrix) = inference(d.m, TestMode(), A)
Distributions._rand!(rng::AbstractRNG, d::ICNFDistribution, x::AbstractVector) = (x[:] = generate(d.m, TestMode(), size(x, 2); rng))
Distributions._rand!(rng::AbstractRNG, d::ICNFDistribution, A::AbstractMatrix) = (A[:] = generate(d.m, TestMode(), size(A, 2); rng))
152 changes: 152 additions & 0 deletions src/ffjord.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
export FFJORD

"""
Implementations of

[Grathwohl, Will, Ricky TQ Chen, Jesse Bettencourt, Ilya Sutskever, and David Duvenaud. "Ffjord: Free-form continuous dynamics for scalable reversible generative models." arXiv preprint arXiv:1810.01367 (2018).](https://arxiv.org/abs/1810.01367)
"""
@with_kw struct FFJORD{T} <: AbstractICNF where {T <: AbstractFloat}
re::Function
p::AbstractVector{T}

nvars::Integer
basedist::Distribution = MvNormal(zeros(T, nvars), Diagonal(ones(T, nvars)))
tspan::Tuple{T, T} = convert.(T, (0, 1))

solver_test::OrdinaryDiffEqAlgorithm = default_solver_test
solver_train::OrdinaryDiffEqAlgorithm = default_solver_train

sensealg_test::SciMLBase.AbstractSensitivityAlgorithm = default_sensealg
sensealg_train::SciMLBase.AbstractSensitivityAlgorithm = sensealg_test

acceleration::AbstractResource = CPU1()

# trace_test
# trace_train
end

function FFJORD{T}(
nn,
nvars::Integer,
;
basedist::Distribution=MvNormal(zeros(T, nvars), Diagonal(ones(T, nvars))),
tspan::Tuple{T, T}=convert.(T, (0, 1)),

solver_test::OrdinaryDiffEqAlgorithm=default_solver_test,
solver_train::OrdinaryDiffEqAlgorithm=default_solver_train,

sensealg_test::SciMLBase.AbstractSensitivityAlgorithm=default_sensealg,
sensealg_train::SciMLBase.AbstractSensitivityAlgorithm=sensealg_test,

acceleration::AbstractResource=CPU1(),
) where {T <: AbstractFloat}
move = MLJFlux.Mover(acceleration)
if T <: Float64
nn = f64(nn)
elseif T <: Float32
nn = f32(nn)
else
nn = Flux.paramtype(T, nn)
end
nn = move(nn)
p, re = Flux.destructure(nn)
FFJORD{T}(;
re, p, nvars, basedist, tspan,
solver_test, solver_train,
sensealg_test, sensealg_train,
acceleration,
)
end

function augmented_f(icnf::FFJORD{T}, mode::TestMode)::Function where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)

function f_aug(u, p, t)
m = icnf.re(p)
z = u[1:end - 1, :]
mz = m(z)
J = jacobian_batched(m, z, move)
trace_J = transpose(tr.(eachslice(J, dims=3)))
vcat(mz, -trace_J)
end
f_aug
end

function augmented_f(icnf::FFJORD{T}, mode::TrainMode, sz::Tuple{Int64, Int64})::Function where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)
ϵ = randn(T, sz) |> move

function f_aug(u, p, t)
m = icnf.re(p)
z = u[1:end - 1, :]
mz, back = Zygote.pullback(m, z)
ϵJ = only(back(ϵ))
trace_J = sum(ϵJ .* ϵ, dims=1)
vcat(mz, -trace_J)
end
f_aug
end

function inference(icnf::FFJORD{T}, mode::TestMode, xs::AbstractMatrix{T})::AbstractVector{T} where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)
xs = xs |> move
zrs = zeros(T, 1, size(xs, 2)) |> move
f_aug = augmented_f(icnf, mode)
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, icnf.p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Δlogp = fsol[end, :]
logp̂x = logpdf(icnf.basedist, z) - Δlogp
logp̂x
end

function inference(icnf::FFJORD{T}, mode::TrainMode, xs::AbstractMatrix{T})::AbstractVector{T} where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)
xs = xs |> move
zrs = zeros(T, 1, size(xs, 2)) |> move
f_aug = augmented_f(icnf, mode, size(xs))
prob = ODEProblem{false}(f_aug, vcat(xs, zrs), icnf.tspan, icnf.p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
Δlogp = fsol[end, :]
logp̂x = logpdf(icnf.basedist, z) - Δlogp
logp̂x
end

function generate(icnf::FFJORD{T}, mode::TestMode, n::Integer; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractMatrix{T} where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)
new_xs = isnothing(rng) ? rand(icnf.basedist, n) : rand(rng, icnf.basedist, n)
new_xs = new_xs |> move
zrs = zeros(T, 1, size(new_xs, 2)) |> move
f_aug = augmented_f(icnf, mode)
prob = ODEProblem{false}(f_aug, vcat(new_xs, zrs), reverse(icnf.tspan), icnf.p)
sol = solve(prob, icnf.solver_test; sensealg=icnf.sensealg_test)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
z
end

function generate(icnf::FFJORD{T}, mode::TrainMode, n::Integer; rng::Union{AbstractRNG, Nothing}=nothing)::AbstractMatrix{T} where {T <: AbstractFloat}
move = MLJFlux.Mover(icnf.acceleration)
new_xs = isnothing(rng) ? rand(icnf.basedist, n) : rand(rng, icnf.basedist, n)
new_xs = new_xs |> move
zrs = zeros(T, 1, size(new_xs, 2)) |> move
f_aug = augmented_f(icnf, mode, size(new_xs))
prob = ODEProblem{false}(f_aug, vcat(new_xs, zrs), reverse(icnf.tspan), icnf.p)
sol = solve(prob, icnf.solver_train; sensealg=icnf.sensealg_train)
fsol = sol[:, :, end]
z = fsol[1:end - 1, :]
z
end

Flux.@functor FFJORD (p,)

function loss(icnf::FFJORD{T}; agg::Function=mean)::Function where {T <: AbstractFloat}
function f(x::AbstractMatrix{T})::T
logp̂x = inference(icnf, TrainMode(), x)
agg(-logp̂x)
end
f
end
4 changes: 4 additions & 0 deletions src/metrics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
function loglikelihood(icnf::AbstractICNF, xs::AbstractMatrix{T}; agg::Function=mean)::T where {T <: AbstractFloat}
logp̂x = inference(icnf, TestMode(), xs)
agg(logp̂x)
end
11 changes: 11 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function jacobian_batched(f, x::AbstractMatrix{T}, move::MLJFlux.Mover)::AbstractArray{T, 3} where {T <: AbstractFloat}
y, back = Zygote.pullback(f, x)
z = zeros(eltype(x), size(x)) |> move
res = zeros(size(x, 1), size(x, 1), size(x, 2)) |> move
for i in 1:size(y, 1)
z[i, :] .= one(eltype(x))
res[i, :, :] .= only(back(z))
z[i, :] .= zero(eltype(x))
end
res
end
6 changes: 6 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3 changes: 3 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
function gen_dist(scl=5, nvars=2, nmix=nv)
Product([MixtureModel([Normal(scl*rand(), scl*rand()) for _ in 1:nmix]) for _ in 1:nvars])
end
Loading