Skip to content

Commit

Permalink
fix JuliaLang#28581 and specialize istriu/istril for structured matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
ranocha committed Mar 21, 2020
1 parent 6447534 commit 0908602
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 50 deletions.
40 changes: 38 additions & 2 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,44 @@ end

iszero(M::Bidiagonal) = iszero(M.dv) && iszero(M.ev)
isone(M::Bidiagonal) = all(isone, M.dv) && iszero(M.ev)
istriu(M::Bidiagonal) = M.uplo == 'U' || iszero(M.ev)
istril(M::Bidiagonal) = M.uplo == 'L' || iszero(M.ev)
function istriu(M::Bidiagonal, k::Integer=0)
if M.uplo == 'U'
if k <= 0
return true
elseif k == 1
return iszero(M.dv)
else # k >= 2
return iszero(M.dv) && iszero(M.ev)
end
else # M.uplo == 'L'
if k <= -1
return true
elseif k == 0
return iszero(M.ev)
else # k >= 1
return iszero(M.ev) && iszero(M.dv)
end
end
end
function istril(M::Bidiagonal, k::Integer=0)
if M.uplo == 'U'
if k >= 1
return true
elseif k == 0
return iszero(M.ev)
else # k <= -1
return iszero(M.ev) && iszero(M.dv)
end
else # M.uplo == 'L'
if k >= 0
return true
elseif k == -1
return iszero(M.dv)
else # k <= -2
return iszero(M.dv) && iszero(M.ev)
end
end
end
isdiag(M::Bidiagonal) = iszero(M.ev)

function tril!(M::Bidiagonal, k::Integer=0)
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ iszero(D::Diagonal) = all(iszero, D.diag)
isone(D::Diagonal) = all(isone, D.diag)
isdiag(D::Diagonal) = all(isdiag, D.diag)
isdiag(D::Diagonal{<:Number}) = true
istriu(D::Diagonal) = true
istril(D::Diagonal) = true
istriu(D::Diagonal, k::Integer=0) = k <= 0 || iszero(D.diag) ? true : false
istril(D::Diagonal, k::Integer=0) = k >= 0 || iszero(D.diag) ? true : false
function triu!(D::Diagonal,k::Integer=0)
n = size(D,1)
if !(-n + 1 <= k <= n + 1)
Expand Down
43 changes: 13 additions & 30 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ Tridiagonal(A::Diagonal) = Tridiagonal(fill!(similar(A.diag, length(A.diag)-1),
fill!(similar(A.diag, length(A.diag)-1), 0))

# conversions from Bidiagonal to other special matrix types
Diagonal(A::Bidiagonal) =
iszero(A.ev) ? Diagonal(A.dv) :
throw(ArgumentError("matrix cannot be represented as Diagonal"))
Diagonal(A::Bidiagonal) = Diagonal(A.dv)
SymTridiagonal(A::Bidiagonal) =
iszero(A.ev) ? SymTridiagonal(A.dv, A.ev) :
throw(ArgumentError("matrix cannot be represented as SymTridiagonal"))
Expand All @@ -23,58 +21,43 @@ Tridiagonal(A::Bidiagonal) =
A.uplo == 'U' ? A.ev : fill!(similar(A.ev), 0))

# conversions from SymTridiagonal to other special matrix types
Diagonal(A::SymTridiagonal) =
iszero(A.ev) ? Diagonal(A.dv) :
throw(ArgumentError("matrix cannot be represented as Diagonal"))
Diagonal(A::SymTridiagonal) = Diagonal(A.dv)
Bidiagonal(A::SymTridiagonal) =
iszero(A.ev) ? Bidiagonal(A.dv, A.ev, :U) :
throw(ArgumentError("matrix cannot be represented as Bidiagonal"))
Tridiagonal(A::SymTridiagonal) =
Tridiagonal(copy(A.ev), A.dv, A.ev)

# conversions from Tridiagonal to other special matrix types
Diagonal(A::Tridiagonal) =
iszero(A.dl) && iszero(A.du) ? Diagonal(A.d) :
throw(ArgumentError("matrix cannot be represented as Diagonal"))
Diagonal(A::Tridiagonal) = Diagonal(A.d)
Bidiagonal(A::Tridiagonal) =
iszero(A.dl) ? Bidiagonal(A.d, A.du, :U) :
iszero(A.du) ? Bidiagonal(A.d, A.dl, :L) :
throw(ArgumentError("matrix cannot be represented as Bidiagonal"))
SymTridiagonal(A::Tridiagonal) =
A.dl == A.du ? SymTridiagonal(A.d, A.dl) :
throw(ArgumentError("matrix cannot be represented as SymTridiagonal"))

