Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cut down to bare minimum for AbstractChainRules #1

Merged
merged 4 commits into from
Aug 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ os:
julia:
- 1.0
- 1.1
- 1.2
- nightly
matrix:
allow_failures:
Expand All @@ -14,7 +15,7 @@ notifications:
# uncomment the following lines to override the default test script
#script:
# - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
# - julia -e 'Pkg.clone(pwd()); Pkg.build("ChainRules"); Pkg.test("ChainRules"; coverage=true)'
# - julia -e 'Pkg.clone(pwd()); Pkg.build("AbstractChainRules"); Pkg.test("AbstractChainRules"; coverage=true)'
after_success:
# push coverage results to Coveralls
- julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Coveralls.submit(Coveralls.process_folder())'
Expand Down
2 changes: 1 addition & 1 deletion LICENSE.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
The ChainRules.jl package is licensed under the MIT "Expat" License:
The AbstractChainRules.jl package is licensed under the MIT "Expat" License:

> Copyright (c) 2018: Jarrett Revels.
>
Expand Down
17 changes: 6 additions & 11 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.0.1"
name = "AbstractChainRules"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.1.0"

[deps]
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"


[compat]
Cassette = "^0.2"
FDM = "^0.6"
julia = "^1.0"

[extras]
FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[targets]
test = ["FDM", "Random", "Test"]
test = ["Test", "LinearAlgebra"]
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# ChainRules
# AbstractChainRules

