-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Distributions 0.22 #23
Conversation
Unfortunately, due to lines such as DistributionsAD.jl/src/multivariate.jl Lines 89 to 91 in ab14b91
|
Sounds good. |
Btw IIUC, the different behavior is not really an issue in this case. For one, we are not committing any additional type piracy now than we were earlier. It's just going to use a diagonal covariance for |
Oh I see the problem, it's the squaring, sorry. |
So while testing the matrix case, I found some other bugs and fixed them including a performance bug that was probably the reason why |
Here is a MWE of a failed test case if you want to investigate: using Distributions, DistributionsAD, ForwardDiff, Tracker
function f(x)
@assert length(x) == 8
C = reshape(x[1:4], 2, 2)
val = reshape(x[5:8], 2, 2)
return sum(logpdf(MvNormal(C), val))
end
x = [1.0, 0.0, 0.0, 1.0]; x = [x; ones(4)];
Tracker.data(Tracker.gradient(f, x)[1])
ForwardDiff.gradient(f, x) CC: @willtebbutt |
I will add more tests to raise the coverage, then if there are no objections after that, I will merge. We can look into the test failures later if we can't figure it out soon. |
I don't have time right now to look into these test failures, but since the main motivation for this PR initially was to support Distributions 0.22, I think it can be merged. As a general comment, I'm not sure if one should actually define three different Tracker-compatible Gaussian distributions, it somehow doesn't feel "right". It feels like they differ just in the type/structure of the covariance matrix, and hence IMO the implementation should also handle this on the level of the covariance matrices, similar to how |
I agree, let's do that in another PR. |
Regarding the incorrect gradients @mohamed82008 mentioned in his comment above, the problem with ForwardDiff seems to be that basically all linear algebra implementations that exploit the sparsity of matrices (by, e.g., calling using ForwardDiff
using LinearAlgebra
@show ForwardDiff.gradient(x -> sum(x \ ones(2)), Matrix(I, 2, 2))
@show ForwardDiff.gradient(x -> sum(cholesky(x) \ ones(2)), Matrix(I, 2, 2)) |
No description provided.