diff --git a/src/bayesoptbase.cpp b/src/bayesoptbase.cpp index ac9edad..10f4067 100644 --- a/src/bayesoptbase.cpp +++ b/src/bayesoptbase.cpp @@ -144,6 +144,10 @@ namespace bayesopt mModel->addSample(xNext,yNext); + // Update surrogate model + bool retrain = ((mParameters.n_iter_relearn > 0) && + ((mCurrentIter + 1) % mParameters.n_iter_relearn == 0)); + if (mUseRobust) { mFilter->addSample(xNext,yNext); @@ -156,13 +160,11 @@ namespace bayesopt { //TODO: Check that the copy is safe. mModel->copyData(mFilter->filterPoints()); + retrain = true; } } - // Update surrogate model - bool retrain = ((mParameters.n_iter_relearn > 0) && - ((mCurrentIter + 1) % mParameters.n_iter_relearn == 0)); if (retrain) // Full update { diff --git a/src/robust_filtering.cpp b/src/robust_filtering.cpp index ea8595e..b7d646b 100644 --- a/src/robust_filtering.cpp +++ b/src/robust_filtering.cpp @@ -37,12 +37,15 @@ namespace bayesopt const Dataset* RobustFiltering::filterPoints() { - mFilteredData.reset(); + mFilteredData.reset(new Dataset()); vecOfvec XX = mRobustModel->getData()->mX; vectord YY = mRobustModel->getData()->mY; size_t n_points = mRobustModel->getData()->getNSamples(); + mRobustModel->updateHyperParameters(); + mRobustModel->fitSurrogateModel(); + for(size_t i = 0; i < n_points; ++i) { ProbabilityDistribution* pd = mRobustModel->getPrediction(XX[i]); @@ -54,6 +57,12 @@ namespace bayesopt utils::append(mFilteredData->mY, YY[i]); } } + if (mFilteredData->getNSamples() <= n_points * 0.5) + { + mFilteredData->mX = XX; + mFilteredData->mY = YY; + } + return mFilteredData.get(); }