You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am on Julia 1.8.5, Enzyme main (2ccf4b) and CUDA v4.0.1. This works okay:
using Enzyme, CUDA
CUDA.limit!(CUDA.CU_LIMIT_MALLOC_HEAP_SIZE, 1*1024^3)
functionkernel!(a, b_ref)
b = b_ref[]
a[threadIdx().x] = a[threadIdx().x] * b
returnnothingendfunctiongrad_kernel!(a, da, b, db)
Enzyme.autodiff_deferred(
Reverse,
kernel!,
Const,
Duplicated(a, da),
Duplicated(b, db),
)
returnnothingend
a = CUDA.rand(256)
da =zero(a) .+1.0f0
b =CuArray([2.0f0])
db =CuArray([0.0f0])
CUDA.@sync@cuda threads=256 blocks=1grad_kernel!(a, da, b, db)
println(db)
Float32[121.32266]
However if I try and use a Ref to avoid the array then the gradient is zero:
a = CUDA.rand(256)
da =zero(a) .+1.0f0
b =Ref(2.0f0)
db =Ref(0.0f0)
CUDA.@sync@cuda threads=256 blocks=1grad_kernel!(a, da, b, db)
println(db)
Base.RefValue{Float32}(0.0f0)
This could also be achieved with Active, but then I need to reduce the gradients either inside or outside the kernel:
functionkernel_2!(a, b)
a[threadIdx().x] = a[threadIdx().x] * b
returnnothingendfunctiongrad_kernel_2!(a, da, b, db)
grads = Enzyme.autodiff_deferred(
Reverse,
kernel_2!,
Const,
Duplicated(a, da),
Active(b),
)
db[threadIdx().x] = grads[1][2]
returnnothingend
a = CUDA.rand(256)
da =zero(a) .+1.0f0
b =2.0f0
db = CUDA.zeros(256)
CUDA.@sync@cuda threads=256 blocks=1grad_kernel_2!(a, da, b, db)
println(sum(db))
127.38301
The text was updated successfully, but these errors were encountered:
That wouldn't work with Ref, and also would break the adapt that happens now. But it may be something worth considering, yes, as it would make Ref behave more like users expect.
I am on Julia 1.8.5, Enzyme main (2ccf4b) and CUDA v4.0.1. This works okay:
However if I try and use a
Ref
to avoid the array then the gradient is zero:This could also be achieved with
Active
, but then I need to reduce the gradients either inside or outside the kernel:The text was updated successfully, but these errors were encountered: