Skip to content

Commit

Permalink
Draft export to sleap h5
Browse files Browse the repository at this point in the history
  • Loading branch information
lochhh committed Nov 21, 2023
1 parent 848d1a1 commit cc11124
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 56 deletions.
105 changes: 104 additions & 1 deletion movement/io/save_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import Literal, Union

import h5py
import numpy as np
import pandas as pd
import xarray as xr
Expand Down Expand Up @@ -112,6 +113,7 @@ def to_dlc_df(
to_dlc_file : Save the xarray dataset containing pose tracks directly
to a DeepLabCut-style .h5 or .csv file.
"""

_validate_dataset(ds)
scorer = ["movement"]
bodyparts = ds.coords["keypoints"].data.tolist()
Expand Down Expand Up @@ -187,7 +189,7 @@ def to_dlc_file(
Examples
--------
>>> from movement.io import save_poses, load_poses
>>> ds = load_poses.from_sleap("/path/to/file_sleap.analysis.h5")
>>> ds = load_poses.from_sleap_file("/path/to/file_sleap.analysis.h5")
>>> save_poses.to_dlc_file(ds, "/path/to/file_dlc.h5")
"""

Expand Down Expand Up @@ -224,6 +226,105 @@ def to_dlc_file(
logger.info(f"Saved PoseTracks dataset to {file.path}.")


def to_sleap_analysis_file(
ds: xr.Dataset, file_path: Union[str, Path]
) -> None:
"""Save the xarray dataset containing pose tracks to a SLEAP-style
.h5 analysis file.
Parameters
----------
ds : xarray.Dataset
Dataset containing pose tracks, confidence scores, and metadata.
file_path : pathlib.Path or str
Path to the file to save the poses to. The file extension
must be .h5 (recommended) or .csv.
Examples
--------
>>> from movement.io import save_poses, load_poses
>>> ds = load_poses.from_dlc_file("path/to/file.h5")
>>> save_poses.to_sleap_analysis_file(
... ds, "/path/to/file_sleap.analysis.h5"
... )
"""

file = _validate_file_path(file_path=file_path, expected_suffix=[".h5"])
_validate_dataset(ds)

ds = _remove_unoccupied_tracks(ds)

# Target shapes:
# "track_occupancy" n_frames * n_individuals
# "tracks" n_individuals * n_space * n_keypoints * n_frames
# "track_names" n_individuals
# "point_scores" n_individuals * n_keypoints * n_frames
# "instance_scores" n_individuals * n_frames
# "tracking_scores" n_individuals * n_frames
individual_names = ds.individuals.values.tolist()
n_individuals = len(individual_names)
keypoint_names = ds.keypoints.values.tolist()
# Compute frame indices from fps, if set
if ds.fps is not None:
frame_idxs = np.rint(ds.time.values * ds.fps).astype(int).tolist()
else:
frame_idxs = ds.time.values.astype(int).tolist()
n_frames = frame_idxs[-1] - frame_idxs[0] + 1
pos_x = ds.pose_tracks.sel(space="x").values
# Mask denoting which individuals are present in each frame
track_occupancy = (~np.all(np.isnan(pos_x), axis=2)).astype(int)
tracks = np.transpose(ds.pose_tracks.data, (1, 3, 2, 0))
point_scores = np.transpose(ds.confidence.data, (1, 2, 0))
instance_scores = np.full((n_individuals, n_frames), np.nan, dtype=float)
tracking_scores = np.full((n_individuals, n_frames), np.nan, dtype=float)

data_dict = dict(
track_names=individual_names,
node_names=keypoint_names,
tracks=tracks,
track_occupancy=track_occupancy,
point_scores=point_scores,
instance_scores=instance_scores,
tracking_scores=tracking_scores,
labels_path=ds.source_file,
edge_names=[],
edge_inds=[],
video_path="",
video_ind=0,
provenance="{}",
)
with h5py.File(file.path, "w") as f:
for key, val in data_dict.items():
if isinstance(val, np.ndarray):
f.create_dataset(
key,
data=val,
compression="gzip",
compression_opts=9,
)
else:
f.create_dataset(key, data=val)
logger.info(f"Saved PoseTracks dataset to {file.path}.")


def _remove_unoccupied_tracks(ds: xr.Dataset):
"""Remove tracks that are completely unoccupied in the xarray dataset.
Parameters
----------
ds : xarray.Dataset
Dataset containing pose tracks, confidence scores, and metadata.
Returns
-------
xarray.Dataset
The input dataset without the unoccupied tracks.
"""

all_nan = ds.pose_tracks.isnull().all(dim=["keypoints", "space", "time"])
return ds.where(~all_nan, drop=True)


def _validate_file_path(
file_path: Union[str, Path], expected_suffix: list[str]
) -> ValidFile:
Expand All @@ -250,6 +351,7 @@ def _validate_file_path(
ValueError
If the file does not have the expected suffix.
"""

try:
file = ValidFile(
file_path,
Expand All @@ -275,6 +377,7 @@ def _validate_dataset(ds: xr.Dataset) -> None:
ValueError
If `ds` is not an xarray Dataset with valid PoseTracks.
"""

if not isinstance(ds, xr.Dataset):
raise log_error(
ValueError, f"Expected an xarray Dataset, but got {type(ds)}."
Expand Down
Loading

0 comments on commit cc11124

Please sign in to comment.