Skip to content

Commit

Permalink
Merge pull request #196 from JuliaDiff/os/fix-Cuda-gradient
Browse files Browse the repository at this point in the history
fix `Cuda` gradients
  • Loading branch information
oscardssmith authored Nov 18, 2024
2 parents c98f1ab + 6fd0ea4 commit 5963b5d
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions src/gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,8 @@ function finite_difference_gradient!(
fx, c1, c2, c3 = cache.fx, cache.c1, cache.c2, cache.c3
if fdtype != Val(:complex) && ArrayInterface.fast_scalar_indexing(c2)
@. c2 = compute_epsilon(fdtype, one(eltype(x)), relstep, absstep, dir)
copyto!(c1, x)
end
copyto!(c3, x)
copyto!(c1, x)
if fdtype == Val(:forward)
@inbounds for i eachindex(x)
if ArrayInterface.fast_scalar_indexing(c2)
Expand Down Expand Up @@ -273,6 +272,7 @@ function finite_difference_gradient!(
end
end
elseif fdtype == Val(:central)
copyto!(c3, x)
@inbounds for i eachindex(x)
if ArrayInterface.fast_scalar_indexing(c2)
epsilon = ArrayInterface.allowed_getindex(c2, i) * dir
Expand All @@ -296,9 +296,8 @@ function finite_difference_gradient!(
ArrayInterface.allowed_setindex!(c3, x_old, i)
end
elseif fdtype == Val(:complex) && returntype <: Real
copyto!(c1, x)
epsilon_complex = eps(real(eltype(x)))
# we use c1 here to avoid typing issues with x
epsilon_complex = eps(real(eltype(x)))
@inbounds for i eachindex(x)
c1_old = ArrayInterface.allowed_getindex(c1, i)
ArrayInterface.allowed_setindex!(c1, c1_old + im * epsilon_complex, i)
Expand Down

0 comments on commit 5963b5d

Please sign in to comment.