Skip to content

Commit

Permalink
Merge pull request #406 from yfukai/allow_no_coord_frame
Browse files Browse the repository at this point in the history
  • Loading branch information
yfukai authored Apr 2, 2024
2 parents cf0323e + 59551ce commit bb757dd
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 29 deletions.
34 changes: 22 additions & 12 deletions src/laptrack/_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
convert_dataframe_to_coords_frame_index,
convert_tree_to_dataframe,
)
from .utils import _coord_is_empty

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -219,8 +220,9 @@ def _predict_links(
# initialize tree
track_tree = nx.Graph()
for frame, coord in enumerate(coords):
for j in range(coord.shape[0]):
track_tree.add_node((frame, j))
if not _coord_is_empty(coord):
for j in range(coord.shape[0]):
track_tree.add_node((frame, j))

# linking between frames

Expand All @@ -231,6 +233,9 @@ def _predict_link_single_frame(
coord1: np.ndarray,
coord2: np.ndarray,
) -> List[EdgeType]:
if _coord_is_empty(coord1) or _coord_is_empty(coord2):
return []

force_end_indices = [e[0][1] for e in edges_list if e[0][0] == frame]
force_start_indices = [e[1][1] for e in edges_list if e[1][0] == frame + 1]
dist_matrix = cdist(coord1, coord2, metric=self.track_dist_metric)
Expand Down Expand Up @@ -431,7 +436,11 @@ def to_candidates(row, coords):
# note: can use KDTree if metric is distance,
# but might not be appropriate for general metrics
# https://stackoverflow.com/questions/35459306/find-points-within-cutoff-distance-of-other-points-with-scipy # noqa
if other_frame < 0 or len(coords) <= other_frame:
if (
other_frame < 0
or len(coords) <= other_frame
or _coord_is_empty(coords[other_frame])
):
return [], []
target_dist_matrix = cdist(
[target_coord], coords[other_frame], metric=dist_metric
Expand Down Expand Up @@ -658,12 +667,16 @@ def predict(
The graph for the tracks, whose nodes are `(frame, index)`.
The edge direction represents the time order.
"""
if any(list(map(lambda coord: coord.ndim != 2, coords))):
raise ValueError("the elements in coords must be 2-dim.")
###### Check the input format ######
nonempty_coords = [coord for coord in coords if not _coord_is_empty(coord)]

if any(list(map(lambda coord: coord.ndim != 2, nonempty_coords))):
raise ValueError("the elements in coords must be 2-dim or an empty array.")
coord_dim = coords[0].shape[1]
if any(list(map(lambda coord: coord.shape[1] != coord_dim, coords))):
if any(list(map(lambda coord: coord.shape[1] != coord_dim, nonempty_coords))):
raise ValueError("the second dimension in coords must have the same size")

###### Check the connected edges format #######
if connected_edges:
connected_edges = [
(n1, n2) if n1[0] < n2[0] else (n2, n1) for n1, n2 in connected_edges
Expand All @@ -690,15 +703,15 @@ def predict(
split_edges = []
merge_edges = []

####### Particle-particle tracking #######
###### Particle-particle tracking ######
track_tree = self._predict_links(
coords, segment_connected_edges, list(split_edges) + list(merge_edges)
)
track_tree = self._predict_gap_split_merge(
coords, track_tree, split_edges, merge_edges
)

# convert to directed graph
# convert the result to directed graph
edges = [
(n1, n2) if n1[0] < n2[0] else (n2, n1) for (n1, n2) in track_tree.edges()
]
Expand All @@ -713,7 +726,6 @@ def predict_dataframe(
df: pd.DataFrame,
coordinate_cols: List[str],
frame_col: str = "frame",
validate_frame: bool = True,
only_coordinate_cols: bool = True,
connected_edges: Optional[List[Tuple[Int, Int]]] = None,
index_offset: Int = 0,
Expand All @@ -728,8 +740,6 @@ def predict_dataframe(
The list of the columns to use for coordinates.
frame_col : str, optional
The column name to use for the frame index. Defaults to "frame".
validate_frame : bool, optional
Whether to validate the frame. Defaults to True.
only_coordinate_cols : bool, optional
Whether to use only the coordinate columns. Defaults to True.
connected_edges : Optional[List[Tuple[Int,Int]]]
Expand Down Expand Up @@ -761,7 +771,7 @@ def predict_dataframe(
- "child_track_id" : The track id of the child.
"""
coords, frame_index = convert_dataframe_to_coords_frame_index(
df, coordinate_cols, frame_col, validate_frame
df, coordinate_cols, frame_col
)
if connected_edges is not None:
connected_edges2 = [
Expand Down
39 changes: 22 additions & 17 deletions src/laptrack/data_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ._typing_utils import Int
from ._typing_utils import NumArray
from .utils import _coord_is_empty

IntTuple = Tuple[Int, Int]

Expand All @@ -19,7 +20,6 @@ def convert_dataframe_to_coords(
df: pd.DataFrame,
coordinate_cols: List[str],
frame_col: str = "frame",
validate_frame: bool = True,
) -> List[NumArray]:
"""
Convert a track dataframe to a list of coordinates for input.
Expand All @@ -29,29 +29,30 @@ def convert_dataframe_to_coords(
df : pd.DataFrame
The input dataframe.
coordinate_cols : List[str]
The list of columns to use for coordinates.
The list of columns used for the coordinates.
frame_col : str, default "frame"
The column name to use for the frame index.
validate_frame : bool, default True
Whether to validate the frame.
The column name used for the integer frame index.
Returns
-------
coords : List[np.ndarray]
The list of the coordinates.
The list of the coordinates. Note that the first frame is the minimum frame index.
"""
grps = list(df.groupby(frame_col, sort=True))
if validate_frame:
assert np.array_equal(np.arange(df[frame_col].max() + 1), [g[0] for g in grps])
coords = [grp[list(coordinate_cols)].values for _frame, grp in grps]
coords_dict = {frame: grp[list(coordinate_cols)].values for frame, grp in grps}
min_frame = min(coords_dict.keys())
max_frame = max(coords_dict.keys())
coords = [
coords_dict.get(frame, np.array([]))
for frame in range(min_frame, max_frame + 1)
]
return coords


def convert_dataframe_to_coords_frame_index(
df: pd.DataFrame,
coordinate_cols: List[str],
frame_col: str = "frame",
validate_frame: bool = True,
) -> Tuple[List[NumArray], List[Tuple[int, int]]]:
"""
Convert a track dataframe to a list of coordinates for input with (frame,index) list.
Expand All @@ -64,8 +65,6 @@ def convert_dataframe_to_coords_frame_index(
The list of columns to use for coordinates.
frame_col : str, default "frame"
The column name to use for the frame index.
validate_frame : bool, default True
Whether to validate the frame.
Returns
-------
Expand All @@ -74,19 +73,22 @@ def convert_dataframe_to_coords_frame_index(
frame_index : List[Tuple[int, int]]
The (frame, index) list for the original iloc of the dataframe.
"""
assert "iloc__" not in df.columns
assert (
"iloc__" not in df.columns
), 'The column name "iloc__" is reserved and cannot be used.'
df = df.copy()
df["iloc__"] = np.arange(len(df), dtype=int)

coords = convert_dataframe_to_coords(
df, list(coordinate_cols) + ["iloc__"], frame_col, validate_frame
df, list(coordinate_cols) + ["iloc__"], frame_col
)

inverse_map = dict(
sum(
[
[(int(c2[-1]), (frame, index)) for index, c2 in enumerate(c)]
for frame, c in enumerate(coords)
[(int(c[-1]), (frame, index)) for index, c in enumerate(coord)]
for frame, coord in enumerate(coords)
if not _coord_is_empty(coord)
],
[],
)
Expand All @@ -96,7 +98,10 @@ def convert_dataframe_to_coords_frame_index(
assert set(inverse_map.keys()) == set(ilocs)
frame_index = [inverse_map[i] for i in ilocs]

return [c[:, :-1] for c in coords], frame_index
return [
coord[:, :-1] if not _coord_is_empty(coord) else np.array([])
for coord in coords
], frame_index


def convert_tree_to_dataframe(
Expand Down
6 changes: 6 additions & 0 deletions src/laptrack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import List
from typing import Tuple

import numpy as np

from ._typing_utils import EdgeType
from ._typing_utils import Int

Expand All @@ -22,3 +24,7 @@ def order_edges(edges: EdgeType) -> List[Tuple[Tuple[Int, Int], Tuple[Int, Int]]
"""
return [(n1, n2) if n1[0] < n2[0] else (n2, n1) for (n1, n2) in edges]


def _coord_is_empty(coord):
return coord is None or np.array(coord).size == 0
50 changes: 50 additions & 0 deletions tests/test_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,56 @@ def test_tracking_zero_distance2(shared_datadir: str) -> None:
)


@pytest.mark.parametrize("splitting_cost_cutoff", [False, 15**2])
@pytest.mark.parametrize("merging_cost_cutoff", [False, 15**2])
def test_allow_frame_without_coords(splitting_cost_cutoff, merging_cost_cutoff) -> None:
coords_edges = [
([np.array([[10, 10], [12, 11]]), np.array([])], set()),
(
[
np.array([[10, 10], [12, 11]]),
np.array([]),
np.array([[10, 10], [12, 11]]),
np.array([[10, 10], [13, 11]]),
],
set([((2, 0), (3, 0)), ((2, 1), (3, 1))]),
),
]

lt = LapTrack(
track_cost_cutoff=15**2,
gap_closing_cost_cutoff=False,
splitting_cost_cutoff=splitting_cost_cutoff,
merging_cost_cutoff=merging_cost_cutoff,
)
for coords, target_edges in coords_edges:
track_tree = lt.predict(coords)
edges = track_tree.edges()
assert set(edges) == target_edges

df = pd.DataFrame(
{
"x": [10, 12, 10, 12, 9, 11],
"y": [10, 11, 10, 11, 9, 12],
"frame": [2, 2, 4, 4, 5, 5],
}
)
track_iloc_sets = [{0}, {1}, {2, 4}, {3, 5}]
values_sets = [
set([tuple(df.iloc[i][["frame", "x", "y"]].to_list()) for i in s])
for s in track_iloc_sets
]

track_df, split_df, merge_df = lt.predict_dataframe(
df, ["x", "y"], only_coordinate_cols=False
)
track_df = track_df.set_index(["frame_y", "x", "y"])
assert split_df.empty
assert merge_df.empty
for _track_id, grp in track_df.groupby("track_id"):
assert set(grp.index) in values_sets


def test_tracking_not_connected() -> None:
coords = [np.array([[10, 10], [12, 11]]), np.array([[50, 50], [53, 51]])]
lt = LapTrack(
Expand Down
10 changes: 10 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import numpy as np

from laptrack.utils import _coord_is_empty


def test_coord_is_empty():
assert _coord_is_empty(None)
assert _coord_is_empty([])
assert _coord_is_empty(np.array([]))
assert not _coord_is_empty([1, 2])

0 comments on commit bb757dd

Please sign in to comment.