Skip to content

Commit

Permalink
Refactor file path and dataset validation
Browse files Browse the repository at this point in the history
  • Loading branch information
lochhh committed Nov 22, 2023
1 parent 7e12f08 commit 8837210
Showing 1 changed file with 60 additions and 16 deletions.
76 changes: 60 additions & 16 deletions movement/io/save_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,7 @@ def to_dlc_df(
to_dlc_file : Save the xarray dataset containing pose tracks directly
to a DeepLabCut-style .h5 or .csv file.
"""
if not isinstance(ds, xr.Dataset):
raise log_error(
ValueError, f"Expected an xarray Dataset, but got {type(ds)}."
)

ds.poses.validate() # validate the dataset

_validate_dataset(ds)
scorer = ["movement"]
bodyparts = ds.coords["keypoints"].data.tolist()
coords = ds.coords["space"].data.tolist() + ["likelihood"]
Expand Down Expand Up @@ -197,15 +191,7 @@ def to_dlc_file(
>>> save_poses.to_dlc_file(ds, "/path/to/file_dlc.h5")
"""

try:
file = ValidFile(
file_path,
expected_permission="w",
expected_suffix=[".csv", ".h5"],
)
except (OSError, ValueError) as error:
logger.error(error)
raise error
file = _validate_file_path(file_path, expected_suffix=[".csv", ".h5"])

# Sets default behaviour for the function
if split_individuals == "auto":
Expand Down Expand Up @@ -236,3 +222,61 @@ def to_dlc_file(
if isinstance(df_all, pd.DataFrame):
_save_dlc_df(file.path, df_all)
logger.info(f"Saved PoseTracks dataset to {file.path}.")


def _validate_file_path(
file_path: Union[str, Path], expected_suffix: list[str]
) -> ValidFile:
"""Validate the input file path by checking that the file has
write permission and expected suffix(es). If the file is not valid,
an appropriate error is raised.
Parameters
----------
file_path : pathlib.Path or str
Path to the file to validate.
expected_suffix : list of str
Expected suffix(es) for the file.
Returns
-------
ValidFile
The validated file.
Raises
------
OSError
If the file cannot be written.
ValueError
If the file does not have the expected suffix.
"""
try:
file = ValidFile(
file_path,
expected_permission="w",
expected_suffix=expected_suffix,
)
except (OSError, ValueError) as error:
logger.error(error)
raise error
return file


def _validate_dataset(ds: xr.Dataset) -> None:
"""Validate the input dataset is an xarray Dataset with valid PoseTracks.
Parameters
----------
ds : xr.Dataset
Dataset to validate.
Raises
------
ValueError
If `ds` is not an xarray Dataset with valid PoseTracks.
"""
if not isinstance(ds, xr.Dataset):
raise log_error(
ValueError, f"Expected an xarray Dataset, but got {type(ds)}."
)
ds.poses.validate() # validate the dataset

0 comments on commit 8837210

Please sign in to comment.