Skip to content

Commit

Permalink
Merge pull request kokkos#2180 from MalachiTimothyPhillips/malachi/fi…
Browse files Browse the repository at this point in the history
…x-divide-by-zero-in-trsv

Add early return if numRows == 0 in trsv to avoid integer divide-by-zero error
  • Loading branch information
ndellingwood authored Apr 18, 2024
2 parents 47a1849 + 35e115a commit 83374bf
Showing 1 changed file with 36 additions and 12 deletions.
48 changes: 36 additions & 12 deletions sparse/impl/KokkosSparse_trsv_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ struct TrsvWrap {
static void lowerTriSolveCsrUnitDiag(RangeMultiVectorType X,
const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
const lno_t numRows = A.numRows();
if (numRows == 0) return;

const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
const lno_t numVecs = X.extent(1);
Expand All @@ -211,7 +213,9 @@ struct TrsvWrap {

static void lowerTriSolveCsr(RangeMultiVectorType X, const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
const lno_t numRows = A.numRows();
if (numRows == 0) return;

const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
const lno_t numVecs = X.extent(1);
Expand Down Expand Up @@ -254,7 +258,9 @@ struct TrsvWrap {
static void upperTriSolveCsrUnitDiag(RangeMultiVectorType X,
const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
const lno_t numRows = A.numRows();
if (numRows == 0) return;

const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
const lno_t numVecs = X.extent(1);
Expand Down Expand Up @@ -304,7 +310,9 @@ struct TrsvWrap {

static void upperTriSolveCsr(RangeMultiVectorType X, const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
const lno_t numRows = A.numRows();
if (numRows == 0) return;

const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
const lno_t numVecs = X.extent(1);
Expand Down Expand Up @@ -371,7 +379,9 @@ struct TrsvWrap {
static void upperTriSolveCscUnitDiag(RangeMultiVectorType X,
const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
const lno_t numRows = A.numRows();
if (numRows == 0) return;

const lno_t numCols = A.numCols();
const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
Expand Down Expand Up @@ -422,7 +432,9 @@ struct TrsvWrap {

static void upperTriSolveCsc(RangeMultiVectorType X, const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
const lno_t numRows = A.numRows();
if (numRows == 0) return;

const lno_t numCols = A.numCols();
const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
Expand Down Expand Up @@ -481,7 +493,9 @@ struct TrsvWrap {
static void lowerTriSolveCscUnitDiag(RangeMultiVectorType X,
const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
const lno_t numRows = A.numRows();
if (numRows == 0) return;

const lno_t numCols = A.numCols();
const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
Expand Down Expand Up @@ -510,7 +524,9 @@ struct TrsvWrap {
static void upperTriSolveCscUnitDiagConj(RangeMultiVectorType X,
const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
const lno_t numRows = A.numRows();
if (numRows == 0) return;

const lno_t numCols = A.numCols();
const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
Expand Down Expand Up @@ -562,7 +578,9 @@ struct TrsvWrap {
static void upperTriSolveCscConj(RangeMultiVectorType X,
const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
const lno_t numRows = A.numRows();
if (numRows == 0) return;

const lno_t numCols = A.numCols();
const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
Expand Down Expand Up @@ -620,7 +638,9 @@ struct TrsvWrap {

static void lowerTriSolveCsc(RangeMultiVectorType X, const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
const lno_t numRows = A.numRows();
if (numRows == 0) return;

const lno_t numCols = A.numCols();
const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
Expand Down Expand Up @@ -657,7 +677,9 @@ struct TrsvWrap {
static void lowerTriSolveCscUnitDiagConj(RangeMultiVectorType X,
const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
const lno_t numRows = A.numRows();
if (numRows == 0) return;

const lno_t numCols = A.numCols();
const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
Expand Down Expand Up @@ -686,7 +708,9 @@ struct TrsvWrap {
static void lowerTriSolveCscConj(RangeMultiVectorType X,
const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
const lno_t numRows = A.numRows();
if (numRows == 0) return;

const lno_t numCols = A.numCols();
const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
Expand Down

0 comments on commit 83374bf

Please sign in to comment.