diff --git a/poetry.lock b/poetry.lock index da03f094..ee09c3c3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -161,7 +161,7 @@ dev = ["build (==0.8.0)", "flake8 (==4.0.1)", "hashin (==0.17.0)", "pip-tools (= [[package]] name = "certifi" -version = "2022.6.15" +version = "2022.9.24" description = "Python package for providing Mozilla's CA Bundle." category = "dev" optional = false @@ -230,9 +230,27 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""} [package.extras] development = ["black", "flake8", "mypy", "pytest", "types-colorama"] +[[package]] +name = "contourpy" +version = "1.0.5" +description = "Python library for calculating contours of 2D quadrilateral grids" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +numpy = ">=1.16" + +[package.extras] +test-no-codebase = ["pillow", "matplotlib", "pytest"] +test-minimal = ["pytest"] +test = ["isort", "flake8", "pillow", "matplotlib", "pytest"] +docs = ["sphinx-rtd-theme", "sphinx", "docutils (<0.18)"] +bokeh = ["selenium", "bokeh"] + [[package]] name = "coverage" -version = "6.4.4" +version = "6.5.0" description = "Code coverage measurement for Python" category = "dev" optional = false @@ -294,7 +312,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [[package]] name = "dparse" -version = "0.5.2" +version = "0.6.2" description = "A parser for Python dependency files" category = "dev" optional = false @@ -318,15 +336,18 @@ python-versions = ">=3.6" [[package]] name = "executing" -version = "1.0.0" +version = "1.1.0" description = "Get the currently executing AST node of a frame, and other information" category = "dev" optional = false python-versions = "*" +[package.extras] +tests = ["rich", "littleutils", "pytest", "asttokens"] + [[package]] name = "fastjsonschema" -version = "2.16.1" +version = "2.16.2" description = "Fastest Python implementation of JSON schema" category = "dev" optional = false @@ -349,7 +370,7 @@ testing = ["covdefaults (>=2.2)", "coverage (>=6.4.2)", "pytest (>=7.1.2)", "pyt [[package]] name = "fonttools" -version = "4.37.1" +version = "4.37.4" description = "Tools to manipulate font files" category = "dev" optional = false @@ -371,7 +392,7 @@ woff = ["zopfli (>=0.1.4)", "brotlicffi (>=0.8.0)", "brotli (>=1.0.1)"] [[package]] name = "identify" -version = "2.5.5" +version = "2.5.6" description = "File identification library for Python" category = "dev" optional = false @@ -382,7 +403,7 @@ license = ["ukkonen"] [[package]] name = "idna" -version = "3.3" +version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" category = "dev" optional = false @@ -398,7 +419,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" [[package]] name = "importlib-metadata" -version = "4.12.0" +version = "5.0.0" description = "Read metadata from Python packages" category = "dev" optional = false @@ -408,9 +429,9 @@ python-versions = ">=3.7" zipp = ">=0.5" [package.extras] -docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)"] +docs = ["sphinx (>=3.5)", "jaraco.packaging (>=9)", "rst.linker (>=1.9)", "furo", "jaraco.tidelift (>=1.4)"] perf = ["ipython"] -testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.3)", "packaging", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "importlib-resources (>=1.3)"] +testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "flake8 (<5)", "pytest-cov", "pytest-enabler (>=1.3)", "packaging", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "importlib-resources (>=1.3)"] [[package]] name = "importlib-resources" @@ -437,7 +458,7 @@ python-versions = "*" [[package]] name = "ipykernel" -version = "6.15.2" +version = "6.16.0" description = "IPython Kernel for Jupyter" category = "dev" optional = false @@ -551,7 +572,7 @@ i18n = ["Babel (>=2.7)"] [[package]] name = "jsonschema" -version = "4.15.0" +version = "4.16.0" description = "An implementation of JSON Schema validation for Python" category = "dev" optional = false @@ -673,20 +694,6 @@ python-versions = "*" six = "*" tornado = {version = "*", markers = "python_version > \"2.7\""} -[[package]] -name = "lxml" -version = "4.9.1" -description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API." -category = "dev" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, != 3.4.*" - -[package.extras] -cssselect = ["cssselect (>=0.7)"] -html5 = ["html5lib"] -htmlsoup = ["beautifulsoup4"] -source = ["Cython (>=0.29.7)"] - [[package]] name = "markupsafe" version = "2.1.1" @@ -697,22 +704,23 @@ python-versions = ">=3.7" [[package]] name = "matplotlib" -version = "3.5.3" +version = "3.6.0" description = "Python plotting package" category = "dev" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" [package.dependencies] +contourpy = ">=1.0.1" cycler = ">=0.10" fonttools = ">=4.22.0" kiwisolver = ">=1.0.1" -numpy = ">=1.17" +numpy = ">=1.19" packaging = ">=20.0" pillow = ">=6.2.0" pyparsing = ">=2.2.1" python-dateutil = ">=2.7" -setuptools_scm = ">=4,<7" +setuptools_scm = ">=7" [[package]] name = "matplotlib-inline" @@ -760,7 +768,7 @@ python-versions = "*" [[package]] name = "nbclient" -version = "0.6.7" +version = "0.6.8" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." category = "dev" optional = false @@ -778,7 +786,7 @@ test = ["black", "check-manifest", "flake8", "ipykernel", "ipython", "ipywidgets [[package]] name = "nbconvert" -version = "7.0.0" +version = "7.1.0" description = "Converting Jupyter Notebooks" category = "dev" optional = false @@ -792,7 +800,6 @@ importlib-metadata = {version = ">=3.6", markers = "python_version < \"3.10\""} jinja2 = ">=3.0" jupyter-core = ">=4.7" jupyterlab-pygments = "*" -lxml = "*" markupsafe = ">=2.0" mistune = ">=2.0.3,<3" nbclient = ">=0.5.0" @@ -814,7 +821,7 @@ webpdf = ["pyppeteer (>=1,<1.1)"] [[package]] name = "nbformat" -version = "5.4.0" +version = "5.6.1" description = "The Jupyter Notebook format" category = "dev" optional = false @@ -827,11 +834,11 @@ jupyter-core = "*" traitlets = ">=5.1" [package.extras] -test = ["check-manifest", "testpath", "pytest", "pre-commit"] +test = ["check-manifest", "pep440", "pre-commit", "pytest", "testpath"] [[package]] name = "nest-asyncio" -version = "1.5.5" +version = "1.5.6" description = "Patch asyncio to allow nested event loops" category = "dev" optional = false @@ -839,7 +846,7 @@ python-versions = ">=3.5" [[package]] name = "networkx" -version = "2.8.6" +version = "2.8.7" description = "Python package for creating and manipulating graphs and networks" category = "main" optional = false @@ -847,7 +854,7 @@ python-versions = ">=3.8" [package.extras] default = ["numpy (>=1.19)", "scipy (>=1.8)", "matplotlib (>=3.4)", "pandas (>=1.3)"] -developer = ["pre-commit (>=2.20)", "mypy (>=0.961)"] +developer = ["pre-commit (>=2.20)", "mypy (>=0.981)"] doc = ["sphinx (>=5)", "pydata-sphinx-theme (>=0.9)", "sphinx-gallery (>=0.10)", "numpydoc (>=1.4)", "pillow (>=9.1)", "nb2plots (>=0.6)", "texext (>=0.6.6)"] extra = ["lxml (>=4.6)", "pygraphviz (>=1.9)", "pydot (>=1.4.2)", "sympy (>=1.10)"] test = ["pytest (>=7.1)", "pytest-cov (>=3.0)", "codecov (>=2.1)"] @@ -923,7 +930,7 @@ tomlkit = ">=0.7" [[package]] name = "numpy" -version = "1.23.2" +version = "1.23.3" description = "NumPy is the fundamental package for array computing with Python." category = "main" optional = false @@ -942,7 +949,7 @@ pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" [[package]] name = "pandas" -version = "1.4.4" +version = "1.5.0" description = "Powerful data structures for data analysis, time series, and statistics" category = "main" optional = false @@ -950,16 +957,14 @@ python-versions = ">=3.8" [package.dependencies] numpy = [ - {version = ">=1.18.5", markers = "platform_machine != \"aarch64\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, - {version = ">=1.19.2", markers = "platform_machine == \"aarch64\" and python_version < \"3.10\""}, - {version = ">=1.20.0", markers = "platform_machine == \"arm64\" and python_version < \"3.10\""}, {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, + {version = ">=1.20.3", markers = "python_version < \"3.10\""}, ] python-dateutil = ">=2.8.1" pytz = ">=2020.1" [package.extras] -test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] +test = ["pytest-xdist (>=1.31)", "pytest (>=6.0)", "hypothesis (>=5.5.3)"] [[package]] name = "pandocfilters" @@ -1141,14 +1146,14 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" [[package]] name = "pydantic" -version = "1.9.2" +version = "1.10.2" description = "Data validation and settings management using python type hints" category = "main" optional = false -python-versions = ">=3.6.1" +python-versions = ">=3.7" [package.dependencies] -typing-extensions = ">=3.7.4.3" +typing-extensions = ">=4.1.0" [package.extras] dotenv = ["python-dotenv (>=0.10.4)"] @@ -1229,7 +1234,7 @@ six = ">=1.5" [[package]] name = "pytz" -version = "2022.2.1" +version = "2022.4" description = "World timezone definitions, modern and historical" category = "main" optional = false @@ -1245,7 +1250,7 @@ python-versions = "*" [[package]] name = "pywinpty" -version = "2.0.7" +version = "2.0.8" description = "Pseudo terminal support for Windows from Python." category = "dev" optional = false @@ -1261,7 +1266,7 @@ python-versions = ">=3.6" [[package]] name = "pyzmq" -version = "23.2.1" +version = "24.0.1" description = "Python bindings for 0MQ" category = "dev" optional = false @@ -1295,7 +1300,7 @@ test = ["flaky", "pytest", "pytest-qt"] [[package]] name = "qtpy" -version = "2.2.0" +version = "2.2.1" description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6)." category = "dev" optional = false @@ -1305,7 +1310,7 @@ python-versions = ">=3.7" packaging = "*" [package.extras] -test = ["pytest-qt", "pytest-cov (>=3.0.0)", "pytest (>=6,!=7.0.0,!=7.0.1)"] +test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] [[package]] name = "requests" @@ -1365,15 +1370,16 @@ nativelib = ["pywin32", "pyobjc-framework-cocoa"] [[package]] name = "setuptools-scm" -version = "6.4.2" +version = "7.0.5" description = "the blessed package to manage your versions by scm tags" category = "dev" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.dependencies] packaging = ">=20.0" tomli = ">=1.0.0" +typing-extensions = "*" [package.extras] test = ["pytest (>=6.2)", "virtualenv (>20)"] @@ -1564,7 +1570,7 @@ test = ["pytest"] [[package]] name = "stack-data" -version = "0.5.0" +version = "0.5.1" description = "Extract data from python stack frames and tracebacks for informative displays" category = "dev" optional = false @@ -1580,7 +1586,7 @@ tests = ["cython", "littleutils", "pygments", "typeguard", "pytest"] [[package]] name = "terminado" -version = "0.15.0" +version = "0.16.0" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." category = "dev" optional = false @@ -1627,7 +1633,7 @@ python-versions = ">=3.7" [[package]] name = "tomlkit" -version = "0.11.4" +version = "0.11.5" description = "Style preserving TOML library" category = "dev" optional = false @@ -1643,7 +1649,7 @@ python-versions = ">= 3.7" [[package]] name = "traitlets" -version = "5.3.0" +version = "5.4.0" description = "" category = "dev" optional = false @@ -1666,11 +1672,11 @@ test = ["pytest", "typing-extensions", "mypy"] [[package]] name = "typing-extensions" -version = "3.10.0.2" -description = "Backported and Experimental Type Hints for Python 3.5+" +version = "4.3.0" +description = "Backported and Experimental Type Hints for Python 3.7+" category = "main" optional = false -python-versions = "*" +python-versions = ">=3.7" [[package]] name = "urllib3" @@ -1769,7 +1775,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = "^3.8.1,<3.11" -content-hash = "74b6ca4c2e6dbaf3db65daac12106b013c5f444608487f61ca04715baf9ab91f" +content-hash = "58519c6472de2fa12ee07bcee5206e15dfffd0b344074503d95826ca9a22875f" [metadata.files] alabaster = [] @@ -1792,6 +1798,7 @@ charset-normalizer = [] click = [] colorama = [] colorlog = [] +contourpy = [] coverage = [] cycler = [] debugpy = [] @@ -1826,7 +1833,6 @@ jupyterlab-pygments = [] jupyterlab-widgets = [] kiwisolver = [] livereload = [] -lxml = [] markupsafe = [] matplotlib = [] matplotlib-inline = [] diff --git a/pyproject.toml b/pyproject.toml index 24c045b4..02dfc211 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,8 +30,8 @@ numpy = "^1.22.0" scipy = "^1.7.0" networkx = "^2.6.1" pandas = "^1.3.1" -typing-extensions = "^3.10.0" -pydantic = "^1.9.1" +typing-extensions = "^4.1.0" +pydantic = "^1.10.2" [tool.poetry.dev-dependencies] pytest = "^6.2.4" diff --git a/src/laptrack/_tracking.py b/src/laptrack/_tracking.py index aeaf5b09..657578c9 100644 --- a/src/laptrack/_tracking.py +++ b/src/laptrack/_tracking.py @@ -1,5 +1,6 @@ """Main module for tracking.""" import logging +import warnings from abc import ABC from abc import abstractmethod from inspect import Parameter @@ -36,7 +37,10 @@ from ._typing_utils import NumArray from ._typing_utils import Int from ._coo_matrix_builder import coo_matrix_builder -from .data_conversion import convert_dataframe_to_coords, convert_tree_to_dataframe +from .data_conversion import ( + convert_dataframe_to_coords_frame_index, + convert_tree_to_dataframe, +) logger = logging.getLogger(__name__) @@ -612,6 +616,7 @@ def predict_dataframe( coordinate_cols: List[str], frame_col: str = "frame", validate_frame: bool = True, + only_coordinate_cols: bool = True, ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """Shorthand for the tracking with the dataframe input / output. @@ -625,10 +630,12 @@ def predict_dataframe( The column name to use for the frame index. Defaults to "frame". validate_frame : bool, optional Whether to validate the frame. Defaults to True. + only_coordinate_cols : bool, optional + Whether to use only coordinate columns. Defaults to True. Returns ------- - df : pd.DataFrame + track_df : pd.DataFrame the track dataframe, with the following columns: - "frame" : the frame index - "index" : the coordinate index @@ -644,13 +651,24 @@ def predict_dataframe( - "parent_track_id" : the track id of the parent - "child_track_id" : the track id of the parent """ - coords = convert_dataframe_to_coords( + coords, frame_index = convert_dataframe_to_coords_frame_index( df, coordinate_cols, frame_col, validate_frame ) tree = self.predict(coords) - df, split_df, merge_df = convert_tree_to_dataframe(tree, coords) - df = df.rename(columns={f"coord-{i}": k for i, k in enumerate(coordinate_cols)}) - return df, split_df, merge_df + if only_coordinate_cols: + track_df, split_df, merge_df = convert_tree_to_dataframe(tree, coords) + track_df = track_df.rename( + columns={f"coord-{i}": k for i, k in enumerate(coordinate_cols)} + ) + warnings.warn( + "The parameter only_coordinate_cols will be False by default in the major release.", + FutureWarning, + ) + else: + track_df, split_df, merge_df = convert_tree_to_dataframe( + tree, dataframe=df, frame_index=frame_index + ) + return track_df, split_df, merge_df class LapTrack(LapTrackBase): diff --git a/src/laptrack/data_conversion.py b/src/laptrack/data_conversion.py index 503d7230..fc2648a4 100644 --- a/src/laptrack/data_conversion.py +++ b/src/laptrack/data_conversion.py @@ -32,7 +32,7 @@ def convert_dataframe_to_coords( frame_col : str, optional The column name to use for the frame index. Defaults to "frame". validate_frame : bool, optional - whether to validate the frame. Defaults to True. + Whether to validate the frame. Defaults to True. Returns ------- @@ -42,12 +42,67 @@ def convert_dataframe_to_coords( grps = list(df.groupby(frame_col, sort=True)) if validate_frame: assert np.array_equal(np.arange(df[frame_col].max() + 1), [g[0] for g in grps]) - coords = [grp[coordinate_cols].values for _frame, grp in grps] + coords = [grp[list(coordinate_cols)].values for _frame, grp in grps] return coords +def convert_dataframe_to_coords_frame_index( + df: pd.DataFrame, + coordinate_cols: List[str], + frame_col: str = "frame", + validate_frame: bool = True, +) -> Tuple[List[NumArray], List[Tuple[int, int]]]: + """ + Convert a track dataframe to a list of coordinates for input. + + Parameters + ---------- + df : pd.DataFrame + the input dataframe + coordinate_cols : List[str] + the list of columns to use for coordinates + frame_col : str, optional + The column name to use for the frame index. Defaults to "frame". + validate_frame : bool, optional + Whether to validate the frame. Defaults to True. + + Returns + ------- + coords : List[np.ndarray] + the list of coordinates + frame_index : List[Tuple[int, int]] + the (frame, index) list for the original iloc of the dataframe. + """ + assert "iloc__" not in df.columns + df = df.copy() + df["iloc__"] = np.arange(len(df), dtype=int) + + coords = convert_dataframe_to_coords( + df, list(coordinate_cols) + ["iloc__"], frame_col, validate_frame + ) + + inverse_map = dict( + sum( + [ + [(int(c2[-1]), (frame, index)) for index, c2 in enumerate(c)] + for frame, c in enumerate(coords) + ], + [], + ) + ) + + ilocs = list(range(len(df))) + assert set(inverse_map.keys()) == set(ilocs) + frame_index = [inverse_map[i] for i in ilocs] + + return [c[:, :-1] for c in coords], frame_index + + def convert_tree_to_dataframe( - tree: nx.DiGraph, coords: Optional[Sequence[NumArray]] = None + tree: nx.DiGraph, + coords: Optional[Sequence[NumArray]] = None, + dataframe: Optional[pd.DataFrame] = None, + frame_index: Optional[List[Tuple[int, int]]] = None, ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """Convert the track tree to dataframes. @@ -57,10 +112,14 @@ def convert_tree_to_dataframe( The track tree, resulted from the traking coords : Optional[Sequence[NumArray]] The coordinate values. If None, no coordinate values are appended to the dataframe. + dataframe : Optional[pd.DataFrame] + The dataframe. If not None, `frame_index` should also exist. Ignored if `coords` is not None. + frame_index : Optional[List[Tuple[int, int]]] + the inverse map to map (frame, index) to the original iloc of the dataframe. Returns ------- - df : pd.DataFrame + track_df : pd.DataFrame the track dataframe, with the following columns: - "frame" : the frame index - "index" : the coordinate index @@ -89,20 +148,50 @@ def convert_tree_to_dataframe( } ) ) - df = pd.concat(df_data) + track_df = pd.concat(df_data) if coords is not None: # XXX there may exist faster impl. for i in range(coords[0].shape[1]): - df[f"coord-{i}"] = [ + track_df[f"coord-{i}"] = [ coords[int(row["frame"])][int(row["index"]), i] - for _, row in df.iterrows() + for _, row in track_df.iterrows() ] + elif dataframe is not None: + assert len(track_df) == len(dataframe) + df_len = len(track_df) + if frame_index is None: + raise ValueError("frame_index must not be None if dataframe is not None") + frame_index_test = set( + [ + tuple([int(v) for v in vv]) + for vv in track_df[["frame", "index"]].to_numpy() + ] + ) + assert ( + set(list(frame_index)) == frame_index_test + ), "inverse map (frame,index) is incorrect" + + assert "__frame" not in dataframe.columns + assert "__index" not in dataframe.columns + dataframe["__frame"] = [x[0] for x in frame_index] + dataframe["__index"] = [x[1] for x in frame_index] + track_df = pd.merge( + track_df, + dataframe, + left_on=["frame", "index"], + right_on=["__frame", "__index"], + how="outer", + ) + assert len(track_df) == df_len + track_df = track_df.drop(columns=["__frame", "__index"]).rename( + columns={"frame_x": "frame", "index_x": "index"} + ) - df = df.set_index(["frame", "index"]) + track_df = track_df.set_index(["frame", "index"]) connected_components = list(nx.connected_components(nx.Graph(tree))) for track_id, nodes in enumerate(connected_components): for (frame, index) in nodes: - df.loc[(frame, index), "tree_id"] = track_id + track_df.loc[(frame, index), "tree_id"] = track_id # tree.nodes[(frame, index)]["tree_id"] = track_id tree2 = tree.copy() @@ -127,19 +216,19 @@ def convert_tree_to_dataframe( connected_components = list(nx.connected_components(nx.Graph(tree2))) for track_id, nodes in enumerate(connected_components): for (frame, index) in nodes: - df.loc[(frame, index), "track_id"] = track_id + track_df.loc[(frame, index), "track_id"] = track_id # tree.nodes[(frame, index)]["track_id"] = track_id for k in ["tree_id", "track_id"]: - df[k] = df[k].astype(int) + track_df[k] = track_df[k].astype(int) split_df_data = [] for (node, children) in splits: for child in children: split_df_data.append( { - "parent_track_id": df.loc[node, "track_id"], - "child_track_id": df.loc[child, "track_id"], + "parent_track_id": track_df.loc[node, "track_id"], + "child_track_id": track_df.loc[child, "track_id"], } ) split_df = pd.DataFrame.from_records(split_df_data).astype(int) @@ -149,10 +238,10 @@ def convert_tree_to_dataframe( for parent in parents: merge_df_data.append( { - "parent_track_id": df.loc[parent, "track_id"], - "child_track_id": df.loc[node, "track_id"], + "parent_track_id": track_df.loc[parent, "track_id"], + "child_track_id": track_df.loc[node, "track_id"], } ) merge_df = pd.DataFrame.from_records(merge_df_data).astype(int) - return df, split_df, merge_df + return track_df, split_df, merge_df diff --git a/tests/test_data_conversion.py b/tests/test_data_conversion.py index abeb7db4..35a2faff 100644 --- a/tests/test_data_conversion.py +++ b/tests/test_data_conversion.py @@ -22,11 +22,30 @@ def test_convert_dataframe_to_coords(): np.array([[3, 3], [4, 4]]), np.array([[5, 5], [6, 6], [7, 7], [8, 8], [9, 9]]), ] + frame_index_target = [ + (0, 0), + (0, 1), + (0, 2), + (1, 0), + (1, 1), + (2, 0), + (2, 1), + (2, 2), + (2, 3), + (2, 4), + ] coords = data_conversion.convert_dataframe_to_coords(df, ["x", "y"]) assert len(coords) == len(df["frame"].unique()) assert all([np.all(c1 == c2) for c1, c2 in zip(coords, coords_target)]) + coords, frame_index = data_conversion.convert_dataframe_to_coords_frame_index( + df, ["x", "y"] + ) + assert len(coords) == len(df["frame"].unique()) + assert all([np.all(c1 == c2) for c1, c2 in zip(coords, coords_target)]) + assert frame_index == frame_index_target + @pytest.fixture def test_trees(): @@ -169,6 +188,30 @@ def test_convert_tree_to_dataframe(test_trees): ) +@pytest.mark.parametrize("track_class", [LapTrack, LapTrackMulti]) +def test_convert_tree_to_dataframe_frame_index(track_class): + df = pd.DataFrame( + { + "frame": [0, 0, 0, 1, 1, 2, 2, 2, 2, 2], + "x": [0.1, 1.1, 2.1, 0.05, 1.05, 0.1, 1.1, 7, 8, 9], + "y": [0.1, 1.1, 2.1, 0.05, 1.05, 0.1, 1.1, 7, 8, 9], + "z": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } + ) + coords, frame_index = data_conversion.convert_dataframe_to_coords_frame_index( + df, ["x", "y"] + ) + lt = track_class(gap_closing_max_frame_count=1) + tree = lt.predict(coords) + df, split_df, merge_df = data_conversion.convert_tree_to_dataframe( + tree, dataframe=df, frame_index=frame_index + ) + assert all(df["frame_y"] == df.index.get_level_values("frame")) + assert len(np.unique(df.iloc[[0, 3, 5]]["tree_id"])) == 1 + assert len(np.unique(df.iloc[[1, 4, 6]]["tree_id"])) == 1 + assert len(np.unique(df["tree_id"])) > 1 + + @pytest.mark.parametrize("track_class", [LapTrack, LapTrackMulti]) def test_integration(track_class): df = pd.DataFrame( diff --git a/tests/test_scores.py b/tests/test_scores.py index b1107603..c07aaf3a 100644 --- a/tests/test_scores.py +++ b/tests/test_scores.py @@ -55,6 +55,8 @@ def test_scores(test_trees) -> None: } assert score == calc_scores(true_tree.edges, pred_tree.edges) + assert calc_scores(true_tree.edges, []).keys() == score.keys() + def test_scores_no_track(test_trees) -> None: true_tree, pred_tree = test_trees diff --git a/tests/test_tracking.py b/tests/test_tracking.py index 264c9a5f..042e26f3 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -1,4 +1,5 @@ """Test cases for the tracking.""" +import warnings from itertools import product from os import path @@ -13,6 +14,8 @@ from laptrack import LapTrackMulti from laptrack.data_conversion import convert_tree_to_dataframe +warnings.simplefilter("ignore", FutureWarning) + DEFAULT_PARAMS = dict( track_dist_metric="sqeuclidean", splitting_dist_metric="sqeuclidean", @@ -122,6 +125,15 @@ def test_reproducing_trackmate(testdata, tracker_class) -> None: assert all(split_df == split_df2) assert all(merge_df == merge_df2) + track_df, split_df, merge_df = lt.predict_dataframe( + df, ["x", "y"], only_coordinate_cols=False + ) + assert all(track_df["frame_y"] == track_df2.index.get_level_values("frame")) + track_df = track_df.drop(columns=["frame_y"]) + assert all(track_df == track_df2) + assert all(split_df == split_df2) + assert all(merge_df == merge_df2) + @pytest.fixture(params=[2, 3, 4]) def dist_metric(request):