Skip to content

Commit

Permalink
Make One, Zero, and DNE inferable when broadcasting with Static…
Browse files Browse the repository at this point in the history
…Arrays (#96)

* Make `One`, `Zero`, and `DNE` inferable when broadcasting with StaticArrays

* New release

* Remove superfluous [deps]

* Update test/rules.jl

Co-Authored-By: Lyndon White <oxinabox@ucc.asn.au>

Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
  • Loading branch information
YingboMa and oxinabox authored Jan 13, 2020
1 parent 932b704 commit c782e40
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 3 deletions.
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.5.2"
version = "0.5.3"

[deps]
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"

[compat]
julia = "^1.0"
MuladdMacro = "0.2.1"
StaticArrays = "0.11, 0.12"
julia = "^1.0"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "LinearAlgebra"]
test = ["Test", "LinearAlgebra", "StaticArrays"]
1 change: 1 addition & 0 deletions src/differentials/does_not_exist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ function extern(x::DoesNotExist)
end

Base.Broadcast.broadcastable(::DoesNotExist) = Ref(DoesNotExist())
Base.Broadcast.broadcasted(::Type{DoesNotExist}) = DoesNotExist()

Base.iterate(x::DoesNotExist) = (x, nothing)
Base.iterate(::DoesNotExist, ::Any) = nothing
1 change: 1 addition & 0 deletions src/differentials/one.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ struct One <: AbstractDifferential end
extern(x::One) = true # true is a strong 1.

Base.Broadcast.broadcastable(::One) = Ref(One())
Base.Broadcast.broadcasted(::Type{One}) = One()

Base.iterate(x::One) = (x, nothing)
Base.iterate(::One, ::Any) = nothing
1 change: 1 addition & 0 deletions src/differentials/zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ struct Zero <: AbstractDifferential end
extern(x::Zero) = false # false is a strong 0. E.g. `false * NaN = 0.0`

Base.Broadcast.broadcastable(::Zero) = Ref(Zero())
Base.Broadcast.broadcasted(::Type{Zero}) = Zero()

Base.iterate(x::Zero) = (x, nothing)
Base.iterate(::Zero, ::Any) = nothing
Expand Down
9 changes: 9 additions & 0 deletions test/rules.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#######
# Demo setup
using StaticArrays: @SVector

cool(x) = x + 1
cool(x, y) = x + y + 1
Expand All @@ -11,6 +12,9 @@ dummy_identity(x) = x
nice(x) = 1
@scalar_rule(nice(x), Zero())

very_nice(x, y) = x + y
@scalar_rule(very_nice(x, y), (One(), One()))

#######

_second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
Expand Down Expand Up @@ -46,4 +50,9 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
@test nice_pushforward === 0
rrx, nice_pullback = rrule(nice, 1)
@test (NO_FIELDS, 0) === nice_pullback(1)

sx = @SVector [1, 2]
sy = @SVector [3, 4]
# This actually is testing that @scalar_rule and `One()` play nice together, w.r.t broadcasting
@inferred frule(very_nice, 1, 2, Zero(), sx, sy)
end

2 comments on commit c782e40

@YingboMa
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/7871

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.3 -m "<description of version>" c782e40c0f6b46a13029edc87b3fec889f807e80
git push origin v0.5.3

Please sign in to comment.