Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move sparse AD code to SparseMatrixColorings extension #417

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -23,6 +22,7 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -38,6 +38,7 @@ DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings"
DifferentiationInterfaceSymbolicsExt = "Symbolics"
DifferentiationInterfaceTapirExt = "Tapir"
DifferentiationInterfaceTrackerExt = "Tracker"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
module DifferentiationInterfaceSparseMatrixColoringsExt

using ADTypes: AutoSparse, dense_ad
using ADTypes: coloring_algorithm, sparsity_detector, jacobian_sparsity, hessian_sparsity
using Compat
using DifferentiationInterface
using DifferentiationInterface:
Batch,
GradientExtras,
JacobianExtras,
HessianExtras,
HVPExtras,
hvp_batched,
make_seed,
maybe_inner,
pick_batchsize,
pushforward_batched,
pullback_batched,
prepare_gradient,
prepare_hvp_batched_same_point,
prepare_pushforward_batched_same_point,
prepare_pullback_batched_same_point
import DifferentiationInterface as DI
using SparseMatrixColorings:
AbstractColoringResult,
ColoringProblem,
GreedyColoringAlgorithm,
coloring,
column_colors,
row_colors,
column_groups,
row_groups,
decompress,
decompress!

include("fallbacks.jl")
include("jacobian.jl")
include("hessian.jl")

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
DI.check_available(backend::AutoSparse) = DI.check_available(dense_ad(backend))
DI.twoarg_support(backend::AutoSparse) = DI.twoarg_support(dense_ad(backend))

function DI.pushforward_performance(backend::AutoSparse)
return DI.pushforward_performance(dense_ad(backend))
end

DI.pullback_performance(backend::AutoSparse) = DI.pullback_performance(dense_ad(backend))
DI.hvp_mode(backend::AutoSparse{<:SecondOrder}) = DI.hvp_mode(dense_ad(backend))
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ end

## Hessian, one argument

function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
function DI.prepare_hessian(f::F, backend::AutoSparse, x) where {F}
dense_backend = dense_ad(backend)
sparsity = hessian_sparsity(f, x, sparsity_detector(backend))
problem = ColoringProblem{:symmetric,:column}()
Expand Down Expand Up @@ -65,7 +65,9 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
)
end

function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras{B}) where {F,B}
function DI.hessian(
f::F, backend::AutoSparse, x, extras::SparseHessianExtras{B}
) where {F,B}
@compat (; coloring_result, batched_seeds, hvp_batched_extras) = extras
dense_backend = dense_ad(backend)
Ng = length(column_groups(coloring_result))
Expand All @@ -86,7 +88,7 @@ function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras{B}) w
return decompress(compressed_matrix, coloring_result)
end

