Skip to content

Commit

Permalink
Merge pull request #9 from huanglangwen/master
Browse files Browse the repository at this point in the history
add findstructralnz for (bi/tri-)diagonal matrices
  • Loading branch information
ChrisRackauckas authored Jul 27, 2019
2 parents 2d00a0c + ff564e8 commit 29079ab
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 2 deletions.
86 changes: 86 additions & 0 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
module ArrayInterface

using Requires
using LinearAlgebra
using SparseArrays

export findstructralnz,has_sparsestruct

function ismutable end

Expand All @@ -16,6 +20,88 @@ ismutable(x) = ismutable(typeof(x))
ismutable(::Type{<:Array}) = true
ismutable(::Type{<:Number}) = false

"""
has_sparsestruct(x::AbstractArray)
determine whether `findstructralnz` accepts the parameter `x`
"""
has_sparsestruct(x)=false
has_sparsestruct(x::AbstractArray)=false
has_sparsestruct(x::SparseMatrixCSC)=true
has_sparsestruct(x::Diagonal)=true
has_sparsestruct(x::Bidiagonal)=true
has_sparsestruct(x::Tridiagonal)=true
has_sparsestruct(x::SymTridiagonal)=true

"""
findstructralnz(x::AbstractArray)
Return: (I,J) #indexable objects
Find sparsity pattern of special matrices, similar to first two elements of findnz(::SparseMatrixCSC)
"""
function findstructralnz(x::Diagonal)
n=size(x,1)
(1:n,1:n)
end

abstract type MatrixIndex end

struct BidiagonalIndex <: MatrixIndex
count::Int
isup::Bool
end

struct TridiagonalIndex <: MatrixIndex
count::Int
nsize::Int
isrow::Bool
end

Base.firstindex(ind::MatrixIndex)=1
Base.lastindex(ind::MatrixIndex)=ind.count
Base.length(ind::MatrixIndex)=ind.count
function Base.getindex(ind::BidiagonalIndex,i::Int)
1 <= i <= ind.count || throw(BoundsError(ind, i))
if ind.isup
ii=i+1
else
ii=i+1+1
end
convert(Int,floor(ii/2))
end

function Base.getindex(ind::TridiagonalIndex,i::Int)
1 <= i <= ind.count || throw(BoundsError(ind, i))
offsetu= ind.isrow ? 0 : 1
offsetl= ind.isrow ? 1 : 0
if 1 <= i <= ind.nsize
return i
elseif ind.nsize < i <= ind.nsize+ind.nsize-1
return i-ind.nsize+offsetu
else
return i-(ind.nsize+ind.nsize-1)+offsetl
end
end

function findstructralnz(x::Bidiagonal)
n=size(x,1)
isup= x.uplo=='U' ? true : false
rowind=BidiagonalIndex(n+n-1,isup)
colind=BidiagonalIndex(n+n-1,!isup)
(rowind,colind)
end

function findstructralnz(x::Union{Tridiagonal,SymTridiagonal})
n=size(x,1)
rowind=TridiagonalIndex(n+n-1+n-1,n,true)
colind=TridiagonalIndex(n+n-1+n-1,n,false)
(rowind,colind)
end

function findstructralnz(x::SparseMatrixCSC)
rowind,colind,_=findnz(x)
(rowind,colind)
end

function __init__()

Expand Down
35 changes: 33 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,36 @@ using ArrayInterface, Test
@test ArrayInterface.ismutable(rand(3))

using StaticArrays
ArrayInterface.ismutable(@SVector [1,2,3]) == false
ArrayInterface.ismutable(@MVector [1,2,3]) == true
@test ArrayInterface.ismutable(@SVector [1,2,3]) == false
@test ArrayInterface.ismutable(@MVector [1,2,3]) == true

using LinearAlgebra, SparseArrays
D=Diagonal([1,2,3,4])
@test has_sparsestruct(D)
rowind,colind=findstructralnz(D)
@test [D[rowind[i],colind[i]] for i in 1:length(rowind)]==[1,2,3,4]
@test length(rowind)==4
@test length(rowind)==length(colind)

Bu = Bidiagonal([1,2,3,4], [7,8,9], :U)
@test has_sparsestruct(Bu)
rowind,colind=findstructralnz(Bu)
@test [Bu[rowind[i],colind[i]] for i in 1:length(rowind)]==[1,7,2,8,3,9,4]
Bl = Bidiagonal([1,2,3,4], [7,8,9], :L)
@test has_sparsestruct(Bl)
rowind,colind=findstructralnz(Bl)
@test [Bl[rowind[i],colind[i]] for i in 1:length(rowind)]==[1,7,2,8,3,9,4]

Tri=Tridiagonal([1,2,3],[1,2,3,4],[4,5,6])
@test has_sparsestruct(Tri)
rowind,colind=findstructralnz(Tri)
@test [Tri[rowind[i],colind[i]] for i in 1:length(rowind)]==[1,2,3,4,4,5,6,1,2,3]
STri=SymTridiagonal([1,2,3,4],[5,6,7])
@test has_sparsestruct(STri)
rowind,colind=findstructralnz(STri)
@test [STri[rowind[i],colind[i]] for i in 1:length(rowind)]==[1,2,3,4,5,6,7,5,6,7]

Sp=sparse([1,2,3],[1,2,3],[1,2,3])
@test has_sparsestruct(Sp)
rowind,colind=findstructralnz(Sp)
@test [Tri[rowind[i],colind[i]] for i in 1:length(rowind)]==[1,2,3]

0 comments on commit 29079ab

Please sign in to comment.