Skip to content

Commit

Permalink
improved stopping of self tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
rbouckaert committed Oct 24, 2024
1 parent cd16a69 commit 4e4dbc2
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,18 @@ public void restore() {
* the BEAGLE library instance
*/
private Beagle beagle;

public String getProcessor() {
InstanceDetails details = beagle.getDetails();
for (BeagleFlag flag : BeagleFlag.values()) {
if (flag.isSet(details.getFlags())) {
if (flag.getMeaning().contains("processor")) {
return flag.getMeaning();
}
}
}
return null;
}

/**
* Cached log likelihood for each partition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@




import java.util.ArrayList;
import java.util.List;
import java.util.Random;
Expand All @@ -22,8 +23,8 @@
import beast.base.evolution.branchratemodel.BranchRateModel;
import beast.base.evolution.likelihood.BeagleTreeLikelihood.PartialsRescalingScheme;
import beast.base.evolution.likelihood.GenericTreeLikelihood;
import beast.base.evolution.likelihood.TreeLikelihood;
import beast.base.evolution.sitemodel.SiteModel;
import beast.base.evolution.tree.Tree;
import beast.base.evolution.tree.TreeInterface;
import beast.base.inference.Distribution;
import beast.base.inference.State;
Expand All @@ -50,7 +51,7 @@ public class SelfTuningCompoundDistribution extends Distribution {

final public Input<Long> swithcCountInput = new Input<>("switchCount", "number of milli seconds to calculate likelihood before switching configuration", 500l);
final public Input<Long> reconfigCountInput = new Input<>("reconfigCount", "number of times to calculate likelihood before self tuning again", 100000l);

final public Input<Integer> stopAfterSamerResultsInput = new Input<>("stopAfterSamerResults", "number of times the same configuration is optimal in a row before stopping to tune", 3);

final public Input<Boolean> includeMPTLInput = new Input<>("includeMPTL", "include multi-partition (BEAGLE 3) tree likelihood in configurations", true);
final public Input<Boolean> includeSPTLInput = new Input<>("includeSPTL", "include single-partition (BEAGLE 2) tree likelihood in configurations", true);
Expand Down Expand Up @@ -137,18 +138,24 @@ void reset() {

@Override
public String toString() {
return "Multi Partition";
String hardware = mpTreeLikelihood.getProcessor();
return "Multipartition " + hardware;
}

}

private List<Configuration> configurations;
private Configuration bestConfigurationSoFar;
private Configuration bestConfigurationSoFar, lastBestConfiguration = null;
private Configuration currentConfiguration = null;

// represent current configuration
private long switchTime = 0;
private long switchCount;


private boolean keepTuning;
private boolean initialMeasurement;
private int sameOptimumCount;


/**
Expand All @@ -168,6 +175,10 @@ public void initAndValidate() {
throw new IllegalArgumentException("minThreads must be at least 1");
}

if (stopAfterSamerResultsInput.get() < 1) {
throw new IllegalArgumentException("stopAfterSamerResults must be at least 1");
}

useThreads = useThreadsInput.get() && (ProgramStatus.m_nThreads > 1);
maxNrOfThreads = useThreads ? ProgramStatus.m_nThreads : 1;
if (useThreads && maxNrOfThreadsInput.get() > 0) {
Expand All @@ -194,6 +205,10 @@ public void initAndValidate() {
currentConfiguration = initConfigurations(mpTreeLikelihood);
Log.warning("Starting with " + currentConfiguration.toString());

keepTuning = configurations.size() > 1;
sameOptimumCount = 0;
initialMeasurement = true;

switchTime = System.currentTimeMillis();
}

Expand Down Expand Up @@ -286,6 +301,7 @@ private Configuration initConfigurations(MultiPartitionTreeLikelihood mpTreeLike
rd.getNumber();

configurations = new ArrayList<>();

if (mpTreeLikelihood != null) {
configurations.add(new MultiPartitionConfiguration(mpTreeLikelihood));
}
Expand All @@ -297,16 +313,25 @@ private Configuration initConfigurations(MultiPartitionTreeLikelihood mpTreeLike
}
}


return configurations.get(0);
}


private boolean switchConfiguration() {
if (bestConfigurationSoFar != null) {
if (bestConfigurationSoFar != null || !keepTuning) {
// all configurations tried, and best one found
return false;
}

if (initialMeasurement) {
currentConfiguration.nrOfSamples = 1;
switchTime = System.currentTimeMillis();
initialMeasurement = false;
Log.warning("Start timing " + currentConfiguration.toString());
return true;
}

long endTime = System.currentTimeMillis();
currentConfiguration.totalRunTime += endTime - switchTime;

Expand All @@ -318,14 +343,27 @@ private boolean switchConfiguration() {
bestConfigurationSoFar = cfg0;
for (Configuration cfg : configurations) {
double score = cfg.totalRunTime / cfg.nrOfSamples;
Log.warning(cfg.toString() + ": " + cfg.totalRunTime + "/" + cfg.nrOfSamples + " " + cfg.totalRunTime / cfg.nrOfSamples);
Log.warning(cfg.toString() + ": " + cfg.totalRunTime + "/" + cfg.nrOfSamples + " = " + cfg.totalRunTime / cfg.nrOfSamples);
if (score < best) {
bestConfigurationSoFar = cfg;
best = score;
}
}

currentConfiguration = bestConfigurationSoFar;

if (lastBestConfiguration == bestConfigurationSoFar) {
// stop tuning if consecutive tuning sessions gave the same configuration
sameOptimumCount++;
if (sameOptimumCount == stopAfterSamerResultsInput.get()-1) {
keepTuning = false;
}
} else {
sameOptimumCount = 0;
keepTuning = true;
}
lastBestConfiguration = bestConfigurationSoFar;

} else {
// continue with next configuration
currentConfiguration = configurations.get(i);
Expand All @@ -343,7 +381,11 @@ private boolean switchConfiguration() {
currentConfiguration.nrOfSamples = 1;
// currentConfiguration.totalRunTime = 0;

Log.warning("Switching to " + currentConfiguration.toString());
if (keepTuning) {
Log.warning((i != configurations.size() ? "Start timing ": "Using ") + currentConfiguration.toString());
} else {
Log.warning("Settling for " + currentConfiguration.toString() + " -- tuning finished");
}
switchTime = System.currentTimeMillis();
return true;
}
Expand Down Expand Up @@ -376,6 +418,12 @@ public double calculateLogP() {
}

private void restartTuning() {
if (!keepTuning) {
return;
}

lastBestConfiguration = bestConfigurationSoFar;

bestConfigurationSoFar = null;

currentConfiguration = configurations.get(0);
Expand All @@ -392,7 +440,7 @@ private void restartTuning() {
cfg.nrOfSamples = 0;
cfg.totalRunTime = 0;
}
Log.warning("Switching to " + currentConfiguration.toString());
Log.warning("Start timing " + currentConfiguration.toString());
switchTime = System.currentTimeMillis();
}

Expand Down

0 comments on commit 4e4dbc2

Please sign in to comment.