diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 55a3454d..f5345e0e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,6 +32,8 @@ repos: - types-setuptools - pandas-stubs - types-attrs + - types-PyYAML + - types-requests - repo: https://github.com/mgedmin/check-manifest rev: "0.49" hooks: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 95ece7ad..27161124 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -259,9 +259,9 @@ by the [German Neuroinformatics Node](https://www.g-node.org/). GIN has a GitHub-like interface and git-like [CLI](gin:G-Node/Info/wiki/GIN+CLI+Setup#quickstart) functionalities. -Currently the data repository contains sample pose estimation data files -stored in the `poses` folder. Each file name starts with either "DLC" or "SLEAP", -depending on the pose estimation software used to generate the data. +Currently, the data repository contains sample pose estimation data files +stored in the `poses` folder. Metadata for these files, including information +about their provenance, is stored in the `poses_files_metadata.yaml` file. ### Fetching data To fetch the data from GIN, we use the [pooch](https://www.fatiando.org/pooch/latest/index.html) @@ -269,15 +269,16 @@ Python package, which can download data from pre-specified URLs and store them locally for all subsequent uses. It also provides some nice utilities, like verification of sha256 hashes and decompression of archives. -The relevant functionality is implemented in the `movement.datasets.py` module. +The relevant functionality is implemented in the `movement.sample_data.py` module. The most important parts of this module are: -1. The `POSE_DATA` download manager object, which contains a list of stored files and their known hashes. -2. The `list_pose_data()` function, which returns a list of the available files in the data repository. -3. The `fetch_pose_data_path()` function, which downloads a file (if not already cached locally) and returns the local path to it. +1. The `SAMPLE_DATA` download manager object. +2. The `list_sample_data()` function, which returns a list of the available files in the data repository. +3. The `fetch_sample_data_path()` function, which downloads a file (if not already cached locally) and returns the local path to it. +4. The `fetch_sample_data()` function, which downloads a file and loads it into movement directly, returning an `xarray.Dataset` object. By default, the downloaded files are stored in the `~/.movement/data` folder. -This can be changed by setting the `DATA_DIR` variable in the `movement.datasets.py` module. +This can be changed by setting the `DATA_DIR` variable in the `movement.sample_data.py` module. ### Adding new data Only core movement developers may add new files to the external data repository. @@ -287,9 +288,8 @@ To add a new file, you will need to: 2. Ask to be added as a collaborator on the [movement data repository](gin:neuroinformatics/movement-test-data) (if not already) 3. Download the [GIN CLI](gin:G-Node/Info/wiki/GIN+CLI+Setup#quickstart) and set it up with your GIN credentials, by running `gin login` in a terminal. 4. Clone the movement data repository to your local machine, by running `gin get neuroinformatics/movement-test-data` in a terminal. -5. Add your new files and commit them with `gin commit -m `. -6. Upload the commited changes to the GIN repository, by running `gin upload`. Latest changes to the repository can be pulled via `gin download`. `gin sync` will synchronise the latest changes bidirectionally. -7. Determine the sha256 checksum hash of each new file, by running `sha256sum ` in a terminal. Alternatively, you can use `pooch` to do this for you: `python -c "import pooch; pooch.file_hash('/path/to/file')"`. If you wish to generate a text file containing the hashes of all the files in a given folder, you can use `python -c "import pooch; pooch.make_registry('/path/to/folder', 'sha256_registry.txt')`. -8. Update the `movement.datasets.py` module on the [movement GitHub repository](movement-github:) by adding the new files to the `POSE_DATA` registry. Make sure to include the correct sha256 hash, as determined in the previous step. Follow all the usual [guidelines for contributing code](#contributing-code). Make sure to test whether the new files can be fetched successfully (see [fetching data](#fetching-data) above) before submitting your pull request. - -You can also perform steps 3-6 via the GIN web interface, if you prefer to avoid using the CLI. +5. Add your new files to `/movement-test-data/poses/`. +6. Determine the sha256 checksum hash of each new file by running `sha256sum ` in a terminal. Alternatively, you can use `pooch` to do this for you: `python -c "import pooch; hash = pooch.file_hash('/path/to/file'); print(hash)"`. If you wish to generate a text file containing the hashes of all the files in a given folder, you can use `python -c "import pooch; pooch.make_registry('/path/to/folder', 'sha256_registry.txt')`. +7. Add metadata for your new files to `poses_files_metadata.yaml`, including their sha256 hashes. +8. Commit your changes using `gin commit -m `. +9. Upload the committed changes to the GIN repository by running `gin upload`. Latest changes to the repository can be pulled via `gin download`. `gin sync` will synchronise the latest changes bidirectionally. diff --git a/docs/source/api_index.rst b/docs/source/api_index.rst index c191e086..293e98b5 100644 --- a/docs/source/api_index.rst +++ b/docs/source/api_index.rst @@ -33,14 +33,15 @@ Input/Output ValidPosesCSV ValidPoseTracks -Datasets --------- -.. currentmodule:: movement.datasets +Sample Data +----------- +.. currentmodule:: movement.sample_data .. autosummary:: :toctree: api - list_pose_data - fetch_pose_data_path + list_sample_data + fetch_sample_data_path + fetch_sample_data Logging ------- diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md index bf7b7a26..7b7d6e52 100644 --- a/docs/source/getting_started.md +++ b/docs/source/getting_started.md @@ -53,7 +53,7 @@ Please see the [contributing guide](target-contributing) for more information. ## Loading data You can load predicted pose tracks from the pose estimation software packages -[DeepLabCut](dlc:) or [SLEAP](sleap:). +[DeepLabCut](dlc:), [SLEAP](sleap:), or [LightingPose](lp:). First import the `movement.io.load_poses` module: @@ -114,27 +114,36 @@ You can also try movement out on some sample data included in the package. You can view the available sample data files with: ```python -from movement import datasets +from movement import sample_data -file_names = datasets.list_pose_data() +file_names = sample_data.list_sample_data() print(file_names) ``` + This will print a list of file names containing sample pose data. -The files are prefixed with the name of the pose estimation software package, -either "DLC" or "SLEAP". +Each file is prefixed with the name of the pose estimation software package +that was used to generate it - either "DLC", "SLEAP", or "LP". To get the path to one of the sample files, you can use the `fetch_pose_data_path` function: ```python -file_path = datasets.fetch_pose_data_path("DLC_two-mice.predictions.csv") +file_path = sample_data.fetch_sample_data_path("DLC_two-mice.predictions.csv") ``` The first time you call this function, it will download the corresponding file to your local machine and save it in the `~/.movement/data` directory. On subsequent calls, it will simply return the path to that local file. -You can feed the path to the `from_dlc_file` or `from_sleap_file` functions -and load the data, as shown above. +You can feed the path to the `from_dlc_file`, `from_sleap_file`, or +`from_lp_file` functions and load the data, as shown above. + +Alternatively, you can skip the `fetch_sample_data_path()` step and load the +data directly using the `fetch_sample_data()` function: + +```python +ds = sample_data.fetch_sample_data("DLC_two-mice.predictions.csv") +``` + ::: ## Working with movement datasets diff --git a/examples/load_and_explore_poses.py b/examples/load_and_explore_poses.py index de5f5196..4a68ee7e 100644 --- a/examples/load_and_explore_poses.py +++ b/examples/load_and_explore_poses.py @@ -10,7 +10,7 @@ # ------- from matplotlib import pyplot as plt -from movement import datasets +from movement import sample_data from movement.io import load_poses # %% @@ -18,14 +18,14 @@ # ------------------------ # Print a list of available datasets: -for file_name in datasets.list_pose_data(): +for file_name in sample_data.list_sample_data(): print(file_name) # %% # Fetch the path to an example dataset. # Feel free to replace this with the path to your own dataset. # e.g., ``file_path = "/path/to/my/data.h5"``) -file_path = datasets.fetch_pose_data_path( +file_path = sample_data.fetch_sample_data_path( "SLEAP_three-mice_Aeon_proofread.analysis.h5" ) diff --git a/movement/datasets.py b/movement/datasets.py deleted file mode 100644 index 6853d952..00000000 --- a/movement/datasets.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Module for fetching and loading datasets. - -This module provides functions for fetching and loading data used in tests, -examples, and tutorials. The data are stored in a remote repository on GIN -and are downloaded to the user's local machine the first time they are used. -""" - -from pathlib import Path - -import pooch - -# URL to the remote data repository on GIN -# noinspection PyInterpreter -DATA_URL = ( - "https://gin.g-node.org/neuroinformatics/movement-test-data/raw/master" -) - -# Save data in ¬/.movement/data -DATA_DIR = Path("~", ".movement", "data").expanduser() -# Create the folder if it doesn't exist -DATA_DIR.mkdir(parents=True, exist_ok=True) - -# Create a download manager for the pose data -POSE_DATA = pooch.create( - path=DATA_DIR / "poses", - base_url=f"{DATA_URL}/poses/", - retry_if_failed=0, - registry={ - "DLC_single-wasp.predictions.h5": "931dddb6ef5e08db6054d3757a441ee31dd0d9ff5a10802ad8405d6c4e7e274e", # noqa: E501 - "DLC_single-wasp.predictions.csv": "9b194cc930c2e2e0d33c816320d029f889b306d53ff2fe95ff408e99c9cdea23", # noqa: E501 - "DLC_single-mouse_EPM.predictions.h5": "0ddc2b08c9401435929783b22ea31b3673ceb80c3a02c5f3531bb1cfd78deea5", # noqa: E501 - "DLC_two-mice.predictions.csv": "6e891ab4e14a3ad74451ed0899f770b47aa3eb959d9824d3782827297d7c75e0", # noqa: E501 - "LP_mouse-face_AIND.predictions.csv": "13620f20a3fbb20c9ef0ad6fd8898944d2cfe3e83c5fe18946b7592e0a718da2", # noqa: E501 - "LP_mouse-twoview_AIND.predictions.csv": "22cded593ff77226dffab3c10f87874e597d86d70d62a81e5c4c690a96fd1e49", # noqa: E501 - "SLEAP_single-mouse_EPM.analysis.h5": "0df0a09c2493a1d9964ba98cbf751eda62743f1d688ae82b6df7b0f77169ed47", # noqa: E501 - "SLEAP_single-mouse_EPM.predictions.slp": "ca620db6123635761ddf69947f72f653d14a59137b355bd2d8f7c2f1be67e474", # noqa: E501 - "SLEAP_two-mice_social-interaction.analysis.h5": "f7f1e59d4b2c34712089f8aaf2390272291d93e6991c1abe32d9ce798a6234f9", # noqa: E501 - "SLEAP_two-mice_social-interaction.predictions.slp": "45881affde9704c045e70b8d4b3f6bbb8d9bd8ef9f4cdea6d173cfe35857549b", # noqa: E501 - "SLEAP_three-mice_Aeon_proofread.analysis.h5": "82ebd281c406a61536092863bc51d1a5c7c10316275119f7daf01c1ff33eac2a", # noqa: E501 - "SLEAP_three-mice_Aeon_proofread.predictions.slp": "7b7436a52dfd5f4d80d7c66919ad1a1732e5435fe33faf9011ec5f7b7074e788", # noqa: E501 - "SLEAP_three-mice_Aeon_mixed-labels.analysis.h5": "899651ec027eb8fd6181246f89142ad1c0a40f14394fc8144d44ea093c0e137d", # noqa: E501 - "SLEAP_three-mice_Aeon_mixed-labels.predictions.slp": "6d3f2c5446e9c12aabf28d5a9470835736b0419dcefebba89c305114f83b82d1", # noqa: E501 - }, -) - - -def list_pose_data() -> list[str]: - """Find available sample pose data in the *movement* data repository. - - Returns - ------- - filenames : list of str - List of filenames for available pose data.""" - return list(POSE_DATA.registry.keys()) - - -def fetch_pose_data_path(filename: str) -> Path: - """Fetch sample pose data from the *movement* data repository. - - The data are downloaded to the user's local machine the first time they are - used and are stored in a local cache directory. The function returns the - path to the downloaded file, not the contents of the file itself. - - Parameters - ---------- - filename : str - Name of the file to fetch. - - Returns - ------- - path : pathlib.Path - Path to the downloaded file. - """ - return Path(POSE_DATA.fetch(filename, progressbar=True)) diff --git a/movement/sample_data.py b/movement/sample_data.py new file mode 100644 index 00000000..03b59d35 --- /dev/null +++ b/movement/sample_data.py @@ -0,0 +1,188 @@ +"""Module for fetching and loading sample datasets. + +This module provides functions for fetching and loading sample data used in +tests, examples, and tutorials. The data are stored in a remote repository +on GIN and are downloaded to the user's local machine the first time they +are used. +""" + +import logging +from pathlib import Path + +import pooch +import xarray +import yaml +from requests.exceptions import RequestException + +from movement.io import load_poses +from movement.logging import log_error, log_warning + +logger = logging.getLogger(__name__) + +# URL to the remote data repository on GIN +# noinspection PyInterpreter +DATA_URL = ( + "https://gin.g-node.org/neuroinformatics/movement-test-data/raw/master" +) + +# Save data in ~/.movement/data +DATA_DIR = Path("~", ".movement", "data").expanduser() +# Create the folder if it doesn't exist +DATA_DIR.mkdir(parents=True, exist_ok=True) + + +def _download_metadata_file(file_name: str, data_dir: Path = DATA_DIR) -> Path: + """Download the yaml file containing sample metadata from the *movement* + data repository and save it in the specified directory with a temporary + filename - temp_{file_name} - to avoid overwriting any existing files. + + Parameters + ---------- + file_name : str + Name of the metadata file to fetch. + data_dir : pathlib.Path, optional + Directory to store the metadata file in. Defaults to the constant + ``DATA_DIR``. Can be overridden for testing purposes. + + Returns + ------- + path : pathlib.Path + Path to the downloaded file. + """ + local_file_path = pooch.retrieve( + url=f"{DATA_URL}/{file_name}", + known_hash=None, + path=data_dir, + fname=f"temp_{file_name}", + progressbar=False, + ) + logger.debug( + f"Successfully downloaded sample metadata file {file_name} " + f"from {DATA_URL} to {data_dir}" + ) + return Path(local_file_path) + + +def _fetch_metadata(file_name: str, data_dir: Path = DATA_DIR) -> list[dict]: + """Download the yaml file containing metadata from the *movement* sample + data repository and load it as a list of dictionaries. + + Parameters + ---------- + file_name : str + Name of the metadata file to fetch. + data_dir : pathlib.Path, optional + Directory to store the metadata file in. Defaults to + the constant ``DATA_DIR``. Can be overridden for testing purposes. + + Returns + ------- + list[dict] + A list of dictionaries containing metadata for each sample file. + """ + + local_file_path = Path(data_dir / file_name) + failed_msg = "Failed to download the newest sample metadata file." + + # try downloading the newest metadata file + try: + downloaded_file_path = _download_metadata_file(file_name, data_dir) + # if download succeeds, replace any existing local metadata file + downloaded_file_path.replace(local_file_path) + # if download fails, try loading an existing local metadata file, + # otherwise raise an error + except RequestException as exc_info: + if local_file_path.is_file(): + log_warning( + f"{failed_msg} Will use the existing local version instead." + ) + else: + raise log_error(RequestException, failed_msg) from exc_info + + with open(local_file_path, "r") as metadata_file: + metadata = yaml.safe_load(metadata_file) + return metadata + + +metadata = _fetch_metadata("poses_files_metadata.yaml") + +# Create a download manager for the pose data +SAMPLE_DATA = pooch.create( + path=DATA_DIR / "poses", + base_url=f"{DATA_URL}/poses/", + retry_if_failed=0, + registry={file["file_name"]: file["sha256sum"] for file in metadata}, +) + + +def list_sample_data() -> list[str]: + """Find available sample pose data in the *movement* data repository. + + Returns + ------- + filenames : list of str + List of filenames for available pose data.""" + return list(SAMPLE_DATA.registry.keys()) + + +def fetch_sample_data_path(filename: str) -> Path: + """Download sample pose data from the *movement* data repository and return + its local filepath. + + The data are downloaded to the user's local machine the first time they are + used and are stored in a local cache directory. The function returns the + path to the downloaded file, not the contents of the file itself. + + Parameters + ---------- + filename : str + Name of the file to fetch. + + Returns + ------- + path : pathlib.Path + Path to the downloaded file. + """ + try: + return Path(SAMPLE_DATA.fetch(filename, progressbar=True)) + except ValueError: + raise log_error( + ValueError, + f"File '{filename}' is not in the registry. Valid " + f"filenames are: {list_sample_data()}", + ) + + +def fetch_sample_data( + filename: str, +) -> xarray.Dataset: + """Download and return sample pose data from the *movement* data + repository. + + The data are downloaded to the user's local machine the first time they are + used and are stored in a local cache directory. Returns sample pose data as + an xarray Dataset. + + Parameters + ---------- + filename : str + Name of the file to fetch. + + Returns + ------- + ds : xarray.Dataset + Pose data contained in the fetched sample file. + """ + + file_path = fetch_sample_data_path(filename) + file_metadata = next( + file for file in metadata if file["file_name"] == filename + ) + + if file_metadata["source_software"] == "SLEAP": + ds = load_poses.from_sleap_file(file_path, fps=file_metadata["fps"]) + elif file_metadata["source_software"] == "DeepLabCut": + ds = load_poses.from_dlc_file(file_path, fps=file_metadata["fps"]) + elif file_metadata["source_software"] == "LightningPose": + ds = load_poses.from_lp_file(file_path, fps=file_metadata["fps"]) + return ds diff --git a/pyproject.toml b/pyproject.toml index c41c4e99..dfebb7d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "tqdm", "sleap-io", "xarray", + "PyYAML", ] classifiers = [ @@ -54,6 +55,8 @@ dev = [ "pandas-stubs", "types-attrs", "check-manifest", + "types-PyYAML", + "types-requests", ] [build-system] diff --git a/tests/conftest.py b/tests/conftest.py index e4baa670..411c9679 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,29 +8,17 @@ import pytest import xarray as xr -from movement.datasets import fetch_pose_data_path from movement.io import PosesAccessor from movement.logging import configure_logging +from movement.sample_data import fetch_sample_data_path, list_sample_data def pytest_configure(): """Perform initial configuration for pytest. Fetches pose data file paths as a dictionary for tests.""" - pytest.POSE_DATA = { - file_name: fetch_pose_data_path(file_name) - for file_name in [ - "DLC_single-wasp.predictions.h5", - "DLC_single-wasp.predictions.csv", - "DLC_two-mice.predictions.csv", - "SLEAP_single-mouse_EPM.analysis.h5", - "SLEAP_single-mouse_EPM.predictions.slp", - "SLEAP_three-mice_Aeon_proofread.analysis.h5", - "SLEAP_three-mice_Aeon_proofread.predictions.slp", - "SLEAP_three-mice_Aeon_mixed-labels.analysis.h5", - "SLEAP_three-mice_Aeon_mixed-labels.predictions.slp", - "LP_mouse-face_AIND.predictions.csv", - "LP_mouse-twoview_AIND.predictions.csv", - ] + pytest.POSE_DATA_PATHS = { + file_name: fetch_sample_data_path(file_name) + for file_name in list_sample_data() } @@ -186,7 +174,9 @@ def new_csv_file(tmp_path): @pytest.fixture def dlc_style_df(): """Return a valid DLC-style DataFrame.""" - return pd.read_hdf(pytest.POSE_DATA.get("DLC_single-wasp.predictions.h5")) + return pd.read_hdf( + pytest.POSE_DATA_PATHS.get("DLC_single-wasp.predictions.h5") + ) @pytest.fixture( @@ -201,7 +191,7 @@ def dlc_style_df(): ) def sleap_file(request): """Return the file path for a SLEAP .h5 or .slp file.""" - return pytest.POSE_DATA.get(request.param) + return pytest.POSE_DATA_PATHS.get(request.param) @pytest.fixture diff --git a/tests/test_integration/test_io.py b/tests/test_integration/test_io.py index 260ef7b1..316485c4 100644 --- a/tests/test_integration/test_io.py +++ b/tests/test_integration/test_io.py @@ -2,7 +2,7 @@ import numpy as np import pytest import xarray as xr -from pytest import POSE_DATA +from pytest import POSE_DATA_PATHS from movement.io import load_poses, save_poses @@ -56,7 +56,7 @@ def test_to_sleap_analysis_file_returns_same_h5_file_content( """Test that saving pose tracks (loaded from a SLEAP analysis file) to a SLEAP-style .h5 analysis file returns the same file contents.""" - sleap_h5_file_path = POSE_DATA.get(sleap_h5_file) + sleap_h5_file_path = POSE_DATA_PATHS.get(sleap_h5_file) ds = load_poses.from_sleap_file(sleap_h5_file_path, fps=fps) save_poses.to_sleap_analysis_file(ds, new_h5_file) @@ -85,7 +85,7 @@ def test_to_sleap_analysis_file_source_file(self, file, new_h5_file): """Test that saving pose tracks (loaded from valid source files) to a SLEAP-style .h5 analysis file stores the .slp labels path only when the source file is a .slp file.""" - file_path = POSE_DATA.get(file) + file_path = POSE_DATA_PATHS.get(file) if file.startswith("DLC"): ds = load_poses.from_dlc_file(file_path) else: diff --git a/tests/test_unit/test_load_poses.py b/tests/test_unit/test_load_poses.py index dedc5b5b..37f76f6a 100644 --- a/tests/test_unit/test_load_poses.py +++ b/tests/test_unit/test_load_poses.py @@ -2,7 +2,7 @@ import numpy as np import pytest import xarray as xr -from pytest import POSE_DATA +from pytest import POSE_DATA_PATHS from sleap_io.io.slp import read_labels, write_labels from sleap_io.model.labels import LabeledFrame, Labels @@ -15,7 +15,9 @@ class TestLoadPoses: @pytest.fixture def sleap_slp_file_without_tracks(self, tmp_path): """Mock and return the path to a SLEAP .slp file without tracks.""" - sleap_file = POSE_DATA.get("SLEAP_single-mouse_EPM.predictions.slp") + sleap_file = POSE_DATA_PATHS.get( + "SLEAP_single-mouse_EPM.predictions.slp" + ) labels = read_labels(sleap_file) file_path = tmp_path / "track_is_none.slp" lfs = [] @@ -43,7 +45,7 @@ def sleap_slp_file_without_tracks(self, tmp_path): @pytest.fixture def sleap_h5_file_without_tracks(self, tmp_path): """Mock and return the path to a SLEAP .h5 file without tracks.""" - sleap_file = POSE_DATA.get("SLEAP_single-mouse_EPM.analysis.h5") + sleap_file = POSE_DATA_PATHS.get("SLEAP_single-mouse_EPM.analysis.h5") file_path = tmp_path / "track_is_none.h5" with h5py.File(sleap_file, "r") as f1, h5py.File(file_path, "w") as f2: for key in list(f1.keys()): @@ -112,7 +114,7 @@ def test_load_from_sleap_file_without_tracks( sleap_file_without_tracks ) ds_from_tracked = load_poses.from_sleap_file( - POSE_DATA.get("SLEAP_single-mouse_EPM.analysis.h5") + POSE_DATA_PATHS.get("SLEAP_single-mouse_EPM.analysis.h5") ) # Check if the "individuals" coordinate matches # the assigned default "individuals_0" @@ -144,8 +146,8 @@ def test_load_from_sleap_slp_file_or_h5_file_returns_same( ): """Test that loading pose tracks from SLEAP .slp and .h5 files return the same Dataset.""" - slp_file_path = POSE_DATA.get(slp_file) - h5_file_path = POSE_DATA.get(h5_file) + slp_file_path = POSE_DATA_PATHS.get(slp_file) + h5_file_path = POSE_DATA_PATHS.get(h5_file) ds_from_slp = load_poses.from_sleap_file(slp_file_path) ds_from_h5 = load_poses.from_sleap_file(h5_file_path) xr.testing.assert_allclose(ds_from_h5, ds_from_slp) @@ -161,7 +163,7 @@ def test_load_from_sleap_slp_file_or_h5_file_returns_same( def test_load_from_dlc_file(self, file_name): """Test that loading pose tracks from valid DLC files returns a proper Dataset.""" - file_path = POSE_DATA.get(file_name) + file_path = POSE_DATA_PATHS.get(file_name) ds = load_poses.from_dlc_file(file_path) self.assert_dataset(ds, file_path, "DeepLabCut") @@ -174,8 +176,8 @@ def test_load_from_dlc_df(self, dlc_style_df): def test_load_from_dlc_file_csv_or_h5_file_returns_same(self): """Test that loading pose tracks from DLC .csv and .h5 files return the same Dataset.""" - csv_file_path = POSE_DATA.get("DLC_single-wasp.predictions.csv") - h5_file_path = POSE_DATA.get("DLC_single-wasp.predictions.h5") + csv_file_path = POSE_DATA_PATHS.get("DLC_single-wasp.predictions.csv") + h5_file_path = POSE_DATA_PATHS.get("DLC_single-wasp.predictions.h5") ds_from_csv = load_poses.from_dlc_file(csv_file_path) ds_from_h5 = load_poses.from_dlc_file(h5_file_path) xr.testing.assert_allclose(ds_from_h5, ds_from_csv) @@ -193,7 +195,7 @@ def test_load_from_dlc_file_csv_or_h5_file_returns_same(self): def test_fps_and_time_coords(self, fps, expected_fps, expected_time_unit): """Test that time coordinates are set according to the provided fps.""" ds = load_poses.from_sleap_file( - POSE_DATA.get("SLEAP_three-mice_Aeon_proofread.analysis.h5"), + POSE_DATA_PATHS.get("SLEAP_three-mice_Aeon_proofread.analysis.h5"), fps=fps, ) assert ds.time_unit == expected_time_unit @@ -216,7 +218,7 @@ def test_fps_and_time_coords(self, fps, expected_fps, expected_time_unit): def test_load_from_lp_file(self, file_name): """Test that loading pose tracks from valid LightningPose (LP) files returns a proper Dataset.""" - file_path = POSE_DATA.get(file_name) + file_path = POSE_DATA_PATHS.get(file_name) ds = load_poses.from_lp_file(file_path) self.assert_dataset(ds, file_path, "LightningPose") @@ -224,7 +226,7 @@ def test_load_from_lp_or_dlc_file_returns_same(self): """Test that loading a single-animal DeepLabCut-style .csv file using either the `from_lp_file` or `from_dlc_file` function returns the same Dataset (except for the source_software).""" - file_path = POSE_DATA.get("LP_mouse-face_AIND.predictions.csv") + file_path = POSE_DATA_PATHS.get("LP_mouse-face_AIND.predictions.csv") ds_drom_lp = load_poses.from_lp_file(file_path) ds_from_dlc = load_poses.from_dlc_file(file_path) xr.testing.assert_allclose(ds_from_dlc, ds_drom_lp) @@ -234,6 +236,6 @@ def test_load_from_lp_or_dlc_file_returns_same(self): def test_load_multi_animal_from_lp_file_raises(self): """Test that loading a multi-animal .csv file using the `from_lp_file` function raises a ValueError.""" - file_path = POSE_DATA.get("DLC_two-mice.predictions.csv") + file_path = POSE_DATA_PATHS.get("DLC_two-mice.predictions.csv") with pytest.raises(ValueError): load_poses.from_lp_file(file_path) diff --git a/tests/test_unit/test_sample_data.py b/tests/test_unit/test_sample_data.py new file mode 100644 index 00000000..09bcafad --- /dev/null +++ b/tests/test_unit/test_sample_data.py @@ -0,0 +1,121 @@ +"""Test suite for the sample_data module.""" + +from unittest.mock import MagicMock, patch + +import pooch +import pytest +from requests.exceptions import RequestException +from xarray import Dataset + +from movement.sample_data import ( + _fetch_metadata, + fetch_sample_data, + list_sample_data, +) + + +@pytest.fixture(scope="module") +def valid_file_names_with_fps(): + """Return a dict containing one valid file name and the corresponding fps + for each supported pose estimation tool.""" + return { + "SLEAP_single-mouse_EPM.analysis.h5": 30, + "DLC_single-wasp.predictions.h5": 40, + "LP_mouse-face_AIND.predictions.csv": 60, + } + + +def validate_metadata(metadata: list[dict]) -> None: + """Assert that the metadata is in the expected format.""" + metadata_fields = [ + "file_name", + "sha256sum", + "source_software", + "fps", + "species", + "number_of_individuals", + "shared_by", + "video_frame_file", + "note", + ] + check_yaml_msg = "Check the format of the metadata yaml file." + assert isinstance( + metadata, list + ), f"Expected metadata to be a list. {check_yaml_msg}" + assert all( + isinstance(file, dict) for file in metadata + ), f"Expected metadata entries to be dicts. {check_yaml_msg}" + assert all( + set(file.keys()) == set(metadata_fields) for file in metadata + ), f"Expected all metadata entries to have the same keys. {check_yaml_msg}" + + # check that filenames are unique + file_names = [file["file_name"] for file in metadata] + assert len(file_names) == len(set(file_names)) + + # check that the first 3 fields are present and are strings + required_fields = metadata_fields[:3] + assert all( + (isinstance(file[field], str)) + for file in metadata + for field in required_fields + ) + + +# Mock pooch.retrieve with RequestException as side_effect +mock_retrieve = MagicMock(pooch.retrieve, side_effect=RequestException) + + +@pytest.mark.parametrize("download_fails", [True, False]) +@pytest.mark.parametrize("local_exists", [True, False]) +def test_fetch_metadata(tmp_path, caplog, download_fails, local_exists): + """Test the fetch_metadata function with different combinations of + failed download and pre-existing local file. The expected behavior is + that the function will try to download the metadata file, and if that + fails, it will try to load an existing local file. If neither succeeds, + an error is raised.""" + metadata_file_name = "poses_files_metadata.yaml" + local_file_path = tmp_path / metadata_file_name + + with patch("movement.sample_data.DATA_DIR", tmp_path): + # simulate the existence of a local metadata file + if local_exists: + local_file_path.touch() + + if download_fails: + # simulate a failed download + with patch("movement.sample_data.pooch.retrieve", mock_retrieve): + if local_exists: + _fetch_metadata(metadata_file_name) + # check that a warning was logged + assert ( + "Will use the existing local version instead" + in caplog.records[-1].getMessage() + ) + else: + with pytest.raises( + RequestException, match="Failed to download" + ): + _fetch_metadata(metadata_file_name, data_dir=tmp_path) + else: + metadata = _fetch_metadata(metadata_file_name, data_dir=tmp_path) + assert local_file_path.is_file() + validate_metadata(metadata) + + +def test_list_sample_data(valid_file_names_with_fps): + assert isinstance(list_sample_data(), list) + assert all( + file in list_sample_data() for file in valid_file_names_with_fps + ) + + +def test_fetch_sample_data(valid_file_names_with_fps): + # test with valid files + for file, fps in valid_file_names_with_fps.items(): + ds = fetch_sample_data(file) + assert isinstance(ds, Dataset) and ds.fps == fps + + # Test with an invalid file + with pytest.raises(ValueError): + fetch_sample_data("nonexistent_file") diff --git a/tests/test_unit/test_save_poses.py b/tests/test_unit/test_save_poses.py index c65a88d8..e2a4f43e 100644 --- a/tests/test_unit/test_save_poses.py +++ b/tests/test_unit/test_save_poses.py @@ -5,7 +5,7 @@ import pandas as pd import pytest import xarray as xr -from pytest import POSE_DATA +from pytest import POSE_DATA_PATHS from movement.io import load_poses, save_poses @@ -65,25 +65,25 @@ def output_file_params(self, request): (np.array([1, 2, 3]), pytest.raises(ValueError)), # incorrect type ( load_poses.from_dlc_file( - POSE_DATA.get("DLC_single-wasp.predictions.h5") + POSE_DATA_PATHS.get("DLC_single-wasp.predictions.h5") ), does_not_raise(), ), # valid dataset ( load_poses.from_dlc_file( - POSE_DATA.get("DLC_two-mice.predictions.csv") + POSE_DATA_PATHS.get("DLC_two-mice.predictions.csv") ), does_not_raise(), ), # valid dataset ( load_poses.from_sleap_file( - POSE_DATA.get("SLEAP_single-mouse_EPM.analysis.h5") + POSE_DATA_PATHS.get("SLEAP_single-mouse_EPM.analysis.h5") ), does_not_raise(), ), # valid dataset ( load_poses.from_sleap_file( - POSE_DATA.get( + POSE_DATA_PATHS.get( "SLEAP_three-mice_Aeon_proofread.predictions.slp" ) ), @@ -91,7 +91,7 @@ def output_file_params(self, request): ), # valid dataset ( load_poses.from_lp_file( - POSE_DATA.get("LP_mouse-face_AIND.predictions.csv") + POSE_DATA_PATHS.get("LP_mouse-face_AIND.predictions.csv") ), does_not_raise(), ), # valid dataset