Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type stability of Jacobians and Hessian, fix test scenarios #337

Merged
merged 6 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
version:
- '1'
- '1.6'
- '~1.11.0-0'
# - '~1.11.0-0'
group:
- Formalities
- Internals
Expand Down Expand Up @@ -118,7 +118,7 @@ jobs:
version:
- '1'
- '1.6'
- '~1.11.0-0'
# - '~1.11.0-0'
group:
- Formalities
- Zero
Expand Down
72 changes: 43 additions & 29 deletions DifferentiationInterface/src/first_order/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,18 @@ abstract type JacobianExtras <: Extras end

struct NoJacobianExtras <: JacobianExtras end

struct PushforwardJacobianExtras{B,D,E<:PushforwardExtras,Y} <: JacobianExtras
struct PushforwardJacobianExtras{B,D,R,E<:PushforwardExtras} <: JacobianExtras
batched_seeds::Vector{Batch{B,D}}
batched_results::Vector{Batch{B,R}}
pushforward_batched_extras::E
y_example::Y
N::Int
end

struct PullbackJacobianExtras{B,D,E<:PullbackExtras,Y} <: JacobianExtras
struct PullbackJacobianExtras{B,D,R,E<:PullbackExtras} <: JacobianExtras
batched_seeds::Vector{Batch{B,D}}
batched_results::Vector{Batch{B,R}}
pullback_batched_extras::E
y_example::Y
M::Int
end

function prepare_jacobian(f::F, backend::AbstractADType, x) where {F}
Expand All @@ -85,14 +87,15 @@ function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardFast) wh
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for
a in 1:div(N, B, RoundUp)
])
batched_results = Batch.([ntuple(b -> similar(y), Val(B)) for _ in batched_seeds])
pushforward_batched_extras = prepare_pushforward_batched(
f_or_f!y..., backend, x, batched_seeds[1]
)
D = eltype(seeds)
D = eltype(batched_seeds[1])
R = eltype(batched_results[1])
E = typeof(pushforward_batched_extras)
Y = typeof(y)
return PushforwardJacobianExtras{B,D,E,Y}(
batched_seeds, pushforward_batched_extras, copy(y)
return PushforwardJacobianExtras{B,D,R,E}(
batched_seeds, batched_results, pushforward_batched_extras, N
)
end

Expand All @@ -105,13 +108,16 @@ function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardSlow) wh
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % M], Val(B)) for
a in 1:div(M, B, RoundUp)
])
batched_results = Batch.([ntuple(b -> similar(x), Val(B)) for _ in batched_seeds])
pullback_batched_extras = prepare_pullback_batched(
f_or_f!y..., backend, x, batched_seeds[1]
)
D = eltype(seeds)
D = eltype(batched_seeds[1])
R = eltype(batched_results[1])
E = typeof(pullback_batched_extras)
Y = typeof(y)
return PullbackJacobianExtras{B,D,E,Y}(batched_seeds, pullback_batched_extras, copy(y))
return PullbackJacobianExtras{B,D,R,E}(
batched_seeds, batched_results, pullback_batched_extras, M
)
end

## One argument
Expand Down Expand Up @@ -209,8 +215,7 @@ end
function jacobian_aux(
f_or_f!y::FY, backend, x::AbstractArray, extras::PushforwardJacobianExtras{B}
) where {FY,B}
@compat (; batched_seeds, pushforward_batched_extras, y_example) = extras
N = length(x)
@compat (; batched_seeds, pushforward_batched_extras, N) = extras

