You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am looking to create a function that uses a truncated SVD and take the gradient for a complex input (but has a real output).
I find that with Zygote this works fine for real inputs, but does not work for complex inputs (e.g. if you uncomment the line in the below example), even though I believe Zygote supports gradients of functions on complex inputs.
Please could you suggest how to resolve this.
Thanks,
Joe
using LinearAlgebra, Zygote
X = kron(rand(Float64, 4,4), rand(Float64, 4,4)) + kron(rand(Float64, 4,4), rand(Float64, 4,4))
#X = kron(rand(ComplexF64, 4,4), rand(ComplexF64, 4,4)) + kron(rand(ComplexF64, 4,4), rand(ComplexF64, 4,4))
function foo(X)
F = svd(X)
return abs(sum(F.S[1:2]))
end
G = foo'(X)
The text was updated successfully, but these errors were encountered:
Support for complex inputs is very much a function-by-function thing in Zygote. Some may work without dedicated AD rules, but in the case of svd that likely isn't true. Zygote uses the SVD rule from https://github.com/JuliaDiff/ChainRules.jl and I think that only handles real-numbered inputs, so I'd recommend opening a feature request there.
Hello,
I am looking to create a function that uses a truncated SVD and take the gradient for a complex input (but has a real output).
I find that with Zygote this works fine for real inputs, but does not work for complex inputs (e.g. if you uncomment the line in the below example), even though I believe Zygote supports gradients of functions on complex inputs.
Please could you suggest how to resolve this.
Thanks,
Joe
The text was updated successfully, but these errors were encountered: