Skip to content

Commit

Permalink
Move sparse AD code to SparseMatrixColorings extension
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Aug 19, 2024
1 parent 8077dce commit ff044d2
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 41 deletions.
3 changes: 2 additions & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -23,6 +22,7 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -38,6 +38,7 @@ DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings"
DifferentiationInterfaceSymbolicsExt = "Symbolics"
DifferentiationInterfaceTapirExt = "Tapir"
DifferentiationInterfaceTrackerExt = "Tracker"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
module DifferentiationInterfaceSparseMatrixColoringsExt

using ADTypes: AutoSparse, dense_ad
using ADTypes: coloring_algorithm, sparsity_detector, jacobian_sparsity, hessian_sparsity
using Compat
using DifferentiationInterface
using DifferentiationInterface:
Batch,
GradientExtras,
JacobianExtras,
HessianExtras,
HVPExtras,
hvp_batched,
make_seed,
maybe_inner,
pick_batchsize,
pushforward_batched,
pullback_batched,
prepare_gradient,
prepare_hvp_batched_same_point,
prepare_pushforward_batched_same_point,
prepare_pullback_batched_same_point
import DifferentiationInterface as DI
using SparseMatrixColorings:
AbstractColoringResult,
ColoringProblem,
GreedyColoringAlgorithm,
coloring,
column_colors,
row_colors,
column_groups,
row_groups,
decompress,
decompress!

include("fallbacks.jl")
include("jacobian.jl")
include("hessian.jl")

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
DI.check_available(backend::AutoSparse) = DI.check_available(dense_ad(backend))
DI.twoarg_support(backend::AutoSparse) = DI.twoarg_support(dense_ad(backend))

function DI.pushforward_performance(backend::AutoSparse)
return DI.pushforward_performance(dense_ad(backend))
end

DI.pullback_performance(backend::AutoSparse) = DI.pullback_performance(dense_ad(backend))
DI.hvp_mode(backend::AutoSparse{<:SecondOrder}) = DI.hvp_mode(dense_ad(backend))
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ end

## Hessian, one argument

function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
function DI.prepare_hessian(f::F, backend::AutoSparse, x) where {F}
dense_backend = dense_ad(backend)
sparsity = hessian_sparsity(f, x, sparsity_detector(backend))
problem = ColoringProblem{:symmetric,:column}()
Expand Down Expand Up @@ -65,7 +65,9 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
)
end

function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras{B}) where {F,B}
function DI.hessian(
f::F, backend::AutoSparse, x, extras::SparseHessianExtras{B}
) where {F,B}
@compat (; coloring_result, batched_seeds, hvp_batched_extras) = extras
dense_backend = dense_ad(backend)
Ng = length(column_groups(coloring_result))
Expand All @@ -86,7 +88,7 @@ function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras{B}) w
return decompress(compressed_matrix, coloring_result)
end

