Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Belos: reuse stuff #8855

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions packages/belos/src/BelosCGIter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class CGIter : virtual public CGIteration<ScalarType,MV,OP> {
state.P = P_;
state.AP = AP_;
state.Z = Z_;
state.S = S_;
return state;
}

Expand Down Expand Up @@ -253,7 +254,7 @@ class CGIter : virtual public CGIteration<ScalarType,MV,OP> {
// Internal methods
//
//! Method for initalizing the state storage needed by CG.
void setStateSize();
void setStateSize(CGIterationState<ScalarType,MV>& newstate, bool leftAndRightPrec);

//
// Classes inputed through constructor that define the linear problem to be solved.
Expand Down Expand Up @@ -312,6 +313,8 @@ class CGIter : virtual public CGIteration<ScalarType,MV,OP> {

Teuchos::RCP<MV> S_;

Teuchos::RCP<MV> tmp_;

};

//////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -339,7 +342,7 @@ class CGIter : virtual public CGIteration<ScalarType,MV,OP> {
//////////////////////////////////////////////////////////////////////////////////////////////////
// Setup the state storage.
template <class ScalarType, class MV, class OP>
void CGIter<ScalarType,MV,OP>::setStateSize ()
void CGIter<ScalarType,MV,OP>::setStateSize (CGIterationState<ScalarType,MV>& newstate, bool leftAndRightPrec)
{
if (!stateStorageInitialized_) {

Expand All @@ -351,23 +354,46 @@ class CGIter : virtual public CGIteration<ScalarType,MV,OP> {
return;
}
else {

// Initialize the state storage
// If the subspace has not be initialized before, generate it using the LHS or RHS from lp_.
if (R_ == Teuchos::null) {
// Get the multivector that is not null.
Teuchos::RCP<const MV> tmp = ( (rhsMV!=Teuchos::null)? rhsMV: lhsMV );
TEUCHOS_TEST_FOR_EXCEPTION(tmp == Teuchos::null,std::invalid_argument,
Teuchos::RCP<const MV> tmp = ( (rhsMV!=Teuchos::null)? rhsMV: lhsMV );

if (!newstate.S.is_null() && MVT::GetNumberVecs(*newstate.S) == 2) {
S_ = Teuchos::rcp_const_cast<MV>(newstate.S);
R_ = Teuchos::rcp_const_cast<MV>(newstate.R);
Z_ = Teuchos::rcp_const_cast<MV>(newstate.Z);
} else {
TEUCHOS_TEST_FOR_EXCEPTION(tmp == Teuchos::null,std::invalid_argument,
"Belos::CGIter::setStateSize(): linear problem does not specify multivectors to clone from.");
S_ = MVT::Clone( *tmp, 2 );
std::vector<int> index(1,0);
index[0] = 0;
R_ = MVT::CloneViewNonConst( *S_, index );
index[0] = 1;
Z_ = MVT::CloneViewNonConst( *S_, index );
P_ = MVT::Clone( *tmp, 1 );
AP_ = MVT::Clone( *tmp, 1 );

}
if (!newstate.P.is_null() && MVT::GetNumberVecs(*newstate.P) == 1) {
P_ = Teuchos::rcp_const_cast<MV>(newstate.P);
} else {
TEUCHOS_TEST_FOR_EXCEPTION(tmp == Teuchos::null,std::invalid_argument,
"Belos::CGIter::setStateSize(): linear problem does not specify multivectors to clone from.");
P_ = MVT::Clone( *tmp, 1 );
}
if (!newstate.AP.is_null() && MVT::GetNumberVecs(*newstate.AP) == 1) {
AP_ = Teuchos::rcp_const_cast<MV>(newstate.AP);
} else {
TEUCHOS_TEST_FOR_EXCEPTION(tmp == Teuchos::null,std::invalid_argument,
"Belos::CGIter::setStateSize(): linear problem does not specify multivectors to clone from.");
AP_ = MVT::Clone( *tmp, 1 );
}
if (leftAndRightPrec) {
if (!newstate.tmp.is_null() && MVT::GetNumberVecs(*newstate.tmp) == 1) {
tmp_ = Teuchos::rcp_const_cast<MV>(newstate.tmp);
} else {
TEUCHOS_TEST_FOR_EXCEPTION(tmp == Teuchos::null,std::invalid_argument,
"Belos::CGIter::setStateSize(): linear problem does not specify multivectors to clone from.");
tmp_ = MVT::Clone( *tmp, 1 );
}
}

// Tracking information for condition number estimation
Expand All @@ -389,8 +415,7 @@ class CGIter : virtual public CGIteration<ScalarType,MV,OP> {
void CGIter<ScalarType,MV,OP>::initializeCG(CGIterationState<ScalarType,MV>& newstate)
{
// Initialize the state storage if it isn't already.
if (!stateStorageInitialized_)
setStateSize();
setStateSize(newstate, !lp_->getLeftPrec().is_null() && !lp_->getRightPrec().is_null() );

TEUCHOS_TEST_FOR_EXCEPTION(!stateStorageInitialized_,std::invalid_argument,
"Belos::CGIter::initialize(): Cannot initialize state storage!");
Expand Down Expand Up @@ -418,8 +443,8 @@ class CGIter : virtual public CGIteration<ScalarType,MV,OP> {
if ( lp_->getLeftPrec() != Teuchos::null ) {
lp_->applyLeftPrec( *R_, *Z_ );
if ( lp_->getRightPrec() != Teuchos::null ) {
Teuchos::RCP<MV> tmp = MVT::CloneCopy( *Z_ );
lp_->applyRightPrec( *tmp, *Z_ );
MVT::Assign( *Z_, *tmp_ );
lp_->applyRightPrec( *tmp_, *Z_ );
}
}
else if ( lp_->getRightPrec() != Teuchos::null ) {
Expand Down Expand Up @@ -520,8 +545,8 @@ class CGIter : virtual public CGIteration<ScalarType,MV,OP> {
if ( lp_->getLeftPrec() != Teuchos::null ) {
lp_->applyLeftPrec( *R_, *Z_ );
if ( lp_->getRightPrec() != Teuchos::null ) {
Teuchos::RCP<MV> tmp = MVT::CloneCopy( *Z_);
lp_->applyRightPrec( *tmp, *Z_ );
MVT::Assign( *Z_, *tmp_ );
lp_->applyRightPrec( *tmp_, *Z_ );
}
}
else if ( lp_->getRightPrec() != Teuchos::null ) {
Expand Down
20 changes: 19 additions & 1 deletion packages/belos/src/BelosCGIteration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,28 @@ namespace Belos {

/*! \brief The matrix A applied to current decent direction vector */
Teuchos::RCP<const MV> AP;

/*! \brief The current [residual, preconditioned residual]. */
Teuchos::RCP<const MV> S;

/*! \brief Temporary vector needed when left and right preconditioning is used. */
Teuchos::RCP<const MV> tmp;

CGIterationState() : R(Teuchos::null), Z(Teuchos::null),
P(Teuchos::null), AP(Teuchos::null)
P(Teuchos::null), AP(Teuchos::null), S(Teuchos::null),
tmp(Teuchos::null)
{}

CGIterationState&
operator=(const CGIterationState& rhs) {
R = rhs.R;
Z = rhs.Z;
P = rhs.P;
AP = rhs.AP;
S = rhs.S;
tmp = rhs.tmp;
return *this;
}
};

//! @name CGIteration Exceptions
Expand Down
11 changes: 9 additions & 2 deletions packages/belos/src/BelosLinearProblem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -697,8 +697,15 @@ namespace Belos {
solutionUpdated_ = false;
}
else {
curX_ = MVT::CloneViewNonConst( *X_, rhsIndex_ );
curB_ = MVT::CloneView( *B_, rhsIndex_ );
bool trivialSubview = ((Teuchos::as<int>(blocksize_) == MVT::GetNumberVecs( *X_ )) &&
(Teuchos::as<int>(blocksize_) == MVT::GetNumberVecs( *B_ )));
if (trivialSubview) {
curX_ = X_;
curB_ = B_;
} else {
curX_ = MVT::CloneViewNonConst( *X_, rhsIndex_ );
curB_ = MVT::CloneView( *B_, rhsIndex_ );
}
}
//
// Increment the number of linear systems that have been loaded into this object.
Expand Down
20 changes: 15 additions & 5 deletions packages/belos/src/BelosPseudoBlockCGSolMgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ namespace Belos {

// Internal state variables.
bool isSet_;

Teuchos::RCP<CGIterationState<ScalarType,MV> > state_;
};


Expand Down Expand Up @@ -832,13 +834,21 @@ ReturnType PseudoBlockCGSolMgr<ScalarType,MV,OP,true>::solve ()
// Reset the number of calls that the status test output knows about.
outputTest_->resetNumCalls();

// Get a new state struct and initialize the solver.
if (state_.is_null())
state_ = Teuchos::rcp(new CGIterationState<ScalarType,MV>());

// Get the current residual for this block of linear systems.
Teuchos::RCP<MV> R_0 = MVT::CloneViewNonConst( *(Teuchos::rcp_const_cast<MV>(problem_->getInitResVec())), currIdx );
if (state_->R.is_null() || MVT::GetNumberVecs( *state_->R ) != Teuchos::as<int>(currIdx.size()) ) {
state_->R = MVT::CloneViewNonConst( *(Teuchos::rcp_const_cast<MV>(problem_->getInitResVec())), currIdx );
} else {
MVT::SetBlock( *Teuchos::rcp_const_cast<MV>(problem_->getInitResVec()), currIdx, *Teuchos::rcp_const_cast<MV>(state_->R) );
}

// Get a new state struct and initialize the solver.
CGIterationState<ScalarType,MV> newState;
newState.R = R_0;
block_cg_iter->initializeCG(newState);
block_cg_iter->initializeCG(*state_);
*state_ = block_cg_iter->getState();

Teuchos::RCP<MV> R_0;

while(1) {

Expand Down