# conversions from AbstractTriangular to special matrix types
Diagonal(A::AbstractTriangular) =
isdiag(A) ? Diagonal(diag(A)) :
throw(ArgumentError("matrix cannot be represented as Diagonal"))
Bidiagonal(A::AbstractTriangular) =
isbanded(A, 0, 1) ? Bidiagonal(diag(A, 0), diag(A, 1), :U) : # is upper bidiagonal
isbanded(A, -1, 0) ? Bidiagonal(diag(A, 0), diag(A, -1), :L) : # is lower bidiagonal
throw(ArgumentError("matrix cannot be represented as Bidiagonal"))
SymTridiagonal(A::AbstractTriangular) = SymTridiagonal(Tridiagonal(A))
Tridiagonal(A::AbstractTriangular) =
isbanded(A, -1, 1) ? Tridiagonal(diag(A, -1), diag(A, 0), diag(A, 1)) : # is tridiagonal
throw(ArgumentError("matrix cannot be represented as Tridiagonal"))
UpperTriangular(A::Bidiagonal) =
A.uplo == 'U' ? UpperTriangular{eltype(A), typeof(A)}(A) :
throw(ArgumentError("matrix cannot be represented as UpperTriangular"))
LowerTriangular(A::Bidiagonal) =
A.uplo == 'L' ? LowerTriangular{eltype(A), typeof(A)}(A) :
throw(ArgumentError("matrix cannot be represented as LowerTriangular"))

const ConvertibleSpecialMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,AbstractTriangular}
const PossibleTriangularMatrix = Union{Diagonal, Bidiagonal, AbstractTriangular}