[![Travis](https://travis-ci.org/JuliaDiff/ChainRules.jl.svg?branch=master)](https://travis-ci.org/JuliaDiff/ChainRules.jl)
[![Coveralls](https://coveralls.io/repos/github/JuliaDiff/ChainRules.jl/badge.svg?branch=master)](https://coveralls.io/github/JuliaDiff/ChainRules.jl?branch=master)
[![](https://img.shields.io/badge/docs-latest-blue.svg)](https://JuliaDiff.github.io/ChainRules.jl/latest)
[![Travis](https://travis-ci.org/JuliaDiff/AbstractChainRules.jl.svg?branch=master)](https://travis-ci.org/JuliaDiff/AbstractChainRules.jl)
[![Coveralls](https://coveralls.io/repos/github/JuliaDiff/AbstractChainRules.jl/badge.svg?branch=master)](https://coveralls.io/github/JuliaDiff/AbstractChainRules.jl?branch=master)
[![](https://img.shields.io/badge/docs-latest-blue.svg)](https://JuliaDiff.github.io/AbstractChainRules.jl/latest)

The ChainRules package provides a variety of common utilities that can be used by downstream automatic differentiation (AD) tools to define and execute forward-, reverse-, and mixed-mode primitives.
The AbstractChainRules package provides a variety of common utilities that can be used by downstream automatic differentiation (AD) tools to define and execute forward-, reverse-, and mixed-mode primitives.

This package is a WIP; the framework is essentially there, but there are a bunch of TODOs, virtually no tests, etc. PRs welcome! Documentation is incoming, which should help if you'd like to contribute.

Expand All @@ -18,4 +18,4 @@ Here are some of the basic goals for the package:

- Control-inverted design: rule authors can fully specify derivatives in a concise manner while naturally allowing the caller to compute only what they need.

The ChainRules source code follows the [YASGuide](https://github.com/jrevels/YASGuide).
The AbstractChainRules source code follows the [YASGuide](https://github.com/jrevels/YASGuide).
4 changes: 0 additions & 4 deletions REQUIRE

This file was deleted.

10 changes: 5 additions & 5 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
using ChainRules
using AbstractChainRules
using Documenter

makedocs(modules=[ChainRules],
sitename="ChainRules",
makedocs(modules=[AbstractChainRules],
sitename="AbstractChainRules",
authors="Jarrett Revels and other contributors",
pages=["Introduction" => "index.md",
"Getting Started" => "getting_started.md",
"ChainRules API Documentation" => "api.md"])
"AbstractChainRules API Documentation" => "api.md"])

deploydocs(repo="github.com/JuliaDiff/ChainRules.jl.git")
deploydocs(repo="github.com/JuliaDiff/AbstractChainRules.jl.git")
36 changes: 18 additions & 18 deletions docs/src/api.md
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
# ChainRules API Documentation
# AbstractChainRules API Documentation

```@docs
ChainRules.frule
ChainRules.rrule
ChainRules.AbstractRule
ChainRules.Rule
ChainRules.DNERule
ChainRules.WirtingerRule
ChainRules.accumulate
ChainRules.accumulate!
ChainRules.store!
AbstractChainRules.frule
AbstractChainRules.rrule
AbstractChainRules.AbstractRule
AbstractChainRules.Rule
AbstractChainRules.DNERule
AbstractChainRules.WirtingerRule
AbstractChainRules.accumulate
AbstractChainRules.accumulate!
AbstractChainRules.store!
```

```@docs
ChainRules.AbstractDifferential
ChainRules.extern
ChainRules.Casted
ChainRules.Wirtinger
ChainRules.Thunk
ChainRules.Zero
ChainRules.DNE
ChainRules.One
AbstractChainRules.AbstractDifferential
AbstractChainRules.extern
AbstractChainRules.Casted
AbstractChainRules.Wirtinger
AbstractChainRules.Thunk
AbstractChainRules.Zero
AbstractChainRules.DNE
AbstractChainRules.One
```
10 changes: 5 additions & 5 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
```@meta
DocTestSetup = :(using ChainRules)
CurrentModule = ChainRules
DocTestSetup = :(using AbstractChainRules)
CurrentModule = AbstractChainRules
```

# ChainRules
# AbstractChainRules

Hello! Welcome to ChainRules's documentation.
Hello! Welcome to AbstractChainRules's documentation.

For an initial overview of ChainRules, please see the README. Otherwise, feel free to peruse available documentation via the sidebar.
For an initial overview of AbstractChainRules, please see the README. Otherwise, feel free to peruse available documentation via the sidebar.
11 changes: 11 additions & 0 deletions src/AbstractChainRules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module AbstractChainRules
using Cassette
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable

export AbstractRule, Rule, frule, rrule
export @scalar_rule, @thunk
export extern, cast, store!, Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule

include("differentials.jl")
include("rules.jl")
end # module
34 changes: 0 additions & 34 deletions src/ChainRules.jl

This file was deleted.

2 changes: 1 addition & 1 deletion src/differentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ function mul_wirtinger(a::Wirtinger, b::Wirtinger)
such that we assume the chain rule application is of the form `f_a ∘ f_b`
instead of `f_b ∘ f_a`. However, picking such a convention is likely to
lead to silently incorrect derivatives due to commutativity assumptions
in downstream generic code that deals with the reals. Thus, ChainRules
in downstream generic code that deals with the reals. Thus, AbstractChainRules
makes this operation an error instead.
""")
end
Expand Down
63 changes: 10 additions & 53 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ return that calculated differential value.
For example:

```julia-repl
julia> using ChainRules: frule, rrule, AbstractRule
julia> using AbstractChainRules: frule, rrule, AbstractRule

julia> x, y = rand(2);

Expand Down Expand Up @@ -61,7 +61,7 @@ Base.getindex(rule::AbstractRule, i::Integer) = i == 1 ? rule : throw(BoundsErro
"""
accumulate(Δ, rule::AbstractRule, args...)

Return `Δ + rule(args...)` evaluated in a manner that supports ChainRules'
Return `Δ + rule(args...)` evaluated in a manner that supports AbstractChainRules'
various `AbstractDifferential` types.

This method intended to be customizable for specific rules/input types. For
Expand Down Expand Up @@ -112,32 +112,6 @@ See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref)
"""
store!(Δ, rule::AbstractRule, args...) = materialize!(Δ, broadcastable(rule(args...)))

# Special purpose updating for operations which can be done in-place. This function is
# just internal and free-form; it is not a method of `accumulate!` directly as it does
# not adhere to the expected method signature form, i.e. `accumulate!(value, rule, args)`.
# Instead it's `_update!(old, new, extrastuff...)` and is not specific to any particular
# rule.

_update!(x, y) = x + y
_update!(x::Array{T,N}, y::AbstractArray{T,N}) where {T,N} = x .+= y

_update!(x, ::Zero) = x
_update!(::Zero, y) = y
_update!(::Zero, ::Zero) = Zero()

function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}) where Ns
return NamedTuple{Ns}(map(p->_update!(getproperty(x, p), getproperty(y, p)), Ns))
end

function _update!(x::NamedTuple, y, p::Symbol)
new = NamedTuple{(p,)}((_update!(getproperty(x, p), y),))
return merge(x, new)
end

function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}, p::Symbol) where Ns
return _update!(x, getproperty(y, p), p)
end

#####
##### `Rule`
#####
Expand Down Expand Up @@ -232,7 +206,7 @@ rules, where e.g. `frule` is used within a `rrule` definition. For example,
broadcasted functions may not themselves be forward-mode *primitives*, but are
often forward-mode *differentiable*.

ChainRules, by design, is decoupled from any specific AD implementation. How,
AbstractChainRules, by design, is decoupled from any specific AD implementation. How,
then, do we know which AD to fall back to when there isn't a primitive defined?

Well, if you're a greedy AD implementation, you can just overload `frule` and/or
Expand All @@ -244,12 +218,12 @@ It turns out, Cassette solves this problem nicely by allowing AD authors to
overload the fallbacks w.r.t. their own context. Example using ForwardDiff:

```
using ChainRules, ForwardDiff, Cassette
using AbstractChainRules, ForwardDiff, Cassette

Cassette.@context MyChainRuleCtx

# ForwardDiff, itself, can call `my_frule` instead of
# `frule` to utilize the ForwardDiff-injected ChainRules
# `frule` to utilize the ForwardDiff-injected AbstractChainRules
# infrastructure
my_frule(args...) = Cassette.overdub(MyChainRuleCtx(), frule, args...)

Expand Down Expand Up @@ -377,23 +351,6 @@ See also: [`frule`](@ref), [`AbstractRule`](@ref), [`@scalar_rule`](@ref)
"""
rrule(::Any, ::Vararg{Any}; kwargs...) = nothing

@noinline function _throw_checked_rrule_error(f, args...; kwargs...)
io = IOBuffer()
print(io, "can't differentiate `", f, '(')
join(io, map(arg->string("::", typeof(arg)), args), ", ")
if !isempty(kwargs)
print(io, ";")
join(io, map(((k, v),)->string(k, "=", v), kwargs), ", ")
end
print(io, ")`; no matching `rrule` is defined")
throw(ArgumentError(String(take!(io))))
end

function _checked_rrule(f, args...; kwargs...)
r = rrule(f, args...; kwargs...)
r isa Nothing && _throw_checked_rrule_error(f, args...; kwargs...)
return r
end

#####
##### macros
Expand All @@ -410,15 +367,15 @@ A convenience macro that generates simple scalar forward or reverse rules using
the provided partial derivatives. Specifically, generates the corresponding
methods for `frule` and `rrule`:

function ChainRules.frule(::typeof(f), x₁::Number, x₂::Number, ...)
function AbstractChainRules.frule(::typeof(f), x₁::Number, x₂::Number, ...)
Ω = f(x₁, x₂, ...)
\$(statement₁, statement₂, ...)
return Ω, (Rule((Δx₁, Δx₂, ...) -> ∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...),
Rule((Δx₁, Δx₂, ...) -> ∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...),
...)
end

function ChainRules.rrule(::typeof(f), x₁::Number, x₂::Number, ...)
function AbstractChainRules.rrule(::typeof(f), x₁::Number, x₂::Number, ...)
Ω = f(x₁, x₂, ...)
\$(statement₁, statement₂, ...)
return Ω, (Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...),
Expand Down Expand Up @@ -453,7 +410,7 @@ is equivalent to:
(∂f₂_∂x₁, ∂f₂_∂x₂, ...),
...)

For examples, see ChainRules' `rules` directory.
For examples, see AbstractChainRules' `rules` directory.

See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref)
"""
Expand Down Expand Up @@ -493,12 +450,12 @@ macro scalar_rule(call, maybe_setup, partials...)
forward_rules = length(forward_rules) == 1 ? forward_rules[1] : Expr(:tuple, forward_rules...)
reverse_rules = length(reverse_rules) == 1 ? reverse_rules[1] : Expr(:tuple, reverse_rules...)
return quote
function ChainRules.frule(::typeof($f), $(inputs...))
function AbstractChainRules.frule(::typeof($f), $(inputs...))
$(esc(:Ω)) = $call
$(setup_stmts...)
return $(esc(:Ω)), $forward_rules
end
function ChainRules.rrule(::typeof($f), $(inputs...))
function AbstractChainRules.rrule(::typeof($f), $(inputs...))
$(esc(:Ω)) = $call
$(setup_stmts...)
return $(esc(:Ω)), $reverse_rules
Expand Down
Loading