Skip to content

Commit

Permalink
chg: fix: reformat code; change how we sample indices to test avoid s…
Browse files Browse the repository at this point in the history
…egfault when n < 10000
  • Loading branch information
Chenhan Yu committed Mar 6, 2019
1 parent f95c22b commit e99a655
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 130 deletions.
44 changes: 22 additions & 22 deletions gofmm/gofmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2625,7 +2625,7 @@ hmlpError_t FindFarNodes( NODE *node, NODE *target )
NearNodes = &target->NearNodes;

/** If this node contains any Near( target ) or isn't skeletonized */
if ( !data.is_compressed || node->ContainAny( *NearNodes ) )
if ( !data.is_compressed || node->containAnyNodePointer( *NearNodes ) )
{
if ( !node->isLeaf() )
{
Expand All @@ -2650,7 +2650,7 @@ hmlpError_t FindFarNodes( NODE *node, NODE *target )
NearNodes = &target->NNNearNodes;

/** If this node contains any Near( target ) or isn't skeletonized */
if ( !data.is_compressed || node->ContainAny( *NearNodes ) )
if ( !data.is_compressed || node->containAnyNodePointer( *NearNodes ) )
{
if ( !node->isLeaf() )
{
Expand Down Expand Up @@ -2946,7 +2946,7 @@ hmlpError_t Evaluate( NODE *node, const size_t gid, Data<T> & potentials, const

assert( potentials.size() == nrhs );

if ( !data.is_compressed || node->ContainAny( neighbors ) )
if ( !data.is_compressed || node->containAnyGlobalIndex( neighbors ) )
{
auto I = vector<size_t>( 1, gid );
auto & J = node->gids;
Expand Down Expand Up @@ -3558,31 +3558,31 @@ tree::Tree< gofmm::Setup<SPDMATRIX, SPLITTER, T>, gofmm::NodeData<T>>
*/
template<typename T, typename SPDMATRIX>
tree::Tree<
gofmm::Setup<SPDMATRIX, centersplit<SPDMATRIX, 2, T>, T>,
gofmm::Setup<SPDMATRIX, centersplit<SPDMATRIX, 2, T>, T>,
gofmm::NodeData<T>>
*Compress( SPDMATRIX &K, T stol, T budget, size_t m, size_t k, size_t s )
{
using SPLITTER = centersplit<SPDMATRIX, 2, T>;
using RKDTSPLITTER = randomsplit<SPDMATRIX, 2, T>;
Data<pair<T, size_t>> NN;
/** GOFMM tree splitter */
/** GOFMM tree splitter */
SPLITTER splitter( K );
splitter.Kptr = &K;
splitter.metric = ANGLE_DISTANCE;
/** randomized tree splitter */
splitter.metric = ANGLE_DISTANCE;
/** randomized tree splitter */
RKDTSPLITTER rkdtsplitter( K );
rkdtsplitter.Kptr = &K;
rkdtsplitter.metric = ANGLE_DISTANCE;
rkdtsplitter.metric = ANGLE_DISTANCE;
size_t n = K.row();

/** creatgin configuration for all user-define arguments */
Configuration<T> config( ANGLE_DISTANCE, n, m, k, s, stol, budget );
/** creatgin configuration for all user-define arguments */
Configuration<T> config( ANGLE_DISTANCE, n, m, k, s, stol, budget );

/** call the complete interface and return tree_ptr */
/** call the complete interface and return tree_ptr */
return Compress<SPLITTER, RKDTSPLITTER>
( K, NN, //ANGLE_DISTANCE,
splitter, rkdtsplitter, //n, m, k, s, stol, budget,
config );
( K, NN, //ANGLE_DISTANCE,
splitter, rkdtsplitter, //n, m, k, s, stol, budget,
config );
}; /** end Compress */


Expand All @@ -3604,11 +3604,11 @@ tree::Tree<
using SPLITTER = centersplit<SPDMATRIX, 2, T>;
using RKDTSPLITTER = randomsplit<SPDMATRIX, 2, T>;
Data<pair<T, std::size_t>> NN;
/** GOFMM tree splitter */
/** GOFMM tree splitter */
SPLITTER splitter( K );
splitter.Kptr = &K;
splitter.metric = ANGLE_DISTANCE;
/** randomized tree splitter */
/** randomized tree splitter */
RKDTSPLITTER rkdtsplitter( K );
rkdtsplitter.Kptr = &K;
rkdtsplitter.metric = ANGLE_DISTANCE;
Expand Down Expand Up @@ -3654,11 +3654,11 @@ tree::Tree<
*/
template<typename T>
tree::Tree<
gofmm::Setup<SPDMatrix<T>, centersplit<SPDMatrix<T>, 2, T>, T>,
gofmm::Setup<SPDMatrix<T>, centersplit<SPDMatrix<T>, 2, T>, T>,
gofmm::NodeData<T>>
*Compress( SPDMatrix<T> &K, T stol, T budget )
{
return Compress<T, SPDMatrix<T>>( K, stol, budget );
return Compress<T, SPDMatrix<T>>( K, stol, budget );
}; /** end Compress() */


Expand All @@ -3667,7 +3667,7 @@ tree::Tree<



template<typename NODE, typename T>
template<typename NODE, typename T>
void ComputeError( NODE *node, Data<T> potentials )
{
auto &K = *node->setup->K;
Expand All @@ -3681,7 +3681,7 @@ void ComputeError( NODE *node, Data<T> potentials )
auto Kab = K( amap, bmap );

auto nrm2 = hmlp_norm( potentials.row(), potentials.col(),
potentials.data(), potentials.row() );
potentials.data(), potentials.row() );

xgemm
(
Expand Down Expand Up @@ -3774,8 +3774,8 @@ hmlpError_t SelfTesting( TREE &tree, size_t ntest, size_t nrhs )
printf( "========================================================\n");
for ( size_t i = 0; i < ntest; i ++ )
{
//size_t tar = i * n / ntest;
size_t tar = i * 1000;
size_t tar = i * n / ntest;
//size_t tar = i * 1000;
Data<T> potentials;
/** ASKIT treecode with NN pruning. */
RETURN_IF_ERROR( Evaluate( tree, tar, potentials, EVALUATE_OPTION_NEIGHBOR_PRUNING ) );
Expand Down
98 changes: 49 additions & 49 deletions gofmm/gofmm_mpi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1635,8 +1635,8 @@ hmlpError_t FindFarNodes( const MortonHelper::Recursor r, NODE *target )
auto & compression_failure_frontier = target->setup->compression_failure_frontier_;

/** Recur to children if the current node contains near interactions. */
if ( MortonHelper::ContainAny( node_morton, NearMortonIDs ) ||
MortonHelper::ContainAny( node_morton, compression_failure_frontier ) )
if ( MortonHelper::containAny( node_morton, NearMortonIDs ) ||
MortonHelper::containAny( node_morton, compression_failure_frontier ) )
{
RETURN_IF_ERROR( FindFarNodes( MortonHelper::RecurLeft( r ), target ) );
RETURN_IF_ERROR( FindFarNodes( MortonHelper::RecurRight( r ), target ) );
Expand Down Expand Up @@ -3107,10 +3107,10 @@ void DistRowSamples( NODE *node, size_t nsamples )
vector<size_t> &I = node->data.candidate_rows;

/** Clean up candidates from previous iteration */
I.clear();
I.clear();

/** Fill-on snids first */
if ( rank == 0 )
if ( rank == 0 )
{
/** reserve space */
I.reserve( nsamples );
Expand All @@ -3119,68 +3119,68 @@ void DistRowSamples( NODE *node, size_t nsamples )
multimap<T, size_t> ordered_snids = gofmm::flip_map( snids );

for ( auto it = ordered_snids.begin();
it != ordered_snids.end(); it++ )
it != ordered_snids.end(); it++ )
{
/** (*it) has type pair<T, size_t> */
I.push_back( (*it).second );
if ( I.size() >= nsamples ) break;
}
}

/** buffer space */
vector<size_t> candidates( nsamples );
/** buffer space */
vector<size_t> candidates( nsamples );

size_t n_required = nsamples - I.size();
size_t n_required = nsamples - I.size();

/** bcast the termination criteria */
mpi::Bcast( &n_required, 1, 0, comm );
/** bcast the termination criteria */
mpi::Bcast( &n_required, 1, 0, comm );

while ( n_required )
{
if ( rank == 0 )
{
for ( size_t i = 0; i < nsamples; i ++ )
while ( n_required )
{
if ( rank == 0 )
{
for ( size_t i = 0; i < nsamples; i ++ )
{
auto important_sample = K.ImportantSample( 0 );
candidates[ i ] = important_sample.second;
}
}
}

/** Bcast candidates */
mpi::Bcast( candidates.data(), candidates.size(), 0, comm );
/** Bcast candidates */
mpi::Bcast( candidates.data(), candidates.size(), 0, comm );

/** validation */
vector<size_t> vconsensus( nsamples, 0 );
vector<size_t> validation = node->setup->ContainAny( candidates, node->getMortonID() );
/** validation */
vector<size_t> vconsensus( nsamples, 0 );
vector<size_t> validation = node->setup->ContainAny( candidates, node->getMortonID() );

/** reduce validation */
mpi::Reduce( validation.data(), vconsensus.data(), nsamples, MPI_SUM, 0, comm );
/** reduce validation */
mpi::Reduce( validation.data(), vconsensus.data(), nsamples, MPI_SUM, 0, comm );

if ( rank == 0 )
{
for ( size_t i = 0; i < nsamples; i ++ )
{
/** exit is there is enough samples */
if ( I.size() >= nsamples )
{
I.resize( nsamples );
break;
}
/** Push the candidate to I after validation */
if ( !vconsensus[ i ] )
{
if ( find( I.begin(), I.end(), candidates[ i ] ) == I.end() )
I.push_back( candidates[ i ] );
}
};
if ( rank == 0 )
{
for ( size_t i = 0; i < nsamples; i ++ )
{
/** exit is there is enough samples */
if ( I.size() >= nsamples )
{
I.resize( nsamples );
break;
}
/** Push the candidate to I after validation */
if ( !vconsensus[ i ] )
{
if ( find( I.begin(), I.end(), candidates[ i ] ) == I.end() )
I.push_back( candidates[ i ] );
}
};

/** Update n_required */
n_required = nsamples - I.size();
}
/** Update n_required */
n_required = nsamples - I.size();
}

/** Bcast the termination criteria */
mpi::Bcast( &n_required, 1, 0, comm );
}
/** Bcast the termination criteria */
mpi::Bcast( &n_required, 1, 0, comm );
}

}; /** end DistRowSamples() */

Expand Down Expand Up @@ -4078,12 +4078,12 @@ mpitree::Tree<mpigofmm::Setup<SPDMATRIX, SPLITTER, T>, gofmm::NodeData<T>>

/** Initialize metric ball tree using approximate center split. */
auto *tree_ptr = new TREE( CommGOFMM );
auto &tree = *tree_ptr;
auto &tree = *tree_ptr;

/** Global configuration for the metric tree. */
/** Global configuration for the metric tree. */
tree.setup.FromConfiguration( config, K, splitter, &NN_cblk );

/** Metric ball tree partitioning. */
/** Metric ball tree partitioning. */
beg = omp_get_wtime();
tree.TreePartition();
tree_time = omp_get_wtime() - beg;
Expand Down
61 changes: 12 additions & 49 deletions gofmm/tree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class MortonHelper
* \return true if at lease one query contains the target.
*/
template<typename TQUERY>
static bool ContainAny( mortonType target, TQUERY &querys )
static bool containAny( mortonType target, const TQUERY& querys )
{
for ( auto & q : querys )
{
Expand Down Expand Up @@ -757,32 +757,31 @@ class Node : public ReadWrite
* needs to be accessed using gids.
*
*/
bool ContainAny( const std::vector<size_t> & queries )
bool containAnyGlobalIndex( const std::vector<indexType> & queries )
{
if ( !setup->morton.size() )
{
printf( "Morton id was not initialized.\n" );
exit( 1 );
throw std::out_of_range( "MortonID was not initialized" );
}
for ( size_t i = 0; i < queries.size(); i ++ )
for ( auto gid : queries )
{
if ( MortonHelper::IsMyParent( setup->morton[ queries[ i ] ], getMortonID() ) )
if ( MortonHelper::IsMyParent( setup->morton[ gid ], getMortonID() ) )
{
#ifdef DEBUG_TREE
printf( "\n" );
hmlp_print_binary( setup->morton[ queries[ i ] ] );
hmlp_print_binary( morton );
hmlp_print_binary( setup->morton[ queries[ gid ] ] );
hmlp_print_binary( morton_ );
printf( "\n" );
#endif
return true;
}
}
/* Other return false as not containing any index in queries. */
return false;

}; /** end ContainAny() */
}; /* end containAnyGlobalIndex() */


bool ContainAny( set<Node*> &querys )
bool containAnyNodePointer( set<Node*> &querys )
{
if ( !setup->morton.size() )
{
Expand All @@ -798,7 +797,7 @@ class Node : public ReadWrite
}
return false;

}; /** end ContainAny() */
}; /** end ContainAnyNodePointer() */


void Print()
Expand Down Expand Up @@ -980,43 +979,7 @@ class Setup
/** Tree splitter */
SPLITTER splitter;

/**
* @brief Check if this node contain any query using morton.
* Notice that queries[] contains gids; thus, morton[]
* needs to be accessed using gids.
*
*/
vector<size_t> ContainAny( vector<size_t> &queries, size_t target )
{
vector<size_t> validation( queries.size(), 0 );

if ( !morton.size() )
{
printf( "Morton id was not initialized.\n" );
exit( 1 );
}

for ( size_t i = 0; i < queries.size(); i ++ )
{
/** notice that setup->morton only contains local morton ids */
//auto it = this->setup->morton.find( queries[ i ] );

//if ( it != this->setup->morton.end() )
//{
// if ( tree::IsMyParent( *it, this->morton ) ) validation[ i ] = 1;
//}


//if ( tree::IsMyParent( morton[ queries[ i ] ], target ) )
if ( MortonHelper::IsMyParent( morton[ queries[ i ] ], target ) )
validation[ i ] = 1;

}
return validation;

}; /** end ContainAny() */

}; /** end class Setup */
}; /* end class Setup */


/** */
Expand Down
Loading

0 comments on commit e99a655

Please sign in to comment.