convert(T::Type{<:Diagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m : T(m)
convert(T::Type{<:SymTridiagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m : T(m)
convert(T::Type{<:Tridiagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m : T(m)
convert(T::Type{<:Diagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m :
isdiag(m) ? T(m) : throw(ArgumentError("matrix cannot be represented as Diagonal"))
convert(T::Type{<:SymTridiagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m :
issymmetric(m) && isbanded(m, -1, 1) ? T(m) : throw(ArgumentError("matrix cannot be represented as SymTridiagonal"))
convert(T::Type{<:Tridiagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m :
isbanded(m, -1, 1) ? T(m) : throw(ArgumentError("matrix cannot be represented as Tridiagonal"))

convert(T::Type{<:LowerTriangular}, m::Union{LowerTriangular,UnitLowerTriangular}) = m isa T ? m : T(m)
convert(T::Type{<:UpperTriangular}, m::Union{UpperTriangular,UnitUpperTriangular}) = m isa T ? m : T(m)

convert(T::Type{<:LowerTriangular}, m::PossibleTriangularMatrix) = m isa T ? m : T(m)
convert(T::Type{<:UpperTriangular}, m::PossibleTriangularMatrix) = m isa T ? m : T(m)
convert(T::Type{<:LowerTriangular}, m::PossibleTriangularMatrix) = m isa T ? m :
istril(m) ? T(m) : throw(ArgumentError("matrix cannot be represented as LowerTriangular"))
convert(T::Type{<:UpperTriangular}, m::PossibleTriangularMatrix) = m isa T ? m :
istriu(m) ? T(m) : throw(ArgumentError("matrix cannot be represented as UpperTriangular"))

# Constructs two method definitions taking into account (assumed) commutativity
# e.g. @commutative f(x::S, y::T) where {S,T} = x+y is the same is defining
Expand Down
29 changes: 20 additions & 9 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ similar(A::Union{Adjoint{Ti,Tv}, Transpose{Ti,Tv}}, ::Type{T}) where {T,Ti,Tv<:U
UnitLowerTriangular(similar(parent(parent(A)), T))


LowerTriangular(U::UpperTriangular) = throw(ArgumentError(
"cannot create a LowerTriangular matrix from an UpperTriangular input"))
UpperTriangular(U::LowerTriangular) = throw(ArgumentError(
"cannot create an UpperTriangular matrix from a LowerTriangular input"))

"""
LowerTriangular(A::AbstractMatrix)
Expand Down Expand Up @@ -288,10 +283,26 @@ function Base.replace_in_print_matrix(A::Union{LowerTriangular,UnitLowerTriangul
return i >= j ? s : Base.replace_with_centered_mark(s)
end

istril(A::LowerTriangular) = true
istril(A::UnitLowerTriangular) = true
istriu(A::UpperTriangular) = true
istriu(A::UnitUpperTriangular) = true
function istril(A::Union{LowerTriangular,UnitLowerTriangular}, k::Integer=0)
k >= 0 && return true
m, n = size(A)
for j in max(1, k + 2):n
for i in 1:min(j - k - 1, m)
iszero(A[i, j]) || return false
end
end
return true
end
function istriu(A::Union{UpperTriangular,UnitUpperTriangular}, k::Integer=0)
k <= 0 && return true
m, n = size(A)
for j in 1:min(n, m + k - 1)
for i in max(1, j - k + 1):m
iszero(A[i, j]) || return false
end
end
return true
end
istril(A::Adjoint) = istriu(A.parent)
istril(A::Transpose) = istriu(A.parent)
istriu(A::Adjoint) = istril(A.parent)
Expand Down
36 changes: 32 additions & 4 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,16 @@ end

#tril and triu

istriu(M::SymTridiagonal) = iszero(M.ev)
istril(M::SymTridiagonal) = iszero(M.ev)
function istriu(M::SymTridiagonal, k::Integer=0)
if k <= -1
return true
elseif k == 0
return iszero(M.ev)
else # k >= 1
return iszero(M.ev) && iszero(M.dv)
end
end
istril(M::SymTridiagonal, k::Integer) = istriu(M, k)
iszero(M::SymTridiagonal) = iszero(M.ev) && iszero(M.dv)
isone(M::SymTridiagonal) = iszero(M.ev) && all(isone, M.dv)
isdiag(M::SymTridiagonal) = iszero(M.ev)
Expand Down Expand Up @@ -654,8 +662,28 @@ end

iszero(M::Tridiagonal) = iszero(M.dl) && iszero(M.d) && iszero(M.du)
isone(M::Tridiagonal) = iszero(M.dl) && all(isone, M.d) && iszero(M.du)
istriu(M::Tridiagonal) = iszero(M.dl)
istril(M::Tridiagonal) = iszero(M.du)
function istriu(M::Tridiagonal, k::Integer=0)
if k <= -1
return true
elseif k == 0
return iszero(M.dl)
elseif k == 1
return iszero(M.dl) && iszero(M.d)
else # k >= 2
return iszero(M.dl) && iszero(M.d) && iszero(M.du)
end
end
function istril(M::Tridiagonal, k::Integer=0)
if k >= 1
return true
elseif k == 0
return iszero(M.du)
elseif k == -1
return iszero(M.du) && iszero(M.d)
else # k <= -2
return iszero(M.du) && iszero(M.d) && iszero(M.dl)
end
end
isdiag(M::Tridiagonal) = iszero(M.dl) && iszero(M.du)

function tril!(M::Tridiagonal, k::Integer=0)
Expand Down
18 changes: 18 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,16 @@ Random.seed!(1)
bidiagcopy(dv, ev, uplo) = Bidiagonal(copy(dv), copy(ev), uplo)

@test istril(Bidiagonal(dv,ev,:L))
@test istril(Bidiagonal(dv,ev,:L), -1)
@test !istril(Bidiagonal(dv,ev,:L), 1)
@test istril(Bidiagonal(zerosdv,ev,:L), 1)
@test !istril(Bidiagonal(zerosdv,ev,:L), 2)
@test istril(Bidiagonal(zerosdv,zerosev,:L), 2)
@test !istril(Bidiagonal(dv,ev,:U))
@test istril(Bidiagonal(dv,ev,:U), -1)
@test !istril(Bidiagonal(dv,ev,:U), 1)
@test !istril(Bidiagonal(zerosdv,ev,:U), 1)
@test istril(Bidiagonal(zerosdv,zerosev,:U), 1)
@test tril!(bidiagcopy(dv,ev,:U),-1) == Bidiagonal(zerosdv,zerosev,:U)
@test tril!(bidiagcopy(dv,ev,:L),-1) == Bidiagonal(zerosdv,ev,:L)
@test tril!(bidiagcopy(dv,ev,:U),-2) == Bidiagonal(zerosdv,zerosev,:U)
Expand All @@ -145,7 +154,16 @@ Random.seed!(1)
@test_throws ArgumentError tril!(bidiagcopy(dv, ev, :U), n)

@test istriu(Bidiagonal(dv,ev,:U))
@test istriu(Bidiagonal(dv,ev,:U), -1)
@test !istriu(Bidiagonal(dv,ev,:U), 1)
@test istriu(Bidiagonal(zerosdv,ev,:U), 1)
@test !istriu(Bidiagonal(zerosdv,ev,:U), 2)
@test !istriu(Bidiagonal(zerosdv,zerosev,:U), 2)
@test !istriu(Bidiagonal(dv,ev,:L))
@test istriu(Bidiagonal(dv,ev,:L), -1)
@test !istriu(Bidiagonal(dv,ev,:L), 1)
@test !istriu(Bidiagonal(zerosdv,ev,:L), 1)
@test istriu(Bidiagonal(zerosdv,zerosev,:L), 1)
@test triu!(bidiagcopy(dv,ev,:L),1) == Bidiagonal(zerosdv,zerosev,:L)
@test triu!(bidiagcopy(dv,ev,:U),1) == Bidiagonal(zerosdv,ev,:U)
@test triu!(bidiagcopy(dv,ev,:U),2) == Bidiagonal(zerosdv,zerosev,:U)
Expand Down
6 changes: 6 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@ Random.seed!(1)
@test isdiag(Diagonal([[1 0; 0 1], [1 0; 0 1]]))
@test !isdiag(Diagonal([[1 0; 0 1], [1 0; 1 1]]))
@test istriu(D)
@test istriu(D, -1)
@test !istriu(D, 1)
@test istriu(Diagonal(zero(diag(D))), 1)
@test istril(D)
@test istril(D, -1)
@test !istril(D, 1)
@test istril(Diagonal(zero(diag(D))), 1)
if elty <: Real
@test ishermitian(D)
end
Expand Down
27 changes: 27 additions & 0 deletions stdlib/LinearAlgebra/test/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,33 @@ Random.seed!(1)
for newtype in [Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal]
@test_throws ArgumentError convert(newtype,A)
end


# test operations/constructors (not conversions) permitted in the docs
dl = [1., 1.]
d = [-2., -2., -2.]
T = Tridiagonal(dl, d, -dl)
S = SymTridiagonal(d, dl)
Bu = Bidiagonal(d, dl, :U)
Bl = Bidiagonal(d, dl, :L)
D = Diagonal(d)
M = [-2. 0. 0.; 1. -2. 0.; -1. 1. -2.]
U = UpperTriangular(M)
L = LowerTriangular(Matrix(M'))

for A in (T, S, Bu, Bl, D, U, L, M)
Adense = Matrix(A)
B = Symmetric(A)
Bdense = Matrix(B)
for (C,Cdense) in ((A,Adense), (B,Bdense))
@test Diagonal(C) == Diagonal(Cdense)
@test Bidiagonal(C, :U) == Bidiagonal(Cdense, :U)
@test Bidiagonal(C, :L) == Bidiagonal(Cdense, :L)
@test Tridiagonal(C) == Tridiagonal(Cdense)
@test UpperTriangular(C) == UpperTriangular(Cdense)
@test LowerTriangular(C) == LowerTriangular(Cdense)
end
end
end

@testset "Binary ops among special types" begin
Expand Down
5 changes: 2 additions & 3 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -577,9 +577,8 @@ let n = 5
@test_throws DimensionMismatch rdiv!(A, transpose(UnitUpperTriangular(B)))
end

# Test that UpperTriangular(LowerTriangular) throws. See #16201
@test_throws ArgumentError LowerTriangular(UpperTriangular(randn(3,3)))
@test_throws ArgumentError UpperTriangular(LowerTriangular(randn(3,3)))
@test isdiag(LowerTriangular(UpperTriangular(randn(3,3))))
@test isdiag(UpperTriangular(LowerTriangular(randn(3,3))))

# Issue 16196
@test UpperTriangular(Matrix(1.0I, 3, 3)) \ view(fill(1., 3), [1,2,3]) == fill(1., 3)
Expand Down

0 comments on commit 0908602

Please sign in to comment.