From aa0c96b520b87802b2f9b15dba52401a7ea370e1 Mon Sep 17 00:00:00 2001 From: K Devine Date: Tue, 24 Apr 2018 21:13:56 -0600 Subject: [PATCH] Zoltan2: Remove call to soon-to-be-deprecated getNodeNumDiags PR #2634; part of #2630. Authored by @kddevin; reviewed by @mhoemmen. --- .../zoltan2/test/unit/models/GraphModel.cpp | 74 +++++++++++++++++-- 1 file changed, 67 insertions(+), 7 deletions(-) diff --git a/packages/zoltan2/test/unit/models/GraphModel.cpp b/packages/zoltan2/test/unit/models/GraphModel.cpp index 20cb08cb7fb2..0a5dfe4ffbac 100644 --- a/packages/zoltan2/test/unit/models/GraphModel.cpp +++ b/packages/zoltan2/test/unit/models/GraphModel.cpp @@ -127,6 +127,63 @@ void printGraph(zlno_t nrows, const zgno_t *v, comm->barrier(); } +///////////////////////////////////////////////////////////////////////////// + +template +void computeNumDiags( + RCP &M, + size_t &numLocalDiags, + size_t &numGlobalDiags +) +{ + // See specializations below +} + +template <> +void computeNumDiags( + RCP &M, + size_t &numLocalDiags, + size_t &numGlobalDiags +) +{ + typedef typename tcrsGraph_t::global_ordinal_type gno_t; + + size_t maxnnz = M->getNodeMaxNumRowEntries(); + Teuchos::Array colGids(maxnnz); + + numLocalDiags = 0; + numGlobalDiags = 0; + + int nLocalRows = M->getNodeNumRows(); + for (int i = 0; i < nLocalRows; i++) { + + gno_t rowGid = M->getRowMap()->getGlobalElement(i); + size_t nnz; + M->getGlobalRowCopy(rowGid, colGids(), nnz); + + for (size_t j = 0; j < nnz; j++) { + if (rowGid == colGids[j]) { + numLocalDiags++; + break; + } + } + } + Teuchos::reduceAll(*(M->getComm()), Teuchos::REDUCE_SUM, 1, + &numLocalDiags, &numGlobalDiags); +} + +template <> +void computeNumDiags( + RCP &M, + size_t &numLocalDiags, + size_t &numGlobalDiags +) +{ + RCP graph = M->getCrsGraph(); + computeNumDiags(graph, numLocalDiags, numGlobalDiags); +} + + ///////////////////////////////////////////////////////////////////////////// template void testAdapter( @@ -217,8 +274,11 @@ void testAdapter( tmi.setCoordinateInput(via); } - int numLocalDiags = M->getNodeNumDiags(); - int numGlobalDiags = M->getGlobalNumDiags(); + size_t numLocalDiags = 0; + size_t numGlobalDiags = 0; + if (removeSelfEdges) { + computeNumDiags(M, numLocalDiags, numGlobalDiags); + } const RCP rowMap = M->getRowMap(); const RCP colMap = M->getColMap(); @@ -228,7 +288,7 @@ void testAdapter( int *numNbors = new int [nLocalRows]; int *numLocalNbors = new int [nLocalRows]; bool *haveDiag = new bool [nLocalRows]; - zgno_t totalLocalNbors = 0; + size_t totalLocalNbors = 0; for (zlno_t i=0; i < nLocalRows; i++){ numLocalNbors[i] = 0; @@ -281,7 +341,7 @@ void testAdapter( if (model->getLocalNumVertices() != size_t(nLocalRows)) fail = 1; TEST_FAIL_AND_EXIT(*comm, !fail, "getGlobalNumVertices", 1) - size_t num = (removeSelfEdges ? (totalLocalNbors - numLocalDiags) + size_t num = (removeSelfEdges ? totalLocalNbors - numLocalDiags : totalLocalNbors); if (model->getLocalNumEdges() != num) fail = 1; TEST_FAIL_AND_EXIT(*comm, !fail, "getLocalNumEdges", 1) @@ -293,7 +353,7 @@ void testAdapter( if (model->getGlobalNumVertices() != size_t(nGlobalRows)) fail = 1; TEST_FAIL_AND_EXIT(*comm, !fail, "getGlobalNumVertices", 1) - size_t num = (removeSelfEdges ? (nLocalNZ-numLocalDiags) : nLocalNZ); + size_t num = (removeSelfEdges ? nLocalNZ-numLocalDiags : nLocalNZ); if (model->getLocalNumEdges() != num) fail = 1; TEST_FAIL_AND_EXIT(*comm, !fail, "getLocalNumEdges", 1) @@ -508,8 +568,8 @@ void testAdapter( TEST_FAIL_AND_EXIT(*comm, numLocalNeighbors==num, "getLocalEdgeList sum size", 1) - fail = ((removeSelfEdges ? size_t(totalLocalNbors-numLocalDiags) - : size_t(totalLocalNbors)) + fail = ((removeSelfEdges ? totalLocalNbors-numLocalDiags + : totalLocalNbors) != numLocalNeighbors); TEST_FAIL_AND_EXIT(*comm, !fail, "getLocalEdgeList total size", 1)