Skip to content

Commit

Permalink
Zoltan2: Remove call to soon-to-be-deprecated getNodeNumDiags
Browse files Browse the repository at this point in the history
PR #2634; part of #2630.
Authored by @kddevin; reviewed by @mhoemmen.
  • Loading branch information
kddevin authored and mhoemmen committed Apr 25, 2018
1 parent e24a1e6 commit aa0c96b
Showing 1 changed file with 67 additions and 7 deletions.
74 changes: 67 additions & 7 deletions packages/zoltan2/test/unit/models/GraphModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,63 @@ void printGraph(zlno_t nrows, const zgno_t *v,
comm->barrier();
}

/////////////////////////////////////////////////////////////////////////////

template <typename MatrixOrGraph>
void computeNumDiags(
RCP<const MatrixOrGraph> &M,
size_t &numLocalDiags,
size_t &numGlobalDiags
)
{
// See specializations below
}

template <>
void computeNumDiags<tcrsGraph_t>(
RCP<const tcrsGraph_t> &M,
size_t &numLocalDiags,
size_t &numGlobalDiags
)
{
typedef typename tcrsGraph_t::global_ordinal_type gno_t;

size_t maxnnz = M->getNodeMaxNumRowEntries();
Teuchos::Array<gno_t> colGids(maxnnz);

numLocalDiags = 0;
numGlobalDiags = 0;

int nLocalRows = M->getNodeNumRows();
for (int i = 0; i < nLocalRows; i++) {

gno_t rowGid = M->getRowMap()->getGlobalElement(i);
size_t nnz;
M->getGlobalRowCopy(rowGid, colGids(), nnz);

for (size_t j = 0; j < nnz; j++) {
if (rowGid == colGids[j]) {
numLocalDiags++;
break;
}
}
}
Teuchos::reduceAll<int, size_t>(*(M->getComm()), Teuchos::REDUCE_SUM, 1,
&numLocalDiags, &numGlobalDiags);
}

template <>
void computeNumDiags<tcrsMatrix_t>(
RCP<const tcrsMatrix_t> &M,
size_t &numLocalDiags,
size_t &numGlobalDiags
)
{
RCP<const tcrsGraph_t> graph = M->getCrsGraph();
computeNumDiags<tcrsGraph_t>(graph, numLocalDiags, numGlobalDiags);
}


/////////////////////////////////////////////////////////////////////////////
template <typename BaseAdapter, typename Adapter, typename MatrixOrGraph>
void testAdapter(
Expand Down Expand Up @@ -217,8 +274,11 @@ void testAdapter(
tmi.setCoordinateInput(via);
}

int numLocalDiags = M->getNodeNumDiags();
int numGlobalDiags = M->getGlobalNumDiags();
size_t numLocalDiags = 0;
size_t numGlobalDiags = 0;
if (removeSelfEdges) {
computeNumDiags<MatrixOrGraph>(M, numLocalDiags, numGlobalDiags);
}

const RCP<const tmap_t> rowMap = M->getRowMap();
const RCP<const tmap_t> colMap = M->getColMap();
Expand All @@ -228,7 +288,7 @@ void testAdapter(
int *numNbors = new int [nLocalRows];
int *numLocalNbors = new int [nLocalRows];
bool *haveDiag = new bool [nLocalRows];
zgno_t totalLocalNbors = 0;
size_t totalLocalNbors = 0;

for (zlno_t i=0; i < nLocalRows; i++){
numLocalNbors[i] = 0;
Expand Down Expand Up @@ -281,7 +341,7 @@ void testAdapter(
if (model->getLocalNumVertices() != size_t(nLocalRows)) fail = 1;
TEST_FAIL_AND_EXIT(*comm, !fail, "getGlobalNumVertices", 1)

size_t num = (removeSelfEdges ? (totalLocalNbors - numLocalDiags)
size_t num = (removeSelfEdges ? totalLocalNbors - numLocalDiags
: totalLocalNbors);
if (model->getLocalNumEdges() != num) fail = 1;
TEST_FAIL_AND_EXIT(*comm, !fail, "getLocalNumEdges", 1)
Expand All @@ -293,7 +353,7 @@ void testAdapter(
if (model->getGlobalNumVertices() != size_t(nGlobalRows)) fail = 1;
TEST_FAIL_AND_EXIT(*comm, !fail, "getGlobalNumVertices", 1)

size_t num = (removeSelfEdges ? (nLocalNZ-numLocalDiags) : nLocalNZ);
size_t num = (removeSelfEdges ? nLocalNZ-numLocalDiags : nLocalNZ);
if (model->getLocalNumEdges() != num) fail = 1;
TEST_FAIL_AND_EXIT(*comm, !fail, "getLocalNumEdges", 1)

Expand Down Expand Up @@ -508,8 +568,8 @@ void testAdapter(
TEST_FAIL_AND_EXIT(*comm, numLocalNeighbors==num,
"getLocalEdgeList sum size", 1)

fail = ((removeSelfEdges ? size_t(totalLocalNbors-numLocalDiags)
: size_t(totalLocalNbors))
fail = ((removeSelfEdges ? totalLocalNbors-numLocalDiags
: totalLocalNbors)
!= numLocalNeighbors);
TEST_FAIL_AND_EXIT(*comm, !fail, "getLocalEdgeList total size", 1)

Expand Down

0 comments on commit aa0c96b

Please sign in to comment.