diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 5370b64f..851cc88e 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -86,7 +86,13 @@ const StaticMatrixLike{T} = Union{ Adjoint{T, <:StaticVecOrMat{T}}, Symmetric{T, <:StaticMatrix{<:Any, <:Any, T}}, Hermitian{T, <:StaticMatrix{<:Any, <:Any, T}}, - Diagonal{T, <:StaticVector{<:Any, T}} + Diagonal{T, <:StaticVector{<:Any, T}}, + # We specifically list *Triangular here rather than using + # AbstractTriangular to avoid ambiguities in size() etc. + UpperTriangular{T, <:StaticMatrix{<:Any, <:Any, T}}, + LowerTriangular{T, <:StaticMatrix{<:Any, <:Any, T}}, + UnitUpperTriangular{T, <:StaticMatrix{<:Any, <:Any, T}}, + UnitLowerTriangular{T, <:StaticMatrix{<:Any, <:Any, T}} } const StaticVecOrMatLike{T} = Union{StaticVector{<:Any, T}, StaticMatrixLike{T}} const StaticArrayLike{T} = Union{StaticVecOrMatLike{T}, StaticArray{<:Tuple, T}} diff --git a/src/abstractarray.jl b/src/abstractarray.jl index 0f9a4f31..d45aead4 100644 --- a/src/abstractarray.jl +++ b/src/abstractarray.jl @@ -1,13 +1,12 @@ -length(a::SA) where {SA <: StaticArrayLike} = length(SA) +length(a::StaticArrayLike) = prod(Size(a)) length(a::Type{SA}) where {SA <: StaticArrayLike} = prod(Size(SA)) -@pure size(::Type{SA}) where {SA <: StaticArrayLike} = get(Size(SA)) +@pure size(::Type{SA}) where {SA <: StaticArrayLike} = Tuple(Size(SA)) @inline function size(t::Type{<:StaticArrayLike}, d::Int) S = size(t) d > length(S) ? 1 : S[d] end -@inline size(a::StaticArrayLike) = size(typeof(a)) -@inline size(a::StaticArrayLike, d::Int) = size(typeof(a), d) +@inline size(a::StaticArrayLike) = Tuple(Size(a)) Base.axes(s::StaticArray) = _axes(Size(s)) @pure function _axes(::Size{sizes}) where {sizes} diff --git a/src/lu.jl b/src/lu.jl index bf9357d3..361aae25 100644 --- a/src/lu.jl +++ b/src/lu.jl @@ -11,6 +11,24 @@ Base.iterate(S::LU, ::Val{:U}) = (S.U, Val(:p)) Base.iterate(S::LU, ::Val{:p}) = (S.p, Val(:done)) Base.iterate(S::LU, ::Val{:done}) = nothing +@inline function Base.getproperty(F::LU, s::Symbol) + if s === :P + U = getfield(F, :U) + p = getfield(F, :p) + one(similar_type(p, Size(U)))[:,invperm(p)] + else + getfield(F, s) + end +end + +function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, F::LU) + println(io, LU) # Don't show full type - this will be in the factors + println(io, "L factor:") + show(io, mime, F.L) + println(io, "\nU factor:") + show(io, mime, F.U) +end + # LU decomposition function lu(A::StaticMatrix, pivot::Union{Val{false},Val{true}}=Val(true)) L, U, p = _lu(A, pivot) @@ -136,9 +154,6 @@ end :(SVector{$(M-1),Int}($(tuple(2:M...)))) end -# Base.lufact() interface is fairly inherently type unstable. Punt on -# implementing that, for now... - \(F::LU, v::AbstractVector) = F.U \ (F.L \ v[F.p]) \(F::LU, B::AbstractMatrix) = F.U \ (F.L \ B[F.p,:]) diff --git a/src/traits.jl b/src/traits.jl index d358d82d..93780c86 100644 --- a/src/traits.jl +++ b/src/traits.jl @@ -93,6 +93,7 @@ Size(::Type{Transpose{T, A}}) where {T, A <: AbstractVecOrMat{T}} = Size(Size(A) Size(::Type{Symmetric{T, A}}) where {T, A <: AbstractMatrix{T}} = Size(A) Size(::Type{Hermitian{T, A}}) where {T, A <: AbstractMatrix{T}} = Size(A) Size(::Type{Diagonal{T, A}}) where {T, A <: AbstractVector{T}} = Size(Size(A)[1], Size(A)[1]) +Size(::Type{<:LinearAlgebra.AbstractTriangular{T, A}}) where {T,A} = Size(A) @pure Size(::Type{<:AbstractArray{<:Any, N}}) where {N} = Size(ntuple(_ -> Dynamic(), N)) @@ -117,7 +118,7 @@ Length(::Size{S}) where {S} = _Length(S...) @inline _Length(S...) = Length{Dynamic()}() # Some @pure convenience functions for `Size` -@pure get(::Size{S}) where {S} = S +@pure (::Type{Tuple})(::Size{S}) where {S} = S @pure getindex(::Size{S}, i::Int) where {S} = i <= length(S) ? S[i] : 1 @@ -138,7 +139,7 @@ Base.LinearIndices(::Size{S}) where {S} = LinearIndices(S) @pure size_tuple(::Size{S}) where {S} = Tuple{S...} # Some @pure convenience functions for `Length` -@pure get(::Length{L}) where {L} = L +@pure (::Type{Int})(::Length{L}) where {L} = L @pure Base.:(==)(::Length{L}, l::Int) where {L} = L == l @pure Base.:(==)(l::Int, ::Length{L}) where {L} = l == L diff --git a/src/triangular.jl b/src/triangular.jl index 283c5a1f..499ca66b 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -1,5 +1,3 @@ -@inline Size(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}) = Size(A.data) - @inline transpose(A::LinearAlgebra.LowerTriangular{<:Any,<:StaticMatrix}) = LinearAlgebra.UpperTriangular(transpose(A.data)) @inline adjoint(A::LinearAlgebra.LowerTriangular{<:Any,<:StaticMatrix}) = diff --git a/test/core.jl b/test/core.jl index 5b48ff46..a0b1f80c 100644 --- a/test/core.jl +++ b/test/core.jl @@ -155,6 +155,9 @@ @test Size(Adjoint(zero(SMatrix{2, 3}))) == Size(3, 2) @test Size(Diagonal(SVector(1, 2, 3))) == Size(3, 3) @test Size(Transpose(Diagonal(SVector(1, 2, 3)))) == Size(3, 3) + @test Size(UpperTriangular(zero(SMatrix{2, 2}))) == Size(2, 2) + @test Size(LowerTriangular(zero(SMatrix{2, 2}))) == Size(2, 2) + @test Size(LowerTriangular(Symmetric(zero(SMatrix{2, 2})))) == Size(2,2) end @testset "dimmatch" begin diff --git a/test/lu.jl b/test/lu.jl index 8ff15902..e96bd043 100644 --- a/test/lu.jl +++ b/test/lu.jl @@ -1,5 +1,14 @@ using StaticArrays, Test, LinearAlgebra +@testset "LU utils" begin + F = lu(SA[1 2; 3 4]) + + @test @inferred((F->F.p)(F)) === SA[2, 1] + @test @inferred((F->F.P)(F)) === SA[0 1; 1 0] + + @test occursin(r"^StaticArrays.LU.*L factor.*U factor"s, sprint(show, MIME("text/plain"), F)) +end + @testset "LU decomposition ($m×$n, pivot=$pivot)" for pivot in (true, false), m in [0:4..., 15], n in [0:4..., 15] a = SMatrix{m,n,Int}(1:(m*n)) l, u, p = @inferred(lu(a, Val{pivot}()))