From 1f2338c5244b08c9475136aa6df279b2cb7a5413 Mon Sep 17 00:00:00 2001 From: Dhruv <49231411+DhruvSkyy@users.noreply.github.com> Date: Thu, 16 Nov 2023 18:45:18 +0000 Subject: [PATCH] Save DLC multi-animal pose tracks to single-animal files (#83) Co-authored-by: niksirbi --- movement/io/save_poses.py | 230 ++++++++++++++++++++++------- tests/conftest.py | 17 ++- tests/test_integration/test_io.py | 10 +- tests/test_unit/test_save_poses.py | 109 +++++++++++++- 4 files changed, 306 insertions(+), 60 deletions(-) diff --git a/movement/io/save_poses.py b/movement/io/save_poses.py index e1be3f70..81e74b2a 100644 --- a/movement/io/save_poses.py +++ b/movement/io/save_poses.py @@ -1,50 +1,33 @@ import logging from pathlib import Path -from typing import Union +from typing import Literal, Union import numpy as np import pandas as pd import xarray as xr from movement.io.validators import ValidFile +from movement.logging import log_error logger = logging.getLogger(__name__) -def to_dlc_df(ds: xr.Dataset) -> pd.DataFrame: - """Convert an xarray dataset containing pose tracks into a - DeepLabCut-style pandas DataFrame with multi-index columns. +def _xarray_to_dlc_df(ds: xr.Dataset, columns: pd.MultiIndex) -> pd.DataFrame: + """Takes an xarray dataset and DLC-style multi-index columns and outputs + a pandas dataframe. Parameters ---------- - ds : xarray Dataset + ds : xarray.Dataset Dataset containing pose tracks, confidence scores, and metadata. + columns : pandas.MultiIndex + DLC-style multi-index columns Returns ------- - pandas DataFrame - - Notes - ----- - The DataFrame will have a multi-index column with the following levels: - "scorer", "individuals", "bodyparts", "coords" (even if there is only - one individual present). Regardless of the provenance of the - points-wise confidence scores, they will be referred to as - "likelihood", and stored in the "coords" level (as DeepLabCut expects). - - See Also - -------- - to_dlc_file : Save the xarray dataset containing pose tracks directly - to a DeepLabCut-style ".h5" or ".csv" file. + pandas.DataFrame """ - if not isinstance(ds, xr.Dataset): - error_msg = f"Expected an xarray Dataset, but got {type(ds)}. " - logger.error(error_msg) - raise ValueError(error_msg) - - ds.poses.validate() # validate the dataset - # Concatenate the pose tracks and confidence scores into one array tracks_with_scores = np.concatenate( ( @@ -54,44 +37,164 @@ def to_dlc_df(ds: xr.Dataset) -> pd.DataFrame: axis=-1, ) - # Create the DLC-style multi-index columns - # Use the DLC terminology: scorer, individuals, bodyparts, coords - scorer = ["movement"] - individuals = ds.coords["individuals"].data.tolist() - bodyparts = ds.coords["keypoints"].data.tolist() - # The confidence scores in DLC are referred to as "likelihood" - coords = ds.coords["space"].data.tolist() + ["likelihood"] - - index_levels = ["scorer", "individuals", "bodyparts", "coords"] - columns = pd.MultiIndex.from_product( - [scorer, individuals, bodyparts, coords], names=index_levels - ) + # Create DataFrame with multi-index columns df = pd.DataFrame( data=tracks_with_scores.reshape(ds.dims["time"], -1), index=np.arange(ds.dims["time"], dtype=int), columns=columns, dtype=float, ) - logger.info("Converted PoseTracks dataset to DLC-style DataFrame.") + return df -def to_dlc_file(ds: xr.Dataset, file_path: Union[str, Path]) -> None: +def _auto_split_individuals(ds: xr.Dataset) -> bool: + """Returns True if there is only one individual in the dataset, + else returns False.""" + + n_individuals = ds.sizes["individuals"] + return True if n_individuals == 1 else False + + +def _save_dlc_df(filepath: Path, df: pd.DataFrame) -> None: + """Given a filepath, will save the dataframe as either a .h5 or .csv. + + Parameters + ---------- + filepath : pathlib.Path + Path of the file to save the dataframe to. The file extension + must be either .h5 (recommended) or .csv. + df : pandas.DataFrame + Pandas Dataframe to save + """ + + if filepath.suffix == ".csv": + df.to_csv(filepath, sep=",") + else: # at this point it can only be .h5 (because of validation) + df.to_hdf(filepath, key="df_with_missing") + + +def to_dlc_df( + ds: xr.Dataset, split_individuals: bool = False +) -> Union[pd.DataFrame, dict[str, pd.DataFrame]]: + """Convert an xarray dataset containing pose tracks into a single + DeepLabCut-style pandas DataFrame or a dictionary of DataFrames + per individual, depending on the 'split_individuals' argument. + + Parameters + ---------- + ds : xarray.Dataset + Dataset containing pose tracks, confidence scores, and metadata. + split_individuals : bool, optional + If True, return a dictionary of pandas DataFrames per individual, + with individual names as keys and DataFrames as values. + If False, return a single pandas DataFrame for all individuals. + Default is False. + + Returns + ------- + pandas.DataFrame or dict + DeepLabCut-style pandas DataFrame or dictionary of DataFrames. + + Notes + ----- + The DataFrame(s) will have a multi-index column with the following levels: + "scorer", "bodyparts", "coords" (if split_individuals is True), + or "scorer", "individuals", "bodyparts", "coords" + (if split_individuals is False). + + Regardless of the provenance of the points-wise confidence scores, + they will be referred to as "likelihood", and stored in + the "coords" level (as DeepLabCut expects). + + See Also + -------- + to_dlc_file : Save the xarray dataset containing pose tracks directly + to a DeepLabCut-style .h5 or .csv file. + """ + if not isinstance(ds, xr.Dataset): + raise log_error( + ValueError, f"Expected an xarray Dataset, but got {type(ds)}." + ) + + ds.poses.validate() # validate the dataset + + scorer = ["movement"] + bodyparts = ds.coords["keypoints"].data.tolist() + coords = ds.coords["space"].data.tolist() + ["likelihood"] + individuals = ds.coords["individuals"].data.tolist() + + if split_individuals: + df_dict = {} + + for individual in individuals: + individual_data = ds.sel(individuals=individual) + + index_levels = ["scorer", "bodyparts", "coords"] + columns = pd.MultiIndex.from_product( + [scorer, bodyparts, coords], names=index_levels + ) + + df = _xarray_to_dlc_df(individual_data, columns) + df_dict[individual] = df + + logger.info( + "Converted PoseTracks dataset to DeepLabCut-style DataFrames " + "per individual." + ) + return df_dict + else: + index_levels = ["scorer", "individuals", "bodyparts", "coords"] + columns = pd.MultiIndex.from_product( + [scorer, individuals, bodyparts, coords], names=index_levels + ) + + df_all = _xarray_to_dlc_df(ds, columns) + + logger.info("Converted PoseTracks dataset to DLC-style DataFrame.") + return df_all + + +def to_dlc_file( + ds: xr.Dataset, + file_path: Union[str, Path], + split_individuals: Union[bool, Literal["auto"]] = "auto", +) -> None: """Save the xarray dataset containing pose tracks to a - DeepLabCut-style ".h5" or ".csv" file. + DeepLabCut-style .h5 or .csv file. Parameters ---------- - ds : xarray Dataset + ds : xarray.Dataset Dataset containing pose tracks, confidence scores, and metadata. - file_path : pathlib Path or str + file_path : pathlib.Path or str Path to the file to save the DLC poses to. The file extension - must be either ".h5" (recommended) or ".csv". + must be either .h5 (recommended) or .csv. + split_individuals : bool, optional + If True, each individual will be saved to a separate file, + formatted as in a single-animal DeepLabCut project - i.e. without + the "individuals" column level. The individual's name will be appended + to the file path, just before the file extension, i.e. + "/path/to/filename_individual1.h5". + If False, all individuals will be saved to the same file, + formatted as in a multi-animal DeepLabCut project - i.e. the columns + will include the "individuals" level. The file path will not be + modified. + If "auto" the argument's value be determined based on the number of + individuals in the dataset: True if there is only one, and + False if there are more than one. This is the default. See Also -------- - to_dlc_df : Convert an xarray dataset containing pose tracks into a - DeepLabCut-style pandas DataFrame with multi-index columns. + to_dlc_df : Convert an xarray dataset containing pose tracks into a single + DeepLabCut-style pandas DataFrame or a dictionary of DataFrames + per individual. + + Examples + -------- + >>> from movement.io import save_poses, load_poses + >>> ds = load_poses.from_sleap("/path/to/file_sleap.analysis.h5") + >>> save_poses.to_dlc_file(ds, "/path/to/file_dlc.h5") """ try: @@ -104,9 +207,32 @@ def to_dlc_file(ds: xr.Dataset, file_path: Union[str, Path]) -> None: logger.error(error) raise error - df = to_dlc_df(ds) # convert to pandas DataFrame - if file.path.suffix == ".csv": - df.to_csv(file.path, sep=",") - else: # file.path.suffix == ".h5" - df.to_hdf(file.path, key="df_with_missing") - logger.info(f"Saved PoseTracks dataset to {file.path}.") + # Sets default behaviour for the function + if split_individuals == "auto": + split_individuals = _auto_split_individuals(ds) + + elif not isinstance(split_individuals, bool): + raise log_error( + ValueError, + "Expected 'split_individuals' to be a boolean or 'auto', but got " + f"{type(split_individuals)}.", + ) + + if split_individuals: + # split the dataset into a dictionary of dataframes per individual + df_dict = to_dlc_df(ds, split_individuals=True) + + for key, df in df_dict.items(): + # the key is the individual's name + filepath = f"{file.path.with_suffix('')}_{key}{file.path.suffix}" + if isinstance(df, pd.DataFrame): + _save_dlc_df(Path(filepath), df) + logger.info( + f"Saved PoseTracks data for individual {key} to {file.path}." + ) + else: + # convert the dataset to a single dataframe for all individuals + df_all = to_dlc_df(ds, split_individuals=False) + if isinstance(df_all, pd.DataFrame): + _save_dlc_df(file.path, df_all) + logger.info(f"Saved PoseTracks dataset to {file.path}.") diff --git a/tests/conftest.py b/tests/conftest.py index bc7a47e3..59716a9a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -202,10 +202,19 @@ def _valid_tracks_array(array_type): @pytest.fixture -def valid_pose_dataset(valid_tracks_array): +def valid_pose_dataset(valid_tracks_array, request): """Return a valid pose tracks dataset.""" dim_names = PosesAccessor.dim_names - tracks_array = valid_tracks_array("multi_track_array") + + # create a multi_track_array by default unless overriden via param + try: + array_format = request.param + except AttributeError: + array_format = "multi_track_array" + + tracks_array = valid_tracks_array(array_format) + n_individuals, n_keypoints = tracks_array.shape[1:3] + return xr.Dataset( data_vars={ "pose_tracks": xr.DataArray(tracks_array, dims=dim_names), @@ -216,8 +225,8 @@ def valid_pose_dataset(valid_tracks_array): }, coords={ "time": np.arange(tracks_array.shape[0]), - "individuals": ["ind1", "ind2"], - "keypoints": ["key1", "key2"], + "individuals": [f"ind{i}" for i in range(1, n_individuals + 1)], + "keypoints": [f"key{i}" for i in range(1, n_keypoints + 1)], "space": ["x", "y"], }, attrs={ diff --git a/tests/test_integration/test_io.py b/tests/test_integration/test_io.py index 2070ae9e..54527892 100644 --- a/tests/test_integration/test_io.py +++ b/tests/test_integration/test_io.py @@ -17,13 +17,15 @@ def test_load_and_save_to_dlc_df(self, dlc_style_df): """Test that loading pose tracks from a DLC-style DataFrame and converting back to a DataFrame returns the same data values.""" ds = load_poses.from_dlc_df(dlc_style_df) - df = save_poses.to_dlc_df(ds) + df = save_poses.to_dlc_df(ds, split_individuals=False) np.testing.assert_allclose(df.values, dlc_style_df.values) def test_save_and_load_dlc_file(self, dlc_output_file, valid_pose_dataset): """Test that saving pose tracks to DLC .h5 and .csv files and then loading them back in returns the same Dataset.""" - save_poses.to_dlc_file(valid_pose_dataset, dlc_output_file) + save_poses.to_dlc_file( + valid_pose_dataset, dlc_output_file, split_individuals=False + ) ds = load_poses.from_dlc_file(dlc_output_file) xr.testing.assert_allclose(ds, valid_pose_dataset) @@ -32,6 +34,8 @@ def test_convert_sleap_to_dlc_file(self, sleap_file, dlc_output_file): when converted to DLC .h5 and .csv files and re-loaded return the same Datasets.""" sleap_ds = load_poses.from_sleap_file(sleap_file) - save_poses.to_dlc_file(sleap_ds, dlc_output_file) + save_poses.to_dlc_file( + sleap_ds, dlc_output_file, split_individuals=False + ) dlc_ds = load_poses.from_dlc_file(dlc_output_file) xr.testing.assert_allclose(sleap_ds, dlc_ds) diff --git a/tests/test_unit/test_save_poses.py b/tests/test_unit/test_save_poses.py index 2f086b4a..31b31cb2 100644 --- a/tests/test_unit/test_save_poses.py +++ b/tests/test_unit/test_save_poses.py @@ -1,4 +1,5 @@ from contextlib import nullcontext as does_not_raise +from pathlib import Path import numpy as np import pandas as pd @@ -83,7 +84,7 @@ def test_to_dlc_df(self, ds, expected_exception): """Test that converting a valid/invalid xarray dataset to a DeepLabCut-style pandas DataFrame returns the expected result.""" with expected_exception as e: - df = save_poses.to_dlc_df(ds) + df = save_poses.to_dlc_df(ds, split_individuals=False) if e is None: # valid input assert isinstance(df, pd.DataFrame) assert isinstance(df.columns, pd.MultiIndex) @@ -141,4 +142,110 @@ def test_to_dlc_file_invalid_dataset( save_poses.to_dlc_file( request.getfixturevalue(invalid_pose_dataset), tmp_path / "test.h5", + split_individuals=False, ) + + @pytest.mark.parametrize( + "valid_pose_dataset, split_value", + [("single_track_array", True), ("multi_track_array", False)], + indirect=["valid_pose_dataset"], + ) + def test_auto_split_individuals(self, valid_pose_dataset, split_value): + """Test that setting 'split_individuals' to 'auto' yields True + for single-individual datasets and False for multi-individual ones.""" + assert ( + save_poses._auto_split_individuals(valid_pose_dataset) + == split_value + ) + + @pytest.mark.parametrize( + "valid_pose_dataset, split_individuals", + [ + ("single_track_array", True), # single-individual, split + ("multi_track_array", False), # multi-individual, no split + ("single_track_array", False), # single-individual, no split + ("multi_track_array", True), # multi-individual, split + ], + indirect=["valid_pose_dataset"], + ) + def test_to_dlc_df_split_individuals( + self, + valid_pose_dataset, + split_individuals, + request, + ): + """Test that the 'split_individuals' argument affects the behaviour + of the 'to_dlc_df` function as expected + """ + df = save_poses.to_dlc_df(valid_pose_dataset, split_individuals) + # Get the names of the individuals in the dataset + ds = request.getfixturevalue("valid_pose_dataset") + ind_names = ds.individuals.values + + if split_individuals is False: + # this should produce a single df in multi-animal DLC format + assert isinstance(df, pd.DataFrame) + assert df.columns.names == [ + "scorer", + "individuals", + "bodyparts", + "coords", + ] + assert all( + [ind in df.columns.get_level_values("individuals")] + for ind in ind_names + ) + elif split_individuals is True: + # this should produce a dict of dfs in single-animal DLC format + assert isinstance(df, dict) + for ind in ind_names: + assert ind in df.keys() + assert isinstance(df[ind], pd.DataFrame) + assert df[ind].columns.names == [ + "scorer", + "bodyparts", + "coords", + ] + + @pytest.mark.parametrize( + "split_individuals, expected_exception", + [ + (True, does_not_raise()), + (False, does_not_raise()), + ("auto", does_not_raise()), + ("1", pytest.raises(ValueError, match="boolean or 'auto'")), + ], + ) + def test_to_dlc_file_split_individuals( + self, + valid_pose_dataset, + new_dlc_h5_file, + split_individuals, + expected_exception, + request, + ): + """Test that the 'split_individuals' argument affects the behaviour + of the 'to_dlc_file` function as expected + """ + + with expected_exception: + save_poses.to_dlc_file( + valid_pose_dataset, + new_dlc_h5_file, + split_individuals, + ) + ds = request.getfixturevalue("valid_pose_dataset") + + # "auto" becomes False, default valid dataset is multi-individual + if split_individuals in [False, "auto"]: + # this should save only one file + assert new_dlc_h5_file.is_file() + new_dlc_h5_file.unlink() + elif split_individuals is True: + # this should save one file per individual + for ind in ds.individuals.values: + file_path_ind = Path( + f"{new_dlc_h5_file.with_suffix('')}_{ind}.h5" + ) + assert file_path_ind.is_file() + file_path_ind.unlink()