From 44e988f3445861931dbd6b1e1da5f954dd98d5de Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 9 Aug 2022 10:17:05 +0900 Subject: [PATCH 01/14] added failing tests --- tests/test_tracking.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_tracking.py b/tests/test_tracking.py index 53670fbd..f77cd262 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -199,3 +199,16 @@ def test_no_accepting_wrong_argments() -> None: lt = LapTrack(hogehoge=True) with pytest.raises(ValidationError): lt = LapTrack(fugafuga=True) + + +def test_connected_edges() -> None: + coords = [np.array([[10, 10], [12, 11]]), np.array([[10, 10], [13, 11]])] + lt = LapTrack( + gap_closing_cost_cutoff=100, + splitting_cost_cutoff=100, + merging_cost_cutoff=100, + ) # type: ignore + connected_edges = [((0, 0), (1, 1))] + track_tree = lt.predict(coords, connected_edges=connected_edges) + edges = track_tree.edges() + assert set(edges) == set([((0, 0), (1, 1)), ((0, 1), (1, 0))]) From dc0dd208f113ce1a1cf7c786f55488301605c851 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 9 Aug 2022 10:37:29 +0900 Subject: [PATCH 02/14] wrote frame linking version --- src/laptrack/_coo_matrix_builder.py | 14 ++++++++++++ src/laptrack/_cost_matrix.py | 5 +++-- src/laptrack/_tracking.py | 35 +++++++++++++++++++++++------ 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/src/laptrack/_coo_matrix_builder.py b/src/laptrack/_coo_matrix_builder.py index 1cb21db6..988b9451 100644 --- a/src/laptrack/_coo_matrix_builder.py +++ b/src/laptrack/_coo_matrix_builder.py @@ -8,6 +8,7 @@ import numpy as np import numpy.typing as npt from scipy.sparse import coo_matrix +from scipy.sparse import csr_matrix from ._typing_utils import Float from ._typing_utils import Int @@ -159,6 +160,19 @@ def to_coo_matrix(self) -> coo_matrix: shape=(self.n_row, self.n_col), ) + def to_csr_matrix(self) -> coo_matrix: + """Generate `csr_matrix`. + + Returns + ------- + matrix : csr_matrix + the generated csr_matrix + """ + return csr_matrix( + (self.data, (self.row, self.col)), + shape=(self.n_row, self.n_col), + ) + def __setitem__( self, index: Tuple[Union[Int, IndexType], Union[Int, IndexType]], value ): diff --git a/src/laptrack/_cost_matrix.py b/src/laptrack/_cost_matrix.py index 21af793e..05f566eb 100644 --- a/src/laptrack/_cost_matrix.py +++ b/src/laptrack/_cost_matrix.py @@ -3,6 +3,7 @@ import numpy as np from scipy.sparse import coo_matrix +from scipy.sparse import csr_matrix from ._coo_matrix_builder import coo_matrix_builder from ._typing_utils import Float @@ -16,7 +17,7 @@ def build_frame_cost_matrix( *, track_start_cost: Optional[Float], track_end_cost: Optional[Float], -) -> coo_matrix: +) -> csr_matrix: """Build sparce array for frame-linking cost matrix. Parameters @@ -57,7 +58,7 @@ def build_frame_cost_matrix( C.data = C.data + EPSILON - return C.to_coo_matrix() + return C.to_csr_matrix() def build_segment_cost_matrix( diff --git a/src/laptrack/_tracking.py b/src/laptrack/_tracking.py index 096b130c..1fb6869a 100644 --- a/src/laptrack/_tracking.py +++ b/src/laptrack/_tracking.py @@ -280,11 +280,12 @@ class LapTrackBase(BaseModel, ABC, extra=Extra.forbid): + "See `numpy.percentile` for accepted values.", ) - def _link_frames(self, coords) -> nx.Graph: + def _link_frames(self, coords, connected_edges_list) -> nx.Graph: """Link particles between frames according to the cost function Args: coords (List[np.ndarray]): the input coordinates + connected_edges_list (List[List[Tuple[Tuple[int, int],Tuple[int, int]]]]): the connected edges list Returns: nx.Graph: the resulted tree @@ -311,6 +312,10 @@ def _link_frames(self, coords) -> nx.Graph: track_start_cost=self.track_start_cost, track_end_cost=self.track_end_cost, ) + if len(connected_edges_list[frame]) > 0: + min_val = -cost_matrix.max() * 1.05 + for n1, n2 in connected_edges_list[frame]: + cost_matrix[n1[1], n2[1]] = min_val xs, _ = lap_optimization(cost_matrix) count1 = dist_matrix.shape[0] @@ -385,10 +390,10 @@ def _link_gap_split_merge_from_matrix( return track_tree @abstractmethod - def _predict_gap_split_merge(self, coords, track_tree): + def _predict_gap_split_merge(self, coords, track_tree, connected_edges_list): ... - def predict(self, coords) -> nx.Graph: + def predict(self, coords, connected_edges=None) -> nx.Graph: """Predict the tracking graph from coordinates Args: @@ -410,14 +415,28 @@ def predict(self, coords) -> nx.Graph: if any(list(map(lambda coord: coord.shape[1] != coord_dim, coords))): raise ValueError("the second dimension in coords must have the same size") + ######## initialize connected edges ######## + connected_edges_list = [[]] * len(coords) + if connected_edges is not None: + connected_edges = [ + e if e[0][0] < e[1][0] else (e[1], e[0]) for e in connected_edges + ] + for frame in range(len(coords)): + connected_edges_list[frame] = [ + e for e in connected_edges if e[0][0] == frame + ] + assert connected_edges_list[-1] == [] + ####### Particle-particle tracking ####### - track_tree = self._link_frames(coords) - track_tree = self._predict_gap_split_merge(coords, track_tree) + track_tree = self._link_frames(coords, connected_edges_list) + track_tree = self._predict_gap_split_merge( + coords, track_tree, connected_edges_list + ) return track_tree class LapTrack(LapTrackBase): - def _predict_gap_split_merge(self, coords, track_tree): + def _predict_gap_split_merge(self, coords, track_tree, connected_edges_list): """one-step fitting, as TrackMate and K. Jaqaman et al., Nat Methods 5, 695 (2008). Args: @@ -426,6 +445,8 @@ def _predict_gap_split_merge(self, coords, track_tree): The array index means (sample, dimension). track_tree : nx.Graph the track tree + connected_edges_list (List[List[Tuple[Tuple[int, int],Tuple[int, int]]]]): + the connected edges list Returns: track_tree : nx.Graph @@ -501,7 +522,7 @@ def _get_segment_connecting_matrix(self, segments_df): self.segment_connecting_cost_cutoff, ) - def _predict_gap_split_merge(self, coords, track_tree): + def _predict_gap_split_merge(self, coords, track_tree, connected_edges_list): # "multi-step" type of fitting (Y. T. Fukai (2022)) segments_df = _get_segment_df(coords, track_tree) From 86b79eb399592195b9c05ffe87561b4dc10e8ab9 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 9 Aug 2022 10:48:11 +0900 Subject: [PATCH 03/14] tests running. --- src/laptrack/_cost_matrix.py | 5 ++--- src/laptrack/_tracking.py | 38 ++++++++++++++++++------------------ 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/src/laptrack/_cost_matrix.py b/src/laptrack/_cost_matrix.py index 05f566eb..9aed90db 100644 --- a/src/laptrack/_cost_matrix.py +++ b/src/laptrack/_cost_matrix.py @@ -2,7 +2,6 @@ from typing import Union import numpy as np -from scipy.sparse import coo_matrix from scipy.sparse import csr_matrix from ._coo_matrix_builder import coo_matrix_builder @@ -72,7 +71,7 @@ def build_segment_cost_matrix( alternative_cost_factor: Float = 1.05, alternative_cost_percentile: Float = 90, alternative_cost_percentile_interpolation: str = "lower", -) -> Optional[coo_matrix]: +) -> Optional[csr_matrix]: """Build sparce array for segment-linking cost matrix. Parameters @@ -172,4 +171,4 @@ def build_segment_cost_matrix( C[upper_left_cols + M + N1, upper_left_rows + M + N2] = min_val C.data = C.data + EPSILON - return C.to_coo_matrix() + return C.to_csr_matrix() diff --git a/src/laptrack/_tracking.py b/src/laptrack/_tracking.py index 1fb6869a..9884ab21 100644 --- a/src/laptrack/_tracking.py +++ b/src/laptrack/_tracking.py @@ -280,12 +280,12 @@ class LapTrackBase(BaseModel, ABC, extra=Extra.forbid): + "See `numpy.percentile` for accepted values.", ) - def _link_frames(self, coords, connected_edges_list) -> nx.Graph: + def _link_frames(self, coords, connected_edges) -> nx.Graph: """Link particles between frames according to the cost function Args: coords (List[np.ndarray]): the input coordinates - connected_edges_list (List[List[Tuple[Tuple[int, int],Tuple[int, int]]]]): the connected edges list + connected_edges (EdgesType): the connected edges list Returns: nx.Graph: the resulted tree @@ -296,6 +296,18 @@ def _link_frames(self, coords, connected_edges_list) -> nx.Graph: for j in range(coord.shape[0]): track_tree.add_node((frame, j)) + # initialize connected edges + connected_edges_list = [[]] * len(coords) + if connected_edges is not None: + connected_edges = [ + e if e[0][0] < e[1][0] else (e[1], e[0]) for e in connected_edges + ] + for frame in range(len(coords)): + connected_edges_list[frame] = [ + e for e in connected_edges if e[0][0] == frame + ] + assert connected_edges_list[-1] == [] + # linking between frames for frame, (coord1, coord2) in enumerate(zip(coords[:-1], coords[1:])): dist_matrix = cdist(coord1, coord2, metric=self.track_dist_metric) @@ -359,6 +371,8 @@ def _link_gap_split_merge_from_matrix( ) if not cost_matrix is None: + # FIXME connected_edges_list + xs, ys = lap_optimization(cost_matrix) M = gap_closing_dist_matrix.shape[0] @@ -415,28 +429,14 @@ def predict(self, coords, connected_edges=None) -> nx.Graph: if any(list(map(lambda coord: coord.shape[1] != coord_dim, coords))): raise ValueError("the second dimension in coords must have the same size") - ######## initialize connected edges ######## - connected_edges_list = [[]] * len(coords) - if connected_edges is not None: - connected_edges = [ - e if e[0][0] < e[1][0] else (e[1], e[0]) for e in connected_edges - ] - for frame in range(len(coords)): - connected_edges_list[frame] = [ - e for e in connected_edges if e[0][0] == frame - ] - assert connected_edges_list[-1] == [] - ####### Particle-particle tracking ####### - track_tree = self._link_frames(coords, connected_edges_list) - track_tree = self._predict_gap_split_merge( - coords, track_tree, connected_edges_list - ) + track_tree = self._link_frames(coords, connected_edges) + track_tree = self._predict_gap_split_merge(coords, track_tree, connected_edges) return track_tree class LapTrack(LapTrackBase): - def _predict_gap_split_merge(self, coords, track_tree, connected_edges_list): + def _predict_gap_split_merge(self, coords, track_tree, connected_edges): """one-step fitting, as TrackMate and K. Jaqaman et al., Nat Methods 5, 695 (2008). Args: From e381693ed3190f1eacedf5894e33e27915bdce2c Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 9 Aug 2022 11:36:10 +0900 Subject: [PATCH 04/14] frame link version working --- src/laptrack/_cost_matrix.py | 16 ++++++++++------ src/laptrack/_tracking.py | 21 +++++++++++++++++---- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/laptrack/_cost_matrix.py b/src/laptrack/_cost_matrix.py index 9aed90db..3532bbfb 100644 --- a/src/laptrack/_cost_matrix.py +++ b/src/laptrack/_cost_matrix.py @@ -6,6 +6,7 @@ from ._coo_matrix_builder import coo_matrix_builder from ._typing_utils import Float +from ._typing_utils import FloatArray from ._typing_utils import Matrix EPSILON = 1e-6 @@ -14,8 +15,8 @@ def build_frame_cost_matrix( dist_matrix: coo_matrix_builder, *, - track_start_cost: Optional[Float], - track_end_cost: Optional[Float], + track_start_cost: Optional[Union[Float, FloatArray]], + track_end_cost: Optional[Union[Float, FloatArray]], ) -> csr_matrix: """Build sparce array for frame-linking cost matrix. @@ -23,9 +24,9 @@ def build_frame_cost_matrix( ---------- dist_matrix : Matrix or `_utils.coo_matrix_builder` The distance matrix for points at time t and t+1. - track_start_cost : Float, optional + track_start_cost : Float or FloatArray, optional The cost for starting the track (b in Jaqaman et al 2008 NMeth) - track_end_cost : Float, optional + track_end_cost : Float or FloatArray, optional The cost for ending the track (d in Jaqaman et al 2008 NMeth) Returns @@ -50,8 +51,11 @@ def build_frame_cost_matrix( else: track_end_cost = 1.05 - C[np.arange(M, M + N), np.arange(N)] = np.ones(N) * track_end_cost - C[np.arange(M), np.arange(N, N + M)] = np.ones(M) * track_start_cost + track_end_costs = np.ones(N) * track_end_cost + track_start_costs = np.ones(M) * track_start_cost + + C[np.arange(M, M + N), np.arange(N)] = track_end_costs + C[np.arange(M), np.arange(N, N + M)] = track_start_costs min_val = np.min(C.data) if len(C.data) > 0 else 0 C[dist_matrix.col + M, dist_matrix.row + N] = min_val diff --git a/src/laptrack/_tracking.py b/src/laptrack/_tracking.py index 9884ab21..ba1d7be8 100644 --- a/src/laptrack/_tracking.py +++ b/src/laptrack/_tracking.py @@ -310,7 +310,14 @@ def _link_frames(self, coords, connected_edges) -> nx.Graph: # linking between frames for frame, (coord1, coord2) in enumerate(zip(coords[:-1], coords[1:])): + force_end_indices = [e[0][1] for e in connected_edges_list[frame]] + force_start_indices = [ + e[1][1] for e in connected_edges_list[frame] if e[1][0] == frame + 1 + ] dist_matrix = cdist(coord1, coord2, metric=self.track_dist_metric) + dist_matrix[force_end_indices, :] = np.inf + dist_matrix[:, force_start_indices] = np.inf + ind = np.where(dist_matrix < self.track_cost_cutoff) dist_matrix = coo_matrix_builder( dist_matrix.shape, @@ -319,15 +326,13 @@ def _link_frames(self, coords, connected_edges) -> nx.Graph: data=dist_matrix[(*ind,)], dtype=dist_matrix.dtype, ) + cost_matrix = build_frame_cost_matrix( dist_matrix, track_start_cost=self.track_start_cost, track_end_cost=self.track_end_cost, ) - if len(connected_edges_list[frame]) > 0: - min_val = -cost_matrix.max() * 1.05 - for n1, n2 in connected_edges_list[frame]: - cost_matrix[n1[1], n2[1]] = min_val + print(cost_matrix.todense()) xs, _ = lap_optimization(cost_matrix) count1 = dist_matrix.shape[0] @@ -335,6 +340,8 @@ def _link_frames(self, coords, connected_edges) -> nx.Graph: connections = [(i, xs[i]) for i in range(count1) if xs[i] < count2] # track_start=[i for i in range(count1) if xs[i]>count2] # track_end=[i for i in range(count2) if ys[i]>count1] + for edge in connected_edges_list[frame]: + track_tree.add_edge(*edge) for connection in connections: track_tree.add_edge((frame, connection[0]), (frame + 1, connection[1])) return track_tree @@ -485,6 +492,12 @@ def _predict_gap_split_merge(self, coords, track_tree, connected_edges): merging_dist_matrix = dist_matrices["last"] splitting_all_candidates = middle_points["first"] merging_all_candidates = middle_points["last"] + print("========") + print(segments_df) + print(splitting_all_candidates) + print(merging_all_candidates) + print("========") + track_tree = self._link_gap_split_merge_from_matrix( segments_df, track_tree, From 32f8f4281ec7a9b2da7f339f96e829eb025cc17f Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 9 Aug 2022 11:42:56 +0900 Subject: [PATCH 05/14] added failing tests... --- tests/test_tracking.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_tracking.py b/tests/test_tracking.py index f77cd262..6a323096 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -212,3 +212,19 @@ def test_connected_edges() -> None: track_tree = lt.predict(coords, connected_edges=connected_edges) edges = track_tree.edges() assert set(edges) == set([((0, 0), (1, 1)), ((0, 1), (1, 0))]) + + +def test_connected_edges_splitting() -> None: + coords = [ + np.array([[10, 10], [11, 11], [13, 12]]), + np.array([[10, 10], [13, 11], [13, 15]]), + ] + lt = LapTrack( + gap_closing_cost_cutoff=100, + splitting_cost_cutoff=100, + merging_cost_cutoff=100, + ) # type: ignore + connected_edges = [((0, 0), (1, 1)), ((0, 0), (1, 2))] + track_tree = lt.predict(coords, connected_edges=connected_edges) + edges = track_tree.edges() + assert set(edges) == set([((0, 0), (1, 1)), ((0, 0), (1, 2)), ((0, 1), (1, 0))]) From a262045b7b9a4c0fd3a3bfc1de71eb6887ba026b Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 9 Aug 2022 11:51:10 +0900 Subject: [PATCH 06/14] tests running but this strategy is not good maybe ... --- tests/test_tracking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_tracking.py b/tests/test_tracking.py index 6a323096..81f6db28 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -222,7 +222,7 @@ def test_connected_edges_splitting() -> None: lt = LapTrack( gap_closing_cost_cutoff=100, splitting_cost_cutoff=100, - merging_cost_cutoff=100, + merging_cost_cutoff=False, ) # type: ignore connected_edges = [((0, 0), (1, 1)), ((0, 0), (1, 2))] track_tree = lt.predict(coords, connected_edges=connected_edges) From e18a2999e9728b4321b4854549aeb9857e4e1995 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 9 Aug 2022 22:26:22 +0900 Subject: [PATCH 07/14] test working to middle --- src/laptrack/_tracking.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/laptrack/_tracking.py b/src/laptrack/_tracking.py index ba1d7be8..010478e6 100644 --- a/src/laptrack/_tracking.py +++ b/src/laptrack/_tracking.py @@ -436,8 +436,33 @@ def predict(self, coords, connected_edges=None) -> nx.Graph: if any(list(map(lambda coord: coord.shape[1] != coord_dim, coords))): raise ValueError("the second dimension in coords must have the same size") + if connected_edges: + connected_edges = [ + (n1, n2) if n1[0] < n2[0] else (n2, n1) for n1, n2 in connected_edges + ] + tree = nx.from_edgelist(connected_edges) + split_edges = [] + merge_edges = [] + for m in tree.nodes(): + successors = [n for n in tree.neighbors(m) if n[0] > m[0]] + if len(successors) > 1: + assert len(successors) == 2, "splitting into >2 nodes" + split_edges.append((m, successors[0])) + predecessors = [n for n in tree.neighbors(m) if n[0] < m[0]] + if len(predecessors) > 1: + assert len(predecessors) == 2, "merging of >2 nodes" + merge_edges.append((m, predecessors[0])) + segment_connected_edges = list( + set(connected_edges) - set(split_edges) - set(merge_edges) + ) + else: + segment_connected_edges = None + split_edges = None + merge_edges = None + ####### Particle-particle tracking ####### - track_tree = self._link_frames(coords, connected_edges) + print(segment_connected_edges) + track_tree = self._link_frames(coords, segment_connected_edges) track_tree = self._predict_gap_split_merge(coords, track_tree, connected_edges) return track_tree From 069d809c0919f7a1719ede57ff8c5218696dfe37 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 10 Aug 2022 09:52:11 +0900 Subject: [PATCH 08/14] test working... --- src/laptrack/_tracking.py | 42 +++++++++++++++------------------------ 1 file changed, 16 insertions(+), 26 deletions(-) diff --git a/src/laptrack/_tracking.py b/src/laptrack/_tracking.py index 010478e6..f93ad6ab 100644 --- a/src/laptrack/_tracking.py +++ b/src/laptrack/_tracking.py @@ -280,12 +280,15 @@ class LapTrackBase(BaseModel, ABC, extra=Extra.forbid): + "See `numpy.percentile` for accepted values.", ) - def _link_frames(self, coords, connected_edges) -> nx.Graph: + def _link_frames( + self, coords, segment_connected_edges, split_merge_edges + ) -> nx.Graph: """Link particles between frames according to the cost function Args: coords (List[np.ndarray]): the input coordinates - connected_edges (EdgesType): the connected edges list + segment_connected_edges (EdgesType): the connected edges list that will be connected in this step + split_merge_edges (EdgesType): the connected edges list that will be connected in split and merge step Returns: nx.Graph: the resulted tree @@ -296,24 +299,11 @@ def _link_frames(self, coords, connected_edges) -> nx.Graph: for j in range(coord.shape[0]): track_tree.add_node((frame, j)) - # initialize connected edges - connected_edges_list = [[]] * len(coords) - if connected_edges is not None: - connected_edges = [ - e if e[0][0] < e[1][0] else (e[1], e[0]) for e in connected_edges - ] - for frame in range(len(coords)): - connected_edges_list[frame] = [ - e for e in connected_edges if e[0][0] == frame - ] - assert connected_edges_list[-1] == [] - # linking between frames + edges_list = list(segment_connected_edges) + list(split_merge_edges) for frame, (coord1, coord2) in enumerate(zip(coords[:-1], coords[1:])): - force_end_indices = [e[0][1] for e in connected_edges_list[frame]] - force_start_indices = [ - e[1][1] for e in connected_edges_list[frame] if e[1][0] == frame + 1 - ] + force_end_indices = [e[0][1] for e in edges_list if e[0][0] == frame] + force_start_indices = [e[1][1] for e in edges_list if e[1][0] == frame + 1] dist_matrix = cdist(coord1, coord2, metric=self.track_dist_metric) dist_matrix[force_end_indices, :] = np.inf dist_matrix[:, force_start_indices] = np.inf @@ -340,10 +330,9 @@ def _link_frames(self, coords, connected_edges) -> nx.Graph: connections = [(i, xs[i]) for i in range(count1) if xs[i] < count2] # track_start=[i for i in range(count1) if xs[i]>count2] # track_end=[i for i in range(count2) if ys[i]>count1] - for edge in connected_edges_list[frame]: - track_tree.add_edge(*edge) for connection in connections: track_tree.add_edge((frame, connection[0]), (frame + 1, connection[1])) + track_tree.add_edges_from(segment_connected_edges) return track_tree def _get_gap_closing_matrix(self, segments_df): @@ -451,18 +440,19 @@ def predict(self, coords, connected_edges=None) -> nx.Graph: predecessors = [n for n in tree.neighbors(m) if n[0] < m[0]] if len(predecessors) > 1: assert len(predecessors) == 2, "merging of >2 nodes" - merge_edges.append((m, predecessors[0])) + merge_edges.append((predecessors[0], m)) segment_connected_edges = list( set(connected_edges) - set(split_edges) - set(merge_edges) ) else: - segment_connected_edges = None - split_edges = None - merge_edges = None + segment_connected_edges = [] + split_edges = [] + merge_edges = [] ####### Particle-particle tracking ####### - print(segment_connected_edges) - track_tree = self._link_frames(coords, segment_connected_edges) + track_tree = self._link_frames( + coords, segment_connected_edges, list(split_edges) + list(merge_edges) + ) track_tree = self._predict_gap_split_merge(coords, track_tree, connected_edges) return track_tree From e647c5e71579aff9b5395d1f23f5b05e9c9ff1c8 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 10 Aug 2022 10:21:26 +0900 Subject: [PATCH 09/14] updated gap closing --- src/laptrack/_tracking.py | 65 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 5 deletions(-) diff --git a/src/laptrack/_tracking.py b/src/laptrack/_tracking.py index f93ad6ab..628ffbb6 100644 --- a/src/laptrack/_tracking.py +++ b/src/laptrack/_tracking.py @@ -39,6 +39,21 @@ def _get_segment_df(coords, track_tree): + """Create segment dataframe from track tree. + + Parameters + ---------- + coords : + coordinates + track_tree : nx.Graph + the track tree + + Returns + ------- + pd.DataFrame + the segment dataframe, with columns "segment", "first_frame", "first_index", + "first_frame_coords", "last_frame", "last_index", "last_frame_coords" + """ # linking between tracks segments = list(nx.connected_components(track_tree)) first_nodes = np.array( @@ -66,15 +81,53 @@ def _get_segment_df(coords, track_tree): def _get_segment_end_connecting_matrix( - segments_df, max_frame_count, dist_metric, cost_cutoff + segments_df, + max_frame_count, + dist_metric, + cost_cutoff, + force_end_indices=[], + force_start_indices=[], ): + """Generate the cost matrix for connecting segment ends. + + Parameters + ---------- + segments_df : pd.DataFrame + must have columns "first_frame", "first_index", "first_crame_coords", "last_frame", "last_index", "last_frame_coords" + max_frame_count : int + connecting cost is set to infinity if the distance between the two ends is larger than this value + dist_metric : + the distance metric + cost_cutoff : float + the cutoff value for the cost + force_end_indices : list of int + the indices of the segments_df that is forced to be end for future connection + force_start_indices : list of int + the indices of the segments_df that is forced to be start for future connection + + Returns + ------- + segments_df: pd.DataFrame + the segments dataframe with additional column "gap_closing_candidates" + (index of the candidate row of segments_df, the associated costs) + """ if cost_cutoff: def to_gap_closing_candidates(row): + # if the index is in force_end_indices, do not add to gap closing candidates + if row.name in force_end_indices: + return [], [] + target_coord = row["last_frame_coords"] frame_diff = segments_df["first_frame"] - row["last_frame"] + + # only take the elements that are within the frame difference range. + # segments in df is later than the candidate segment (row) indices = (1 <= frame_diff) & (frame_diff <= max_frame_count) df = segments_df[indices] + df = df.drop( + index=force_start_indices, errors="ignore" + ) # do not connect to the segments that is forced to be start # note: can use KDTree if metric is distance, # but might not be appropriate for general metrics # https://stackoverflow.com/questions/35459306/find-points-within-cutoff-distance-of-other-points-with-scipy # noqa @@ -400,7 +453,7 @@ def _link_gap_split_merge_from_matrix( return track_tree @abstractmethod - def _predict_gap_split_merge(self, coords, track_tree, connected_edges_list): + def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges): ... def predict(self, coords, connected_edges=None) -> nx.Graph: @@ -453,12 +506,14 @@ def predict(self, coords, connected_edges=None) -> nx.Graph: track_tree = self._link_frames( coords, segment_connected_edges, list(split_edges) + list(merge_edges) ) - track_tree = self._predict_gap_split_merge(coords, track_tree, connected_edges) + track_tree = self._predict_gap_split_merge( + coords, track_tree, split_edges, merge_edges + ) return track_tree class LapTrack(LapTrackBase): - def _predict_gap_split_merge(self, coords, track_tree, connected_edges): + def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges): """one-step fitting, as TrackMate and K. Jaqaman et al., Nat Methods 5, 695 (2008). Args: @@ -550,7 +605,7 @@ def _get_segment_connecting_matrix(self, segments_df): self.segment_connecting_cost_cutoff, ) - def _predict_gap_split_merge(self, coords, track_tree, connected_edges_list): + def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges): # "multi-step" type of fitting (Y. T. Fukai (2022)) segments_df = _get_segment_df(coords, track_tree) From c82ec8c0b9e60077e5110e372f55792a2163f71f Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 10 Aug 2022 10:47:46 +0900 Subject: [PATCH 10/14] adding logic for split and merge... --- src/laptrack/_tracking.py | 54 +++++++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 10 deletions(-) diff --git a/src/laptrack/_tracking.py b/src/laptrack/_tracking.py index 628ffbb6..2407d5be 100644 --- a/src/laptrack/_tracking.py +++ b/src/laptrack/_tracking.py @@ -85,8 +85,9 @@ def _get_segment_end_connecting_matrix( max_frame_count, dist_metric, cost_cutoff, - force_end_indices=[], - force_start_indices=[], + *, + force_end_nodes=[], + force_start_nodes=[], ): """Generate the cost matrix for connecting segment ends. @@ -115,7 +116,7 @@ def _get_segment_end_connecting_matrix( def to_gap_closing_candidates(row): # if the index is in force_end_indices, do not add to gap closing candidates - if row.name in force_end_indices: + if (row["last_frame"], row["last_index"]) in force_end_nodes: return [], [] target_coord = row["last_frame_coords"] @@ -125,9 +126,13 @@ def to_gap_closing_candidates(row): # segments in df is later than the candidate segment (row) indices = (1 <= frame_diff) & (frame_diff <= max_frame_count) df = segments_df[indices] - df = df.drop( - index=force_start_indices, errors="ignore" - ) # do not connect to the segments that is forced to be start + force_start = df.apply( + lambda row: (row["first_frame"], row["first_index"]) + in force_start_nodes, + axis=1, + ) + df = df[~force_start] + # do not connect to the segments that is forced to be start # note: can use KDTree if metric is distance, # but might not be appropriate for general metrics # https://stackoverflow.com/questions/35459306/find-points-within-cutoff-distance-of-other-points-with-scipy # noqa @@ -172,23 +177,52 @@ def _get_splitting_merging_candidates( cutoff, prefix, dist_metric, + *, + force_end_nodes=[], + force_start_nodes=[], ): if cutoff: def to_candidates(row): + # if the prefix is first, this means the row is the track start, and the target is the track end + other_frame = row[f"{prefix}_frame"] + (-1 if prefix == "first" else 1) target_coord = row[f"{prefix}_frame_coords"] - frame = row[f"{prefix}_frame"] + (-1 if prefix == "first" else 1) + row_no_connection_nodes = ( + force_start_nodes if prefix == "first" else force_end_nodes + ) + other_no_connection_nodes = ( + force_end_nodes if prefix == "first" else force_start_nodes + ) + other_no_connection_indices = [ + n[1] for n in other_no_connection_nodes if n[0] == other_frame + ] + + if ( + row[f"{prefix}_frame"], + row[f"{prefix}_index"], + ) in row_no_connection_nodes: + return ( + [], + [], + ) # do not connect to the segments that is forced to be start or end # note: can use KDTree if metric is distance, # but might not be appropriate for general metrics # https://stackoverflow.com/questions/35459306/find-points-within-cutoff-distance-of-other-points-with-scipy # noqa - if frame < 0 or len(coords) <= frame: + if other_frame < 0 or len(coords) <= other_frame: return [], [] target_dist_matrix = cdist( - [target_coord], coords[frame], metric=dist_metric + [target_coord], coords[other_frame], metric=dist_metric ) assert target_dist_matrix.shape[0] == 1 + target_dist_matrix[ + 0, other_no_connection_indices + ] = ( + np.inf + ) # do not connect to the segments that is forced to be start or end indices = np.where(target_dist_matrix[0] < cutoff)[0] - return [(frame, index) for index in indices], target_dist_matrix[0][indices] + return [(other_frame, index) for index in indices], target_dist_matrix[0][ + indices + ] segments_df[f"{prefix}_candidates"] = segments_df.apply(to_candidates, axis=1) else: From 8c11717dc1fae4b2ed920107af796686b2c0a80b Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 10 Aug 2022 11:02:59 +0900 Subject: [PATCH 11/14] tests all green --- src/laptrack/_tracking.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/laptrack/_tracking.py b/src/laptrack/_tracking.py index 2407d5be..de8ce745 100644 --- a/src/laptrack/_tracking.py +++ b/src/laptrack/_tracking.py @@ -422,12 +422,16 @@ def _link_frames( track_tree.add_edges_from(segment_connected_edges) return track_tree - def _get_gap_closing_matrix(self, segments_df): + def _get_gap_closing_matrix( + self, segments_df, *, force_end_nodes=[], force_start_nodes=[] + ): return _get_segment_end_connecting_matrix( segments_df, self.gap_closing_max_frame_count, self.gap_closing_dist_metric, self.gap_closing_cost_cutoff, + force_end_nodes=force_end_nodes, + force_start_nodes=force_start_nodes, ) def _link_gap_split_merge_from_matrix( @@ -563,16 +567,21 @@ def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges) track_tree : nx.Graph the updated track tree """ + edges = list(split_edges) + list(merge_edges) if ( self.gap_closing_cost_cutoff or self.splitting_cost_cutoff or self.merging_cost_cutoff ): segments_df = _get_segment_df(coords, track_tree) + force_end_nodes = [tuple(map(int, e[0])) for e in edges] + force_start_nodes = [tuple(map(int, e[1])) for e in edges] # compute candidate for gap closing segments_df, gap_closing_dist_matrix = self._get_gap_closing_matrix( - segments_df + segments_df, + force_end_nodes=force_end_nodes, + force_start_nodes=force_start_nodes, ) middle_points: Dict = {} @@ -589,18 +598,19 @@ def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges) dist_matrices[prefix], middle_points[prefix], ) = _get_splitting_merging_candidates( - segments_df, coords, cutoff, prefix, dist_metric + segments_df, + coords, + cutoff, + prefix, + dist_metric, + force_end_nodes=force_end_nodes, + force_start_nodes=force_start_nodes, ) splitting_dist_matrix = dist_matrices["first"] merging_dist_matrix = dist_matrices["last"] splitting_all_candidates = middle_points["first"] merging_all_candidates = middle_points["last"] - print("========") - print(segments_df) - print(splitting_all_candidates) - print(merging_all_candidates) - print("========") track_tree = self._link_gap_split_merge_from_matrix( segments_df, @@ -611,7 +621,7 @@ def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges) splitting_all_candidates, merging_all_candidates, ) - + track_tree.add_edges_from(edges) return track_tree From a73e5f0245b9ad707730ee5a654d9dfd909df639 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 10 Aug 2022 11:08:42 +0900 Subject: [PATCH 12/14] failing test added --- src/laptrack/_tracking.py | 2 ++ tests/test_tracking.py | 22 +++++++++------------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/laptrack/_tracking.py b/src/laptrack/_tracking.py index de8ce745..d90c5ca5 100644 --- a/src/laptrack/_tracking.py +++ b/src/laptrack/_tracking.py @@ -651,6 +651,8 @@ def _get_segment_connecting_matrix(self, segments_df): def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges): # "multi-step" type of fitting (Y. T. Fukai (2022)) + + edges = list(split_edges) + list(merge_edges) segments_df = _get_segment_df(coords, track_tree) ###### gap closing step ###### diff --git a/tests/test_tracking.py b/tests/test_tracking.py index 81f6db28..e3cac713 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -90,16 +90,10 @@ def testdata(request, shared_datadir: str): return params, coords, edges_set -def test_reproducing_trackmate(testdata) -> None: +@pytest.mark.parametrize("tracker_class", [LapTrack, LapTrackMulti]) +def test_reproducing_trackmate(testdata, tracker_class) -> None: params, coords, edges_set = testdata - lt = LapTrack(**params) - track_tree = lt.predict(coords) - assert edges_set == set(track_tree.edges) - - -def test_multi_algorithm_reproducing_trackmate(testdata) -> None: - params, coords, edges_set = testdata - lt = LapTrackMulti(**params) + lt = tracker_class(**params) track_tree = lt.predict(coords) assert edges_set == set(track_tree.edges) @@ -201,9 +195,10 @@ def test_no_accepting_wrong_argments() -> None: lt = LapTrack(fugafuga=True) -def test_connected_edges() -> None: +@pytest.mark.parametrize("tracker_class", [LapTrack, LapTrackMulti]) +def test_connected_edges(tracker_class) -> None: coords = [np.array([[10, 10], [12, 11]]), np.array([[10, 10], [13, 11]])] - lt = LapTrack( + lt = tracker_class( gap_closing_cost_cutoff=100, splitting_cost_cutoff=100, merging_cost_cutoff=100, @@ -214,12 +209,13 @@ def test_connected_edges() -> None: assert set(edges) == set([((0, 0), (1, 1)), ((0, 1), (1, 0))]) -def test_connected_edges_splitting() -> None: +@pytest.mark.parametrize("tracker_class", [LapTrack, LapTrackMulti]) +def test_connected_edges_splitting(tracker_class) -> None: coords = [ np.array([[10, 10], [11, 11], [13, 12]]), np.array([[10, 10], [13, 11], [13, 15]]), ] - lt = LapTrack( + lt = tracker_class( gap_closing_cost_cutoff=100, splitting_cost_cutoff=100, merging_cost_cutoff=False, From ed7765827627e1e11263898427db01b08b79fad8 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 10 Aug 2022 11:11:39 +0900 Subject: [PATCH 13/14] all tests running --- src/laptrack/_tracking.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/src/laptrack/_tracking.py b/src/laptrack/_tracking.py index d90c5ca5..fe2e34ad 100644 --- a/src/laptrack/_tracking.py +++ b/src/laptrack/_tracking.py @@ -641,20 +641,27 @@ class LapTrackMulti(LapTrackBase): description="if True, remove segment connections if splitting did not happen.", ) - def _get_segment_connecting_matrix(self, segments_df): + def _get_segment_connecting_matrix( + self, segments_df, force_end_nodes=[], force_start_nodes=[] + ): return _get_segment_end_connecting_matrix( segments_df, 1, # only arrow 1-frame difference self.segment_connecting_metric, self.segment_connecting_cost_cutoff, + force_end_nodes=force_end_nodes, + force_start_nodes=force_start_nodes, ) def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges): # "multi-step" type of fitting (Y. T. Fukai (2022)) - edges = list(split_edges) + list(merge_edges) segments_df = _get_segment_df(coords, track_tree) + edges = list(split_edges) + list(merge_edges) + force_end_nodes = [tuple(map(int, e[0])) for e in edges] + force_start_nodes = [tuple(map(int, e[1])) for e in edges] + ###### gap closing step ###### ###### split - merge step 1 ###### @@ -665,7 +672,11 @@ def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges) segment_connected_edges = [] for mode, get_matrix_fn in get_matrix_fns.items(): - segments_df, gap_closing_dist_matrix = get_matrix_fn(segments_df) + segments_df, gap_closing_dist_matrix = get_matrix_fn( + segments_df, + force_end_nodes=force_end_nodes, + force_start_nodes=force_start_nodes, + ) cost_matrix = build_frame_cost_matrix( gap_closing_dist_matrix, track_start_cost=self.segment_start_cost, @@ -790,7 +801,13 @@ def _dist_metric(c1, c2): dist_matrices[prefix], middle_points[prefix], ) = _get_splitting_merging_candidates( - segments_df, _coords, cutoff, prefix, _dist_metric + segments_df, + _coords, + cutoff, + prefix, + _dist_metric, + force_end_nodes=force_end_nodes, + force_start_nodes=force_start_nodes, ) splitting_dist_matrix = dist_matrices["first"] @@ -814,6 +831,7 @@ def _dist_metric(c1, c2): track_tree = _remove_no_split_merge_links( track_tree.copy(), segment_connected_edges ) + track_tree.add_edges_from(edges) return track_tree From ed107c83630f9092b9abeade9740e7a8407f55f0 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 10 Aug 2022 11:21:59 +0900 Subject: [PATCH 14/14] solved typing problem --- poetry.lock | 93 +++++++++++++++++++++++++++-- pyproject.toml | 2 + src/laptrack/_coo_matrix_builder.py | 14 ----- src/laptrack/_cost_matrix.py | 10 ++-- 4 files changed, 96 insertions(+), 23 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1af99088..0ef98031 100644 --- a/poetry.lock +++ b/poetry.lock @@ -14,6 +14,17 @@ category = "dev" optional = false python-versions = "*" +[[package]] +name = "argcomplete" +version = "2.0.0" +description = "Bash tab completion for argparse" +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.extras] +test = ["wheel", "pexpect", "flake8", "coverage"] + [[package]] name = "argon2-cffi" version = "21.3.0" @@ -205,6 +216,20 @@ category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +[[package]] +name = "colorlog" +version = "6.6.0" +description = "Add colours to the output of Python's logging module." +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} + +[package.extras] +development = ["black", "flake8", "mypy", "pytest", "types-colorama"] + [[package]] name = "coverage" version = "6.4.3" @@ -269,7 +294,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [[package]] name = "dparse" -version = "0.5.1" +version = "0.5.2" description = "A parser for Python dependency files" category = "dev" optional = false @@ -277,11 +302,11 @@ python-versions = ">=3.5" [package.dependencies] packaging = "*" -pyyaml = "*" toml = "*" [package.extras] pipenv = ["pipenv"] +conda = ["pyyaml"] [[package]] name = "entrypoints" @@ -649,6 +674,20 @@ python-versions = "*" six = "*" tornado = {version = "*", markers = "python_version > \"2.7\""} +[[package]] +name = "lxml" +version = "4.9.1" +description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API." +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, != 3.4.*" + +[package.extras] +cssselect = ["cssselect (>=0.7)"] +html5 = ["html5lib"] +htmlsoup = ["beautifulsoup4"] +source = ["Cython (>=0.29.7)"] + [[package]] name = "markupsafe" version = "2.1.1" @@ -740,7 +779,7 @@ test = ["black", "check-manifest", "flake8", "ipykernel", "ipython (<8.0.0)", "i [[package]] name = "nbconvert" -version = "6.5.0" +version = "6.5.1" description = "Converting Jupyter Notebooks" category = "dev" optional = false @@ -754,6 +793,7 @@ entrypoints = ">=0.2.2" jinja2 = ">=3.0" jupyter-core = ">=4.7" jupyterlab-pygments = "*" +lxml = "*" MarkupSafe = ">=2.0" mistune = ">=0.8.1,<2" nbclient = ">=0.5.0" @@ -849,6 +889,37 @@ docs = ["sphinx", "nbsphinx", "sphinxcontrib-github-alt", "sphinx-rtd-theme", "m json-logging = ["json-logging"] test = ["pytest", "coverage", "requests", "testpath", "nbval", "selenium", "pytest-cov", "requests-unixsocket"] +[[package]] +name = "nox" +version = "2022.8.7" +description = "Flexible test automation." +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +argcomplete = ">=1.9.4,<3.0" +colorlog = ">=2.6.1,<7.0.0" +packaging = ">=20.9" +py = ">=1.4,<2.0.0" +virtualenv = ">=14" + +[package.extras] +tox_to_nox = ["jinja2", "tox"] + +[[package]] +name = "nox-poetry" +version = "1.0.1" +description = "nox-poetry" +category = "dev" +optional = false +python-versions = ">=3.7,<4.0" + +[package.dependencies] +nox = ">=2020.8.22" +packaging = ">=20.9" +tomlkit = ">=0.7" + [[package]] name = "numpy" version = "1.23.1" @@ -1551,6 +1622,14 @@ category = "dev" optional = false python-versions = ">=3.7" +[[package]] +name = "tomlkit" +version = "0.11.3" +description = "Style preserving TOML library" +category = "dev" +optional = false +python-versions = ">=3.6,<4.0" + [[package]] name = "tornado" version = "6.2" @@ -1690,11 +1769,12 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = "^3.8.1,<3.11" -content-hash = "5d62274bc91a3594be5dc389612bf3d007456916980004ba31cbe3f92841e545" +content-hash = "0f854b89e63c9d5b78d82ad33f1ce9e109f2149ebbd289c66727d0ccb27370c3" [metadata.files] alabaster = [] appnope = [] +argcomplete = [] argon2-cffi = [] argon2-cffi-bindings = [] asttokens = [] @@ -1711,6 +1791,7 @@ cfgv = [] charset-normalizer = [] click = [] colorama = [] +colorlog = [] coverage = [] cycler = [] debugpy = [] @@ -1745,6 +1826,7 @@ jupyterlab-pygments = [] jupyterlab-widgets = [] kiwisolver = [] livereload = [] +lxml = [] markupsafe = [] matplotlib = [] matplotlib-inline = [] @@ -1758,6 +1840,8 @@ nest-asyncio = [] networkx = [] nodeenv = [] notebook = [] +nox = [] +nox-poetry = [] numpy = [] packaging = [] pandas = [] @@ -1816,6 +1900,7 @@ terminado = [] tinycss2 = [] toml = [] tomli = [] +tomlkit = [] tornado = [] traitlets = [] typeguard = [] diff --git a/pyproject.toml b/pyproject.toml index 548600e1..a7ff35e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,8 @@ matplotlib = "^3.5.1" sphinxcontrib-napoleon = "^0.7" autodoc_pydantic = "^1.7.2" pre-commit = "^2.20.0" +nox = "^2022.8.7" +nox-poetry = "^1.0.1" [tool.poetry.scripts] laptrack = "laptrack.__main__:main" diff --git a/src/laptrack/_coo_matrix_builder.py b/src/laptrack/_coo_matrix_builder.py index 988b9451..1cb21db6 100644 --- a/src/laptrack/_coo_matrix_builder.py +++ b/src/laptrack/_coo_matrix_builder.py @@ -8,7 +8,6 @@ import numpy as np import numpy.typing as npt from scipy.sparse import coo_matrix -from scipy.sparse import csr_matrix from ._typing_utils import Float from ._typing_utils import Int @@ -160,19 +159,6 @@ def to_coo_matrix(self) -> coo_matrix: shape=(self.n_row, self.n_col), ) - def to_csr_matrix(self) -> coo_matrix: - """Generate `csr_matrix`. - - Returns - ------- - matrix : csr_matrix - the generated csr_matrix - """ - return csr_matrix( - (self.data, (self.row, self.col)), - shape=(self.n_row, self.n_col), - ) - def __setitem__( self, index: Tuple[Union[Int, IndexType], Union[Int, IndexType]], value ): diff --git a/src/laptrack/_cost_matrix.py b/src/laptrack/_cost_matrix.py index 3532bbfb..87439337 100644 --- a/src/laptrack/_cost_matrix.py +++ b/src/laptrack/_cost_matrix.py @@ -2,7 +2,7 @@ from typing import Union import numpy as np -from scipy.sparse import csr_matrix +from scipy.sparse import coo_matrix from ._coo_matrix_builder import coo_matrix_builder from ._typing_utils import Float @@ -17,7 +17,7 @@ def build_frame_cost_matrix( *, track_start_cost: Optional[Union[Float, FloatArray]], track_end_cost: Optional[Union[Float, FloatArray]], -) -> csr_matrix: +) -> coo_matrix: """Build sparce array for frame-linking cost matrix. Parameters @@ -61,7 +61,7 @@ def build_frame_cost_matrix( C.data = C.data + EPSILON - return C.to_csr_matrix() + return C.to_coo_matrix() def build_segment_cost_matrix( @@ -75,7 +75,7 @@ def build_segment_cost_matrix( alternative_cost_factor: Float = 1.05, alternative_cost_percentile: Float = 90, alternative_cost_percentile_interpolation: str = "lower", -) -> Optional[csr_matrix]: +) -> Optional[coo_matrix]: """Build sparce array for segment-linking cost matrix. Parameters @@ -175,4 +175,4 @@ def build_segment_cost_matrix( C[upper_left_cols + M + N1, upper_left_rows + M + N2] = min_val C.data = C.data + EPSILON - return C.to_csr_matrix() + return C.to_coo_matrix()