Skip to content

Commit

Permalink
Fix negative times for time windows
Browse files Browse the repository at this point in the history
  • Loading branch information
nspope committed Apr 5, 2024
1 parent ef9b885 commit 15fca84
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 58 deletions.
78 changes: 52 additions & 26 deletions python/tests/test_coalrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def test_windows(self):
np.testing.assert_allclose(implm[0], check)
np.testing.assert_allclose(implm[1], check)

def test_time_discretisation(self):
def test_time_windows(self):
"""
┊ 15 pairs ┊
┊ ┏━━┻━━┓ ┊
Expand All @@ -229,10 +229,10 @@ def test_time_discretisation(self):
0.0┊ 0 0 0 0 0 0 0 0 ┊
"""
ts = self.example_ts()
time_discretisation = np.array([0.0, 5.0, 7.0, np.inf])
time_windows = np.array([0.0, 5.0, 7.0, np.inf])
check = np.array([4, 5, 19]) * ts.sequence_length
implm = ts.pair_coalescence_counts(
span_normalise=False, time_discretisation=time_discretisation
span_normalise=False, time_windows=time_windows
)
np.testing.assert_allclose(implm, check)

Expand Down Expand Up @@ -360,7 +360,7 @@ def test_windows(self):
np.testing.assert_allclose(implm[0], check_0)
np.testing.assert_allclose(implm[1], check_1)

def test_time_discretisation(self):
def test_time_windows(self):
"""
┊ 3 pairs 3 ┊
3.5┊-┏━┻━┓---┊-┏━┻━┓---┊
Expand All @@ -375,14 +375,14 @@ def test_time_discretisation(self):
"""
L, S = 200, 100
ts = self.example_ts(S, L)
time_discretisation = np.array([0.0, 1.5, 3.5, np.inf])
time_windows = np.array([0.0, 1.5, 3.5, np.inf])
windows = np.array(list(ts.breakpoints()))
check_0 = np.array([0.0, 3.0, 3.0]) * S
check_1 = np.array([1.0, 2.0, 3.0]) * (L - S)
implm = ts.pair_coalescence_counts(
span_normalise=False,
windows=windows,
time_discretisation=time_discretisation,
time_windows=time_windows,
)
np.testing.assert_allclose(implm[0], check_0)
np.testing.assert_allclose(implm[1], check_1)
Expand Down Expand Up @@ -460,7 +460,7 @@ def test_missing_interval(self):

def test_missing_leaves(self):
"""
test case where two segments have 1/2 of samples missing
test case where 1/2 of samples are missing
"""
t = self.example_ts().dump_tables()
ss0 = np.flatnonzero(t.nodes.population == 0)
Expand Down Expand Up @@ -523,6 +523,9 @@ def test_nonsuccinct_sequence(self):
self._check_subset_pairs(ts, windows)

def test_span_normalise(self):
"""
test case where span is normalised
"""
ts = self.example_ts()
windows = np.array([0.0, 0.33, 1.0]) * ts.sequence_length
window_size = np.diff(windows)
Expand All @@ -546,21 +549,41 @@ def test_internal_nodes_are_samples(self):
self._check_total_pairs(ts_modified, windows)
self._check_subset_pairs(ts_modified, windows)

def test_time_discretisation(self):
def test_time_windows(self):
ts = self.example_ts()
total_pair_count = np.sum(ts.pair_coalescence_counts(span_normalise=False))
samples = list(ts.samples())
time_discretisation = np.quantile(ts.nodes_time, [0.0, 0.25, 0.5, 0.75])
time_discretisation = np.append(time_discretisation, np.inf)
time_windows = np.quantile(ts.nodes_time, [0.0, 0.25, 0.5, 0.75])
time_windows = np.append(time_windows, np.inf)
implm = ts.pair_coalescence_counts(
span_normalise=False, time_discretisation=time_discretisation
span_normalise=False, time_windows=time_windows
)
assert np.isclose(np.sum(implm), total_pair_count)
check = naive_pair_coalescence_counts(ts, samples, samples).squeeze() / 2
nodes_map = (
np.searchsorted(time_discretisation, ts.nodes_time, side="right") - 1
)
nodes_map = np.searchsorted(time_windows, ts.nodes_time, side="right") - 1
check = np.bincount(nodes_map, weights=check)
np.testing.assert_allclose(implm, check)

def test_time_windows_truncated(self):
"""
test case where some nodes fall outside of time bins
"""
ts = self.example_ts()
total_pair_count = np.sum(ts.pair_coalescence_counts(span_normalise=False))
samples = list(ts.samples())
time_windows = np.quantile(ts.nodes_time, [0.5, 0.75])
assert time_windows[0] > 0.0
time_windows = np.append(time_windows, np.inf)
implm = ts.pair_coalescence_counts(
span_normalise=False, time_windows=time_windows
)
assert np.sum(implm) < total_pair_count
check = naive_pair_coalescence_counts(ts, samples, samples).squeeze() / 2
nodes_map = np.searchsorted(time_windows, ts.nodes_time, side="right") - 1
oob = np.logical_or(nodes_map < 0, nodes_map >= time_windows.size)
check = np.bincount(nodes_map[~oob], weights=check[~oob])
np.testing.assert_allclose(implm, check)

