diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 9d75e1d33..d69768d35 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -11,7 +11,6 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -23,6 +22,7 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -38,6 +38,7 @@ DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences" DifferentiationInterfaceForwardDiffExt = "ForwardDiff" DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff" DifferentiationInterfaceReverseDiffExt = "ReverseDiff" +DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings" DifferentiationInterfaceSymbolicsExt = "Symbolics" DifferentiationInterfaceTapirExt = "Tapir" DifferentiationInterfaceTrackerExt = "Tracker" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl new file mode 100644 index 000000000..0b52f24c9 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl @@ -0,0 +1,40 @@ +module DifferentiationInterfaceSparseMatrixColoringsExt + +using ADTypes: AutoSparse, dense_ad +using ADTypes: coloring_algorithm, sparsity_detector, jacobian_sparsity, hessian_sparsity +using Compat +using DifferentiationInterface +using DifferentiationInterface: + Batch, + GradientExtras, + JacobianExtras, + HessianExtras, + HVPExtras, + hvp_batched, + make_seed, + maybe_inner, + pick_batchsize, + pushforward_batched, + pullback_batched, + prepare_gradient, + prepare_hvp_batched_same_point, + prepare_pushforward_batched_same_point, + prepare_pullback_batched_same_point +import DifferentiationInterface as DI +using SparseMatrixColorings: + AbstractColoringResult, + ColoringProblem, + GreedyColoringAlgorithm, + coloring, + column_colors, + row_colors, + column_groups, + row_groups, + decompress, + decompress! + +include("fallbacks.jl") +include("jacobian.jl") +include("hessian.jl") + +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/fallbacks.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/fallbacks.jl new file mode 100644 index 000000000..04001ed0c --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/fallbacks.jl @@ -0,0 +1,9 @@ +DI.check_available(backend::AutoSparse) = DI.check_available(dense_ad(backend)) +DI.twoarg_support(backend::AutoSparse) = DI.twoarg_support(dense_ad(backend)) + +function DI.pushforward_performance(backend::AutoSparse) + return DI.pushforward_performance(dense_ad(backend)) +end + +DI.pullback_performance(backend::AutoSparse) = DI.pullback_performance(dense_ad(backend)) +DI.hvp_mode(backend::AutoSparse{<:SecondOrder}) = DI.hvp_mode(dense_ad(backend)) diff --git a/DifferentiationInterface/src/sparse/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl similarity index 94% rename from DifferentiationInterface/src/sparse/hessian.jl rename to DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index 828ea940d..49c35c1ef 100644 --- a/DifferentiationInterface/src/sparse/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -35,7 +35,7 @@ end ## Hessian, one argument -function prepare_hessian(f::F, backend::AutoSparse, x) where {F} +function DI.prepare_hessian(f::F, backend::AutoSparse, x) where {F} dense_backend = dense_ad(backend) sparsity = hessian_sparsity(f, x, sparsity_detector(backend)) problem = ColoringProblem{:symmetric,:column}() @@ -65,7 +65,9 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F} ) end -function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras{B}) where {F,B} +function DI.hessian( + f::F, backend::AutoSparse, x, extras::SparseHessianExtras{B} +) where {F,B} @compat (; coloring_result, batched_seeds, hvp_batched_extras) = extras dense_backend = dense_ad(backend) Ng = length(column_groups(coloring_result)) @@ -86,7 +88,7 @@ function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras{B}) w return decompress(compressed_matrix, coloring_result) end -function hessian!( +function DI.hessian!( f::F, hess, backend::AutoSparse, x, extras::SparseHessianExtras{B} ) where {F,B} @compat (; @@ -125,7 +127,7 @@ function hessian!( return hess end -function value_gradient_and_hessian!( +function DI.value_gradient_and_hessian!( f::F, grad, hess, backend::AutoSparse, x, extras::SparseHessianExtras ) where {F} y, _ = value_and_gradient!( @@ -135,7 +137,7 @@ function value_gradient_and_hessian!( return y, grad, hess end -function value_gradient_and_hessian( +function DI.value_gradient_and_hessian( f::F, backend::AutoSparse, x, extras::SparseHessianExtras ) where {F} y, grad = value_and_gradient( diff --git a/DifferentiationInterface/src/sparse/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl similarity index 93% rename from DifferentiationInterface/src/sparse/jacobian.jl rename to DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index b40e37efd..4a3bfd32b 100644 --- a/DifferentiationInterface/src/sparse/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -64,16 +64,16 @@ function PullbackSparseJacobianExtras{B}(; ) end -function prepare_jacobian(f::F, backend::AutoSparse, x) where {F} +function DI.prepare_jacobian(f::F, backend::AutoSparse, x) where {F} y = f(x) return prepare_sparse_jacobian_aux( - (f,), backend, x, y, pushforward_performance(backend) + (f,), backend, x, y, DI.pushforward_performance(backend) ) end -function prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F} +function DI.prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F} return prepare_sparse_jacobian_aux( - (f!, y), backend, x, y, pushforward_performance(backend) + (f!, y), backend, x, y, DI.pushforward_performance(backend) ) end @@ -149,23 +149,23 @@ end ## One argument -function jacobian(f::F, backend::AutoSparse, x, extras::SparseJacobianExtras) where {F} +function DI.jacobian(f::F, backend::AutoSparse, x, extras::SparseJacobianExtras) where {F} return sparse_jacobian_aux((f,), backend, x, extras) end -function jacobian!( +function DI.jacobian!( f::F, jac, backend::AutoSparse, x, extras::SparseJacobianExtras ) where {F} return sparse_jacobian_aux!((f,), jac, backend, x, extras) end -function value_and_jacobian( +function DI.value_and_jacobian( f::F, backend::AutoSparse, x, extras::SparseJacobianExtras ) where {F} return f(x), jacobian(f, backend, x, extras) end -function value_and_jacobian!( +function DI.value_and_jacobian!( f::F, jac, backend::AutoSparse, x, extras::SparseJacobianExtras ) where {F} return f(x), jacobian!(f, jac, backend, x, extras) @@ -173,17 +173,19 @@ end ## Two arguments -function jacobian(f!::F, y, backend::AutoSparse, x, extras::SparseJacobianExtras) where {F} +function DI.jacobian( + f!::F, y, backend::AutoSparse, x, extras::SparseJacobianExtras +) where {F} return sparse_jacobian_aux((f!, y), backend, x, extras) end -function jacobian!( +function DI.jacobian!( f!::F, y, jac, backend::AutoSparse, x, extras::SparseJacobianExtras ) where {F} return sparse_jacobian_aux!((f!, y), jac, backend, x, extras) end -function value_and_jacobian( +function DI.value_and_jacobian( f!::F, y, backend::AutoSparse, x, extras::SparseJacobianExtras ) where {F} jac = jacobian(f!, y, backend, x, extras) @@ -191,7 +193,7 @@ function value_and_jacobian( return y, jac end -function value_and_jacobian!( +function DI.value_and_jacobian!( f!::F, y, jac, backend::AutoSparse, x, extras::SparseJacobianExtras ) where {F} jacobian!(f!, y, jac, backend, x, extras) diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index e7d7aa07c..c1c5adad0 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -12,8 +12,7 @@ module DifferentiationInterface using ADTypes: ADTypes, AbstractADType using ADTypes: mode, ForwardMode, ForwardOrReverseMode, ReverseMode, SymbolicMode using ADTypes: AutoSparse, dense_ad -using ADTypes: coloring_algorithm -using ADTypes: sparsity_detector, jacobian_sparsity, hessian_sparsity +using ADTypes: coloring_algorithm, sparsity_detector, jacobian_sparsity, hessian_sparsity using ADTypes: AutoChainRules, AutoDiffractor, @@ -33,18 +32,7 @@ using DocStringExtensions using FillArrays: OneElement using LinearAlgebra: Symmetric, Transpose, dot, parent, transpose using PackageExtensionCompat: @require_extensions -using SparseArrays: SparseMatrixCSC, nonzeros, nzrange, rowvals, sparse -using SparseMatrixColorings: - AbstractColoringResult, - ColoringProblem, - GreedyColoringAlgorithm, - coloring, - column_colors, - row_colors, - column_groups, - row_groups, - decompress, - decompress! +using SparseArrays: sparse abstract type Extras end @@ -73,10 +61,6 @@ include("second_order/hessian.jl") include("fallbacks/no_extras.jl") -include("sparse/fallbacks.jl") -include("sparse/jacobian.jl") -include("sparse/hessian.jl") - include("misc/differentiate_with.jl") include("misc/sparsity_detector.jl") include("misc/from_primitive.jl") diff --git a/DifferentiationInterface/src/sparse/fallbacks.jl b/DifferentiationInterface/src/sparse/fallbacks.jl deleted file mode 100644 index e568cfda5..000000000 --- a/DifferentiationInterface/src/sparse/fallbacks.jl +++ /dev/null @@ -1,5 +0,0 @@ -check_available(backend::AutoSparse) = check_available(dense_ad(backend)) -twoarg_support(backend::AutoSparse) = twoarg_support(dense_ad(backend)) -pushforward_performance(backend::AutoSparse) = pushforward_performance(dense_ad(backend)) -pullback_performance(backend::AutoSparse) = pullback_performance(dense_ad(backend)) -hvp_mode(backend::AutoSparse{<:SecondOrder}) = hvp_mode(dense_ad(backend))