Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't load Yota at all #166

Merged
merged 8 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ChainRulesCore = "1"
Functors = "0.4"
Statistics = "1"
Yota = "0.8.2"
Zygote = "0.6.40"
julia = "1.6"

[extras]
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Yota = "cd998857-8626-517d-b929-70ad188a48f0"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "StaticArrays", "Yota", "Zygote"]
test = ["Test", "StaticArrays", "Zygote"]
19 changes: 0 additions & 19 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,6 @@ Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see
[Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference.


## Usage with [Yota.jl](https://github.com/dfdx/Yota.jl)

Yota is another modern automatic differentiation package, an alternative to Zygote.

Its main function is `Yota.grad`, which returns the loss as well as the gradient (like `Zygote.withgradient`)
but also returns a gradient component for the loss function.
To extract what Optimisers.jl needs, you can write (for the Flux model above):

```julia
using Yota

loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
sum(m(x)
end;

# Or else, this may save computing ∇image:
loss, (_, ∇model) = grad(m -> sum(m(image)), model);
```

## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)

The main design difference of Lux from Flux is that the tree of parameters is separate from
Expand Down
6 changes: 3 additions & 3 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ end
sum(gradient(m -> sum(destructure(m)[1])^3, (v, [4,5,6.0]))[1][1])
end[1] == [378, 378, 378]

@test_broken gradient([1,2,3.0]) do v
VERSION >= v"1.10" && @test gradient([1,2,3.0]) do v
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (x = v, y = sin, z = [4,5,6.0]))[1][1])
end[1] ≈ [8,16,24]
# Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z)
# Diffractor error in perform_optic_transform
end

VERSION < v"1.9-" && @testset "using Yota" begin
false && @testset "using Yota" begin
@test Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0]
@test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0])
@test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing)
Expand Down Expand Up @@ -175,7 +175,7 @@ end
# Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,)
end

VERSION < v"1.9-" && @testset "using Yota" begin
false && @testset "using Yota" begin
re1 = destructure(m1)[2]
@test Yota_gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0]
re2 = destructure(m2)[2]
Expand Down
4 changes: 2 additions & 2 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ end
end
end

VERSION < v"1.9-" && @testset "using Yota" begin
false && @testset "using Yota" begin
@testset "$(name(o))" for o in RULES
w′ = (abc = (α = rand(3, 3), β = rand(3, 3), γ = rand(3)), d = (δ = rand(3), ε = eps))
w = (abc = (α = 5rand(3, 3), β = rand(3, 3), γ = rand(3)), d = (δ = rand(3), ε = eps))
Expand Down Expand Up @@ -266,4 +266,4 @@ end

tree, x4 = Optimisers.update(tree, x3, g4)
@test x4 ≈ x3
end
end
15 changes: 9 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Optimisers
using ChainRulesCore, Functors, StaticArrays, Zygote, Yota
using ChainRulesCore, Functors, StaticArrays, Zygote
using LinearAlgebra, Statistics, Test, Random
using Optimisers: @.., @lazy
using Base.Broadcast: broadcasted, instantiate, Broadcasted
Expand Down Expand Up @@ -38,12 +38,15 @@ function Optimisers.apply!(o::BiRule, state, x, dx, dx2)
return state, dx
end

# Make Yota's output look like Zygote's:
# if VERSION < v"1.9-"
# using Yota
# end
# # Make Yota's output look like Zygote's:

Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2]))
y2z(::AbstractZero) = nothing # we don't care about different flavours of zero
y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t))) # namedtuples!
y2z(x) = x
# Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2]))
# y2z(::AbstractZero) = nothing # we don't care about different flavours of zero
# y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t))) # namedtuples!
# y2z(x) = x

@testset verbose=true "Optimisers.jl" begin
@testset verbose=true "Features" begin
Expand Down
Loading