Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor field names #472

Merged
merged 3 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
- id: black
- id: black-jupyter
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
rev: 7.1.1
hooks:
- id: flake8
types: [python]
Expand Down
10 changes: 5 additions & 5 deletions docs/examples/3D_tracking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,12 @@
"source": [
"max_distance = 15\n",
"lt = LapTrack(\n",
" track_dist_metric=\"sqeuclidean\", # The similarity metric for particles. See `scipy.spatial.distance.cdist` for allowed values.\n",
" splitting_dist_metric=\"sqeuclidean\", # The similarity metric for splits.\n",
" metric=\"sqeuclidean\", # The similarity metric for particles. See `scipy.spatial.distance.cdist` for allowed values.\n",
" splitting_metric=\"sqeuclidean\", # The similarity metric for splits.\n",
" # the square of the cutoff distance for the \"sqeuclidean\" metric\n",
" track_cost_cutoff=max_distance**2,\n",
" gap_closing_cost_cutoff=max_distance**2,\n",
" splitting_cost_cutoff=max_distance**2,\n",
" cutoff=max_distance**2,\n",
" gap_closing_cutoff=max_distance**2,\n",
" splitting_cutoff=max_distance**2,\n",
")"
]
},
Expand Down
12 changes: 6 additions & 6 deletions docs/examples/api_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,13 @@
"source": [
"max_distance = 15\n",
"lt = LapTrack(\n",
" track_dist_metric=\"sqeuclidean\", # The similarity metric for particles. See `scipy.spatial.distance.cdist` for allowed values.\n",
" splitting_dist_metric=\"sqeuclidean\",\n",
" merging_dist_metric=\"sqeuclidean\",\n",
" metric=\"sqeuclidean\", # The similarity metric for particles. See `scipy.spatial.distance.cdist` for allowed values.\n",
" splitting_metric=\"sqeuclidean\",\n",
" merging_metric=\"sqeuclidean\",\n",
" # the square of the cutoff distance for the \"sqeuclidean\" metric\n",
" track_cost_cutoff=max_distance**2,\n",
" splitting_cost_cutoff=max_distance**2, # or False for non-splitting case\n",
" merging_cost_cutoff=max_distance**2, # or False for non-merging case\n",
" cutoff=max_distance**2,\n",
" splitting_cutoff=max_distance**2, # or False for non-splitting case\n",
" merging_cutoff=max_distance**2, # or False for non-merging case\n",
")"
]
},
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/bright_spots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@
}
],
"source": [
"lt = LapTrack(track_cost_cutoff=5**2)\n",
"lt = LapTrack(cutoff=5**2)\n",
"track_df, _, _ = lt.predict_dataframe(spots_df, [\"y\", \"x\"], only_coordinate_cols=False)\n",
"track_df = track_df.reset_index()\n",
"viewer.add_tracks(track_df[[\"track_id\", \"frame\", \"y\", \"x\"]], tail_length=50)"
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/cell_segmentation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
"metadata": {},
"outputs": [],
"source": [
"lt = LapTrack(track_cost_cutoff=15**2, splitting_cost_cutoff=30**2)\n",
"lt = LapTrack(cutoff=15**2, splitting_cutoff=30**2)\n",
"track_df, split_df, merge_df = lt.predict_dataframe(\n",
" regionprops_df.copy(),\n",
" coordinate_cols=[\"centroid-0\", \"centroid-1\"],\n",
Expand Down
12 changes: 6 additions & 6 deletions docs/examples/custom_metric.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The defined `metric` function is used for the frame-to-frame linking (`track_dist_metric`), gap closing (`gap_closing_dist_metric`) and the splitting connection (`splitting_dist_metric`)."
"The defined `metric` function is used for the frame-to-frame linking (`metric`), gap closing (`gap_closing_metric`) and the splitting connection (`splitting_metric`)."
]
},
{
Expand All @@ -459,12 +459,12 @@
"outputs": [],
"source": [
"lt = LapTrack(\n",
" track_dist_metric=metric,\n",
" track_cost_cutoff=0.9,\n",
" gap_closing_dist_metric=metric,\n",
" metric=metric,\n",
" cutoff=0.9,\n",
" gap_closing_metric=metric,\n",
" gap_closing_max_frame_count=1,\n",
" splitting_dist_metric=metric,\n",
" splitting_cost_cutoff=0.9,\n",
" splitting_metric=metric,\n",
" splitting_cutoff=0.9,\n",
")"
]
},
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/napari_interactive_fix.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@
}
],
"source": [
"lt = LapTrack(track_cost_cutoff=100**2, splitting_cost_cutoff=20**2)\n",
"lt = LapTrack(cutoff=100**2, splitting_cutoff=20**2)\n",
"track_df, split_df, _ = lt.predict_dataframe(\n",
" regionprops_df,\n",
" coordinate_cols=[\"centroid-0\", \"centroid-1\"],\n",
Expand Down Expand Up @@ -743,7 +743,7 @@
"metadata": {},
"outputs": [],
"source": [
"lt = LapTrack(track_cost_cutoff=100**2, splitting_cost_cutoff=20**2)\n",
"lt = LapTrack(cutoff=100**2, splitting_cutoff=20**2)\n",
"new_track_df, _new_split_df, _ = lt.predict_dataframe(\n",
" new_regionprops_df,\n",
" coordinate_cols=[\"centroid-0\", \"centroid-1\"],\n",
Expand Down
20 changes: 10 additions & 10 deletions docs/examples/overlap_tracking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,12 @@
"\n",
"```python\n",
"OverLapTrack(\n",
" track_cost_cutoff=0.9,\n",
" track_dist_metric_coefs = (1.0, -1.0, 0.0, 0.0, 0.0),\n",
" gap_closing_dist_metric_coefs = (1.0, -1.0, 0.0, 0.0, 0.0),\n",
" cutoff=0.9,\n",
" metric_coefs = (1.0, -1.0, 0.0, 0.0, 0.0),\n",
" gap_closing_metric_coefs = (1.0, -1.0, 0.0, 0.0, 0.0),\n",
" gap_closing_max_frame_count=1,\n",
" splitting_cost_cutoff=0.9,\n",
" splitting_dist_metric_coefs = (1.0, 0.0, 0.0, 0.0, -1.0),\n",
" splitting_cutoff=0.9,\n",
" splitting_metric_coefs = (1.0, 0.0, 0.0, 0.0, -1.0),\n",
")\n",
"```"
]
Expand All @@ -157,12 +157,12 @@
"outputs": [],
"source": [
"olt = OverLapTrack(\n",
" track_cost_cutoff=0.9,\n",
" track_dist_metric_coefs=(1.0, -1.0, 0.0, 0.0, 0.0),\n",
" gap_closing_dist_metric_coefs=(1.0, -1.0, 0.0, 0.0, 0.0),\n",
" cutoff=0.9,\n",
" metric_coefs=(1.0, -1.0, 0.0, 0.0, 0.0),\n",
" gap_closing_metric_coefs=(1.0, -1.0, 0.0, 0.0, 0.0),\n",
" gap_closing_max_frame_count=1,\n",
" splitting_cost_cutoff=0.9,\n",
" splitting_dist_metric_coefs=(1.0, 0.0, 0.0, 0.0, -1.0),\n",
" splitting_cutoff=0.9,\n",
" splitting_metric_coefs=(1.0, 0.0, 0.0, 0.0, -1.0),\n",
")\n",
"track_df, split_df, merge_df = olt.predict_overlap_dataframe(labels)"
]
Expand Down
28 changes: 11 additions & 17 deletions src/laptrack/_overlap_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,27 @@
class OverLapTrack(LapTrack):
"""Tracking by label overlaps."""

track_dist_metric_coefs: CoefType = Field(
metric_coefs: CoefType = Field(
(1.0, 0.0, 0.0, 0.0, -1.0),
description="The coefficients to calculate the distance for the overlapping labels."
+ "Must be tuple of 5 floats of `(offset, overlap_coef, iou_coef, ratio_1_coef, ratio_2_coef)`."
+ "The distance is calculated by"
+ "`offset + overlap_coef * overlap + iou_coef * iou + ratio_1_coef * ratio_1 + ratio_2_coef * ratio_2`.",
)
gap_closing_dist_metric_coefs: CoefType = Field(
gap_closing_metric_coefs: CoefType = Field(
(1.0, 0.0, 0.0, 0.0, -1.0),
description="The coefficients to calculate the distance for the overlapping labels."
+ "See `track_dist_metric_coefs` for details.",
+ "See `metric_coefs` for details.",
)
splitting_dist_metric_coefs: CoefType = Field(
splitting_metric_coefs: CoefType = Field(
(1.0, 0.0, 0.0, 0.0, -1.0),
description="The coefficients to calculate the distance for the overlapping labels."
+ "See `track_dist_metric_coefs` for details.",
+ "See `metric_coefs` for details.",
)
merging_dist_metric_coefs: CoefType = Field(
merging_metric_coefs: CoefType = Field(
(1.0, 0.0, 0.0, 0.0, -1.0),
description="The coefficients to calculate the distance for the overlapping labels."
+ "See `track_dist_metric_coefs` for details.",
+ "See `metric_coefs` for details.",
)

def predict_overlap_dataframe(self, labels: Union[IntArray, List[IntArray]]):
Expand Down Expand Up @@ -97,16 +97,10 @@ def metric(c1, c2, params):
)
return distance

self.track_dist_metric = partial(metric, params=self.track_dist_metric_coefs)
self.gap_closing_dist_metric = partial(
metric, params=self.gap_closing_dist_metric_coefs
)
self.splitting_dist_metric = partial(
metric, params=self.splitting_dist_metric_coefs
)
self.merging_dist_metric = partial(
metric, params=self.merging_dist_metric_coefs
)
self.metric = partial(metric, params=self.metric_coefs)
self.gap_closing_metric = partial(metric, params=self.gap_closing_metric_coefs)
self.splitting_metric = partial(metric, params=self.splitting_metric_coefs)
self.merging_metric = partial(metric, params=self.merging_metric_coefs)

track_df, split_df, merge_df = super().predict_dataframe(
lo.frame_label_df, ["frame", "label"], only_coordinate_cols=False
Expand Down
76 changes: 48 additions & 28 deletions src/laptrack/_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from enum import Enum
from functools import partial
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
Expand All @@ -26,7 +27,7 @@
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
from pydantic import BaseModel, Field, Extra
from pydantic import BaseModel, Field, model_validator


from ._cost_matrix import build_frame_cost_matrix, build_segment_cost_matrix
Expand Down Expand Up @@ -95,26 +96,38 @@ def _get_segment_df(coords, track_tree):
return segments_df


class LapTrack(BaseModel, extra=Extra.forbid):
_ALIAS_FIELDS = {
"track_dist_metric": "metric",
"track_cost_cutoff": "cutoff",
"gap_closing_dist_metric": "gap_closing_metric",
"gap_closing_cost_cutoff": "gap_closing_cutoff",
"splitting_dist_metric": "splitting_metric",
"splitting_cost_cutoff": "splitting_cutoff",
"merging_dist_metric": "merging_metric",
"merging_cost_cutoff": "merging_cutoff",
}


class LapTrack(BaseModel, extra="forbid"):
"""Tracking class for LAP tracker with parameters."""

track_dist_metric: Union[str, Callable] = Field(
metric: Union[str, Callable] = Field(
"sqeuclidean",
description="The metric for calculating track linking cost. "
+ "See documentation for `scipy.spatial.distance.cdist` for accepted values.",
)
track_cost_cutoff: float = Field(
cutoff: float = Field(
15**2,
description="The cost cutoff for the connected points in the track. "
+ "For default cases with `dist_metric='sqeuclidean'`, "
+ "this value should be squared maximum distance.",
)
gap_closing_dist_metric: Union[str, Callable] = Field(
gap_closing_metric: Union[str, Callable] = Field(
"sqeuclidean",
description="The metric for calculating gap closing cost. "
+ "See documentation for `scipy.spatial.distance.cdist` for accepted values.",
)
gap_closing_cost_cutoff: Union[Literal[False], float] = Field(
gap_closing_cutoff: Union[Literal[False], float] = Field(
15**2,
description="The cost cutoff for gap closing. "
+ "For default cases with `dist_metric='sqeuclidean'`, "
Expand All @@ -125,26 +138,24 @@ class LapTrack(BaseModel, extra=Extra.forbid):
2, description="The maximum frame gaps, by default 2."
)

splitting_dist_metric: Union[str, Callable] = Field(
splitting_metric: Union[str, Callable] = Field(
"sqeuclidean",
description="The metric for calculating splitting cost. "
+ "See `track_dist_metric`.",
description="The metric for calculating splitting cost. " + "See `metric`.",
)
splitting_cost_cutoff: Union[Literal[False], float] = Field(
splitting_cutoff: Union[Literal[False], float] = Field(
False,
description="The cost cutoff for splitting. "
+ "See `gap_closing_cost_cutoff`. "
+ "See `gap_closing_cutoff`. "
+ "If False, no splitting is allowed.",
)
merging_dist_metric: Union[str, Callable] = Field(
merging_metric: Union[str, Callable] = Field(
"sqeuclidean",
description="The metric for calculating merging cost. "
+ "See `track_dist_metric`",
description="The metric for calculating merging cost. " + "See `metric`",
)
merging_cost_cutoff: Union[Literal[False], float] = Field(
merging_cutoff: Union[Literal[False], float] = Field(
False,
description="The cost cutoff for merging. "
+ "See `gap_closing_cost_cutoff`. "
+ "See `gap_closing_cutoff`. "
+ "If False, no merging is allowed.",
)

Expand Down Expand Up @@ -197,6 +208,19 @@ class LapTrack(BaseModel, extra=Extra.forbid):
exclude=True,
)

@model_validator(mode="before")
@classmethod
def _check_deprecated_names(cls, data: Any) -> Any:
if isinstance(data, dict):
for old_name, new_name in _ALIAS_FIELDS.items():
if old_name in data:
warnings.warn(
f"Use of `{old_name}` is deprecated, use `{new_name}` instead.",
DeprecationWarning,
)
data[new_name] = data.pop(old_name)
return data

def _predict_links(
self, coords, segment_connected_edges, split_merge_edges
) -> nx.Graph:
Expand Down Expand Up @@ -238,11 +262,11 @@ def _predict_link_single_frame(

force_end_indices = [e[0][1] for e in edges_list if e[0][0] == frame]
force_start_indices = [e[1][1] for e in edges_list if e[1][0] == frame + 1]
dist_matrix = cdist(coord1, coord2, metric=self.track_dist_metric)
dist_matrix = cdist(coord1, coord2, metric=self.metric)
dist_matrix[force_end_indices, :] = np.inf
dist_matrix[:, force_start_indices] = np.inf

ind = np.where(dist_matrix < self.track_cost_cutoff)
ind = np.where(dist_matrix < self.cutoff)
dist_matrix = coo_matrix_builder(
dist_matrix.shape,
row=ind[0],
Expand Down Expand Up @@ -317,7 +341,7 @@ def _get_gap_closing_matrix(
the cost matrix for gap closing candidates

"""
if self.gap_closing_cost_cutoff:
if self.gap_closing_cutoff:

def to_gap_closing_candidates(row, segments_df):
# if the index is in force_end_indices, do not add to gap closing candidates
Expand Down Expand Up @@ -348,11 +372,11 @@ def to_gap_closing_candidates(row, segments_df):
target_dist_matrix = cdist(
[target_coord],
np.stack(df["first_frame_coords"].values),
metric=self.gap_closing_dist_metric,
metric=self.gap_closing_metric,
)
assert target_dist_matrix.shape[0] == 1
indices2 = np.where(
target_dist_matrix[0] < self.gap_closing_cost_cutoff
target_dist_matrix[0] < self.gap_closing_cutoff
)[0]
return (
df.index[indices2].values,
Expand Down Expand Up @@ -581,11 +605,7 @@ def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges)
the updated track tree
"""
edges = list(split_edges) + list(merge_edges)
if (
self.gap_closing_cost_cutoff
or self.splitting_cost_cutoff
or self.merging_cost_cutoff
):
if self.gap_closing_cutoff or self.splitting_cutoff or self.merging_cutoff:
segments_df = _get_segment_df(coords, track_tree)
force_end_nodes = [tuple(map(int, e[0])) for e in edges]
force_start_nodes = [tuple(map(int, e[1])) for e in edges]
Expand All @@ -603,8 +623,8 @@ def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges)
# compute candidate for splitting and merging
for prefix, cutoff, dist_metric in zip(
["first", "last"],
[self.splitting_cost_cutoff, self.merging_cost_cutoff],
[self.splitting_dist_metric, self.merging_dist_metric],
[self.splitting_cutoff, self.merging_cutoff],
[self.splitting_metric, self.merging_metric],
):
(
segments_df,
Expand Down
Loading
Loading