Skip to content

Commit

Permalink
make localClock return to a tree with branch rates #461
Browse files Browse the repository at this point in the history
  • Loading branch information
EvaLiyt committed Jun 4, 2024
1 parent f3b9c8a commit 07db152
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import java.util.Arrays;

public class LocalClock extends DeterministicFunction<Double[]> {
public class LocalClock extends DeterministicFunction<TimeTree> {
public static final String treeName = "tree";
public static final String cladeArrayName = "clades";
public static final String cladeRateArrayName = "cladeRates";
Expand All @@ -19,7 +19,6 @@ public class LocalClock extends DeterministicFunction<Double[]> {

public static final boolean DEFAULT_INCLUDE_STEM = true;

// TODO: add option for includeStem=[false]
public LocalClock(
@ParameterInfo(name = treeName, description = "the tree used to calculate branch rates" ) Value<TimeTree> tree,
@ParameterInfo(name = cladeArrayName, description = "the array of the node names") Value<Object[]> clades,
Expand All @@ -30,6 +29,7 @@ public LocalClock(
if (clades == null) throw new IllegalArgumentException("The clades can't be null!");
if (cladeRates == null) throw new IllegalArgumentException("The clade rates can't be null!");
if (rootRate == null) throw new IllegalArgumentException("The root rate can't be null!");
if (cladeRates.value().length != clades.value().length) throw new IllegalArgumentException("The clade rates should match the given clades!");
setParam(treeName, tree);
setParam(cladeArrayName, clades);
setParam(cladeRateArrayName, cladeRates);
Expand All @@ -48,21 +48,27 @@ public LocalClock(
}
}

@GeneratorInfo(name = "localClock", description = "Apply local clock in a phylogenetic tree to generate the " +
@GeneratorInfo(name = "localClock", description = "Apply local clock in a phylogenetic tree to generate a tree with " +
"branch rates. The order of elements in clades and cladeRates array should match. The clades" +
" should not be overlapped with each other.")
@Override
public Value<Double[]> apply() {
public Value<TimeTree> apply() {
// get parameters
TimeTree tree = getTree().value();
TimeTree originalTree = getTree().value();
Object[] clades = getClades().value();
Double[] cladeRates = getCladeRates().value();
Double rootRate = getRootRate().value();
Boolean includeStem = getIncludeStem().value();

// make a deep copy of the original tree
TimeTree tree = new TimeTree(originalTree);

// set the rates within specified clades
for (int i = 0; i < clades.length; i++){
TimeTreeNode clade = (TimeTreeNode) clades[i];
TimeTreeNode oldClade = (TimeTreeNode) clades[i];
// get the clade in the deep copy tree
TimeTreeNode clade = tree.getNodeByIndex(oldClade.getIndex());

double rate = cladeRates[i];
if (includeStem == null || includeStem) {
setRate(clade, rate, true);
Expand All @@ -71,27 +77,19 @@ public Value<Double[]> apply() {
}
}

// initialise the branch rate array
Double[] branchRates = new Double[tree.branchCount()];

for (TimeTreeNode node : tree.getNodes()){ // set the branch rate for rest of the tree
if (! Arrays.asList(cladeRates).contains(node.getBranchRate())){
node.setBranchRate(rootRate);
}

if (! node.isRoot()) { // write the branch rate into the array
int cladeNumber = node.getIndex();
branchRates[cladeNumber] = node.getBranchRate();
}
}

// return to the branch rate list
return new Value<>(branchRates, this);
// return to the tree with branch rates
return new Value<>(null, tree, this);
}

// public for unit test
public void setRate(TimeTreeNode node, double rate, boolean includeNode) {
if (includeNode) {
public void setRate(TimeTreeNode node, double rate, boolean includeStem) {
if (includeStem) {
node.setBranchRate(rate);
}
if (node.getChildCount() == 2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.util.List;
import java.util.Objects;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class LocalClockTest {
Expand Down Expand Up @@ -56,9 +59,22 @@ void setRate() {

@Test
void apply() {
TimeTreeNode node = tree.getNodes().get(4);
TimeTreeNode node2 = tree.getNodes().get(3);
TimeTreeNode[] clades = {node, node2};
TimeTreeNode node1 = null;
TimeTreeNode node2 = null;

for (int i = 0; i<tree.getNodes().size(); i++){
if (Objects.equals(tree.getNodes().get(i).getId(), "4")){ //node2 is the leaf node 4
node2 = tree.getNodes().get(i);
} else if (tree.getNodes().get(i).getAllLeafNodes().size() == 2){ //node1 is the parent of (2,3)
node1 = tree.getNodes().get(i);
}
}

assertEquals(true, node1 != null);
assertEquals(true, node2 != null);


TimeTreeNode[] clades = {node1, node2};
Double[] cladeRates = {0.4,0.3};
double rootRate = 0.2;

Expand All @@ -69,20 +85,39 @@ void apply() {
Value<Boolean> includeStemValue = new Value<>("includeStem" , Boolean.TRUE);

LocalClock instance = new LocalClock(treeValue, cladesValue, cladeRatesValue, rootRateValue, includeStemValue);
Value<Double[]> observe = instance.apply();

Double[] expect = {0.4, 0.4, rootRate, 0.3, 0.4, rootRate};
Value<Double[]> expectValue = new Value<>(null, expect);
for (int i = 0; i<expect.length; i++){
assertEquals(expectValue.value()[i], observe.value()[i]);
Value<TimeTree> observe = instance.apply();
List<TimeTreeNode> allNodes = observe.value().getNodes();

for (int i = 0; i<allNodes.size() - 1; i++){
TimeTreeNode node = allNodes.get(i);
if (node.getId() != null){
if (node.getId().equals("2") || node.getId().equals("3")){ //leaf node 2 and 3 should have branch rate 0.4
assertEquals(0.4, node.getBranchRate());
} else if (node.getId().equals("1")){ // not specified, should be rootRate
assertEquals(rootRate, node.getBranchRate());
} else if (node.getId().equals("4")) { // node2, should be 0.3
assertEquals(0.3, node.getBranchRate());
}
} else if (node.getAllLeafNodes().size() == 2){ // the node should be the parent of 2 and 3
assertEquals(0.4, node.getBranchRate());
}else if (node.getAllLeafNodes().size() == 3){ // the node should be the parent of ((2,3),1)
assertEquals(rootRate, node.getBranchRate());
}
}
}


@Test
void applyExcludeStem() {
TimeTreeNode node = tree.getNodes().get(4);
TimeTreeNode[] clades = {node};
TimeTreeNode clade = null;

for (int i = 0; i<tree.getNodes().size(); i++) {
if (tree.getNodes().get(i).getAllLeafNodes().size() == 2) { //node is the parent of (2,3)
clade = tree.getNodes().get(i);
}
}

TimeTreeNode[] clades = {clade};
Double[] cladeRates = {0.4};
double rootRate = 0.2;

Expand All @@ -93,12 +128,20 @@ void applyExcludeStem() {
Value<Boolean> includeStemValue = new Value<>("includeStem" , Boolean.FALSE);

LocalClock instance = new LocalClock(treeValue, cladesValue, cladeRatesValue, rootRateValue, includeStemValue);
Value<Double[]> observe = instance.apply();

Double[] expect = {0.4, 0.4, rootRate, rootRate, rootRate, rootRate};
Value<Double[]> expectValue = new Value<>(null, expect);
for (int i = 0; i<expect.length; i++){
assertEquals(expectValue.value()[i], observe.value()[i]);
Value<TimeTree> observe = instance.apply();
List<TimeTreeNode> allNodes = observe.value().getNodes();

for (int i = 0; i<allNodes.size() - 1; i++){
TimeTreeNode node = allNodes.get(i);
if (node.getId() != null){
if (node.getId().equals("2") || node.getId().equals("3")){ //leaf node 2 and 3 should have branch rate 0.4
assertEquals(0.4, node.getBranchRate());
} else if (node.getId().equals("1") || node.getId().equals("4")){ // not specified, should be rootRate
assertEquals(rootRate, node.getBranchRate());
}
} else if (node.getAllLeafNodes().size() == 2 || node.getAllLeafNodes().size() == 3){ // internal nodes should all be root rate
assertEquals(rootRate, node.getBranchRate());
}
}
}
}

0 comments on commit 07db152

Please sign in to comment.