Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Distributions 0.22 #23

Merged
merged 7 commits into from
Jan 16, 2020
Merged

Support Distributions 0.22 #23

merged 7 commits into from
Jan 16, 2020

Conversation

mohamed82008
Copy link
Member

No description provided.

@devmotion
Copy link
Member

Unfortunately, due to lines such as

MvNormal(m::TrackedVector{<:Real}, A::UniformScaling{<:TrackedReal}) = MvNormal(m, A.λ)
MvNormal(m::AbstractVector{<:Real}, A::UniformScaling{<:TrackedReal}) = MvNormal(m, A.λ)
MvNormal(m::TrackedVector{<:Real}, A::UniformScaling{<:Real}) = MvNormal(m, A.λ)
the breaking change introduced in Distributions 0.22 will lead to different behaviour of MvNormal constructors in Distributions and DistributionsAD. I don't really see how one could support both Distributions 0.21 and 0.22, so I guess the way forward would be to fix the implementation of MvNormal in DistributionsAD and drop support for Distributions 0.21.

@mohamed82008
Copy link
Member Author

Sounds good.

@mohamed82008
Copy link
Member Author

Btw IIUC, the different behavior is not really an issue in this case. For one, we are not committing any additional type piracy now than we were earlier. It's just going to use a diagonal covariance for UniformScaling here when Distributions is using a ScalMat, is this about right? I can fix it to use a ScalMat-like implementation here instead but just checking that this is the only "problem".

@mohamed82008
Copy link
Member Author

Oh I see the problem, it's the squaring, sorry.

src/multivariate.jl Outdated Show resolved Hide resolved
src/multivariate.jl Outdated Show resolved Hide resolved
src/multivariate.jl Outdated Show resolved Hide resolved
src/multivariate.jl Outdated Show resolved Hide resolved
@mohamed82008
Copy link
Member Author

So while testing the matrix case, I found some other bugs and fixed them including a performance bug that was probably the reason why Multi of MvNormal was slow, (CC: @xukai92). I also found 2 failed tests cases for MvNormal which I don't know why yet. You can find them in broken_mult_cont_dists in runtests.jl.

@mohamed82008
Copy link
Member Author

mohamed82008 commented Jan 16, 2020

Here is a MWE of a failed test case if you want to investigate:

using Distributions, DistributionsAD, ForwardDiff, Tracker
function f(x)
	@assert length(x) == 8
	C = reshape(x[1:4], 2, 2)
	val = reshape(x[5:8], 2, 2)
	return sum(logpdf(MvNormal(C), val))
end
x = [1.0, 0.0, 0.0, 1.0]; x = [x; ones(4)];
Tracker.data(Tracker.gradient(f, x)[1])
ForwardDiff.gradient(f, x)

CC: @willtebbutt

@mohamed82008
Copy link
Member Author

I will add more tests to raise the coverage, then if there are no objections after that, I will merge. We can look into the test failures later if we can't figure it out soon.

@devmotion
Copy link
Member

I don't have time right now to look into these test failures, but since the main motivation for this PR initially was to support Distributions 0.22, I think it can be merged.

As a general comment, I'm not sure if one should actually define three different Tracker-compatible Gaussian distributions, it somehow doesn't feel "right". It feels like they differ just in the type/structure of the covariance matrix, and hence IMO the implementation should also handle this on the level of the covariance matrices, similar to how MvNormal supports ScalMat, PDiagMat, and PDMat as covariance matrices and FillArray as mean vector. But also that could/should be part of a separate PR IMO.

@mohamed82008
Copy link
Member Author

As a general comment, I'm not sure if one should actually define three different Tracker-compatible Gaussian distributions, it somehow doesn't feel "right".

I agree, let's do that in another PR.

@mohamed82008 mohamed82008 merged commit 5757d2d into master Jan 16, 2020
@delete-merged-branch delete-merged-branch bot deleted the mt/bump_Distributions branch January 16, 2020 10:37
@devmotion
Copy link
Member

Regarding the incorrect gradients @mohamed82008 mentioned in his comment above, the problem with ForwardDiff seems to be that basically all linear algebra implementations that exploit the sparsity of matrices (by, e.g., calling convert(UpperTriangular, ...)) are AD unsafe (see, e.g., JuliaDiff/ForwardDiff.jl#197 (comment) and the other comments in this issue). For instance you can observe this in the following small example, which could be related to the problem mentioned above:

using ForwardDiff
using LinearAlgebra

@show ForwardDiff.gradient(x -> sum(x \ ones(2)), Matrix(I, 2, 2))
@show ForwardDiff.gradient(x -> sum(cholesky(x) \ ones(2)), Matrix(I, 2, 2))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants