Skip to content

Commit

Permalink
Updated datasets.py module name, docstrings, and functions
Browse files Browse the repository at this point in the history
  • Loading branch information
b-peri committed Jan 4, 2024
1 parent 8865eba commit 687a1ed
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 18 deletions.
6 changes: 3 additions & 3 deletions examples/load_and_explore_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,22 @@
# -------
from matplotlib import pyplot as plt

from movement import datasets
from movement import sample_datasets
from movement.io import load_poses

# %%
# Fetch an example dataset
# ------------------------
# Print a list of available datasets:

for file_name in datasets.list_pose_data():
for file_name in sample_datasets.list_sample_data():
print(file_name)

# %%
# Fetch the path to an example dataset.
# Feel free to replace this with the path to your own dataset.
# e.g., ``file_path = "/path/to/my/data.h5"``)
file_path = datasets.fetch_pose_data_path(
file_path = sample_datasets.fetch_sample_data_path(
"SLEAP_three-mice_Aeon_proofread.analysis.h5"
)

Expand Down
35 changes: 22 additions & 13 deletions movement/datasets.py → movement/sample_datasets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Module for fetching and loading datasets.
"""Module for fetching and loading sample datasets.
This module provides functions for fetching and loading data used in tests,
examples, and tutorials. The data are stored in a remote repository on GIN
and are downloaded to the user's local machine the first time they are used.
This module provides functions for fetching and loading sample data used in
tests, examples, and tutorials. The data are stored in a remote repository
on GIN and are downloaded to the user's local machine the first time they
are used.
"""

from pathlib import Path
Expand Down Expand Up @@ -36,20 +37,20 @@
)

with open(METADATA_PATH, "r") as sample_info:
metadata = yaml.safe_load(sample_info)
METADATA = yaml.safe_load(sample_info)

sample_registry = {file["file_name"]: file["sha256sum"] for file in metadata}
SAMPLE_REGISTRY = {file["file_name"]: file["sha256sum"] for file in METADATA}

# Create a download manager for the pose data
POSE_DATA = pooch.create(
path=DATA_DIR / "poses",
base_url=f"{DATA_URL}/poses/",
retry_if_failed=0,
registry=sample_registry,
registry=SAMPLE_REGISTRY,
)


def list_pose_data() -> list[str]:
def list_sample_data() -> list[str]:
"""Find available sample pose data in the *movement* data repository.
Returns
Expand All @@ -59,7 +60,7 @@ def list_pose_data() -> list[str]:
return list(POSE_DATA.registry.keys())


def fetch_pose_data_path(filename: str) -> Path:
def fetch_sample_data_path(filename: str) -> Path:
"""Fetch sample pose data from the *movement* data repository.
The data are downloaded to the user's local machine the first time they are
Expand All @@ -79,7 +80,9 @@ def fetch_pose_data_path(filename: str) -> Path:
return Path(POSE_DATA.fetch(filename, progressbar=True))


def fetch_pose_data(filename: str) -> xarray.Dataset:
def fetch_sample_data(
filename: str,
) -> xarray.Dataset: # TODO: Add LightningPose
"""Fetch sample pose data from the *movement* data repository.
The data are downloaded to the user's local machine the first time they are
Expand All @@ -97,9 +100,15 @@ def fetch_pose_data(filename: str) -> xarray.Dataset:
Pose data contained in the fetched sample file.
"""

file_path = fetch_pose_data_path(filename)
if filename.startswith("SLEAP"):
file_path = fetch_sample_data_path(filename)
file_metadata = next(
file for file in METADATA if file["file_name"] == filename
)

if file_metadata["source_software"] == "SLEAP":
ds = load_poses.from_sleap_file(file_path)
elif filename.startswith("DLC"):
elif file_metadata["source_software"] == "DeepLabCut":
ds = load_poses.from_dlc_file(file_path)
elif file_metadata["source_software"] == "LightningPose":
pass
return ds
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
import pytest
import xarray as xr

from movement.datasets import fetch_pose_data_path
from movement.io import PosesAccessor
from movement.logging import configure_logging
from movement.sample_datasets import fetch_sample_data_path


def pytest_configure():
"""Perform initial configuration for pytest.
Fetches pose data file paths as a dictionary for tests."""
pytest.POSE_DATA = {
file_name: fetch_pose_data_path(file_name)
file_name: fetch_sample_data_path(file_name)
for file_name in [
"DLC_single-wasp.predictions.h5",
"DLC_single-wasp.predictions.csv",
Expand Down

0 comments on commit 687a1ed

Please sign in to comment.