Skip to content

Commit

Permalink
AutoSparse should only support Jacobians and Hessians (#277)
Browse files Browse the repository at this point in the history
* AutoSparse only does Jacobians and Hessians

* Use dense backend explicitly

* Re-export checks

* Typo

* Fix scenarios

* Fix tests

* Type stab

* Types

* Zygote on GPU

* Sparsity on 1.6

* Typo

* Typo

* Doc
  • Loading branch information
gdalle authored May 28, 2024
1 parent 170a729 commit ef2bcb5
Show file tree
Hide file tree
Showing 25 changed files with 271 additions and 240 deletions.
18 changes: 15 additions & 3 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.4.2"
version = "0.5.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -73,7 +73,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3"
# DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3"
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
Expand All @@ -95,4 +95,16 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ADTypes", "Aqua", "DataFrames", "DifferentiationInterfaceTest", "JET", "JuliaFormatter", "Pkg", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "Test"]
test = [
"ADTypes",
"Aqua",
"DataFrames",
# "DifferentiationInterfaceTest",
"JET",
"JuliaFormatter",
"Pkg",
"SparseArrays",
"SparseConnectivityTracer",
"SparseMatrixColorings",
"Test",
]
3 changes: 3 additions & 0 deletions DifferentiationInterface/docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ hessian!
check_available
check_twoarg
check_hessian
DifferentiationInterface.outer
DifferentiationInterface.inner
```

### Backend switch
Expand All @@ -116,4 +118,5 @@ The following is not part of the public API.
```@autodocs
Modules = [DifferentiationInterface]
Public = false
Filter = t -> !(Symbol(t) in [:outer, :inner])
```
19 changes: 10 additions & 9 deletions DifferentiationInterface/docs/src/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ backend = SecondOrder(outer_backend, inner_backend)

The inner backend will be called first, and the outer backend will differentiate the generated code.

!!! warning
There are many possible backend combinations, a lot of which will fail.
Usually, the most efficient approach for Hessians is forward-over-reverse, i.e. a forward-mode outer backend and a reverse-mode inner backend.
There are many possible backend combinations, a lot of which will fail.
Usually, the most efficient approach for Hessians is forward-over-reverse, i.e. a forward-mode outer backend and a reverse-mode inner backend.

!!! danger
`SecondOrder` backends do not support first-order operators.

!!! warning
Preparation does not yet work for the inner differentiation step of a `SecondOrder`, only the outer differentiation is prepared.
Expand All @@ -164,23 +166,22 @@ For this to work, three ingredients are needed (read [this survey](https://epubs
2. A sparsity pattern detector like [`TracerSparsityDetector`](@extref SparseConnectivityTracer.TracerSparsityDetector) from [SparseConnectivityTracer.jl](https://github.com/adrhill/SparseConnectivityTracer.jl)
3. A coloring algorithm like [`GreedyColoringAlgorithm`](@extref SparseMatrixColorings.GreedyColoringAlgorithm) from [SparseMatrixColorings.jl](https://github.com/gdalle/SparseMatrixColorings.jl)

These ingredients can be combined within the [`AutoSparse`](@extref ADTypes.AutoSparse) wrapper, which Differentiation.jl re-exports.
These ingredients can be combined within the [`AutoSparse`](@extref ADTypes.AutoSparse) wrapper, which DifferentiationInterface.jl re-exports.
Note that for sparse Hessians, you need to put the `SecondOrder` backend inside `AutoSparse`, and not the other way around.

The preparation step of `jacobian` or `hessian` with an `AutoSparse` backend can be long, because it needs to detect the sparsity pattern and color the resulting sparse matrix.
But after preparation, the more zeros are present in the matrix, the greater the speedup will be compared to dense differentiation.

!!! danger
`AutoSparse` backends only support operators [`jacobian`](@ref) and [`hessian`](@ref) (as well as their variants).

!!! warning
The result of preparation for an `AutoSparse` backend cannot be reused if the sparsity pattern changes.

!!! info
The symbolic backends have built-in sparsity handling, so `AutoSparse(AutoSymbolics())` and `AutoSparse(AutoFastDifferentiation())` do not need additional configuration for pattern detection or coloring.
Symbolic backends have built-in sparsity handling, so `AutoSparse(AutoSymbolics())` and `AutoSparse(AutoFastDifferentiation())` do not need additional configuration for pattern detection or coloring.
However they still benefit from preparation.

!!! warning
At the moment, `AutoSparse` backends can be used with operators other than `jacobian` and `hessian`.
This possibility will be removed in the next breaking release.

## Going further

### Non-standard types
Expand Down
13 changes: 11 additions & 2 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ function __init__()
@require_extensions
end

## Exported

export SecondOrder

export value_and_pushforward!, value_and_pushforward
Expand Down Expand Up @@ -107,9 +109,8 @@ export check_available, check_twoarg, check_hessian

export DifferentiateWith

export GreedyColoringAlgorithm
## Re-exported from ADTypes

# Re-export backends from ADTypes
export AutoChainRules
export AutoDiffractor
export AutoEnzyme
Expand All @@ -126,4 +127,12 @@ export AutoZygote

export AutoSparse

## Re-exported from SparseMatrixColorings

export GreedyColoringAlgorithm

## Public but not exported

@compat public inner, outer

end # module
4 changes: 2 additions & 2 deletions DifferentiationInterface/src/second_order/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ end
function hessian(
f::F, backend::AbstractADType, x, extras::HessianExtras=prepare_hessian(f, backend, x)
) where {F}
new_backend = SecondOrder(backend)
new_backend = SecondOrder(backend, backend)
new_extras = prepare_hessian(f, new_backend, x)
return hessian(f, new_backend, x, new_extras)
end
Expand All @@ -75,7 +75,7 @@ function hessian!(
x,
extras::HessianExtras=prepare_hessian(f, backend, x),
) where {F}
new_backend = SecondOrder(backend)
new_backend = SecondOrder(backend, backend)
new_extras = prepare_hessian(f, new_backend, x)
return hessian!(f, hess, new_backend, x, new_extras)
end
Expand Down
4 changes: 2 additions & 2 deletions DifferentiationInterface/src/second_order/hvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ end
function hvp(
f::F, backend::AbstractADType, x, v, extras::HVPExtras=prepare_hvp(f, backend, x, v)
) where {F}
new_backend = SecondOrder(backend)
new_backend = SecondOrder(backend, backend)
new_extras = prepare_hvp(f, new_backend, x, v)
return hvp(f, new_backend, x, v, new_extras)
end
Expand Down Expand Up @@ -175,7 +175,7 @@ end
function hvp!(
f::F, p, backend::AbstractADType, x, v, extras::HVPExtras=prepare_hvp(f, backend, x, v)
) where {F}
new_backend = SecondOrder(backend)
new_backend = SecondOrder(backend, backend)
new_extras = prepare_hvp(f, new_backend, x, v)
return hvp!(f, p, new_backend, x, v, new_extras)
end
Expand Down
32 changes: 18 additions & 14 deletions DifferentiationInterface/src/second_order/second_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
Combination of two backends for second-order differentiation.
!!! danger
`SecondOrder` backends do not support first-order operators.
# Constructor
SecondOrder(outer, inner)
SecondOrder(outer_backend, inner_backend)
# Fields
Expand All @@ -18,26 +21,27 @@ struct SecondOrder{ADO<:AbstractADType,ADI<:AbstractADType} <: AbstractADType
inner::ADI
end

SecondOrder(backend::AbstractADType) = SecondOrder(backend, backend)

inner(backend::SecondOrder) = backend.inner
outer(backend::SecondOrder) = backend.outer

function Base.show(io::IO, backend::SecondOrder)
return print(io, "SecondOrder($(outer(backend)) / $(inner(backend)))")
end

"""
inner(backend::SecondOrder)
Return the inner backend of a [`SecondOrder`](@ref) object, tasked with differentiation at the first order.
"""
inner(backend::SecondOrder) = backend.inner

"""
outer(backend::SecondOrder)
Return the outer backend of a [`SecondOrder`](@ref) object, tasked with differentiation at the second order.
"""
outer(backend::SecondOrder) = backend.outer

"""
mode(backend::SecondOrder)
Return the _outer_ mode of the second-order backend.
"""
ADTypes.mode(backend::SecondOrder) = mode(outer(backend))

function twoarg_support(backend::SecondOrder)
if Bool(twoarg_support(inner(backend))) && Bool(twoarg_support(outer(backend)))
return TwoArgSupported()
else
return TwoArgNotSupported()
end
end
110 changes: 5 additions & 105 deletions DifferentiationInterface/src/sparse/fallbacks.jl
Original file line number Diff line number Diff line change
@@ -1,105 +1,5 @@
## Traits

for trait in (
:check_available,
:twoarg_support,
:pushforward_performance,
:pullback_performance,
:hvp_mode,
)
@eval $trait(backend::AutoSparse) = $trait(dense_ad(backend))
end

## Operators

for op in (:pushforward, :pullback, :hvp)
op! = Symbol(op, "!")
valop = Symbol("value_and_", op)
valop! = Symbol("value_and_", op, "!")
prep = Symbol("prepare_", op)
prepsame = Symbol("prepare_", op, "_same_point")
E = if op == :pushforward
:PushforwardExtras
elseif op == :pullback
:PullbackExtras
elseif op == :hvp
:HVPExtras
end

## One argument
@eval begin
$prep(f::F, ba::AutoSparse, x, v) where {F} = $prep(f, dense_ad(ba), x, v)
$prepsame(f::F, ba::AutoSparse, x, v) where {F} = $prepsame(f, dense_ad(ba), x, v)
$prepsame(f::F, ba::AutoSparse, x, v, ex::$E) where {F} =
$prepsame(f, dense_ad(ba), x, v, ex)
$op(f::F, ba::AutoSparse, x, v, ex::$E=$prep(f, ba, x, v)) where {F} =
$op(f, dense_ad(ba), x, v, ex)
$valop(f::F, ba::AutoSparse, x, v, ex::$E=$prep(f, ba, x, v)) where {F} =
$valop(f, dense_ad(ba), x, v, ex)
$op!(f::F, res, ba::AutoSparse, x, v, ex::$E=$prep(f, ba, x, v)) where {F} =
$op!(f, res, dense_ad(ba), x, v, ex)
$valop!(f::F, res, ba::AutoSparse, x, v, ex::$E=$prep(f, ba, x, v)) where {F} =
$valop!(f, res, dense_ad(ba), x, v, ex)
end

## Two arguments
@eval begin
$prep(f!::F, y, ba::AutoSparse, x, v) where {F} = $prep(f!, y, dense_ad(ba), x, v)
$prepsame(f!::F, y, ba::AutoSparse, x, v) where {F} =
$prepsame(f!, y, dense_ad(ba), x, v)
$prepsame(f!::F, y, ba::AutoSparse, x, v, ex::$E) where {F} =
$prepsame(f!, y, dense_ad(ba), x, v, ex)
$op(f!::F, y, ba::AutoSparse, x, v, ex::$E=$prep(f!, y, ba, x, v)) where {F} =
$op(f!, y, dense_ad(ba), x, v, ex)
$valop(f!::F, y, ba::AutoSparse, x, v, ex::$E=$prep(f!, y, ba, x, v)) where {F} =
$valop(f!, y, dense_ad(ba), x, v, ex)
$op!(f!::F, y, res, ba::AutoSparse, x, v, ex::$E=$prep(f!, y, ba, x, v)) where {F} =
$op!(f!, y, res, dense_ad(ba), x, v, ex)
$valop!(
f!::F, y, res, ba::AutoSparse, x, v, ex::$E=$prep(f!, y, ba, x, v)
) where {F} = $valop!(f!, y, res, dense_ad(ba), x, v, ex)
end
end

for op in (:derivative, :gradient, :second_derivative)
op! = Symbol(op, "!")
valop = Symbol("value_and_", op)
valop! = Symbol("value_and_", op, "!")
prep = Symbol("prepare_", op)
E = if op == :derivative
:DerivativeExtras
elseif op == :gradient
:GradientExtras
elseif op == :second_derivative
:SecondDerivativeExtras
end

## One argument
@eval begin
$prep(f::F, ba::AutoSparse, x) where {F} = $prep(f, dense_ad(ba), x)
$op(f::F, ba::AutoSparse, x, ex::$E=$prep(f, ba, x)) where {F} =
$op(f, dense_ad(ba), x, ex)
$valop(f::F, ba::AutoSparse, x, ex::$E=$prep(f, ba, x)) where {F} =
$valop(f, dense_ad(ba), x, ex)
$op!(f::F, res, ba::AutoSparse, x, ex::$E=$prep(f, ba, x)) where {F} =
$op!(f, res, dense_ad(ba), x, ex)
$valop!(f::F, res, ba::AutoSparse, x, ex::$E=$prep(f, ba, x)) where {F} =
$valop!(f, res, dense_ad(ba), x, ex)
end

## Two arguments
if op in (:derivative,)
@eval begin
$prep(f!::F, y, ba::AutoSparse, x) where {F} = $prep(f!, y, dense_ad(ba), x)
$op(f!::F, y, ba::AutoSparse, x, ex::$E=$prep(f!, y, ba, x)) where {F} =
$op(f!, y, dense_ad(ba), x, ex)
$valop(f!::F, y, ba::AutoSparse, x, ex::$E=$prep(f!, y, ba, x)) where {F} =
$valop(f!, y, dense_ad(ba), x, ex)
$op!(f!::F, y, res, ba::AutoSparse, x, ex::$E=$prep(f!, y, ba, x)) where {F} =
$op!(f!, y, res, dense_ad(ba), x, ex)
$valop!(
f!::F, y, res, ba::AutoSparse, x, ex::$E=$prep(f!, y, ba, x)
) where {F} = $valop!(f!, y, res, dense_ad(ba), x, ex)
end
end
end
check_available(backend::AutoSparse) = check_available(dense_ad(backend))
twoarg_support(backend::AutoSparse) = twoarg_support(dense_ad(backend))
pushforward_performance(backend::AutoSparse) = pushforward_performance(dense_ad(backend))
pullback_performance(backend::AutoSparse) = pullback_performance(dense_ad(backend))
hvp_mode(backend::AutoSparse{<:SecondOrder}) = hvp_mode(dense_ad(backend))
13 changes: 8 additions & 5 deletions DifferentiationInterface/src/sparse/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ end
## Hessian, one argument

function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
dense_backend = dense_ad(backend)
initial_sparsity = hessian_sparsity(f, x, sparsity_detector(backend))
sparsity = col_major(initial_sparsity)
colors = symmetric_coloring(sparsity, coloring_algorithm(backend))
Expand All @@ -26,7 +27,7 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
seed[group] .= one(eltype(x))
seed
end
hvp_extras = prepare_hvp(f, backend, x, first(seeds))
hvp_extras = prepare_hvp(f, dense_backend, x, first(seeds))
products = map(seeds) do _
similar(x)
end
Expand All @@ -36,9 +37,10 @@ end

function hessian!(f::F, hess, backend::AutoSparse, x, extras::SparseHessianExtras) where {F}
@compat (; sparsity, compressed, colors, seeds, products, hvp_extras) = extras
hvp_extras_same = prepare_hvp_same_point(f, backend, x, seeds[1], hvp_extras)
dense_backend = dense_ad(backend)
hvp_extras_same = prepare_hvp_same_point(f, dense_backend, x, seeds[1], hvp_extras)
for k in eachindex(seeds, products)
hvp!(f, products[k], backend, x, seeds[k], hvp_extras_same)
hvp!(f, products[k], dense_backend, x, seeds[k], hvp_extras_same)
copyto!(view(compressed, :, k), vec(products[k]))
end
decompress_symmetric!(hess, sparsity, compressed, colors)
Expand All @@ -47,9 +49,10 @@ end

function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras) where {F}
@compat (; sparsity, compressed, colors, seeds, products, hvp_extras) = extras
hvp_extras_same = prepare_hvp_same_point(f, backend, x, seeds[1], hvp_extras)
dense_backend = dense_ad(backend)
hvp_extras_same = prepare_hvp_same_point(f, dense_backend, x, seeds[1], hvp_extras)
compressed = stack(eachindex(seeds, products); dims=2) do k
vec(hvp(f, backend, x, seeds[k], hvp_extras_same))
vec(hvp(f, dense_backend, x, seeds[k], hvp_extras_same))
end
return decompress_symmetric(sparsity, compressed, colors)
end
Loading

0 comments on commit ef2bcb5

Please sign in to comment.