Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BREAKING] Move sparse functionality into package extensions #448

Merged
merged 20 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -23,6 +21,8 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -38,6 +38,8 @@ DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings"
DifferentiationInterfaceSymbolicsExt = "Symbolics"
DifferentiationInterfaceTapirExt = "Tapir"
DifferentiationInterfaceTrackerExt = "Tracker"
Expand Down
7 changes: 4 additions & 3 deletions DifferentiationInterface/docs/src/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,17 @@ For this to work, three ingredients are needed (read [this survey](https://epubs
- [`DenseSparsityDetector`](@ref) from DifferentiationInterface.jl (beware that this detector only gives a locally valid pattern)
3. A coloring algorithm: [`GreedyColoringAlgorithm`](@extref SparseMatrixColorings.GreedyColoringAlgorithm) from [SparseMatrixColorings.jl](https://github.com/gdalle/SparseMatrixColorings.jl) is the only one we support.

!!! warning
Generic sparse AD is now located in a package extension which depends on SparseMatrixColorings.jl.

These ingredients can be combined within the [`AutoSparse`](@extref ADTypes.AutoSparse) wrapper, which DifferentiationInterface.jl re-exports.
Note that for sparse Hessians, you need to put the `SecondOrder` backend inside `AutoSparse`, and not the other way around.
`AutoSparse` backends only support operators [`jacobian`](@ref) and [`hessian`](@ref) (as well as their variants).

The preparation step of `jacobian` or `hessian` with an `AutoSparse` backend can be long, because it needs to detect the sparsity pattern and color the resulting sparse matrix.
But after preparation, the more zeros are present in the matrix, the greater the speedup will be compared to dense differentiation.

!!! danger
`AutoSparse` backends only support operators [`jacobian`](@ref) and [`hessian`](@ref) (as well as their variants).

!!! warning
The result of preparation for an `AutoSparse` backend cannot be reused if the sparsity pattern changes.

!!! info
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module DifferentiationInterfaceSparseArraysExt

using ADTypes: ADTypes
using Compat
using DifferentiationInterface
using DifferentiationInterface:
DenseSparsityDetector, PushforwardFast, PushforwardSlow, basis, pushforward_performance
import DifferentiationInterface as DI
using SparseArrays: SparseMatrixCSC, nonzeros, nzrange, rowvals, sparse

include("sparsity_detector.jl")

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
## Direct

function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector{:direct})
@compat (; backend, atol) = detector
J = jacobian(f, backend, x)
return sparse(abs.(J) .> atol)
end

function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector{:direct})
@compat (; backend, atol) = detector
J = jacobian(f!, y, backend, x)
return sparse(abs.(J) .> atol)
end

function ADTypes.hessian_sparsity(f, x, detector::DenseSparsityDetector{:direct})
@compat (; backend, atol) = detector
H = hessian(f, backend, x)
return sparse(abs.(H) .> atol)
end

## Iterative

function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector{:iterative})
@compat (; backend, atol) = detector
y = f(x)
n, m = length(x), length(y)
I, J = Int[], Int[]
if pushforward_performance(backend) isa PushforwardFast
p = similar(y)
extras = prepare_pushforward_same_point(
f, backend, x, basis(backend, x, first(CartesianIndices(x)))
)
for (kj, j) in enumerate(CartesianIndices(x))
pushforward!(f, p, extras, backend, x, basis(backend, x, j))
for ki in LinearIndices(p)
if abs(p[ki]) > atol
push!(I, ki)
push!(J, kj)
end
end
end
else
p = similar(x)
extras = prepare_pullback_same_point(
f, backend, x, basis(backend, y, first(CartesianIndices(y)))
)
for (ki, i) in enumerate(CartesianIndices(y))
pullback!(f, p, extras, backend, x, basis(backend, y, i))
for kj in LinearIndices(p)
if abs(p[kj]) > atol
push!(I, ki)
push!(J, kj)
end
end
end
end
return sparse(I, J, ones(Bool, length(I)), m, n)
end

function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector{:iterative})
@compat (; backend, atol) = detector
n, m = length(x), length(y)
I, J = Int[], Int[]
if pushforward_performance(backend) isa PushforwardFast
p = similar(y)
extras = prepare_pushforward_same_point(
f!, y, backend, x, basis(backend, x, first(CartesianIndices(x)))
)
for (kj, j) in enumerate(CartesianIndices(x))
pushforward!(f!, y, p, extras, backend, x, basis(backend, x, j))
for ki in LinearIndices(p)
if abs(p[ki]) > atol
push!(I, ki)
push!(J, kj)
end
end
end
else
p = similar(x)
extras = prepare_pullback_same_point(
f!, y, backend, x, basis(backend, y, first(CartesianIndices(y)))
)
for (ki, i) in enumerate(CartesianIndices(y))
pullback!(f!, y, p, extras, backend, x, basis(backend, y, i))
for kj in LinearIndices(p)
if abs(p[kj]) > atol
push!(I, ki)
push!(J, kj)
end
end
end
end
return sparse(I, J, ones(Bool, length(I)), m, n)
end

function ADTypes.hessian_sparsity(f, x, detector::DenseSparsityDetector{:iterative})
@compat (; backend, atol) = detector
n = length(x)
I, J = Int[], Int[]
p = similar(x)
extras = prepare_hvp_same_point(
f, backend, x, basis(backend, x, first(CartesianIndices(x)))
)
for (kj, j) in enumerate(CartesianIndices(x))
hvp!(f, p, extras, backend, x, basis(backend, x, j))
for ki in LinearIndices(p)
if abs(p[ki]) > atol
push!(I, ki)
push!(J, kj)
end
end
end
return sparse(I, J, ones(Bool, length(I)), n, n)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
module DifferentiationInterfaceSparseMatrixColoringsExt

