diff --git a/include/RbtAlignTransform.h b/include/RbtAlignTransform.h index cb87158..05e13a1 100644 --- a/include/RbtAlignTransform.h +++ b/include/RbtAlignTransform.h @@ -27,9 +27,20 @@ class RbtAlignTransform: public RbtBaseBiMolTransform { static RbtString _COM; static RbtString _AXES; + enum LigandCenterOfMassPlacementStrategy { COM_ALIGN, COM_RANDOM }; + enum LigandAxesAlignmentStrategy { AXES_ALIGN, AXES_RANDOM }; + + struct Config { + LigandCenterOfMassPlacementStrategy center_of_mass_placement_strategy{ + LigandCenterOfMassPlacementStrategy::COM_ALIGN}; + LigandAxesAlignmentStrategy axes_alignment_strategy{LigandAxesAlignmentStrategy::AXES_ALIGN}; + }; + + static const Config DEFAULT_CONFIG; + //////////////////////////////////////// // Constructors/destructors - RbtAlignTransform(const RbtString& strName = "ALIGN"); + RbtAlignTransform(const RbtString& strName, const Config& config); virtual ~RbtAlignTransform(); //////////////////////////////////////// @@ -65,6 +76,8 @@ class RbtAlignTransform: public RbtBaseBiMolTransform { RbtCavityList m_cavities; // List of active site cavities to choose from RbtIntList m_cumulSize; // Cumulative sizes, for weighted probabilities RbtInt m_totalSize; // Total size of all cavities + + const Config config; }; // Useful typedefs diff --git a/include/RbtGATransform.h b/include/RbtGATransform.h index 186dde2..8b8897d 100644 --- a/include/RbtGATransform.h +++ b/include/RbtGATransform.h @@ -43,10 +43,24 @@ class RbtGATransform: public RbtBaseBiMolTransform { // Output the best pose every _HISTORY_FREQ cycles. static RbtString _HISTORY_FREQ; + struct Config { + RbtDouble population_size_fraction_as_new_individuals_per_cycle{0.5}; + RbtDouble crossover_probability{0.4}; + RbtBool cauchy_mutation_after_crossover{true}; + RbtBool use_cauchy_distribution_for_mutations{false}; // We might want to make this an enum + RbtDouble relative_step_size{1.0}; + RbtDouble equality_threshold{0.1}; + RbtInt max_cycles{100}; + RbtInt num_convergence_cycles{6}; + RbtInt history_frequency{0}; + }; + + static const Config DEFAULT_CONFIG; + //////////////////////////////////////// // Constructors/destructors //////////////////////////////////////// - RbtGATransform(const RbtString& strName = "GAGENRW"); + RbtGATransform(const RbtString& strName, const Config& config); virtual ~RbtGATransform(); protected: @@ -67,6 +81,7 @@ class RbtGATransform: public RbtBaseBiMolTransform { private: RbtRand& m_rand; + const Config config; }; #endif //_RBTGATRANSFORM_H_ diff --git a/include/RbtParameterFileSource.h b/include/RbtParameterFileSource.h index b1b2f0b..81ded4e 100644 --- a/include/RbtParameterFileSource.h +++ b/include/RbtParameterFileSource.h @@ -42,9 +42,20 @@ class RbtParameterFileSource: public RbtBaseFileSource { RbtDouble GetParameterValue(const RbtString& strParamName); // DM 12 Feb 1999 Get a particular named parameter value as a string RbtString GetParameterValueAsString(const RbtString& strParamName); + + RbtVariant GetParameterValueAsVariant(const RbtString& strParamName); // DM 11 Feb 1999 Check if parameter is present RbtBool isParameterPresent(const RbtString& strParamName); + template + ParamType GetParamOrDefault(const RbtString& paramName, const ParamType& defaultValue) { + if (isParameterPresent(paramName)) { + return GetParameterValueAsVariant(paramName); + } else { + return defaultValue; + } + } + // DM 11 Feb 1999 - section handling // Parameters can be grouped into named sections // such that the same parameter name can appear in multiple sections diff --git a/include/RbtRandLigTransform.h b/include/RbtRandLigTransform.h index 479d80c..dd8a9a7 100644 --- a/include/RbtRandLigTransform.h +++ b/include/RbtRandLigTransform.h @@ -26,9 +26,16 @@ class RbtRandLigTransform: public RbtBaseUniMolTransform { // Parameter names static RbtString _TORS_STEP; + struct Config { + RbtDouble torsion_step{180.0}; + }; + + static const Config DEFAULT_CONFIG; + //////////////////////////////////////// // Constructors/destructors RbtRandLigTransform(const RbtString& strName = "RANDLIG"); + RbtRandLigTransform(const RbtString& strName, const Config& config); virtual ~RbtRandLigTransform(); //////////////////////////////////////// @@ -60,6 +67,8 @@ class RbtRandLigTransform: public RbtBaseUniMolTransform { ////////////// RbtRand& m_rand; // keep a reference to the singleton random number generator RbtBondList m_rotableBonds; + + const Config config; }; // Useful typedefs diff --git a/include/RbtRandPopTransform.h b/include/RbtRandPopTransform.h index 8e90c72..e082dc7 100644 --- a/include/RbtRandPopTransform.h +++ b/include/RbtRandPopTransform.h @@ -23,9 +23,14 @@ class RbtRandPopTransform: public RbtBaseBiMolTransform { static RbtString _POP_SIZE; static RbtString _SCALE_CHROM_LENGTH; - //////////////////////////////////////// - // Constructors/destructors - RbtRandPopTransform(const RbtString& strName = "RANDPOP"); + struct Config { + RbtInt population_size{50}; + RbtBool scale_chromosome_length{true}; + }; + + static const Config DEFAULT_CONFIG; + + RbtRandPopTransform(const RbtString& strName, const Config& config); virtual ~RbtRandPopTransform(); //////////////////////////////////////// @@ -59,6 +64,8 @@ class RbtRandPopTransform: public RbtBaseBiMolTransform { // Private data ////////////// RbtChromElementPtr m_chrom; + + const Config config; }; // Useful typedefs diff --git a/include/RbtSimAnnTransform.h b/include/RbtSimAnnTransform.h index 1d863ce..ba98d76 100644 --- a/include/RbtSimAnnTransform.h +++ b/include/RbtSimAnnTransform.h @@ -59,9 +59,25 @@ class RbtSimAnnTransform: public RbtBaseBiMolTransform { static RbtString _PARTITION_FREQ; static RbtString _HISTORY_FREQ; + struct Config { + RbtDouble initial_temp{1000.0}; + RbtDouble final_temp{300.0}; + RbtInt num_blocks{25}; + RbtInt block_length{50}; + RbtBool scale_chromosome_length{true}; + RbtDouble step_size{1.0}; + RbtDouble min_accuracy_rate{0.25}; + RbtDouble partition_distance{0.0}; + RbtInt partition_frequency{0}; + RbtInt history_frequency{0}; + }; + + static const Config DEFAULT_CONFIG; + //////////////////////////////////////// // Constructors/destructors - RbtSimAnnTransform(const RbtString& strName = "SIMANN"); + + RbtSimAnnTransform(const RbtString& strName, const Config& config); virtual ~RbtSimAnnTransform(); //////////////////////////////////////// @@ -100,6 +116,8 @@ class RbtSimAnnTransform: public RbtBaseBiMolTransform { RbtChromElementPtr m_chrom; // Current chromosome RbtDoubleList m_minVector; // Chromosome vector corresponding to overall minimum score RbtDoubleList m_lastGoodVector; // Saved chromosome before each MC mutation (to allow revert) + + const Config config; }; // Useful typedefs diff --git a/include/RbtSimplexTransform.h b/include/RbtSimplexTransform.h index b37a8bf..19f5fd2 100644 --- a/include/RbtSimplexTransform.h +++ b/include/RbtSimplexTransform.h @@ -31,9 +31,18 @@ class RbtSimplexTransform: public RbtBaseBiMolTransform { // between cycles static RbtString _CONVERGENCE; - //////////////////////////////////////// - // Constructors/destructors - RbtSimplexTransform(const RbtString& strName = "SIMPLEX"); + struct Config { + RbtInt max_calls{200}; + RbtInt num_cycles{5}; + RbtDouble stopping_step_length{10e-4}; + RbtDouble convergence_threshold{0.001}; + RbtDouble step_size{0.1}; + RbtDouble partition_distribution{0.0}; + }; + + static const Config DEFAULT_CONFIG; + + RbtSimplexTransform(const RbtString& strName, const Config& config); virtual ~RbtSimplexTransform(); //////////////////////////////////////// @@ -66,6 +75,8 @@ class RbtSimplexTransform: public RbtBaseBiMolTransform { // Private data ////////////// RbtChromElementPtr m_chrom; + + const Config config; }; // Useful typedefs diff --git a/include/RbtTransformFactory.h b/include/RbtTransformFactory.h index ea105ff..f19ebd8 100644 --- a/include/RbtTransformFactory.h +++ b/include/RbtTransformFactory.h @@ -37,10 +37,6 @@ class RbtTransformFactory { // Public methods //////////////// - // Creates a single transform object of type strTransformClass, and name strName - // e.g. strTransformClass = RbtSimAnnTransform - virtual RbtBaseTransform* Create(const RbtString& strTransformClass, const RbtString& strName); - // Creates an aggregate transform from a parameter file source // Each component transform is in a named section, which should minimally contain a TRANSFORM parameter // whose value is the class name to instantiate @@ -64,9 +60,9 @@ class RbtTransformFactory { // Private methods ///////////////// - RbtTransformFactory(const RbtTransformFactory&); // Copy constructor disabled by default - + RbtTransformFactory(const RbtTransformFactory&); // Copy constructor disabled by default RbtTransformFactory& operator=(const RbtTransformFactory&); // Copy assignment disabled by default + RbtBaseTransform* MakeTransformFromFile(RbtParameterFileSourcePtr paramsPtr, const RbtString& name); protected: //////////////////////////////////////// diff --git a/src/lib/RbtAlignTransform.cxx b/src/lib/RbtAlignTransform.cxx index 4f080da..ffb887e 100644 --- a/src/lib/RbtAlignTransform.cxx +++ b/src/lib/RbtAlignTransform.cxx @@ -21,15 +21,14 @@ RbtString RbtAlignTransform::_CT("RbtAlignTransform"); RbtString RbtAlignTransform::_COM("COM"); RbtString RbtAlignTransform::_AXES("AXES"); -//////////////////////////////////////// -// Constructors/destructors -RbtAlignTransform::RbtAlignTransform(const RbtString& strName): +const RbtAlignTransform::Config + RbtAlignTransform::DEFAULT_CONFIG{}; // Empty initializer to fall back to default values + +RbtAlignTransform::RbtAlignTransform(const RbtString& strName, const Config& config): RbtBaseBiMolTransform(_CT, strName), m_rand(Rbt::GetRbtRand()), - m_totalSize(0) { - // Add parameters - AddParameter(_COM, "ALIGN"); - AddParameter(_AXES, "ALIGN"); + m_totalSize(0), + config{config} { #ifdef _DEBUG cout << _CT << " parameterised constructor" << endl; #endif //_DEBUG @@ -89,64 +88,57 @@ void RbtAlignTransform::Execute() { const RbtCoordList& coordList = spCavity->GetCoordList(); const RbtPrincipalAxes& prAxes = spCavity->GetPrincipalAxes(); - // Get the alignment parameters - RbtString strPlaceCOM = GetParameter(_COM); - RbtString strPlaceAxes = GetParameter(_AXES); - - // 1. Ligand translation - // A. Random - if ((strPlaceCOM == "RANDOM") && !coordList.empty()) { - // Select a coord at random - RbtInt iRand = m_rand.GetRandomInt(coordList.size()); - RbtCoord asCavityCoord = coordList[iRand]; - if (iTrace > 1) { - cout << "Translating ligand COM to active site coord #" << iRand << ": " << asCavityCoord << endl; - } - // Translate the ligand center of mass to the selected coord - spLigand->SetCenterOfMass(asCavityCoord); - } - // B. Active site center of mass - else if (strPlaceCOM == "ALIGN") { - if (iTrace > 1) { - cout << "Translating ligand COM to active site COM: " << prAxes.com << endl; - } - spLigand->SetCenterOfMass(prAxes.com); + switch (config.center_of_mass_placement_strategy) { + case COM_RANDOM: + if (!coordList.empty()) { + RbtInt iRand = m_rand.GetRandomInt(coordList.size()); + RbtCoord asCavityCoord = coordList[iRand]; + if (iTrace > 1) { + cout << "Translating ligand COM to active site coord #" << iRand << ": " << asCavityCoord << endl; + } + // Translate the ligand center of mass to the selected coord + spLigand->SetCenterOfMass(asCavityCoord); + } + break; + case COM_ALIGN: + if (iTrace > 1) cout << "Translating ligand COM to active site COM: " << prAxes.com << endl; + spLigand->SetCenterOfMass(prAxes.com); + break; + default: + throw RbtBadArgument(_WHERE_, "Bad ligand center of mass placement strategy"); + break; } - // 2. Ligand axes - // A. Random rotation around random axis - if (strPlaceAxes == "RANDOM") { - RbtDouble thetaDeg = 180.0 * m_rand.GetRandom01(); - RbtCoord axis = m_rand.GetRandomUnitVector(); - if (iTrace > 1) { - cout << "Rotating ligand by " << thetaDeg << " deg around axis=" << axis << " through COM" << endl; - } - spLigand->Rotate(axis, thetaDeg); - } - // B. Align ligand principal axes with principal axes of active site - else if (strPlaceAxes == "ALIGN") { - spLigand->AlignPrincipalAxes(prAxes, false); // false = don't translate COM as we've already done it above - if (iTrace > 1) { - cout << "Aligning ligand principal axes with active site principal axes" << endl; - } - // Make random 180 deg rotations around each of the principal axes - if (m_rand.GetRandom01() < 0.5) { - spLigand->Rotate(prAxes.axis1, 180.0, prAxes.com); - if (iTrace > 1) { - cout << "180 deg rotation around PA#1" << endl; + switch (config.axes_alignment_strategy) { + case AXES_RANDOM: + { // Braces to prevent compiler complains due to variables initialized inside this case + + RbtDouble thetaDeg = 180.0 * m_rand.GetRandom01(); + RbtCoord axis = m_rand.GetRandomUnitVector(); + if (iTrace > 1) + cout << "Rotating ligand by " << thetaDeg << " deg around axis=" << axis << " through COM" << endl; + spLigand->Rotate(axis, thetaDeg); + break; + } + case AXES_ALIGN: + spLigand->AlignPrincipalAxes(prAxes, false); // false = don't translate COM as we've already done it above + if (iTrace > 1) cout << "Aligning ligand principal axes with active site principal axes" << endl; + // Make random 180 deg rotations around each of the principal axes + if (m_rand.GetRandom01() < 0.5) { + spLigand->Rotate(prAxes.axis1, 180.0, prAxes.com); + if (iTrace > 1) cout << "180 deg rotation around PA#1" << endl; } - } - if (m_rand.GetRandom01() < 0.5) { - spLigand->Rotate(prAxes.axis2, 180.0, prAxes.com); - if (iTrace > 1) { - cout << "180 deg rotation around PA#2" << endl; + if (m_rand.GetRandom01() < 0.5) { + spLigand->Rotate(prAxes.axis2, 180.0, prAxes.com); + if (iTrace > 1) cout << "180 deg rotation around PA#2" << endl; } - } - if (m_rand.GetRandom01() < 0.5) { - spLigand->Rotate(prAxes.axis3, 180.0, prAxes.com); - if (iTrace > 1) { - cout << "180 deg rotation around PA#3" << endl; + if (m_rand.GetRandom01() < 0.5) { + spLigand->Rotate(prAxes.axis3, 180.0, prAxes.com); + if (iTrace > 1) cout << "180 deg rotation around PA#3" << endl; } - } + break; + default: + throw RbtBadArgument(_WHERE_, "Bad ligand axes alignment strategy"); + break; } } diff --git a/src/lib/RbtGATransform.cxx b/src/lib/RbtGATransform.cxx index 56131cc..39e2289 100644 --- a/src/lib/RbtGATransform.cxx +++ b/src/lib/RbtGATransform.cxx @@ -30,18 +30,12 @@ RbtString RbtGATransform::_NCYCLES("NCYCLES"); RbtString RbtGATransform::_NCONVERGENCE("NCONVERGENCE"); RbtString RbtGATransform::_HISTORY_FREQ("HISTORY_FREQ"); -RbtGATransform::RbtGATransform(const RbtString& strName): +const RbtGATransform::Config RbtGATransform::DEFAULT_CONFIG{}; // Empty initializer to fall back to default values + +RbtGATransform::RbtGATransform(const RbtString& strName, const Config& config): RbtBaseBiMolTransform(_CT, strName), - m_rand(Rbt::GetRbtRand()) { - AddParameter(_NEW_FRACTION, 0.5); - AddParameter(_PCROSSOVER, 0.4); - AddParameter(_XOVERMUT, true); - AddParameter(_CMUTATE, false); - AddParameter(_STEP_SIZE, 1.0); - AddParameter(_EQUALITY_THRESHOLD, 0.1); - AddParameter(_NCYCLES, 100); - AddParameter(_NCONVERGENCE, 6); - AddParameter(_HISTORY_FREQ, 0); + m_rand(Rbt::GetRbtRand()), + config{config} { _RBTOBJECTCOUNTER_CONSTR_(_CT); } @@ -73,19 +67,9 @@ void RbtGATransform::Execute() { // the scoring function has changed pop->SetSF(pSF); - RbtDouble newFraction = GetParameter(_NEW_FRACTION); - RbtDouble pcross = GetParameter(_PCROSSOVER); - RbtBool xovermut = GetParameter(_XOVERMUT); - RbtBool cmutate = GetParameter(_CMUTATE); - RbtDouble relStepSize = GetParameter(_STEP_SIZE); - RbtDouble equalityThreshold = GetParameter(_EQUALITY_THRESHOLD); - RbtInt nCycles = GetParameter(_NCYCLES); - RbtInt nConvergence = GetParameter(_NCONVERGENCE); - RbtInt nHisFreq = GetParameter(_HISTORY_FREQ); - RbtInt popsize = pop->GetMaxSize(); - RbtInt nrepl = newFraction * popsize; - RbtBool bHistory = nHisFreq > 0; + RbtInt nrepl = config.population_size_fraction_as_new_individuals_per_cycle * popsize; + RbtBool bHistory = config.history_frequency > 0; RbtInt iTrace = GetTrace(); RbtDouble bestScore = pop->Best()->GetScore(); @@ -105,12 +89,19 @@ void RbtGATransform::Execute() { << setw(10) << pop->GetScoreVariance() << endl; } - for (RbtInt iCycle = 0; (iCycle < nCycles) && (iConvergence < nConvergence); ++iCycle) { - if (bHistory && ((iCycle % nHisFreq) == 0)) { + for (RbtInt iCycle = 0; (iCycle < config.max_cycles) && (iConvergence < config.num_convergence_cycles); ++iCycle) { + if (bHistory && ((iCycle % config.history_frequency) == 0)) { pop->Best()->GetChrom()->SyncToModel(); pWorkSpace->SaveHistory(true); } - pop->GAstep(nrepl, relStepSize, equalityThreshold, pcross, xovermut, cmutate); + pop->GAstep( + nrepl, + config.relative_step_size, + config.equality_threshold, + config.crossover_probability, + config.cauchy_mutation_after_crossover, + config.use_cauchy_distribution_for_mutations + ); RbtDouble score = pop->Best()->GetScore(); if (score > bestScore) { bestScore = score; diff --git a/src/lib/RbtParameterFileSource.cxx b/src/lib/RbtParameterFileSource.cxx index 2a1ec8b..422eb0a 100644 --- a/src/lib/RbtParameterFileSource.cxx +++ b/src/lib/RbtParameterFileSource.cxx @@ -85,6 +85,16 @@ RbtString RbtParameterFileSource::GetParameterValueAsString(const RbtString& str throw RbtFileMissingParameter(_WHERE_, strFullParamName + " parameter not found in " + GetFileName()); } +RbtVariant RbtParameterFileSource::GetParameterValueAsVariant(const RbtString& strParamName) { + Parse(); + RbtString strFullParamName = GetFullParameterName(strParamName); + RbtStringVariantMapConstIter iter = m_paramsMap.find(strFullParamName); + if (iter != m_paramsMap.end()) + return (*iter).second; + else + throw RbtFileMissingParameter(_WHERE_, strFullParamName + " parameter not found in " + GetFileName()); +} + // DM 11 Feb 1999 Check if parameter is present RbtBool RbtParameterFileSource::isParameterPresent(const RbtString& strParamName) { Parse(); diff --git a/src/lib/RbtRandLigTransform.cxx b/src/lib/RbtRandLigTransform.cxx index d7c4e0f..4ce1c74 100644 --- a/src/lib/RbtRandLigTransform.cxx +++ b/src/lib/RbtRandLigTransform.cxx @@ -17,13 +17,13 @@ RbtString RbtRandLigTransform::_CT("RbtRandLigTransform"); // Parameter names RbtString RbtRandLigTransform::_TORS_STEP("TORS_STEP"); -//////////////////////////////////////// -// Constructors/destructors -RbtRandLigTransform::RbtRandLigTransform(const RbtString& strName): +const RbtRandLigTransform::Config + RbtRandLigTransform::DEFAULT_CONFIG{}; // Empty initializer to fall back to default values + +RbtRandLigTransform::RbtRandLigTransform(const RbtString& strName, const Config& config): RbtBaseUniMolTransform(_CT, strName), - m_rand(Rbt::GetRbtRand()) { - // Add parameters - AddParameter(_TORS_STEP, 180); + m_rand(Rbt::GetRbtRand()), + config{config} { #ifdef _DEBUG cout << _CT << " parameterised constructor" << endl; #endif //_DEBUG @@ -56,9 +56,8 @@ void RbtRandLigTransform::SetupTransform() { void RbtRandLigTransform::Execute() { RbtModelPtr spLigand = GetLigand(); if (spLigand.Null()) return; - RbtDouble torsStep = GetParameter(_TORS_STEP); - for (RbtBondListIter iter = m_rotableBonds.begin(); iter != m_rotableBonds.end(); iter++) { - RbtDouble thetaDeg = 2.0 * torsStep * m_rand.GetRandom01() - torsStep; - spLigand->RotateBond(*iter, thetaDeg, false); + for (auto& rotable_bond: m_rotableBonds) { + RbtDouble thetaDeg = 2.0 * config.torsion_step * m_rand.GetRandom01() - config.torsion_step; + spLigand->RotateBond(rotable_bond, thetaDeg, false); } } diff --git a/src/lib/RbtRandPopTransform.cxx b/src/lib/RbtRandPopTransform.cxx index 9c473ec..a0532a2 100644 --- a/src/lib/RbtRandPopTransform.cxx +++ b/src/lib/RbtRandPopTransform.cxx @@ -20,9 +20,12 @@ RbtString RbtRandPopTransform::_CT("RbtRandPopTransform"); RbtString RbtRandPopTransform::_POP_SIZE("POP_SIZE"); RbtString RbtRandPopTransform::_SCALE_CHROM_LENGTH("SCALE_CHROM_LENGTH"); -RbtRandPopTransform::RbtRandPopTransform(const RbtString& strName): RbtBaseBiMolTransform(_CT, strName) { - AddParameter(_POP_SIZE, 50); - AddParameter(_SCALE_CHROM_LENGTH, true); +const RbtRandPopTransform::Config + RbtRandPopTransform::DEFAULT_CONFIG{}; // Empty initializer to fall back to default values + +RbtRandPopTransform::RbtRandPopTransform(const RbtString& strName, const Config& config): + RbtBaseBiMolTransform(_CT, strName), + config{config} { _RBTOBJECTCOUNTER_CONSTR_(_CT); } @@ -55,16 +58,13 @@ void RbtRandPopTransform::Execute() { if (pSF == NULL) { return; } - RbtInt popSize = GetParameter(_POP_SIZE); - RbtBool bScale = GetParameter(_SCALE_CHROM_LENGTH); - if (bScale) { + RbtInt population_size = config.population_size; + if (config.scale_chromosome_length) { RbtInt chromLength = m_chrom->GetLength(); - popSize *= chromLength; - } - if (GetTrace() > 3) { - cout << _CT << ": popSize=" << popSize << endl; + population_size *= chromLength; } - RbtPopulationPtr pop = new RbtPopulation(m_chrom, popSize, pSF); + if (GetTrace() > 3) cout << _CT << ": popSize=" << population_size << endl; + RbtPopulationPtr pop = new RbtPopulation(m_chrom, population_size, pSF); pop->Best()->GetChrom()->SyncToModel(); GetWorkSpace()->SetPopulation(pop); } diff --git a/src/lib/RbtSimAnnTransform.cxx b/src/lib/RbtSimAnnTransform.cxx index 8611184..0cf32bd 100644 --- a/src/lib/RbtSimAnnTransform.cxx +++ b/src/lib/RbtSimAnnTransform.cxx @@ -66,22 +66,14 @@ RbtString RbtSimAnnTransform::_PARTITION_DIST("PARTITION_DIST"); RbtString RbtSimAnnTransform::_PARTITION_FREQ("PARTITION_FREQ"); RbtString RbtSimAnnTransform::_HISTORY_FREQ("HISTORY_FREQ"); -//////////////////////////////////////// -// Constructors/destructors -RbtSimAnnTransform::RbtSimAnnTransform(const RbtString& strName): +const RbtSimAnnTransform::Config + RbtSimAnnTransform::DEFAULT_CONFIG{}; // Empty initializer to fall back to default values + +RbtSimAnnTransform::RbtSimAnnTransform(const RbtString& strName, const Config& config): RbtBaseBiMolTransform(_CT, strName), - m_rand(Rbt::GetRbtRand()) { - // Add parameters - AddParameter(_START_T, 1000.0); - AddParameter(_FINAL_T, 300.0); - AddParameter(_BLOCK_LENGTH, 50); - AddParameter(_SCALE_CHROM_LENGTH, true); - AddParameter(_NUM_BLOCKS, 25); - AddParameter(_STEP_SIZE, 1.0); - AddParameter(_MIN_ACC_RATE, 0.25); - AddParameter(_PARTITION_DIST, 0.0); - AddParameter(_PARTITION_FREQ, 0); - AddParameter(_HISTORY_FREQ, 0); + m_rand(Rbt::GetRbtRand()), + config{config} // For now simply copy the config structure. +{ m_spStats = RbtMCStatsPtr(new RbtMCStats()); #ifdef _DEBUG cout << _CT << " parameterised constructor" << endl; @@ -133,34 +125,17 @@ void RbtSimAnnTransform::Execute() { RbtInt iTrace = GetTrace(); pWorkSpace->ClearPopulation(); - // Get the cooling schedule params - RbtDouble t = GetParameter(_START_T); - RbtDouble tFinal = GetParameter(_FINAL_T); - RbtInt nBlocks = GetParameter(_NUM_BLOCKS); - RbtInt blockLen = GetParameter(_BLOCK_LENGTH); - RbtBool bScale = GetParameter(_SCALE_CHROM_LENGTH); - RbtDouble stepSize = GetParameter(_STEP_SIZE); - RbtDouble minAccRate = GetParameter(_MIN_ACC_RATE); - if (bScale) { + RbtInt block_length = config.block_length; + if (config.scale_chromosome_length) { RbtInt chromLength = m_chrom->GetLength(); - blockLen *= chromLength; - } - if (iTrace > 0) { - cout << _CT << ": Block length = " << blockLen << endl; + block_length *= chromLength; } - - // Cooling factor (check for nBlocks=1) - RbtDouble tFac = (nBlocks > 1) ? pow(tFinal / t, 1.0 / (nBlocks - 1)) : 1.0; - - // DM 15 Feb 1999 - don't initialise the Monte Carlo stats each block - // if we are doing a constant temperature run - RbtBool bInitBlock = (t != tFinal); + if (iTrace > 0) cout << _CT << ": Block length = " << block_length << endl; // Send the partitioning request separately based on the current partition distance // If partDist is zero, the partitioning automatically gets removed - RbtDouble partDist = GetParameter(_PARTITION_DIST); - m_spPartReq = RbtRequestPtr(new RbtSFPartitionRequest(partDist)); + m_spPartReq = RbtRequestPtr(new RbtSFPartitionRequest(config.partition_distance)); pSF->HandleRequest(m_spPartReq); // Update the chromosome to match the current model coords @@ -184,22 +159,31 @@ void RbtSimAnnTransform::Execute() { << setw(10) << "MAX" << endl; } - for (RbtInt iBlock = 1; iBlock <= nBlocks; iBlock++, t *= tFac) { + RbtDouble initial_temp = config.initial_temp; + RbtDouble step_size = config.step_size; + + // DM 15 Feb 1999 - don't initialise the Monte Carlo stats each block + // if we are doing a constant temperature run + RbtBool bInitBlock = (config.initial_temp != config.final_temp); + RbtDouble tFac = + (config.num_blocks > 1) ? pow(config.final_temp / config.initial_temp, 1.0 / (config.num_blocks - 1)) : 1.0; + + for (RbtInt iBlock = 1; iBlock <= config.num_blocks; iBlock++, initial_temp *= tFac) { if (bInitBlock) { m_spStats->InitBlock(pSF->Score()); } - MC(t, blockLen, stepSize); + MC(initial_temp, block_length, step_size); if (iTrace > 0) { - cout << setw(5) << iBlock << setw(10) << t << setw(10) << m_spStats->AccRate() << setw(10) << stepSize - << setw(10) << m_spStats->_blockInitial << setw(10) << m_spStats->_blockFinal << setw(10) + cout << setw(5) << iBlock << setw(10) << initial_temp << setw(10) << m_spStats->AccRate() << setw(10) + << step_size << setw(10) << m_spStats->_blockInitial << setw(10) << m_spStats->_blockFinal << setw(10) << m_spStats->Mean() << setw(10) << sqrt(m_spStats->Variance()) << setw(10) << m_spStats->_blockMin << setw(10) << m_spStats->_blockMax << endl; } // Halve the maximum step sizes for all enabled modes // if the acceptance rate is less than the threshold - if (m_spStats->AccRate() < minAccRate) { - stepSize *= 0.5; + if (m_spStats->AccRate() < config.min_accuracy_rate) { + step_size *= 0.5; // Reinitialise the stats (only need to do it here if bInitBlock is false // otherwise it will be done at the beginning of the next block) if (!bInitBlock) { @@ -222,11 +206,6 @@ void RbtSimAnnTransform::MC(RbtDouble t, RbtInt blockLen, RbtDouble stepSize) { RbtInt iTrace = GetTrace(); RbtDouble score = pSF->Score(); - RbtInt nHisFreq = GetParameter(_HISTORY_FREQ); - RbtBool bHistory = nHisFreq > 0; - RbtInt nPartFreq = GetParameter(_PARTITION_FREQ); - RbtBool bPartition = (nPartFreq > 0); - // Keep a record of the last good chromosome vector, for fast revert following a failed Metropolic test m_lastGoodVector.clear(); m_chrom->GetVector(m_lastGoodVector); @@ -258,13 +237,13 @@ void RbtSimAnnTransform::MC(RbtDouble t, RbtInt blockLen, RbtDouble stepSize) { // Gather the statistics m_spStats->Accumulate(score, bMetrop); // Render to the history file if appropriate (true = with component scores) - if (bHistory && (iStep % nHisFreq) == 0) { + if ((config.history_frequency > 0) && (iStep % config.history_frequency) == 0) { GetWorkSpace()->SaveHistory(true); } // Update the interaction lists if appropriate // We update the lists every nth accepted trial (as rejected trials don't change the coords) - if (bPartition && ((m_spStats->_accepted % nPartFreq) == 0)) { + if ((config.partition_frequency > 0) && ((m_spStats->_accepted % config.partition_frequency) == 0)) { pSF->HandleRequest(m_spPartReq); RbtDouble oldScore = score; score = pSF->Score(); diff --git a/src/lib/RbtSimplexTransform.cxx b/src/lib/RbtSimplexTransform.cxx index 8d73893..085d9ff 100644 --- a/src/lib/RbtSimplexTransform.cxx +++ b/src/lib/RbtSimplexTransform.cxx @@ -30,15 +30,12 @@ RbtString RbtSimplexTransform::_PARTITION_DIST("PARTITION_DIST"); RbtString RbtSimplexTransform::_STEP_SIZE("STEP_SIZE"); RbtString RbtSimplexTransform::_CONVERGENCE("CONVERGENCE"); -//////////////////////////////////////// -// Constructors/destructors -RbtSimplexTransform::RbtSimplexTransform(const RbtString& strName): RbtBaseBiMolTransform(_CT, strName) { - AddParameter(_MAX_CALLS, 200); - AddParameter(_NCYCLES, 5); - AddParameter(_STOPPING_STEP_LENGTH, 10e-4); - AddParameter(_PARTITION_DIST, 0.0); - AddParameter(_STEP_SIZE, 0.1); - AddParameter(_CONVERGENCE, 0.001); +const RbtSimplexTransform::Config + RbtSimplexTransform::DEFAULT_CONFIG{}; // Empty initializer to fall back to default values + +RbtSimplexTransform::RbtSimplexTransform(const RbtString& strName, const Config& config): + RbtBaseBiMolTransform(_CT, strName), + config{config} { #ifdef _DEBUG cout << _CT << " parameterised constructor" << endl; #endif //_DEBUG @@ -86,13 +83,7 @@ void RbtSimplexTransform::Execute() { RbtInt iTrace = GetTrace(); pWorkSpace->ClearPopulation(); - RbtInt maxcalls = GetParameter(_MAX_CALLS); - RbtInt ncycles = GetParameter(_NCYCLES); - RbtDouble stopping = GetParameter(_STOPPING_STEP_LENGTH); - RbtDouble convergence = GetParameter(_CONVERGENCE); - RbtDouble stepSize = GetParameter(_STEP_SIZE); - RbtDouble partDist = GetParameter(_PARTITION_DIST); - RbtRequestPtr spPartReq(new RbtSFPartitionRequest(partDist)); + RbtRequestPtr spPartReq(new RbtSFPartitionRequest(config.partition_distribution)); RbtRequestPtr spClearPartReq(new RbtSFPartitionRequest(0.0)); pSF->HandleRequest(spPartReq); @@ -104,20 +95,19 @@ void RbtSimplexTransform::Execute() { RbtInt nsv = sv.size(); RbtDouble* steps = new RbtDouble[nsv]; for (RbtInt i = 0; i < nsv; ++i) { - steps[i] = sv[i] * stepSize; + steps[i] = sv[i] * config.step_size; } // Set up the Simplex search object NMSearch* ssearch; - NMSearch::SetMaxCalls(maxcalls); - NMSearch::SetStoppingLength(stopping); + NMSearch::SetMaxCalls(config.max_calls); + NMSearch::SetStoppingLength(config.stopping_step_length); RbtInt calls = 0; RbtDouble initScore = pSF->Score(); // Current score RbtDouble min = initScore; RbtDoubleList vc; // Vector representation of chromosome - RbtInt N = ncycles; // Energy change between cycles - initialise so as not to terminate loop immediately - RbtDouble delta = -convergence - 1.0; + RbtDouble delta = -config.convergence_threshold - 1.0; if (iTrace > 0) { cout.precision(3); @@ -136,8 +126,8 @@ void RbtSimplexTransform::Execute() { } } - for (RbtInt i = 0; (i < N) && (delta < -convergence); i++) { - if (partDist > 0.0) { + for (RbtInt i = 0; (i < config.num_cycles) && (delta < -config.convergence_threshold); i++) { + if (config.partition_distribution > 0.0) { pSF->HandleRequest(spPartReq); } // Use a variable length simplex diff --git a/src/lib/RbtTransformFactory.cxx b/src/lib/RbtTransformFactory.cxx index 28c85fb..06b0f17 100644 --- a/src/lib/RbtTransformFactory.cxx +++ b/src/lib/RbtTransformFactory.cxx @@ -22,6 +22,33 @@ #include "RbtSimAnnTransform.h" #include "RbtSimplexTransform.h" +static RbtSimAnnTransform* MakeSimmulatedAnnealingTransformFromFile( + RbtParameterFileSourcePtr paramsPtr, const RbtString& name +); +static RbtGATransform* MakeGeneticAlgorithmTransformFromFile( + RbtParameterFileSourcePtr paramsPtr, const RbtString& name +); +static RbtAlignTransform* MakeLigandAlignTransformFromFile(RbtParameterFileSourcePtr paramsPtr, const RbtString& name); +static RbtNullTransform* MakeNullTransformFromFile(RbtParameterFileSourcePtr paramsPtr, const RbtString& name); +static RbtRandLigTransform* MakeRandomizeLigandTransformFromFile( + RbtParameterFileSourcePtr paramsPtr, const RbtString& name +); +static RbtRandPopTransform* MakeRandomizePopulationTransformFromFile( + RbtParameterFileSourcePtr paramsPtr, const RbtString& name +); +static RbtSimplexTransform* MakeSimplexTransformFromFile(RbtParameterFileSourcePtr paramsPtr, const RbtString& name); +static RbtTransformAgg* MakeAggregateTransformFromFile(RbtParameterFileSourcePtr paramsPtr, const RbtString& name); +static void RegisterScoreFunctionOverridesInTransform( + RbtBaseTransform* transform, RbtParameterFileSourcePtr paramsPtr +); + +static RbtAlignTransform::LigandCenterOfMassPlacementStrategy GetLigandCenterOfMassPlacementStrategyFromFile( + RbtParameterFileSourcePtr paramsPtr +); +static RbtAlignTransform::LigandAxesAlignmentStrategy GetLigandAxesAlignmentStrategyFromFile( + RbtParameterFileSourcePtr paramsPtr +); + // Parameter name which identifies a scoring function definition RbtString RbtTransformFactory::_TRANSFORM("TRANSFORM"); @@ -34,22 +61,6 @@ RbtTransformFactory::~RbtTransformFactory() {} //////////////////////////////////////// // Public methods //////////////// -// Creates a single transform object of type strTransformClass, and name strName -// e.g. strTransformClass = RbtSimAnnTransform -RbtBaseTransform* RbtTransformFactory::Create(const RbtString& strTransformClass, const RbtString& strName) { - // Component transforms - if (strTransformClass == RbtSimAnnTransform::_CT) return new RbtSimAnnTransform(strName); - if (strTransformClass == RbtGATransform::_CT) return new RbtGATransform(strName); - if (strTransformClass == RbtAlignTransform::_CT) return new RbtAlignTransform(strName); - if (strTransformClass == RbtNullTransform::_CT) return new RbtNullTransform(strName); - if (strTransformClass == RbtRandLigTransform::_CT) return new RbtRandLigTransform(strName); - if (strTransformClass == RbtRandPopTransform::_CT) return new RbtRandPopTransform(strName); - if (strTransformClass == RbtSimplexTransform::_CT) return new RbtSimplexTransform(strName); - // Aggregate transforms - if (strTransformClass == RbtTransformAgg::_CT) return new RbtTransformAgg(strName); - - throw RbtBadArgument(_WHERE_, "Unknown transform " + strTransformClass); -} // Creates an aggregate transform from a parameter file source // Each component transform is in a named section, which should minimally contain a TRANSFORM parameter @@ -62,10 +73,9 @@ RbtTransformAgg* RbtTransformFactory::CreateAggFromFile( ) { // Get list of transform objects to create RbtStringList transformList = Rbt::ConvertDelimitedStringToList(strTransformClasses); - // If strTransformClasses is empty, then default to reading all sections of the - // parameter file for valid transform definitions - // In this case we do not throw an error if a particular section - // is not a transform, we simply skip it + // If strTransformClasses is empty, then default to reading all sections of the parameter file for valid transform + // definitions In this case we do not throw an error if a particular section is not a transform, we simply skip it. + // This is not used anywhere but not sure if any user of the API needs it. RbtBool bThrowError(true); if (transformList.empty()) { transformList = spPrmSource->GetSectionList(); @@ -75,32 +85,205 @@ RbtTransformAgg* RbtTransformFactory::CreateAggFromFile( // Create empty aggregate RbtTransformAgg* pTransformAgg(new RbtTransformAgg(strName)); - for (RbtStringListConstIter tIter = transformList.begin(); tIter != transformList.end(); tIter++) { - spPrmSource->SetSection(*tIter); - // Check if this section is a valid scoring function definition + for (auto& sectionName: transformList) { + spPrmSource->SetSection(sectionName); if (spPrmSource->isParameterPresent(_TRANSFORM)) { - RbtString strTransformClass(spPrmSource->GetParameterValueAsString(_TRANSFORM)); - // Create new transform according to the string value of _TRANSFORM parameter - RbtBaseTransform* pTransform = Create(strTransformClass, *tIter); - // Set all the transform parameters from the rest of the parameters listed - RbtStringList prmList = spPrmSource->GetParameterList(); - for (RbtStringListConstIter prmIter = prmList.begin(); prmIter != prmList.end(); prmIter++) { - // Look for scoring function request (PARAM@SF) - // Only SetParamRequest currently supported - RbtStringList compList = Rbt::ConvertDelimitedStringToList(*prmIter, "@"); - if (compList.size() == 2) { - RbtRequestPtr spReq(new RbtSFSetParamRequest( - compList[1], compList[0], spPrmSource->GetParameterValueAsString(*prmIter) - )); - pTransform->AddSFRequest(spReq); - } else if ((*prmIter) != _TRANSFORM) { // Skip _TRANSFORM parameter - pTransform->SetParameter(*prmIter, spPrmSource->GetParameterValueAsString(*prmIter)); - } - } - pTransformAgg->Add(pTransform); + RbtBaseTransform* transform = MakeTransformFromFile(spPrmSource, sectionName); + // All examples + the docs instruct to use NullTransform to set the SF overrides, nevertheless + // it is possible to set the overrides in an arbitrary transform. + RegisterScoreFunctionOverridesInTransform(transform, spPrmSource); + pTransformAgg->Add(transform); } else if (bThrowError) { - throw RbtFileMissingParameter(_WHERE_, "Missing " + _TRANSFORM + " parameter in section " + (*tIter)); + throw RbtFileMissingParameter(_WHERE_, "Missing " + _TRANSFORM + " parameter in section " + (sectionName)); } } return pTransformAgg; } + +// Assumes that the fileSource has the appropriate Section set. +RbtBaseTransform* RbtTransformFactory::MakeTransformFromFile( + RbtParameterFileSourcePtr paramsPtr, const RbtString& name +) { + RbtString kind = paramsPtr->GetParameterValueAsString(_TRANSFORM); // Force a cast from RbtVariant to String + if (kind == RbtSimAnnTransform::_CT) + return MakeSimmulatedAnnealingTransformFromFile(paramsPtr, name); + else if (kind == RbtGATransform::_CT) + return MakeGeneticAlgorithmTransformFromFile(paramsPtr, name); + else if (kind == RbtAlignTransform::_CT) + return MakeLigandAlignTransformFromFile(paramsPtr, name); + else if (kind == RbtNullTransform::_CT) + return MakeNullTransformFromFile(paramsPtr, name); + else if (kind == RbtRandLigTransform::_CT) + return MakeRandomizeLigandTransformFromFile(paramsPtr, name); + else if (kind == RbtRandPopTransform::_CT) + return MakeRandomizePopulationTransformFromFile(paramsPtr, name); + else if (kind == RbtSimplexTransform::_CT) + return MakeSimplexTransformFromFile(paramsPtr, name); + else if (kind == RbtTransformAgg::_CT) + return MakeAggregateTransformFromFile(paramsPtr, name); + else + throw RbtBadArgument(_WHERE_, "Unknown transform: " + kind); +} + +static RbtSimAnnTransform* MakeSimmulatedAnnealingTransformFromFile( + RbtParameterFileSourcePtr paramsPtr, const RbtString& name +) { + const RbtSimAnnTransform::Config& default_config = RbtSimAnnTransform::DEFAULT_CONFIG; + RbtSimAnnTransform::Config config{ + .initial_temp = paramsPtr->GetParamOrDefault(RbtSimAnnTransform::_START_T, default_config.initial_temp), + .final_temp = paramsPtr->GetParamOrDefault(RbtSimAnnTransform::_FINAL_T, default_config.final_temp), + .num_blocks = paramsPtr->GetParamOrDefault(RbtSimAnnTransform::_NUM_BLOCKS, default_config.num_blocks), + .block_length = paramsPtr->GetParamOrDefault(RbtSimAnnTransform::_BLOCK_LENGTH, default_config.block_length), + .scale_chromosome_length = paramsPtr->GetParamOrDefault( + RbtSimAnnTransform::_SCALE_CHROM_LENGTH, default_config.scale_chromosome_length + ), + .step_size = paramsPtr->GetParamOrDefault(RbtSimAnnTransform::_STEP_SIZE, default_config.step_size), + .min_accuracy_rate = + paramsPtr->GetParamOrDefault(RbtSimAnnTransform::_MIN_ACC_RATE, default_config.min_accuracy_rate), + .partition_distance = + paramsPtr->GetParamOrDefault(RbtSimAnnTransform::_PARTITION_DIST, default_config.partition_distance), + .partition_frequency = + paramsPtr->GetParamOrDefault(RbtSimAnnTransform::_PARTITION_FREQ, default_config.partition_frequency), + .history_frequency = + paramsPtr->GetParamOrDefault(RbtSimAnnTransform::_HISTORY_FREQ, default_config.history_frequency), + }; + return new RbtSimAnnTransform(name, config); +} + +static RbtGATransform* MakeGeneticAlgorithmTransformFromFile( + RbtParameterFileSourcePtr paramsPtr, const RbtString& name +) { + const RbtGATransform::Config& default_config = RbtGATransform::DEFAULT_CONFIG; + RbtGATransform::Config config{ + .population_size_fraction_as_new_individuals_per_cycle = paramsPtr->GetParamOrDefault( + RbtGATransform::_NEW_FRACTION, default_config.population_size_fraction_as_new_individuals_per_cycle + ), + .crossover_probability = + paramsPtr->GetParamOrDefault(RbtGATransform::_PCROSSOVER, default_config.crossover_probability), + .cauchy_mutation_after_crossover = + paramsPtr->GetParamOrDefault(RbtGATransform::_XOVERMUT, default_config.cauchy_mutation_after_crossover), + .use_cauchy_distribution_for_mutations = paramsPtr->GetParamOrDefault( + RbtGATransform::_CMUTATE, default_config.use_cauchy_distribution_for_mutations + ), + .relative_step_size = + paramsPtr->GetParamOrDefault(RbtGATransform::_STEP_SIZE, default_config.relative_step_size), + .equality_threshold = + paramsPtr->GetParamOrDefault(RbtGATransform::_EQUALITY_THRESHOLD, default_config.equality_threshold), + .max_cycles = paramsPtr->GetParamOrDefault(RbtGATransform::_NCYCLES, default_config.max_cycles), + .num_convergence_cycles = + paramsPtr->GetParamOrDefault(RbtGATransform::_NCONVERGENCE, default_config.num_convergence_cycles), + .history_frequency = + paramsPtr->GetParamOrDefault(RbtGATransform::_HISTORY_FREQ, default_config.history_frequency), + }; + return new RbtGATransform(name, config); +} + +static RbtAlignTransform* MakeLigandAlignTransformFromFile( + RbtParameterFileSourcePtr paramsPtr, const RbtString& name +) { + RbtAlignTransform::Config config{ + .center_of_mass_placement_strategy = GetLigandCenterOfMassPlacementStrategyFromFile(paramsPtr), + .axes_alignment_strategy = GetLigandAxesAlignmentStrategyFromFile(paramsPtr), + }; + return new RbtAlignTransform(name, config); +} + +// Doesn't have any parameters but let's create it for simmetry; +static RbtNullTransform* MakeNullTransformFromFile(RbtParameterFileSourcePtr paramsPtr, const RbtString& name) { + return new RbtNullTransform(name); +} + +static RbtRandLigTransform* MakeRandomizeLigandTransformFromFile( + RbtParameterFileSourcePtr paramsPtr, const RbtString& name +) { + const RbtRandLigTransform::Config& default_config = RbtRandLigTransform::DEFAULT_CONFIG; + RbtRandLigTransform::Config config{ + .torsion_step = paramsPtr->GetParamOrDefault(RbtRandLigTransform::_TORS_STEP, default_config.torsion_step), + }; + return new RbtRandLigTransform(name, config); +} + +static RbtRandPopTransform* MakeRandomizePopulationTransformFromFile( + RbtParameterFileSourcePtr paramsPtr, const RbtString& name +) { + const RbtRandPopTransform::Config& default_config = RbtRandPopTransform::DEFAULT_CONFIG; + RbtRandPopTransform::Config config{ + .population_size = + paramsPtr->GetParamOrDefault(RbtRandPopTransform::_POP_SIZE, default_config.population_size), + .scale_chromosome_length = paramsPtr->GetParamOrDefault( + RbtRandPopTransform::_SCALE_CHROM_LENGTH, default_config.scale_chromosome_length + ), + }; + return new RbtRandPopTransform(name, config); +} + +static RbtSimplexTransform* MakeSimplexTransformFromFile(RbtParameterFileSourcePtr paramsPtr, const RbtString& name) { + const RbtSimplexTransform::Config& default_config = RbtSimplexTransform::DEFAULT_CONFIG; + RbtSimplexTransform::Config config{ + .max_calls = paramsPtr->GetParamOrDefault(RbtSimplexTransform::_MAX_CALLS, default_config.max_calls), + .num_cycles = paramsPtr->GetParamOrDefault(RbtSimplexTransform::_NCYCLES, default_config.num_cycles), + .stopping_step_length = paramsPtr->GetParamOrDefault( + RbtSimplexTransform::_STOPPING_STEP_LENGTH, default_config.stopping_step_length + ), + .convergence_threshold = + paramsPtr->GetParamOrDefault(RbtSimplexTransform::_CONVERGENCE, default_config.convergence_threshold), + .step_size = paramsPtr->GetParamOrDefault(RbtSimplexTransform::_STEP_SIZE, default_config.step_size), + .partition_distribution = + paramsPtr->GetParamOrDefault(RbtSimplexTransform::_PARTITION_DIST, default_config.partition_distribution), + }; + return new RbtSimplexTransform(name, config); +} + +static RbtTransformAgg* MakeAggregateTransformFromFile(RbtParameterFileSourcePtr paramsPtr, const RbtString& name) { + return new RbtTransformAgg(name); +} + +static void RegisterScoreFunctionOverridesInTransform( + RbtBaseTransform* transform, RbtParameterFileSourcePtr paramsPtr +) { + // Set all the transform parameters from the rest of the parameters listed + for (auto& paramName: paramsPtr->GetParameterList()) { + // Look for scoring function request (PARAM@SF). Only SetParamRequest currently supported + // Parameters of the individual transformers are explicitly set by their respective constructor functions + // So we only look for score function overrides. + RbtStringList compList = Rbt::ConvertDelimitedStringToList(paramName, "@"); + if (compList.size() == 2) { + RbtRequestPtr spReq( + new RbtSFSetParamRequest(compList[1], compList[0], paramsPtr->GetParameterValueAsString(paramName)) + ); + transform->AddSFRequest(spReq); + } + } +} + +static RbtAlignTransform::LigandCenterOfMassPlacementStrategy GetLigandCenterOfMassPlacementStrategyFromFile( + RbtParameterFileSourcePtr paramsPtr +) { + if (paramsPtr->isParameterPresent(RbtAlignTransform::_COM)) { + RbtString placement_strategy_val = paramsPtr->GetParameterValueAsString(RbtAlignTransform::_COM); + if (placement_strategy_val == "ALIGN") + return RbtAlignTransform::LigandCenterOfMassPlacementStrategy::COM_ALIGN; + else if (placement_strategy_val == "RANDOM") + return RbtAlignTransform::LigandCenterOfMassPlacementStrategy::COM_RANDOM; + else + throw RbtBadArgument( + _WHERE_, "Invalid ligand center of mass placement strategy: " + placement_strategy_val + ); + } else + return RbtAlignTransform::DEFAULT_CONFIG.center_of_mass_placement_strategy; +} + +static RbtAlignTransform::LigandAxesAlignmentStrategy GetLigandAxesAlignmentStrategyFromFile( + RbtParameterFileSourcePtr paramsPtr +) { + if (paramsPtr->isParameterPresent(RbtAlignTransform::_AXES)) { + RbtString alignment_strategy_val = paramsPtr->GetParameterValueAsString(RbtAlignTransform::_AXES); + if (alignment_strategy_val == "ALIGN") + return RbtAlignTransform::LigandAxesAlignmentStrategy::AXES_ALIGN; + else if (alignment_strategy_val == "RANDOM") + return RbtAlignTransform::LigandAxesAlignmentStrategy::AXES_RANDOM; + else + throw RbtBadArgument(_WHERE_, "Invalid ligand axes alignment strategy: " + alignment_strategy_val); + } else + return RbtAlignTransform::DEFAULT_CONFIG.axes_alignment_strategy; +} \ No newline at end of file