-
Notifications
You must be signed in to change notification settings - Fork 423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove constraint on AbstractPDMat #1552
base: master
Are you sure you want to change the base?
Changes from all commits
69859c1
8987628
246df4e
f006fab
7883de3
e416cb6
c77770c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -48,11 +48,15 @@ struct MvNormal{T<:Real,Cov<:AbstractPDMat,Mean<:AbstractVector} <: AbstractMvNo | |||||||||
end | ||||||||||
``` | ||||||||||
|
||||||||||
Here, the mean vector can be an instance of any `AbstractVector`. The covariance can be | ||||||||||
of any subtype of `AbstractPDMat`. Particularly, one can use `PDMat` for full covariance, | ||||||||||
`PDiagMat` for diagonal covariance, and `ScalMat` for the isotropic covariance -- those | ||||||||||
in the form of ``\\sigma^2 \\mathbf{I}``. (See the Julia package | ||||||||||
Here, the mean vector can be an instance of any `AbstractVector`. | ||||||||||
|
||||||||||
Special handling is included if the covariance is a subtype of `AbstractPDMat`. | ||||||||||
Particularly, one can use `PDMat` for full covariance, `PDiagMat` for diagonal covariance, | ||||||||||
and `ScalMat` for the isotropic covariance | ||||||||||
-- those in the form of ``\\sigma^2 \\mathbf{I}``. (See the Julia package | ||||||||||
[PDMats](https://github.com/JuliaStats/PDMats.jl/) for details). | ||||||||||
If you pass a dense `Matrix` for the covariance, it is automatically converted to a `PDMat`. | ||||||||||
For other matrix types, you have to convert them yourself. | ||||||||||
|
||||||||||
We also define a set of aliases for the types using different combinations of mean vectors and covariance: | ||||||||||
|
||||||||||
|
@@ -166,9 +170,14 @@ Generally, users don't have to worry about these internal details. | |||||||||
We provide a common constructor `MvNormal`, which will construct a distribution of | ||||||||||
appropriate type depending on the input arguments. | ||||||||||
""" | ||||||||||
struct MvNormal{T<:Real,Cov<:AbstractPDMat,Mean<:AbstractVector} <: AbstractMvNormal | ||||||||||
struct MvNormal{T<:Real,Cov<:AbstractMatrix,Mean<:AbstractVector} <: AbstractMvNormal | ||||||||||
μ::Mean | ||||||||||
Σ::Cov | ||||||||||
|
||||||||||
function MvNormal{T, Cov, Mean}(μ::Mean, Σ::Cov) where {T, Mean, Cov} | ||||||||||
size(Σ, 1) == size(Σ, 2) == length(μ) || throw(DimensionMismatch("The dimensions of mu and Sigma are inconsistent.")) | ||||||||||
return new{T, Cov, Mean}(μ, Σ) | ||||||||||
end | ||||||||||
end | ||||||||||
|
||||||||||
const MultivariateNormal = MvNormal # for the purpose of backward compatibility | ||||||||||
|
@@ -182,14 +191,15 @@ const ZeroMeanDiagNormal{Axes} = MvNormal{Float64,PDiagMat{Float64,Vector{Float6 | |||||||||
const ZeroMeanFullNormal{Axes} = MvNormal{Float64,PDMat{Float64,Matrix{Float64}},Zeros{Float64,1,Axes}} | ||||||||||
|
||||||||||
### Construction | ||||||||||
function MvNormal(μ::AbstractVector{T}, Σ::AbstractPDMat{T}) where {T<:Real} | ||||||||||
dim(Σ) == length(μ) || throw(DimensionMismatch("The dimensions of mu and Sigma are inconsistent.")) | ||||||||||
MvNormal{T,typeof(Σ), typeof(μ)}(μ, Σ) | ||||||||||
end | ||||||||||
|
||||||||||
function MvNormal(μ::AbstractVector{T}, Σ::AbstractMatrix{T}) where {T<:Real} | ||||||||||
MvNormal{T, typeof(Σ), typeof(μ)}(μ, Σ) | ||||||||||
end | ||||||||||
function MvNormal(μ::AbstractVector{<:Real}, Σ::AbstractPDMat{<:Real}) | ||||||||||
R = Base.promote_eltype(μ, Σ) | ||||||||||
MvNormal(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, Σ)) | ||||||||||
μc = convert(AbstractArray{R}, μ) | ||||||||||
Σc = convert(AbstractArray{R}, Σ) | ||||||||||
MvNormal{R, typeof(Σc), typeof(μc)}(μc, Σc) | ||||||||||
Comment on lines
+200
to
+202
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is not needed it seems, is it?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is so we can call the parameterized constructor. |
||||||||||
end | ||||||||||
|
||||||||||
# constructor with general covariance matrix | ||||||||||
|
@@ -198,7 +208,7 @@ end | |||||||||
|
||||||||||
Construct a multivariate normal distribution with mean `μ` and covariance matrix `Σ`. | ||||||||||
""" | ||||||||||
MvNormal(μ::AbstractVector{<:Real}, Σ::AbstractMatrix{<:Real}) = MvNormal(μ, PDMat(Σ)) | ||||||||||
MvNormal(μ::AbstractVector{<:Real}, Σ::Matrix{<:Real}) = MvNormal(μ, PDMat(Σ)) | ||||||||||
MvNormal(μ::AbstractVector{<:Real}, Σ::Diagonal{<:Real}) = MvNormal(μ, PDiagMat(Σ.diag)) | ||||||||||
MvNormal(μ::AbstractVector{<:Real}, Σ::Union{Symmetric{<:Real,<:Diagonal{<:Real}},Hermitian{<:Real,<:Diagonal{<:Real}}}) = MvNormal(μ, PDiagMat(Σ.data.diag)) | ||||||||||
MvNormal(μ::AbstractVector{<:Real}, Σ::UniformScaling{<:Real}) = | ||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -19,7 +19,7 @@ which is also a subtype of `AbstractMvNormal` to represent a multivariate normal | |||||
canonical parameters. Particularly, `MvNormalCanon` is defined as: | ||||||
|
||||||
```julia | ||||||
struct MvNormalCanon{T<:Real,P<:AbstractPDMat,V<:AbstractVector} <: AbstractMvNormal | ||||||
struct MvNormalCanon{T<:Real,P<:AbstractMatrix,V<:AbstractVector} <: AbstractMvNormal | ||||||
μ::V # the mean vector | ||||||
h::V # potential vector, i.e. inv(Σ) * μ | ||||||
J::P # precision matrix, i.e. inv(Σ) | ||||||
|
@@ -40,10 +40,19 @@ const ZeroMeanIsoNormalCanon{Axes} = MvNormalCanon{Float64, ScalMat{Float64}, | |||||
|
||||||
**Note:** `MvNormalCanon` share the same set of methods as `MvNormal`. | ||||||
""" | ||||||
struct MvNormalCanon{T<:Real,P<:AbstractPDMat,V<:AbstractVector} <: AbstractMvNormal | ||||||
struct MvNormalCanon{T<:Real,P<:AbstractMatrix,V<:AbstractVector} <: AbstractMvNormal | ||||||
μ::V # the mean vector | ||||||
h::V # potential vector, i.e. inv(Σ) * μ | ||||||
J::P # precision matrix, i.e. inv(Σ) | ||||||
|
||||||
function MvNormalCanon{T,P,V}(μ::V, h::AbstractVector, J::P) where {T<:Real, V<:AbstractVector{T}, P} | ||||||
length(μ) == length(h) == dim(J) || throw(DimensionMismatch("Inconsistent argument dimensions")) | ||||||
if typeof(μ) === typeof(h) | ||||||
return new{T,typeof(J),typeof(μ)}(μ, h, J) | ||||||
else | ||||||
return new{T,typeof(J),Vector{T}}(collect(μ), collect(h), J) | ||||||
end | ||||||
Comment on lines
+50
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be handled by an outer constructor shouldn't it? It seems a bit weird to lie about the type of the constructed distribution and get something else than There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that does seem odd. |
||||||
end | ||||||
end | ||||||
|
||||||
const FullNormalCanon = MvNormalCanon{Float64,PDMat{Float64,Matrix{Float64}},Vector{Float64}} | ||||||
|
@@ -56,26 +65,17 @@ const ZeroMeanIsoNormalCanon{Axes} = MvNormalCanon{Float64,ScalMat{Float64},Zer | |||||
|
||||||
|
||||||
### Constructors | ||||||
function MvNormalCanon(μ::AbstractVector{T}, h::AbstractVector{T}, J::AbstractPDMat{T}) where {T<:Real} | ||||||
length(μ) == length(h) == dim(J) || throw(DimensionMismatch("Inconsistent argument dimensions")) | ||||||
if typeof(μ) === typeof(h) | ||||||
return MvNormalCanon{T,typeof(J),typeof(μ)}(μ, h, J) | ||||||
else | ||||||
return MvNormalCanon{T,typeof(J),Vector{T}}(collect(μ), collect(h), J) | ||||||
end | ||||||
end | ||||||
|
||||||
function MvNormalCanon(μ::AbstractVector{T}, h::AbstractVector{T}, J::AbstractPDMat) where {T<:Real} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we want to keep this function but change |
||||||
function MvNormalCanon(μ::AbstractVector{T}, h::AbstractVector{T}, J::P) where {T<:Real, P} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
R = promote_type(T, eltype(J)) | ||||||
MvNormalCanon(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, h), convert(AbstractArray{R}, J)) | ||||||
MvNormalCanon{T,P,typeof(μ)}(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, h), convert(AbstractArray{R}, J)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the old function here (relaxed to
Suggested change
|
||||||
end | ||||||
|
||||||
function MvNormalCanon(μ::AbstractVector{<:Real}, h::AbstractVector{<:Real}, J::AbstractPDMat) | ||||||
function MvNormalCanon(μ::AbstractVector{<:Real}, h::AbstractVector{<:Real}, J::AbstractMatrix{<:Real}) | ||||||
R = Base.promote_eltype(μ, h, J) | ||||||
MvNormalCanon(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, h), convert(AbstractArray{R}, J)) | ||||||
end | ||||||
|
||||||
function MvNormalCanon(h::AbstractVector{<:Real}, J::AbstractPDMat) | ||||||
function MvNormalCanon(h::AbstractVector{<:Real}, J::AbstractMatrix{<:Real}) | ||||||
length(h) == dim(J) || throw(DimensionMismatch("Inconsistent argument dimensions")) | ||||||
R = Base.promote_eltype(h, J) | ||||||
hh = convert(AbstractArray{R}, h) | ||||||
|
@@ -89,7 +89,7 @@ end | |||||
Construct a multivariate normal distribution with potential vector `h` and precision matrix | ||||||
`J`. | ||||||
""" | ||||||
MvNormalCanon(h::AbstractVector{<:Real}, J::AbstractMatrix{<:Real}) = MvNormalCanon(h, PDMat(J)) | ||||||
MvNormalCanon(h::AbstractVector{<:Real}, J::Matrix{<:Real}) = MvNormalCanon(h, PDMat(J)) | ||||||
MvNormalCanon(h::AbstractVector{<:Real}, J::Diagonal{<:Real}) = MvNormalCanon(h, PDiagMat(J.diag)) | ||||||
MvNormalCanon(μ::AbstractVector{<:Real}, J::Union{Symmetric{<:Real,<:Diagonal{<:Real}},Hermitian{<:Real,<:Diagonal{<:Real}}}) = MvNormalCanon(μ, PDiagMat(J.data.diag)) | ||||||
function MvNormalCanon(h::AbstractVector{<:Real}, J::UniformScaling{<:Real}) | ||||||
|
@@ -170,7 +170,7 @@ sqmahal!(r::AbstractVector, d::MvNormalCanon, x::AbstractMatrix) = quad!(r, d.J, | |||||
|
||||||
# Sampling (for GenericMvNormal) | ||||||
|
||||||
unwhiten_winv!(J::AbstractPDMat, x::AbstractVecOrMat) = unwhiten!(inv(J), x) | ||||||
unwhiten_winv!(J::AbstractMatrix, x::AbstractVecOrMat) = unwhiten!(inv(J), x) | ||||||
unwhiten_winv!(J::PDiagMat, x::AbstractVecOrMat) = whiten!(J, x) | ||||||
unwhiten_winv!(J::ScalMat, x::AbstractVecOrMat) = whiten!(J, x) | ||||||
if isdefined(PDMats, :PDSparseMat) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess here we might want to change it to
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests pass, right now.
It took some effort to make sure the constructors cover the right cases and don't stackoverflow, or ambiguity error.
I could go and give them another pass over now that it is working, but I wouldn't want to just go and relax the ones there right now whily-nilly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think without this relaxation you can provoke a test error when you use something like
MvNormal(::Vector{Float32}, ::BlockDiagonal{Float64})
.