diff --git a/src/generics.jl b/src/generics.jl index 686d190..78e95d2 100644 --- a/src/generics.jl +++ b/src/generics.jl @@ -33,19 +33,6 @@ LinearAlgebra.checksquare(a::AbstractPDMat) = size(a, 1) ## whiten and unwhiten -whiten!(a::AbstractMatrix, x::AbstractVecOrMat) = whiten!(x, a, x) -unwhiten!(a::AbstractMatrix, x::AbstractVecOrMat) = unwhiten!(x, a, x) - -function whiten!(r::AbstractVecOrMat, a::AbstractMatrix, x::AbstractVecOrMat) - v = _rcopy!(r, x) - ldiv!(chol_lower(cholesky(a)), v) -end - -function unwhiten!(r::AbstractVecOrMat, a::AbstractMatrix, x::AbstractVecOrMat) - v = _rcopy!(r, x) - lmul!(chol_lower(cholesky(a)), v) -end - """ whiten(a::AbstractMatrix, x::AbstractVecOrMat) unwhiten(a::AbstractMatrix, x::AbstractVecOrMat) @@ -80,35 +67,41 @@ julia> W * W' 0.0 1.0 ``` """ -whiten(a::AbstractMatrix, x::AbstractVecOrMat) = whiten!(similar(x), a, x) -unwhiten(a::AbstractMatrix, x::AbstractVecOrMat) = unwhiten!(similar(x), a, x) +whiten(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = whiten(AbstractPDMat(a), x) +unwhiten(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = unwhiten(AbstractPDMat(a), x) +whiten!(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = whiten!(x, a, x) +unwhiten!(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = unwhiten!(x, a, x) + +function whiten!(r::AbstractVecOrMat, a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) + return whiten!(r, AbstractPDMat(a), x) +end +function unwhiten!(r::AbstractVecOrMat, a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) + return unwhiten!(r, AbstractPDMat(a), x) +end ## quad """ quad(a::AbstractMatrix, x::AbstractVecOrMat) -Return the value of the quadratic form defined by `a` applied to `x` +Return the value of the quadratic form defined by `a` applied to `x`. If `x` is a vector the quadratic form is `x' * a * x`. If `x` is a matrix the quadratic form is applied column-wise. """ -function quad(a::AbstractMatrix{T}, x::AbstractMatrix{S}) where {T<:Real, S<:Real} - @check_argdims LinearAlgebra.checksquare(a) == size(x, 1) - quad!(Array{promote_type(T, S)}(undef, size(x,2)), a, x) +function quad(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) + return quad(AbstractPDMat(a), x) end -quad(a::AbstractMatrix, x::AbstractVector) = sum(abs2, chol_upper(cholesky(a)) * x) -invquad(a::AbstractMatrix, x::AbstractVector) = sum(abs2, chol_lower(cholesky(a)) \ x) - """ quad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix) -Overwrite `r` with the value of the quadratic form defined by `a` applied columnwise to `x` +Overwrite `r` with the value of the quadratic form defined by `a` applied columnwise to `x`. """ -quad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix) = colwise_dot!(r, x, a * x) - +function quad!(r::AbstractArray, a::AbstractMatrix{<:Real}, x::AbstractMatrix) + return quad!(r, AbstractPDMat(a), x) +end """ invquad(a::AbstractMatrix, x::AbstractVecOrMat) @@ -120,10 +113,8 @@ For most `PDMat` types this is done in a way that does not require evaluation of If `x` is a vector the quadratic form is `x' * a * x`. If `x` is a matrix the quadratic form is applied column-wise. """ -invquad(a::AbstractMatrix, x::AbstractVecOrMat) = x' / a * x -function invquad(a::AbstractMatrix{T}, x::AbstractMatrix{S}) where {T<:Real, S<:Real} - @check_argdims LinearAlgebra.checksquare(a) == size(x, 1) - invquad!(Array{promote_type(T, S)}(undef, size(x,2)), a, x) +function invquad(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) + return invquad(AbstractPDMat(a), x) end """ @@ -131,4 +122,7 @@ end Overwrite `r` with the value of the quadratic form defined by `inv(a)` applied columnwise to `x` """ -invquad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix) = colwise_dot!(r, x, a \ x) +function invquad!(r::AbstractArray, a::AbstractMatrix{<:Real}, x::AbstractMatrix) + return invquad!(r, AbstractPDMat(a), x) +end + diff --git a/src/pdiagmat.jl b/src/pdiagmat.jl index 752a2b6..de3dfa4 100644 --- a/src/pdiagmat.jl +++ b/src/pdiagmat.jl @@ -91,45 +91,38 @@ LinearAlgebra.sqrt(a::PDiagMat) = PDiagMat(map(sqrt, a.diag)) ### whiten and unwhiten -function whiten!(r::StridedVector, a::PDiagMat, x::StridedVector) - n = a.dim - @check_argdims length(r) == length(x) == n - v = a.diag - for i = 1:n - r[i] = x[i] / sqrt(v[i]) - end - return r +function whiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) + @check_argdims axes(r) == axes(x) + @check_argdims a.dim == size(x, 1) + return r .= x ./ sqrt.(a.diag) end - -function unwhiten!(r::StridedVector, a::PDiagMat, x::StridedVector) - n = a.dim - @check_argdims length(r) == length(x) == n - v = a.diag - for i = 1:n - r[i] = x[i] * sqrt(v[i]) - end - return r +function unwhiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) + @check_argdims axes(r) == axes(x) + @check_argdims a.dim == size(x, 1) + return r .= x .* sqrt.(a.diag) end -function whiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix) - r .= x ./ sqrt.(a.diag) - return r +function whiten(a::PDiagMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + return x ./ sqrt.(a.diag) end - -function unwhiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix) - r .= x .* sqrt.(a.diag) - return r +function unwhiten(a::PDiagMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + return x .* sqrt.(a.diag) end - -whiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) = r .= x ./ sqrt.(a.diag) -unwhiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) = r .= x .* sqrt.(a.diag) - - ### quadratic forms -quad(a::PDiagMat, x::AbstractVector) = wsumsq(a.diag, x) -invquad(a::PDiagMat, x::AbstractVector) = invwsumsq(a.diag, x) +function quad(a::PDiagMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + if x isa AbstractVector + return wsumsq(a.diag, x) + else + # map(Base.Fix1(invquad, a), eachcol(x)) or similar alternatives + # do NOT return a `SVector` for inputs `x::SMatrix`. + return vec(sum(abs2.(x) .* a.diag; dims = 1)) + end +end function quad!(r::AbstractArray, a::PDiagMat, x::AbstractMatrix) ad = a.diag @@ -145,8 +138,18 @@ function quad!(r::AbstractArray, a::PDiagMat, x::AbstractMatrix) r end +function invquad(a::PDiagMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + if x isa AbstractVector + return invwsumsq(a.diag, x) + else + # map(Base.Fix1(invquad, a), eachcol(x)) or similar alternatives + # do NOT return a `SVector` for inputs `x::SMatrix`. + return vec(sum(abs2.(x) ./ a.diag; dims = 1)) + end +end + function invquad!(r::AbstractArray, a::PDiagMat, x::AbstractMatrix) - m, n = size(x) ad = a.diag @check_argdims eachindex(ad) == axes(x, 1) @check_argdims eachindex(r) == axes(x, 2) @@ -186,3 +189,18 @@ function Xt_invA_X(a::PDiagMat, x::AbstractMatrix) z = x ./ sqrt.(a.diag) transpose(z) * z end + +### Specializations for `Array` arguments with reduced allocations + +function quad(a::PDiagMat{<:Real,<:Vector}, x::Matrix) + @check_argdims a.dim == size(x, 1) + T = typeof(zero(eltype(a)) * abs2(zero(eltype(x)))) + return quad!(Vector{T}(undef, size(x, 2)), a, x) +end + +function invquad(a::PDiagMat{<:Real,<:Vector}, x::Matrix) + @check_argdims a.dim == size(x, 1) + T = typeof(abs2(zero(eltype(x))) / zero(eltype(a))) + return invquad!(Vector{T}(undef, size(x, 2)), a, x) +end + diff --git a/src/pdmat.jl b/src/pdmat.jl index 4149df9..05c6d34 100644 --- a/src/pdmat.jl +++ b/src/pdmat.jl @@ -86,6 +86,78 @@ LinearAlgebra.eigmin(a::PDMat) = eigmin(a.mat) Base.kron(A::PDMat, B::PDMat) = PDMat(kron(A.mat, B.mat), Cholesky(kron(A.chol.U, B.chol.U), 'U', A.chol.info)) LinearAlgebra.sqrt(A::PDMat) = PDMat(sqrt(Hermitian(A.mat))) +### (un)whitening + +function whiten!(r::AbstractVecOrMat, a::PDMat, x::AbstractVecOrMat) + @check_argdims axes(r) == axes(x) + @check_argdims a.dim == size(x, 1) + v = _rcopy!(r, x) + return ldiv!(chol_lower(cholesky(a)), v) +end +function unwhiten!(r::AbstractVecOrMat, a::PDMat, x::AbstractVecOrMat) + @check_argdims axes(r) == axes(x) + @check_argdims a.dim == size(x, 1) + v = _rcopy!(r, x) + return lmul!(chol_lower(cholesky(a)), v) +end + +function whiten(a::PDMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + return chol_lower(cholesky(a)) \ x +end +function unwhiten(a::PDMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + return chol_lower(cholesky(a)) * x +end + +## quad/invquad + +function quad(a::PDMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + aU_x = chol_upper(cholesky(a)) * x + if x isa AbstractVector + return sum(abs2, aU_x) + else + return vec(sum(abs2, aU_x; dims = 1)) + end +end + +function quad!(r::AbstractArray, a::PDMat, x::AbstractMatrix) + @check_argdims axes(r) == axes(x, 2) + @check_argdims a.dim == size(x, 1) + aU = chol_upper(cholesky(a)) + z = similar(r, a.dim) # buffer to save allocations + @inbounds for i in axes(x, 2) + copyto!(z, view(x, :, i)) + lmul!(aU, z) + r[i] = sum(abs2, z) + end + return r +end + +function invquad(a::PDMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + inv_aL_x = chol_lower(cholesky(a)) \ x + if x isa AbstractVector + return sum(abs2, inv_aL_x) + else + return vec(sum(abs2, inv_aL_x; dims = 1)) + end +end + +function invquad!(r::AbstractArray, a::PDMat, x::AbstractMatrix) + @check_argdims axes(r) == axes(x, 2) + @check_argdims a.dim == size(x, 1) + aL = chol_lower(cholesky(a)) + z = similar(r, a.dim) # buffer to save allocations + @inbounds for i in axes(x, 2) + copyto!(z, view(x, :, i)) + ldiv!(aL, z) + r[i] = sum(abs2, z) + end + return r +end + ### tri products function X_A_Xt(a::PDMat, x::AbstractMatrix) @@ -111,3 +183,18 @@ function Xt_invA_X(a::PDMat, x::AbstractMatrix) z = chol_lower(a.chol) \ x return transpose(z) * z end + +### Specializations for `Array` arguments with reduced allocations + +function quad(a::PDMat{<:Real,<:Vector}, x::Matrix) + @check_argdims a.dim == size(x, 1) + T = typeof(zero(eltype(a)) * abs2(zero(eltype(x)))) + return quad!(Vector{T}(undef, size(x, 2)), a, x) +end + +function invquad(a::PDMat{<:Real,<:Vector}, x::Matrix) + @check_argdims a.dim == size(x, 1) + T = typeof(abs2(zero(eltype(x))) / zero(eltype(a))) + return invquad!(Vector{T}(undef, size(x, 2)), a, x) +end + diff --git a/src/pdsparsemat.jl b/src/pdsparsemat.jl index baedad6..e890dd4 100644 --- a/src/pdsparsemat.jl +++ b/src/pdsparsemat.jl @@ -78,37 +78,84 @@ LinearAlgebra.sqrt(A::PDSparseMat) = PDMat(sqrt(Hermitian(Matrix(A)))) ### whiten and unwhiten function whiten!(r::AbstractVecOrMat, a::PDSparseMat, x::AbstractVecOrMat) + @check_argdims axes(r) == axes(x) + @check_argdims a.dim == size(x, 1) # Can't use `ldiv!` due to missing support in SparseArrays return copyto!(r, chol_lower(a.chol) \ x) end function unwhiten!(r::AbstractVecOrMat, a::PDSparseMat, x::AbstractVecOrMat) + @check_argdims axes(r) == axes(x) + @check_argdims a.dim == size(x, 1) # `*` is not defined for `PtL` factor components, # so we can't use `chol_lower(a.chol) * x` C = a.chol PtL = sparse(C.L)[C.p, :] - # Can't use `lmul!` due to missing support in SparseArrays return copyto!(r, PtL * x) end +function whiten(a::PDSparseMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + return chol_lower(cholesky(a)) \ x +end + +function unwhiten(a::PDSparseMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + # `*` is not defined for `PtL` factor components, + # so we can't use `chol_lower(a.chol) * x` + C = a.chol + PtL = sparse(C.L)[C.p, :] + return PtL * x +end ### quadratic forms -quad(a::PDSparseMat, x::AbstractVector) = dot(x, a * x) -invquad(a::PDSparseMat, x::AbstractVector) = dot(x, a \ x) +function quad(a::PDSparseMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + # https://github.com/JuliaLang/julia/commit/2425ae760fb5151c5c7dd0554e87c5fc9e24de73 + if VERSION < v"1.4.0-DEV.92" + z = a.mat * x + return x isa AbstractVector ? dot(x, z) : map(dot, eachcol(x), eachcol(z)) + else + return x isa AbstractVector ? dot(x, a.mat, x) : map(Base.Fix1(quad, a), eachcol(x)) + end +end function quad!(r::AbstractArray, a::PDSparseMat, x::AbstractMatrix) - @check_argdims eachindex(r) == axes(x, 2) - for i in axes(x, 2) - r[i] = quad(a, x[:,i]) + @check_argdims axes(r) == axes(x, 2) + # https://github.com/JuliaLang/julia/commit/2425ae760fb5151c5c7dd0554e87c5fc9e24de73 + if VERSION < v"1.4.0-DEV.92" + z = similar(r, a.dim) # buffer to save allocations + @inbounds for i in axes(x, 2) + xi = view(x, :, i) + copyto!(z, xi) + lmul!(a.mat, z) + r[i] = dot(xi, z) + end + else + @inbounds for i in axes(x, 2) + xi = view(x, :, i) + r[i] = dot(xi, a.mat, xi) + end end return r end +function invquad(a::PDSparseMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + z = a.chol \ x + return x isa AbstractVector ? dot(x, z) : map(dot, eachcol(x), eachcol(z)) +end + function invquad!(r::AbstractArray, a::PDSparseMat, x::AbstractMatrix) - @check_argdims eachindex(r) == axes(x, 2) - for i in axes(x, 2) - r[i] = invquad(a, x[:,i]) + @check_argdims axes(r) == axes(x, 2) + @check_argdims a.dim == size(x, 1) + z = similar(r, a.dim) # buffer to save allocations + @inbounds for i in axes(x, 2) + xi = view(x, :, i) + copyto!(z, xi) + ldiv!(a.chol, z) + r[i] = dot(xi, z) end return r end diff --git a/src/scalmat.jl b/src/scalmat.jl index 58ee3e1..3db91f4 100644 --- a/src/scalmat.jl +++ b/src/scalmat.jl @@ -76,23 +76,67 @@ LinearAlgebra.sqrt(a::ScalMat) = ScalMat(a.dim, sqrt(a.value)) ### whiten and unwhiten function whiten!(r::AbstractVecOrMat, a::ScalMat, x::AbstractVecOrMat) - @check_argdims LinearAlgebra.checksquare(a) == size(x, 1) + @check_argdims axes(r) == axes(x) + @check_argdims a.dim == size(x, 1) _ldiv!(r, sqrt(a.value), x) end function unwhiten!(r::AbstractVecOrMat, a::ScalMat, x::AbstractVecOrMat) - @check_argdims LinearAlgebra.checksquare(a) == size(x, 1) + @check_argdims axes(r) == axes(x) + @check_argdims a.dim == size(x, 1) mul!(r, x, sqrt(a.value)) end +function whiten(a::ScalMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + return x / sqrt(a.value) +end +function unwhiten(a::ScalMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + return sqrt(a.value) * x +end ### quadratic forms -quad(a::ScalMat, x::AbstractVector) = sum(abs2, x) * a.value -invquad(a::ScalMat, x::AbstractVector) = sum(abs2, x) / a.value +function quad(a::ScalMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + if x isa AbstractVector + return sum(abs2, x) * a.value + else + # map(Base.Fix1(quad, a), eachcol(x)) or similar alternatives + # do NOT return a `SVector` for inputs `x::SMatrix`. + wsq = let w = a.value + x -> w * abs2(x) + end + return vec(sum(wsq, x; dims=1)) + end +end + +function quad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix) + @check_argdims eachindex(r) == axes(x, 2) + @check_argdims a.dim == size(x, 1) + return map!(Base.Fix1(quad, a), r, eachcol(x)) +end -quad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix) = colwise_sumsq!(r, x, a.value) -invquad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix) = colwise_sumsqinv!(r, x, a.value) +function invquad(a::ScalMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + if x isa AbstractVector + return sum(abs2, x) / a.value + else + # map(Base.Fix1(invquad, a), eachcol(x)) or similar alternatives + # do NOT return a `SVector` for inputs `x::SMatrix`. + wsq = let w = a.value + x -> abs2(x) / w + end + return vec(sum(wsq, x; dims=1)) + end +end + +function invquad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix) + @check_argdims eachindex(r) == axes(x, 2) + @check_argdims a.dim == size(x, 1) + return map!(Base.Fix1(invquad, a), r, eachcol(x)) +end ### tri products diff --git a/src/utils.jl b/src/utils.jl index c85d938..398bbcc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -124,3 +124,8 @@ end else _ldiv!(Y::AbstractArray, s::Number, X::AbstractArray) = ldiv!(Y, s, X) end + +# https://github.com/JuliaLang/julia/pull/29749 +if VERSION < v"1.1.0-DEV.792" + eachcol(A::AbstractVecOrMat) = (view(A, :, i) for i in axes(A, 2)) +end diff --git a/test/specialarrays.jl b/test/specialarrays.jl index b812a80..3b6417e 100644 --- a/test/specialarrays.jl +++ b/test/specialarrays.jl @@ -43,6 +43,30 @@ using StaticArrays @test A \ Y isa SMatrix{4, 10, Float64} @test A \ Y ≈ Matrix(A) \ Matrix(Y) + @test whiten(A, x) isa SVector{4, Float64} + @test whiten(A, x) ≈ cholesky(Matrix(A)).L \ Vector(x) + + @test whiten(A, Y) isa SMatrix{4, 10, Float64} + @test whiten(A, Y) ≈ cholesky(Matrix(A)).L \ Matrix(Y) + + @test unwhiten(A, x) isa SVector{4, Float64} + @test unwhiten(A, x) ≈ cholesky(Matrix(A)).L * Vector(x) + + @test unwhiten(A, Y) isa SMatrix{4, 10, Float64} + @test unwhiten(A, Y) ≈ cholesky(Matrix(A)).L * Matrix(Y) + + @test quad(A, x) isa Float64 + @test quad(A, x) ≈ Vector(x)' * Matrix(A) * Vector(x) + + @test quad(A, Y) isa SVector{10, Float64} + @test quad(A, Y) ≈ diag(Matrix(Y)' * Matrix(A) * Matrix(Y)) + + @test invquad(A, x) isa Float64 + @test invquad(A, x) ≈ Vector(x)' * (Matrix(A) \ Vector(x)) + + @test invquad(A, Y) isa SVector{10, Float64} + @test invquad(A, Y) ≈ diag(Matrix(Y)' * (Matrix(A) \ Matrix(Y))) + @test X_A_Xt(A, X) isa SMatrix{10, 10, Float64} @test X_A_Xt(A, X) ≈ Matrix(X) * Matrix(A) * Matrix(X)'