Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUSPARSE: Eagerly combine duplicate element on construction. #2213

Merged
merged 2 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 103 additions & 13 deletions lib/cusparse/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,113 @@ function SparseArrays.sparse(x::DenseCuMatrix; fmt=:csc)
end
end

SparseArrays.sparse(I::CuVector, J::CuVector, V::CuVector; kws...) =
sparse(I, J, V, maximum(I), maximum(J); kws...)
function SparseArrays.sparse(I::CuVector, J::CuVector, V::CuVector, args...; kwargs...)
sparse(Cint.(I), Cint.(J), V, args...; kwargs...)
end

SparseArrays.sparse(I::CuVector, J::CuVector, V::CuVector, m, n; kws...) =
sparse(Cint.(I), Cint.(J), V, m, n; kws...)
function SparseArrays.sparse(I::CuVector{Cint}, J::CuVector{Cint}, V::CuVector{Tv},
m=maximum(I), n=maximum(J);
fmt=:csc, combine=nothing) where Tv
# we use COO as an intermediate format, as it's easy to construct from I/J/V.
coo = CuSparseMatrixCOO{Tv}(I, J, V, (m, n))

function SparseArrays.sparse(I::CuVector{Cint}, J::CuVector{Cint}, V::CuVector{Tv}, m, n;
fmt=:csc) where Tv
# find groups of values that correspond to the same position in the matrix.
# if there's no duplicates, `groups` will just be a vector of ones.
# otherwise, it will contain the number of duplicates for each group,
# with subsequent values that are part of the group set to zero.
coo = sort_coo(coo, 'R')
groups = similar(I, Int)
function find_groups(groups, I, J)
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x
if i > length(groups)
return
end
len = 0

coo = CuSparseMatrixCOO{Tv}(I, J, V, (m, n))
if fmt == :csc
# check if we're at the start of a new group
@inbounds if i == 1 || I[i] != I[i-1] || J[i] != J[i-1]
len = 1
while i+len <= length(groups) && I[i] == I[i+len] && J[i] == J[i+len]
len += 1
end
end

@inbounds groups[i] = len

return
end
kernel = @cuda launch=false find_groups(groups, coo.rowInd, coo.colInd)
config = launch_configuration(kernel.fun)
threads = min(length(groups), config.threads)
blocks = cld(length(groups), threads)
kernel(groups, coo.rowInd, coo.colInd; threads, blocks)

# if we got any group of more than one element, we need to combine them.
# this may actually not be required, as some CUSPARSE functions can handle
# duplicate entries, but it's not clear which ones do and which ones don't.
# also, to ensure matrix display is correct, combine values eagerly.
ngroups = mapreduce(!iszero, +, groups)
if ngroups != length(groups)
if combine === nothing
combine = if Tv === Bool
|
else
+
end
end

total_lengths = cumsum(groups) # TODO: add and use `scan!(; exclusive=true)`
I = similar(I, ngroups)
J = similar(J, ngroups)
V = similar(V, ngroups)

# use one thread per value, and if it's at the start of a group,
# combine (if needed) all values and update the output vectors.
function combine_groups(groups, total_lengths, oldI, oldJ, oldV, newI, newJ, newV, combine)
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x
if i > length(groups)
return
end

# check if we're at the start of a group
@inbounds if groups[i] != 0
# get an exclusive offset from the inclusive cumsum
offset = total_lengths[i] - groups[i] + 1

# copy values
newI[i] = oldI[offset]
newJ[i] = oldJ[offset]
newV[i] = if groups[i] == 1
oldV[offset]
else
# combine all values in the group
val = oldV[offset]
j = 1
while j < groups[i]
val = combine(val, oldV[offset+j])
j += 1
end
val
end
end

return
end
kernel = @cuda launch=false combine_groups(groups, total_lengths, coo.rowInd, coo.colInd, coo.nzVal, I, J, V, combine)
config = launch_configuration(kernel.fun)
threads = min(length(groups), config.threads)
blocks = cld(length(groups), threads)
kernel(groups, total_lengths, coo.rowInd, coo.colInd, coo.nzVal, I, J, V, combine; threads, blocks)

coo = CuSparseMatrixCOO{Tv}(I, J, V, (m, n))
end

if fmt == :coo
return coo
elseif fmt == :csc
return CuSparseMatrixCSC(coo)
elseif fmt == :csr
return CuSparseMatrixCSR(coo)
elseif fmt == :coo
# The COO format is assumed to be sorted by row.
return sort_coo(coo, 'R')
else
error("Format :$fmt not available, use :csc, :csr, or :coo.")
end
Expand Down Expand Up @@ -231,7 +321,7 @@ for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR, :CuSparseMatrix
$SparseMatrixType(S::Diagonal{Tv, <:CuArray}) where Tv = $SparseMatrixType{Tv}(S)
$SparseMatrixType{Tv}(S::Diagonal) where {Tv} = $SparseMatrixType{Tv, Cint}(S)
end

if SparseMatrixType == :CuSparseMatrixCOO
@eval function $SparseMatrixType{Tv, Ti}(S::Diagonal) where {Tv, Ti}
m = size(S, 1)
Expand All @@ -242,7 +332,7 @@ for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR, :CuSparseMatrix
m = size(S, 1)
return $SparseMatrixType{Tv, Ti}(CuVector(1:(m+1)), CuVector(1:m), convert(CuVector{Tv}, S.diag), (m, m))
end
end
end
end

# by flipping rows and columns, we can use that to get CSC to CSR too
Expand Down
24 changes: 24 additions & 0 deletions test/libraries/cusparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1078,3 +1078,27 @@ end
end
end
end

@testset "duplicate entries" begin
# already sorted
let
I = [1, 3, 4, 4]
J = [1, 2, 3, 3]
V = [1f0, 2f0, 3f0, 10f0]
coo = sparse(cu(I), cu(J), cu(V); fmt=:coo)
@test Array(coo.rowInd) == [1, 3, 4]
@test Array(coo.colInd) == [1, 2, 3]
@test Array(coo.nzVal) == [1f0, 2f0, 13f0]
end

# out of order
let
I = [4, 1, 3, 4]
J = [3, 1, 2, 3]
V = [10f0, 1f0, 2f0, 3f0]
coo = sparse(cu(I), cu(J), cu(V); fmt=:coo)
@test Array(coo.rowInd) == [1, 3, 4]
@test Array(coo.colInd) == [1, 2, 3]
@test Array(coo.nzVal) == [1f0, 2f0, 13f0]
end
end