-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support (un)whiten
and (inv)quad
with static arrays
#183
Changes from all commits
af62114
f9b4a37
5b376f2
0aca2f3
964e70c
61f801f
3f17bde
3b050b9
e21b0b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -91,45 +91,38 @@ LinearAlgebra.sqrt(a::PDiagMat) = PDiagMat(map(sqrt, a.diag)) | |
|
||
### whiten and unwhiten | ||
|
||
function whiten!(r::StridedVector, a::PDiagMat, x::StridedVector) | ||
n = a.dim | ||
@check_argdims length(r) == length(x) == n | ||
v = a.diag | ||
for i = 1:n | ||
r[i] = x[i] / sqrt(v[i]) | ||
end | ||
return r | ||
function whiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) | ||
@check_argdims axes(r) == axes(x) | ||
@check_argdims a.dim == size(x, 1) | ||
return r .= x ./ sqrt.(a.diag) | ||
end | ||
|
||
function unwhiten!(r::StridedVector, a::PDiagMat, x::StridedVector) | ||
n = a.dim | ||
@check_argdims length(r) == length(x) == n | ||
v = a.diag | ||
for i = 1:n | ||
r[i] = x[i] * sqrt(v[i]) | ||
end | ||
return r | ||
function unwhiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) | ||
@check_argdims axes(r) == axes(x) | ||
@check_argdims a.dim == size(x, 1) | ||
return r .= x .* sqrt.(a.diag) | ||
end | ||
|
||
function whiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix) | ||
r .= x ./ sqrt.(a.diag) | ||
return r | ||
function whiten(a::PDiagMat, x::AbstractVecOrMat) | ||
@check_argdims a.dim == size(x, 1) | ||
return x ./ sqrt.(a.diag) | ||
end | ||
|
||
function unwhiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix) | ||
r .= x .* sqrt.(a.diag) | ||
return r | ||
function unwhiten(a::PDiagMat, x::AbstractVecOrMat) | ||
@check_argdims a.dim == size(x, 1) | ||
return x .* sqrt.(a.diag) | ||
end | ||
|
||
|
||
whiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) = r .= x ./ sqrt.(a.diag) | ||
unwhiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) = r .= x .* sqrt.(a.diag) | ||
|
||
|
||
### quadratic forms | ||
|
||
quad(a::PDiagMat, x::AbstractVector) = wsumsq(a.diag, x) | ||
invquad(a::PDiagMat, x::AbstractVector) = invwsumsq(a.diag, x) | ||
function quad(a::PDiagMat, x::AbstractVecOrMat) | ||
@check_argdims a.dim == size(x, 1) | ||
if x isa AbstractVector | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if it is helpful/desirable here, but generally I tried to reduce the number of methods to lower the probability for method ambiguity errors. Maybe it's mostly useful for the generic code path in src/generics.jl. |
||
return wsumsq(a.diag, x) | ||
else | ||
# map(Base.Fix1(invquad, a), eachcol(x)) or similar alternatives | ||
# do NOT return a `SVector` for inputs `x::SMatrix`. | ||
return vec(sum(abs2.(x) .* a.diag; dims = 1)) | ||
Comment on lines
+121
to
+123
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit unsatisfying - is there any way we could avoid unnecessary allocations but still make StaticArray return the expected types? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a version with reduced allocations specialized for |
||
end | ||
end | ||
|
||
function quad!(r::AbstractArray, a::PDiagMat, x::AbstractMatrix) | ||
ad = a.diag | ||
|
@@ -145,8 +138,18 @@ function quad!(r::AbstractArray, a::PDiagMat, x::AbstractMatrix) | |
r | ||
end | ||
|
||
function invquad(a::PDiagMat, x::AbstractVecOrMat) | ||
@check_argdims a.dim == size(x, 1) | ||
if x isa AbstractVector | ||
return invwsumsq(a.diag, x) | ||
else | ||
# map(Base.Fix1(invquad, a), eachcol(x)) or similar alternatives | ||
# do NOT return a `SVector` for inputs `x::SMatrix`. | ||
return vec(sum(abs2.(x) ./ a.diag; dims = 1)) | ||
end | ||
end | ||
|
||
function invquad!(r::AbstractArray, a::PDiagMat, x::AbstractMatrix) | ||
m, n = size(x) | ||
ad = a.diag | ||
@check_argdims eachindex(ad) == axes(x, 1) | ||
@check_argdims eachindex(r) == axes(x, 2) | ||
|
@@ -186,3 +189,18 @@ function Xt_invA_X(a::PDiagMat, x::AbstractMatrix) | |
z = x ./ sqrt.(a.diag) | ||
transpose(z) * z | ||
end | ||
|
||
### Specializations for `Array` arguments with reduced allocations | ||
|
||
function quad(a::PDiagMat{<:Real,<:Vector}, x::Matrix) | ||
@check_argdims a.dim == size(x, 1) | ||
T = typeof(zero(eltype(a)) * abs2(zero(eltype(x)))) | ||
return quad!(Vector{T}(undef, size(x, 2)), a, x) | ||
end | ||
|
||
function invquad(a::PDiagMat{<:Real,<:Vector}, x::Matrix) | ||
@check_argdims a.dim == size(x, 1) | ||
T = typeof(abs2(zero(eltype(x))) / zero(eltype(a))) | ||
return invquad!(Vector{T}(undef, size(x, 2)), a, x) | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the restriction on the type parameter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently,
AbstractPDMat
s areReal
matrices:PDMats.jl/src/PDMats.jl
Line 36 in 9572e79