Skip to content

Commit

Permalink
Merge pull request #477 from yfukai/refactor_field_names
Browse files Browse the repository at this point in the history
Updated overlap tracking field names
  • Loading branch information
yfukai authored Jan 4, 2025
2 parents 2378765 + e02200f commit e8534ed
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 19 deletions.
23 changes: 23 additions & 0 deletions src/laptrack/_overlap_tracking.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# from functools import cache
import warnings
from functools import partial
from typing import Any
from typing import List
from typing import Tuple
from typing import Union

from pydantic import Field
from pydantic import model_validator

from ._tracking import LapTrack
from ._typing_utils import IntArray
Expand All @@ -14,6 +17,13 @@
float, float, float, float, float
] # offset, overlap, iou, ratio_1, ratio_2

_ALIAS_FIELDS = {
"track_dist_metric_coefs": "metric_coefs",
"gap_closing_dist_metric_coefs": "gap_closing_metric_coefs",
"splitting_dist_metric_coefs": "splitting_metric_coefs",
"merging_dist_metric_coefs": "merging_metric_coefs",
}


class OverLapTrack(LapTrack):
"""Tracking by label overlaps."""
Expand Down Expand Up @@ -41,6 +51,19 @@ class OverLapTrack(LapTrack):
+ "See `metric_coefs` for details.",
)

@model_validator(mode="before")
@classmethod
def _check_deprecated_names(cls, data: Any) -> Any:
if isinstance(data, dict):
for old_name, new_name in _ALIAS_FIELDS.items():
if old_name in data:
warnings.warn(
f"Use of `{old_name}` is deprecated, use `{new_name}` instead.",
DeprecationWarning,
)
data[new_name] = data.pop(old_name)
return data

def predict_overlap_dataframe(self, labels: Union[IntArray, List[IntArray]]):
"""Predicts tracks with label overlaps.
Expand Down
47 changes: 28 additions & 19 deletions tests/test_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from laptrack import LapTrack
from laptrack import laptrack
from laptrack import OverLapTrack
from laptrack._overlap_tracking import _ALIAS_FIELDS as _ALIAS_FIELDS_OVERLAP
from laptrack._tracking import _ALIAS_FIELDS
from laptrack.data_conversion import convert_tree_to_dataframe

Expand Down Expand Up @@ -429,25 +431,32 @@ def test_no_connected_node(tracker_class) -> None:

def test_alias_and_deprecation_warning():
# Check that the alias works and issues a warning
for old_name, new_name in _ALIAS_FIELDS.items():
with warnings.catch_warnings(record=True) as w:
if "metric" in old_name:
test_value = "euclidean"
else:
test_value = 20**2
warnings.simplefilter("always") # Ensure all warnings are captured
lt = LapTrack(**{old_name: test_value})
assert (
getattr(lt, new_name) == test_value
) # Validate the correct field is populated

# Check if a DeprecationWarning was raised
assert len(w) == 1 # Only one warning should be present
assert issubclass(w[0].category, DeprecationWarning)
assert (
f"Use of `{old_name}` is deprecated, use `{new_name}` instead."
== str(w[0].message)
)
for fields, cls in [
(_ALIAS_FIELDS, LapTrack),
(_ALIAS_FIELDS_OVERLAP, OverLapTrack),
]:
for old_name, new_name in fields.items():
with warnings.catch_warnings(record=True) as w:
if cls == LapTrack:
if "metric" in old_name:
test_value = "euclidean"
else:
test_value = 20**2
else:
test_value = (1.0, 0.0, 0.0, 0.0, -1.0)
warnings.simplefilter("always") # Ensure all warnings are captured
lt = cls(**{old_name: test_value})
assert (
getattr(lt, new_name) == test_value
) # Validate the correct field is populated

# Check if a DeprecationWarning was raised
assert len(w) == 1 # Only one warning should be present
assert issubclass(w[0].category, DeprecationWarning)
assert (
f"Use of `{old_name}` is deprecated, use `{new_name}` instead."
== str(w[0].message)
)


# # %%
Expand Down

0 comments on commit e8534ed

Please sign in to comment.