def test_diversity(self):
"""
test that weighted mean of node times equals branch diversity
Expand Down Expand Up @@ -615,6 +638,13 @@ def test_unsorted_windows(self):
windows=np.array([0.0, 0.3, 0.2, 1.0]) * ts.sequence_length
)

def test_bad_windows(self):
ts = self.example_ts()
with pytest.raises(ValueError, match="must be an array of breakpoints"):
ts.pair_coalescence_counts(windows="whatever")
with pytest.raises(ValueError, match="must be an array of breakpoints"):
ts.pair_coalescence_counts(windows=np.array([0.0]))

def test_empty_sample_sets(self):
ts = self.example_ts()
with pytest.raises(ValueError, match="contain at least one element"):
Expand Down Expand Up @@ -646,24 +676,20 @@ def test_uncalibrated_time(self):
tables.time_units = tskit.TIME_UNITS_UNCALIBRATED
ts = tables.tree_sequence()
with pytest.raises(ValueError, match="requires calibrated node times"):
ts.pair_coalescence_counts(time_discretisation=np.array([0.0, np.inf]))
ts.pair_coalescence_counts(time_windows=np.array([0.0, np.inf]))

def test_bad_time_discretisation(self):
def test_bad_time_windows(self):
ts = self.example_ts()
with pytest.raises(ValueError, match="must be an array of breakpoints"):
ts.pair_coalescence_counts(time_discretisation="whatever")

def test_oor_time_discretisation(self):
ts = self.example_ts()
time_discretisation = np.array([-3.0, 12.0])
with pytest.raises(ValueError, match="start at zero and end at infinity"):
ts.pair_coalescence_counts(time_discretisation=time_discretisation)
ts.pair_coalescence_counts(time_windows="whatever")
with pytest.raises(ValueError, match="must be an array of breakpoints"):
ts.pair_coalescence_counts(time_windows=np.array([0.0]))

def test_unsorted_time_discretisation(self):
def test_unsorted_time_windows(self):
ts = self.example_ts()
time_discretisation = np.array([0.0, 12.0, 6.0, np.inf])
time_windows = np.array([0.0, 12.0, 6.0, np.inf])
with pytest.raises(ValueError, match="must be strictly increasing"):
ts.pair_coalescence_counts(time_discretisation=time_discretisation)
ts.pair_coalescence_counts(time_windows=time_windows)

def test_output_dim(self):
"""
Expand Down
67 changes: 35 additions & 32 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -9296,7 +9296,7 @@ def pair_coalescence_counts(
indexes=None,
windows=None,
span_normalise=True,
time_discretisation="nodes",
time_windows="nodes",
):
"""
Calculate the number of coalescing sample pairs per node, summed over
Expand All @@ -9314,7 +9314,7 @@ def pair_coalescence_counts(
single sample set and `[(0,1)]` for two sample sets. For more than two
sample sets, `indexes` must be explicitly passed.
The argument `time_discretisation` may be used to count coalescence
The argument `time_windows` may be used to count coalescence
events within time intervals (if an array of breakpoints is supplied)
rather than for individual nodes (the default).
Expand All @@ -9328,7 +9328,7 @@ def pair_coalescence_counts(
sequence windows to compute the statistic in, or None.
:param bool span_normalise: Whether to divide the result by the span of
the window (defaults to True).
:param time_discretisation: Either a string "nodes" or an increasing
:param time_windows: Either a string "nodes" or an increasing
list of breakpoints between time intervals.
"""

