Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for loading "trackless" SLEAP files #90

Merged
merged 2 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
ValidPosesCSV,
ValidPoseTracks,
)
from movement.logging import log_error
from movement.logging import log_error, log_warning

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -116,13 +116,17 @@ def from_sleap_file(
when exporting ".h5" analysis files [2]_.

*movement* expects the tracks to be assigned and proofread before loading
them, meaning each track is interpreted as a single individual/animal.
them, meaning each track is interpreted as a single individual/animal. If
no tracks are found in the file, *movement* assumes that this is a
single-individual/animal track, and will assign a default individual name.
If multiple instances without tracks are present in a frame, the last
instance is selected [2]_.
Follow the SLEAP guide for tracking and proofreading [3]_.

References
----------
.. [1] https://sleap.ai/tutorials/analysis.html
.. [2] https://github.com/talmolab/sleap/blob/v1.3.3/sleap/info/write_tracking_h5.py#L129-L150
.. [2] https://github.com/talmolab/sleap/blob/v1.3.3/sleap/info/write_tracking_h5.py#L59
.. [3] https://sleap.ai/guides/proofreading.html

Examples
Expand Down Expand Up @@ -238,14 +242,21 @@ def _load_from_sleap_analysis_file(
tracks = f["tracks"][:].transpose((3, 0, 2, 1))
# Create an array of NaNs for the confidence scores
scores = np.full(tracks.shape[:-1], np.nan)
individual_names = [n.decode() for n in f["track_names"][:]] or None
if individual_names is None:
log_warning(
f"Could not find SLEAP Track in {file.path}. "
"Assuming single-individual dataset and assigning "
"default individual name."
niksirbi marked this conversation as resolved.
Show resolved Hide resolved
)
# If present, read the point-wise scores,
# and transpose to shape: (n_frames, n_tracks, n_keypoints)
if "point_scores" in f.keys():
scores = f["point_scores"][:].transpose((2, 0, 1))
return ValidPoseTracks(
tracks_array=tracks.astype(np.float32),
scores_array=scores.astype(np.float32),
individual_names=[n.decode() for n in f["track_names"][:]],
individual_names=individual_names,
keypoint_names=[n.decode() for n in f["node_names"][:]],
fps=fps,
)
Expand Down Expand Up @@ -274,10 +285,17 @@ def _load_from_sleap_labels_file(
file = ValidHDF5(file_path, expected_datasets=["pred_points", "metadata"])
labels = read_labels(file.path.as_posix())
tracks_with_scores = _sleap_labels_to_numpy(labels)
individual_names = [track.name for track in labels.tracks] or None
if individual_names is None:
log_warning(
f"Could not find SLEAP Track in {file.path}. "
"Assuming single-individual dataset and assigning "
"default individual name."
niksirbi marked this conversation as resolved.
Show resolved Hide resolved
)
return ValidPoseTracks(
tracks_array=tracks_with_scores[:, :, :, :-1],
scores_array=tracks_with_scores[:, :, :, -1],
individual_names=[track.name for track in labels.tracks],
individual_names=individual_names,
keypoint_names=[kp.name for kp in labels.skeletons[0].nodes],
fps=fps,
)
Expand Down Expand Up @@ -309,7 +327,7 @@ def _sleap_labels_to_numpy(labels: Labels) -> np.ndarray:

References
----------
.. [1] https://github.com/talmolab/sleap/blob/v1.3.3/sleap/info/write_tracking_h5.py#L129-L150
.. [1] https://github.com/talmolab/sleap/blob/v1.3.3/sleap/info/write_tracking_h5.py#L59
.. [2] https://github.com/talmolab/sleap-io
"""
# Select frames from the first video only
Expand All @@ -319,7 +337,8 @@ def _sleap_labels_to_numpy(labels: Labels) -> np.ndarray:
first_frame = min(0, min(frame_idxs))
last_frame = max(0, max(frame_idxs))

n_tracks = len(labels.tracks)
n_tracks = len(labels.tracks) or 1 # If no tracks, assume 1 individual
individuals = labels.tracks or [None]
skeleton = labels.skeletons[-1] # Assume project only uses last skeleton
n_nodes = len(skeleton.nodes)
n_frames = int(last_frame - first_frame + 1)
Expand All @@ -329,21 +348,21 @@ def _sleap_labels_to_numpy(labels: Labels) -> np.ndarray:
i = int(lf.frame_idx - first_frame)
user_instances = lf.user_instances
predicted_instances = lf.predicted_instances
for j, track in enumerate(labels.tracks):
for j, ind in enumerate(individuals):
user_track_instances = [
inst for inst in user_instances if inst.track == track
inst for inst in user_instances if inst.track == ind
]
predicted_track_instances = [
inst for inst in predicted_instances if inst.track == track
inst for inst in predicted_instances if inst.track == ind
]
# Use user-labelled instance if available
if user_track_instances:
inst = user_track_instances[0]
inst = user_track_instances[-1]
tracks[i, j] = np.hstack(
(inst.numpy(), np.full((n_nodes, 1), np.nan))
)
elif predicted_track_instances:
inst = predicted_track_instances[0]
inst = predicted_track_instances[-1]
tracks[i, j] = inst.numpy(scores=True)
return tracks

Expand Down
76 changes: 75 additions & 1 deletion tests/test_unit/test_load_poses.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,67 @@
import h5py
import numpy as np
import pytest
import xarray as xr
from pytest import POSE_DATA
from sleap_io.io.slp import read_labels, write_labels
from sleap_io.model.labels import LabeledFrame, Labels

from movement.io import PosesAccessor, load_poses


class TestLoadPoses:
"""Test suite for the load_poses module."""

@pytest.fixture
def sleap_slp_file_without_tracks(self, tmp_path):
"""Mock and return the path to a SLEAP .slp file without tracks."""
sleap_file = POSE_DATA.get("SLEAP_single-mouse_EPM.predictions.slp")
labels = read_labels(sleap_file)
file_path = tmp_path / "track_is_none.slp"
lfs = []
for lf in labels.labeled_frames:
instances = []
for inst in lf.instances:
inst.track = None
inst.tracking_score = 0
instances.append(inst)
lfs.append(
LabeledFrame(
video=lf.video, frame_idx=lf.frame_idx, instances=instances
)
)
write_labels(
file_path,
Labels(
labeled_frames=lfs,
videos=labels.videos,
skeletons=labels.skeletons,
),
)
return file_path

@pytest.fixture
def sleap_h5_file_without_tracks(self, tmp_path):
"""Mock and return the path to a SLEAP .h5 file without tracks."""
sleap_file = POSE_DATA.get("SLEAP_single-mouse_EPM.analysis.h5")
file_path = tmp_path / "track_is_none.h5"
with h5py.File(sleap_file, "r") as f1, h5py.File(file_path, "w") as f2:
for key in list(f1.keys()):
if key == "track_names":
f2.create_dataset(key, data=[])
else:
f1.copy(key, f2, name=key)
return file_path

@pytest.fixture(
params=[
"sleap_h5_file_without_tracks",
"sleap_slp_file_without_tracks",
]
)
def sleap_file_without_tracks(self, request):
return request.getfixturevalue(request.param)

def assert_dataset(
self, dataset, file_path=None, expected_source_software=None
):
Expand Down Expand Up @@ -42,12 +95,33 @@ def assert_dataset(
)
assert dataset.fps is None

def test_load_from_slp_file(self, sleap_file):
def test_load_from_sleap_file(self, sleap_file):
"""Test that loading pose tracks from valid SLEAP files
returns a proper Dataset."""
ds = load_poses.from_sleap_file(sleap_file)
self.assert_dataset(ds, sleap_file, "SLEAP")

def test_load_from_sleap_file_without_tracks(
self, sleap_file_without_tracks
):
"""Test that loading pose tracks from valid SLEAP files
with tracks removed returns a dataset that matches the
original file, except for the individual names which are
set to default."""
ds_from_trackless = load_poses.from_sleap_file(
sleap_file_without_tracks
)
ds_from_tracked = load_poses.from_sleap_file(
POSE_DATA.get("SLEAP_single-mouse_EPM.analysis.h5")
)
# Check if the "individuals" coordinate matches
# the assigned default "individuals_0"
assert ds_from_trackless.individuals == ["individual_0"]
xr.testing.assert_allclose(
ds_from_trackless.drop("individuals"),
ds_from_tracked.drop("individuals"),
)

@pytest.mark.parametrize(
"slp_file, h5_file",
[
Expand Down