From 6fd0ea4bf2d3ca43d0bd2531436dfd4a0d8f4058 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Mon, 18 Nov 2024 16:46:39 -0500 Subject: [PATCH] fix `Cuda` gradients fixes https://github.com/JuliaDiff/FiniteDiff.jl/issues/195 This really needs some tests as the non-fast_scalar_indexing path was just completely broken. --- src/gradients.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/gradients.jl b/src/gradients.jl index 54fec28..acc56d0 100644 --- a/src/gradients.jl +++ b/src/gradients.jl @@ -239,9 +239,8 @@ function finite_difference_gradient!( fx, c1, c2, c3 = cache.fx, cache.c1, cache.c2, cache.c3 if fdtype != Val(:complex) && ArrayInterface.fast_scalar_indexing(c2) @. c2 = compute_epsilon(fdtype, one(eltype(x)), relstep, absstep, dir) - copyto!(c1, x) end - copyto!(c3, x) + copyto!(c1, x) if fdtype == Val(:forward) @inbounds for i ∈ eachindex(x) if ArrayInterface.fast_scalar_indexing(c2) @@ -273,6 +272,7 @@ function finite_difference_gradient!( end end elseif fdtype == Val(:central) + copyto!(c3, x) @inbounds for i ∈ eachindex(x) if ArrayInterface.fast_scalar_indexing(c2) epsilon = ArrayInterface.allowed_getindex(c2, i) * dir @@ -296,9 +296,8 @@ function finite_difference_gradient!( ArrayInterface.allowed_setindex!(c3, x_old, i) end elseif fdtype == Val(:complex) && returntype <: Real - copyto!(c1, x) - epsilon_complex = eps(real(eltype(x))) # we use c1 here to avoid typing issues with x + epsilon_complex = eps(real(eltype(x))) @inbounds for i ∈ eachindex(x) c1_old = ArrayInterface.allowed_getindex(c1, i) ArrayInterface.allowed_setindex!(c1, c1_old + im * epsilon_complex, i)