From cc11124ce5a7f58b3a89461f13f3708928d98d04 Mon Sep 17 00:00:00 2001 From: lochhh Date: Tue, 21 Nov 2023 17:17:26 +0000 Subject: [PATCH] Draft export to sleap h5 --- movement/io/save_poses.py | 105 +++++++++++++++- tests/test_unit/test_save_poses.py | 190 ++++++++++++++++++++--------- 2 files changed, 239 insertions(+), 56 deletions(-) diff --git a/movement/io/save_poses.py b/movement/io/save_poses.py index 3437dc51e..c6f9385a0 100644 --- a/movement/io/save_poses.py +++ b/movement/io/save_poses.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Literal, Union +import h5py import numpy as np import pandas as pd import xarray as xr @@ -112,6 +113,7 @@ def to_dlc_df( to_dlc_file : Save the xarray dataset containing pose tracks directly to a DeepLabCut-style .h5 or .csv file. """ + _validate_dataset(ds) scorer = ["movement"] bodyparts = ds.coords["keypoints"].data.tolist() @@ -187,7 +189,7 @@ def to_dlc_file( Examples -------- >>> from movement.io import save_poses, load_poses - >>> ds = load_poses.from_sleap("/path/to/file_sleap.analysis.h5") + >>> ds = load_poses.from_sleap_file("/path/to/file_sleap.analysis.h5") >>> save_poses.to_dlc_file(ds, "/path/to/file_dlc.h5") """ @@ -224,6 +226,105 @@ def to_dlc_file( logger.info(f"Saved PoseTracks dataset to {file.path}.") +def to_sleap_analysis_file( + ds: xr.Dataset, file_path: Union[str, Path] +) -> None: + """Save the xarray dataset containing pose tracks to a SLEAP-style + .h5 analysis file. + + Parameters + ---------- + ds : xarray.Dataset + Dataset containing pose tracks, confidence scores, and metadata. + file_path : pathlib.Path or str + Path to the file to save the poses to. The file extension + must be .h5 (recommended) or .csv. + + Examples + -------- + >>> from movement.io import save_poses, load_poses + >>> ds = load_poses.from_dlc_file("path/to/file.h5") + >>> save_poses.to_sleap_analysis_file( + ... ds, "/path/to/file_sleap.analysis.h5" + ... ) + """ + + file = _validate_file_path(file_path=file_path, expected_suffix=[".h5"]) + _validate_dataset(ds) + + ds = _remove_unoccupied_tracks(ds) + + # Target shapes: + # "track_occupancy" n_frames * n_individuals + # "tracks" n_individuals * n_space * n_keypoints * n_frames + # "track_names" n_individuals + # "point_scores" n_individuals * n_keypoints * n_frames + # "instance_scores" n_individuals * n_frames + # "tracking_scores" n_individuals * n_frames + individual_names = ds.individuals.values.tolist() + n_individuals = len(individual_names) + keypoint_names = ds.keypoints.values.tolist() + # Compute frame indices from fps, if set + if ds.fps is not None: + frame_idxs = np.rint(ds.time.values * ds.fps).astype(int).tolist() + else: + frame_idxs = ds.time.values.astype(int).tolist() + n_frames = frame_idxs[-1] - frame_idxs[0] + 1 + pos_x = ds.pose_tracks.sel(space="x").values + # Mask denoting which individuals are present in each frame + track_occupancy = (~np.all(np.isnan(pos_x), axis=2)).astype(int) + tracks = np.transpose(ds.pose_tracks.data, (1, 3, 2, 0)) + point_scores = np.transpose(ds.confidence.data, (1, 2, 0)) + instance_scores = np.full((n_individuals, n_frames), np.nan, dtype=float) + tracking_scores = np.full((n_individuals, n_frames), np.nan, dtype=float) + + data_dict = dict( + track_names=individual_names, + node_names=keypoint_names, + tracks=tracks, + track_occupancy=track_occupancy, + point_scores=point_scores, + instance_scores=instance_scores, + tracking_scores=tracking_scores, + labels_path=ds.source_file, + edge_names=[], + edge_inds=[], + video_path="", + video_ind=0, + provenance="{}", + ) + with h5py.File(file.path, "w") as f: + for key, val in data_dict.items(): + if isinstance(val, np.ndarray): + f.create_dataset( + key, + data=val, + compression="gzip", + compression_opts=9, + ) + else: + f.create_dataset(key, data=val) + logger.info(f"Saved PoseTracks dataset to {file.path}.") + + +def _remove_unoccupied_tracks(ds: xr.Dataset): + """Remove tracks that are completely unoccupied in the xarray dataset. + + Parameters + ---------- + ds : xarray.Dataset + Dataset containing pose tracks, confidence scores, and metadata. + + Returns + ------- + xarray.Dataset + The input dataset without the unoccupied tracks. + """ + + all_nan = ds.pose_tracks.isnull().all(dim=["keypoints", "space", "time"]) + return ds.where(~all_nan, drop=True) + + def _validate_file_path( file_path: Union[str, Path], expected_suffix: list[str] ) -> ValidFile: @@ -250,6 +351,7 @@ def _validate_file_path( ValueError If the file does not have the expected suffix. """ + try: file = ValidFile( file_path, @@ -275,6 +377,7 @@ def _validate_dataset(ds: xr.Dataset) -> None: ValueError If `ds` is not an xarray Dataset with valid PoseTracks. """ + if not isinstance(ds, xr.Dataset): raise log_error( ValueError, f"Expected an xarray Dataset, but got {type(ds)}." diff --git a/tests/test_unit/test_save_poses.py b/tests/test_unit/test_save_poses.py index 31b31cb23..23ed84a8a 100644 --- a/tests/test_unit/test_save_poses.py +++ b/tests/test_unit/test_save_poses.py @@ -1,6 +1,7 @@ from contextlib import nullcontext as does_not_raise from pathlib import Path +import h5py import numpy as np import pandas as pd import pytest @@ -13,6 +14,47 @@ class TestSavePoses: """Test suite for the save_poses module.""" + output_files = [ + { + "file_fixture": "fake_h5_file", + "to_dlc_file_expected_exception": pytest.raises(FileExistsError), + "to_sleap_file_expected_exception": pytest.raises(FileExistsError), + # invalid file path + }, + { + "file_fixture": "directory", + "to_dlc_file_expected_exception": pytest.raises(IsADirectoryError), + "to_sleap_file_expected_exception": pytest.raises( + IsADirectoryError + ), + # invalid file path + }, + { + "file_fixture": "new_file_wrong_ext", + "to_dlc_file_expected_exception": pytest.raises(ValueError), + "to_sleap_file_expected_exception": pytest.raises(ValueError), + # invalid file path + }, + { + "file_fixture": "new_csv_file", + "to_dlc_file_expected_exception": does_not_raise(), + "to_sleap_file_expected_exception": pytest.raises(ValueError), + # valid file path for dlc, invalid for sleap + }, + { + "file_fixture": "new_h5_file", + "to_dlc_file_expected_exception": does_not_raise(), + "to_sleap_file_expected_exception": does_not_raise(), + # valid file path + }, + ] + + @pytest.fixture(params=output_files) + def output_file_params(self, request): + """Return a dictionary containing parameters for testing saving + valid pose datasets to DeepLabCut- or SLEAP-style files.""" + return request.param + @pytest.fixture def not_a_dataset(self): """Return an invalid pose tracks dataset.""" @@ -34,20 +76,31 @@ def new_file_wrong_ext(self, tmp_path): return tmp_path / "new_file_wrong_ext.txt" @pytest.fixture - def new_dlc_h5_file(self, tmp_path): - """Return the file path for a new DeepLabCut .h5 file.""" - return tmp_path / "new_dlc_file.h5" + def new_h5_file(self, tmp_path): + """Return the file path for a new .h5 file.""" + return tmp_path / "new_file.h5" @pytest.fixture - def new_dlc_csv_file(self, tmp_path): - """Return the file path for a new DeepLabCut .csv file.""" - return tmp_path / "new_dlc_file.csv" + def new_csv_file(self, tmp_path): + """Return the file path for a new .csv file.""" + return tmp_path / "new_file.csv" @pytest.fixture def missing_dim_dataset(self, valid_pose_dataset): """Return a pose tracks dataset missing a dimension.""" return valid_pose_dataset.drop_dims("time") + @pytest.fixture( + params=[ + "not_a_dataset", + "empty_dataset", + "missing_var_dataset", + "missing_dim_dataset", + ] + ) + def invalid_pose_dataset(self, request): + return request.getfixturevalue(request.param) + @pytest.mark.parametrize( "ds, expected_exception", [ @@ -95,52 +148,23 @@ def test_to_dlc_df(self, ds, expected_exception): "coords", ] - @pytest.mark.parametrize( - "file_fixture, expected_exception", - [ - ( - "fake_h5_file", - pytest.raises(FileExistsError), - ), # invalid file path - ( - "directory", - pytest.raises(IsADirectoryError), - ), # invalid file path - ( - "new_file_wrong_ext", - pytest.raises(ValueError), - ), # invalid file path - ("new_dlc_h5_file", does_not_raise()), # valid file path - ("new_dlc_csv_file", does_not_raise()), # valid file path - ], - ) def test_to_dlc_file_valid_dataset( - self, file_fixture, expected_exception, valid_pose_dataset, request + self, output_file_params, valid_pose_dataset, request ): """Test that saving a valid pose dataset to a valid/invalid DeepLabCut-style file returns the appropriate errors.""" - with expected_exception: + with output_file_params.get("to_dlc_file_expected_exception"): + file_fixture = output_file_params.get("file_fixture") val = request.getfixturevalue(file_fixture) file_path = val.get("file_path") if isinstance(val, dict) else val save_poses.to_dlc_file(valid_pose_dataset, file_path) - @pytest.mark.parametrize( - "invalid_pose_dataset", - [ - "not_a_dataset", - "empty_dataset", - "missing_var_dataset", - "missing_dim_dataset", - ], - ) - def test_to_dlc_file_invalid_dataset( - self, invalid_pose_dataset, request, tmp_path - ): + def test_to_dlc_file_invalid_dataset(self, invalid_pose_dataset, tmp_path): """Test that saving an invalid pose dataset to a valid DeepLabCut-style file returns the appropriate errors.""" with pytest.raises(ValueError): save_poses.to_dlc_file( - request.getfixturevalue(invalid_pose_dataset), + invalid_pose_dataset, tmp_path / "test.h5", split_individuals=False, ) @@ -172,16 +196,13 @@ def test_to_dlc_df_split_individuals( self, valid_pose_dataset, split_individuals, - request, ): """Test that the 'split_individuals' argument affects the behaviour of the 'to_dlc_df` function as expected """ df = save_poses.to_dlc_df(valid_pose_dataset, split_individuals) # Get the names of the individuals in the dataset - ds = request.getfixturevalue("valid_pose_dataset") - ind_names = ds.individuals.values - + ind_names = valid_pose_dataset.individuals.values if split_individuals is False: # this should produce a single df in multi-animal DLC format assert isinstance(df, pd.DataFrame) @@ -219,33 +240,92 @@ def test_to_dlc_df_split_individuals( def test_to_dlc_file_split_individuals( self, valid_pose_dataset, - new_dlc_h5_file, + new_h5_file, split_individuals, expected_exception, - request, ): """Test that the 'split_individuals' argument affects the behaviour - of the 'to_dlc_file` function as expected + of the 'to_dlc_file` function as expected. """ - with expected_exception: save_poses.to_dlc_file( valid_pose_dataset, - new_dlc_h5_file, + new_h5_file, split_individuals, ) - ds = request.getfixturevalue("valid_pose_dataset") - + # Get the names of the individuals in the dataset + ind_names = valid_pose_dataset.individuals.values # "auto" becomes False, default valid dataset is multi-individual if split_individuals in [False, "auto"]: # this should save only one file - assert new_dlc_h5_file.is_file() - new_dlc_h5_file.unlink() + assert new_h5_file.is_file() + new_h5_file.unlink() elif split_individuals is True: # this should save one file per individual - for ind in ds.individuals.values: + for ind in ind_names: file_path_ind = Path( - f"{new_dlc_h5_file.with_suffix('')}_{ind}.h5" + f"{new_h5_file.with_suffix('')}_{ind}.h5" ) assert file_path_ind.is_file() file_path_ind.unlink() + + def test_to_sleap_analysis_file_valid_dataset( + self, output_file_params, valid_pose_dataset, request + ): + """Test that saving a valid pose dataset to a valid/invalid + SLEAP-style file returns the appropriate errors.""" + with output_file_params.get("to_sleap_file_expected_exception"): + file_fixture = output_file_params.get("file_fixture") + val = request.getfixturevalue(file_fixture) + file_path = val.get("file_path") if isinstance(val, dict) else val + save_poses.to_sleap_analysis_file(valid_pose_dataset, file_path) + + def test_to_sleap_analysis_file_invalid_dataset( + self, invalid_pose_dataset, new_h5_file + ): + """Test that saving an invalid pose dataset to a valid + SLEAP-style file returns the appropriate errors.""" + with pytest.raises(ValueError): + save_poses.to_sleap_analysis_file( + invalid_pose_dataset, + new_h5_file, + ) + + @pytest.mark.parametrize( + "sleap_h5_file", + [ + "SLEAP_single-mouse_EPM.analysis.h5", + "SLEAP_three-mice_Aeon_proofread.analysis.h5", + "SLEAP_three-mice_Aeon_mixed-labels.analysis.h5", + ], + ) + def test_to_sleap_analysis_file_returns_same_h5_file_content( + self, sleap_h5_file, new_h5_file + ): + """Test that saving pose tracks from a SLEAP analysis file + to a SLEAP-style .h5 analysis file returns the same file + contents.""" + sleap_h5_file_path = POSE_DATA.get(sleap_h5_file) + ds = load_poses.from_sleap_file(sleap_h5_file_path) + save_poses.to_sleap_analysis_file(ds, new_h5_file) + + with h5py.File(sleap_h5_file_path, "r") as file_in, h5py.File( + new_h5_file, "r" + ) as file_out: + assert set(file_in.keys()) == set(file_out.keys()) + keys = [ + "track_occupancy", + "tracks", + "point_scores", + ] + for key in keys: + np.testing.assert_allclose(file_in[key][:], file_out[key][:]) + + def test_remove_unoccupied_tracks(self, valid_pose_dataset): + """Test that removing unoccupied tracks from a valid pose dataset + returns the expected result.""" + new_individuals = [f"ind{i}" for i in range(1, 4)] + # Add new individual with NaN data + ds = valid_pose_dataset.reindex(individuals=new_individuals) + ds = save_poses._remove_unoccupied_tracks(ds) + xr.testing.assert_equal(ds, valid_pose_dataset)