Skip to content

Commit

Permalink
Add support for loading "trackless" SLEAP files (#90)
Browse files Browse the repository at this point in the history
* Add support for loading "trackless" SLEAP files

* Log warning when SLEAP tracks not found
  • Loading branch information
lochhh authored Nov 24, 2023
1 parent 7e12f08 commit 88b23de
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 13 deletions.
43 changes: 31 additions & 12 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
ValidPosesCSV,
ValidPoseTracks,
)
from movement.logging import log_error
from movement.logging import log_error, log_warning

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -116,13 +116,17 @@ def from_sleap_file(
when exporting ".h5" analysis files [2]_.
*movement* expects the tracks to be assigned and proofread before loading
them, meaning each track is interpreted as a single individual/animal.
them, meaning each track is interpreted as a single individual/animal. If
no tracks are found in the file, *movement* assumes that this is a
single-individual/animal track, and will assign a default individual name.
If multiple instances without tracks are present in a frame, the last
instance is selected [2]_.
Follow the SLEAP guide for tracking and proofreading [3]_.
References
----------
.. [1] https://sleap.ai/tutorials/analysis.html
.. [2] https://github.com/talmolab/sleap/blob/v1.3.3/sleap/info/write_tracking_h5.py#L129-L150
.. [2] https://github.com/talmolab/sleap/blob/v1.3.3/sleap/info/write_tracking_h5.py#L59
.. [3] https://sleap.ai/guides/proofreading.html
Examples
Expand Down Expand Up @@ -238,14 +242,21 @@ def _load_from_sleap_analysis_file(
tracks = f["tracks"][:].transpose((3, 0, 2, 1))
# Create an array of NaNs for the confidence scores
scores = np.full(tracks.shape[:-1], np.nan)
individual_names = [n.decode() for n in f["track_names"][:]] or None
if individual_names is None:
log_warning(
f"Could not find SLEAP Track in {file.path}. "
"Assuming single-individual dataset and assigning "
"default individual name."
)
# If present, read the point-wise scores,
# and transpose to shape: (n_frames, n_tracks, n_keypoints)
if "point_scores" in f.keys():
scores = f["point_scores"][:].transpose((2, 0, 1))
return ValidPoseTracks(
tracks_array=tracks.astype(np.float32),
scores_array=scores.astype(np.float32),
individual_names=[n.decode() for n in f["track_names"][:]],
individual_names=individual_names,
keypoint_names=[n.decode() for n in f["node_names"][:]],
fps=fps,
)
Expand Down Expand Up @@ -274,10 +285,17 @@ def _load_from_sleap_labels_file(
file = ValidHDF5(file_path, expected_datasets=["pred_points", "metadata"])
labels = read_labels(file.path.as_posix())
tracks_with_scores = _sleap_labels_to_numpy(labels)
individual_names = [track.name for track in labels.tracks] or None
if individual_names is None:
log_warning(
f"Could not find SLEAP Track in {file.path}. "
"Assuming single-individual dataset and assigning "
"default individual name."
)
return ValidPoseTracks(
tracks_array=tracks_with_scores[:, :, :, :-1],
scores_array=tracks_with_scores[:, :, :, -1],
individual_names=[track.name for track in labels.tracks],
individual_names=individual_names,
keypoint_names=[kp.name for kp in labels.skeletons[0].nodes],
fps=fps,
)
Expand Down Expand Up @@ -309,7 +327,7 @@ def _sleap_labels_to_numpy(labels: Labels) -> np.ndarray:
References
----------
.. [1] https://github.com/talmolab/sleap/blob/v1.3.3/sleap/info/write_tracking_h5.py#L129-L150
.. [1] https://github.com/talmolab/sleap/blob/v1.3.3/sleap/info/write_tracking_h5.py#L59
.. [2] https://github.com/talmolab/sleap-io
"""
# Select frames from the first video only
Expand All @@ -319,7 +337,8 @@ def _sleap_labels_to_numpy(labels: Labels) -> np.ndarray:
first_frame = min(0, min(frame_idxs))
last_frame = max(0, max(frame_idxs))

n_tracks = len(labels.tracks)
n_tracks = len(labels.tracks) or 1 # If no tracks, assume 1 individual
individuals = labels.tracks or [None]
skeleton = labels.skeletons[-1] # Assume project only uses last skeleton
n_nodes = len(skeleton.nodes)
n_frames = int(last_frame - first_frame + 1)
Expand All @@ -329,21 +348,21 @@ def _sleap_labels_to_numpy(labels: Labels) -> np.ndarray:
i = int(lf.frame_idx - first_frame)
user_instances = lf.user_instances
predicted_instances = lf.predicted_instances
for j, track in enumerate(labels.tracks):
for j, ind in enumerate(individuals):
user_track_instances = [
inst for inst in user_instances if inst.track == track
inst for inst in user_instances if inst.track == ind
]
predicted_track_instances = [
inst for inst in predicted_instances if inst.track == track
inst for inst in predicted_instances if inst.track == ind
]
# Use user-labelled instance if available
if user_track_instances:
inst = user_track_instances[0]
inst = user_track_instances[-1]
tracks[i, j] = np.hstack(
(inst.numpy(), np.full((n_nodes, 1), np.nan))
)
elif predicted_track_instances:
inst = predicted_track_instances[0]
inst = predicted_track_instances[-1]
tracks[i, j] = inst.numpy(scores=True)
return tracks

Expand Down
76 changes: 75 additions & 1 deletion tests/test_unit/test_load_poses.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,67 @@
import h5py
import numpy as np
import pytest
import xarray as xr
from pytest import POSE_DATA
from sleap_io.io.slp import read_labels, write_labels
from sleap_io.model.labels import LabeledFrame, Labels

from movement.io import PosesAccessor, load_poses


class TestLoadPoses:
"""Test suite for the load_poses module."""

@pytest.fixture
def sleap_slp_file_without_tracks(self, tmp_path):
"""Mock and return the path to a SLEAP .slp file without tracks."""
sleap_file = POSE_DATA.get("SLEAP_single-mouse_EPM.predictions.slp")
labels = read_labels(sleap_file)
file_path = tmp_path / "track_is_none.slp"
lfs = []
for lf in labels.labeled_frames:
instances = []
for inst in lf.instances:
inst.track = None
inst.tracking_score = 0
instances.append(inst)
lfs.append(
LabeledFrame(
video=lf.video, frame_idx=lf.frame_idx, instances=instances
)
)
write_labels(
file_path,
Labels(
labeled_frames=lfs,
videos=labels.videos,
skeletons=labels.skeletons,
),
)
return file_path

@pytest.fixture
def sleap_h5_file_without_tracks(self, tmp_path):
"""Mock and return the path to a SLEAP .h5 file without tracks."""
sleap_file = POSE_DATA.get("SLEAP_single-mouse_EPM.analysis.h5")
file_path = tmp_path / "track_is_none.h5"
with h5py.File(sleap_file, "r") as f1, h5py.File(file_path, "w") as f2:
for key in list(f1.keys()):
if key == "track_names":
f2.create_dataset(key, data=[])
else:
f1.copy(key, f2, name=key)
return file_path

@pytest.fixture(
params=[
"sleap_h5_file_without_tracks",
"sleap_slp_file_without_tracks",
]
)
def sleap_file_without_tracks(self, request):
return request.getfixturevalue(request.param)

def assert_dataset(
self, dataset, file_path=None, expected_source_software=None
):
Expand Down Expand Up @@ -42,12 +95,33 @@ def assert_dataset(
)
assert dataset.fps is None

def test_load_from_slp_file(self, sleap_file):
def test_load_from_sleap_file(self, sleap_file):
"""Test that loading pose tracks from valid SLEAP files
returns a proper Dataset."""
ds = load_poses.from_sleap_file(sleap_file)
self.assert_dataset(ds, sleap_file, "SLEAP")

def test_load_from_sleap_file_without_tracks(
self, sleap_file_without_tracks
):
"""Test that loading pose tracks from valid SLEAP files
with tracks removed returns a dataset that matches the
original file, except for the individual names which are
set to default."""
ds_from_trackless = load_poses.from_sleap_file(
sleap_file_without_tracks
)
ds_from_tracked = load_poses.from_sleap_file(
POSE_DATA.get("SLEAP_single-mouse_EPM.analysis.h5")
)
# Check if the "individuals" coordinate matches
# the assigned default "individuals_0"
assert ds_from_trackless.individuals == ["individual_0"]
xr.testing.assert_allclose(
ds_from_trackless.drop("individuals"),
ds_from_tracked.drop("individuals"),
)

@pytest.mark.parametrize(
"slp_file, h5_file",
[
Expand Down

0 comments on commit 88b23de

Please sign in to comment.