Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DiffRules-based definitions for complex-valued functions #577

Merged
merged 4 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ForwardDiff"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.26"
version = "0.10.27"

[deps]
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"
Expand Down
40 changes: 36 additions & 4 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,38 @@ macro define_ternary_dual_op(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_b
return esc(defs)
end

# Support complex-valued functions such as `hankelh1`
function dual_definition_retval(::Val{T}, val::Real, deriv::Real, partial::Partials) where {T}
return Dual{T}(val, deriv * partial)
end
function dual_definition_retval(::Val{T}, val::Real, deriv1::Real, partial1::Partials, deriv2::Real, partial2::Partials) where {T}
return Dual{T}(val, _mul_partials(partial1, partial2, deriv1, deriv2))
end
function dual_definition_retval(::Val{T}, val::Complex, deriv::Union{Real,Complex}, partial::Partials) where {T}
reval, imval = reim(val)
if deriv isa Real
p = deriv * partial
return Complex(Dual{T}(reval, p), Dual{T}(imval, zero(p)))
else
rederiv, imderiv = reim(deriv)
return Complex(Dual{T}(reval, rederiv * partial), Dual{T}(imval, imderiv * partial))
end
end
function dual_definition_retval(::Val{T}, val::Complex, deriv1::Union{Real,Complex}, partial1::Partials, deriv2::Union{Real,Complex}, partial2::Partials) where {T}
reval, imval = reim(val)
if deriv1 isa Real && deriv2 isa Real
p = _mul_partials(partial1, partial2, deriv1, deriv2)
return Complex(Dual{T}(reval, p), Dual{T}(imval, zero(p)))
else
rederiv1, imderiv1 = reim(deriv1)
rederiv2, imderiv2 = reim(deriv2)
return Complex(
Dual{T}(reval, _mul_partials(partial1, partial2, rederiv1, rederiv2)),
Dual{T}(imval, _mul_partials(partial1, partial2, imderiv1, imderiv2)),
)
end
end

function unary_dual_definition(M, f)
FD = ForwardDiff
Mf = M == :Base ? f : :($M.$f)
Expand All @@ -206,7 +238,7 @@ function unary_dual_definition(M, f)
@inline function $M.$f(d::$FD.Dual{T}) where T
x = $FD.value(d)
$work
return $FD.Dual{T}(val, deriv * $FD.partials(d))
return $FD.dual_definition_retval(Val{T}(), val, deriv, $FD.partials(d))
end
end
end
Expand Down Expand Up @@ -236,17 +268,17 @@ function binary_dual_definition(M, f)
begin
vx, vy = $FD.value(x), $FD.value(y)
$xy_work
return $FD.Dual{Txy}(val, $FD._mul_partials($FD.partials(x), $FD.partials(y), dvx, dvy))
return $FD.dual_definition_retval(Val{Txy}(), val, dvx, $FD.partials(x), dvy, $FD.partials(y))
end,
begin
vx = $FD.value(x)
$x_work
return $FD.Dual{Tx}(val, dvx * $FD.partials(x))
return $FD.dual_definition_retval(Val{Tx}(), val, dvx, $FD.partials(x))
end,
begin
vy = $FD.value(y)
$y_work
return $FD.Dual{Ty}(val, dvy * $FD.partials(y))
return $FD.dual_definition_retval(Val{Ty}(), val, dvy, $FD.partials(y))
end
)
end
Expand Down
48 changes: 38 additions & 10 deletions test/DualTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ for N in (0,3), M in (0,4), V in (Int, Float32)

if V != Int
for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing)
if f in (:hankelh1, :hankelh1x, :hankelh2, :hankelh2x, :/, :rem2pi)
if f in (:/, :rem2pi)
continue # Skip these rules
elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
continue # Skip rules for methods not defined in the current scope
Expand All @@ -457,9 +457,20 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
end
@eval begin
x = rand() + $modifier
dx = $M.$f(Dual{TestTag()}(x, one(x)))
@test value(dx) == $M.$f(x)
@test partials(dx, 1) == $deriv
dx = @inferred $M.$f(Dual{TestTag()}(x, one(x)))
actualval = $M.$f(x)
@assert actualval isa Real || actualval isa Complex
if actualval isa Real
@test dx isa Dual{TestTag()}
@test value(dx) == actualval
@test partials(dx, 1) == $deriv
else
@test dx isa Complex{<:Dual{TestTag()}}
@test value(real(dx)) == real(actualval)
@test value(imag(dx)) == imag(actualval)
@test partials(real(dx), 1) == real($deriv)
@test partials(imag(dx), 1) == imag($deriv)
end
end
elseif arity == 2
derivs = DiffRules.diffrule(M, f, :x, :y)
Expand All @@ -472,14 +483,31 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
end
@eval begin
x, y = $x, $y
dx = $M.$f(Dual{TestTag()}(x, one(x)), y)
dy = $M.$f(x, Dual{TestTag()}(y, one(y)))
dx = @inferred $M.$f(Dual{TestTag()}(x, one(x)), y)
dy = @inferred $M.$f(x, Dual{TestTag()}(y, one(y)))
actualdx = $(derivs[1])
actualdy = $(derivs[2])
@test value(dx) == $M.$f(x, y)
@test value(dy) == value(dx)
@test partials(dx, 1) ≈ actualdx nans=true
@test partials(dy, 1) ≈ actualdy nans=true
actualval = $M.$f(x, y)
@assert actualval isa Real || actualval isa Complex
if actualval isa Real
@test dx isa Dual{TestTag()}
@test dy isa Dual{TestTag()}
@test value(dx) == actualval
@test value(dy) == actualval
@test partials(dx, 1) ≈ actualdx nans=true
@test partials(dy, 1) ≈ actualdy nans=true
else
@test dx isa Complex{<:Dual{TestTag()}}
@test dy isa Complex{<:Dual{TestTag()}}
@test real(value(dx)) == real(actualval)
@test real(value(dy)) == real(actualval)
@test imag(value(dx)) == imag(actualval)
@test imag(value(dy)) == imag(actualval)
@test partials(real(dx), 1) ≈ real(actualdx) nans=true
@test partials(real(dy), 1) ≈ real(actualdy) nans=true
@test partials(imag(dx), 1) ≈ imag(actualdx) nans=true
@test partials(imag(dy), 1) ≈ imag(actualdy) nans=true
end
end
end
end
Expand Down