Expand Down Expand Up @@ -9361,32 +9361,29 @@ def pair_coalescence_counts(
if windows is None:
drop_left_dimension = True
windows = np.array([0.0, self.sequence_length])
if not (isinstance(windows, np.ndarray) and windows.size > 1):
raise ValueError("Windows must be an array of breakpoints")
if not (windows[0] == 0.0 and windows[-1] == self.sequence_length):
raise ValueError("First and last window breaks must be sequence boundary")
if not (windows[0] == 0.0 and windows[-1] == self.sequence_length):
raise ValueError("First and last window breaks must be sequence boundary")

Check warning on line 9369 in python/tskit/trees.py

View check run for this annotation

Codecov / codecov/patch

python/tskit/trees.py#L9369

Added line #L9369 was not covered by tests
if not np.all(np.diff(windows) > 0):
raise ValueError("Window breaks must be strictly increasing")

if isinstance(time_discretisation, str) and time_discretisation == "nodes":
if isinstance(time_windows, str) and time_windows == "nodes":
nodes_map = np.arange(self.num_nodes)
output_size = self.num_nodes
else:
if not isinstance(time_discretisation, np.ndarray):
raise ValueError("Time discretisation must be an array of breakpoints")
if not (
time_discretisation[0] == 0.0 and time_discretisation[-1] == np.inf
):
raise ValueError(
"Time discretisation must start at zero and end at infinity"
)
if not np.all(np.diff(time_discretisation) > 0):
raise ValueError("Time discretisation must be strictly increasing")
if not (isinstance(time_windows, np.ndarray) and time_windows.size > 1):
raise ValueError("Time windows must be an array of breakpoints")
if not np.all(np.diff(time_windows) > 0):
raise ValueError("Time windows must be strictly increasing")
if self.time_units == tskit.TIME_UNITS_UNCALIBRATED:
raise ValueError("Time discretisation requires calibrated node times")
nodes_map = (
np.searchsorted(time_discretisation, self.nodes_time, side="right") - 1
)
assert nodes_map[0] >= 0 and nodes_map[-1] < time_discretisation.size
output_size = time_discretisation.size - 1
raise ValueError("Time windows requires calibrated node times")
nodes_map = np.searchsorted(time_windows, self.nodes_time, side="right") - 1
nodes_oob = np.logical_or(nodes_map < 0, nodes_map >= time_windows.size)
nodes_map[nodes_oob] = tskit.NULL
output_size = time_windows.size - 1

num_nodes = self.num_nodes
num_edges = self.num_edges
Expand All @@ -9402,14 +9399,14 @@ def pair_coalescence_counts(
remove_position = self.edges_right[remove_index]
sequence_length = self.sequence_length

windows_span = np.zeros(num_windows)
nodes_parent = np.full(num_nodes, tskit.NULL)
nodes_sample = np.zeros((num_nodes, num_sample_sets))
coalescing_pairs = np.zeros((num_windows, output_size, num_indexes))

for i, s in enumerate(sample_sets):
nodes_sample[s, i] = 1
sample_counts = nodes_sample.copy()

position = 0.0
w, a, b = 0, 0, 0
while position < sequence_length:
Expand All @@ -9423,10 +9420,11 @@ def pair_coalescence_counts(
inside = sample_counts[c]
while p != tskit.NULL:
u = nodes_map[p]
outside = sample_counts[p] - sample_counts[c] - nodes_sample[p]
for i, (j, k) in enumerate(indexes):
weight = inside[j] * outside[k] + inside[k] * outside[j]
coalescing_pairs[w, u, i] -= weight * remainder
if u != tskit.NULL:
outside = sample_counts[p] - sample_counts[c] - nodes_sample[p]
for i, (j, k) in enumerate(indexes):
weight = inside[j] * outside[k] + inside[k] * outside[j]
coalescing_pairs[w, u, i] -= weight * remainder
c, p = p, nodes_parent[p]
p = edges_parent[e]
while p != tskit.NULL:
Expand All @@ -9446,10 +9444,11 @@ def pair_coalescence_counts(
p = edges_parent[e]
while p != tskit.NULL:
u = nodes_map[p]
outside = sample_counts[p] - sample_counts[c] - nodes_sample[p]
for i, (j, k) in enumerate(indexes):
weight = inside[j] * outside[k] + inside[k] * outside[j]
coalescing_pairs[w, u, i] += weight * remainder
if u != tskit.NULL:
outside = sample_counts[p] - sample_counts[c] - nodes_sample[p]
for i, (j, k) in enumerate(indexes):
weight = inside[j] * outside[k] + inside[k] * outside[j]
coalescing_pairs[w, u, i] += weight * remainder
c, p = p, nodes_parent[p]
a += 1

Expand All @@ -9460,11 +9459,14 @@ def pair_coalescence_counts(
position = min(position, insert_position[a])

while w < num_windows and windows[w + 1] <= position: # flush window
windows_span[w] -= position - windows[w + 1]
if w + 1 < num_windows:
windows_span[w + 1] += position - windows[w + 1]
remainder = sequence_length - windows[w + 1]
for c, p in enumerate(nodes_parent):
if p == tskit.NULL:
continue
u = nodes_map[p]
if p == tskit.NULL or u == tskit.NULL:
continue
inside = sample_counts[c]
outside = sample_counts[p] - sample_counts[c] - nodes_sample[p]
for i, (j, k) in enumerate(indexes):
Expand All @@ -9479,7 +9481,8 @@ def pair_coalescence_counts(
coalescing_pairs[:, :, i] /= 2
if span_normalise:
for w, s in enumerate(np.diff(windows)):
coalescing_pairs[w] /= s
if s > 0:
coalescing_pairs[w] /= s

if drop_right_dimension:
coalescing_pairs = coalescing_pairs[..., 0]
Expand Down

0 comments on commit 15fca84

Please sign in to comment.