Skip to content

Commit

Permalink
Cleaning (#298)
Browse files Browse the repository at this point in the history
* use `let`

* remove `fnn`
  • Loading branch information
prbzrg authored Sep 28, 2023
1 parent d952851 commit ee77711
Show file tree
Hide file tree
Showing 12 changed files with 77 additions and 75 deletions.
3 changes: 0 additions & 3 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ function construct(
rng::AbstractRNG = Random.default_rng(),
)
steerdist = Uniform{data_type}(-steer_rate, steer_rate)
_fnn(x, ps, st) = first(nn(x, ps, st))

aicnf{
data_type,
Expand All @@ -49,7 +48,6 @@ function construct(
typeof(autodiff_backend),
typeof(sol_kwargs),
typeof(rng),
typeof(_fnn),
}(
nn,
nvars,
Expand All @@ -62,7 +60,6 @@ function construct(
autodiff_backend,
sol_kwargs,
rng,
_fnn,
)
end

Expand Down
4 changes: 2 additions & 2 deletions src/base_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ end
mz, J = AbstractDifferentiation.value_and_jacobian(
icnf.differentiation_backend,
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z,
)
Expand All @@ -179,7 +179,7 @@ end
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz, J = jacobian_batched(icnf, let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end, z)
trace_J = transpose(tr.(eachslice(J; dims = 3)))
cat(mz, -trace_J; dims = 1)
Expand Down
4 changes: 2 additions & 2 deletions src/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ end
mz, J = AbstractDifferentiation.value_and_jacobian(
icnf.differentiation_backend,
let p = p, st = st
x -> icnf._fnn(x, p, st)
x -> first(icnf.nn(x, p, st))
end,
z,
)
Expand All @@ -168,7 +168,7 @@ end
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz, J = jacobian_batched(icnf, let p = p, st = st
x -> icnf._fnn(x, p, st)
x -> first(icnf.nn(x, p, st))
end, z)
trace_J = transpose(tr.(eachslice(J; dims = 3)))
cat(mz, -trace_J; dims = 1)
Expand Down
14 changes: 6 additions & 8 deletions src/cond_ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ struct CondFFJORD{
AUTODIFF_BACKEND <: ADTypes.AbstractADType,
SOL_KWARGS <: Dict,
RNG <: AbstractRNG,
_FNN <: Function,
} <: AbstractCondICNF{T, CM, AUGMENTED, STEER}
nn::NN
nvars::NVARS
Expand All @@ -32,7 +31,6 @@ struct CondFFJORD{
autodiff_backend::AUTODIFF_BACKEND
sol_kwargs::SOL_KWARGS
rng::RNG
_fnn::_FNN
end

@views function augmented_f(
Expand All @@ -50,7 +48,7 @@ end
v_pb = AbstractDifferentiation.value_and_pullback_function(
icnf.differentiation_backend,
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z,
)
Expand All @@ -73,7 +71,7 @@ end
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz, back = Zygote.pullback(let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end, z)
ϵJ = only(back(ϵ))
trace_J = sum(ϵJ .* ϵ; dims = 1)
Expand All @@ -92,10 +90,10 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz = icnf._fnn(cat(z, ys; dims = 1), p, st)
mz = first(icnf.nn(cat(z, ys; dims = 1), p, st))
Jf = VecJac(
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z;
autodiff = icnf.autodiff_backend,
Expand All @@ -117,10 +115,10 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz = icnf._fnn(cat(z, ys; dims = 1), p, st)
mz = first(icnf.nn(cat(z, ys; dims = 1), p, st))
Jf = JacVec(
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z;
autodiff = icnf.autodiff_backend,
Expand Down
16 changes: 7 additions & 9 deletions src/cond_planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ struct CondPlanar{
AUTODIFF_BACKEND <: ADTypes.AbstractADType,
SOL_KWARGS <: Dict,
RNG <: AbstractRNG,
_FNN <: Function,
} <: AbstractCondICNF{T, CM, AUGMENTED, STEER}
nn::NN
nvars::NVARS
Expand All @@ -32,7 +31,6 @@ struct CondPlanar{
autodiff_backend::AUTODIFF_BACKEND
sol_kwargs::SOL_KWARGS
rng::RNG
_fnn::_FNN
end

@views function augmented_f(
Expand All @@ -47,7 +45,7 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1)]
mz = icnf._fnn(cat(z, ys; dims = 1), p, st)
mz = first(icnf.nn(cat(z, ys; dims = 1), p, st))
trace_J =
p.u transpose(
only(
Expand Down Expand Up @@ -75,7 +73,7 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1)]
mz = icnf._fnn(cat(z, ys; dims = 1), p, st)
mz = first(icnf.nn(cat(z, ys; dims = 1), p, st))
trace_J =
p.u transpose(
only(
Expand Down Expand Up @@ -104,7 +102,7 @@ end
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz, back = Zygote.pullback(let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end, z)
ϵJ = only(back(ϵ))
trace_J = sum(ϵJ .* ϵ; dims = 1)
Expand All @@ -123,10 +121,10 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz = icnf._fnn(cat(z, ys; dims = 1), p, st)
mz = first(icnf.nn(cat(z, ys; dims = 1), p, st))
Jf = VecJac(
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z;
autodiff = icnf.autodiff_backend,
Expand All @@ -148,10 +146,10 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz = icnf._fnn(cat(z, ys; dims = 1), p, st)
mz = first(icnf.nn(cat(z, ys; dims = 1), p, st))
Jf = JacVec(
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z;
autodiff = icnf.autodiff_backend,
Expand Down
17 changes: 6 additions & 11 deletions src/cond_rnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ struct CondRNODE{
AUTODIFF_BACKEND <: ADTypes.AbstractADType,
SOL_KWARGS <: Dict,
RNG <: AbstractRNG,
_FNN <: Function,
} <: AbstractCondICNF{T, CM, AUGMENTED, STEER}
nn::NN
nvars::NVARS
Expand All @@ -32,7 +31,6 @@ struct CondRNODE{
autodiff_backend::AUTODIFF_BACKEND
sol_kwargs::SOL_KWARGS
rng::RNG
_fnn::_FNN
λ₁::T
λ₂::T
end
Expand Down Expand Up @@ -71,7 +69,6 @@ function construct(
λ₂::AbstractFloat = convert(data_type, 1e-2),
)
steerdist = Uniform{data_type}(-steer_rate, steer_rate)
_fnn(x, ps, st) = first(nn(x, ps, st))

aicnf{
data_type,
Expand All @@ -88,7 +85,6 @@ function construct(
typeof(autodiff_backend),
typeof(sol_kwargs),
typeof(rng),
typeof(_fnn),
}(
nn,
nvars,
Expand All @@ -101,7 +97,6 @@ function construct(
autodiff_backend,
sol_kwargs,
rng,
_fnn,
λ₁,
λ₂,
)
Expand All @@ -122,7 +117,7 @@ end
v_pb = AbstractDifferentiation.value_and_pullback_function(
icnf.differentiation_backend,
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z,
)
Expand All @@ -147,7 +142,7 @@ end
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
ż, back = Zygote.pullback(let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end, z)
ϵJ = only(back(ϵ))
= sum(ϵJ .* ϵ; dims = 1)
Expand All @@ -168,10 +163,10 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
= icnf._fnn(cat(z, ys; dims = 1), p, st)
= first(icnf.nn(cat(z, ys; dims = 1), p, st))
Jf = VecJac(
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z;
autodiff = icnf.autodiff_backend,
Expand All @@ -195,10 +190,10 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
= icnf._fnn(cat(z, ys; dims = 1), p, st)
= first(icnf.nn(cat(z, ys; dims = 1), p, st))
Jf = JacVec(
let ys = ys, p = p, st = st
x -> icnf._fnn(cat(x, ys; dims = 1), p, st)
x -> first(icnf.nn(cat(x, ys; dims = 1), p, st))
end,
z;
autodiff = icnf.autodiff_backend,
Expand Down
23 changes: 17 additions & 6 deletions src/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
st = gdev(st)
end
optfunc = OptimizationFunction(
(ps_, θ, xs_, ys_) -> model.loss(model.m, TrainMode(), xs_, ys_, ps_, st),
let mm = model.m, md = TrainMode(), st = st
(ps_, θ, xs_, ys_) -> model.loss(mm, md, xs_, ys_, ps_, st)
end,
model.adtype,
)
optprob = OptimizationProblem(optfunc, ps)
Expand Down Expand Up @@ -99,7 +101,9 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
optprob_re,
opt,
data;
callback = (ps_, l_) -> callback_f(ps_, l_, model.m, prgr, itr_n),
callback = let mm = model.m, prgr = prgr, itr_n = itr_n
(ps_, l_) -> callback_f(ps_, l_, mm, prgr, itr_n)
end,
)
ProgressMeter.finish!(prgr)

Expand Down Expand Up @@ -144,8 +148,11 @@ function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew)

if model.compute_mode <: VectorMode
tst = @timed logp̂x = broadcast(
((x, y),) -> first(inference(model.m, TestMode(), x, y, ps, st)),
zip(eachcol(xnew), eachcol(ynew)),
let mm = model.m, md = TestMode(), ps = ps, st = st
(x, y) -> first(inference(mm, md, x, y, ps, st))
end,
eachcol(xnew),
eachcol(ynew),
)
elseif model.compute_mode <: MatrixMode
tst = @timed logp̂x = first(inference(model.m, TestMode(), xnew, ynew, ps, st))
Expand Down Expand Up @@ -218,7 +225,9 @@ function Distributions._logpdf(d::CondICNFDist, x::AbstractVector{<:Real})
end
function Distributions._logpdf(d::CondICNFDist, A::AbstractMatrix{<:Real})
if d.m isa AbstractCondICNF{<:AbstractFloat, <:VectorMode}
broadcast(x -> Distributions._logpdf(d, x), eachcol(A))
broadcast(let d = d
x -> Distributions._logpdf(d, x)
end, eachcol(A))
elseif d.m isa AbstractCondICNF{<:AbstractFloat, <:MatrixMode}
first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st))
else
Expand All @@ -236,7 +245,9 @@ function Distributions._rand!(rng::AbstractRNG, d::CondICNFDist, x::AbstractVect
end
function Distributions._rand!(rng::AbstractRNG, d::CondICNFDist, A::AbstractMatrix{<:Real})
if d.m isa AbstractCondICNF{<:AbstractFloat, <:VectorMode}
A .= hcat(broadcast(x -> Distributions._rand!(rng, d, x), eachcol(A))...)
A .= hcat(broadcast(let rng = rng, d = d
x -> Distributions._rand!(rng, d, x)
end, eachcol(A))...)
elseif d.m isa AbstractCondICNF{<:AbstractFloat, <:MatrixMode}
A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2))
else
Expand Down
20 changes: 15 additions & 5 deletions src/core_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
st = gdev(st)
end
optfunc = OptimizationFunction(
(ps_, θ, xs_) -> model.loss(model.m, TrainMode(), xs_, ps_, st),
let mm = model.m, md = TrainMode(), st = st
(ps_, θ, xs_) -> model.loss(mm, md, xs_, ps_, st)
end,
model.adtype,
)
optprob = OptimizationProblem(optfunc, ps)
Expand Down Expand Up @@ -97,7 +99,9 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
optprob_re,
opt,
data;
callback = (ps_, l_) -> callback_f(ps_, l_, model.m, prgr, itr_n),
callback = let mm = model.m, prgr = prgr, itr_n = itr_n
(ps_, l_) -> callback_f(ps_, l_, mm, prgr, itr_n)
end,
)
ProgressMeter.finish!(prgr)
else
Expand Down Expand Up @@ -138,7 +142,9 @@ function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew)

if model.compute_mode <: VectorMode
tst = @timed logp̂x = broadcast(
x -> first(inference(model.m, TestMode(), x, ps, st)),
let mm = model.m, md = TestMode(), ps = ps, st = st
x -> first(inference(model.m, TestMode(), x, ps, st))
end,
eachcol(xnew),
)
elseif model.compute_mode <: MatrixMode
Expand Down Expand Up @@ -205,7 +211,9 @@ function Distributions._logpdf(d::ICNFDist, x::AbstractVector{<:Real})
end
function Distributions._logpdf(d::ICNFDist, A::AbstractMatrix{<:Real})
if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
broadcast(x -> Distributions._logpdf(d, x), eachcol(A))
broadcast(let d = d
x -> Distributions._logpdf(d, x)
end, eachcol(A))
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
first(inference(d.m, d.mode, A, d.ps, d.st))
else
Expand All @@ -223,7 +231,9 @@ function Distributions._rand!(rng::AbstractRNG, d::ICNFDist, x::AbstractVector{<
end
function Distributions._rand!(rng::AbstractRNG, d::ICNFDist, A::AbstractMatrix{<:Real})
if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
A .= hcat(broadcast(x -> Distributions._rand!(rng, d, x), eachcol(A))...)
A .= hcat(broadcast(let rng = rng, d = d
x -> Distributions._rand!(rng, d, x)
end, eachcol(A))...)
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
A .= generate(d.m, d.mode, d.ps, d.st, size(A, 2))
else
Expand Down
Loading

0 comments on commit ee77711

Please sign in to comment.