From 7b2d2d6f97da22db1860ecb7ebbb1a44e64c6027 Mon Sep 17 00:00:00 2001 From: Songchen Tan Date: Wed, 1 Feb 2023 10:02:14 -0500 Subject: [PATCH] Fix format and bump version --- Project.toml | 2 +- src/chainrules.jl | 36 +++++++++++++++++++++++++----------- src/scalar.jl | 4 +++- 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index 101868f..ca1f428 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TaylorDiff" uuid = "b36ab563-344f-407b-a36a-4f200bebf99c" authors = ["Songchen Tan "] -version = "0.1.3" +version = "0.2.0" [deps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" diff --git a/src/chainrules.jl b/src/chainrules.jl index fe6166e..0784edb 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -55,28 +55,40 @@ end ProjectTo(::T) where {T <: TaylorScalar} = ProjectTo{T}() (p::ProjectTo{T})(x::T) where {T <: TaylorScalar} = x -ProjectTo(x::AbstractArray{T}) where {T <: TaylorScalar} = ProjectTo{AbstractArray}(; element=ProjectTo(zero(T)), axes=axes(x)) +function ProjectTo(x::AbstractArray{T}) where {T <: TaylorScalar} + ProjectTo{AbstractArray}(; element = ProjectTo(zero(T)), axes = axes(x)) +end (p::ProjectTo{AbstractArray{T}})(x::AbstractArray{T}) where {T <: TaylorScalar} = x accum_sum(xs::AbstractArray{T}; dims = :) where {T <: TaylorScalar} = sum(xs, dims = dims) -TaylorNumeric{T<:TaylorScalar} = Union{T, AbstractArray{<:T}} +TaylorNumeric{T <: TaylorScalar} = Union{T, AbstractArray{<:T}} -@adjoint broadcasted(::typeof(+), xs::Union{Numeric, TaylorNumeric}...) = broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...) +@adjoint function broadcasted(::typeof(+), xs::Union{Numeric, TaylorNumeric}...) + broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...) +end -struct TaylorOneElement{T,N,I,A} <: AbstractArray{T,N} +struct TaylorOneElement{T, N, I, A} <: AbstractArray{T, N} val::T ind::I axes::A - TaylorOneElement(val::T, ind::I, axes::A) where {T<:TaylorScalar, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes) + function TaylorOneElement(val::T, ind::I, + axes::A) where {T <: TaylorScalar, I <: NTuple{N, Int}, + A <: NTuple{N, AbstractUnitRange}} where {N} + new{T, N, I, A}(val, ind, axes) + end end Base.size(A::TaylorOneElement) = map(length, A.axes) Base.axes(A::TaylorOneElement) = A.axes -Base.getindex(A::TaylorOneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T)) +function Base.getindex(A::TaylorOneElement{T, N}, i::Vararg{Int, N}) where {T, N} + ifelse(i == A.ind, A.val, zero(T)) +end -∇getindex(x::AbstractArray{T, N}, inds) where {T <: TaylorScalar, N} = dy -> begin - dx = TaylorOneElement(dy, inds, axes(x)) - return (_project(x, dx), map(_->nothing, inds)...) +function ∇getindex(x::AbstractArray{T, N}, inds) where {T <: TaylorScalar, N} + dy -> begin + dx = TaylorOneElement(dy, inds, axes(x)) + return (_project(x, dx), map(_ -> nothing, inds)...) + end end @generated function mul_adjoint(Ω::TaylorScalar{T, N}, x::TaylorScalar{T, N}) where {T, N} @@ -93,12 +105,14 @@ rrule(::typeof(*), x::TaylorScalar) = rrule(identity, x) function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar) function times_pullback2(Ω̇) ΔΩ = unthunk(Ω̇) - return (NoTangent(), ProjectTo(x)(mul_adjoint(ΔΩ, y)), ProjectTo(y)(mul_adjoint(ΔΩ, x))) + return (NoTangent(), ProjectTo(x)(mul_adjoint(ΔΩ, y)), + ProjectTo(y)(mul_adjoint(ΔΩ, x))) end return x * y, times_pullback2 end -function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar, more::TaylorScalar...) +function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar, + more::TaylorScalar...) Ω2, back2 = rrule(*, x, y) Ω3, back3 = rrule(*, Ω2, z) Ω4, back4 = rrule(*, Ω3, more...) diff --git a/src/scalar.jl b/src/scalar.jl index e42d3fc..391b4e7 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -76,7 +76,9 @@ end # Number-like convention (I patched them after removing <: Number) convert(::Type{TaylorScalar{T, N}}, x::TaylorScalar{T, N}) where {T, N} = x -convert(::Type{TaylorScalar{T, N}}, x::S) where {T, S, N} = TaylorScalar{T, N}(convert(T, x)) +function convert(::Type{TaylorScalar{T, N}}, x::S) where {T, S, N} + TaylorScalar{T, N}(convert(T, x)) +end for op in (:+, :-, :*, :/) @eval @inline $op(a::TaylorScalar, b::Number) = $op(promote(a, b)...) @eval @inline $op(a::Number, b::TaylorScalar) = $op(promote(a, b)...)