Skip to content

Commit

Permalink
sparse * sparse fixes (#55)
Browse files Browse the repository at this point in the history
* sp*sp: overload for BlasFloat eltype only

* tests: fix special{sparse}*special{sparse}

wrap tested matrices into Special() calll

* SpecialMatrices Tuple

* declare sp*sp for all pairs of special matrices

* cosmetic fix

Co-authored-by: Kristoffer Carlsson <kcarlsson89@gmail.com>

---------

Co-authored-by: Alexey Stukalov <astukalol@seer.bio>
Co-authored-by: Kristoffer Carlsson <kcarlsson89@gmail.com>
  • Loading branch information
3 people authored Jan 13, 2025
1 parent fec49ef commit 78c0f41
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
52 changes: 41 additions & 11 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ SimpleOrSpecialOrAdjMat{T, M} = Union{M,
AdjOrTranspMat{T, <:SpecialMat{T,<:M}},
SpecialMat{T,<:AdjOrTranspMat{T,<:M}}}

const SpecialMatrices = (LowerTriangular, UpperTriangular,
UnitLowerTriangular, UnitUpperTriangular,
Symmetric, Hermitian)

# unwraps matrix A from Adjoint/Transpose transform
unwrap_trans(A::AbstractMatrix) = A
unwrap_trans(A::Union{Adjoint, Transpose}) = unwrap_trans(parent(A))
Expand Down Expand Up @@ -219,17 +223,45 @@ end
# sparse * sparse overloads, have to be more specific than
# the ones in SparseArrays.jl to avoid ambiguity

(*)(A::SparseMat{T}, B::SparseMat{T}) where T =
spmatmul_sparse(A, B)
for Amat in (nothing, SpecialMatrices...), Bmat in (nothing, SpecialMatrices...)
Atype = !isnothing(Amat) ? :($Amat{T,S}) : :S
tAtype = !isnothing(Amat) ? :($Amat{T, <:AdjOrTranspMat{T, S}}) : nothing
Btype = !isnothing(Bmat) ? :($Bmat{T,S}) : :S
tBtype = !isnothing(Bmat) ? :($Bmat{T, <:AdjOrTranspMat{T, S}}) : nothing

@eval (*)(A::$Atype, B::$Btype) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)

@eval (*)(A::AdjOrTranspMat{T, $Atype}, B::$Btype) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)

(*)(A::AdjOrTranspMat{T, S}, B::S) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)
@eval (*)(A::$Atype, B::AdjOrTranspMat{T, $Btype}) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)

(*)(A::S, B::AdjOrTranspMat{T, S}) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)
@eval (*)(A::AdjOrTranspMat{T, $Atype}, B::AdjOrTranspMat{T, $Btype}) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)

if tAtype !== nothing
@eval (*)(A::$tAtype, B::$Btype) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)

@eval (*)(A::$tAtype, B::AdjOrTranspMat{T, $Btype}) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)
end

(*)(A::AdjOrTranspMat{T, S}, B::AdjOrTranspMat{T, S}) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)
if tBtype !== nothing
@eval (*)(A::$Atype, B::$tBtype) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)

@eval (*)(A::AdjOrTranspMat{T, $Atype}, B::$tBtype) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)
end

if tAtype !== nothing && tBtype !== nothing
@eval (*)(A::$tAtype, B::$tBtype) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)
end
end

if VERSION < v"1.11" # in 1.11 these wrappers are already defined in LinearAlgebra

Expand All @@ -245,9 +277,7 @@ function (\)(A::Union{S, AdjOrTranspMat{T, S}}, B::StridedMatrix{T}) where {T <:
return ldiv!(C, A, B)
end

for mat in (LowerTriangular, UpperTriangular,
UnitLowerTriangular, UnitUpperTriangular,
Symmetric, Hermitian)
for mat in SpecialMatrices

@eval function (\)(A::Union{$mat{T, S}, AdjOrTranspMat{T, $mat{T, S}}, $mat{T, <:AdjOrTranspMat{T, S}}},
x::StridedVector{T}
Expand Down
4 changes: 2 additions & 2 deletions test/test_BLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -728,8 +728,8 @@ end
n = rand(10:50)
spf = 0.1 + 0.8 * rand()

spA = convert_to_Aclass(sparserandn(SPMT{T, IT}, n, n, spf))
spB = convert_to_Bclass(sparserandn(SPMT{T, IT}, n, n, spf))
spA = Aclass(convert_to_Aclass(sparserandn(SPMT{T, IT}, n, n, spf)))
spB = Bclass(convert_to_Bclass(sparserandn(SPMT{T, IT}, n, n, spf)))
A = convert(Matrix, spA)
B = convert(Matrix, spB)

Expand Down

0 comments on commit 78c0f41

Please sign in to comment.