Skip to content

Commit

Permalink
fix Cuda gradients
Browse files Browse the repository at this point in the history
fixes #195

This really needs some tests as the non-fast_scalar_indexing path was just completely broken.
  • Loading branch information
oscardssmith authored Nov 18, 2024
1 parent c98f1ab commit 6fd0ea4
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions src/gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,8 @@ function finite_difference_gradient!(
fx, c1, c2, c3 = cache.fx, cache.c1, cache.c2, cache.c3
if fdtype != Val(:complex) && ArrayInterface.fast_scalar_indexing(c2)
@. c2 = compute_epsilon(fdtype, one(eltype(x)), relstep, absstep, dir)
copyto!(c1, x)
end
copyto!(c3, x)
copyto!(c1, x)
if fdtype == Val(:forward)
@inbounds for i eachindex(x)
if ArrayInterface.fast_scalar_indexing(c2)
Expand Down Expand Up @@ -273,6 +272,7 @@ function finite_difference_gradient!(
end
end
elseif fdtype == Val(:central)
copyto!(c3, x)
@inbounds for i eachindex(x)
if ArrayInterface.fast_scalar_indexing(c2)
epsilon = ArrayInterface.allowed_getindex(c2, i) * dir
Expand All @@ -296,9 +296,8 @@ function finite_difference_gradient!(
ArrayInterface.allowed_setindex!(c3, x_old, i)
end
elseif fdtype == Val(:complex) && returntype <: Real
copyto!(c1, x)
epsilon_complex = eps(real(eltype(x)))
# we use c1 here to avoid typing issues with x
epsilon_complex = eps(real(eltype(x)))
@inbounds for i eachindex(x)
c1_old = ArrayInterface.allowed_getindex(c1, i)
ArrayInterface.allowed_setindex!(c1, c1_old + im * epsilon_complex, i)
Expand Down

0 comments on commit 6fd0ea4

Please sign in to comment.