function hessian!(
function DI.hessian!(
f::F, hess, backend::AutoSparse, x, extras::SparseHessianExtras{B}
) where {F,B}
@compat (;
Expand Down Expand Up @@ -125,7 +127,7 @@ function hessian!(
return hess
end

function value_gradient_and_hessian!(
function DI.value_gradient_and_hessian!(
f::F, grad, hess, backend::AutoSparse, x, extras::SparseHessianExtras
) where {F}
y, _ = value_and_gradient!(
Expand All @@ -135,7 +137,7 @@ function value_gradient_and_hessian!(
return y, grad, hess
end

function value_gradient_and_hessian(
function DI.value_gradient_and_hessian(
f::F, backend::AutoSparse, x, extras::SparseHessianExtras
) where {F}
y, grad = value_and_gradient(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,16 @@ function PullbackSparseJacobianExtras{B}(;
)
end

function prepare_jacobian(f::F, backend::AutoSparse, x) where {F}
function DI.prepare_jacobian(f::F, backend::AutoSparse, x) where {F}
y = f(x)
return prepare_sparse_jacobian_aux(
(f,), backend, x, y, pushforward_performance(backend)
(f,), backend, x, y, DI.pushforward_performance(backend)
)
end

function prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F}
function DI.prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F}
return prepare_sparse_jacobian_aux(
(f!, y), backend, x, y, pushforward_performance(backend)
(f!, y), backend, x, y, DI.pushforward_performance(backend)
)
end

Expand Down Expand Up @@ -149,49 +149,51 @@ end

## One argument

function jacobian(f::F, backend::AutoSparse, x, extras::SparseJacobianExtras) where {F}
function DI.jacobian(f::F, backend::AutoSparse, x, extras::SparseJacobianExtras) where {F}
return sparse_jacobian_aux((f,), backend, x, extras)
end

function jacobian!(
function DI.jacobian!(
f::F, jac, backend::AutoSparse, x, extras::SparseJacobianExtras
) where {F}
return sparse_jacobian_aux!((f,), jac, backend, x, extras)
end

function value_and_jacobian(
function DI.value_and_jacobian(
f::F, backend::AutoSparse, x, extras::SparseJacobianExtras
) where {F}
return f(x), jacobian(f, backend, x, extras)
end

function value_and_jacobian!(
function DI.value_and_jacobian!(
f::F, jac, backend::AutoSparse, x, extras::SparseJacobianExtras
) where {F}
return f(x), jacobian!(f, jac, backend, x, extras)
end

## Two arguments

function jacobian(f!::F, y, backend::AutoSparse, x, extras::SparseJacobianExtras) where {F}
function DI.jacobian(
f!::F, y, backend::AutoSparse, x, extras::SparseJacobianExtras
) where {F}
return sparse_jacobian_aux((f!, y), backend, x, extras)
end

function jacobian!(
function DI.jacobian!(
f!::F, y, jac, backend::AutoSparse, x, extras::SparseJacobianExtras
) where {F}
return sparse_jacobian_aux!((f!, y), jac, backend, x, extras)
end

function value_and_jacobian(
function DI.value_and_jacobian(
f!::F, y, backend::AutoSparse, x, extras::SparseJacobianExtras
) where {F}
jac = jacobian(f!, y, backend, x, extras)
f!(y, x)
return y, jac
end

function value_and_jacobian!(
function DI.value_and_jacobian!(
f!::F, y, jac, backend::AutoSparse, x, extras::SparseJacobianExtras
) where {F}
jacobian!(f!, y, jac, backend, x, extras)
Expand Down
20 changes: 2 additions & 18 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ module DifferentiationInterface
using ADTypes: ADTypes, AbstractADType
using ADTypes: mode, ForwardMode, ForwardOrReverseMode, ReverseMode, SymbolicMode
using ADTypes: AutoSparse, dense_ad
using ADTypes: coloring_algorithm
using ADTypes: sparsity_detector, jacobian_sparsity, hessian_sparsity
using ADTypes: coloring_algorithm, sparsity_detector, jacobian_sparsity, hessian_sparsity
using ADTypes:
AutoChainRules,
AutoDiffractor,
Expand All @@ -33,18 +32,7 @@ using DocStringExtensions
using FillArrays: OneElement
using LinearAlgebra: Symmetric, Transpose, dot, parent, transpose
using PackageExtensionCompat: @require_extensions
using SparseArrays: SparseMatrixCSC, nonzeros, nzrange, rowvals, sparse
using SparseMatrixColorings:
AbstractColoringResult,
ColoringProblem,
GreedyColoringAlgorithm,
coloring,
column_colors,
row_colors,
column_groups,
row_groups,
decompress,
decompress!
using SparseArrays: sparse

abstract type Extras end

Expand Down Expand Up @@ -73,10 +61,6 @@ include("second_order/hessian.jl")

include("fallbacks/no_extras.jl")

include("sparse/fallbacks.jl")
include("sparse/jacobian.jl")
include("sparse/hessian.jl")

include("misc/differentiate_with.jl")
include("misc/sparsity_detector.jl")
include("misc/from_primitive.jl")
Expand Down
5 changes: 0 additions & 5 deletions DifferentiationInterface/src/sparse/fallbacks.jl

This file was deleted.

0 comments on commit ff044d2

Please sign in to comment.