using ADTypes:
ADTypes,
AbstractADType,
AutoSparse,
dense_ad,
coloring_algorithm,
sparsity_detector,
jacobian_sparsity,
hessian_sparsity
using Compat
using DifferentiationInterface
using DifferentiationInterface:
GradientExtras,
HessianExtras,
HVPExtras,
JacobianExtras,
PullbackExtras,
PushforwardExtras,
PushforwardFast,
PushforwardSlow,
Tangents,
dense_ad,
maybe_dense_ad,
maybe_inner,
maybe_outer,
multibasis,
pick_batchsize,
pushforward_performance
import DifferentiationInterface as DI
using SparseMatrixColorings:
AbstractColoringResult,
ColoringProblem,
GreedyColoringAlgorithm,
coloring,
column_colors,
row_colors,
column_groups,
row_groups,
decompress,
decompress!

include("jacobian.jl")
include("hessian.jl")

end
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ end

## Hessian, one argument

function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
function DI.prepare_hessian(f::F, backend::AutoSparse, x) where {F}
dense_backend = dense_ad(backend)
sparsity = hessian_sparsity(f, x, sparsity_detector(backend))
problem = ColoringProblem{:symmetric,:column}()
Expand Down Expand Up @@ -64,7 +64,9 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
)
end

function hessian(f::F, extras::SparseHessianExtras{B}, backend::AutoSparse, x) where {F,B}
function DI.hessian(
f::F, extras::SparseHessianExtras{B}, backend::AutoSparse, x
) where {F,B}
@compat (; coloring_result, batched_seeds, hvp_extras) = extras
dense_backend = dense_ad(backend)
Ng = length(column_groups(coloring_result))
Expand All @@ -85,7 +87,7 @@ function hessian(f::F, extras::SparseHessianExtras{B}, backend::AutoSparse, x) w
return decompress(compressed_matrix, coloring_result)
end

function hessian!(
function DI.hessian!(
f::F, hess, extras::SparseHessianExtras{B}, backend::AutoSparse, x
) where {F,B}
@compat (;
Expand Down Expand Up @@ -113,7 +115,7 @@ function hessian!(
return hess
end

function value_gradient_and_hessian!(
function DI.value_gradient_and_hessian!(
f::F, grad, hess, extras::SparseHessianExtras, backend::AutoSparse, x
) where {F}
y, _ = value_and_gradient!(
Expand All @@ -123,7 +125,7 @@ function value_gradient_and_hessian!(
return y, grad, hess
end

function value_gradient_and_hessian(
function DI.value_gradient_and_hessian(
f::F, extras::SparseHessianExtras, backend::AutoSparse, x
) where {F}
y, grad = value_and_gradient(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ function PullbackSparseJacobianExtras{B}(;
)
end

function prepare_jacobian(f::F, backend::AutoSparse, x) where {F}
function DI.prepare_jacobian(f::F, backend::AutoSparse, x) where {F}
y = f(x)
return _prepare_sparse_jacobian_aux(
(f,), backend, x, y, pushforward_performance(backend)
)
end

function prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F}
function DI.prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F}
return _prepare_sparse_jacobian_aux(
(f!, y), backend, x, y, pushforward_performance(backend)
)
Expand Down Expand Up @@ -137,49 +137,51 @@ end

## One argument

function jacobian(f::F, extras::SparseJacobianExtras, backend::AutoSparse, x) where {F}
function DI.jacobian(f::F, extras::SparseJacobianExtras, backend::AutoSparse, x) where {F}
return _sparse_jacobian_aux((f,), extras, backend, x)
end

function jacobian!(
function DI.jacobian!(
f::F, jac, extras::SparseJacobianExtras, backend::AutoSparse, x
) where {F}
return _sparse_jacobian_aux!((f,), jac, extras, backend, x)
end

function value_and_jacobian(
function DI.value_and_jacobian(
f::F, extras::SparseJacobianExtras, backend::AutoSparse, x
) where {F}
return f(x), jacobian(f, extras, backend, x)
end

function value_and_jacobian!(
function DI.value_and_jacobian!(
f::F, jac, extras::SparseJacobianExtras, backend::AutoSparse, x
) where {F}
return f(x), jacobian!(f, jac, extras, backend, x)
end

## Two arguments

function jacobian(f!::F, y, extras::SparseJacobianExtras, backend::AutoSparse, x) where {F}
function DI.jacobian(
f!::F, y, extras::SparseJacobianExtras, backend::AutoSparse, x
) where {F}
return _sparse_jacobian_aux((f!, y), extras, backend, x)
end

function jacobian!(
function DI.jacobian!(
f!::F, y, jac, extras::SparseJacobianExtras, backend::AutoSparse, x
) where {F}
return _sparse_jacobian_aux!((f!, y), jac, extras, backend, x)
end

function value_and_jacobian(
function DI.value_and_jacobian(
f!::F, y, extras::SparseJacobianExtras, backend::AutoSparse, x
) where {F}
jac = jacobian(f!, y, extras, backend, x)
f!(y, x)
return y, jac
end

function value_and_jacobian!(
function DI.value_and_jacobian!(
f!::F, y, jac, extras::SparseJacobianExtras, backend::AutoSparse, x
) where {F}
jacobian!(f!, y, jac, extras, backend, x)
Expand Down
Loading
Loading