Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleaning #298

Merged
merged 2 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ function construct(
rng::AbstractRNG = Random.default_rng(),
)
steerdist = Uniform{data_type}(-steer_rate, steer_rate)
_fnn(x, ps, st) = first(nn(x, ps, st))

aicnf{
data_type,
Expand All @@ -49,7 +48,6 @@ function construct(
typeof(autodiff_backend),
typeof(sol_kwargs),
typeof(rng),
typeof(_fnn),
}(
nn,
nvars,
Expand All @@ -62,7 +60,6 @@ function construct(
autodiff_backend,
sol_kwargs,
rng,
_fnn,
)
end

Expand Down
4 changes: 2 additions & 2 deletions src/base_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ end
mz, J = AbstractDifferentiation.value_and_jacobian(
icnf.differentiation_backend,
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z,
)
Expand All @@ -179,7 +179,7 @@ end
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz, J = jacobian_batched(icnf, let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end, z)
trace_J = transpose(tr.(eachslice(J; dims = 3)))
cat(mz, -trace_J; dims = 1)
Expand Down
4 changes: 2 additions & 2 deletions src/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ end
mz, J = AbstractDifferentiation.value_and_jacobian(
icnf.differentiation_backend,
let p = p, st = st
x -> icnf._fnn(x, p, st)
x -> first(icnf.nn(x, p, st))
end,
z,
)
Expand All @@ -168,7 +168,7 @@ end
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz, J = jacobian_batched(icnf, let p = p, st = st
x -> icnf._fnn(x, p, st)
x -> first(icnf.nn(x, p, st))
end, z)
trace_J = transpose(tr.(eachslice(J; dims = 3)))
cat(mz, -trace_J; dims = 1)
Expand Down
14 changes: 6 additions & 8 deletions src/cond_ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ struct CondFFJORD{
AUTODIFF_BACKEND <: ADTypes.AbstractADType,
SOL_KWARGS <: Dict,
RNG <: AbstractRNG,
_FNN <: Function,
} <: AbstractCondICNF{T, CM, AUGMENTED, STEER}
nn::NN
nvars::NVARS
Expand All @@ -32,7 +31,6 @@ struct CondFFJORD{
autodiff_backend::AUTODIFF_BACKEND
sol_kwargs::SOL_KWARGS
rng::RNG
_fnn::_FNN
end

@views function augmented_f(
Expand All @@ -50,7 +48,7 @@ end
v_pb = AbstractDifferentiation.value_and_pullback_function(
icnf.differentiation_backend,
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z,
)
Expand All @@ -73,7 +71,7 @@ end
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz, back = Zygote.pullback(let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end, z)
ϵJ = only(back(ϵ))
trace_J = sum(ϵJ .* ϵ; dims = 1)
Expand All @@ -92,10 +90,10 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz = icnf._fnn(cat(z, ys; dims = 1), p, st)
mz = first(icnf.nn(cat(z, ys; dims = 1), p, st))
Jf = VecJac(
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z;
autodiff = icnf.autodiff_backend,
Expand All @@ -117,10 +115,10 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz = icnf._fnn(cat(z, ys; dims = 1), p, st)
mz = first(icnf.nn(cat(z, ys; dims = 1), p, st))
Jf = JacVec(
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z;
autodiff = icnf.autodiff_backend,
Expand Down
16 changes: 7 additions & 9 deletions src/cond_planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ struct CondPlanar{
AUTODIFF_BACKEND <: ADTypes.AbstractADType,
SOL_KWARGS <: Dict,
RNG <: AbstractRNG,
_FNN <: Function,
} <: AbstractCondICNF{T, CM, AUGMENTED, STEER}
nn::NN
nvars::NVARS
Expand All @@ -32,7 +31,6 @@ struct CondPlanar{
autodiff_backend::AUTODIFF_BACKEND
sol_kwargs::SOL_KWARGS
rng::RNG
_fnn::_FNN
end

@views function augmented_f(
Expand All @@ -47,7 +45,7 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1)]
mz = icnf._fnn(cat(z, ys; dims = 1), p, st)
mz = first(icnf.nn(cat(z, ys; dims = 1), p, st))
trace_J =
p.u ⋅ transpose(
only(
Expand Down Expand Up @@ -75,7 +73,7 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1)]
mz = icnf._fnn(cat(z, ys; dims = 1), p, st)
mz = first(icnf.nn(cat(z, ys; dims = 1), p, st))
trace_J =
p.u ⋅ transpose(
only(
Expand Down Expand Up @@ -104,7 +102,7 @@ end
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz, back = Zygote.pullback(let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end, z)
ϵJ = only(back(ϵ))
trace_J = sum(ϵJ .* ϵ; dims = 1)
Expand All @@ -123,10 +121,10 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz = icnf._fnn(cat(z, ys; dims = 1), p, st)
mz = first(icnf.nn(cat(z, ys; dims = 1), p, st))
Jf = VecJac(
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z;
autodiff = icnf.autodiff_backend,
Expand All @@ -148,10 +146,10 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz = icnf._fnn(cat(z, ys; dims = 1), p, st)
mz = first(icnf.nn(cat(z, ys; dims = 1), p, st))
Jf = JacVec(
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z;
autodiff = icnf.autodiff_backend,
Expand Down
17 changes: 6 additions & 11 deletions src/cond_rnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ struct CondRNODE{
AUTODIFF_BACKEND <: ADTypes.AbstractADType,
SOL_KWARGS <: Dict,
RNG <: AbstractRNG,
_FNN <: Function,
} <: AbstractCondICNF{T, CM, AUGMENTED, STEER}
nn::NN
nvars::NVARS
Expand All @@ -32,7 +31,6 @@ struct CondRNODE{
autodiff_backend::AUTODIFF_BACKEND
sol_kwargs::SOL_KWARGS
rng::RNG
_fnn::_FNN
λ₁::T
λ₂::T
end
Expand Down Expand Up @@ -71,7 +69,6 @@ function construct(
λ₂::AbstractFloat = convert(data_type, 1e-2),
)
steerdist = Uniform{data_type}(-steer_rate, steer_rate)
_fnn(x, ps, st) = first(nn(x, ps, st))

aicnf{
data_type,
Expand All @@ -88,7 +85,6 @@ function construct(
typeof(autodiff_backend),
typeof(sol_kwargs),
typeof(rng),
typeof(_fnn),
}(
nn,
nvars,
Expand All @@ -101,7 +97,6 @@ function construct(
autodiff_backend,
sol_kwargs,
rng,
_fnn,
λ₁,
λ₂,
)
Expand All @@ -122,7 +117,7 @@ end
v_pb = AbstractDifferentiation.value_and_pullback_function(
icnf.differentiation_backend,
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z,
)
Expand All @@ -147,7 +142,7 @@ end
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
ż, back = Zygote.pullback(let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end, z)
ϵJ = only(back(ϵ))
= sum(ϵJ .* ϵ; dims = 1)
Expand All @@ -168,10 +163,10 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
= icnf._fnn(cat(z, ys; dims = 1), p, st)
= first(icnf.nn(cat(z, ys; dims = 1), p, st))
Jf = VecJac(
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z;
autodiff = icnf.autodiff_backend,
Expand All @@ -195,10 +190,10 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
= icnf._fnn(cat(z, ys; dims = 1), p, st)
= first(icnf.nn(cat(z, ys; dims = 1), p, st))
Jf = JacVec(
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z;
autodiff = icnf.autodiff_backend,
Expand Down
23 changes: 17 additions & 6 deletions src/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
st = gdev(st)
end
optfunc = OptimizationFunction(
(ps_, θ, xs_, ys_) -> model.loss(model.m, TrainMode(), xs_, ys_, ps_, st),
let mm = model.m, md = TrainMode(), st = st
(ps_, θ, xs_, ys_) -> model.loss(mm, md, xs_, ys_, ps_, st)
end,
model.adtype,
)
optprob = OptimizationProblem(optfunc, ps)
Expand Down Expand Up @@ -99,7 +101,9 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
optprob_re,
opt,
data;
callback = (ps_, l_) -> callback_f(ps_, l_, model.m, prgr, itr_n),
callback = let mm = model.m, prgr = prgr, itr_n = itr_n
(ps_, l_) -> callback_f(ps_, l_, mm, prgr, itr_n)
end,
)
ProgressMeter.finish!(prgr)

Expand Down Expand Up @@ -144,8 +148,11 @@ function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew)

if model.compute_mode <: VectorMode
tst = @timed logp̂x = broadcast(
((x, y),) -> first(inference(model.m, TestMode(), x, y, ps, st)),
zip(eachcol(xnew), eachcol(ynew)),
let mm = model.m, md = TestMode(), ps = ps, st = st
(x, y) -> first(inference(mm, md, x, y, ps, st))
end,
eachcol(xnew),
eachcol(ynew),
)
elseif model.compute_mode <: MatrixMode
tst = @timed logp̂x = first(inference(model.m, TestMode(), xnew, ynew, ps, st))
Expand Down Expand Up @@ -218,7 +225,9 @@ function Distributions._logpdf(d::CondICNFDist, x::AbstractVector{<:Real})
end
function Distributions._logpdf(d::CondICNFDist, A::AbstractMatrix{<:Real})
if d.m isa AbstractCondICNF{<:AbstractFloat, <:VectorMode}
broadcast(x -> Distributions._logpdf(d, x), eachcol(A))
broadcast(let d = d
x -> Distributions._logpdf(d, x)
end, eachcol(A))
elseif d.m isa AbstractCondICNF{<:AbstractFloat, <:MatrixMode}
first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st))
else
Expand All @@ -236,7 +245,9 @@ function Distributions._rand!(rng::AbstractRNG, d::CondICNFDist, x::AbstractVect
end
function Distributions._rand!(rng::AbstractRNG, d::CondICNFDist, A::AbstractMatrix{<:Real})
if d.m isa AbstractCondICNF{<:AbstractFloat, <:VectorMode}
A .= hcat(broadcast(x -> Distributions._rand!(rng, d, x), eachcol(A))...)
A .= hcat(broadcast(let rng = rng, d = d
x -> Distributions._rand!(rng, d, x)
end, eachcol(A))...)
elseif d.m isa AbstractCondICNF{<:AbstractFloat, <:MatrixMode}
A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2))
else
Expand Down
20 changes: 15 additions & 5 deletions src/core_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
st = gdev(st)
end
optfunc = OptimizationFunction(
(ps_, θ, xs_) -> model.loss(model.m, TrainMode(), xs_, ps_, st),
let mm = model.m, md = TrainMode(), st = st
(ps_, θ, xs_) -> model.loss(mm, md, xs_, ps_, st)
end,
model.adtype,
)
optprob = OptimizationProblem(optfunc, ps)
Expand Down Expand Up @@ -97,7 +99,9 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
optprob_re,
opt,
data;
callback = (ps_, l_) -> callback_f(ps_, l_, model.m, prgr, itr_n),
callback = let mm = model.m, prgr = prgr, itr_n = itr_n
(ps_, l_) -> callback_f(ps_, l_, mm, prgr, itr_n)
end,
)
ProgressMeter.finish!(prgr)
else
Expand Down Expand Up @@ -138,7 +142,9 @@ function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew)

if model.compute_mode <: VectorMode
tst = @timed logp̂x = broadcast(
x -> first(inference(model.m, TestMode(), x, ps, st)),
let mm = model.m, md = TestMode(), ps = ps, st = st
x -> first(inference(model.m, TestMode(), x, ps, st))
end,
eachcol(xnew),
)
elseif model.compute_mode <: MatrixMode
Expand Down Expand Up @@ -205,7 +211,9 @@ function Distributions._logpdf(d::ICNFDist, x::AbstractVector{<:Real})
end
function Distributions._logpdf(d::ICNFDist, A::AbstractMatrix{<:Real})
if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
broadcast(x -> Distributions._logpdf(d, x), eachcol(A))
broadcast(let d = d
x -> Distributions._logpdf(d, x)
end, eachcol(A))
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
first(inference(d.m, d.mode, A, d.ps, d.st))
else
Expand All @@ -223,7 +231,9 @@ function Distributions._rand!(rng::AbstractRNG, d::ICNFDist, x::AbstractVector{<
end
function Distributions._rand!(rng::AbstractRNG, d::ICNFDist, A::AbstractMatrix{<:Real})
if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
A .= hcat(broadcast(x -> Distributions._rand!(rng, d, x), eachcol(A))...)
A .= hcat(broadcast(let rng = rng, d = d
x -> Distributions._rand!(rng, d, x)
end, eachcol(A))...)
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
A .= generate(d.m, d.mode, d.ps, d.st, size(A, 2))
else
Expand Down
Loading