From 8eaba053581a9132cceb64a5fd79941f45f0e6b7 Mon Sep 17 00:00:00 2001 From: Jadon Clugston <34165782+jClugstor@users.noreply.github.com> Date: Fri, 8 Nov 2024 11:17:37 -0500 Subject: [PATCH] Fix exponentiation for `NaNMath.pow` (#717) * fix NaNMath exponentiation * reuse code * fix * add tests * Update src/dual.jl Co-authored-by: David Widmann * import NaNMath * oops, no begin * Update test/GradientTest.jl Co-authored-by: David Widmann --------- Co-authored-by: David Widmann --- src/dual.jl | 6 +++--- test/DerivativeTest.jl | 12 ++++++++++++ test/GradientTest.jl | 11 +++++++++++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/dual.jl b/src/dual.jl index ca3a2cbe..7e8ec110 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -552,7 +552,7 @@ end # exponentiation # #----------------# -for f in (:(Base.:^), :(NaNMath.pow)) +for (f, log) in ((:(Base.:^), :(Base.log)), (:(NaNMath.pow), :(NaNMath.log))) @eval begin @define_binary_dual_op( $f, @@ -565,7 +565,7 @@ for f in (:(Base.:^), :(NaNMath.pow)) elseif iszero(vx) && vy > 0 logval = zero(vx) else - logval = expv * log(vx) + logval = expv * ($log)(vx) end new_partials = _mul_partials(partials(x), partials(y), powval, logval) return Dual{Txy}(expv, new_partials) @@ -583,7 +583,7 @@ for f in (:(Base.:^), :(NaNMath.pow)) begin v = value(y) expv = ($f)(x, v) - deriv = (iszero(x) && v > 0) ? zero(expv) : expv*log(x) + deriv = (iszero(x) && v > 0) ? zero(expv) : expv*($log)(x) return Dual{Ty}(expv, deriv * partials(y)) end ) diff --git a/test/DerivativeTest.jl b/test/DerivativeTest.jl index 4b7463c8..4de1a6de 100644 --- a/test/DerivativeTest.jl +++ b/test/DerivativeTest.jl @@ -1,6 +1,7 @@ module DerivativeTest import Calculus +import NaNMath using Test using Random @@ -93,6 +94,17 @@ end @test (x -> ForwardDiff.derivative(y -> x^y, 1.5))(0.0) === 0.0 end +@testset "exponentiation with NaNMath" begin + @test isnan(ForwardDiff.derivative(x -> NaNMath.pow(NaN, x), 1.0)) + @test isnan(ForwardDiff.derivative(x -> NaNMath.pow(x,NaN), 1.0)) + @test !isnan(ForwardDiff.derivative(x -> NaNMath.pow(1.0, x),1.0)) + @test isnan(ForwardDiff.derivative(x -> NaNMath.pow(x,0.5), -1.0)) + + @test isnan(ForwardDiff.derivative(x -> x^NaN, 2.0)) + @test ForwardDiff.derivative(x -> x^2.0,2.0) == 4.0 + @test_throws DomainError ForwardDiff.derivative(x -> x^0.5, -1.0) +end + @testset "dimension error for derivative" begin @test_throws DimensionMismatch ForwardDiff.derivative(sum, fill(2pi, 3)) end diff --git a/test/GradientTest.jl b/test/GradientTest.jl index 5adfc8c7..a386c479 100644 --- a/test/GradientTest.jl +++ b/test/GradientTest.jl @@ -1,6 +1,7 @@ module GradientTest import Calculus +import NaNMath using Test using LinearAlgebra @@ -200,6 +201,16 @@ end @test ForwardDiff.gradient(L -> logdet(L), Matrix(L)) ≈ [1.0 -1.3333333333333337; 0.0 1.666666666666667] end +@testset "gradient for exponential with NaNMath" begin + @test isnan(ForwardDiff.gradient(x -> NaNMath.pow(x[1],x[1]), [NaN, 1.0])[1]) + @test ForwardDiff.gradient(x -> NaNMath.pow(x[1], x[2]), [1.0, 1.0]) == [1.0, 0.0] + @test isnan(ForwardDiff.gradient((x) -> NaNMath.pow(x[1], x[2]), [-1.0, 0.5])[1]) + + @test isnan(ForwardDiff.gradient(x -> x[1]^x[2], [NaN, 1.0])[1]) + @test ForwardDiff.gradient(x -> x[1]^x[2], [1.0, 1.0]) == [1.0, 0.0] + @test_throws DomainError ForwardDiff.gradient(x -> x[1]^x[2], [-1.0, 0.5]) +end + @testset "branches in mul!" begin a, b = rand(3,3), rand(3,3)