pushforward_batched_extras_same = prepare_pushforward_batched_same_point(
f_or_f!y..., backend, x, batched_seeds[1], pushforward_batched_extras
Expand All @@ -233,8 +238,7 @@ end
function jacobian_aux(
f_or_f!y::FY, backend, x::AbstractArray, extras::PullbackJacobianExtras{B}
) where {FY,B}
@compat (; batched_seeds, pullback_batched_extras, y_example) = extras
M = length(y_example)
@compat (; batched_seeds, pullback_batched_extras, M) = extras

pullback_batched_extras_same = prepare_pullback_batched_same_point(
f_or_f!y..., backend, x, batched_seeds[1], extras.pullback_batched_extras
Expand All @@ -261,27 +265,32 @@ function jacobian_aux!(
x::AbstractArray,
extras::PushforwardJacobianExtras{B},
) where {FY,B}
@compat (; batched_seeds, pushforward_batched_extras, y_example) = extras
N = length(x)
@compat (; batched_seeds, batched_results, pushforward_batched_extras, N) = extras

pushforward_batched_extras_same = prepare_pushforward_batched_same_point(
f_or_f!y..., backend, x, batched_seeds[1], pushforward_batched_extras
)

for a in eachindex(batched_seeds)
dy_batch_elements = ntuple(Val(B)) do b
reshape(view(jac, :, 1 + ((a - 1) * B + (b - 1)) % N), size(y_example))
end
for a in eachindex(batched_seeds, batched_results)
pushforward_batched!(
f_or_f!y...,
Batch(dy_batch_elements),
batched_results[a],
backend,
x,
batched_seeds[a],
pushforward_batched_extras_same,
)
end

for a in eachindex(batched_results)
for b in eachindex(batched_results[a].elements)
copyto!(
view(jac, :, 1 + ((a - 1) * B + (b - 1)) % N),
vec(batched_results[a].elements[b]),
)
end
end

return jac
end

Expand All @@ -292,26 +301,31 @@ function jacobian_aux!(
x::AbstractArray,
extras::PullbackJacobianExtras{B},
) where {FY,B}
@compat (; batched_seeds, pullback_batched_extras, y_example) = extras
M = length(y_example)
@compat (; batched_seeds, batched_results, pullback_batched_extras, M) = extras

pullback_batched_extras_same = prepare_pullback_batched_same_point(
f_or_f!y..., backend, x, batched_seeds[1], extras.pullback_batched_extras
)

for a in eachindex(batched_seeds)
dx_batch_elements = ntuple(Val(B)) do b
reshape(view(jac, 1 + ((a - 1) * B + (b - 1)) % M, :), size(x))
end
for a in eachindex(batched_seeds, batched_results)
pullback_batched!(
f_or_f!y...,
Batch(dx_batch_elements),
batched_results[a],
backend,
x,
batched_seeds[a],
pullback_batched_extras_same,
)
end

for a in eachindex(batched_results)
for b in eachindex(batched_results[a].elements)
copyto!(
view(jac, 1 + ((a - 1) * B + (b - 1)) % M, :),
vec(batched_results[a].elements[b]),
)
end
end

return jac
end
39 changes: 21 additions & 18 deletions DifferentiationInterface/src/second_order/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ abstract type HessianExtras <: Extras end

struct NoHessianExtras <: HessianExtras end

struct HVPGradientHessianExtras{B,D,E2<:HVPExtras,E1<:GradientExtras} <: HessianExtras
struct HVPGradientHessianExtras{B,D,R,E2<:HVPExtras,E1<:GradientExtras} <: HessianExtras
batched_seeds::Vector{Batch{B,D}}
batched_results::Vector{Batch{B,R}}
hvp_batched_extras::E2
gradient_extras::E1
N::Int
end

function prepare_hessian(f::F, backend::AbstractADType, x) where {F}
Expand All @@ -64,12 +66,14 @@ function prepare_hessian(f::F, backend::AbstractADType, x) where {F}
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for
a in 1:div(N, B, RoundUp)
])
batched_results = Batch.([ntuple(b -> similar(x), Val(B)) for _ in batched_seeds])
hvp_batched_extras = prepare_hvp_batched(f, backend, x, batched_seeds[1])
gradient_extras = prepare_gradient(f, maybe_inner(backend), x)
D = eltype(seeds)
D = eltype(batched_seeds[1])
R = eltype(batched_results[1])
E2, E1 = typeof(hvp_batched_extras), typeof(gradient_extras)
return HVPGradientHessianExtras{B,D,E2,E1}(
batched_seeds, hvp_batched_extras, gradient_extras
return HVPGradientHessianExtras{B,D,R,E2,E1}(
batched_seeds, batched_results, hvp_batched_extras, gradient_extras, N
)
end

Expand Down Expand Up @@ -100,8 +104,7 @@ end
function hessian(
f::F, backend::AbstractADType, x, extras::HVPGradientHessianExtras{B}
) where {F,B}
@compat (; batched_seeds, hvp_batched_extras) = extras
N = length(x)
@compat (; batched_seeds, hvp_batched_extras, N) = extras

hvp_batched_extras_same = prepare_hvp_batched_same_point(
f, backend, x, batched_seeds[1], hvp_batched_extras
Expand All @@ -122,27 +125,27 @@ end
function hessian!(
f::F, hess, backend::AbstractADType, x, extras::HVPGradientHessianExtras{B}
) where {F,B}
@compat (; batched_seeds, hvp_batched_extras) = extras
N = length(x)
@compat (; batched_seeds, batched_results, hvp_batched_extras, N) = extras

hvp_batched_extras_same = prepare_hvp_batched_same_point(
f, backend, x, batched_seeds[1], hvp_batched_extras
)

for a in eachindex(batched_seeds)
dg_batch_elements = ntuple(Val(B)) do b
reshape(view(hess, :, 1 + ((a - 1) * B + (b - 1)) % N), size(x))
end
for a in eachindex(batched_seeds, batched_results)
hvp_batched!(
f,
Batch(dg_batch_elements),
backend,
x,
batched_seeds[a],
hvp_batched_extras_same,
f, batched_results[a], backend, x, batched_seeds[a], hvp_batched_extras_same
)
end

for a in eachindex(batched_results)
for b in eachindex(batched_results[a].elements)
copyto!(
view(hess, :, 1 + ((a - 1) * B + (b - 1)) % N),
vec(batched_results[a].elements[b]),
)
end
end

return hess
end

Expand Down
38 changes: 28 additions & 10 deletions DifferentiationInterface/src/sparse/hessian.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
struct SparseHessianExtras{
B,S<:AbstractMatrix{Bool},C<:AbstractMatrix{<:Real},D,E2<:HVPExtras,E1<:GradientExtras
B,S<:AbstractMatrix{Bool},C<:AbstractMatrix{<:Real},D,R,E2<:HVPExtras,E1<:GradientExtras
} <: HessianExtras
sparsity::S
colors::Vector{Int}
groups::Vector{Vector{Int}}
compressed::C
batched_seeds::Vector{Batch{B,D}}
batched_results::Vector{Batch{B,R}}
hvp_batched_extras::E2
gradient_extras::E1
end
Expand All @@ -16,16 +17,18 @@ function SparseHessianExtras{B}(;
groups,
compressed::C,
batched_seeds::Vector{Batch{B,D}},
batched_results::Vector{Batch{B,R}},
hvp_batched_extras::E2,
gradient_extras::E1,
) where {B,S,C,D,E2,E1}
) where {B,S,C,D,R,E2,E1}
@assert size(sparsity, 1) == size(sparsity, 2) == size(compressed, 1) == length(colors)
return SparseHessianExtras{B,S,C,D,E2,E1}(
return SparseHessianExtras{B,S,C,D,R,E2,E1}(
sparsity,
colors,
groups,
compressed,
batched_seeds,
batched_results,
hvp_batched_extras,
gradient_extras,
)
Expand All @@ -48,6 +51,7 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % Ng], Val(B)) for
a in 1:div(Ng, B, RoundUp)
])
batched_results = Batch.([ntuple(b -> similar(x), Val(B)) for _ in batched_seeds])
hvp_batched_extras = prepare_hvp_batched(f, dense_backend, x, batched_seeds[1])
gradient_extras = prepare_gradient(f, maybe_inner(dense_backend), x)
return SparseHessianExtras{B}(;
Expand All @@ -56,6 +60,7 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
groups,
compressed,
batched_seeds,
batched_results,
hvp_batched_extras,
gradient_extras,
)
Expand Down Expand Up @@ -86,29 +91,42 @@ end
function hessian!(
f::F, hess, backend::AutoSparse, x, extras::SparseHessianExtras{B}
) where {F,B}
@compat (; sparsity, compressed, colors, groups, batched_seeds, hvp_batched_extras) =
extras
@compat (;
sparsity,
compressed,
colors,
groups,
batched_seeds,
batched_results,
hvp_batched_extras,
) = extras
dense_backend = dense_ad(backend)
Ng = length(groups)

hvp_batched_extras_same = prepare_hvp_batched_same_point(
f, dense_backend, x, batched_seeds[1], hvp_batched_extras
)

for a in 1:div(Ng, B, RoundUp)
dg_batch_elements = ntuple(Val(B)) do b
reshape(view(compressed, :, 1 + ((a - 1) * B + (b - 1)) % Ng), size(x))
end
for a in eachindex(batched_seeds, batched_results)
hvp_batched!(
f,
Batch(dg_batch_elements),
batched_results[a],
dense_backend,
x,
batched_seeds[a],
hvp_batched_extras_same,
)
end

for a in eachindex(batched_results)
for b in eachindex(batched_results[a].elements)
copyto!(
view(compressed, :, 1 + ((a - 1) * B + (b - 1)) % Ng),
vec(batched_results[a].elements[b]),
)
end
end

decompress_symmetric!(hess, sparsity, compressed, colors)
return hess
end
Expand Down
Loading
Loading