function hessian!(
function DI.hessian!(
f::F, hess, backend::AutoSparse, x, extras::SparseHessianExtras{B}
) where {F,B}
@compat (;
Expand Down Expand Up @@ -125,7 +127,7 @@ function hessian!(
return hess
end

function value_gradient_and_hessian!(
function DI.value_gradient_and_hessian!(
f::F, grad, hess, backend::AutoSparse, x, extras::SparseHessianExtras
) where {F}
y, _ = value_and_gradient!(
Expand All @@ -135,7 +137,7 @@ function value_gradient_and_hessian!(
return y, grad, hess
end

function value_gradient_and_hessian(
function DI.value_gradient_and_hessian(
f::F, backend::AutoSparse, x, extras::SparseHessianExtras
) where {F}
y, grad = value_and_gradient(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,16 @@ function PullbackSparseJacobianExtras{B}(;
)
end

function prepare_jacobian(f::F, backend::AutoSparse, x) where {F}
function DI.prepare_jacobian(f::F, backend::AutoSparse, x) where {F}
y = f(x)
return prepare_sparse_jacobian_aux(
(f,), backend, x, y, pushforward_performance(backend)
(f,), backend, x, y, DI.pushforward_performance(backend)
)
end

function prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F}
function DI.prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F}
return prepare_sparse_jacobian_aux(
(f!, y), backend, x, y, pushforward_performance(backend)
(f!, y), backend, x, y, DI.pushforward_performance(backend)
)
end

Expand Down Expand Up @@ -149,49 +149,51 @@ end

## One argument

function jacobian(f::F, backend::AutoSparse, x, extras::SparseJacobianExtras) where {F}
function DI.jacobian(f::F, backend::AutoSparse, x, extras::SparseJacobianExtras) where {F}
return sparse_jacobian_aux((f,), backend, x, extras)
end

function jacobian!(
function DI.jacobian!(
f::F, jac, backend::AutoSparse, x, extras::SparseJacobianExtras
) where {F}
return sparse_jacobian_aux!((f,), jac, backend, x, extras)
end

function value_and_jacobian(
function DI.value_and_jacobian(
f::F, backend::AutoSparse, x, extras::SparseJacobianExtras
) where {F}
return f(x), jacobian(f, backend, x, extras)
end

function value_and_jacobian!(
function DI.value_and_jacobian!(
f::F, jac, backend::AutoSparse, x, extras::SparseJacobianExtras
) where {F}
return f(x), jacobian!(f, jac, backend, x, extras)
end

## Two arguments

function jacobian(f!::F, y, backend::AutoSparse, x, extras::SparseJacobianExtras) where {F}
function DI.jacobian(
f!::F, y, backend::AutoSparse, x, extras::SparseJacobianExtras
) where {F}
return sparse_jacobian_aux((f!, y), backend, x, extras)
end

function jacobian!(
function DI.jacobian!(
f!::F, y, jac, backend::AutoSparse, x, extras::SparseJacobianExtras
) where {F}
return sparse_jacobian_aux!((f!, y), jac, backend, x, extras)
end

function value_and_jacobian(
function DI.value_and_jacobian(
f!::F, y, backend::AutoSparse, x, extras::SparseJacobianExtras
) where {F}
jac = jacobian(f!, y, backend, x, extras)
f!(y, x)
return y, jac
end

function value_and_jacobian!(
function DI.value_and_jacobian!(
f!::F, y, jac, backend::AutoSparse, x, extras::SparseJacobianExtras
) where {F}
jacobian!(f!, y, jac, backend, x, extras)
Expand Down
20 changes: 2 additions & 18 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ module DifferentiationInterface
using ADTypes: ADTypes, AbstractADType
using ADTypes: mode, ForwardMode, ForwardOrReverseMode, ReverseMode, SymbolicMode
using ADTypes: AutoSparse, dense_ad
using ADTypes: coloring_algorithm
using ADTypes: sparsity_detector, jacobian_sparsity, hessian_sparsity
using ADTypes: coloring_algorithm, sparsity_detector, jacobian_sparsity, hessian_sparsity
using ADTypes:
AutoChainRules,
AutoDiffractor,
Expand All @@ -33,18 +32,7 @@ using DocStringExtensions
using FillArrays: OneElement
using LinearAlgebra: Symmetric, Transpose, dot, parent, transpose
using PackageExtensionCompat: @require_extensions
using SparseArrays: SparseMatrixCSC, nonzeros, nzrange, rowvals, sparse
using SparseMatrixColorings:
AbstractColoringResult,
ColoringProblem,
GreedyColoringAlgorithm,
coloring,
column_colors,
row_colors,
column_groups,
row_groups,
decompress,
decompress!
using SparseArrays: sparse

abstract type Extras end

Expand Down Expand Up @@ -73,10 +61,6 @@ include("second_order/hessian.jl")

include("fallbacks/no_extras.jl")

include("sparse/fallbacks.jl")
include("sparse/jacobian.jl")
include("sparse/hessian.jl")

include("misc/differentiate_with.jl")
include("misc/sparsity_detector.jl")
include("misc/from_primitive.jl")
Expand Down
5 changes: 0 additions & 5 deletions DifferentiationInterface/src/sparse/fallbacks.jl

This file was deleted.

Loading