From c7a1177aad4a6ee28a2e9babdafc5cc2458dea5a Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Sun, 31 Mar 2024 22:11:04 -0700 Subject: [PATCH 1/2] Add coalescing_pairs method and tests; remove CoalescenceTimeDistribution Add test against worked example Remove unused imports More efficient windowing Raise errors with invalid inputs Misc. fixes and tests; add time discretisation argument Test nonsuccinct case Fix negative times for time windows Delete accidental line copy Remove unneeded if --- python/tests/test_coalrate.py | 715 +++++++++++ python/tests/test_coaltime_distribution.py | 1313 -------------------- python/tskit/stats.py | 710 +---------- python/tskit/trees.py | 209 +++- 4 files changed, 909 insertions(+), 2038 deletions(-) create mode 100644 python/tests/test_coalrate.py delete mode 100644 python/tests/test_coaltime_distribution.py diff --git a/python/tests/test_coalrate.py b/python/tests/test_coalrate.py new file mode 100644 index 0000000000..95c7b124df --- /dev/null +++ b/python/tests/test_coalrate.py @@ -0,0 +1,715 @@ +# MIT License +# +# Copyright (c) 2024 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test cases for coalescence rate calculation in tskit. +""" +import itertools + +import msprime +import numpy as np +import pytest + +import tests +import tskit + + +def naive_pair_coalescence_counts(ts, sample_set_0, sample_set_1): + """ + Count pairwise coalescences tree by tree, by enumerating nodes in each + tree. For a binary node, the number of pairs of samples that coalesce in a + given node is the product of the number of samples subtended by the left + and right child. For higher arities, the count is summed over all possible + pairs of children. + """ + output = np.zeros(ts.num_nodes) + for t in ts.trees(): + sample_counts = np.zeros((ts.num_nodes, 2), dtype=np.int32) + pair_counts = np.zeros(ts.num_nodes) + for p in t.postorder(): + samples = list(t.samples(p)) + sample_counts[p, 0] = np.intersect1d(samples, sample_set_0).size + sample_counts[p, 1] = np.intersect1d(samples, sample_set_1).size + for i, j in itertools.combinations(t.children(p), 2): + pair_counts[p] += sample_counts[i, 0] * sample_counts[j, 1] + pair_counts[p] += sample_counts[i, 1] * sample_counts[j, 0] + output += pair_counts * t.span + return output + + +def convert_to_nonsuccinct(ts): + """ + Give the edges and internal nodes in each tree distinct IDs + """ + tables = tskit.TableCollection(sequence_length=ts.sequence_length) + for _ in range(ts.num_populations): + tables.populations.add_row() + nodes_count = 0 + for n in ts.samples(): + tables.nodes.add_row( + time=ts.nodes_time[n], + flags=ts.nodes_flags[n], + population=ts.nodes_population[n], + ) + nodes_count += 1 + for t in ts.trees(): + nodes_map = {n: n for n in ts.samples()} + for n in t.nodes(): + if t.num_samples(n) > 1: + tables.nodes.add_row( + time=ts.nodes_time[n], + flags=ts.nodes_flags[n], + population=ts.nodes_population[n], + ) + nodes_map[n] = nodes_count + nodes_count += 1 + for n in t.nodes(): + if t.edge(n) != tskit.NULL: + tables.edges.add_row( + parent=nodes_map[t.parent(n)], + child=nodes_map[n], + left=t.interval.left, + right=t.interval.right, + ) + tables.sort() + ts_unroll = tables.tree_sequence() + assert nodes_count == ts_unroll.num_nodes + return ts_unroll + + +class TestCoalescingPairsOneTree: + """ + Test against worked example (single tree) + """ + + def example_ts(self): + """ + 10.0┊ 13 ┊ + ┊ ┏━━┻━━┓ ┊ + 8.0┊ 12 ┃ ┊ + ┊ ┏━┻━┓ ┃ ┊ + 6.0┊ 11 ┃ ┃ ┊ + ┊ ┏━━╋━┓ ┃ ┃ ┊ + 2.0┊ 10 ┃ ┃ ┃ 9 ┊ + ┊ ┏┻┓ ┃ ┃ ┃ ┏┻━┓ ┊ + 1.0┊ ┃ ┃ ┃ ┃ ┃ 8 ┃ ┊ + ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ + 0.0┊ 0 7 4 5 6 1 2 3 ┊ + ┊ A A A A B B B B ┊ + """ + tables = tskit.TableCollection(sequence_length=100) + tables.nodes.set_columns( + time=np.array([0] * 8 + [1, 2, 2, 6, 8, 10]), + flags=np.repeat([1, 0], [8, 6]).astype("uint32"), + ) + tables.edges.set_columns( + left=np.repeat([0], 13), + right=np.repeat([100], 13), + parent=np.array( + [8, 8, 9, 9, 10, 10, 11, 11, 11, 12, 12, 13, 13], dtype="int32" + ), + child=np.array([1, 2, 3, 8, 0, 7, 4, 5, 10, 6, 11, 9, 12], dtype="int32"), + ) + tables.populations.add_row() + tables.populations.add_row() + tables.nodes.population = np.array( + [0, 1, 1, 1, 0, 0, 1, 0] + [tskit.NULL] * 6, dtype="int32" + ) + return tables.tree_sequence() + + def test_total_pairs(self): + """ + ┊ 15 pairs ┊ + ┊ ┏━━┻━━┓ ┊ + ┊ 4 ┃ ┊ + ┊ ┏━┻━┓ ┃ ┊ + ┊ 5 ┃ ┃ ┊ + ┊ ┏━━╋━┓ ┃ ┃ ┊ + ┊ 1 ┃ ┃ ┃ 2 ┊ + ┊ ┏┻┓ ┃ ┃ ┃ ┏┻━┓ ┊ + ┊ ┃ ┃ ┃ ┃ ┃ 1 ┃ ┊ + ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ + ┊ 0 0 0 0 0 0 0 0 ┊ + """ + ts = self.example_ts() + check = np.array([0.0] * 8 + [1, 2, 1, 5, 4, 15]) + implm = ts.pair_coalescence_counts() + np.testing.assert_allclose(implm, check) + + def test_population_pairs(self): + """ + ┊ AA 0 pairs ┊ AB 12 pairs ┊ BB 3 pairs ┊ + ┊ ┏━━┻━━┓ ┊ ┏━━┻━━┓ ┊ ┏━━┻━━┓ ┊ + ┊ 0 ┃ ┊ 4 ┃ ┊ 0 ┃ ┊ + ┊ ┏━┻━┓ ┃ ┊ ┏━┻━┓ ┃ ┊ ┏━┻━┓ ┃ ┊ + ┊ 5 ┃ ┃ ┊ 0 ┃ ┃ ┊ 0 ┃ ┃ ┊ + ┊ ┏━━╋━┓ ┃ ┃ ┊ ┏━━╋━┓ ┃ ┃ ┊ ┏━━╋━┓ ┃ ┃ ┊ + ┊ 1 ┃ ┃ ┃ 0 ┊ 0 ┃ ┃ ┃ 0 ┊ 0 ┃ ┃ ┃ 2 ┊ + ┊ ┏┻┓ ┃ ┃ ┃ ┏┻━┓ ┊ ┏┻┓ ┃ ┃ ┃ ┏┻━┓ ┊ ┏┻┓ ┃ ┃ ┃ ┏┻━┓ ┊ + ┊ ┃ ┃ ┃ ┃ ┃ 0 ┃ ┊ ┃ ┃ ┃ ┃ ┃ 0 ┃ ┊ ┃ ┃ ┃ ┃ ┃ 1 ┃ ┊ + ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ + ┊ A A A A B B B B ┊ A A A A B B B B ┊ A A A A B B B B ┊ + """ + ts = self.example_ts() + ss0 = np.flatnonzero(ts.nodes_population == 0) + ss1 = np.flatnonzero(ts.nodes_population == 1) + indexes = [(0, 0), (0, 1), (1, 1)] + implm = ts.pair_coalescence_counts(sample_sets=[ss0, ss1], indexes=indexes) + check = np.full(implm.shape, np.nan) + check[:, 0] = np.array([0.0] * 8 + [0, 0, 1, 5, 0, 0]) + check[:, 1] = np.array([0.0] * 8 + [0, 0, 0, 0, 4, 12]) + check[:, 2] = np.array([0.0] * 8 + [1, 2, 0, 0, 0, 3]) + np.testing.assert_allclose(implm, check) + + def test_internal_samples(self): + """ + ┊ Not ┊ 24 pairs ┊ + ┊ ┏━━┻━━┓ ┊ ┏━━┻━━┓ ┊ + ┊ N ┃ ┊ 5 ┃ ┊ + ┊ ┏━┻━┓ ┃ ┊ ┏━┻━┓ ┃ ┊ + ┊ S ┃ ┃ ┊ 5 ┃ ┃ ┊ + ┊ ┏━━╋━┓ ┃ ┃ ┊ ┏━━╋━┓ ┃ ┃ ┊ + ┊ N ┃ ┃ ┃ Samp ┊ 1 ┃ ┃ ┃ 2 ┊ + ┊ ┏┻┓ ┃ ┃ ┃ ┏┻━┓ ┊ ┏┻┓ ┃ ┃ ┃ ┏┻━┓ ┊ + ┊ ┃ ┃ ┃ ┃ ┃ N ┃ ┊ ┃ ┃ ┃ ┃ ┃ 1 ┃ ┊ + ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ + ┊ S S S S S S S S ┊ 0 0 0 0 0 0 0 0 ┊ + """ + ts = self.example_ts() + tables = ts.dump_tables() + nodes_flags = tables.nodes.flags.copy() + nodes_flags[9] = tskit.NODE_IS_SAMPLE + nodes_flags[11] = tskit.NODE_IS_SAMPLE + tables.nodes.flags = nodes_flags + ts = tables.tree_sequence() + assert ts.num_samples == 10 + implm = ts.pair_coalescence_counts(span_normalise=False) + check = np.array([0] * 8 + [1, 2, 1, 5, 5, 24]) * ts.sequence_length + np.testing.assert_allclose(implm, check) + + def test_windows(self): + ts = self.example_ts() + check = np.array([0.0] * 8 + [1, 2, 1, 5, 4, 15]) * ts.sequence_length / 2 + implm = ts.pair_coalescence_counts( + windows=np.linspace(0, ts.sequence_length, 3), span_normalise=False + ) + np.testing.assert_allclose(implm[0], check) + np.testing.assert_allclose(implm[1], check) + + def test_time_windows(self): + """ + ┊ 15 pairs ┊ + ┊ ┏━━┻━━┓ ┊ + ┊ 4 ┃ ┊ + 7.0┊-----┏━┻━┓---┃----┊ + ┊ 5 ┃ ┃ ┊ + 5.0┊--┏━━╋━┓-┃---┃----┊ + ┊ 1 ┃ ┃ ┃ 2 ┊ + ┊ ┏┻┓ ┃ ┃ ┃ ┏┻━┓ ┊ + ┊ ┃ ┃ ┃ ┃ ┃ 1 ┃ ┊ + ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ + 0.0┊ 0 0 0 0 0 0 0 0 ┊ + """ + ts = self.example_ts() + time_windows = np.array([0.0, 5.0, 7.0, np.inf]) + check = np.array([4, 5, 19]) * ts.sequence_length + implm = ts.pair_coalescence_counts( + span_normalise=False, time_windows=time_windows + ) + np.testing.assert_allclose(implm, check) + + +class TestCoalescingPairsTwoTree: + """ + Test against worked example (two trees) + """ + + def example_ts(self, S, L): + """ + 0 S L + 4.0┊ 7 ┊ 7 ┊ + ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ + 3.0┊ ┃ 6 ┊ ┃ ┃ ┊ + ┊ ┃ ┏━┻┓ ┊ ┃ ┃ ┊ + 2.0┊ ┃ ┃ 5 ┊ ┃ 5 ┊ + ┊ ┃ ┃ ┏┻┓ ┊ ┃ ┏┻━┓ ┊ + 1.0┊ ┃ ┃ ┃ ┃ ┊ ┃ 4 ┃ ┊ + ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┏┻┓ ┃ ┊ + 0.0┊ 0 1 2 3 ┊ 0 1 2 3 ┊ + A A B B A A B B + """ + tables = tskit.TableCollection(sequence_length=L) + tables.nodes.set_columns( + time=np.array([0, 0, 0, 0, 1.0, 2.0, 3.0, 4.0]), + flags=np.array([1, 1, 1, 1, 0, 0, 0, 0], dtype="uint32"), + ) + tables.edges.set_columns( + left=np.array([S, S, 0, 0, S, 0, 0, 0, S, 0]), + right=np.array([L, L, S, L, L, S, S, L, L, S]), + parent=np.array([4, 4, 5, 5, 5, 6, 6, 7, 7, 7], dtype="int32"), + child=np.array([1, 2, 2, 3, 4, 1, 5, 0, 5, 6], dtype="int32"), + ) + return tables.tree_sequence() + + def test_total_pairs(self): + """ + ┊ 3 pairs 3 ┊ + ┊ ┏━┻━┓ ┏━┻━┓ ┊ + ┊ ┃ 2 ┃ ┃ ┊ + ┊ ┃ ┏━┻┓ ┃ ┃ ┊ + ┊ ┃ ┃ 1 ┃ 2 ┊ + ┊ ┃ ┃ ┏┻┓ ┃ ┏┻━┓ ┊ + ┊ ┃ ┃ ┃ ┃ ┃ 1 ┃ ┊ + ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ + ┊ 0 0 0 0 0 0 0 0 ┊ + 0 S L + """ + L, S = 1e8, 1.0 + ts = self.example_ts(S, L) + implm = ts.pair_coalescence_counts(span_normalise=False) + check = np.array([0] * 4 + [1 * (L - S), 2 * (L - S) + 1 * S, 2 * S, 3 * L]) + np.testing.assert_allclose(implm, check) + + def test_population_pairs(self): + """ + ┊AA ┊AB ┊BB ┊ + ┊ 1 pairs 1 ┊ 2 pairs 2 ┊ 0 pairs 0 ┊ + ┊ ┏━┻━┓ ┏━┻━┓ ┊ ┏━┻━┓ ┏━┻━┓ ┊ ┏━┻━┓ ┏━┻━┓ ┊ + ┊ ┃ 0 ┃ ┃ ┊ ┃ 2 ┃ ┃ ┊ ┃ 0 ┃ ┃ ┊ + ┊ ┃ ┏━┻┓ ┃ ┃ ┊ ┃ ┏━┻┓ ┃ ┃ ┊ ┃ ┏━┻┓ ┃ ┃ ┊ + ┊ ┃ ┃ 0 ┃ 0 ┊ ┃ ┃ 0 ┃ 1 ┊ ┃ ┃ 1 ┃ 1 ┊ + ┊ ┃ ┃ ┏┻┓ ┃ ┏┻━┓ ┊ ┃ ┃ ┏┻┓ ┃ ┏┻━┓ ┊ ┃ ┃ ┏┻┓ ┃ ┏┻━┓ ┊ + ┊ ┃ ┃ ┃ ┃ ┃ 0 ┃ ┊ ┃ ┃ ┃ ┃ ┃ 1 ┃ ┊ ┃ ┃ ┃ ┃ ┃ 0 ┃ ┊ + ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ + ┊ A A B B A A B B ┊ A A B B A A B B ┊ A A B B A A B B ┊ + 0 S L S L S L + """ + L, S = 1e8, 1.0 + ts = self.example_ts(S, L) + indexes = [(0, 0), (0, 1), (1, 1)] + implm = ts.pair_coalescence_counts( + sample_sets=[[0, 1], [2, 3]], indexes=indexes, span_normalise=False + ) + check = np.empty(implm.shape) + check[:, 0] = np.array([0] * 4 + [0, 0, 0, 1 * L]) + check[:, 1] = np.array([0] * 4 + [1 * (L - S), 1 * (L - S), 2 * S, 2 * L]) + check[:, 2] = np.array([0] * 4 + [0, 1 * L, 0, 0]) + np.testing.assert_allclose(implm, check) + + def test_internal_samples(self): + """ + ┊ Not N ┊ 4 pairs 4 ┊ + ┊ ┏━┻━┓ ┏━┻━┓ ┊ ┏━┻━┓ ┏━┻━┓ ┊ + ┊ ┃ N ┃ ┃ ┊ ┃ 3 ┃ ┃ ┊ + ┊ ┃ ┏━┻┓ ┃ ┃ ┊ ┃ ┏━┻┓ ┃ ┃ ┊ + ┊ ┃ ┃ Samp ┃ S ┊ ┃ ┃ 1 ┃ 2 ┊ + ┊ ┃ ┃ ┏┻┓ ┃ ┏┻━┓ ┊ ┃ ┃ ┏┻┓ ┃ ┏┻━┓ ┊ + ┊ ┃ ┃ ┃ ┃ ┃ N ┃ ┊ ┃ ┃ ┃ ┃ ┃ 1 ┃ ┊ + ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ + ┊ S S S S S S S S ┊ 0 0 0 0 0 0 0 0 ┊ + """ + L, S = 200, 100 + ts = self.example_ts(S, L) + tables = ts.dump_tables() + nodes_flags = tables.nodes.flags.copy() + nodes_flags[5] = tskit.NODE_IS_SAMPLE + tables.nodes.flags = nodes_flags + ts = tables.tree_sequence() + assert ts.num_samples == 5 + implm = ts.pair_coalescence_counts(span_normalise=False) + check = np.array([0.0] * 4 + [(L - S), S + 2 * (L - S), 3 * S, 4 * L]) + np.testing.assert_allclose(implm, check) + + def test_windows(self): + """ + ┊ 3 pairs 3 ┊ + ┊ ┏━┻━┓ ┏━┻━┓ ┊ + ┊ ┃ 2 ┃ ┃ ┊ + ┊ ┃ ┏━┻┓ ┃ ┃ ┊ + ┊ ┃ ┃ 1 ┃ 2 ┊ + ┊ ┃ ┃ ┏┻┓ ┃ ┏┻━┓ ┊ + ┊ ┃ ┃ ┃ ┃ ┃ 1 ┃ ┊ + ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ + ┊ 0 0 0 0 0 0 0 0 ┊ + 0 S L + """ + L, S = 200, 100 + ts = self.example_ts(S, L) + windows = np.array(list(ts.breakpoints())) + check_0 = np.array([0.0] * 4 + [0, 1, 2, 3]) * S + check_1 = np.array([0.0] * 4 + [1, 2, 0, 3]) * (L - S) + implm = ts.pair_coalescence_counts(windows=windows, span_normalise=False) + np.testing.assert_allclose(implm[0], check_0) + np.testing.assert_allclose(implm[1], check_1) + + def test_time_windows(self): + """ + ┊ 3 pairs 3 ┊ + 3.5┊-┏━┻━┓---┊-┏━┻━┓---┊ + ┊ ┃ 2 ┊ ┃ ┃ ┊ + ┊ ┃ ┏━┻┓ ┊ ┃ ┃ ┊ + ┊ ┃ ┃ 1 ┊ ┃ 2 ┊ + 1.5┊-┃-┃-┏┻┓-┊-┃--┏┻━┓-┊ + ┊ ┃ ┃ ┃ ┃ ┊ ┃ 1 ┃ ┊ + ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┏┻┓ ┃ ┊ + 0.0┊ 0 0 0 0 ┊ 0 0 0 0 ┊ + 0 S L + """ + L, S = 200, 100 + ts = self.example_ts(S, L) + time_windows = np.array([0.0, 1.5, 3.5, np.inf]) + windows = np.array(list(ts.breakpoints())) + check_0 = np.array([0.0, 3.0, 3.0]) * S + check_1 = np.array([1.0, 2.0, 3.0]) * (L - S) + implm = ts.pair_coalescence_counts( + span_normalise=False, + windows=windows, + time_windows=time_windows, + ) + np.testing.assert_allclose(implm[0], check_0) + np.testing.assert_allclose(implm[1], check_1) + + +class TestCoalescingPairsSimulated: + """ + Test against a naive implementation on simulated data. + """ + + @tests.cached_example + def example_ts(self): + n = 10 + model = msprime.BetaCoalescent(alpha=1.5) # polytomies + tables = msprime.sim_ancestry( + samples=n, + recombination_rate=1e-8, + sequence_length=1e6, + population_size=1e4, + random_seed=1024, + model=model, + ).dump_tables() + tables.populations.add_row(metadata={"name": "foo", "description": "bar"}) + tables.nodes.population = np.repeat( + [0, 1, tskit.NULL], [n, n, tables.nodes.num_rows - 2 * n] + ).astype("int32") + ts = tables.tree_sequence() + assert ts.num_trees > 1 + return ts + + @staticmethod + def _check_total_pairs(ts, windows): + samples = list(ts.samples()) + implm = ts.pair_coalescence_counts(windows=windows, span_normalise=False) + dim = (windows.size - 1, ts.num_nodes) + check = np.full(dim, np.nan) + for w, (a, b) in enumerate(zip(windows[:-1], windows[1:])): + tsw = ts.keep_intervals(np.array([[a, b]]), simplify=False) + check[w] = naive_pair_coalescence_counts(tsw, samples, samples) / 2 + np.testing.assert_allclose(implm, check) + + @staticmethod + def _check_subset_pairs(ts, windows): + ss0 = np.flatnonzero(ts.nodes_population == 0) + ss1 = np.flatnonzero(ts.nodes_population == 1) + idx = [(0, 1), (1, 1), (0, 0)] + implm = ts.pair_coalescence_counts( + sample_sets=[ss0, ss1], indexes=idx, windows=windows, span_normalise=False + ) + dim = (windows.size - 1, ts.num_nodes, len(idx)) + check = np.full(dim, np.nan) + for w, (a, b) in enumerate(zip(windows[:-1], windows[1:])): + tsw = ts.keep_intervals(np.array([[a, b]]), simplify=False) + check[w, :, 0] = naive_pair_coalescence_counts(tsw, ss0, ss1) + check[w, :, 1] = naive_pair_coalescence_counts(tsw, ss1, ss1) / 2 + check[w, :, 2] = naive_pair_coalescence_counts(tsw, ss0, ss0) / 2 + np.testing.assert_allclose(implm, check) + + def test_sequence(self): + ts = self.example_ts() + windows = np.array([0.0, ts.sequence_length]) + self._check_total_pairs(ts, windows) + self._check_subset_pairs(ts, windows) + + def test_missing_interval(self): + """ + test case where three segments have all samples missing + """ + ts = self.example_ts() + windows = np.array([0.0, ts.sequence_length]) + intervals = np.array([[0.0, 0.1], [0.4, 0.6], [0.9, 1.0]]) * ts.sequence_length + ts = ts.delete_intervals(intervals) + self._check_total_pairs(ts, windows) + self._check_subset_pairs(ts, windows) + + def test_missing_leaves(self): + """ + test case where 1/2 of samples are missing + """ + t = self.example_ts().dump_tables() + ss0 = np.flatnonzero(t.nodes.population == 0) + remove = np.in1d(t.edges.child, ss0) + assert np.any(remove) + t.edges.set_columns( + left=t.edges.left[~remove], + right=t.edges.right[~remove], + parent=t.edges.parent[~remove], + child=t.edges.child[~remove], + ) + t.sort() + ts = t.tree_sequence() + windows = np.array([0.0, ts.sequence_length]) + self._check_total_pairs(ts, windows) + self._check_subset_pairs(ts, windows) + + def test_missing_roots(self): + """ + test case where all trees have multiple roots + """ + ts = self.example_ts() + ts = ts.decapitate(np.quantile(ts.nodes_time, 0.75)) + windows = np.array([0.0, ts.sequence_length]) + self._check_total_pairs(ts, windows) + self._check_subset_pairs(ts, windows) + + def test_windows(self): + ts = self.example_ts() + windows = np.linspace(0.0, ts.sequence_length, 9) + self._check_total_pairs(ts, windows) + self._check_subset_pairs(ts, windows) + + def test_windows_are_trees(self): + """ + test case where window breakpoints coincide with tree breakpoints + """ + ts = self.example_ts() + windows = np.array(list(ts.breakpoints())) + self._check_total_pairs(ts, windows) + self._check_subset_pairs(ts, windows) + + def test_windows_inside_trees(self): + """ + test case where windows are nested within trees + """ + ts = self.example_ts() + windows = np.array(list(ts.breakpoints())) + windows = np.sort(np.append(windows[:-1] / 2 + windows[1:] / 2, windows)) + self._check_total_pairs(ts, windows) + self._check_subset_pairs(ts, windows) + + def test_nonsuccinct_sequence(self): + """ + test case where each tree has distinct nodes + """ + ts = convert_to_nonsuccinct(self.example_ts()) + windows = np.linspace(0, ts.sequence_length, 9) + self._check_total_pairs(ts, windows) + self._check_subset_pairs(ts, windows) + + def test_span_normalise(self): + """ + test case where span is normalised + """ + ts = self.example_ts() + windows = np.array([0.0, 0.33, 1.0]) * ts.sequence_length + window_size = np.diff(windows) + implm = ts.pair_coalescence_counts(windows=windows, span_normalise=False) + check = ts.pair_coalescence_counts(windows=windows) * window_size[:, np.newaxis] + np.testing.assert_allclose(implm, check) + + def test_internal_nodes_are_samples(self): + """ + test case where some samples are descendants of other samples + """ + ts = self.example_ts() + tables = ts.dump_tables() + nodes_flags = tables.nodes.flags.copy() + nodes_sample = np.arange(ts.num_samples, ts.num_nodes, 10) + nodes_flags[nodes_sample] = tskit.NODE_IS_SAMPLE + tables.nodes.flags = nodes_flags + ts_modified = tables.tree_sequence() + assert ts_modified.num_samples > ts.num_samples + windows = np.linspace(0.0, 1.0, 9) * ts_modified.sequence_length + self._check_total_pairs(ts_modified, windows) + self._check_subset_pairs(ts_modified, windows) + + def test_time_windows(self): + ts = self.example_ts() + total_pair_count = np.sum(ts.pair_coalescence_counts(span_normalise=False)) + samples = list(ts.samples()) + time_windows = np.quantile(ts.nodes_time, [0.0, 0.25, 0.5, 0.75]) + time_windows = np.append(time_windows, np.inf) + implm = ts.pair_coalescence_counts( + span_normalise=False, time_windows=time_windows + ) + assert np.isclose(np.sum(implm), total_pair_count) + check = naive_pair_coalescence_counts(ts, samples, samples).squeeze() / 2 + nodes_map = np.searchsorted(time_windows, ts.nodes_time, side="right") - 1 + check = np.bincount(nodes_map, weights=check) + np.testing.assert_allclose(implm, check) + + def test_time_windows_truncated(self): + """ + test case where some nodes fall outside of time bins + """ + ts = self.example_ts() + total_pair_count = np.sum(ts.pair_coalescence_counts(span_normalise=False)) + samples = list(ts.samples()) + time_windows = np.quantile(ts.nodes_time, [0.5, 0.75]) + assert time_windows[0] > 0.0 + time_windows = np.append(time_windows, np.inf) + implm = ts.pair_coalescence_counts( + span_normalise=False, time_windows=time_windows + ) + assert np.sum(implm) < total_pair_count + check = naive_pair_coalescence_counts(ts, samples, samples).squeeze() / 2 + nodes_map = np.searchsorted(time_windows, ts.nodes_time, side="right") - 1 + oob = np.logical_or(nodes_map < 0, nodes_map >= time_windows.size) + check = np.bincount(nodes_map[~oob], weights=check[~oob]) + np.testing.assert_allclose(implm, check) + + def test_diversity(self): + """ + test that weighted mean of node times equals branch diversity + """ + ts = self.example_ts() + windows = np.linspace(0.0, ts.sequence_length, 9) + check = ts.diversity(mode="branch", windows=windows) + implm = ts.pair_coalescence_counts(windows=windows) + implm = 2 * (implm @ ts.nodes_time) / implm.sum(axis=1) + np.testing.assert_allclose(implm, check) + + def test_divergence(self): + """ + test that weighted mean of node times equals branch divergence + """ + ts = self.example_ts() + ss0 = np.flatnonzero(ts.nodes_population == 0) + ss1 = np.flatnonzero(ts.nodes_population == 1) + windows = np.linspace(0.0, ts.sequence_length, 9) + check = ts.divergence(sample_sets=[ss0, ss1], mode="branch", windows=windows) + implm = ts.pair_coalescence_counts(sample_sets=[ss0, ss1], windows=windows) + implm = 2 * (implm @ ts.nodes_time) / implm.sum(axis=1) + np.testing.assert_allclose(implm, check) + + +class TestCoalescingPairsUsage: + """ + Test invalid inputs + """ + + @tests.cached_example + def example_ts(self): + return msprime.sim_ancestry( + samples=10, + recombination_rate=1e-8, + sequence_length=1e5, + population_size=1e4, + random_seed=1024, + ) + + def test_oor_windows(self): + ts = self.example_ts() + with pytest.raises(ValueError, match="must be sequence boundary"): + ts.pair_coalescence_counts( + windows=np.array([0.0, 2.0]) * ts.sequence_length + ) + + def test_unsorted_windows(self): + ts = self.example_ts() + with pytest.raises(ValueError, match="must be strictly increasing"): + ts.pair_coalescence_counts( + windows=np.array([0.0, 0.3, 0.2, 1.0]) * ts.sequence_length + ) + + def test_bad_windows(self): + ts = self.example_ts() + with pytest.raises(ValueError, match="must be an array of breakpoints"): + ts.pair_coalescence_counts(windows="whatever") + with pytest.raises(ValueError, match="must be an array of breakpoints"): + ts.pair_coalescence_counts(windows=np.array([0.0])) + + def test_empty_sample_sets(self): + ts = self.example_ts() + with pytest.raises(ValueError, match="contain at least one element"): + ts.pair_coalescence_counts(sample_sets=[[0, 1, 2], []]) + + def test_oob_sample_sets(self): + ts = self.example_ts() + with pytest.raises(ValueError, match="is out of bounds"): + ts.pair_coalescence_counts(sample_sets=[[0, ts.num_nodes]]) + + def test_nonbinary_indexes(self): + ts = self.example_ts() + with pytest.raises(ValueError, match="must be length two"): + ts.pair_coalescence_counts(indexes=[(0, 0, 0)]) + + def test_oob_indexes(self): + ts = self.example_ts() + with pytest.raises(ValueError, match="is out of bounds"): + ts.pair_coalescence_counts(indexes=[(0, 1)]) + + def test_no_indexes(self): + ts = self.example_ts() + ss = [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + with pytest.raises(ValueError, match="more than two sample sets"): + ts.pair_coalescence_counts(sample_sets=ss) + + def test_uncalibrated_time(self): + tables = self.example_ts().dump_tables() + tables.time_units = tskit.TIME_UNITS_UNCALIBRATED + ts = tables.tree_sequence() + with pytest.raises(ValueError, match="requires calibrated node times"): + ts.pair_coalescence_counts(time_windows=np.array([0.0, np.inf])) + + def test_bad_time_windows(self): + ts = self.example_ts() + with pytest.raises(ValueError, match="must be an array of breakpoints"): + ts.pair_coalescence_counts(time_windows="whatever") + with pytest.raises(ValueError, match="must be an array of breakpoints"): + ts.pair_coalescence_counts(time_windows=np.array([0.0])) + + def test_unsorted_time_windows(self): + ts = self.example_ts() + time_windows = np.array([0.0, 12.0, 6.0, np.inf]) + with pytest.raises(ValueError, match="must be strictly increasing"): + ts.pair_coalescence_counts(time_windows=time_windows) + + def test_output_dim(self): + """ + test that output dimensions corresponding to None arguments are dropped + """ + ts = self.example_ts() + ss = [[0, 1, 2], [3, 4, 5]] + implm = ts.pair_coalescence_counts(sample_sets=ss, windows=None, indexes=None) + assert implm.shape == (ts.num_nodes,) + windows = np.linspace(0.0, ts.sequence_length, 2) + implm = ts.pair_coalescence_counts( + sample_sets=ss, windows=windows, indexes=None + ) + assert implm.shape == (1, ts.num_nodes) + indexes = [(0, 1)] + implm = ts.pair_coalescence_counts( + sample_sets=ss, windows=windows, indexes=indexes + ) + assert implm.shape == (1, ts.num_nodes, 1) + implm = ts.pair_coalescence_counts( + sample_sets=ss, windows=None, indexes=indexes + ) + assert implm.shape == (ts.num_nodes, 1) diff --git a/python/tests/test_coaltime_distribution.py b/python/tests/test_coaltime_distribution.py deleted file mode 100644 index 715677d99e..0000000000 --- a/python/tests/test_coaltime_distribution.py +++ /dev/null @@ -1,1313 +0,0 @@ -# MIT License -# -# Copyright (c) 2018-2023 Tskit Developers -# Copyright (C) 2016 University of Oxford -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -""" -Test cases for coalescence time distribution objects in tskit. -""" -import msprime -import numpy as np -import pytest - -import tests -import tskit - - -class TestCoalescenceTimeDistribution: - """ - Tree sequences used in tests of classes `CoalescenceTimeTable` and - `CoalescenceTimeDistribution` - """ - - @tests.cached_example - def ts_multimerger_six_leaves(self): - """ - 29.00┊ 9 ┊ - ┊ ┏━━┻━━┓ ┊ - 8.00 ┊ ┃ 8 ┊ - ┊ ┃ ┏━┻━━┓ ┊ - 5.00 ┊ ┃ 7 ┃ ┊ - ┊ ┃ ┏━╋━┓ ┃ ┊ - 1.00 ┊ ┃ ┃ ┃ ┃ 6 ┊ - ┊ ┃ ┃ ┃ ┃ ┏┻┓ ┊ - 0.00 ┊ 0 1 2 4 3 5 ┊ - 0 100 - """ - tables = tskit.TableCollection(sequence_length=100) - tables.nodes.set_columns( - time=np.array([0, 0, 0, 0, 0, 0, 1, 5, 8, 29]), - flags=np.array([1, 1, 1, 1, 1, 1, 0, 0, 0, 0], dtype="uint32"), - ) - tables.edges.set_columns( - left=np.repeat([0.0], 9), - right=np.repeat([100.0], 9), - parent=np.array([6, 6, 7, 7, 7, 8, 8, 9, 9], dtype="int32"), - child=np.array([3, 5, 1, 2, 4, 6, 7, 0, 8], dtype="int32"), - ) - return tables.tree_sequence() - - @tests.cached_example - def ts_multimerger_eight_leaves(self): - """ - 10.00┊ 13 ┊ - ┊ ┏━━┻━━┓ ┊ - 8.00 ┊ 12 ┃ ┊ - ┊ ┏━┻━┓ ┃ ┊ - 6.00 ┊ 11 ┃ ┃ ┊ - ┊ ┏━━╋━┓ ┃ ┃ ┊ - 2.00 ┊ 10 ┃ ┃ ┃ 9 ┊ - ┊ ┏┻┓ ┃ ┃ ┃ ┏┻━┓ ┊ - 1.00 ┊ ┃ ┃ ┃ ┃ ┃ 8 ┃ ┊ - ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ - 0.00 ┊ 0 7 4 5 6 1 2 3 ┊ - 0 100 - """ - tables = tskit.TableCollection(sequence_length=100) - tables.nodes.set_columns( - time=np.array([0] * 8 + [1, 2, 2, 6, 8, 10]), - flags=np.repeat([1, 0], [8, 6]).astype("uint32"), - ) - tables.edges.set_columns( - left=np.repeat([0], 13), - right=np.repeat([100], 13), - parent=np.array( - [8, 8, 9, 9, 10, 10, 11, 11, 11, 12, 12, 13, 13], dtype="int32" - ), - child=np.array([1, 2, 3, 8, 0, 7, 4, 5, 10, 6, 11, 9, 12], dtype="int32"), - ) - return tables.tree_sequence() - - @tests.cached_example - def ts_two_trees_four_leaves(self): - """ - 1.74┊ 7 ┊ 7 ┊ - ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ - 0.73┊ ┃ 6 ┊ ┃ ┃ ┊ - ┊ ┃ ┏━┻┓ ┊ ┃ ┃ ┊ - 0.59┊ ┃ ┃ 5 ┊ ┃ 5 ┊ - ┊ ┃ ┃ ┏┻┓ ┊ ┃ ┏┻━┓ ┊ - 0.54┊ ┃ ┃ ┃ ┃ ┊ ┃ 4 ┃ ┊ - ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┏┻┓ ┃ ┊ - 0.00┊ 0 1 2 3 ┊ 0 1 2 3 ┊ - 0.00 0.88 1.00 - """ - tables = tskit.TableCollection(sequence_length=1) - tables.nodes.set_columns( - time=np.array([0, 0, 0, 0, 0.54, 0.59, 0.73, 1.74]), - flags=np.array([1, 1, 1, 1, 0, 0, 0, 0], dtype="uint32"), - ) - tables.edges.set_columns( - left=np.array([0.88, 0.88, 0, 0, 0.88, 0, 0, 0, 0.88, 0]), - right=np.array([1, 1, 0.88, 1, 1, 0.88, 0.88, 1, 1, 0.88]), - parent=np.array([4, 4, 5, 5, 5, 6, 6, 7, 7, 7], dtype="int32"), - child=np.array([1, 2, 2, 3, 4, 1, 5, 0, 5, 6], dtype="int32"), - ) - return tables.tree_sequence() - - @tests.cached_example - def ts_five_trees_three_leaves(self): - tables = tskit.TableCollection(sequence_length=1) - tables.nodes.set_columns( - time=np.array([0.0, 0.0, 0.0, 0.05, 0.15, 1.13, 4.21, 7.53]), - flags=np.array([1, 1, 1, 0, 0, 0, 0, 0], dtype="uint32"), - ) - tables.edges.set_columns( - left=np.array([0.4, 0.4, 0, 0, 0.4, 0, 0, 0.2, 0.2, 0.1, 0.3, 0.1, 0.3]), - right=np.array([1, 1, 1, 0.4, 1, 0.1, 0.1, 0.3, 0.3, 0.2, 0.4, 0.2, 0.4]), - parent=np.array([3, 3, 4, 4, 4, 5, 5, 6, 6, 7, 7, 7, 7], dtype="int32"), - child=np.array([0, 2, 1, 2, 3, 0, 4, 0, 4, 0, 0, 4, 4], dtype="int32"), - ) - return tables.tree_sequence() - - @tests.cached_example - def ts_eight_trees_two_leaves(self): - tables = tskit.TableCollection(sequence_length=8) - tables.nodes.set_columns( - time=np.array([0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]), - flags=np.array([1, 1, 0, 0, 0, 0, 0, 0, 0], dtype="uint32"), - ) - tables.edges.set_columns( - left=np.array([6, 6, 7, 7, 5, 5, 1, 1, 0, 2, 0, 2, 3, 3, 4, 4]), - right=np.array([7, 7, 8, 8, 6, 6, 2, 2, 1, 3, 1, 3, 4, 4, 5, 5]), - parent=np.array( - [2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 6, 6, 7, 7, 8, 8], - dtype="int32", - ), - child=np.array( - [0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1], - dtype="int32", - ), - ) - return tables.tree_sequence() - - @tests.cached_example - def ts_two_trees_ten_leaves(self): - tables = tskit.TableCollection(sequence_length=2) - tables.nodes.set_columns( - time=np.array([0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]), - flags=np.array([1, 1, 0, 0, 0, 0, 0, 0, 0], dtype="uint32"), - ) - tables.edges.set_columns( - left=np.array([6, 6, 7, 7, 5, 5, 1, 1, 0, 2, 0, 2, 3, 3, 4, 4]), - right=np.array([7, 7, 8, 8, 6, 6, 2, 2, 1, 3, 1, 3, 4, 4, 5, 5]), - parent=np.array( - [2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 6, 6, 7, 7, 8, 8], - dtype="int32", - ), - child=np.array( - [0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1], - dtype="int32", - ), - ) - return tables.tree_sequence() - - @tests.cached_example - def ts_many_edge_diffs(self): - ts = msprime.sim_ancestry( - samples=80, - ploidy=1, - sequence_length=4, - recombination_rate=10, - random_seed=1234, - ) - return ts - - -class TestUnweightedCoalescenceTimeTable(TestCoalescenceTimeDistribution): - """ - Block Weight C.Wght Quant - ------------------------- - 29.00┊ 9 ┊ 0 1 4 1.00 - ┊ ┏━━┻━━┓ ┊ - 8.00 ┊ ┃ 8 ┊ 0 1 3 0.75 - ┊ ┃ ┏━┻━━┓ ┊ - 5.00 ┊ ┃ 7 ┃ ┊ 0 1 2 0.50 - ┊ ┃ ┏━╋━┓ ┃ ┊ - 1.00 ┊ ┃ ┃ ┃ ┃ 6 ┊ 0 1 1 0.25 - ┊ ┃ ┃ ┃ ┃ ┏┻┓ ┊ - 0.00 ┊ 0 1 2 4 3 5 ┊ 0 0 0 0.00 < to catch OOR - 0 100 - Uniform weights on nodes - """ - - def coalescence_time_distribution(self): - ts = self.ts_multimerger_six_leaves() - distr = ts.coalescence_time_distribution(span_normalise=False) - return distr - - def test_time(self): - t = np.array([0, 1, 5, 8, 29]) - distr = self.coalescence_time_distribution() - tt = distr.tables[0].time - np.testing.assert_allclose(t, tt) - - def test_block(self): - b = np.array([0, 0, 0, 0, 0]) - distr = self.coalescence_time_distribution() - tb = distr.tables[0].block - np.testing.assert_allclose(b, tb) - - def test_weights(self): - w = np.array([[0, 1, 1, 1, 1]]).T - distr = self.coalescence_time_distribution() - tw = distr.tables[0].weights - np.testing.assert_allclose(w, tw) - - def test_cum_weights(self): - c = np.array([[0, 1, 2, 3, 4]]).T - distr = self.coalescence_time_distribution() - tc = distr.tables[0].cum_weights - np.testing.assert_allclose(c, tc) - - def test_quantile(self): - q = np.array([[0, 0.25, 0.50, 0.75, 1]]).T - distr = self.coalescence_time_distribution() - tq = distr.tables[0].quantile - np.testing.assert_allclose(q, tq) - - -class TestPairWeightedCoalescenceTimeTable(TestCoalescenceTimeDistribution): - """ - Weights - (A,A) (A,B) (A,C) (B,B) (B,C) (C,C) - ----------------------------------- - 29.00┊ 9 ┊ 1 2 2 0 0 0 - ┊ ┏━━┻━━┓ ┊ - 8.00 ┊ ┃ 8 ┊ 0 1 1 1 2 1 - ┊ ┃ ┏━┻━━┓ ┊ - 5.00 ┊ ┃ 7 ┃ ┊ 0 1 1 0 1 0 - ┊ ┃ ┏━╋━┓ ┃ ┊ - 1.00 ┊ ┃ ┃ ┃ ┃ 6 ┊ 0 0 0 0 1 0 - ┊ ┃ ┃ ┃ ┃ ┏┻┓ ┊ - 0.00 ┊ 0 1 2 4 3 5 ┊ - 0 100 - Pop.┊ A A B C B C ┊ - Weights are number of pairs of a given population labelling that coalesce - in node - """ - - def coalescence_time_distribution(self): - ts = self.ts_multimerger_six_leaves() - sample_sets = [[0, 1], [2, 3], [4, 5]] - distr = ts.coalescence_time_distribution( - sample_sets=sample_sets, - weight_func="pair_coalescence_events", - span_normalise=False, - ) - return distr - - def test_time(self): - t = np.array([0, 1, 5, 8, 29]) - distr = self.coalescence_time_distribution() - tt = distr.tables[0].time - np.testing.assert_allclose(t, tt) - - def test_block(self): - b = np.array([0, 0, 0, 0, 0]) - distr = self.coalescence_time_distribution() - tb = distr.tables[0].block - np.testing.assert_allclose(b, tb) - - def test_weights(self): - w = np.array( - [ - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0], - [0, 1, 1, 0, 1, 0], - [0, 1, 1, 1, 2, 1], - [1, 2, 2, 0, 0, 0], - ] - ) - distr = self.coalescence_time_distribution() - tw = distr.tables[0].weights - np.testing.assert_allclose(w, tw) - - def test_cum_weights(self): - c = np.array( - [ - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0], - [0, 1, 1, 0, 2, 0], - [0, 2, 2, 1, 4, 1], - [1, 4, 4, 1, 4, 1], - ] - ) - distr = self.coalescence_time_distribution() - tc = distr.tables[0].cum_weights - np.testing.assert_allclose(c, tc) - - def test_quantile(self): - q = np.array( - [ - [0.0, 0.00, 0.00, 0.00, 0.00, 0.0], - [0.0, 0.00, 0.00, 0.00, 0.25, 0.0], - [0.0, 0.25, 0.25, 0.00, 0.50, 0.0], - [0.0, 0.50, 0.50, 1.00, 1.00, 1.0], - [1.0, 1.00, 1.00, 1.00, 1.00, 1.0], - ] - ) - distr = self.coalescence_time_distribution() - tq = distr.tables[0].quantile - np.testing.assert_allclose(q, tq) - - -class TestTrioFirstWeightedCoalescenceTimeTable(TestCoalescenceTimeDistribution): - """ - Weights - AAA AAB AAC ABA ABB ABC ACA ACB ACC - ----------------------------------- - 10.00┊ 13 ┊ 0 0 0 0 0 0 0 0 0 - ┊ ┏━━┻━━┓ ┊ - 8.00 ┊ 12 ┃ ┊ 0 0 0 0 0 0 2 1 0 - ┊ ┏━┻━┓ ┃ ┊ - 6.00 ┊ 11 ┃ ┃ ┊ 0 0 0 4 2 2 0 0 0 - ┊ ┏━━╋━┓ ┃ ┃ ┊ - 2.00 ┊ 10 ┃ ┃ ┃ 9 ┊(10) 0 0 0 0 0 0 2 3 1 - ┊ ┏┻┓ ┃ ┃ ┃ ┏┻━┓ ┊( 9) 0 0 0 2 4 4 0 0 0 - 1.00 ┊ ┃ ┃ ┃ ┃ ┃ 8 ┃ ┊ 1 3 2 0 0 0 0 0 0 - ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ - 0.00 ┊ 0 7 4 5 6 1 2 3 ┊ - 0 100 BBA BBB BBC BCA BCB BCC CCA CCB CCC - Pop.┊ A C B B C A A B ┊ ----------------------------------- - 10.00┊ 13 ┊ 0 0 0 0 0 0 0 0 0 <- removed - ┊ ┏━━┻━━┓ ┊ - 8.00 ┊ 12 ┃ ┊ 0 0 0 4 2 0 2 1 0 - ┊ ┏━┻━┓ ┃ ┊ - 6.00 ┊ 11 ┃ ┃ ┊ 2 1 1 4 2 2 0 0 0 - ┊ ┏━━╋━┓ ┃ ┃ ┊ - 2.00 ┊ 10 ┃ ┃ ┃ 9 ┊(10) 0 0 0 0 0 0 0 0 0 - ┊ ┏┻┓ ┃ ┃ ┃ ┏┻━┓ ┊( 9) 0 0 0 0 0 0 0 0 0 - 1.00 ┊ ┃ ┃ ┃ ┃ ┃ 8 ┃ ┊ 0 0 0 0 0 0 0 0 0 - ┊ ┃ ┃ ┃ ┃ ┃ ┏┻┓ ┃ ┊ ^empty - 0.00 ┊ 0 7 4 5 6 1 2 3 ┊ - Pop. ┊ A C B B C A A B ┊ - Weights are number of trios of a given population labelling with first coalescence - in node; shorthand in table columns for newick is ABC = ((A,B):node,C) - """ - - def coalescence_time_distribution(self): - ts = self.ts_multimerger_eight_leaves() - sample_sets = [[0, 1, 2], [3, 4, 5], [6, 7]] - distr = ts.coalescence_time_distribution( - sample_sets=sample_sets, - weight_func="trio_first_coalescence_events", - span_normalise=False, - ) - return distr - - def test_time(self): - t = np.array([0.0, 1.0, 2.0, 2.0, 6.0, 8.00]) - distr = self.coalescence_time_distribution() - tt = distr.tables[0].time - np.testing.assert_allclose(t, tt) - - def test_block(self): - b = np.array([0, 0, 0, 0, 0, 0]) - distr = self.coalescence_time_distribution() - tb = distr.tables[0].block - np.testing.assert_allclose(b, tb) - - def test_weights(self): - w = np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 2, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 2, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 4, 2, 2, 0, 0, 0, 2, 1, 1, 4, 2, 2, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 4, 2, 0, 2, 1, 0], - ] - ) - distr = self.coalescence_time_distribution() - tw = distr.tables[0].weights - np.testing.assert_allclose(w, tw) - - def test_cum_weights(self): - c = np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 3, 2, 2, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 3, 2, 2, 4, 4, 2, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 3, 2, 6, 6, 6, 2, 3, 1, 2, 1, 1, 4, 2, 2, 0, 0, 0], - [1, 3, 2, 6, 6, 6, 4, 4, 1, 2, 1, 1, 8, 4, 2, 2, 1, 0], - ] - ) - distr = self.coalescence_time_distribution() - tc = distr.tables[0].cum_weights - np.testing.assert_allclose(c, tc) - - def test_quantile(self): - q = np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 3, 2, 2, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 3, 2, 2, 4, 4, 2, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 3, 2, 6, 6, 6, 2, 3, 1, 2, 1, 1, 4, 2, 2, 0, 0], - [1, 3, 2, 6, 6, 6, 4, 4, 1, 2, 1, 1, 8, 4, 2, 2, 1], - ], - dtype="float", - ) - q /= q[-1, :] - distr = self.coalescence_time_distribution() - tq = distr.tables[0].quantile - np.testing.assert_allclose(q, tq[:, :-1]) and np.all(np.isnan(tq[:, -1])) - - -class TestSingleBlockCoalescenceTimeTable(TestCoalescenceTimeDistribution): - """ - Cum. - Wght Wght Qntl - ----------------- - 1.74┊ 7 ┊ 7 ┊ 2 6 1.00 - ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ - 0.73┊ ┃ 6 ┊ ┃ ┃ ┊ 1 4 0.67 - ┊ ┃ ┏━┻┓ ┊ ┃ ┃ ┊ - 0.59┊ ┃ ┃ 5 ┊ ┃ 5 ┊ 2 3 0.50 - ┊ ┃ ┃ ┏┻┓ ┊ ┃ ┏┻━┓ ┊ - 0.54┊ ┃ ┃ ┃ ┃ ┊ ┃ 4 ┃ ┊ 1 1 0.17 - ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┏┻┓ ┃ ┊ - 0.00┊ 0 1 2 3 ┊ 0 1 2 3 ┊ - 0.00 0.88 1.00 - Uniform weights on nodes summed over trees - """ - - def coalescence_time_distribution(self): - ts = self.ts_two_trees_four_leaves() - distr = ts.coalescence_time_distribution( - span_normalise=False, - ) - return distr - - def test_time(self): - t = np.array([0.0, 0.54, 0.59, 0.73, 1.74]) - distr = self.coalescence_time_distribution() - tt = distr.tables[0].time - np.testing.assert_allclose(t, tt) - - def test_block(self): - b = np.array([0, 0, 0, 0, 0]) - distr = self.coalescence_time_distribution() - tb = distr.tables[0].block - np.testing.assert_allclose(b, tb) - - def test_weights(self): - w = np.array([[0, 1, 2, 1, 2]]).T - distr = self.coalescence_time_distribution() - tw = distr.tables[0].weights - np.testing.assert_allclose(w, tw) - - def test_cum_weights(self): - c = np.array([[0, 1, 3, 4, 6]]).T - distr = self.coalescence_time_distribution() - tc = distr.tables[0].cum_weights - np.testing.assert_allclose(c, tc) - np.testing.assert_allclose(c, tc) - - def test_quantile(self): - q = np.array([[0.0, 1 / 6, 3 / 6, 4 / 6, 1.0]]).T - distr = self.coalescence_time_distribution() - tq = distr.tables[0].quantile - np.testing.assert_allclose(q, tq) - - -class TestWindowedCoalescenceTimeTable(TestCoalescenceTimeDistribution): - """ - 0.00 0.50 1.00 Window 0 Window 1 - Wndw┊ 0 ┊ 1 ┊ Time Wght Blck Time Wght Blck - Blck┊ 0 ┊ 0 ┊ 1 ┊ ---------------- ---------------- - 1.74┊ 7 ┊ 7 ┊ 1.74 1 0 1.74 1 1 - ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ 1.74 1 0 - 0.73┊ ┃ 6 ┊ ┃ ┃ ┊ 0.73 1 0 0.73 1 0 - ┊ ┃ ┏━┻┓ ┊ ┃ ┃ ┊ - 0.59┊ ┃ ┃ 5 ┊ ┃ 5 ┊ 0.59 1 0 0.59 1 1 - ┊ ┃ ┃ ┏┻┓ ┊ ┃ ┏┻━┓ ┊ 0.59 1 0 - 0.54┊ ┃ ┃ ┃ ┃ ┊ ┃ 4 ┃ ┊ 0.54 1 1 - ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┏┻┓ ┃ ┊ - 0.00┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0.00 0 0 0.00 0 0 0) - assert np.all(~np.isnan(boot_distr.tables[0].quantile[:, 0])) - - -class TestCoalescenceTimeDistributionTableResize(TestCoalescenceTimeDistribution): - """ - If the initial allocation for the table is exceeded, the number of rows is - increased. - """ - - def coalescence_time_distribution(self): - ts = self.ts_five_trees_three_leaves() - distr = ts.coalescence_time_distribution( - blocks_per_window=ts.num_trees, - span_normalise=False, - ) - return distr - - def test_table_resize(self): - distr = self.coalescence_time_distribution() - assert distr.tables[0].num_records > distr.buffer_size + 1 - - -class TestCoalescenceTimeDistributionBlocking(TestCoalescenceTimeDistribution): - """ - Test assignment of blocks per window and trees per block. If window breaks - fall on recombination breakpoints, and the number of trees is divisible by - the number of windows, then there should be an equal number of trees per - window. - """ - - def coalescence_time_distribution(self): - # 2 trees/block, 2 blocks/window, 2 windows/ts - ts = self.ts_eight_trees_two_leaves() - bk = [t.interval.left for t in ts.trees()][::4] + [ts.sequence_length] - - def count_root_init(node, sample_sets): - all_samples = [i for s in sample_sets for i in s] - state = np.array([[node == i for i in all_samples]], dtype=np.float64) - return (state,) - - def count_root_update(child_state): - state = np.sum(child_state, axis=0, keepdims=True) - is_root = np.array([[np.all(state > 0)]], dtype=np.float64) - return is_root, (state,) - - distr = ts.coalescence_time_distribution( - weight_func=(count_root_init, count_root_update), - window_breaks=np.array(bk), - blocks_per_window=2, - span_normalise=False, - ) - return distr - - def test_blocks_per_window(self): - distr = self.coalescence_time_distribution() - bpw = np.array([i.num_blocks for i in distr.tables]) - np.testing.assert_allclose(bpw, 2) - - def test_trees_per_window(self): - distr = self.coalescence_time_distribution() - tpw = np.array([np.sum(distr.tables[i].weights) for i in range(2)]) - np.testing.assert_allclose(tpw, 4) - - def test_trees_per_block(self): - distr = self.coalescence_time_distribution() - tpb = [] - for table in distr.tables: - for block in range(2): - tpb += [np.sum(table.weights[table.block == block])] - np.testing.assert_allclose(tpb, 2) - - -class TestCoalescenceTimeDistributionBlockedVsUnblocked( - TestCoalescenceTimeDistribution -): - """ - Test that methods give the same result regardless of how trees are blocked. - """ - - def coalescence_time_distribution(self, num_blocks=1): - ts = self.ts_many_edge_diffs() - sample_sets = [list(range(10)), list(range(20, 40)), list(range(70, 80))] - distr = ts.coalescence_time_distribution( - sample_sets=sample_sets, - weight_func="pair_coalescence_events", - blocks_per_window=num_blocks, - span_normalise=True, - ) - return distr - - def test_ecdf(self): - distr_noblock = self.coalescence_time_distribution(num_blocks=1) - distr_block = self.coalescence_time_distribution(num_blocks=10) - t = np.linspace(0, distr_noblock.tables[0].time[-1] + 1, 5) - np.testing.assert_allclose(distr_noblock.ecdf(t), distr_block.ecdf(t)) - - def test_num_coalesced(self): - distr_noblock = self.coalescence_time_distribution(num_blocks=1) - distr_block = self.coalescence_time_distribution(num_blocks=10) - t = np.linspace(0, distr_noblock.tables[0].time[-1] + 1, 5) - np.testing.assert_allclose( - distr_noblock.num_coalesced(t), distr_block.num_coalesced(t) - ) - - def test_num_uncoalesced(self): - distr_noblock = self.coalescence_time_distribution(num_blocks=1) - distr_block = self.coalescence_time_distribution(num_blocks=10) - t = np.linspace(0, distr_noblock.tables[0].time[-1] + 1, 5) - np.testing.assert_allclose( - distr_noblock.num_uncoalesced(t), distr_block.num_uncoalesced(t) - ) - - def test_quantile(self): - distr_noblock = self.coalescence_time_distribution(num_blocks=1) - distr_block = self.coalescence_time_distribution(num_blocks=10) - q = np.linspace(0, 1, 11) - np.testing.assert_allclose(distr_noblock.quantile(q), distr_block.quantile(q)) - - def test_mean(self): - distr_noblock = self.coalescence_time_distribution(num_blocks=1) - distr_block = self.coalescence_time_distribution(num_blocks=10) - t = distr_noblock.tables[0].time[-1] / 2 - np.testing.assert_allclose( - distr_noblock.mean(since=t), distr_block.mean(since=t) - ) - - def test_coalescence_rate_in_intervals(self): - distr_noblock = self.coalescence_time_distribution(num_blocks=1) - distr_block = self.coalescence_time_distribution(num_blocks=10) - t = np.linspace(0, distr_noblock.tables[0].time[-1] + 1, 5) - np.testing.assert_allclose( - distr_noblock.coalescence_rate_in_intervals(t), - distr_block.coalescence_rate_in_intervals(t), - ) - - def test_coalescence_probability_in_intervals(self): - distr_noblock = self.coalescence_time_distribution(num_blocks=1) - distr_block = self.coalescence_time_distribution(num_blocks=10) - t = np.linspace(0, distr_noblock.tables[0].time[-1] + 1, 5) - np.testing.assert_allclose( - distr_noblock.coalescence_probability_in_intervals(t), - distr_block.coalescence_probability_in_intervals(t), - ) - - -class TestCoalescenceTimeDistributionRunningUpdate(TestCoalescenceTimeDistribution): - """ - When traversing trees, weights are updated for nodes whose descendant subtree - has changed. This is done by taking the parents of added edges, and tracing - ancestors down to the root. This class tests that this "running update" - scheme produces the correct result. - """ - - def coalescence_time_distribution_running(self, ts, brk, sets=2): - n = ts.num_samples // sets - smp_set = [list(range(i, i + n)) for i in range(0, ts.num_samples, n)] - distr = ts.coalescence_time_distribution( - sample_sets=smp_set, - window_breaks=brk, - weight_func="trio_first_coalescence_events", - span_normalise=False, - ) - return distr - - def coalescence_time_distribution_split(self, ts, brk, sets=2): - n = ts.num_samples // sets - smp_set = [list(range(i, i + n)) for i in range(0, ts.num_samples, n)] - distr_by_win = [] - for left, right in zip(brk[:-1], brk[1:]): - ts_trim = ts.keep_intervals([[left, right]]).trim() - distr_by_win += [ - ts_trim.coalescence_time_distribution( - sample_sets=smp_set, - weight_func="trio_first_coalescence_events", - span_normalise=False, - ) - ] - return distr_by_win - - def test_many_edge_diffs(self): - """ - Test that ts windowed by tree gives same result as set of single trees. - """ - ts = self.ts_many_edge_diffs() - brk = np.array([t.interval.left for t in ts.trees()] + [ts.sequence_length]) - distr = self.coalescence_time_distribution_running(ts, brk) - distr_win = self.coalescence_time_distribution_split(ts, brk) - time_breaks = np.array([np.inf]) - updt = distr.num_coalesced(time_breaks) - sepr = np.zeros(updt.shape) - for i, d in enumerate(distr_win): - c = d.num_coalesced(time_breaks) - sepr[:, :, i] = c.reshape((c.shape[0], 1)) - np.testing.assert_allclose(sepr, updt) - - def test_missing_trees(self): - """ - Test that ts with half of each tree masked gives same result as unmasked ts. - """ - ts = self.ts_many_edge_diffs() - brk = np.array([t.interval.left for t in ts.trees()] + [ts.sequence_length]) - mask = np.array( - [ - [tr.interval.left, (tr.interval.right + tr.interval.left) / 2] - for tr in ts.trees() - ] - ) - ts_mask = ts.delete_intervals(mask) - distr = self.coalescence_time_distribution_running(ts, brk) - distr_mask = self.coalescence_time_distribution_running(ts_mask, brk) - time_breaks = np.array([np.inf]) - updt = distr.num_coalesced(time_breaks) - updt_mask = distr_mask.num_coalesced(time_breaks) - np.testing.assert_allclose(updt, updt_mask) - - def test_unary_nodes(self): - """ - Test that ts with unary nodes gives same result as ts with unary nodes removed. - """ - ts = self.ts_many_edge_diffs() - ts_unary = ts.simplify( - samples=list(range(ts.num_samples // 2)), keep_unary=True - ) - ts_nounary = ts.simplify( - samples=list(range(ts.num_samples // 2)), keep_unary=False - ) - brk = np.array([t.interval.left for t in ts.trees()] + [ts.sequence_length]) - distr_unary = self.coalescence_time_distribution_running(ts_unary, brk) - distr_nounary = self.coalescence_time_distribution_running(ts_nounary, brk) - time_breaks = np.array([np.inf]) - updt_unary = distr_unary.num_coalesced(time_breaks) - updt_nounary = distr_nounary.num_coalesced(time_breaks) - np.testing.assert_allclose(updt_unary, updt_nounary) - - -class TestSpanNormalisedCoalescenceTimeTable(TestCoalescenceTimeDistribution): - """ - Cum. - Wght Wght - --------------- - 1.74┊ 7 ┊ 7 ┊ 0.88+0.12 3.00 - ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ - 0.73┊ ┃ 6 ┊ ┃ ┃ ┊ 0.88 2.00 - ┊ ┃ ┏━┻┓ ┊ ┃ ┃ ┊ - 0.59┊ ┃ ┃ 5 ┊ ┃ 5 ┊ 0.88+0.12 1.12 - ┊ ┃ ┃ ┏┻┓ ┊ ┃ ┏┻━┓ ┊ - 0.54┊ ┃ ┃ ┃ ┃ ┊ ┃ 4 ┃ ┊ 0.12 0.12 - ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┏┻┓ ┃ ┊ - 0.00┊ 0 1 2 3 ┊ 0 1 2 3 ┊ - 0.00 0.88 1.00 - SpnW┊ 0.88 ┊ 0.12 ┊ - Uniform weights on nodes summed over trees, weighted by tree span - """ - - def coalescence_time_distribution(self, mask_half_of_each_tree=False): - """ - Methods should give the same result if half of each tree is masked, - because "span weights" are normalised using the accessible (nonmissing) - portion of the tree sequence. - """ - ts = self.ts_two_trees_four_leaves() - if mask_half_of_each_tree: - mask = np.array( - [ - [t.interval.left, (t.interval.right + t.interval.left) / 2] - for t in ts.trees() - ] - ) - ts = ts.delete_intervals(mask) - distr = ts.coalescence_time_distribution( - span_normalise=True, - ) - return distr - - @pytest.mark.parametrize("with_missing_data", [True, False]) - def test_weights(self, with_missing_data): - w = np.array([[0, 0.12, 1.0, 0.88, 1.0]]).T - distr = self.coalescence_time_distribution(with_missing_data) - tw = distr.tables[0].weights - np.testing.assert_allclose(w, tw) - - @pytest.mark.parametrize("with_missing_data", [True, False]) - def test_cum_weights(self, with_missing_data): - c = np.array([[0, 0.12, 1.12, 2.00, 3.00]]).T - distr = self.coalescence_time_distribution(with_missing_data) - tc = distr.tables[0].cum_weights - np.testing.assert_allclose(c, tc) - - -class TestWindowedSpanNormalisedCoalescenceTimeTable(TestCoalescenceTimeDistribution): - """ - 0.00 0.50 1.00 Window 0 Window 1 - Wndw┊ 0 ┊ 1 ┊ Time Wght Time Wght - Blck┊ 0 ┊ 0 ┊ 1 ┊ ---------- ---------- - 1.74┊ 7 ┊ 7 ┊ 1.74 1 1.74 0.24 - ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ 1.74 0.76 - 0.73┊ ┃ 6 ┊ ┃ ┃ ┊ 0.73 1 0.73 0.76 - ┊ ┃ ┏━┻┓ ┊ ┃ ┃ ┊ - 0.59┊ ┃ ┃ 5 ┊ ┃ 5 ┊ 0.59 1 0.59 0.24 - ┊ ┃ ┃ ┏┻┓ ┊ ┃ ┏┻━┓ ┊ 0.59 0.76 - 0.54┊ ┃ ┃ ┃ ┃ ┊ ┃ 4 ┃ ┊ 0.54 0.24 - ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┏┻┓ ┃ ┊ - 0.00┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0.00 0 0.00 0 0 - self.num_records = sum(not_empty) - self.time = time[not_empty] - self.block = block[not_empty] - self.weights = weights[not_empty, :] - # add left boundary at time 0 - self.num_records += 1 - self.time = np.pad(self.time, (0, 1)) - self.block = np.pad(self.block, (0, 1)) - self.weights = np.pad(self.weights, ((0, 1), (0, 0))) - # sort by node time - time_order = np.argsort(self.time) - self.time = self.time[time_order] - self.block = self.block[time_order] - self.weights = self.weights[time_order, :] - # calculate quantiles - self.num_blocks = 1 + np.max(self.block) if self.num_records > 0 else 0 - self.block_multiplier = np.ones(self.num_blocks) - self.cum_weights = np.cumsum(self.weights, 0) - self.quantile = np.empty((self.num_records, self.num_weights)) - self.quantile[:] = np.nan - for i in range(self.num_weights): - if self.cum_weights[-1, i] > 0: - self.quantile[:, i] = self.cum_weights[:, i] / self.cum_weights[-1, i] - - def resample_blocks(self, block_multiplier): - assert block_multiplier.shape[0] == self.num_blocks - assert np.sum(block_multiplier) == self.num_blocks - self.block_multiplier = block_multiplier - for i in range(self.num_weights): - self.cum_weights[:, i] = np.cumsum( - self.weights[:, i] * self.block_multiplier[self.block], 0 - ) - if self.cum_weights[-1, i] > 0: - self.quantile[:, i] = self.cum_weights[:, i] / self.cum_weights[-1, i] - else: - self.quantile[:, i] = np.nan - - -class CoalescenceTimeDistribution: - """ - Class to precompute a table of sorted/weighted node times, from which to calculate - the empirical distribution function and estimate coalescence rates in time windows. - - To compute weights efficiently requires an update operation of the form: - - ``output[parent], state[parent] = update(state[children])`` - - where ``output`` are the weights associated with the node, and ``state`` - are values that are needed to compute ``output`` that are recursively - calculated along the tree. The value of ``state`` on the leaves is - initialized via, - - ``state[sample] = initialize(sample, sample_sets)`` - """ - - @staticmethod - def _count_coalescence_events(): - """ - Count the number of samples that coalesce in ``node``, within each - set of samples in ``sample_sets``. - """ - - def initialize(node, sample_sets): - singles = np.array([[node in s for s in sample_sets]], dtype=np.float64) - return (singles,) - - def update(singles_per_child): - singles = np.sum(singles_per_child, axis=0, keepdims=True) - is_ancestor = (singles > 0).astype(np.float64) - return is_ancestor, (singles,) - - return (initialize, update) - - @staticmethod - def _count_pair_coalescence_events(): - """ - Count the number of pairs that coalesce in ``node``, within and between the - sets of samples in ``sample_sets``. The count of pairs with members that - belong to sets :math:`a` and :math:`b` is: - - .. math: - - \\sum_{i \\neq j} (C_i(a) C_j(b) + C_i(b) C_j(a))/(1 - \\mathbb{I}[a = b]) - - where :math:`C_i(a)` is the number of samples from set :math:`a` - descended from child :math:`i`. The values in the output are ordered - canonically; e.g. if ``len(sample_sets) == 2`` then the values would - correspond to counts of pairs with set labels ``[(0,0), (0,1), (1,1)]``. - """ - - def initialize(node, sample_sets): - singles = np.array([[node in s for s in sample_sets]], dtype=np.float64) - return (singles,) - - def update(singles_per_child): - C = singles_per_child.shape[0] # number of children - S = singles_per_child.shape[1] # number of sample sets - singles = np.sum(singles_per_child, axis=0, keepdims=True) - pairs = np.zeros((1, int(S * (S + 1) / 2)), dtype=np.float64) - for a, b in itertools.combinations(range(C), 2): - for i, (j, k) in enumerate( - itertools.combinations_with_replacement(range(S), 2) - ): - pairs[0, i] += ( - singles_per_child[a, j] * singles_per_child[b, k] - + singles_per_child[a, k] * singles_per_child[b, j] - ) / (1 + int(j == k)) - return pairs, (singles,) - - return (initialize, update) - - @staticmethod - def _count_trio_first_coalescence_events(): - """ - Count the number of pairs that coalesce in node with an outgroup, - within and between the sets of samples in ``sample_sets``. In other - words, count topologies of the form ``((A,B):node,C)`` where ``A,B,C`` - are labels and `node` is the node ID. The count of pairs with members - that belong to sets :math:`a` and :math:`b` with outgroup :math:`c` is: - - .. math: - - \\sum_{i \\neq j} (C_i(a) C_j(b) + C_i(b) C_j(a)) \\times - O(c) / (1 - \\mathbb{I}[a = b]) - - where :math:`C_i(a)` is the number of samples from set :math:`a` - descended from child :math:`i` of the node, and :math:`O(c)` is the - number of samples from set :math:`c` that are *not* descended from the - node. The values in the output are ordered canonically by pair then - outgroup; e.g. if ``len(sample_sets) == 2`` then the values would - correspond to counts of pairs with set labels, - ``[((0,0),0), ((0,0),1), ..., ((0,1),0), ((0,1),1), ...]``. - """ - - def initialize(node, sample_sets): - S = len(sample_sets) - totals = np.array([[len(s) for s in sample_sets]], dtype=np.float64) - singles = np.array([[node in s for s in sample_sets]], dtype=np.float64) - pairs = np.zeros((1, int(S * (S + 1) / 2)), dtype=np.float64) - return ( - totals, - singles, - pairs, - ) - - def update(totals_per_child, singles_per_child, pairs_per_child): - C = totals_per_child.shape[0] # number of children - S = totals_per_child.shape[1] # number of sample sets - totals = np.mean(totals_per_child, axis=0, keepdims=True) - singles = np.sum(singles_per_child, axis=0, keepdims=True) - pairs = np.zeros((1, int(S * (S + 1) / 2)), dtype=np.float64) - for a, b in itertools.combinations(range(C), 2): - pair_iterator = itertools.combinations_with_replacement(range(S), 2) - for i, (j, k) in enumerate(pair_iterator): - pairs[0, i] += ( - singles_per_child[a, j] * singles_per_child[b, k] - + singles_per_child[a, k] * singles_per_child[b, j] - ) / (1 + int(j == k)) - outgr = totals - singles - trios = np.zeros((1, pairs.size * outgr.size), dtype=np.float64) - trio_iterator = itertools.product(range(pairs.size), range(outgr.size)) - for i, (j, k) in enumerate(trio_iterator): - trios[0, i] += pairs[0, j] * outgr[0, k] - return trios, ( - totals, - singles, - pairs, - ) - - return (initialize, update) - - def _update_running_with_edge_diff( - self, tree, edge_diff, running_output, running_state, running_index - ): - """ - Update ``running_output`` and ``running_state`` to reflect ``tree``, - using edge differences ``edge_diff`` with the previous tree. - The dict ``running_index`` maps node IDs onto rows of the running arrays. - """ - - assert edge_diff.interval == tree.interval - - # empty rows in the running arrays - available_rows = {i for i in range(self.running_array_size)} - available_rows -= set(running_index.values()) - - # find internal nodes that have been removed from tree or are unary - removed_nodes = set() - for i in edge_diff.edges_out: - for j in [i.child, i.parent]: - if tree.num_children(j) < 2 and not tree.is_sample(j): - removed_nodes.add(j) - - # find non-unary nodes where descendant subtree has been altered - modified_nodes = { - i.parent for i in edge_diff.edges_in if tree.num_children(i.parent) > 1 - } - for i in copy.deepcopy(modified_nodes): - while tree.parent(i) != tskit.NULL and not tree.parent(i) in modified_nodes: - i = tree.parent(i) - if tree.num_children(i) > 1: - modified_nodes.add(i) - - # clear running state/output for nodes that are no longer in tree - for i in removed_nodes: - if i in running_index: - running_state[running_index[i], :] = 0 - running_output[running_index[i], :] = 0 - available_rows.add(running_index.pop(i)) - - # recalculate state/output for nodes whose descendants have changed - for i in sorted(modified_nodes, key=lambda node: tree.time(node)): - children = [] - for c in tree.children(i): # skip unary children - while tree.num_children(c) == 1: - (c,) = tree.children(c) - children.append(c) - child_index = [running_index[c] for c in children] - - inputs = ( - running_state[child_index][:, state_index] - for state_index in self.state_indices - ) - output, state = self._update(*inputs) - - # update running state/output arrays - if i not in running_index: - running_index[i] = available_rows.pop() - running_output[running_index[i], :] = output - for state_index, x in zip(self.state_indices, state): - running_state[running_index[i], state_index] = x - - # track the number of times the weight function was called - self.weight_func_evals += len(modified_nodes) - - def _build_ecdf_table_for_window( - self, - left, - right, - tree, - edge_diffs, - running_output, - running_state, - running_index, - ): - """ - Construct ECDF table for genomic interval [left, right]. Update - ``tree``; ``edge_diffs``; and ``running_output``, ``running_state``, - `running_idx``; for input for next window. Trees are counted as - belonging to any interval with which they overlap, and thus can be used - in several intervals. Thus, the concatenation of ECDF tables across - multiple intervals is not the same as the ECDF table for the union of - those intervals. Trees within intervals are chunked into roughly - equal-sized blocks for bootstrapping. - """ - - assert tree.interval.left <= left and right > left - - # TODO: if bootstrapping, block span needs to be tracked - # and used to renormalise each replicate. This should be - # done by the bootstrapping machinery, not here. - - # assign trees in window to equal-sized blocks with unique id - tree_offset = tree.index - if right >= tree.tree_sequence.sequence_length: - tree.last() - else: - # tree.seek(right) won't work if `right` is recomb breakpoint - while tree.interval.right < right: - tree.next() - tree_idx = np.arange(tree_offset, tree.index + 1) - tree_offset - num_blocks = min(self.num_blocks, len(tree_idx)) - tree_blocks = np.floor_divide(num_blocks * tree_idx, len(tree_idx)) - - # calculate span weights - tree.seek_index(tree_offset) - tree_span = [min(tree.interval.right, right) - max(tree.interval.left, left)] - while tree.index < tree_offset + tree_idx[-1]: - tree.next() - tree_span.append( - min(tree.interval.right, right) - max(tree.interval.left, left) - ) - tree_span = np.array(tree_span) - total_span = np.sum(tree_span) - assert np.isclose( - total_span, min(right, tree.tree_sequence.sequence_length) - left - ) - - # storage if using single window, block for entire tree sequence - buffer_size = self.buffer_size - table_size = buffer_size - time = np.zeros(table_size) - block = np.zeros(table_size, dtype=np.int32) - weights = np.zeros((table_size, self.num_weights)) - - # assemble table of coalescence times in window - num_record = 0 - accessible_span = 0.0 - span_weight = 1.0 - indices = np.zeros(tree.tree_sequence.num_nodes, dtype=np.int32) - 1 - last_block = np.zeros(tree.tree_sequence.num_nodes, dtype=np.int32) - 1 - tree.seek_index(tree_offset) - while tree.index != tskit.NULL: - if tree.interval.right > left: - current_block = tree_blocks[tree.index - tree_offset] - if self.span_normalise: - span_weight = tree_span[tree.index - tree_offset] / total_span - - # TODO: shouldn't need to loop over all keys (nodes) for every tree - internal_nodes = np.array( - [i for i in running_index.keys() if not tree.is_sample(i)], - dtype=np.int32, - ) - - if internal_nodes.size > 0: - accessible_span += tree_span[tree.index - tree_offset] - rows_in_running = np.array( - [running_index[i] for i in internal_nodes], dtype=np.int32 - ) - nodes_to_add = internal_nodes[ - last_block[internal_nodes] != current_block - ] - if nodes_to_add.size > 0: - table_idx = np.arange( - num_record, num_record + len(nodes_to_add) - ) - last_block[nodes_to_add] = current_block - indices[nodes_to_add] = table_idx - if table_size < num_record + len(nodes_to_add): - table_size += buffer_size - time = np.pad(time, (0, buffer_size)) - block = np.pad(block, (0, buffer_size)) - weights = np.pad(weights, ((0, buffer_size), (0, 0))) - time[table_idx] = [tree.time(i) for i in nodes_to_add] - block[table_idx] = current_block - num_record += len(nodes_to_add) - weights[indices[internal_nodes], :] += ( - span_weight * running_output[rows_in_running, :] - ) - - if tree.interval.right < right: - # if current tree does not cross window boundary, move to next - tree.next() - self._update_running_with_edge_diff( - tree, next(edge_diffs), running_output, running_state, running_index - ) - else: - # use current tree as initial tree for next window - break - - # reweight span so that weights are averaged over nonmissing trees - if self.span_normalise: - weights *= total_span / accessible_span - - return CoalescenceTimeTable(time, block, weights) - - def _generate_ecdf_tables(self, ts, window_breaks): - """ - Return generator for ECDF tables across genomic windows defined by - ``window_breaks``. - - ..note:: This could be used in methods in place of loops over - pre-assembled tables. - """ - - tree = ts.first() - edge_diffs = ts.edge_diffs() - - # initialize running arrays for first tree - running_index = {i: n for i, n in enumerate(tree.samples())} - running_output = np.zeros( - (self.running_array_size, self.num_weights), - dtype=np.float64, - ) - running_state = np.zeros( - (self.running_array_size, self.num_states), - dtype=np.float64, - ) - for node in tree.samples(): - state = self._initialize(node, self.sample_sets) - for state_index, x in zip(self.state_indices, state): - running_state[running_index[node], state_index] = x - - self._update_running_with_edge_diff( - tree, next(edge_diffs), running_output, running_state, running_index - ) - - for left, right in zip(window_breaks[:-1], window_breaks[1:]): - yield self._build_ecdf_table_for_window( - left, - right, - tree, - edge_diffs, - running_output, - running_state, - running_index, - ) - - def __init__( - self, - ts, - sample_sets=None, - weight_func=None, - window_breaks=None, - blocks_per_window=None, - span_normalise=True, - ): - assert isinstance(ts, trees.TreeSequence) - - if sample_sets is None: - sample_sets = [list(ts.samples())] - assert all([isinstance(i, list) for i in sample_sets]) - assert all([i in ts.samples() for j in sample_sets for i in j]) - self.sample_sets = sample_sets - - if weight_func is None or weight_func == "coalescence_events": - self._initialize, self._update = self._count_coalescence_events() - elif weight_func == "pair_coalescence_events": - self._initialize, self._update = self._count_pair_coalescence_events() - elif weight_func == "trio_first_coalescence_events": - self._initialize, self._update = self._count_trio_first_coalescence_events() - else: - # user supplies pair of callables ``(initialize, update)`` - assert isinstance(weight_func, tuple) - assert len(weight_func) == 2 - self._initialize, self._update = weight_func - assert callable(self._initialize) - assert callable(self._update) - - # check initialization operation - _state = self._initialize(0, self.sample_sets) - assert isinstance(_state, tuple) - self.num_states = 0 - self.state_indices = [] - for x in _state: - # ``assert is_row_vector(x)`` - assert isinstance(x, np.ndarray) - assert x.ndim == 2 - assert x.shape[0] == 1 - index = list(range(self.num_states, self.num_states + x.size)) - self.state_indices.append(index) - self.num_states += x.size - - # check update operation - _weights, _state = self._update(*_state) - assert isinstance(_state, tuple) - for state_index, x in zip(self.state_indices, _state): - # ``assert is_row_vector(x, len(state_index))`` - assert isinstance(x, np.ndarray) - assert x.ndim == 2 - assert x.shape[0] == 1 - assert x.size == len(state_index) - # ``assert is_row_vector(_weights)`` - assert isinstance(_weights, np.ndarray) - assert _weights.ndim == 2 - assert _weights.shape[0] == 1 - self.num_weights = _weights.size - - if window_breaks is None: - window_breaks = np.array([0.0, ts.sequence_length]) - assert isinstance(window_breaks, np.ndarray) - assert window_breaks.ndim == 1 - assert np.min(window_breaks) >= 0.0 - assert np.max(window_breaks) <= ts.sequence_length - window_breaks = np.sort(np.unique(window_breaks)) - self.windows = [ - trees.Interval(left, right) - for left, right in zip(window_breaks[:-1], window_breaks[1:]) - ] - self.num_windows = len(self.windows) - - if blocks_per_window is None: - blocks_per_window = 1 - assert isinstance(blocks_per_window, int) - assert blocks_per_window > 0 - self.num_blocks = blocks_per_window - - assert isinstance(span_normalise, bool) - self.span_normalise = span_normalise - - self.buffer_size = ts.num_nodes - self.running_array_size = ts.num_samples * 2 - 1 # assumes no unary nodes - self.weight_func_evals = 0 - self.tables = [table for table in self._generate_ecdf_tables(ts, window_breaks)] - - # TODO - # - # def __str__(self): - # return self.useful_text_summary() - # - # def __repr_html__(self): - # return self.useful_html_summary() - - def copy(self): - return copy.deepcopy(self) - - def ecdf(self, times): - """ - Returns the empirical distribution function evaluated at the time - points in ``times``. - - The output array has shape ``(self.num_weights, len(times), - self.num_windows)``. - """ - - assert isinstance(times, np.ndarray) - assert times.ndim == 1 - - values = np.empty((self.num_weights, len(times), self.num_windows)) - values[:] = np.nan - for k, table in enumerate(self.tables): - indices = np.searchsorted(table.time, times, side="right") - 1 - assert all([0 <= i < table.num_records for i in indices]) - values[:, :, k] = table.quantile[indices, :].T - return values - - def quantile(self, quantiles): - """ - Return interpolated quantiles of weighted coalescence times. - """ - - assert isinstance(quantiles, np.ndarray) - assert quantiles.ndim == 1 - assert np.all(np.logical_and(quantiles >= 0, quantiles <= 1)) - - values = np.empty((self.num_weights, quantiles.size, self.num_windows)) - values[:] = np.nan - for k, table in enumerate(self.tables): - # retrieve ECDF for each unique timepoint in table - last_index = np.flatnonzero(table.time[:-1] != table.time[1:]) - time = np.append(table.time[last_index], table.time[-1]) - ecdf = np.append( - table.quantile[last_index, :], table.quantile[[-1]], axis=0 - ) - for i in range(self.num_weights): - if not np.isnan(ecdf[-1, i]): - # interpolation requires strictly increasing arguments, so - # retrieve leftmost x for step-like F(x), including F(0) = 0. - assert ecdf[-1, i] == 1.0 - assert ecdf[0, i] == 0.0 - delta = ecdf[1:, i] - ecdf[:-1, i] - first_index = 1 + np.flatnonzero(delta > 0) - - n_eff = first_index.size - weight = delta[first_index - 1] - cum_weight = np.roll(ecdf[first_index, i], 1) - cum_weight[0] = 0 - midpoint = np.arange(n_eff) * weight + (n_eff - 1) * cum_weight - assert midpoint[0] == 0 - assert midpoint[-1] == n_eff - 1 - values[i, :, k] = np.interp( - quantiles * (n_eff - 1), midpoint, time[first_index] - ) - return values - - def num_coalesced(self, times): - """ - Returns number of coalescence events that have occured by the time - points in ``times``. - - The output array has shape ``(self.num_weights, len(times), - self.num_windows)``. - """ - - assert isinstance(times, np.ndarray) - assert times.ndim == 1 - - values = self.ecdf(times) - for k, table in enumerate(self.tables): - weight_totals = table.cum_weights[-1, :].reshape(values.shape[0], 1) - values[:, :, k] *= np.tile(weight_totals, (1, values.shape[1])) - return values - - def num_uncoalesced(self, times): - """ - Returns the number of coalescence events remaining by the time points - in ``times``. - - The output array has shape ``(self.num_weights, len(times), - self.num_windows)``. - """ - - values = 1.0 - self.ecdf(times) - for k, table in enumerate(self.tables): - weight_totals = table.cum_weights[-1, :].reshape(values.shape[0], 1) - values[:, :, k] *= np.tile(weight_totals, (1, values.shape[1])) - return values - - def mean(self, since=0.0): - """ - Returns the average time between ``since`` and the coalescence events - that occurred after ``since``. - - Note that ``1/self.mean(left)`` is an estimate of the coalescence rate - over the interval (left, infinity). - - The output array has shape ``(self.num_weights, self.num_windows)``. - - ..note:: Check for overflow in ``np.average``. - """ - - assert isinstance(since, float) and since >= 0.0 - - values = np.empty((self.num_weights, self.num_windows)) - values[:] = np.nan - for k, table in enumerate(self.tables): - index = np.searchsorted(table.time, since, side="right") - if index == table.num_records: - values[:, k] = np.nan - else: - for i in range(self.num_weights): - multiplier = table.block_multiplier[table.block[index:]] - weights = table.weights[index:, i] * multiplier - if np.any(weights > 0): - values[i, k] = np.average( - table.time[index:] - since, - weights=weights, - ) - return values - - def coalescence_probability_in_intervals(self, time_breaks): - """ - Returns the proportion of coalescence events occurring in the time - intervals defined by ``time_breaks``, out of events that have not - yet occurred by the intervals' left boundaries. - - The output array has shape ``(self.num_weights, len(time_breaks)-1, - self.num_windows)``. - """ - - assert isinstance(time_breaks, np.ndarray) - - time_breaks = np.sort(np.unique(time_breaks)) - num_coalesced = self.num_coalesced(time_breaks) - num_uncoalesced = self.num_uncoalesced(time_breaks) - numer = num_coalesced[:, 1:, :] - num_coalesced[:, :-1, :] - denom = num_uncoalesced[:, :-1, :] - return numer / np.where(np.isclose(denom, 0.0), np.nan, denom) - - def coalescence_rate_in_intervals(self, time_breaks): - """ - Returns the interval-censored Kaplan-Meier estimate of the hazard rate for - coalesence events within the time intervals defined by ``time_breaks``. The - estimator is, - - .. math:: - \\hat{c}_{l,r} = \\begin{cases} - \\log(1 - x_{l,r}/k_{l})/(l - r) & \\mathrm{if~} x_{l,r} < k_{l} \\\\ - \\hat{c}_{l,r} = k_{l} / t_{l,r} & \\mathrm{if~} x_{l,r} = k_{l} - \\end{cases} - - and is undefined where :math:`k_{l} = 0`. Here, :math:`x_{l,r}` is the - number of events occuring in time interval :math:`(l, r]`, - :math:`k_{l}` is the number of events remaining at time :math:`l`, and - :math:`t_{l,r}` is the sum of event times occurring in the interval - :math:`(l, r]`. - - The output array has shape ``(self.num_weights, len(time_breaks)-1, - self.num_windows)``. - """ - - assert isinstance(time_breaks, np.ndarray) - - time_breaks = np.sort(np.unique(time_breaks)) - phi = self.coalescence_probability_in_intervals(time_breaks) - duration = np.reshape(time_breaks[1:] - time_breaks[:-1], (1, phi.shape[1], 1)) - numer = -np.log(1.0 - np.where(np.isclose(phi, 1.0), np.nan, phi)) - denom = np.tile(duration, (self.num_weights, 1, self.num_windows)) - for i, j, k in np.argwhere(np.isclose(phi, 1.0)): - numer[i, j, k] = 1.0 - denom[i, j, k] = self.mean(time_breaks[j])[i, k] - return numer / denom - - def block_bootstrap(self, num_replicates=1, random_seed=None): - """ - Return a generator that produces ``num_replicates`` copies of the - object where blocks within genomic windows are randomly resampled. - - ..note:: Copying could be expensive. - """ - - rng = np.random.default_rng(random_seed) - for _i in range(num_replicates): - replicate = self.copy() - for table in replicate.tables: - block_multiplier = rng.multinomial( - table.num_blocks, [1.0 / table.num_blocks] * table.num_blocks - ) - table.resample_blocks(block_multiplier) - yield replicate diff --git a/python/tskit/trees.py b/python/tskit/trees.py index c198bd9237..cd08c40c2f 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -46,7 +46,6 @@ import tskit.combinatorics as combinatorics import tskit.drawing as drawing import tskit.metadata as metadata_module -import tskit.stats as stats import tskit.tables as tables import tskit.text_formats as text_formats import tskit.util as util @@ -9291,25 +9290,203 @@ def ibd_segments( store_pairs=store_pairs, ) - def coalescence_time_distribution( + def pair_coalescence_counts( self, - *, sample_sets=None, - weight_func=None, - window_breaks=None, - blocks_per_window=None, - span_normalise=False, + indexes=None, + windows=None, + span_normalise=True, + time_windows="nodes", ): - # TODO docstring, not yet in API + """ + Calculate the number of coalescing sample pairs per node, summed over + trees and weighted by tree span. - return stats.CoalescenceTimeDistribution( - self, - sample_sets=sample_sets, - weight_func=weight_func, - window_breaks=window_breaks, - blocks_per_window=blocks_per_window, - span_normalise=span_normalise, - ) + The number of coalescing pairs may be calculated within or between the + non-overlapping lists of samples contained in `sample_sets`. In the + latter case, pairs are counted if they have exactly one member in each + of two sample sets. If `sample_sets` is omitted, a single group + containing all samples is assumed. + + The argument `indexes` may be used to specify which pairs of sample + sets to compute the statistic between, and in what order. If + `indexes=None`, then `indexes` is assumed to equal `[(0,0)]` for a + single sample set and `[(0,1)]` for two sample sets. For more than two + sample sets, `indexes` must be explicitly passed. + + The argument `time_windows` may be used to count coalescence + events within time intervals (if an array of breakpoints is supplied) + rather than for individual nodes (the default). + + The output array has dimension `(windows, nodes, indexes)` with + dimensions dropped when the corresponding argument is set to None. + + :param list sample_sets: A list of lists of Node IDs, specifying the + groups of nodes to compute the statistic with, or None. + :param list indexes: A list of 2-tuples, or None. + :param list windows: An increasing list of breakpoints between the + sequence windows to compute the statistic in, or None. + :param bool span_normalise: Whether to divide the result by the span of + the window (defaults to True). + :param time_windows: Either a string "nodes" or an increasing + list of breakpoints between time intervals. + """ + + if sample_sets is None: + sample_sets = [list(self.samples())] + for s in sample_sets: + if len(s) == 0: + raise ValueError("Sample sets must contain at least one element") + if not (min(s) >= 0 and max(s) < self.num_nodes): + raise ValueError("Sample is out of bounds") + + drop_right_dimension = False + if indexes is None: + drop_right_dimension = True + if len(sample_sets) == 1: + indexes = [(0, 0)] + elif len(sample_sets) == 2: + indexes = [(0, 1)] + else: + raise ValueError( + "Must specify indexes if there are more than two sample sets" + ) + for i in indexes: + if not len(i) == 2: + raise ValueError("Sample set indexes must be length two") + if not (min(i) >= 0 and max(i) < len(sample_sets)): + raise ValueError("Sample set index is out of bounds") + + drop_left_dimension = False + if windows is None: + drop_left_dimension = True + windows = np.array([0.0, self.sequence_length]) + if not (isinstance(windows, np.ndarray) and windows.size > 1): + raise ValueError("Windows must be an array of breakpoints") + if not (windows[0] == 0.0 and windows[-1] == self.sequence_length): + raise ValueError("First and last window breaks must be sequence boundary") + if not np.all(np.diff(windows) > 0): + raise ValueError("Window breaks must be strictly increasing") + + if isinstance(time_windows, str) and time_windows == "nodes": + nodes_map = np.arange(self.num_nodes) + output_size = self.num_nodes + else: + if not (isinstance(time_windows, np.ndarray) and time_windows.size > 1): + raise ValueError("Time windows must be an array of breakpoints") + if not np.all(np.diff(time_windows) > 0): + raise ValueError("Time windows must be strictly increasing") + if self.time_units == tskit.TIME_UNITS_UNCALIBRATED: + raise ValueError("Time windows requires calibrated node times") + nodes_map = np.searchsorted(time_windows, self.nodes_time, side="right") - 1 + nodes_oob = np.logical_or(nodes_map < 0, nodes_map >= time_windows.size) + nodes_map[nodes_oob] = tskit.NULL + output_size = time_windows.size - 1 + + num_nodes = self.num_nodes + num_edges = self.num_edges + num_windows = windows.size - 1 + num_sample_sets = len(sample_sets) + num_indexes = len(indexes) + + edges_child = self.edges_child + edges_parent = self.edges_parent + insert_index = self.indexes_edge_insertion_order + remove_index = self.indexes_edge_removal_order + insert_position = self.edges_left[insert_index] + remove_position = self.edges_right[remove_index] + sequence_length = self.sequence_length + + windows_span = np.zeros(num_windows) + nodes_parent = np.full(num_nodes, tskit.NULL) + nodes_sample = np.zeros((num_nodes, num_sample_sets)) + coalescing_pairs = np.zeros((num_windows, output_size, num_indexes)) + + for i, s in enumerate(sample_sets): + nodes_sample[s, i] = 1 + sample_counts = nodes_sample.copy() + position = 0.0 + w, a, b = 0, 0, 0 + while position < sequence_length: + remainder = sequence_length - position + + while b < num_edges and remove_position[b] == position: # edges out + e = remove_index[b] + p = edges_parent[e] + c = edges_child[e] + nodes_parent[c] = tskit.NULL + inside = sample_counts[c] + while p != tskit.NULL: + u = nodes_map[p] + if u != tskit.NULL: + outside = sample_counts[p] - sample_counts[c] - nodes_sample[p] + for i, (j, k) in enumerate(indexes): + weight = inside[j] * outside[k] + inside[k] * outside[j] + coalescing_pairs[w, u, i] -= weight * remainder + c, p = p, nodes_parent[p] + p = edges_parent[e] + while p != tskit.NULL: + sample_counts[p] -= inside + p = nodes_parent[p] + b += 1 + + while a < num_edges and insert_position[a] == position: # edges in + e = insert_index[a] + p = edges_parent[e] + c = edges_child[e] + nodes_parent[c] = p + inside = sample_counts[c] + while p != tskit.NULL: + sample_counts[p] += inside + p = nodes_parent[p] + p = edges_parent[e] + while p != tskit.NULL: + u = nodes_map[p] + if u != tskit.NULL: + outside = sample_counts[p] - sample_counts[c] - nodes_sample[p] + for i, (j, k) in enumerate(indexes): + weight = inside[j] * outside[k] + inside[k] * outside[j] + coalescing_pairs[w, u, i] += weight * remainder + c, p = p, nodes_parent[p] + a += 1 + + position = sequence_length + if b < num_edges: + position = min(position, remove_position[b]) + if a < num_edges: + position = min(position, insert_position[a]) + + while w < num_windows and windows[w + 1] <= position: # flush window + windows_span[w] -= position - windows[w + 1] + if w + 1 < num_windows: + windows_span[w + 1] += position - windows[w + 1] + remainder = sequence_length - windows[w + 1] + for c, p in enumerate(nodes_parent): + u = nodes_map[p] + if p == tskit.NULL or u == tskit.NULL: + continue + inside = sample_counts[c] + outside = sample_counts[p] - sample_counts[c] - nodes_sample[p] + for i, (j, k) in enumerate(indexes): + weight = inside[j] * outside[k] + inside[k] * outside[j] + coalescing_pairs[w, u, i] -= weight * remainder / 2 + if w + 1 < num_windows: + coalescing_pairs[w + 1, u, i] += weight * remainder / 2 + w += 1 + + for i, (j, k) in enumerate(indexes): + if j == k: + coalescing_pairs[:, :, i] /= 2 + if span_normalise: + for w, s in enumerate(np.diff(windows)): + coalescing_pairs[w] /= s + + if drop_right_dimension: + coalescing_pairs = coalescing_pairs[..., 0] + if drop_left_dimension: + coalescing_pairs = coalescing_pairs[0] + + return coalescing_pairs def impute_unknown_mutations_time( self, From 77bbe724ad1412c4379d3ad0a2acd20ed621a912 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Wed, 10 Apr 2024 17:18:03 -0700 Subject: [PATCH 2/2] Minor change to trigger codecov --- python/tskit/trees.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index cd08c40c2f..85ea7934d8 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -9328,8 +9328,8 @@ def pair_coalescence_counts( sequence windows to compute the statistic in, or None. :param bool span_normalise: Whether to divide the result by the span of the window (defaults to True). - :param time_windows: Either a string "nodes" or an increasing - list of breakpoints between time intervals. + :param time_windows: Either "nodes" or an increasing list of + breakpoints between time intervals. """ if sample_sets is None: