diff --git a/src/laptrack/_cost_matrix.py b/src/laptrack/_cost_matrix.py index 3959c7e0..4a4ff35c 100644 --- a/src/laptrack/_cost_matrix.py +++ b/src/laptrack/_cost_matrix.py @@ -43,12 +43,20 @@ def build_frame_cost_matrix( if track_start_cost is None: if len(C.data) > 0: - track_start_cost = np.max(C.data) * 1.05 + max_val = np.max(C.data) + if max_val > 0: + track_start_cost = max_val * 1.05 + else: + track_start_cost = EPSILON else: track_start_cost = 1.05 if track_end_cost is None: if len(C.data) > 0: - track_end_cost = np.max(C.data) * 1.05 + max_val = np.max(C.data) + if max_val > 0: + track_end_cost = max_val * 1.05 + else: + track_end_cost = EPSILON else: track_end_cost = 1.05 @@ -159,6 +167,8 @@ def build_segment_cost_matrix( ) * alternative_cost_factor ) + if alternative_cost == 0: + alternative_cost = EPSILON if track_start_cost is None: track_start_cost = alternative_cost if track_end_cost is None: diff --git a/tests/data/same_position_example.csv b/tests/data/same_position_example.csv new file mode 100644 index 00000000..7a43c6a4 --- /dev/null +++ b/tests/data/same_position_example.csv @@ -0,0 +1,76 @@ +,seconds,h +0,0,0 +1,1,0 +2,2,3 +3,3,5 +4,4,6 +5,5,6 +6,6,8 +7,7,9 +8,8,11 +9,9,12 +10,10,13 +11,11,14 +12,12,15 +13,13,16 +14,14,17 +15,15,19 +16,16,20 +17,17,20 +18,18,23 +19,19,23 +20,20,25 +21,21,27 +22,22,28 +23,23,29 +24,24,29 +25,25,0 +26,25,30 +27,26,0 +28,26,31 +29,27,3 +30,27,32 +31,28,5 +32,28,34 +33,29,7 +34,29,34 +35,30,8 +36,30,35 +37,31,9 +38,31,35 +39,32,12 +40,32,36 +41,33,13 +42,33,38 +43,34,14 +44,34,39 +45,35,15 +46,35,41 +47,36,16 +48,36,42 +49,37,17 +50,37,43 +51,38,19 +52,38,44 +53,39,20 +54,39,45 +55,40,21 +56,40,46 +57,41,22 +58,41,47 +59,42,22 +60,42,48 +61,43,23 +62,43,50 +63,44,24 +64,44,50 +65,45,25 +66,45,51 +67,46,26 +68,46,52 +69,47,27 +70,47,54 +71,48,28 +72,48,54 +73,49,28 +74,49,56 diff --git a/tests/test_overlap_tracking.py b/tests/test_overlap_tracking.py index e2f0a623..d0745624 100644 --- a/tests/test_overlap_tracking.py +++ b/tests/test_overlap_tracking.py @@ -123,9 +123,13 @@ def metric(c1, c2, params): track_df2, split_df2, merge_df2 = olt.predict_overlap_dataframe(labels) track_df1 = track_df1.droplevel("index").set_index(["label"], append=True) - assert all(track_df1[["tree_id", "track_id"]] == track_df2[["tree_id", "track_id"]]) - assert all(split_df1 == split_df2) - assert all(merge_df1 == merge_df2) + assert ( + (track_df1[["tree_id", "track_id"]] == track_df2[["tree_id", "track_id"]]) + .all() + .all() + ) + assert (split_df1 == split_df2).all().all() + assert (merge_df1 == merge_df2).all().all() """ diff --git a/tests/test_tracking.py b/tests/test_tracking.py index 0db9c56a..b1cd2363 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -133,9 +133,9 @@ def test_reproducing_trackmate(testdata, parallel_backend) -> None: assert not any(merge_df.duplicated()) track_df2, split_df2, merge_df2 = convert_tree_to_dataframe(track_tree, coords) track_df2 = track_df2.rename(columns={"coord-0": "x", "coord-1": "y"}) - assert all(track_df == track_df2) - assert all(split_df == split_df2) - assert all(merge_df == merge_df2) + assert (track_df == track_df2).all().all() + assert (split_df == split_df2).all().all() + assert (merge_df == merge_df2).all().all() # check index offset track_df3, split_df3, merge_df3 = lt.predict_dataframe( @@ -155,18 +155,18 @@ def test_reproducing_trackmate(testdata, parallel_backend) -> None: if not merge_df4.empty: merge_df4["parent_track_id"] = merge_df4["parent_track_id"] + 2 merge_df4["child_track_id"] = merge_df4["child_track_id"] + 2 - assert all(track_df3 == track_df2) - assert all(split_df3 == split_df2) - assert all(merge_df3 == merge_df2) + assert (track_df3 == track_df4).all().all() + assert (split_df3 == split_df4).all().all() + assert (merge_df3 == merge_df4).all().all() track_df, split_df, merge_df = lt.predict_dataframe( df, ["x", "y"], only_coordinate_cols=False ) assert all(track_df["frame_y"] == track_df2.index.get_level_values("frame")) track_df = track_df.drop(columns=["frame_y"]) - assert all(track_df == track_df2) - assert all(split_df == split_df2) - assert all(merge_df == merge_df2) + assert (track_df == track_df2).all().all() + assert (split_df == split_df2).all().all() + assert (merge_df == merge_df2).all().all() @pytest.fixture(params=[2, 3, 4]) @@ -199,6 +199,37 @@ def test_tracking_zero_distance() -> None: assert set(edges) == set([((0, 0), (1, 0)), ((0, 1), (1, 1))]) +def test_tracking_zero_distance2(shared_datadir: str) -> None: + data = pd.read_csv(path.join(shared_datadir, "same_position_example.csv")) + lt = LapTrack( + track_cost_cutoff=15**2, + splitting_cost_cutoff=15**2, + merging_cost_cutoff=15**2, + ) + + track_df1, split_df1, merge_df1 = lt.predict_dataframe( + data, + coordinate_cols=["h"], + frame_col="seconds", + only_coordinate_cols=False, + ) + + data2 = data.copy() + data2["h"] += np.random.random(len(data2)) * 1e-2 + track_df2, split_df2, merge_df2 = lt.predict_dataframe( + data2, + coordinate_cols=["h"], + frame_col="seconds", + only_coordinate_cols=False, + ) + + assert ( + (track_df1[["track_id", "tree_id"]] == track_df2[["track_id", "tree_id"]]) + .all() + .all() + ) + + def test_tracking_not_connected() -> None: coords = [np.array([[10, 10], [12, 11]]), np.array([[50, 50], [53, 51]])] lt = LapTrack(