Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Fixing bugs in filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
rmcantin committed Jan 16, 2018
1 parent 93943e3 commit 4b9d322
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
8 changes: 5 additions & 3 deletions src/bayesoptbase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ namespace bayesopt

mModel->addSample(xNext,yNext);

// Update surrogate model
bool retrain = ((mParameters.n_iter_relearn > 0) &&
((mCurrentIter + 1) % mParameters.n_iter_relearn == 0));

if (mUseRobust)
{
mFilter->addSample(xNext,yNext);
Expand All @@ -156,13 +160,11 @@ namespace bayesopt
{
//TODO: Check that the copy is safe.
mModel->copyData(mFilter->filterPoints());
retrain = true;
}

}

// Update surrogate model
bool retrain = ((mParameters.n_iter_relearn > 0) &&
((mCurrentIter + 1) % mParameters.n_iter_relearn == 0));

if (retrain) // Full update
{
Expand Down
11 changes: 10 additions & 1 deletion src/robust_filtering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@ namespace bayesopt

const Dataset* RobustFiltering::filterPoints()
{
mFilteredData.reset();
mFilteredData.reset(new Dataset());
vecOfvec XX = mRobustModel->getData()->mX;
vectord YY = mRobustModel->getData()->mY;

size_t n_points = mRobustModel->getData()->getNSamples();

mRobustModel->updateHyperParameters();
mRobustModel->fitSurrogateModel();

for(size_t i = 0; i < n_points; ++i)
{
ProbabilityDistribution* pd = mRobustModel->getPrediction(XX[i]);
Expand All @@ -54,6 +57,12 @@ namespace bayesopt
utils::append(mFilteredData->mY, YY[i]);
}
}
if (mFilteredData->getNSamples() <= n_points * 0.5)
{
mFilteredData->mX = XX;
mFilteredData->mY = YY;
}

return mFilteredData.get();
}

Expand Down

0 comments on commit 4b9d322

Please sign in to comment.