diff --git a/openneuro/download.py b/openneuro/download.py index 34fd4f1..4e8c9d4 100644 --- a/openneuro/download.py +++ b/openneuro/download.py @@ -4,7 +4,7 @@ from difflib import get_close_matches import hashlib import asyncio -from pathlib import Path +from pathlib import Path, PosixPath import string import json from typing import Optional, Union @@ -544,7 +544,8 @@ def _iterate_filenames( dataset_id: str, tag: str, max_retries: int, - root: str = '' + root: str = '', + include: Iterable[str] = tuple(), ) -> Generator[dict, None, None]: """Iterate over all files in a dataset, yielding filenames.""" directories = list() @@ -557,6 +558,37 @@ def _iterate_filenames( yield entity for directory in directories: + # Only bother with directories that are in the include list + if include: + # Take the example: + # + # --include="sub-CON001/*.eeg" + # + # or + # + # --include="sub-CON001" + # + # or + # + # --include="sub-CON001/*" + # + # All three of these should traverse `sub-CON001` and its + # subdirectories. + n_parts = len(PosixPath(root).parts) + dir_include = [PosixPath(inc) for inc in include] + dir_include = [ # for stuff like sub-CON001/* + '/'.join(inc.parts[:n_parts] + ('*',)) + for inc in dir_include + if len(inc.parts) >= n_parts + ] + [ # and stuff like sub-CON001/*.eeg + '/'.join(inc.parts[:n_parts - 1] + ('*',)) + for inc in dir_include + if len(inc.parts) >= n_parts - 1 and len(inc.parts) > 1 + ] # we want to traverse sub-CON001 in both cases + matches_include, _ = _match_include_exclude( + directory['filename'], include=dir_include, exclude=[]) + if dir_include and not any(matches_include): + continue # Query filenames this_dir = directory['filename'] metadata = _get_download_metadata( @@ -572,11 +604,27 @@ def _iterate_filenames( tag=tag, max_retries=max_retries, root=this_dir, + include=include, ) for path in dir_iterator: yield path +def _match_include_exclude( + filename: str, + *, + include: Iterable[str], + exclude: Iterable[str], +) -> bool: + """Check if a filename matches an include or exclude pattern.""" + matches_keep = [filename.startswith(i) or fnmatch.fnmatch(filename, i) + for i in include] + matches_remove = [filename.startswith(e) or + fnmatch.fnmatch(filename, e) + for e in exclude] + return matches_keep, matches_remove + + def download(*, dataset: str, tag: Optional[str] = None, @@ -681,7 +729,8 @@ def download(*, for file in tqdm( _iterate_filenames( - these_files, dataset_id=dataset, tag=tag, max_retries=max_retries + these_files, dataset_id=dataset, tag=tag, max_retries=max_retries, + include=include, ), desc=_unicode( f'Traversing directories for {dataset}', end='', emoji='📁' @@ -703,12 +752,9 @@ def download(*, include_counts[include.index(filename)] += 1 continue - matches_keep = [filename.startswith(i) or fnmatch.fnmatch(filename, i) - for i in include] - matches_remove = [filename.startswith(e) or - fnmatch.fnmatch(filename, e) - for e in exclude] - if (not include or any(matches_keep)) and not any(matches_remove): + matches_keep, matches_exclude = _match_include_exclude( + filename, include=include, exclude=exclude) + if (not include or any(matches_keep)) and not any(matches_exclude): files.append(file) # Keep track of include matches. if any(matches_keep): @@ -727,7 +773,7 @@ def download(*, else: extra = ( 'There were no similar filenames found in the ' - 'metadata.' + 'metadata. ' ) raise RuntimeError( f'Could not find path in the dataset:\n- {this}\n{extra}' diff --git a/openneuro/tests/test_download.py b/openneuro/tests/test_download.py index 259a29d..3d2e7ee 100644 --- a/openneuro/tests/test_download.py +++ b/openneuro/tests/test_download.py @@ -8,23 +8,27 @@ dataset_id_aws = 'ds000246' tag_aws = '1.0.0' include_aws = 'sub-0001/anat' +exclude_aws = [] dataset_id_on = 'ds000117' +tag_on = None include_on = 'sub-16/ses-meg' +exclude_on = '*.fif' # save GBs of downloads invalid_tag = 'abcdefg' @pytest.mark.parametrize( - ('dataset_id', 'tag', 'include'), + ('dataset_id', 'tag', 'include', 'exclude'), [ - (dataset_id_aws, tag_aws, include_aws), - (dataset_id_on, None, include_on) + (dataset_id_aws, tag_aws, include_aws, exclude_aws), + (dataset_id_on, tag_on, include_on, exclude_on), ] ) -def test_download(tmp_path: Path, dataset_id, tag, include): +def test_download(tmp_path: Path, dataset_id, tag, include, exclude): """Test downloading some files.""" - download(dataset=dataset_id, tag=tag, target_dir=tmp_path, include=include) + download(dataset=dataset_id, tag=tag, target_dir=tmp_path, include=include, + exclude=exclude) def test_download_invalid_tag(tmp_path: Path, dataset_id=dataset_id_aws, @@ -49,14 +53,16 @@ def test_resume_download(tmp_path: Path): # Download from a different revision / tag new_tag = '00001' + include = ['CHANGES'] with pytest.raises(FileExistsError, match=f'revision {tag} exists'): - download(dataset=dataset, tag=new_tag, target_dir=tmp_path) + download(dataset=dataset, tag=new_tag, target_dir=tmp_path, + include=include) # Try to "resume" from a different dataset new_dataset = 'ds000117' with pytest.raises(RuntimeError, match='existing dataset.*appears to be different'): - download(dataset=new_dataset, target_dir=tmp_path) + download(dataset=new_dataset, target_dir=tmp_path, include=include) # Remove "DatasetDOI" from JSON json_path = tmp_path / 'dataset_description.json' @@ -100,15 +106,21 @@ def test_doi_handling(tmp_path: Path): # Now inject a `doi:` prefix into the DOI dataset_description_path = tmp_path / 'dataset_description.json' - dataset_description = json.loads( + dataset_description_text = \ dataset_description_path.read_text(encoding='utf-8') - ) + dataset_description = json.loads(dataset_description_text) + # Make sure we can dumps to get the same thing back (if they change their + # indent 4->8 for example, we might try to resume our download of the file + # and things will break in a challenging way) + dataset_description_rt = json.dumps(dataset_description, indent=4) + assert dataset_description_text == dataset_description_rt + # Ensure the dataset doesn't already have the problematic prefix, then add assert not dataset_description['DatasetDOI'].startswith('doi:') dataset_description['DatasetDOI'] = ( 'doi:' + dataset_description['DatasetDOI'] ) dataset_description_path.write_text( - data=json.dumps(dataset_description, indent=2), + data=json.dumps(dataset_description, indent=4), encoding='utf-8' ) diff --git a/pyproject.toml b/pyproject.toml index 90c796a..542251a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,3 +46,6 @@ build-backend = "setuptools.build_meta" [tool.setuptools_scm] # can be left blank + +[tool.pytest.ini_options] +addopts = "-ra -vv --tb=short --durations=10"