Skip to content

Commit

Permalink
Add Duplicated methods (#192)
Browse files Browse the repository at this point in the history
* add Duplicated methods

* add test

* test for shared params + minimal docs

* remove 1.6 CI

* indent by two spaces

* fix doctest
  • Loading branch information
mcabbott authored Nov 8, 2024
1 parent 38c9d62 commit 2639523
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1'
- 'nightly'
- "1.10"
os:
- ubuntu-latest
arch:
Expand Down
14 changes: 11 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
name = "Optimisers"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
version = "0.4.1"
authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
version = "0.4.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[extensions]
OptimisersEnzymeCoreExt = "EnzymeCore"

[compat]
ChainRulesCore = "1"
EnzymeCore = "0.8.5"
Functors = "0.4.9, 0.5"
Statistics = "1"
Zygote = "0.6.40"
julia = "1.6"
julia = "1.10"

[extras]
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "StaticArrays", "Zygote"]
test = ["Test", "EnzymeCore", "StaticArrays", "Zygote"]
9 changes: 9 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,12 @@ julia> Optimisers.update!(opt_state, x, g);
julia> opt_state # the state in `a` and `b` differ
(a = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.09, 0.09], [0.000999, 0.000999], (0.729, 0.997003))), b = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.1, 0.1], [0.001, 0.001], (0.81, 0.998001))))
```

## Usage with [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl)

Enzyme.jl is a new automatic differentiation package, an alternative to Zygote.jl.
It likes to store the model and the gradient together, as an object `Duplicated(x, dx)`.

Optimisers.jl now has some methods to handle this:
* `update!(opt_state, Duplicated(model, grad))` uses the gradient to update both the model and the optimiser state, and
* `setup(::AbstractRule, ::Duplicated)` ignores the gradient and returns `setup(rule, model)`.
60 changes: 60 additions & 0 deletions ext/OptimisersEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
module OptimisersEnzymeCoreExt

import Optimisers: trainable, setup, update!, isnumeric, AbstractRule, _setup
import EnzymeCore: Duplicated, Const

using Functors: fmapstructure

trainable(x::Duplicated) = (; val = x.val)
trainable(x::Const) = (;)

"""
setup(rule::AbstractRule, model_grad::Duplicated)
For use with Enzyme's Duplicated, this just calls `setup(rule, model_grad.val)`.
"""
setup(rule::AbstractRule, model_grad::Duplicated) = setup(rule, model_grad.val)

_setup(rule, x::Duplicated; cache) = throw(ArgumentError(
"""Objects of type `Duplicated` are only supported by Optimisers.jl at top level,
they may not appear deep inside other objects."""
))

"""
update!(opt_state, model_grad::Duplicated)
For use with Enzyme's `Duplicated`, which holds both a model/parameters
and the corresponding gradient.
# Example
```jldoctest
julia> using Optimisers, EnzymeCore
julia> x_dx = Duplicated(Float16[1,2,3], Float16[1,0,-4])
Duplicated{Vector{Float16}}(Float16[1.0, 2.0, 3.0], Float16[1.0, 0.0, -4.0])
julia> st = Optimisers.setup(Momentum(1/9), x_dx) # acts only on x not on dx
Leaf(Momentum(0.111111, 0.9), Float16[0.0, 0.0, 0.0])
julia> Optimisers.update!(st, x_dx) # mutates both arguments
julia> x_dx
Duplicated{Vector{Float16}}(Float16[0.8887, 2.0, 3.445], Float16[1.0, 0.0, -4.0])
julia> st
Leaf(Momentum(0.111111, 0.9), Float16[0.1111, 0.0, -0.4443])
```
"""
function update!(opt_state, model_grad::Duplicated)
_, _ = update!(opt_state, model_grad.val, _grad_or_nothing(model_grad))
nothing
end

# This function strips the returned gradient to be Zygote-like,
# most importantly prune=nothing removes 2nd appearance of shared gradient to avoid double-counting.
_grad_or_nothing(dup::Duplicated) = fmapstructure(_grad_or_nothing, dup.dval; prune=nothing)
_grad_or_nothing(::Const) = nothing
_grad_or_nothing(x) = isnumeric(x) ? x : nothing

end
24 changes: 23 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Optimisers
using ChainRulesCore, Functors, StaticArrays, Zygote
using ChainRulesCore, Functors, StaticArrays, Zygote, EnzymeCore
using LinearAlgebra, Statistics, Test, Random
using Optimisers: @.., @lazy
using Base.Broadcast: broadcasted, instantiate, Broadcasted
Expand Down Expand Up @@ -534,6 +534,28 @@ end
@test Optimisers._norm(bc2, p) isa Float64
end
end

@testset "Enzyme Duplicated" begin
x_dx = Duplicated(Float16[1,2,3], Float16[1,0,-4])
st = Optimisers.setup(Momentum(1/9), x_dx) # acts only on x not on dx
@test st isa Optimisers.Leaf
@test nothing === Optimisers.update!(st, x_dx) # mutates both arguments
@test x_dx.val Float16[0.8887, 2.0, 3.445]

shared = [1.0]
model = (x=shared, y=shared)
grad = deepcopy(model) # Enzyme produces something like this, grad.x === grad.y, already accumulated.
dup = Duplicated(model, model)
st2 = Optimisers.setup(Descent(0.1), model)
Optimisers.update!(st2, dup)
@test model.x [0.9]
shared .= 1
Optimisers.update!(st2, model, grad)
model.x [0.8] # This is wrong, but don't make it a test.
# Ideally, perhaps the 3-arg update! could notice that grad.x===grad.y, and not accumulate the gradient in this case?

@test_throws ArgumentError Optimisers.setup(Adam(), (; a=[1,2,3.], b=x_dx)) # Duplicated deep inside is not allowed
end
end
@testset verbose=true "Destructure" begin
include("destructure.jl")
Expand Down

2 comments on commit 2639523

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/119011

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.1 -m "<description of version>" 26395239c0307fadc4d1143d9ae1bc1a6cb2711e
git push origin v0.4.1

Please sign in to comment.