diff --git a/.travis.yml b/.travis.yml index 6f3991439..d2aee5a26 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,6 +5,7 @@ os: julia: - 1.0 - 1.1 + - 1.2 - nightly matrix: allow_failures: @@ -14,7 +15,7 @@ notifications: # uncomment the following lines to override the default test script #script: # - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi -# - julia -e 'Pkg.clone(pwd()); Pkg.build("ChainRules"); Pkg.test("ChainRules"; coverage=true)' +# - julia -e 'Pkg.clone(pwd()); Pkg.build("AbstractChainRules"); Pkg.test("AbstractChainRules"; coverage=true)' after_success: # push coverage results to Coveralls - julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Coveralls.submit(Coveralls.process_folder())' diff --git a/LICENSE.md b/LICENSE.md index 8b768e457..f7be5b148 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -1,4 +1,4 @@ -The ChainRules.jl package is licensed under the MIT "Expat" License: +The AbstractChainRules.jl package is licensed under the MIT "Expat" License: > Copyright (c) 2018: Jarrett Revels. > diff --git a/Project.toml b/Project.toml index 7dc3acd5d..5ad611461 100644 --- a/Project.toml +++ b/Project.toml @@ -1,23 +1,18 @@ -name = "ChainRules" -uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.0.1" +name = "AbstractChainRules" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "0.1.0" [deps] Cassette = "7057c7e9-c182-5462-911a-8362d720325c" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + [compat] Cassette = "^0.2" -FDM = "^0.6" julia = "^1.0" [extras] -FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [targets] -test = ["FDM", "Random", "Test"] +test = ["Test", "LinearAlgebra"] diff --git a/README.md b/README.md index 14b798d75..6f90df675 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ -# ChainRules +# AbstractChainRules -[![Travis](https://travis-ci.org/JuliaDiff/ChainRules.jl.svg?branch=master)](https://travis-ci.org/JuliaDiff/ChainRules.jl) -[![Coveralls](https://coveralls.io/repos/github/JuliaDiff/ChainRules.jl/badge.svg?branch=master)](https://coveralls.io/github/JuliaDiff/ChainRules.jl?branch=master) -[![](https://img.shields.io/badge/docs-latest-blue.svg)](https://JuliaDiff.github.io/ChainRules.jl/latest) +[![Travis](https://travis-ci.org/JuliaDiff/AbstractChainRules.jl.svg?branch=master)](https://travis-ci.org/JuliaDiff/AbstractChainRules.jl) +[![Coveralls](https://coveralls.io/repos/github/JuliaDiff/AbstractChainRules.jl/badge.svg?branch=master)](https://coveralls.io/github/JuliaDiff/AbstractChainRules.jl?branch=master) +[![](https://img.shields.io/badge/docs-latest-blue.svg)](https://JuliaDiff.github.io/AbstractChainRules.jl/latest) -The ChainRules package provides a variety of common utilities that can be used by downstream automatic differentiation (AD) tools to define and execute forward-, reverse-, and mixed-mode primitives. +The AbstractChainRules package provides a variety of common utilities that can be used by downstream automatic differentiation (AD) tools to define and execute forward-, reverse-, and mixed-mode primitives. This package is a WIP; the framework is essentially there, but there are a bunch of TODOs, virtually no tests, etc. PRs welcome! Documentation is incoming, which should help if you'd like to contribute. @@ -18,4 +18,4 @@ Here are some of the basic goals for the package: - Control-inverted design: rule authors can fully specify derivatives in a concise manner while naturally allowing the caller to compute only what they need. -The ChainRules source code follows the [YASGuide](https://github.com/jrevels/YASGuide). +The AbstractChainRules source code follows the [YASGuide](https://github.com/jrevels/YASGuide). diff --git a/REQUIRE b/REQUIRE deleted file mode 100644 index fa5c8e5b5..000000000 --- a/REQUIRE +++ /dev/null @@ -1,4 +0,0 @@ -julia 1.0.0 -SpecialFunctions -NaNMath -Cassette 0.2.0 0.3.0 diff --git a/docs/make.jl b/docs/make.jl index 15b00f2c3..ee82df724 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,11 +1,11 @@ -using ChainRules +using AbstractChainRules using Documenter -makedocs(modules=[ChainRules], - sitename="ChainRules", +makedocs(modules=[AbstractChainRules], + sitename="AbstractChainRules", authors="Jarrett Revels and other contributors", pages=["Introduction" => "index.md", "Getting Started" => "getting_started.md", - "ChainRules API Documentation" => "api.md"]) + "AbstractChainRules API Documentation" => "api.md"]) -deploydocs(repo="github.com/JuliaDiff/ChainRules.jl.git") +deploydocs(repo="github.com/JuliaDiff/AbstractChainRules.jl.git") diff --git a/docs/src/api.md b/docs/src/api.md index b9e9817e6..e52854daf 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -1,24 +1,24 @@ -# ChainRules API Documentation +# AbstractChainRules API Documentation ```@docs -ChainRules.frule -ChainRules.rrule -ChainRules.AbstractRule -ChainRules.Rule -ChainRules.DNERule -ChainRules.WirtingerRule -ChainRules.accumulate -ChainRules.accumulate! -ChainRules.store! +AbstractChainRules.frule +AbstractChainRules.rrule +AbstractChainRules.AbstractRule +AbstractChainRules.Rule +AbstractChainRules.DNERule +AbstractChainRules.WirtingerRule +AbstractChainRules.accumulate +AbstractChainRules.accumulate! +AbstractChainRules.store! ``` ```@docs -ChainRules.AbstractDifferential -ChainRules.extern -ChainRules.Casted -ChainRules.Wirtinger -ChainRules.Thunk -ChainRules.Zero -ChainRules.DNE -ChainRules.One +AbstractChainRules.AbstractDifferential +AbstractChainRules.extern +AbstractChainRules.Casted +AbstractChainRules.Wirtinger +AbstractChainRules.Thunk +AbstractChainRules.Zero +AbstractChainRules.DNE +AbstractChainRules.One ``` diff --git a/docs/src/index.md b/docs/src/index.md index 725deb0ee..6cc7331d4 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,10 +1,10 @@ ```@meta -DocTestSetup = :(using ChainRules) -CurrentModule = ChainRules +DocTestSetup = :(using AbstractChainRules) +CurrentModule = AbstractChainRules ``` -# ChainRules +# AbstractChainRules -Hello! Welcome to ChainRules's documentation. +Hello! Welcome to AbstractChainRules's documentation. -For an initial overview of ChainRules, please see the README. Otherwise, feel free to peruse available documentation via the sidebar. +For an initial overview of AbstractChainRules, please see the README. Otherwise, feel free to peruse available documentation via the sidebar. diff --git a/src/AbstractChainRules.jl b/src/AbstractChainRules.jl new file mode 100644 index 000000000..de0a35eed --- /dev/null +++ b/src/AbstractChainRules.jl @@ -0,0 +1,11 @@ +module AbstractChainRules +using Cassette +using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable + +export AbstractRule, Rule, frule, rrule +export @scalar_rule, @thunk +export extern, cast, store!, Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule + +include("differentials.jl") +include("rules.jl") +end # module diff --git a/src/ChainRules.jl b/src/ChainRules.jl deleted file mode 100644 index 2277699c7..000000000 --- a/src/ChainRules.jl +++ /dev/null @@ -1,34 +0,0 @@ -module ChainRules - -using Cassette -using LinearAlgebra -using LinearAlgebra.BLAS -using Statistics -using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable - -if VERSION < v"1.3.0-DEV.142" - # In prior versions, the BLAS submodule also exported `dot`, which caused a conflict - # with its parent module. To get around this, we can simply create a hard binding for - # the one we want to use without qualification. - import LinearAlgebra: dot -end - -import NaNMath, SpecialFunctions - -export AbstractRule, Rule, frule, rrule - -include("differentials.jl") -include("rules.jl") -include("rules/base.jl") -include("rules/array.jl") -include("rules/broadcast.jl") -include("rules/mapreduce.jl") -include("rules/linalg/utils.jl") -include("rules/linalg/blas.jl") -include("rules/linalg/dense.jl") -include("rules/linalg/structured.jl") -include("rules/linalg/factorization.jl") -include("rules/nanmath.jl") -include("rules/specialfunctions.jl") - -end # module diff --git a/src/differentials.jl b/src/differentials.jl index 689f0bbe6..2da7f15d1 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -140,7 +140,7 @@ function mul_wirtinger(a::Wirtinger, b::Wirtinger) such that we assume the chain rule application is of the form `f_a ∘ f_b` instead of `f_b ∘ f_a`. However, picking such a convention is likely to lead to silently incorrect derivatives due to commutativity assumptions - in downstream generic code that deals with the reals. Thus, ChainRules + in downstream generic code that deals with the reals. Thus, AbstractChainRules makes this operation an error instead. """) end diff --git a/src/rules.jl b/src/rules.jl index b07883c25..4b204b1c3 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -12,7 +12,7 @@ return that calculated differential value. For example: ```julia-repl -julia> using ChainRules: frule, rrule, AbstractRule +julia> using AbstractChainRules: frule, rrule, AbstractRule julia> x, y = rand(2); @@ -61,7 +61,7 @@ Base.getindex(rule::AbstractRule, i::Integer) = i == 1 ? rule : throw(BoundsErro """ accumulate(Δ, rule::AbstractRule, args...) -Return `Δ + rule(args...)` evaluated in a manner that supports ChainRules' +Return `Δ + rule(args...)` evaluated in a manner that supports AbstractChainRules' various `AbstractDifferential` types. This method intended to be customizable for specific rules/input types. For @@ -112,32 +112,6 @@ See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref) """ store!(Δ, rule::AbstractRule, args...) = materialize!(Δ, broadcastable(rule(args...))) -# Special purpose updating for operations which can be done in-place. This function is -# just internal and free-form; it is not a method of `accumulate!` directly as it does -# not adhere to the expected method signature form, i.e. `accumulate!(value, rule, args)`. -# Instead it's `_update!(old, new, extrastuff...)` and is not specific to any particular -# rule. - -_update!(x, y) = x + y -_update!(x::Array{T,N}, y::AbstractArray{T,N}) where {T,N} = x .+= y - -_update!(x, ::Zero) = x -_update!(::Zero, y) = y -_update!(::Zero, ::Zero) = Zero() - -function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}) where Ns - return NamedTuple{Ns}(map(p->_update!(getproperty(x, p), getproperty(y, p)), Ns)) -end - -function _update!(x::NamedTuple, y, p::Symbol) - new = NamedTuple{(p,)}((_update!(getproperty(x, p), y),)) - return merge(x, new) -end - -function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}, p::Symbol) where Ns - return _update!(x, getproperty(y, p), p) -end - ##### ##### `Rule` ##### @@ -232,7 +206,7 @@ rules, where e.g. `frule` is used within a `rrule` definition. For example, broadcasted functions may not themselves be forward-mode *primitives*, but are often forward-mode *differentiable*. -ChainRules, by design, is decoupled from any specific AD implementation. How, +AbstractChainRules, by design, is decoupled from any specific AD implementation. How, then, do we know which AD to fall back to when there isn't a primitive defined? Well, if you're a greedy AD implementation, you can just overload `frule` and/or @@ -244,12 +218,12 @@ It turns out, Cassette solves this problem nicely by allowing AD authors to overload the fallbacks w.r.t. their own context. Example using ForwardDiff: ``` -using ChainRules, ForwardDiff, Cassette +using AbstractChainRules, ForwardDiff, Cassette Cassette.@context MyChainRuleCtx # ForwardDiff, itself, can call `my_frule` instead of -# `frule` to utilize the ForwardDiff-injected ChainRules +# `frule` to utilize the ForwardDiff-injected AbstractChainRules # infrastructure my_frule(args...) = Cassette.overdub(MyChainRuleCtx(), frule, args...) @@ -377,23 +351,6 @@ See also: [`frule`](@ref), [`AbstractRule`](@ref), [`@scalar_rule`](@ref) """ rrule(::Any, ::Vararg{Any}; kwargs...) = nothing -@noinline function _throw_checked_rrule_error(f, args...; kwargs...) - io = IOBuffer() - print(io, "can't differentiate `", f, '(') - join(io, map(arg->string("::", typeof(arg)), args), ", ") - if !isempty(kwargs) - print(io, ";") - join(io, map(((k, v),)->string(k, "=", v), kwargs), ", ") - end - print(io, ")`; no matching `rrule` is defined") - throw(ArgumentError(String(take!(io)))) -end - -function _checked_rrule(f, args...; kwargs...) - r = rrule(f, args...; kwargs...) - r isa Nothing && _throw_checked_rrule_error(f, args...; kwargs...) - return r -end ##### ##### macros @@ -410,7 +367,7 @@ A convenience macro that generates simple scalar forward or reverse rules using the provided partial derivatives. Specifically, generates the corresponding methods for `frule` and `rrule`: - function ChainRules.frule(::typeof(f), x₁::Number, x₂::Number, ...) + function AbstractChainRules.frule(::typeof(f), x₁::Number, x₂::Number, ...) Ω = f(x₁, x₂, ...) \$(statement₁, statement₂, ...) return Ω, (Rule((Δx₁, Δx₂, ...) -> ∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...), @@ -418,7 +375,7 @@ methods for `frule` and `rrule`: ...) end - function ChainRules.rrule(::typeof(f), x₁::Number, x₂::Number, ...) + function AbstractChainRules.rrule(::typeof(f), x₁::Number, x₂::Number, ...) Ω = f(x₁, x₂, ...) \$(statement₁, statement₂, ...) return Ω, (Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...), @@ -453,7 +410,7 @@ is equivalent to: (∂f₂_∂x₁, ∂f₂_∂x₂, ...), ...) -For examples, see ChainRules' `rules` directory. +For examples, see AbstractChainRules' `rules` directory. See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref) """ @@ -493,12 +450,12 @@ macro scalar_rule(call, maybe_setup, partials...) forward_rules = length(forward_rules) == 1 ? forward_rules[1] : Expr(:tuple, forward_rules...) reverse_rules = length(reverse_rules) == 1 ? reverse_rules[1] : Expr(:tuple, reverse_rules...) return quote - function ChainRules.frule(::typeof($f), $(inputs...)) + function AbstractChainRules.frule(::typeof($f), $(inputs...)) $(esc(:Ω)) = $call $(setup_stmts...) return $(esc(:Ω)), $forward_rules end - function ChainRules.rrule(::typeof($f), $(inputs...)) + function AbstractChainRules.rrule(::typeof($f), $(inputs...)) $(esc(:Ω)) = $call $(setup_stmts...) return $(esc(:Ω)), $reverse_rules diff --git a/src/rules/array.jl b/src/rules/array.jl deleted file mode 100644 index b071f0447..000000000 --- a/src/rules/array.jl +++ /dev/null @@ -1,58 +0,0 @@ -##### -##### `reshape` -##### - -function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}}) - return reshape(A, dims), (Rule(Ȳ->reshape(Ȳ, dims)), DNERule()) -end - -function rrule(::typeof(reshape), A::AbstractArray, dims::Int...) - Y, (rule, _) = rrule(reshape, A, dims) - return Y, (rule, fill(DNERule(), length(dims))...) -end - -##### -##### `hcat` (🐈) -##### - -function rrule(::typeof(hcat), A::AbstractArray, Bs::AbstractArray...) - Y = hcat(A, Bs...) - Xs = (A, Bs...) - rules = ntuple(length(Bs) + 1) do i - l = mapreduce(j->size(Xs[j], 2), Base.add_sum, 1:i-1; init=0) - u = l + size(Xs[i], 2) - dim = u > l + 1 ? (l+1:u) : u - # NOTE: The copy here is defensive, since `selectdim` returns a view which we can - # materialize with `copy` - Rule(Ȳ->copy(selectdim(Ȳ, 2, dim))) - end - return Y, rules -end - -##### -##### `vcat` -##### - -function rrule(::typeof(vcat), A::AbstractArray, Bs::AbstractArray...) - Y = vcat(A, Bs...) - n = size(A, 1) - ∂A = Rule(Ȳ->copy(selectdim(Ȳ, 1, 1:n))) - ∂Bs = ntuple(length(Bs)) do i - l = n + mapreduce(j->size(Bs[j], 1), Base.add_sum, 1:i-1; init=0) - u = l + size(Bs[i], 1) - Rule(Ȳ->copy(selectdim(Ȳ, 1, l+1:u))) - end - return Y, (∂A, ∂Bs...) -end - -##### -##### `fill` -##### - -function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}}) - return fill(value, dims), (Rule(sum), DNERule()) -end - -function rrule(::typeof(fill), value::Any, dims::Int...) - return fill(value, dims), (Rule(sum), ntuple(_->DNERule(), length(dims))...) -end diff --git a/src/rules/base.jl b/src/rules/base.jl deleted file mode 100644 index 106707c91..000000000 --- a/src/rules/base.jl +++ /dev/null @@ -1,83 +0,0 @@ -@scalar_rule(abs2(x), Wirtinger(x', x)) -@scalar_rule(log(x), inv(x)) -@scalar_rule(log10(x), inv(x) / log(oftype(x, 10))) -@scalar_rule(log2(x), inv(x) / log(oftype(x, 2))) -@scalar_rule(log1p(x), inv(x + 1)) -@scalar_rule(expm1(x), exp(x)) -@scalar_rule(sin(x), cos(x)) -@scalar_rule(cos(x), -sin(x)) -@scalar_rule(sinpi(x), π * cospi(x)) -@scalar_rule(cospi(x), -π * sinpi(x)) -@scalar_rule(sind(x), (π / oftype(x, 180)) * cosd(x)) -@scalar_rule(cosd(x), -(π / oftype(x, 180)) * sind(x)) -@scalar_rule(asin(x), inv(sqrt(1 - x^2))) -@scalar_rule(acos(x), -inv(sqrt(1 - x^2))) -@scalar_rule(atan(x), inv(1 + x^2)) -@scalar_rule(asec(x), inv(abs(x) * sqrt(x^2 - 1))) -@scalar_rule(acsc(x), -inv(abs(x) * sqrt(x^2 - 1))) -@scalar_rule(acot(x), -inv(1 + x^2)) -@scalar_rule(asind(x), oftype(x, 180) / π / sqrt(1 - x^2)) -@scalar_rule(acosd(x), -oftype(x, 180) / π / sqrt(1 - x^2)) -@scalar_rule(atand(x), oftype(x, 180) / π / (1 + x^2)) -@scalar_rule(asecd(x), oftype(x, 180) / π / abs(x) / sqrt(x^2 - 1)) -@scalar_rule(acscd(x), -oftype(x, 180) / π / abs(x) / sqrt(x^2 - 1)) -@scalar_rule(acotd(x), -oftype(x, 180) / π / (1 + x^2)) -@scalar_rule(sinh(x), cosh(x)) -@scalar_rule(cosh(x), sinh(x)) -@scalar_rule(tanh(x), sech(x)^2) -@scalar_rule(coth(x), -(csch(x)^2)) -@scalar_rule(asinh(x), inv(sqrt(x^2 + 1))) -@scalar_rule(acosh(x), inv(sqrt(x^2 - 1))) -@scalar_rule(atanh(x), inv(1 - x^2)) -@scalar_rule(asech(x), -inv(x * sqrt(1 - x^2))) -@scalar_rule(acsch(x), -inv(abs(x) * sqrt(1 + x^2))) -@scalar_rule(acoth(x), inv(1 - x^2)) -@scalar_rule(deg2rad(x), π / oftype(x, 180)) -@scalar_rule(rad2deg(x), oftype(x, 180) / π) -@scalar_rule(conj(x), Wirtinger(Zero(), One())) -@scalar_rule(adjoint(x), Wirtinger(Zero(), One())) -@scalar_rule(transpose(x), One()) -@scalar_rule(abs(x), sign(x)) -@scalar_rule(rem2pi(x, r::RoundingMode), (One(), DNE())) -@scalar_rule(+(x), One()) -@scalar_rule(-(x), -1) -@scalar_rule(+(x, y), (One(), One())) -@scalar_rule(-(x, y), (One(), -1)) -@scalar_rule(/(x, y), (inv(y), -(x / y / y))) -@scalar_rule(\(x, y), (-(y / x / x), inv(x))) -@scalar_rule(^(x, y), (y * x^(y - 1), Ω * log(x))) -@scalar_rule(inv(x), -abs2(Ω)) -@scalar_rule(sqrt(x), inv(2 * Ω)) -@scalar_rule(cbrt(x), inv(3 * Ω^2)) -@scalar_rule(exp(x), Ω) -@scalar_rule(exp2(x), Ω * log(oftype(x, 2))) -@scalar_rule(exp10(x), Ω * log(oftype(x, 10))) -@scalar_rule(tan(x), 1 + Ω^2) -@scalar_rule(sec(x), Ω * tan(x)) -@scalar_rule(csc(x), -Ω * cot(x)) -@scalar_rule(cot(x), -(1 + Ω^2)) -@scalar_rule(tand(x), (π / oftype(x, 180)) * (1 + Ω^2)) -@scalar_rule(secd(x), (π / oftype(x, 180)) * Ω * tand(x)) -@scalar_rule(cscd(x), -(π / oftype(x, 180)) * Ω * cotd(x)) -@scalar_rule(cotd(x), -(π / oftype(x, 180)) * (1 + Ω^2)) -@scalar_rule(sech(x), -tanh(x) * Ω) -@scalar_rule(csch(x), -coth(x) * Ω) -@scalar_rule(hypot(x, y), (y / Ω, x / Ω)) -@scalar_rule(sincos(x), @setup((sinx, cosx) = Ω), cosx, -sinx) -@scalar_rule(atan(y, x), @setup(u = hypot(x, y)), (x / u, y / u)) -@scalar_rule(max(x, y), @setup(gt = x > y), (gt, !gt)) -@scalar_rule(min(x, y), @setup(gt = x > y), (!gt, gt)) -@scalar_rule(mod(x, y), @setup((u, nan) = promote(x / y, NaN16)), - (ifelse(isint, nan, one(u)), ifelse(isint, nan, -floor(u)))) -@scalar_rule(rem(x, y), @setup((u, nan) = promote(x / y, NaN16)), - (ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u)))) - -# product rule requires special care for arguments where `mul` is non-commutative - -frule(::typeof(*), x, y) = x * y, Rule((Δx, Δy) -> Δx * y + x * Δy) - -rrule(::typeof(*), x, y) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ)) - -frule(::typeof(identity), x) = x, Rule(identity) - -rrule(::typeof(identity), x) = x, Rule(identity) diff --git a/src/rules/broadcast.jl b/src/rules/broadcast.jl deleted file mode 100644 index f3685f5ca..000000000 --- a/src/rules/broadcast.jl +++ /dev/null @@ -1,24 +0,0 @@ -#= -TODO: This partial derivative extraction should be doable without the extra -temporaries utilized here, but AFAICT such an approach is hard to write -without relying on inference hacks unless we have something akin to -https://github.com/JuliaLang/julia/issues/22129. -=# -function _cast_diff(f, x) - element_rule = u -> begin - fu, du = frule(f, u) - fu, extern(du(One())) - end - results = broadcast(element_rule, x) - return first.(results), last.(results) -end - -function frule(::typeof(broadcast), f, x) - Ω, ∂x = _cast_diff(f, x) - return Ω, Rule((_, Δx) -> Δx * cast(∂x)) -end - -function rrule(::typeof(broadcast), f, x) - values, derivs = _cast_diff(f, x) - return values, (DNERule(), Rule(ΔΩ -> ΔΩ * cast(derivs))) -end diff --git a/src/rules/linalg/blas.jl b/src/rules/linalg/blas.jl deleted file mode 100644 index fb5b23f4b..000000000 --- a/src/rules/linalg/blas.jl +++ /dev/null @@ -1,124 +0,0 @@ -#= -These implementations were ported from the wonderful DiffLinearAlgebra -package (https://github.com/invenia/DiffLinearAlgebra.jl). -=# - -using LinearAlgebra: BlasFloat - -_zeros(x) = fill!(similar(x), zero(eltype(x))) - -_rule_via(∂) = Rule(ΔΩ -> isa(ΔΩ, Zero) ? ΔΩ : ∂(extern(ΔΩ))) - -##### -##### `BLAS.dot` -##### - -frule(::typeof(BLAS.dot), x, y) = frule(dot, x, y) - -rrule(::typeof(BLAS.dot), x, y) = rrule(dot, x, y) - -function rrule(::typeof(BLAS.dot), n, X, incx, Y, incy) - Ω = BLAS.dot(n, X, incx, Y, incy) - ∂X = ΔΩ -> scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx) - ∂Y = ΔΩ -> scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy) - return Ω, (DNERule(), _rule_via(∂X), DNERule(), _rule_via(∂Y), DNERule()) -end - -##### -##### `BLAS.nrm2` -##### - -function frule(::typeof(BLAS.nrm2), x) - Ω = BLAS.nrm2(x) - return Ω, Rule(Δx -> sum(Δx * cast(@thunk(x * inv(Ω))))) -end - -function rrule(::typeof(BLAS.nrm2), x) - Ω = BLAS.nrm2(x) - return Ω, Rule(ΔΩ -> ΔΩ * @thunk(x * inv(Ω))) -end - -function rrule(::typeof(BLAS.nrm2), n, X, incx) - Ω = BLAS.nrm2(n, X, incx) - ∂X = ΔΩ -> scal!(n, ΔΩ / Ω, blascopy!(n, X, incx, _zeros(X), incx), incx) - return Ω, (DNERule(), _rule_via(∂X), DNERule()) -end - -##### -##### `BLAS.asum` -##### - -frule(::typeof(BLAS.asum), x) = (BLAS.asum(x), Rule(Δx -> sum(cast(sign, x) * Δx))) - -rrule(::typeof(BLAS.asum), x) = (BLAS.asum(x), Rule(ΔΩ -> ΔΩ * cast(sign, x))) - -function rrule(::typeof(BLAS.asum), n, X, incx) - Ω = BLAS.asum(n, X, incx) - ∂X = ΔΩ -> scal!(n, ΔΩ, blascopy!(n, sign.(X), incx, _zeros(X), incx), incx) - return Ω, (DNERule(), _rule_via(∂X), DNERule()) -end - -##### -##### `BLAS.gemv` -##### - -function rrule(::typeof(gemv), tA::Char, α::T, A::AbstractMatrix{T}, - x::AbstractVector{T}) where T<:BlasFloat - y = gemv(tA, α, A, x) - if uppercase(tA) === 'N' - ∂A = Rule(ȳ -> α * ȳ * x', (Ā, ȳ) -> ger!(α, ȳ, x, Ā)) - ∂x = Rule(ȳ -> gemv('T', α, A, ȳ), (x̄, ȳ) -> gemv!('T', α, A, ȳ, one(T), x̄)) - else - ∂A = Rule(ȳ -> α * x * ȳ', (Ā, ȳ) -> ger!(α, x, ȳ, Ā)) - ∂x = Rule(ȳ -> gemv('N', α, A, ȳ), (x̄, ȳ) -> gemv!('N', α, A, ȳ, one(T), x̄)) - end - return y, (DNERule(), Rule(ȳ -> dot(ȳ, y) / α), ∂A, ∂x) -end - -function rrule(::typeof(gemv), tA::Char, A::AbstractMatrix{T}, - x::AbstractVector{T}) where T<:BlasFloat - y, (dtA, _, dA, dx) = rrule(gemv, tA, one(T), A, x) - return y, (dtA, dA, dx) -end - -##### -##### `BLAS.gemm` -##### - -function rrule(::typeof(gemm), tA::Char, tB::Char, α::T, - A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat - C = gemm(tA, tB, α, A, B) - β = one(T) - if uppercase(tA) === 'N' - if uppercase(tB) === 'N' - ∂A = Rule(C̄ -> gemm('N', 'T', α, C̄, B), - (Ā, C̄) -> gemm!('N', 'T', α, C̄, B, β, Ā)) - ∂B = Rule(C̄ -> gemm('T', 'N', α, A, C̄), - (B̄, C̄) -> gemm!('T', 'N', α, A, C̄, β, B̄)) - else - ∂A = Rule(C̄ -> gemm('N', 'N', α, C̄, B), - (Ā, C̄) -> gemm!('N', 'N', α, C̄, B, β, Ā)) - ∂B = Rule(C̄ -> gemm('T', 'N', α, C̄, A), - (B̄, C̄) -> gemm!('T', 'N', α, C̄, A, β, B̄)) - end - else - if uppercase(tB) === 'N' - ∂A = Rule(C̄ -> gemm('N', 'T', α, B, C̄), - (Ā, C̄) -> gemm!('N', 'T', α, B, C̄, β, Ā)) - ∂B = Rule(C̄ -> gemm('N', 'N', α, A, C̄), - (B̄, C̄) -> gemm!('N', 'N', α, A, C̄, β, B̄)) - else - ∂A = Rule(C̄ -> gemm('T', 'T', α, B, C̄), - (Ā, C̄) -> gemm!('T', 'T', α, B, C̄, β, Ā)) - ∂B = Rule(C̄ -> gemm('T', 'T', α, C̄, A), - (B̄, C̄) -> gemm!('T', 'T', α, C̄, A, β, B̄)) - end - end - return C, (DNERule(), DNERule(), Rule(C̄ -> dot(C̄, C) / α), ∂A, ∂B) -end - -function rrule(::typeof(gemm), tA::Char, tB::Char, - A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat - C, (dtA, dtB, _, dA, dB) = rrule(gemm, tA, tB, one(T), A, B) - return C, (dtA, dtB, dA, dB) -end diff --git a/src/rules/linalg/dense.jl b/src/rules/linalg/dense.jl deleted file mode 100644 index 9eb3ee168..000000000 --- a/src/rules/linalg/dense.jl +++ /dev/null @@ -1,140 +0,0 @@ -using LinearAlgebra: AbstractTriangular - -# Matrix wrapper types that we know are square and are thus potentially invertible. For -# these we can use simpler definitions for `/` and `\`. -const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}} - -##### -##### `dot` -##### - -function frule(::typeof(dot), x, y) - return dot(x, y), Rule((Δx, Δy) -> sum(Δx * cast(y)) + sum(cast(x) * Δy)) -end - -function rrule(::typeof(dot), x, y) - return dot(x, y), (Rule(ΔΩ -> ΔΩ * cast(y)), Rule(ΔΩ -> cast(x) * ΔΩ)) -end - -##### -##### `inv` -##### - -function frule(::typeof(inv), x::AbstractArray) - Ω = inv(x) - m = @thunk(-Ω) - return Ω, Rule(Δx -> m * Δx * Ω) -end - -function rrule(::typeof(inv), x::AbstractArray) - Ω = inv(x) - m = @thunk(-Ω') - return Ω, Rule(ΔΩ -> m * ΔΩ * Ω') -end - -##### -##### `det` -##### - -function frule(::typeof(det), x) - Ω, m = det(x), @thunk(inv(x)) - return Ω, Rule(Δx -> Ω * tr(extern(m * Δx))) -end - -function rrule(::typeof(det), x) - Ω, m = det(x), @thunk(inv(x)') - return Ω, Rule(ΔΩ -> Ω * ΔΩ * m) -end - -##### -##### `logdet` -##### - -function frule(::typeof(logdet), x) - Ω, m = logdet(x), @thunk(inv(x)) - return Ω, Rule(Δx -> tr(extern(m * Δx))) -end - -function rrule(::typeof(logdet), x) - Ω, m = logdet(x), @thunk(inv(x)') - return Ω, Rule(ΔΩ -> ΔΩ * m) -end - -##### -##### `trace` -##### - -frule(::typeof(tr), x) = (tr(x), Rule(Δx -> tr(extern(Δx)))) - -rrule(::typeof(tr), x) = (tr(x), Rule(ΔΩ -> Diagonal(fill(ΔΩ, size(x, 1))))) - -##### -##### `*` -##### - -function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real}) - return A * B, (Rule(Ȳ -> Ȳ * B'), Rule(Ȳ -> A' * Ȳ)) -end - -##### -##### `/` -##### - -function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatrix{<:Real} - Y = A / B - S = T.name.wrapper - ∂A = Rule(Ȳ -> Ȳ / B') - ∂B = Rule(Ȳ -> S(-Y' * (Ȳ / B'))) - return Y, (∂A, ∂B) -end - -function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) - Aᵀ, dA = rrule(adjoint, A) - Bᵀ, dB = rrule(adjoint, B) - Cᵀ, (dBᵀ, dAᵀ) = rrule(\, Bᵀ, Aᵀ) - C, dC = rrule(adjoint, Cᵀ) - ∂A = Rule(dA∘dAᵀ∘dC) - ∂B = Rule(dA∘dBᵀ∘dC) - return C, (∂A, ∂B) -end - -##### -##### `\` -##### - -function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMatrix{<:Real} - Y = A \ B - S = T.name.wrapper - ∂A = Rule(Ȳ -> S(-(A' \ Ȳ) * Y')) - ∂B = Rule(Ȳ -> A' \ Ȳ) - return Y, (∂A, ∂B) -end - -function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) - Y = A \ B - ∂A = Rule() do Ȳ - B̄ = A' \ Ȳ - Ā = -B̄ * Y' - _add!(Ā, (B - A * Y) * B̄' / A') - _add!(Ā, A' \ Y * (Ȳ' - B̄'A)) - Ā - end - ∂B = Rule(Ȳ -> A' \ Ȳ) - return Y, (∂A, ∂B) -end - -##### -##### `norm` -##### - -function rrule(::typeof(norm), A::AbstractArray{<:Real}, p::Real=2) - y = norm(A, p) - u = y^(1-p) - ∂A = Rule(ȳ -> ȳ .* u .* abs.(A).^p ./ A) - ∂p = Rule(ȳ -> ȳ * (u * sum(a->abs(a)^p * log(abs(a)), A) - y * log(y)) / p) - return y, (∂A, ∂p) -end - -function rrule(::typeof(norm), x::Real, p::Real=2) - return norm(x, p), (Rule(ȳ -> ȳ * sign(x)), Rule(_ -> zero(x))) -end diff --git a/src/rules/linalg/factorization.jl b/src/rules/linalg/factorization.jl deleted file mode 100644 index 72527fcc6..000000000 --- a/src/rules/linalg/factorization.jl +++ /dev/null @@ -1,244 +0,0 @@ -using LinearAlgebra: checksquare -using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger! - -##### -##### `svd` -##### - -function rrule(::typeof(svd), X::AbstractMatrix{<:Real}) - F = svd(X) - ∂X = Rule() do Ȳ::NamedTuple{(:U,:S,:V)} - svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V) - end - return F, ∂X -end - -function rrule(::typeof(getproperty), F::SVD, x::Symbol) - if x === :U - rule = Ȳ->(U=Ȳ, S=zero(F.S), V=zero(F.V)) - elseif x === :S - rule = Ȳ->(U=zero(F.U), S=Ȳ, V=zero(F.V)) - elseif x === :V - rule = Ȳ->(U=zero(F.U), S=zero(F.S), V=Ȳ) - elseif x === :Vt - # TODO: This could be made to work, but it'd be a pain - throw(ArgumentError("Vt is unsupported; use V and transpose the result")) - end - update = (X̄::NamedTuple{(:U,:S,:V)}, Ȳ)->_update!(X̄, rule(Ȳ), x) - return getproperty(F, x), (Rule(rule, update), DNERule()) -end - -function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix) - # Note: assuming a thin factorization, i.e. svd(A, full=false), which is the default - U = USV.U - s = USV.S - V = USV.V - Vt = USV.Vt - - k = length(s) - T = eltype(s) - F = T[i == j ? 1 : inv(@inbounds s[j]^2 - s[i]^2) for i = 1:k, j = 1:k] - - # We do a lot of matrix operations here, so we'll try to be memory-friendly and do - # as many of the computations in-place as possible. Benchmarking shows that the in- - # place functions here are significantly faster than their out-of-place, naively - # implemented counterparts, and allocate no additional memory. - Ut = U' - FUᵀŪ = _mulsubtrans!(Ut*Ū, F) # F .* (UᵀŪ - ŪᵀU) - FVᵀV̄ = _mulsubtrans!(Vt*V̄, F) # F .* (VᵀV̄ - V̄ᵀV) - ImUUᵀ = _eyesubx!(U*Ut) # I - UUᵀ - ImVVᵀ = _eyesubx!(V*Vt) # I - VVᵀ - - S = Diagonal(s) - S̄ = Diagonal(s̄) - - Ā = _add!(U * FUᵀŪ * S, ImUUᵀ * (Ū / S)) * Vt - _add!(Ā, U * S̄ * Vt) - _add!(Ā, U * _add!(S * FVᵀV̄ * Vt, (S \ V̄') * ImVVᵀ)) - - return Ā -end - -##### -##### `cholesky` -##### - -function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real}) - F = cholesky(X) - ∂X = Rule(Ȳ->chol_blocked_rev(Matrix(Ȳ), Matrix(F.U), 25, true)) - return F, ∂X -end - -function rrule(::typeof(getproperty), F::Cholesky, x::Symbol) - if x === :U - if F.uplo === 'U' - ∂F = Ȳ->UpperTriangular(Ȳ) - else - ∂F = Ȳ->LowerTriangular(Ȳ') - end - elseif x === :L - if F.uplo === 'L' - ∂F = Ȳ->LowerTriangular(Ȳ) - else - ∂F = Ȳ->UpperTriangular(Ȳ') - end - end - return getproperty(F, x), (Rule(∂F), DNERule()) -end - -# See "Differentiation of the Cholesky decomposition" (Murray 2016), pages 5-9 in particular, -# for derivations. Here we're implementing the algorithms and their transposes. - -""" - level2partition(A::AbstractMatrix, j::Integer, upper::Bool) - -Returns views to various bits of the lower triangle of `A` according to the -`level2partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then -the transposed views are returned from the upper triangle of `A`. - -[1]: "Differentiation of the Cholesky decomposition", Murray 2016 -""" -function level2partition(A::AbstractMatrix, j::Integer, upper::Bool) - n = checksquare(A) - @boundscheck checkbounds(1:n, j) - if upper - r = view(A, 1:j-1, j) - d = view(A, j, j) - B = view(A, 1:j-1, j+1:n) - c = view(A, j, j+1:n) - else - r = view(A, j, 1:j-1) - d = view(A, j, j) - B = view(A, j+1:n, 1:j-1) - c = view(A, j+1:n, j) - end - return r, d, B, c -end - -""" - level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool) - -Returns views to various bits of the lower triangle of `A` according to the -`level3partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then -the transposed views are returned from the upper triangle of `A`. - -[1]: "Differentiation of the Cholesky decomposition", Murray 2016 -""" -function level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool) - n = checksquare(A) - @boundscheck checkbounds(1:n, j) - if upper - R = view(A, 1:j-1, j:k) - D = view(A, j:k, j:k) - B = view(A, 1:j-1, k+1:n) - C = view(A, j:k, k+1:n) - else - R = view(A, j:k, 1:j-1) - D = view(A, j:k, j:k) - B = view(A, k+1:n, 1:j-1) - C = view(A, k+1:n, j:k) - end - return R, D, B, C -end - -""" - chol_unblocked_rev!(Ā::AbstractMatrix, L::AbstractMatrix, upper::Bool) - -Compute the reverse-mode sensitivities of the Cholesky factorization in an unblocked manner. -If `upper` is `false`, then the sensitivites are computed from and stored in the lower triangle -of `Ā` and `L` respectively. If `upper` is `true` then they are computed and stored in the -upper triangles. If at input `upper` is `false` and `tril(Ā) = L̄`, at output -`tril(Ā) = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and -`triu(Ā) = triu(Ū)`, at output `triu(Ā) = triu(Σ̄)` where `Σ = UᵀU`. -""" -function chol_unblocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, upper::Bool) where T<:Real - n = checksquare(Σ̄) - j = n - @inbounds for _ in 1:n - r, d, B, c = level2partition(L, j, upper) - r̄, d̄, B̄, c̄ = level2partition(Σ̄, j, upper) - - # d̄ <- d̄ - c'c̄ / d. - d̄[1] -= dot(c, c̄) / d[1] - - # [d̄ c̄'] <- [d̄ c̄'] / d. - d̄ ./= d - c̄ ./= d - - # r̄ <- r̄ - [d̄ c̄'] [r' B']'. - r̄ = axpy!(-Σ̄[j,j], r, r̄) - r̄ = gemv!(upper ? 'n' : 'T', -one(T), B, c̄, one(T), r̄) - - # B̄ <- B̄ - c̄ r. - B̄ = upper ? ger!(-one(T), r, c̄, B̄) : ger!(-one(T), c̄, r, B̄) - d̄ ./= 2 - j -= 1 - end - return (upper ? triu! : tril!)(Σ̄) -end - -function chol_unblocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, upper::Bool) - return chol_unblocked_rev!(copy(Σ̄), L, upper) -end - -""" - chol_blocked_rev!(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool) - -Compute the sensitivities of the Cholesky factorization using a blocked, cache-friendly -procedure. `Σ̄` are the sensitivities of `L`, and will be transformed into the sensitivities -of `Σ`, where `Σ = LLᵀ`. `nb` is the block size to use. If the upper triangle has been used -to represent the factorization, that is `Σ = UᵀU` where `U := Lᵀ`, then this should be -indicated by passing `upper = true`. -""" -function chol_blocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, nb::Integer, upper::Bool) where T<:Real - n = checksquare(Σ̄) - tmp = Matrix{T}(undef, nb, nb) - k = n - if upper - @inbounds for _ in 1:nb:n - j = max(1, k - nb + 1) - R, D, B, C = level3partition(L, j, k, true) - R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, true) - - C̄ = trsm!('L', 'U', 'N', 'N', one(T), D, C̄) - gemm!('N', 'N', -one(T), R, C̄, one(T), B̄) - gemm!('N', 'T', -one(T), C, C̄, one(T), D̄) - chol_unblocked_rev!(D̄, D, true) - gemm!('N', 'T', -one(T), B, C̄, one(T), R̄) - if size(D̄, 1) == nb - tmp = axpy!(one(T), D̄, transpose!(tmp, D̄)) - gemm!('N', 'N', -one(T), R, tmp, one(T), R̄) - else - gemm!('N', 'N', -one(T), R, D̄ + D̄', one(T), R̄) - end - - k -= nb - end - return triu!(Σ̄) - else - @inbounds for _ in 1:nb:n - j = max(1, k - nb + 1) - R, D, B, C = level3partition(L, j, k, false) - R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, false) - - C̄ = trsm!('R', 'L', 'N', 'N', one(T), D, C̄) - gemm!('N', 'N', -one(T), C̄, R, one(T), B̄) - gemm!('T', 'N', -one(T), C̄, C, one(T), D̄) - chol_unblocked_rev!(D̄, D, false) - gemm!('T', 'N', -one(T), C̄, B, one(T), R̄) - if size(D̄, 1) == nb - tmp = axpy!(one(T), D̄, transpose!(tmp, D̄)) - gemm!('N', 'N', -one(T), tmp, R, one(T), R̄) - else - gemm!('N', 'N', -one(T), D̄ + D̄', R, one(T), R̄) - end - - k -= nb - end - return tril!(Σ̄) - end -end - -function chol_blocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool) - return chol_blocked_rev!(copy(Σ̄), L, nb, upper) -end diff --git a/src/rules/linalg/structured.jl b/src/rules/linalg/structured.jl deleted file mode 100644 index d2ee20309..000000000 --- a/src/rules/linalg/structured.jl +++ /dev/null @@ -1,47 +0,0 @@ -# Structured matrices - -##### -##### `Diagonal` -##### - -rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), Rule(diag) - -rrule(::typeof(diag), A::AbstractMatrix) = diag(A), Rule(Diagonal) - -##### -##### `Symmetric` -##### - -rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back) - -_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ) -_symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ - -##### -##### `Adjoint` -##### - -# TODO: Deal with complex-valued arrays as well -rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) = Adjoint(A), Rule(adjoint) -rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) = Adjoint(A), Rule(vec∘adjoint) - -rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) = adjoint(A), Rule(adjoint) -rrule(::typeof(adjoint), A::AbstractVector{<:Real}) = adjoint(A), Rule(vec∘adjoint) - -##### -##### `Transpose` -##### - -rrule(::Type{<:Transpose}, A::AbstractMatrix) = Transpose(A), Rule(transpose) -rrule(::Type{<:Transpose}, A::AbstractVector) = Transpose(A), Rule(vec∘transpose) - -rrule(::typeof(transpose), A::AbstractMatrix) = transpose(A), Rule(transpose) -rrule(::typeof(transpose), A::AbstractVector) = transpose(A), Rule(vec∘transpose) - -##### -##### Triangular matrices -##### - -rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) = UpperTriangular(A), Rule(Matrix) - -rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) = LowerTriangular(A), Rule(Matrix) diff --git a/src/rules/linalg/utils.jl b/src/rules/linalg/utils.jl deleted file mode 100644 index ed9a9cb10..000000000 --- a/src/rules/linalg/utils.jl +++ /dev/null @@ -1,32 +0,0 @@ -# Some utility functions for optimizing linear algebra operations that aren't specific -# to any particular rule definition - -# F .* (X - X'), overwrites X -function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real - k = size(X, 1) - @inbounds for j = 1:k, i = 1:j # Iterate the upper triangle - if i == j - X[i,i] = zero(T) - else - X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j]) - end - end - X -end - -# I - X, overwrites X -function _eyesubx!(X::AbstractMatrix) - n, m = size(X) - @inbounds for j = 1:m, i = 1:n - X[i,j] = (i == j) - X[i,j] - end - X -end - -# X + Y, overwrites X -function _add!(X::AbstractVecOrMat{T}, Y::AbstractVecOrMat{T}) where T<:Real - @inbounds for i = eachindex(X, Y) - X[i] += Y[i] - end - X -end diff --git a/src/rules/mapreduce.jl b/src/rules/mapreduce.jl deleted file mode 100644 index 69a9b9b2b..000000000 --- a/src/rules/mapreduce.jl +++ /dev/null @@ -1,86 +0,0 @@ -##### -##### `map` -##### - -function rrule(::typeof(map), f, xs...) - y = map(f, xs...) - ∂xs = ntuple(length(xs)) do i - Rule() do ȳ - map(ȳ, xs...) do ȳi, xis... - _, ∂xis = _checked_rrule(f, xis...) - extern(∂xis[i](ȳi)) - end - end - end - return y, (DNERule(), ∂xs...) -end - -##### -##### `mapreduce`, `mapfoldl`, `mapfoldr` -##### - -for mf in (:mapreduce, :mapfoldl, :mapfoldr) - sig = :(rrule(::typeof($mf), f, op, x::AbstractArray{<:Real})) - call = :($mf(f, op, x)) - if mf === :mapreduce - insert!(sig.args, 2, Expr(:parameters, Expr(:kw, :dims, :(:)))) - insert!(call.args, 2, Expr(:parameters, Expr(:kw, :dims, :dims))) - end - body = quote - y = $call - ∂x = Rule() do ȳ - broadcast(x, ȳ) do xi, ȳi - _, ∂xi = _checked_rrule(f, xi) - extern(∂xi(ȳi)) - end - end - return y, (DNERule(), DNERule(), ∂x) - end - eval(Expr(:function, sig, body)) -end - -##### -##### `sum` -##### - -frule(::typeof(sum), x) = (sum(x), Rule(sum)) - -rrule(::typeof(sum), x) = (sum(x), Rule(cast)) - -function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:) - y, (_, _, ∂x) = rrule(mapreduce, f, Base.add_sum, x; dims=dims) - return y, (DNERule(), ∂x) -end - -function rrule(::typeof(sum), x::AbstractArray{<:Real}; dims=:) - y, (_, ∂x) = rrule(sum, identity, x; dims=dims) - return y, ∂x -end - -function rrule(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}; dims=:) - y = sum(abs2, x; dims=dims) - ∂x = Rule(ȳ -> 2ȳ .* x) - return y, (DNERule(), ∂x) -end - -##### -##### `mean` -##### - -_denom(x, dims::Colon) = length(x) -_denom(x, dims::Integer) = size(x, dims) -_denom(x, dims) = mapreduce(i->size(x, i), Base.mul_prod, unique(dims), init=1) - -# TODO: We have `mean(f, x; dims)` as of 1.3.0-DEV.36 - -function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:) - _, dx = rrule(sum, x; dims=dims) - n = _denom(x, dims) - return mean(x; dims=dims), Rule(ȳ -> dx(ȳ) / n) -end - -function rrule(::typeof(mean), f, x::AbstractArray{<:Real}) - _, (_, dx) = rrule(sum, f, x) - n = _denom(x, :) - return mean(f, x), (DNERule(), Rule(ȳ -> dx(ȳ) / n)) -end diff --git a/src/rules/nanmath.jl b/src/rules/nanmath.jl deleted file mode 100644 index 279963ed8..000000000 --- a/src/rules/nanmath.jl +++ /dev/null @@ -1,13 +0,0 @@ -@scalar_rule(NaNMath.sin(x), NaNMath.cos(x)) -@scalar_rule(NaNMath.cos(x), -NaNMath.sin(x)) -@scalar_rule(NaNMath.asin(x), inv(NaNMath.sqrt(1 - NaNMath.pow(x, 2)))) -@scalar_rule(NaNMath.acos(x), -inv(NaNMath.sqrt(1 - NaNMath.pow(x, 2)))) -@scalar_rule(NaNMath.acosh(x), inv(NaNMath.sqrt(NaNMath.pow(x, 2) - 1))) -@scalar_rule(NaNMath.atanh(x), inv(1 - NaNMath.pow(x, 2))) -@scalar_rule(NaNMath.log(x), inv(x)) -@scalar_rule(NaNMath.log2(x), inv(x) / NaNMath.log(oftype(x, 2))) -@scalar_rule(NaNMath.log10(x), inv(x) / NaNMath.log(oftype(x, 10))) -@scalar_rule(NaNMath.log1p(x), inv(x + 1)) -@scalar_rule(NaNMath.lgamma(x), SpecialFunctions.digamma(x)) -@scalar_rule(NaNMath.sqrt(x), inv(2 * Ω)) -@scalar_rule(NaNMath.pow(x, y), (y * NaNMath.pow(x, y - 1), Ω * NaNMath.log(x))) diff --git a/src/rules/specialfunctions.jl b/src/rules/specialfunctions.jl deleted file mode 100644 index 8019bd531..000000000 --- a/src/rules/specialfunctions.jl +++ /dev/null @@ -1,20 +0,0 @@ -@scalar_rule(SpecialFunctions.lgamma(x), SpecialFunctions.digamma(x)) -@scalar_rule(SpecialFunctions.erf(x), (2 / sqrt(π)) * exp(-x * x)) -@scalar_rule(SpecialFunctions.erfc(x), -(2 / sqrt(π)) * exp(-x * x)) -@scalar_rule(SpecialFunctions.erfi(x), (2 / sqrt(π)) * exp(x * x)) -@scalar_rule(SpecialFunctions.digamma(x), SpecialFunctions.trigamma(x)) -@scalar_rule(SpecialFunctions.trigamma(x), SpecialFunctions.polygamma(2, x)) -@scalar_rule(SpecialFunctions.airyai(x), SpecialFunctions.airyaiprime(x)) -@scalar_rule(SpecialFunctions.airyaiprime(x), x * SpecialFunctions.airyai(x)) -@scalar_rule(SpecialFunctions.airybi(x), SpecialFunctions.airybiprime(x)) -@scalar_rule(SpecialFunctions.airybiprime(x), x * SpecialFunctions.airybi(x)) -@scalar_rule(SpecialFunctions.besselj0(x), -SpecialFunctions.besselj1(x)) -@scalar_rule(SpecialFunctions.bessely0(x), -SpecialFunctions.bessely1(x)) -@scalar_rule(SpecialFunctions.invdigamma(x), inv(SpecialFunctions.trigamma(SpecialFunctions.invdigamma(x)))) -@scalar_rule(SpecialFunctions.besselj1(x), (SpecialFunctions.besselj0(x) - SpecialFunctions.besselj(2, x)) / 2) -@scalar_rule(SpecialFunctions.bessely1(x), (SpecialFunctions.bessely0(x) - SpecialFunctions.bessely(2, x)) / 2) -@scalar_rule(SpecialFunctions.gamma(x), Ω * SpecialFunctions.digamma(x)) -@scalar_rule(SpecialFunctions.erfinv(x), (sqrt(π) / 2) * exp(Ω^2)) -@scalar_rule(SpecialFunctions.erfcinv(x), -(sqrt(π) / 2) * exp(Ω^2)) -@scalar_rule(SpecialFunctions.erfcx(x), (2 * x * Ω) - (2 / sqrt(π))) -@scalar_rule(SpecialFunctions.dawson(x), 1 - (2 * x * Ω)) diff --git a/test/rules.jl b/test/rules.jl index 1fab10f6a..128b73ced 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -3,6 +3,10 @@ cool(x, y) = x + y + 1 _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) +# a rule we define so we can test rules +dummy_identity(x) = x +@scalar_rule(dummy_identity(x), One()) + @testset "rules" begin @testset "frule and rrule" begin @test frule(cool, 1) === nothing @@ -10,9 +14,9 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test rrule(cool, 1) === nothing @test rrule(cool, 1; iscool=true) === nothing - ChainRules.@scalar_rule(Main.cool(x), one(x)) + AbstractChainRules.@scalar_rule(Main.cool(x), one(x)) @test hasmethod(rrule, Tuple{typeof(cool),Number}) - ChainRules.@scalar_rule(Main.cool(x::String), "wow such dfdx") + AbstractChainRules.@scalar_rule(Main.cool(x::String), "wow such dfdx") @test hasmethod(rrule, Tuple{typeof(cool),String}) # Ensure those are the *only* methods that have been defined cool_methods = Set(m.sig for m in methods(rrule) if _second(m.sig) == typeof(cool)) @@ -28,7 +32,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test rr(1) == 1 end @testset "iterating and indexing rules" begin - _, rule = frule(+, 1) + _, rule = frule(dummy_identity, 1) i = 0 for r in rule @test r === rule @@ -38,36 +42,4 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test rule[1] == rule @test_throws BoundsError rule[2] end - @testset "helper functions" begin - # Hits fallback, since we can't update `Diagonal`s in place - X = Diagonal([1, 1]) - Y = copy(X) - @test ChainRules._update!(X, [1 2; 3 4]) == [2 2; 3 5] - @test X == Y # no change to X - - X = [1 2; 3 4] - Y = copy(X) - @test ChainRules._update!(X, Diagonal([1, 1])) == [2 2; 3 5] - @test X != Y # X has been updated - - # Reusing above X - @test ChainRules._update!(X, Zero()) === X - @test ChainRules._update!(Zero(), X) === X - @test ChainRules._update!(Zero(), Zero()) === Zero() - - X = (A=[1 0; 0 1], B=[2 2; 2 2]) - Y = deepcopy(X) - @test ChainRules._update!(X, Y) == (A=[2 0; 0 2], B=[4 4; 4 4]) - @test X.A != Y.A - @test X.B != Y.B - - try - # We defined a 2-arg method for `cool` but no `rrule` - ChainRules._checked_rrule(cool, 1.0, 2.0) - catch e - @test e isa ArgumentError - @test e.msg == "can't differentiate `cool(::Float64, ::Float64)`; no " * - "matching `rrule` is defined" - end - end end diff --git a/test/rules/array.jl b/test/rules/array.jl deleted file mode 100644 index 3a70b9ecd..000000000 --- a/test/rules/array.jl +++ /dev/null @@ -1,56 +0,0 @@ -@testset "reshape" begin - rng = MersenneTwister(1) - A = randn(rng, 4, 5) - B, (dA, dd) = rrule(reshape, A, (5, 4)) - @test B == reshape(A, (5, 4)) - @test dd isa ChainRules.DNERule - Ȳ = randn(rng, 4, 5) - Ā = dA(Ȳ) - @test Ā == reshape(Ȳ, (5, 4)) - - B, (dA, dd1, dd2) = rrule(reshape, A, 5, 4) - @test B == reshape(A, 5, 4) - @test dd1 isa ChainRules.DNERule - @test dd2 isa ChainRules.DNERule - Ȳ = randn(rng, 4, 5) - Ā = dA(Ȳ) - @test Ā == reshape(Ȳ, 5, 4) -end - -@testset "hcat" begin - rng = MersenneTwister(2) - A = randn(rng, 3, 2) - B = randn(rng, 3) - C = randn(rng, 3, 3) - H, (dA, dB, dC) = rrule(hcat, A, B, C) - @test H == hcat(A, B, C) - H̄ = randn(rng, 3, 6) - @test dA(H̄) ≈ view(H̄, :, 1:2) - @test dB(H̄) ≈ view(H̄, :, 3) - @test dC(H̄) ≈ view(H̄, :, 4:6) -end - -@testset "vcat" begin - rng = MersenneTwister(3) - A = randn(rng, 2, 4) - B = randn(rng, 1, 4) - C = randn(rng, 3, 4) - V, (dA, dB, dC) = rrule(vcat, A, B, C) - @test V == vcat(A, B, C) - V̄ = randn(rng, 6, 4) - @test dA(V̄) ≈ view(V̄, 1:2, :) - @test dB(V̄) ≈ view(V̄, 3:3, :) - @test dC(V̄) ≈ view(V̄, 4:6, :) -end - -@testset "fill" begin - y, (dv, dd) = rrule(fill, 44, 4) - @test y == [44, 44, 44, 44] - @test dd isa ChainRules.DNERule - @test dv(ones(Int, 4)) == 4 - - y, (dv, dd) = rrule(fill, 2.0, (3, 3, 3)) - @test y == fill(2.0, (3, 3, 3)) - @test dd isa ChainRules.DNERule - @test dv(ones(3, 3, 3)) ≈ 27.0 -end diff --git a/test/rules/base.jl b/test/rules/base.jl deleted file mode 100644 index fd93ce9d5..000000000 --- a/test/rules/base.jl +++ /dev/null @@ -1,107 +0,0 @@ -function test_scalar(f, f′, xs...) - for r = (rrule, frule) - rr = r(f, xs...) - @test rr !== nothing - fx, ∂x = rr - @test fx == f(xs...) - @test ∂x(1) ≈ f′(xs...) atol=1e-5 - end -end - -@testset "base" begin - @testset "Trig" begin - @testset "Basics" for x = (Float64(π), Complex(π, π/2)) - test_scalar(sin, cos, x) - test_scalar(cos, x -> -sin(x), x) - test_scalar(tan, x -> 1 + tan(x)^2, x) - test_scalar(sec, x -> sec(x) * tan(x), x) - test_scalar(csc, x -> -csc(x) * cot(x), x) - test_scalar(cot, x -> -1 - cot(x)^2, x) - test_scalar(sinpi, x -> π * cospi(x), x) - test_scalar(cospi, x -> -π * sinpi(x), x) - end - @testset "Hyperbolic" for x = (Float64(π), Complex(π, π/2)) - test_scalar(sinh, cosh, x) - test_scalar(cosh, sinh, x) - test_scalar(tanh, x -> sech(x)^2, x) - test_scalar(sech, x -> -tanh(x) * sech(x), x) - test_scalar(csch, x -> -coth(x) * csch(x), x) - test_scalar(coth, x -> -csch(x)^2, x) - end - @testset "Degrees" begin - x = 45.0 - test_scalar(sind, x -> (π / 180) * cosd(x), x) - test_scalar(cosd, x -> (-π / 180) * sind(x), x) - test_scalar(tand, x -> (π / 180) * (1 + tand(x)^2), x) - test_scalar(secd, x -> (π / 180) * secd(x) * tand(x), x) - test_scalar(cscd, x -> (-π / 180) * cscd(x) * cotd(x), x) - test_scalar(cotd, x -> (-π / 180) * (1 + cotd(x)^2), x) - end - @testset "Inverses" for x = (1.0, Complex(1.0, 0.25)) - test_scalar(asin, x -> 1 / sqrt(1 - x^2), x) - test_scalar(acos, x -> -1 / sqrt(1 - x^2), x) - test_scalar(atan, x -> 1 / (1 + x^2), x) - test_scalar(asec, x -> 1 / (abs(x) * sqrt(x^2 - 1)), x) - test_scalar(acsc, x -> -1 / (abs(x) * sqrt(x^2 - 1)), x) - test_scalar(acot, x -> -1 / (1 + x^2), x) - end - @testset "Inverse hyperbolic" for x = (0.0, Complex(0.0, 0.25)) - test_scalar(asinh, x -> 1 / sqrt(x^2 + 1), x) - test_scalar(acosh, x -> 1 / sqrt(x^2 - 1), x + 1) # +1 accounts for domain - test_scalar(atanh, x -> 1 / (1 - x^2), x) - test_scalar(asech, x -> -1 / x / sqrt(1 - x^2), x) - test_scalar(acsch, x -> -1 / abs(x) / sqrt(1 + x^2), x) - test_scalar(acoth, x -> 1 / (1 - x^2), x + 1) - end - @testset "Inverse degrees" begin - x = 1.0 - test_scalar(asind, x -> 180 / π / sqrt(1 - x^2), x) - test_scalar(acosd, x -> -180 / π / sqrt(1 - x^2), x) - test_scalar(atand, x -> 180 / π / (1 + x^2), x) - test_scalar(asecd, x -> 180 / π / abs(x) / sqrt(x^2 - 1), x) - test_scalar(acscd, x -> -180 / π / abs(x) / sqrt(x^2 - 1), x) - test_scalar(acotd, x -> -180 / π / (1 + x^2), x) - end - # TODO: atan2 sincos - end - @testset "Misc. Tests" begin - @testset "*(x, y)" begin - x, y = rand(3, 2), rand(2, 5) - z, (dx, dy) = rrule(*, x, y) - - @test z == x * y - - z̄ = rand(3, 5) - - @test dx(z̄) == extern(accumulate(zeros(3, 2), dx, z̄)) - @test dy(z̄) == extern(accumulate(zeros(2, 5), dy, z̄)) - - test_accumulation(rand(3, 2), dx, z̄, z̄ * y') - test_accumulation(rand(2, 5), dy, z̄, x' * z̄) - end - @testset "hypot(x, y)" begin - x, y = rand(2) - h, dxy = frule(hypot, x, y) - - @test extern(dxy(One(), Zero())) === y / h - @test extern(dxy(Zero(), One())) === x / h - - cx, cy = cast((One(), Zero())), cast((Zero(), One())) - dx, dy = extern(dxy(cx, cy)) - @test dx === y / h - @test dy === x / h - - cx, cy = cast((rand(), Zero())), cast((Zero(), rand())) - dx, dy = extern(dxy(cx, cy)) - @test dx === y / h * cx.value[1] - @test dy === x / h * cy.value[2] - end - end - @testset "identity" begin - rng = MersenneTwister(1) - n = 4 - rrule_test(identity, randn(rng), (randn(rng), randn(rng))) - rrule_test(identity, randn(rng, 4), (randn(rng, 4), randn(rng, 4))) - end -end -# TODO: Non-trig stuff diff --git a/test/rules/blas.jl b/test/rules/blas.jl deleted file mode 100644 index 04e421149..000000000 --- a/test/rules/blas.jl +++ /dev/null @@ -1,27 +0,0 @@ -@testset "BLAS" begin - @testset "gemm" begin - rng = MersenneTwister(1) - dims = 3:5 - for m in dims, n in dims, p in dims, tA in ('N', 'T'), tB in ('N', 'T') - α = randn(rng) - A = randn(rng, tA === 'N' ? (m, n) : (n, m)) - B = randn(rng, tB === 'N' ? (n, p) : (p, n)) - C = gemm(tA, tB, α, A, B) - ȳ = randn(rng, size(C)...) - rrule_test(gemm, ȳ, (tA, nothing), (tB, nothing), (α, randn(rng)), - (A, randn(rng, size(A))), (B, randn(rng, size(B)))) - end - end - @testset "gemv" begin - rng = MersenneTwister(2) - for n in 3:5, m in 3:5, t in ('N', 'T') - α = randn(rng) - A = randn(rng, m, n) - x = randn(rng, t === 'N' ? n : m) - y = α * (t === 'N' ? A : A') * x - ȳ = randn(rng, size(y)...) - rrule_test(gemv, ȳ, (t, nothing), (α, randn(rng)), (A, randn(rng, size(A))), - (x, randn(rng, size(x)))) - end - end -end diff --git a/test/rules/broadcast.jl b/test/rules/broadcast.jl deleted file mode 100644 index 81c0c578e..000000000 --- a/test/rules/broadcast.jl +++ /dev/null @@ -1,20 +0,0 @@ -@testset "broadcast" begin - @testset "Misc. Tests" begin - @testset "sin.(x)" begin - x = rand(3, 3) - y, (dsin, dx) = rrule(broadcast, sin, x) - - @test y == sin.(x) - @test extern(dx(One())) == cos.(x) - - x̄, ȳ = rand(), rand() - @test extern(accumulate(x̄, dx, ȳ)) == x̄ .+ ȳ .* cos.(x) - - x̄, ȳ = Zero(), rand(3, 3) - @test extern(accumulate(x̄, dx, ȳ)) == ȳ .* cos.(x) - - x̄, ȳ = Zero(), cast(rand(3, 3)) - @test extern(accumulate(x̄, dx, ȳ)) == extern(ȳ) .* cos.(x) - end - end -end diff --git a/test/rules/linalg/dense.jl b/test/rules/linalg/dense.jl deleted file mode 100644 index dcc861b5b..000000000 --- a/test/rules/linalg/dense.jl +++ /dev/null @@ -1,120 +0,0 @@ -function generate_well_conditioned_matrix(rng, N) - A = randn(rng, N, N) - return A * A' + I -end - -@testset "linalg" begin - @testset "dot" begin - @testset "Vector" begin - rng, M = MersenneTwister(123456), 3 - x, y = randn(rng, M), randn(rng, M) - ẋ, ẏ = randn(rng, M), randn(rng, M) - x̄, ȳ = randn(rng, M), randn(rng, M) - frule_test(dot, (x, ẋ), (y, ẏ)) - rrule_test(dot, randn(rng), (x, x̄), (y, ȳ)) - end - @testset "Matrix" begin - rng, M, N = MersenneTwister(123456), 3, 4 - x, y = randn(rng, M, N), randn(rng, M, N) - ẋ, ẏ = randn(rng, M, N), randn(rng, M, N) - x̄, ȳ = randn(rng, M, N), randn(rng, M, N) - frule_test(dot, (x, ẋ), (y, ẏ)) - rrule_test(dot, randn(rng), (x, x̄), (y, ȳ)) - end - @testset "Array{T, 3}" begin - rng, M, N, P = MersenneTwister(123456), 3, 4, 5 - x, y = randn(rng, M, N, P), randn(rng, M, N, P) - ẋ, ẏ = randn(rng, M, N, P), randn(rng, M, N, P) - x̄, ȳ = randn(rng, M, N, P), randn(rng, M, N, P) - frule_test(dot, (x, ẋ), (y, ẏ)) - rrule_test(dot, randn(rng), (x, x̄), (y, ȳ)) - end - end - @testset "inv" begin - rng, N = MersenneTwister(123456), 3 - B = generate_well_conditioned_matrix(rng, N) - frule_test(inv, (B, randn(rng, N, N))) - rrule_test(inv, randn(rng, N, N), (B, randn(rng, N, N))) - end - @testset "det" begin - rng, N = MersenneTwister(123456), 3 - B = generate_well_conditioned_matrix(rng, N) - frule_test(det, (B, randn(rng, N, N))) - rrule_test(det, randn(rng), (B, randn(rng, N, N))) - end - @testset "logdet" begin - rng, N = MersenneTwister(123456), 3 - B = generate_well_conditioned_matrix(rng, N) - frule_test(logdet, (B, randn(rng, N, N))) - rrule_test(logdet, randn(rng), (B, randn(rng, N, N))) - end - @testset "tr" begin - rng, N = MersenneTwister(123456), 4 - frule_test(tr, (randn(rng, N, N), randn(rng, N, N))) - rrule_test(tr, randn(rng), (randn(rng, N, N), randn(rng, N, N))) - end - @testset "*" begin - rng = MersenneTwister(123456) - dims = [3,4,5] - for n in dims, m in dims, p in dims - n > 3 && n == m == p && continue # don't need to test square case multiple times - A = randn(rng, m, n) - B = randn(rng, n, p) - Ȳ = randn(rng, m, p) - rrule_test(*, Ȳ, (A, randn(rng, m, n)), (B, randn(rng, n, p))) - end - end - @testset "$f" for f in [/, \] - rng = MersenneTwister(42) - for n in 3:5, m in 3:5 - A = randn(rng, m, n) - B = randn(rng, m, n) - Ȳ = randn(rng, size(f(A, B))) - rrule_test(f, Ȳ, (A, randn(rng, m, n)), (B, randn(rng, m, n))) - end - # Vectors - x = randn(rng, 10) - y = randn(rng, 10) - ȳ = randn(rng, size(f(x, y))...) - rrule_test(f, ȳ, (x, randn(rng, 10)), (y, randn(rng, 10))) - if f == (/) - @testset "$T on the RHS" for T in (Diagonal, UpperTriangular, LowerTriangular) - RHS = T(randn(rng, T == Diagonal ? 10 : (10, 10))) - Y = randn(rng, 5, 10) - Ȳ = randn(rng, size(f(Y, RHS))...) - rrule_test(f, Ȳ, (Y, randn(rng, size(Y))), (RHS, randn(rng, size(RHS)))) - end - else - @testset "$T on LHS" for T in (Diagonal, UpperTriangular, LowerTriangular) - LHS = T(randn(rng, T == Diagonal ? 10 : (10, 10))) - y = randn(rng, 10) - ȳ = randn(rng, size(f(LHS, y))...) - rrule_test(f, ȳ, (LHS, randn(rng, size(LHS))), (y, randn(rng, 10))) - Y = randn(rng, 10, 10) - Ȳ = randn(rng, 10, 10) - rrule_test(f, Ȳ, (LHS, randn(rng, size(LHS))), (Y, randn(rng, size(Y)))) - end - @testset "Matrix $f Vector" begin - X = randn(rng, 10, 4) - y = randn(rng, 10) - ȳ = randn(rng, size(f(X, y))...) - rrule_test(f, ȳ, (X, randn(rng, size(X))), (y, randn(rng, 10))) - end - @testset "Vector $f Matrix" begin - x = randn(rng, 10) - Y = randn(rng, 10, 4) - ȳ = randn(rng, size(f(x, Y))...) - rrule_test(f, ȳ, (x, randn(rng, size(x))), (Y, randn(rng, size(Y)))) - end - end - end - @testset "norm" begin - rng = MersenneTwister(3) - for dims in [(), (5,), (3, 2), (7, 3, 2)] - A = randn(rng, dims...) - p = randn(rng) - ȳ = randn(rng) - rrule_test(norm, ȳ, (A, randn(rng, dims...)), (p, randn(rng))) - end - end -end diff --git a/test/rules/linalg/factorization.jl b/test/rules/linalg/factorization.jl deleted file mode 100644 index 68641c429..000000000 --- a/test/rules/linalg/factorization.jl +++ /dev/null @@ -1,101 +0,0 @@ -using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblocked_rev - -@testset "Factorizations" begin - @testset "svd" begin - rng = MersenneTwister(2) - for n in [4, 6, 10], m in [3, 5, 10] - X = randn(rng, n, m) - F, dX = rrule(svd, X) - for p in [:U, :S, :V] - Y, (dF, dp) = rrule(getproperty, F, p) - @test dp isa ChainRules.DNERule - Ȳ = randn(rng, size(Y)...) - X̄_ad = dX(dF(Ȳ)) - X̄_fd = j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X) - @test X̄_ad ≈ X̄_fd rtol=1e-6 atol=1e-6 - end - @test_throws ArgumentError rrule(getproperty, F, :Vt) - end - @testset "accumulate!" begin - X = [1.0 2.0; 3.0 4.0; 5.0 6.0] - F, dX = rrule(svd, X) - X̄ = (U=zeros(3, 2), S=zeros(2), V=zeros(2, 2)) - for p in [:U, :S, :V] - Y, (dF, _) = rrule(getproperty, F, p) - Ȳ = ones(size(Y)...) - ChainRules.accumulate!(X̄, dF, Ȳ) - end - @test X̄.U ≈ ones(3, 2) atol=1e-6 - @test X̄.S ≈ ones(2) atol=1e-6 - @test X̄.V ≈ ones(2, 2) atol=1e-6 - end - @testset "Helper functions" begin - X = randn(rng, 10, 10) - Y = randn(rng, 10, 10) - @test ChainRules._mulsubtrans!(copy(X), Y) ≈ Y .* (X - X') - @test ChainRules._eyesubx!(copy(X)) ≈ I - X - @test ChainRules._add!(copy(X), Y) ≈ X + Y - end - end - @testset "cholesky" begin - rng = MersenneTwister(4) - @testset "the thing" begin - X = generate_well_conditioned_matrix(rng, 10) - V = generate_well_conditioned_matrix(rng, 10) - F, dX = rrule(cholesky, X) - for p in [:U, :L] - Y, (dF, dp) = rrule(getproperty, F, p) - @test dp isa ChainRules.DNERule - Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn(rng, size(Y))) - # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp` - # machinery from FDM because that isn't set up to respect necessary special - # properties of the input. In the case of the Cholesky factorization, we - # need the input to be Hermitian. - X̄_ad = dot(dX(dF(Ȳ)), V) - X̄_fd = central_fdm(5, 1)() do ε - dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p)) - end - @test X̄_ad ≈ X̄_fd rtol=1e-6 atol=1e-6 - end - end - @testset "helper functions" begin - A = randn(rng, 5, 5) - r, d, B2, c = level2partition(A, 4, false) - R, D, B3, C = level3partition(A, 4, 4, false) - @test all(r .== R') - @test all(d .== D) - @test B2[1] == B3[1] - @test all(c .== C) - - # Check that level 2 partition with `upper == true` is consistent with `false` - rᵀ, dᵀ, B2ᵀ, cᵀ = level2partition(transpose(A), 4, true) - @test r == rᵀ - @test d == dᵀ - @test B2' == B2ᵀ - @test c == cᵀ - - # Check that level 3 partition with `upper == true` is consistent with `false` - R, D, B3, C = level3partition(A, 2, 4, false) - Rᵀ, Dᵀ, B3ᵀ, Cᵀ = level3partition(transpose(A), 2, 4, true) - @test transpose(R) == Rᵀ - @test transpose(D) == Dᵀ - @test transpose(B3) == B3ᵀ - @test transpose(C) == Cᵀ - - A = Matrix(LowerTriangular(randn(rng, 10, 10))) - Ā = Matrix(LowerTriangular(randn(rng, 10, 10))) - # NOTE: BLAS gets angry if we don't materialize the Transpose objects first - B = Matrix(transpose(A)) - B̄ = Matrix(transpose(Ā)) - @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 1, false) - @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 3, false) - @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 5, false) - @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 10, false) - @test chol_unblocked_rev(Ā, A, false) ≈ transpose(chol_unblocked_rev(B̄, B, true)) - - @test chol_unblocked_rev(B̄, B, true) ≈ chol_blocked_rev(B̄, B, 1, true) - @test chol_unblocked_rev(B̄, B, true) ≈ chol_blocked_rev(B̄, B, 5, true) - @test chol_unblocked_rev(B̄, B, true) ≈ chol_blocked_rev(B̄, B, 10, true) - end - end -end diff --git a/test/rules/linalg/structured.jl b/test/rules/linalg/structured.jl deleted file mode 100644 index 2d82a2d3f..000000000 --- a/test/rules/linalg/structured.jl +++ /dev/null @@ -1,33 +0,0 @@ -@testset "Structured Matrices" begin - @testset "Diagonal" begin - rng, N = MersenneTwister(123456), 3 - rrule_test(Diagonal, randn(rng, N, N), (randn(rng, N), randn(rng, N))) - D = Diagonal(randn(rng, N)) - rrule_test(Diagonal, D, (randn(rng, N), randn(rng, N))) - # Concrete type instead of UnionAll - rrule_test(typeof(D), D, (randn(rng, N), randn(rng, N))) - end - @testset "diag" begin - rng, N = MersenneTwister(123456), 7 - rrule_test(diag, randn(rng, N), (randn(rng, N, N), randn(rng, N, N))) - rrule_test(diag, randn(rng, N), (Diagonal(randn(rng, N)), randn(rng, N, N))) - rrule_test(diag, randn(rng, N), (randn(rng, N, N), Diagonal(randn(rng, N)))) - rrule_test(diag, randn(rng, N), (Diagonal(randn(rng, N)), Diagonal(randn(rng, N)))) - end - @testset "Symmetric" begin - rng, N = MersenneTwister(123456), 3 - rrule_test(Symmetric, randn(rng, N, N), (randn(rng, N, N), randn(rng, N, N))) - end - @testset "$f" for f in (Adjoint, adjoint, Transpose, transpose) - rng = MersenneTwister(32) - n = 5 - m = 3 - rrule_test(f, randn(rng, m, n), (randn(rng, n, m), randn(rng, n, m))) - rrule_test(f, randn(rng, 1, n), (randn(rng, n), randn(rng, n))) - end - @testset "$T" for T in (UpperTriangular, LowerTriangular) - rng = MersenneTwister(33) - n = 5 - rrule_test(T, T(randn(rng, n, n)), (randn(rng, n, n), randn(rng, n, n))) - end -end diff --git a/test/rules/mapreduce.jl b/test/rules/mapreduce.jl deleted file mode 100644 index 8ee0a1e1d..000000000 --- a/test/rules/mapreduce.jl +++ /dev/null @@ -1,78 +0,0 @@ -@testset "Maps and Reductions" begin - @testset "map" begin - rng = MersenneTwister(42) - n = 10 - x = randn(rng, n) - vx = randn(rng, n) - ȳ = randn(rng, n) - rrule_test(map, ȳ, (sin, nothing), (x, vx)) - rrule_test(map, ȳ, (+, nothing), (x, vx), (randn(rng, n), randn(rng, n))) - end - @testset "mapreduce" begin - rng = MersenneTwister(6) - n = 10 - x = randn(rng, n) - vx = randn(rng, n) - ȳ = randn(rng) - rrule_test(mapreduce, ȳ, (sin, nothing), (+, nothing), (x, vx)) - # With keyword arguments (not yet supported in rrule_test) - X = randn(rng, n, n) - y, (_, _, dx) = rrule(mapreduce, abs2, +, X; dims=2) - ȳ = randn(rng, size(y)) - x̄_ad = dx(ȳ) - x̄_fd = j′vp(central_fdm(5, 1), x->mapreduce(abs2, +, x; dims=2), ȳ, X) - @test x̄_ad ≈ x̄_fd atol=1e-9 rtol=1e-9 - end - @testset "$f" for f in (mapfoldl, mapfoldr) - rng = MersenneTwister(10) - n = 7 - x = randn(rng, n) - vx = randn(rng, n) - ȳ = randn(rng) - rrule_test(f, ȳ, (cos, nothing), (+, nothing), (x, vx)) - end - @testset "sum" begin - @testset "Vector" begin - rng, M = MersenneTwister(123456), 3 - frule_test(sum, (randn(rng, M), randn(rng, M))) - rrule_test(sum, randn(rng), (randn(rng, M), randn(rng, M))) - end - @testset "Matrix" begin - rng, M, N = MersenneTwister(123456), 3, 4 - frule_test(sum, (randn(rng, M, N), randn(rng, M, N))) - rrule_test(sum, randn(rng), (randn(rng, M, N), randn(rng, M, N))) - end - @testset "Array{T, 3}" begin - rng, M, N, P = MersenneTwister(123456), 3, 7, 11 - frule_test(sum, (randn(rng, M, N, P), randn(rng, M, N, P))) - rrule_test(sum, randn(rng), (randn(rng, M, N, P), randn(rng, M, N, P))) - end - @testset "function argument" begin - rng = MersenneTwister(1) - n = 8 - rrule_test(sum, randn(rng), (cos, nothing), (randn(rng, n), randn(rng, n))) - rrule_test(sum, randn(rng), (abs2, nothing), (randn(rng, n), randn(rng, n))) - end - @testset "keyword arguments" begin - rng = MersenneTwister(33) - n = 4 - X = randn(rng, n, n) - y, dX = rrule(sum, X; dims=2) - ȳ = randn(rng, size(y)) - x̄_ad = dX(ȳ) - x̄_fd = j′vp(central_fdm(5, 1), x->sum(x, dims=2), ȳ, X) - @test x̄_ad ≈ x̄_fd atol=1e-9 rtol=1e-9 - end - end - @testset "mean" begin - rng = MersenneTwister(999) - n = 9 - rrule_test(mean, randn(rng), (abs2, nothing), (randn(rng, n), randn(rng, n))) - X = randn(rng, n, n) - y, dX = rrule(mean, X; dims=1) - ȳ = randn(rng, size(y)) - X̄_ad = dX(ȳ) - X̄_fd = j′vp(central_fdm(5, 1), x->mean(x, dims=1), ȳ, X) - @test X̄_ad ≈ X̄_fd rtol=1e-9 atol=1e-9 - end -end diff --git a/test/rules/nanmath.jl b/test/rules/nanmath.jl deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/rules/specialfunctions.jl b/test/rules/specialfunctions.jl deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/runtests.jl b/test/runtests.jl index cd8a25b73..6a757f3a4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,30 +1,16 @@ # TODO: more tests! -using ChainRules, Test, FDM, LinearAlgebra, LinearAlgebra.BLAS, Random, Statistics -using ChainRules: extern, accumulate, accumulate!, store!, @scalar_rule, +using AbstractChainRules, Test +using LinearAlgebra: Diagonal +using AbstractChainRules: extern, accumulate, accumulate!, store!, @scalar_rule, Wirtinger, wirtinger_primal, wirtinger_conjugate, add_wirtinger, mul_wirtinger, Zero, add_zero, mul_zero, One, add_one, mul_one, Casted, cast, add_casted, mul_casted, DNE, Thunk, Casted, DNERule using Base.Broadcast: broadcastable -import LinearAlgebra: dot -include("test_util.jl") +#include("test_util.jl") -@testset "ChainRules" begin +@testset "AbstractChainRules" begin include("differentials.jl") include("rules.jl") - @testset "rules" begin - include(joinpath("rules", "base.jl")) - include(joinpath("rules", "array.jl")) - include(joinpath("rules", "mapreduce.jl")) - @testset "linalg" begin - include(joinpath("rules", "linalg", "dense.jl")) - include(joinpath("rules", "linalg", "structured.jl")) - include(joinpath("rules", "linalg", "factorization.jl")) - end - include(joinpath("rules", "broadcast.jl")) - include(joinpath("rules", "blas.jl")) - include(joinpath("rules", "nanmath.jl")) - include(joinpath("rules", "specialfunctions.jl")) - end end diff --git a/test/test_util.jl b/test/test_util.jl index db8a7b520..35cdd27c8 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -18,7 +18,7 @@ end function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) xs, ẋs = collect(zip(xẋs...)) - Ω, dΩ_rule = ChainRules.frule(f, xs...) + Ω, dΩ_rule = AbstractChainRules.frule(f, xs...) @test f(xs...) == Ω dΩ_ad, dΩ_fd = dΩ_rule(ẋs...), jvp(fdm, xs->f(xs...), (xs, ẋs)) @@ -38,14 +38,14 @@ All keyword arguments except for `fdm` are passed to `isapprox`. """ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) # Check correctness of evaluation. - fx, dx = ChainRules.rrule(f, x) + fx, dx = AbstractChainRules.rrule(f, x) @test fx ≈ f(x) # Correctness testing via finite differencing. x̄_ad, x̄_fd = dx(ȳ), j′vp(fdm, f, ȳ, x) @test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) - # Assuming x̄_ad to be correct, check that other ChainRules mechanisms are correct. + # Assuming x̄_ad to be correct, check that other AbstractChainRules mechanisms are correct. test_accumulation(x̄, dx, ȳ, x̄_ad) test_accumulation(Zero(), dx, ȳ, x̄_ad) end @@ -97,7 +97,7 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm end end - # Assuming the above to be correct, check that other ChainRules mechanisms are correct. + # Assuming the above to be correct, check that other AbstractChainRules mechanisms are correct. for (x̄, rule, x̄_ad) in zip(x̄s, rules, x̄s_ad) x̄ === nothing && continue test_accumulation(x̄, rule, ȳ, x̄_ad) @@ -119,7 +119,7 @@ function Base.isapprox(d_ad::Thunk, d_fd; kwargs...) end function test_accumulation(x̄, dx, ȳ, partial) - @test all(extern(ChainRules.add(x̄, partial)) .≈ extern(x̄) .+ extern(partial)) + @test all(extern(AbstractChainRules.add(x̄, partial)) .≈ extern(x̄) .+ extern(partial)) test_accumulate(x̄, dx, ȳ, partial) test_accumulate!(x̄, dx, ȳ, partial) test_store!(x̄, dx, ȳ, partial)