diff --git a/Project.toml b/Project.toml index 600e1c9ed..349cadfdb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,17 +1,19 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.5.2" +version = "0.5.3" [deps] MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" [compat] -julia = "^1.0" MuladdMacro = "0.2.1" +StaticArrays = "0.11, 0.12" +julia = "^1.0" [extras] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "LinearAlgebra"] +test = ["Test", "LinearAlgebra", "StaticArrays"] diff --git a/src/differentials/does_not_exist.jl b/src/differentials/does_not_exist.jl index 1ef76cae6..e5cc79d09 100644 --- a/src/differentials/does_not_exist.jl +++ b/src/differentials/does_not_exist.jl @@ -33,6 +33,7 @@ function extern(x::DoesNotExist) end Base.Broadcast.broadcastable(::DoesNotExist) = Ref(DoesNotExist()) +Base.Broadcast.broadcasted(::Type{DoesNotExist}) = DoesNotExist() Base.iterate(x::DoesNotExist) = (x, nothing) Base.iterate(::DoesNotExist, ::Any) = nothing diff --git a/src/differentials/one.jl b/src/differentials/one.jl index 5c528f35e..141a0bc6b 100644 --- a/src/differentials/one.jl +++ b/src/differentials/one.jl @@ -8,6 +8,7 @@ struct One <: AbstractDifferential end extern(x::One) = true # true is a strong 1. Base.Broadcast.broadcastable(::One) = Ref(One()) +Base.Broadcast.broadcasted(::Type{One}) = One() Base.iterate(x::One) = (x, nothing) Base.iterate(::One, ::Any) = nothing diff --git a/src/differentials/zero.jl b/src/differentials/zero.jl index 2249a0b95..3903d3263 100644 --- a/src/differentials/zero.jl +++ b/src/differentials/zero.jl @@ -8,6 +8,7 @@ struct Zero <: AbstractDifferential end extern(x::Zero) = false # false is a strong 0. E.g. `false * NaN = 0.0` Base.Broadcast.broadcastable(::Zero) = Ref(Zero()) +Base.Broadcast.broadcasted(::Type{Zero}) = Zero() Base.iterate(x::Zero) = (x, nothing) Base.iterate(::Zero, ::Any) = nothing diff --git a/test/rules.jl b/test/rules.jl index 1845f5739..cf5063994 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -1,5 +1,6 @@ ####### # Demo setup +using StaticArrays: @SVector cool(x) = x + 1 cool(x, y) = x + y + 1 @@ -11,6 +12,9 @@ dummy_identity(x) = x nice(x) = 1 @scalar_rule(nice(x), Zero()) +very_nice(x, y) = x + y +@scalar_rule(very_nice(x, y), (One(), One())) + ####### _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @@ -46,4 +50,9 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test nice_pushforward === 0 rrx, nice_pullback = rrule(nice, 1) @test (NO_FIELDS, 0) === nice_pullback(1) + + sx = @SVector [1, 2] + sy = @SVector [3, 4] + # This actually is testing that @scalar_rule and `One()` play nice together, w.r.t broadcasting + @inferred frule(very_nice, 1, 2, Zero(), sx, sy) end