From 687a1edb17689e60232eb468d02f79626e3eb519 Mon Sep 17 00:00:00 2001 From: b-peri Date: Thu, 4 Jan 2024 14:30:59 +0000 Subject: [PATCH] Updated `datasets.py` module name, docstrings, and functions --- examples/load_and_explore_poses.py | 6 ++-- movement/{datasets.py => sample_datasets.py} | 35 ++++++++++++-------- tests/conftest.py | 4 +-- 3 files changed, 27 insertions(+), 18 deletions(-) rename movement/{datasets.py => sample_datasets.py} (70%) diff --git a/examples/load_and_explore_poses.py b/examples/load_and_explore_poses.py index de5f5196e..1b71d870d 100644 --- a/examples/load_and_explore_poses.py +++ b/examples/load_and_explore_poses.py @@ -10,7 +10,7 @@ # ------- from matplotlib import pyplot as plt -from movement import datasets +from movement import sample_datasets from movement.io import load_poses # %% @@ -18,14 +18,14 @@ # ------------------------ # Print a list of available datasets: -for file_name in datasets.list_pose_data(): +for file_name in sample_datasets.list_sample_data(): print(file_name) # %% # Fetch the path to an example dataset. # Feel free to replace this with the path to your own dataset. # e.g., ``file_path = "/path/to/my/data.h5"``) -file_path = datasets.fetch_pose_data_path( +file_path = sample_datasets.fetch_sample_data_path( "SLEAP_three-mice_Aeon_proofread.analysis.h5" ) diff --git a/movement/datasets.py b/movement/sample_datasets.py similarity index 70% rename from movement/datasets.py rename to movement/sample_datasets.py index 90bd6f3e5..34dbb4e29 100644 --- a/movement/datasets.py +++ b/movement/sample_datasets.py @@ -1,8 +1,9 @@ -"""Module for fetching and loading datasets. +"""Module for fetching and loading sample datasets. -This module provides functions for fetching and loading data used in tests, -examples, and tutorials. The data are stored in a remote repository on GIN -and are downloaded to the user's local machine the first time they are used. +This module provides functions for fetching and loading sample data used in +tests, examples, and tutorials. The data are stored in a remote repository +on GIN and are downloaded to the user's local machine the first time they +are used. """ from pathlib import Path @@ -36,20 +37,20 @@ ) with open(METADATA_PATH, "r") as sample_info: - metadata = yaml.safe_load(sample_info) + METADATA = yaml.safe_load(sample_info) -sample_registry = {file["file_name"]: file["sha256sum"] for file in metadata} +SAMPLE_REGISTRY = {file["file_name"]: file["sha256sum"] for file in METADATA} # Create a download manager for the pose data POSE_DATA = pooch.create( path=DATA_DIR / "poses", base_url=f"{DATA_URL}/poses/", retry_if_failed=0, - registry=sample_registry, + registry=SAMPLE_REGISTRY, ) -def list_pose_data() -> list[str]: +def list_sample_data() -> list[str]: """Find available sample pose data in the *movement* data repository. Returns @@ -59,7 +60,7 @@ def list_pose_data() -> list[str]: return list(POSE_DATA.registry.keys()) -def fetch_pose_data_path(filename: str) -> Path: +def fetch_sample_data_path(filename: str) -> Path: """Fetch sample pose data from the *movement* data repository. The data are downloaded to the user's local machine the first time they are @@ -79,7 +80,9 @@ def fetch_pose_data_path(filename: str) -> Path: return Path(POSE_DATA.fetch(filename, progressbar=True)) -def fetch_pose_data(filename: str) -> xarray.Dataset: +def fetch_sample_data( + filename: str, +) -> xarray.Dataset: # TODO: Add LightningPose """Fetch sample pose data from the *movement* data repository. The data are downloaded to the user's local machine the first time they are @@ -97,9 +100,15 @@ def fetch_pose_data(filename: str) -> xarray.Dataset: Pose data contained in the fetched sample file. """ - file_path = fetch_pose_data_path(filename) - if filename.startswith("SLEAP"): + file_path = fetch_sample_data_path(filename) + file_metadata = next( + file for file in METADATA if file["file_name"] == filename + ) + + if file_metadata["source_software"] == "SLEAP": ds = load_poses.from_sleap_file(file_path) - elif filename.startswith("DLC"): + elif file_metadata["source_software"] == "DeepLabCut": ds = load_poses.from_dlc_file(file_path) + elif file_metadata["source_software"] == "LightningPose": + pass return ds diff --git a/tests/conftest.py b/tests/conftest.py index e4baa670a..e4ef0a55b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,16 +8,16 @@ import pytest import xarray as xr -from movement.datasets import fetch_pose_data_path from movement.io import PosesAccessor from movement.logging import configure_logging +from movement.sample_datasets import fetch_sample_data_path def pytest_configure(): """Perform initial configuration for pytest. Fetches pose data file paths as a dictionary for tests.""" pytest.POSE_DATA = { - file_name: fetch_pose_data_path(file_name) + file_name: fetch_sample_data_path(file_name) for file_name in [ "DLC_single-wasp.predictions.h5", "DLC_single-wasp.predictions.csv",