Skip to content

Commit

Permalink
Simplify tutorial and README (#216)
Browse files Browse the repository at this point in the history
* Simplify Tutorial and README

* Safer random test

---------

Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
  • Loading branch information
adrhill and gdalle authored Apr 26, 2024
1 parent 0f71a42 commit 0047722
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 41 deletions.
20 changes: 11 additions & 9 deletions DifferentiationInterface/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ This package provides a backend-agnostic syntax to differentiate functions of th

## Features

- First- and second-order operators
- First- and second-order operators (gradients, Jacobians, Hessians and [more](https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface/stable/overview/))
- In-place and out-of-place differentiation
- Preparation mechanism (e.g. to create a config or tape)
- Thorough validation on standard inputs and outputs (numbers, vectors, matrices)
Expand Down Expand Up @@ -68,19 +68,21 @@ julia> Pkg.add(

## Example

```jldoctest readme
julia> import ADTypes, ForwardDiff
julia> using DifferentiationInterface
```julia
using DifferentiationInterface
import ForwardDiff, Enzyme, Zygote # import automatic differentiation backends you want to use

julia> backend = ADTypes.AutoForwardDiff();
f(x) = sum(abs2, x)

julia> f(x) = sum(abs2, x);
x = [1.0, 2.0, 3.0]

julia> value_and_gradient(f, backend, [1., 2., 3.])
(14.0, [2.0, 4.0, 6.0])
value_and_gradient(f, AutoForwardDiff(), x) # returns (14.0, [2.0, 4.0, 6.0]) using ForwardDiff.jl
value_and_gradient(f, AutoEnzyme(), x) # returns (14.0, [2.0, 4.0, 6.0]) using Enzyme.jl
value_and_gradient(f, AutoZygote(), x) # returns (14.0, [2.0, 4.0, 6.0]) using Zygote.jl
```

For more performance, take a look at the [DifferentiationInterface tutorial](https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface/stable/tutorial/).

## Related packages

- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) is the original inspiration for DifferentiationInterface.jl.
Expand Down
70 changes: 40 additions & 30 deletions DifferentiationInterface/docs/src/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,50 @@ CurrentModule = Main

We present a typical workflow with DifferentiationInterface.jl and showcase its potential performance benefits.

```@repl tuto
```@example tuto
using DifferentiationInterface
import ADTypes, ForwardDiff, Enzyme
using BenchmarkTools
import ForwardDiff, Enzyme # ⚠️ import the backends you want to use ⚠️
```

## Computing a gradient
!!! tip
Importing backends with `import` instead of `using` avoids name conflicts and makes sure you are using operators from DifferentiationInterface.jl.
This is useful since most backends also export operators like `gradient` and `jacobian`.

A common use case of AD is optimizing real-valued functions with first- or second-order methods.
Let's define a simple objective

```@repl tuto
f(x::AbstractArray) = sum(abs2, x)
```
## Computing a gradient

and a random input vector
A common use case of automatic differentiation (AD) is optimizing real-valued functions with first- or second-order methods.
Let's define a simple objective and a random input vector

```@repl tuto
x = [1.0, 2.0, 3.0];
```@example tuto
f(x) = sum(abs2, x)
x = [1.0, 2.0, 3.0]
nothing # hide
```

To compute its gradient, we need to choose a "backend", i.e. an AD package that DifferentiationInterface.jl will call under the hood.
Most backend types are defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl) and re-exported by DifferentiationInterface.jl.

[ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) is very generic and efficient for low-dimensional inputs, so it's a good starting point:

```@repl tuto
backend = ADTypes.AutoForwardDiff()
```@example tuto
backend = AutoForwardDiff()
nothing # hide
```

Now you can use DifferentiationInterface.jl to get the gradient:

```@repl tuto
```@example tuto
gradient(f, backend, x)
```

Was that fast?
[BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl) helps you answer that question.

```@repl tuto
using BenchmarkTools
@btime gradient($f, $backend, $x);
```

Expand All @@ -58,13 +63,14 @@ Not bad, but you can do better.

## Overwriting a gradient

Since you know how much space your gradient will occupy, you can pre-allocate that memory and offer it to AD.
Since you know how much space your gradient will occupy (the same as your input `x`), you can pre-allocate that memory and offer it to AD.
Some backends get a speed boost from this trick.

```@repl tuto
grad = zero(x)
gradient!(f, grad, backend, x);
grad
```@example tuto
grad = similar(x)
gradient!(f, grad, backend, x)
grad # has been mutated
```

The bang indicates that one of the arguments of `gradient!` might be mutated.
Expand All @@ -76,24 +82,26 @@ More precisely, our convention is that _every positional argument between the fu

For some reason the in-place version is not much better than your first attempt.
However, it has one less allocation, which corresponds to the gradient vector you provided.
Don't worry, you're not done yet.
Don't worry, you can get even more performance.

## Preparing for multiple gradients

Internally, ForwardDiff.jl creates some data structures to keep track of things.
These objects can be reused between gradient computations, even on different input values.
We abstract away the preparation step behind a backend-agnostic syntax:

```@repl tuto
```@example tuto
extras = prepare_gradient(f, backend, x)
nothing # hide
```

You don't need to know what this object is, you just need to pass it to the gradient operator.

```@repl tuto
grad = zero(x);
gradient!(f, grad, backend, x, extras);
grad
```@example tuto
grad = similar(x)
gradient!(f, grad, backend, x, extras)
grad # has been mutated
```

Preparation makes the gradient computation much faster, and (in this case) allocation-free.
Expand All @@ -115,13 +123,14 @@ So let's try the state-of-the-art [Enzyme.jl](https://github.com/EnzymeAD/Enzyme

For this one, the backend definition is slightly more involved, because you need to feed the "mode" to the object from ADTypes.jl:

```@repl tuto
backend2 = ADTypes.AutoEnzyme(; mode=Enzyme.Reverse)
```@example tuto
backend2 = AutoEnzyme(; mode=Enzyme.Reverse)
nothing # hide
```

But once it is done, things run smoothly with exactly the same syntax:

```@repl tuto
```@example tuto
gradient(f, backend2, x)
```

Expand All @@ -136,4 +145,5 @@ And you can run the same benchmarks:

Not only is it blazingly fast, you achieved this speedup without looking at the docs of either ForwardDiff.jl or Enzyme.jl!
In short, DifferentiationInterface.jl allows for easy testing and comparison of AD backends.
If you want to go further, check out the [DifferentiationTest.jl tutorial](https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterfaceTest/dev/tutorial/).
If you want to go further, check out the [DifferentiationInterfaceTest.jl tutorial](https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterfaceTest/dev/tutorial/).
It provides benchmarking utilities to compare backends and help you select the one that is best suited for your problem.
4 changes: 2 additions & 2 deletions DifferentiationInterface/test/coloring.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
alg = DI.GreedyColoringAlgorithm()

A = sprand(Bool, 100, 200, 0.1)
A = sprand(Bool, 100, 200, 0.05)

column_colors = ADTypes.column_coloring(A, alg)
@test DI.check_structurally_orthogonal_columns(A, column_colors)
Expand All @@ -10,7 +10,7 @@ row_colors = ADTypes.row_coloring(A, alg)
@test DI.check_structurally_orthogonal_rows(A, row_colors)
@test maximum(row_colors) < size(A, 1) ÷ 2

S = Symmetric(sprand(Bool, 100, 100, 0.1)) + I
S = Symmetric(sprand(Bool, 100, 100, 0.05)) + I
symmetric_colors = ADTypes.symmetric_coloring(S, alg)
@test DI.check_symmetrically_structurally_orthogonal(S, symmetric_colors)
@test maximum(symmetric_colors) < size(A, 2) ÷ 2

0 comments on commit 0047722

Please sign in to comment.