Skip to content

Commit

Permalink
Refactor: Make transform configurations explicit, rework parsing of t…
Browse files Browse the repository at this point in the history
…ransforms (#119)

* Introduce transform configuration objects holding all transform parameters and make transforms depend on them instead of  accepting strings
  • Loading branch information
Beetelbrox authored Jun 7, 2024
1 parent 76cfe16 commit ea05b80
Show file tree
Hide file tree
Showing 16 changed files with 463 additions and 239 deletions.
15 changes: 14 additions & 1 deletion include/RbtAlignTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,20 @@ class RbtAlignTransform: public RbtBaseBiMolTransform {
static RbtString _COM;
static RbtString _AXES;

enum LigandCenterOfMassPlacementStrategy { COM_ALIGN, COM_RANDOM };
enum LigandAxesAlignmentStrategy { AXES_ALIGN, AXES_RANDOM };

struct Config {
LigandCenterOfMassPlacementStrategy center_of_mass_placement_strategy{
LigandCenterOfMassPlacementStrategy::COM_ALIGN};
LigandAxesAlignmentStrategy axes_alignment_strategy{LigandAxesAlignmentStrategy::AXES_ALIGN};
};

static const Config DEFAULT_CONFIG;

////////////////////////////////////////
// Constructors/destructors
RbtAlignTransform(const RbtString& strName = "ALIGN");
RbtAlignTransform(const RbtString& strName, const Config& config);
virtual ~RbtAlignTransform();

////////////////////////////////////////
Expand Down Expand Up @@ -65,6 +76,8 @@ class RbtAlignTransform: public RbtBaseBiMolTransform {
RbtCavityList m_cavities; // List of active site cavities to choose from
RbtIntList m_cumulSize; // Cumulative sizes, for weighted probabilities
RbtInt m_totalSize; // Total size of all cavities

const Config config;
};

// Useful typedefs
Expand Down
17 changes: 16 additions & 1 deletion include/RbtGATransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,24 @@ class RbtGATransform: public RbtBaseBiMolTransform {
// Output the best pose every _HISTORY_FREQ cycles.
static RbtString _HISTORY_FREQ;

struct Config {
RbtDouble population_size_fraction_as_new_individuals_per_cycle{0.5};
RbtDouble crossover_probability{0.4};
RbtBool cauchy_mutation_after_crossover{true};
RbtBool use_cauchy_distribution_for_mutations{false}; // We might want to make this an enum
RbtDouble relative_step_size{1.0};
RbtDouble equality_threshold{0.1};
RbtInt max_cycles{100};
RbtInt num_convergence_cycles{6};
RbtInt history_frequency{0};
};

static const Config DEFAULT_CONFIG;

////////////////////////////////////////
// Constructors/destructors
////////////////////////////////////////
RbtGATransform(const RbtString& strName = "GAGENRW");
RbtGATransform(const RbtString& strName, const Config& config);
virtual ~RbtGATransform();

protected:
Expand All @@ -67,6 +81,7 @@ class RbtGATransform: public RbtBaseBiMolTransform {

private:
RbtRand& m_rand;
const Config config;
};

#endif //_RBTGATRANSFORM_H_
11 changes: 11 additions & 0 deletions include/RbtParameterFileSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,20 @@ class RbtParameterFileSource: public RbtBaseFileSource {
RbtDouble GetParameterValue(const RbtString& strParamName);
// DM 12 Feb 1999 Get a particular named parameter value as a string
RbtString GetParameterValueAsString(const RbtString& strParamName);

RbtVariant GetParameterValueAsVariant(const RbtString& strParamName);
// DM 11 Feb 1999 Check if parameter is present
RbtBool isParameterPresent(const RbtString& strParamName);

template <typename ParamType>
ParamType GetParamOrDefault(const RbtString& paramName, const ParamType& defaultValue) {
if (isParameterPresent(paramName)) {
return GetParameterValueAsVariant(paramName);
} else {
return defaultValue;
}
}

// DM 11 Feb 1999 - section handling
// Parameters can be grouped into named sections
// such that the same parameter name can appear in multiple sections
Expand Down
9 changes: 9 additions & 0 deletions include/RbtRandLigTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,16 @@ class RbtRandLigTransform: public RbtBaseUniMolTransform {
// Parameter names
static RbtString _TORS_STEP;

struct Config {
RbtDouble torsion_step{180.0};
};

static const Config DEFAULT_CONFIG;

////////////////////////////////////////
// Constructors/destructors
RbtRandLigTransform(const RbtString& strName = "RANDLIG");
RbtRandLigTransform(const RbtString& strName, const Config& config);
virtual ~RbtRandLigTransform();

////////////////////////////////////////
Expand Down Expand Up @@ -60,6 +67,8 @@ class RbtRandLigTransform: public RbtBaseUniMolTransform {
//////////////
RbtRand& m_rand; // keep a reference to the singleton random number generator
RbtBondList m_rotableBonds;

const Config config;
};

// Useful typedefs
Expand Down
13 changes: 10 additions & 3 deletions include/RbtRandPopTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@ class RbtRandPopTransform: public RbtBaseBiMolTransform {
static RbtString _POP_SIZE;
static RbtString _SCALE_CHROM_LENGTH;

////////////////////////////////////////
// Constructors/destructors
RbtRandPopTransform(const RbtString& strName = "RANDPOP");
struct Config {
RbtInt population_size{50};
RbtBool scale_chromosome_length{true};
};

static const Config DEFAULT_CONFIG;

RbtRandPopTransform(const RbtString& strName, const Config& config);
virtual ~RbtRandPopTransform();

////////////////////////////////////////
Expand Down Expand Up @@ -59,6 +64,8 @@ class RbtRandPopTransform: public RbtBaseBiMolTransform {
// Private data
//////////////
RbtChromElementPtr m_chrom;

const Config config;
};

// Useful typedefs
Expand Down
20 changes: 19 additions & 1 deletion include/RbtSimAnnTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,25 @@ class RbtSimAnnTransform: public RbtBaseBiMolTransform {
static RbtString _PARTITION_FREQ;
static RbtString _HISTORY_FREQ;

struct Config {
RbtDouble initial_temp{1000.0};
RbtDouble final_temp{300.0};
RbtInt num_blocks{25};
RbtInt block_length{50};
RbtBool scale_chromosome_length{true};
RbtDouble step_size{1.0};
RbtDouble min_accuracy_rate{0.25};
RbtDouble partition_distance{0.0};
RbtInt partition_frequency{0};
RbtInt history_frequency{0};
};

static const Config DEFAULT_CONFIG;

////////////////////////////////////////
// Constructors/destructors
RbtSimAnnTransform(const RbtString& strName = "SIMANN");

RbtSimAnnTransform(const RbtString& strName, const Config& config);
virtual ~RbtSimAnnTransform();

////////////////////////////////////////
Expand Down Expand Up @@ -100,6 +116,8 @@ class RbtSimAnnTransform: public RbtBaseBiMolTransform {
RbtChromElementPtr m_chrom; // Current chromosome
RbtDoubleList m_minVector; // Chromosome vector corresponding to overall minimum score
RbtDoubleList m_lastGoodVector; // Saved chromosome before each MC mutation (to allow revert)

const Config config;
};

// Useful typedefs
Expand Down
17 changes: 14 additions & 3 deletions include/RbtSimplexTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,18 @@ class RbtSimplexTransform: public RbtBaseBiMolTransform {
// between cycles
static RbtString _CONVERGENCE;

////////////////////////////////////////
// Constructors/destructors
RbtSimplexTransform(const RbtString& strName = "SIMPLEX");
struct Config {
RbtInt max_calls{200};
RbtInt num_cycles{5};
RbtDouble stopping_step_length{10e-4};
RbtDouble convergence_threshold{0.001};
RbtDouble step_size{0.1};
RbtDouble partition_distribution{0.0};
};

static const Config DEFAULT_CONFIG;

RbtSimplexTransform(const RbtString& strName, const Config& config);
virtual ~RbtSimplexTransform();

////////////////////////////////////////
Expand Down Expand Up @@ -66,6 +75,8 @@ class RbtSimplexTransform: public RbtBaseBiMolTransform {
// Private data
//////////////
RbtChromElementPtr m_chrom;

const Config config;
};

// Useful typedefs
Expand Down
8 changes: 2 additions & 6 deletions include/RbtTransformFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ class RbtTransformFactory {
// Public methods
////////////////

// Creates a single transform object of type strTransformClass, and name strName
// e.g. strTransformClass = RbtSimAnnTransform
virtual RbtBaseTransform* Create(const RbtString& strTransformClass, const RbtString& strName);

// Creates an aggregate transform from a parameter file source
// Each component transform is in a named section, which should minimally contain a TRANSFORM parameter
// whose value is the class name to instantiate
Expand All @@ -64,9 +60,9 @@ class RbtTransformFactory {
// Private methods
/////////////////

RbtTransformFactory(const RbtTransformFactory&); // Copy constructor disabled by default

RbtTransformFactory(const RbtTransformFactory&); // Copy constructor disabled by default
RbtTransformFactory& operator=(const RbtTransformFactory&); // Copy assignment disabled by default
RbtBaseTransform* MakeTransformFromFile(RbtParameterFileSourcePtr paramsPtr, const RbtString& name);

protected:
////////////////////////////////////////
Expand Down
114 changes: 53 additions & 61 deletions src/lib/RbtAlignTransform.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ RbtString RbtAlignTransform::_CT("RbtAlignTransform");
RbtString RbtAlignTransform::_COM("COM");
RbtString RbtAlignTransform::_AXES("AXES");

////////////////////////////////////////
// Constructors/destructors
RbtAlignTransform::RbtAlignTransform(const RbtString& strName):
const RbtAlignTransform::Config
RbtAlignTransform::DEFAULT_CONFIG{}; // Empty initializer to fall back to default values

RbtAlignTransform::RbtAlignTransform(const RbtString& strName, const Config& config):
RbtBaseBiMolTransform(_CT, strName),
m_rand(Rbt::GetRbtRand()),
m_totalSize(0) {
// Add parameters
AddParameter(_COM, "ALIGN");
AddParameter(_AXES, "ALIGN");
m_totalSize(0),
config{config} {
#ifdef _DEBUG
cout << _CT << " parameterised constructor" << endl;
#endif //_DEBUG
Expand Down Expand Up @@ -89,64 +88,57 @@ void RbtAlignTransform::Execute() {
const RbtCoordList& coordList = spCavity->GetCoordList();
const RbtPrincipalAxes& prAxes = spCavity->GetPrincipalAxes();

// Get the alignment parameters
RbtString strPlaceCOM = GetParameter(_COM);
RbtString strPlaceAxes = GetParameter(_AXES);

// 1. Ligand translation
// A. Random
if ((strPlaceCOM == "RANDOM") && !coordList.empty()) {
// Select a coord at random
RbtInt iRand = m_rand.GetRandomInt(coordList.size());
RbtCoord asCavityCoord = coordList[iRand];
if (iTrace > 1) {
cout << "Translating ligand COM to active site coord #" << iRand << ": " << asCavityCoord << endl;
}
// Translate the ligand center of mass to the selected coord
spLigand->SetCenterOfMass(asCavityCoord);
}
// B. Active site center of mass
else if (strPlaceCOM == "ALIGN") {
if (iTrace > 1) {
cout << "Translating ligand COM to active site COM: " << prAxes.com << endl;
}
spLigand->SetCenterOfMass(prAxes.com);
switch (config.center_of_mass_placement_strategy) {
case COM_RANDOM:
if (!coordList.empty()) {
RbtInt iRand = m_rand.GetRandomInt(coordList.size());
RbtCoord asCavityCoord = coordList[iRand];
if (iTrace > 1) {
cout << "Translating ligand COM to active site coord #" << iRand << ": " << asCavityCoord << endl;
}
// Translate the ligand center of mass to the selected coord
spLigand->SetCenterOfMass(asCavityCoord);
}
break;
case COM_ALIGN:
if (iTrace > 1) cout << "Translating ligand COM to active site COM: " << prAxes.com << endl;
spLigand->SetCenterOfMass(prAxes.com);
break;
default:
throw RbtBadArgument(_WHERE_, "Bad ligand center of mass placement strategy");
break;
}

// 2. Ligand axes
// A. Random rotation around random axis
if (strPlaceAxes == "RANDOM") {
RbtDouble thetaDeg = 180.0 * m_rand.GetRandom01();
RbtCoord axis = m_rand.GetRandomUnitVector();
if (iTrace > 1) {
cout << "Rotating ligand by " << thetaDeg << " deg around axis=" << axis << " through COM" << endl;
}
spLigand->Rotate(axis, thetaDeg);
}
// B. Align ligand principal axes with principal axes of active site
else if (strPlaceAxes == "ALIGN") {
spLigand->AlignPrincipalAxes(prAxes, false); // false = don't translate COM as we've already done it above
if (iTrace > 1) {
cout << "Aligning ligand principal axes with active site principal axes" << endl;
}
// Make random 180 deg rotations around each of the principal axes
if (m_rand.GetRandom01() < 0.5) {
spLigand->Rotate(prAxes.axis1, 180.0, prAxes.com);
if (iTrace > 1) {
cout << "180 deg rotation around PA#1" << endl;
switch (config.axes_alignment_strategy) {
case AXES_RANDOM:
{ // Braces to prevent compiler complains due to variables initialized inside this case

RbtDouble thetaDeg = 180.0 * m_rand.GetRandom01();
RbtCoord axis = m_rand.GetRandomUnitVector();
if (iTrace > 1)
cout << "Rotating ligand by " << thetaDeg << " deg around axis=" << axis << " through COM" << endl;
spLigand->Rotate(axis, thetaDeg);
break;
}
case AXES_ALIGN:
spLigand->AlignPrincipalAxes(prAxes, false); // false = don't translate COM as we've already done it above
if (iTrace > 1) cout << "Aligning ligand principal axes with active site principal axes" << endl;
// Make random 180 deg rotations around each of the principal axes
if (m_rand.GetRandom01() < 0.5) {
spLigand->Rotate(prAxes.axis1, 180.0, prAxes.com);
if (iTrace > 1) cout << "180 deg rotation around PA#1" << endl;
}
}
if (m_rand.GetRandom01() < 0.5) {
spLigand->Rotate(prAxes.axis2, 180.0, prAxes.com);
if (iTrace > 1) {
cout << "180 deg rotation around PA#2" << endl;
if (m_rand.GetRandom01() < 0.5) {
spLigand->Rotate(prAxes.axis2, 180.0, prAxes.com);
if (iTrace > 1) cout << "180 deg rotation around PA#2" << endl;
}
}
if (m_rand.GetRandom01() < 0.5) {
spLigand->Rotate(prAxes.axis3, 180.0, prAxes.com);
if (iTrace > 1) {
cout << "180 deg rotation around PA#3" << endl;
if (m_rand.GetRandom01() < 0.5) {
spLigand->Rotate(prAxes.axis3, 180.0, prAxes.com);
if (iTrace > 1) cout << "180 deg rotation around PA#3" << endl;
}
}
break;
default:
throw RbtBadArgument(_WHERE_, "Bad ligand axes alignment strategy");
break;
}
}
Loading

0 comments on commit ea05b80

Please sign in to comment.