Skip to content

Commit

Permalink
remove zeros_T_AT and rename cuda ext (#349)
Browse files Browse the repository at this point in the history
* remove `zeros_T_AT`

* Format .jl files (#350)

Co-authored-by: prbzrg <prbzrg@users.noreply.github.com>

* remove `rand_cstm_AT`

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: prbzrg <prbzrg@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 19, 2023
1 parent e2b720b commit 9b4662d
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 66 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[extensions]
CUDAExtICNF = "CUDA"
ContinuousNormalizingFlowsCUDAExt = "CUDA"

[compat]
ADTypes = "0.2"
Expand Down
5 changes: 0 additions & 5 deletions ext/CUDAExtICNF/CUDAExtICNF.jl

This file was deleted.

22 changes: 0 additions & 22 deletions ext/CUDAExtICNF/base_cuda.jl

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module ContinuousNormalizingFlowsCUDAExt

using ContinuousNormalizingFlows, CUDA
using ContinuousNormalizingFlows.ComputationalResources

@inline function ContinuousNormalizingFlows.rng_AT(::CUDALibs)
CURAND.default_rng()
end

end
17 changes: 0 additions & 17 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,23 +105,6 @@ end
Random.default_rng()
end

@inline function zeros_T_AT(
::AbstractResource,
::AbstractFlows{T},
dims...,
) where {T <: AbstractFloat}
zeros(T, dims...)
end

@inline function rand_cstm_AT(
::AbstractResource,
icnf::AbstractFlows{T},
cstm::Any,
dims...,
) where {T <: AbstractFloat}
convert.(T, rand(icnf.rng, cstm, dims...))
end

@views function inference_sol(
icnf::AbstractFlows{T, <:VectorMode, INPLACE},
mode::Mode,
Expand Down
18 changes: 11 additions & 7 deletions src/base_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ export inference, generate, loss
) where {T <: AbstractFloat, INPLACE}
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
zrs = zeros_T_AT(icnf.resource, icnf, n_aug_input + n_aug + 1)
zrs = similar(xs, n_aug_input + n_aug + 1)
@ignore_derivatives fill!(zrs, zero(T))
ϵ = randn(icnf.rng, T, icnf.nvars + n_aug_input)
ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
ifelse(
Expand Down Expand Up @@ -38,7 +39,8 @@ end
) where {T <: AbstractFloat, INPLACE}
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
zrs = zeros_T_AT(icnf.resource, icnf, n_aug_input + n_aug + 1, size(xs, 2))
zrs = similar(xs, n_aug_input + n_aug + 1, size(xs, 2))
@ignore_derivatives fill!(zrs, zero(T))
ϵ = randn(icnf.rng, T, icnf.nvars + n_aug_input, size(xs, 2))
ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
ifelse(
Expand All @@ -65,9 +67,10 @@ end
) where {T <: AbstractFloat, INPLACE}
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
new_xs = rand_cstm_AT(icnf.resource, icnf, icnf.basedist)
zrs = zeros_T_AT(icnf.resource, icnf, n_aug + 1)
ϵ = randn(icnf.rng, T, icnf.nvars + n_aug_input)
new_xs = oftype(ϵ, rand(icnf.rng, icnf.basedist))
zrs = similar(new_xs, n_aug + 1)
@ignore_derivatives fill!(zrs, zero(T))
ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
ifelse(
INPLACE,
Expand All @@ -94,9 +97,10 @@ end
) where {T <: AbstractFloat, INPLACE}
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
new_xs = rand_cstm_AT(icnf.resource, icnf, icnf.basedist, n)
zrs = zeros_T_AT(icnf.resource, icnf, n_aug + 1, size(new_xs, 2))
ϵ = randn(icnf.rng, T, icnf.nvars + n_aug_input, size(new_xs, 2))
ϵ = randn(icnf.rng, T, icnf.nvars + n_aug_input, n)
new_xs = oftype(ϵ, rand(icnf.rng, icnf.basedist, n))
zrs = similar(new_xs, n_aug + 1, n)
@ignore_derivatives fill!(zrs, zero(T))
ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
ifelse(
INPLACE,
Expand Down
18 changes: 11 additions & 7 deletions src/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ export inference, generate, loss
) where {T <: AbstractFloat, INPLACE}
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
zrs = zeros_T_AT(icnf.resource, icnf, n_aug_input + n_aug + 1)
zrs = similar(xs, n_aug_input + n_aug + 1)
@ignore_derivatives fill!(zrs, zero(T))
ϵ = randn(icnf.rng, T, icnf.nvars + n_aug_input)
ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
ifelse(
Expand All @@ -36,7 +37,8 @@ end
) where {T <: AbstractFloat, INPLACE}
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
zrs = zeros_T_AT(icnf.resource, icnf, n_aug_input + n_aug + 1, size(xs, 2))
zrs = similar(xs, n_aug_input + n_aug + 1, size(xs, 2))
@ignore_derivatives fill!(zrs, zero(T))
ϵ = randn(icnf.rng, T, icnf.nvars + n_aug_input, size(xs, 2))
ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
ifelse(
Expand All @@ -62,9 +64,10 @@ end
) where {T <: AbstractFloat, INPLACE}
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
new_xs = rand_cstm_AT(icnf.resource, icnf, icnf.basedist)
zrs = zeros_T_AT(icnf.resource, icnf, n_aug + 1)
ϵ = randn(icnf.rng, T, icnf.nvars + n_aug_input)
new_xs = oftype(ϵ, rand(icnf.rng, icnf.basedist))
zrs = similar(new_xs, n_aug + 1)
@ignore_derivatives fill!(zrs, zero(T))
ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
ifelse(
INPLACE,
Expand All @@ -90,9 +93,10 @@ end
) where {T <: AbstractFloat, INPLACE}
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
new_xs = rand_cstm_AT(icnf.resource, icnf, icnf.basedist, n)
zrs = zeros_T_AT(icnf.resource, icnf, n_aug + 1, size(new_xs, 2))
ϵ = randn(icnf.rng, T, icnf.nvars + n_aug_input, size(new_xs, 2))
ϵ = randn(icnf.rng, T, icnf.nvars + n_aug_input, n)
new_xs = oftype(ϵ, rand(icnf.rng, icnf.basedist, n))
zrs = similar(new_xs, n_aug + 1, n)
@ignore_derivatives fill!(zrs, zero(T))
ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
ifelse(
INPLACE,
Expand Down
17 changes: 10 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
function jacobian_batched(
@views function jacobian_batched(
icnf::AbstractFlows{T, <:SDVecJacMatrixMode},
f,
xs::AbstractMatrix{<:Real},
) where {T <: AbstractFloat}
y = f(xs)
z = zeros_T_AT(icnf.resource, icnf, size(xs))
z = similar(xs)
@ignore_derivatives fill!(z, zero(T))
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
Jf = VecJac(f, xs; autodiff = icnf.autodiff_backend)
for i in axes(xs, 1)
Expand All @@ -15,13 +16,14 @@ function jacobian_batched(
y, copy(res)
end

function jacobian_batched(
@views function jacobian_batched(
icnf::AbstractFlows{T, <:SDJacVecMatrixMode},
f,
xs::AbstractMatrix{<:Real},
) where {T <: AbstractFloat}
y = f(xs)
z = zeros_T_AT(icnf.resource, icnf, size(xs))
z = similar(xs)
@ignore_derivatives fill!(z, zero(T))
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
Jf = JacVec(f, xs; autodiff = icnf.autodiff_backend)
for i in axes(xs, 1)
Expand All @@ -32,13 +34,14 @@ function jacobian_batched(
y, copy(res)
end

function jacobian_batched(
icnf::AbstractFlows{T, <:ZygoteMatrixMode},
@views function jacobian_batched(
::AbstractFlows{T, <:ZygoteMatrixMode},
f,
xs::AbstractMatrix{<:Real},
) where {T <: AbstractFloat}
y, VJ = Zygote.pullback(f, xs)
z = zeros_T_AT(icnf.resource, icnf, size(xs))
z = similar(xs)
@ignore_derivatives fill!(z, zero(T))
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
for i in axes(xs, 1)
@ignore_derivatives z[i, :] .= one(T)
Expand Down

0 comments on commit 9b4662d

Please sign in to comment.