From 267a9f4e6532a5fee3636630189c03a620f4d3f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Sun, 25 Aug 2024 19:37:43 +0545 Subject: [PATCH] dvcfs: optimize `get()` by reducing index.info calls() --- dvc/dependency/repo.py | 8 +-- dvc/fs/__init__.py | 25 ++++--- dvc/fs/dvc.py | 143 +++++++++++++++++++++++++++++++++++++- tests/func/test_import.py | 47 ++++++++++--- 4 files changed, 197 insertions(+), 26 deletions(-) diff --git a/dvc/dependency/repo.py b/dvc/dependency/repo.py index 5fa8d0e7e4..417f088afe 100644 --- a/dvc/dependency/repo.py +++ b/dvc/dependency/repo.py @@ -98,12 +98,13 @@ def download(self, to: "Output", jobs: Optional[int] = None): files = super().download(to=to, jobs=jobs) if not isinstance(to.fs, LocalFileSystem): - return files + return hashes: list[tuple[str, HashInfo, dict[str, Any]]] = [] - for src_path, dest_path in files: + for src_path, dest_path, *rest in files: try: - hash_info = self.fs.info(src_path)["dvc_info"]["entry"].hash_info + info = rest[0] if rest else self.fs.info(src_path) + hash_info = info["dvc_info"]["entry"].hash_info dest_info = to.fs.info(dest_path) except (KeyError, AttributeError): # If no hash info found, just keep going and output will be hashed later @@ -112,7 +113,6 @@ def download(self, to: "Output", jobs: Optional[int] = None): hashes.append((dest_path, hash_info, dest_info)) cache = to.cache if to.use_cache else to.local_cache cache.state.save_many(hashes, to.fs) - return files def update(self, rev: Optional[str] = None): if rev: diff --git a/dvc/fs/__init__.py b/dvc/fs/__init__.py index 0c9cf567ac..4b739428c6 100644 --- a/dvc/fs/__init__.py +++ b/dvc/fs/__init__.py @@ -1,5 +1,5 @@ import glob -from typing import Optional +from typing import Optional, Union from urllib.parse import urlparse from dvc.config import ConfigError as RepoConfigError @@ -47,12 +47,24 @@ def download( fs: "FileSystem", fs_path: str, to: str, jobs: Optional[int] = None -) -> list[tuple[str, str]]: +) -> list[Union[tuple[str, str], tuple[str, str, dict]]]: from dvc.scm import lfs_prefetch from .callbacks import TqdmCallback with TqdmCallback(desc=f"Downloading {fs.name(fs_path)}", unit="files") as cb: + if isinstance(fs, DVCFileSystem): + lfs_prefetch( + fs, + [ + f"{fs.normpath(glob.escape(fs_path))}/**" + if fs.isdir(fs_path) + else glob.escape(fs_path) + ], + ) + if not glob.has_magic(fs_path): + return fs._get(fs_path, to, batch_size=jobs, callback=cb) + # NOTE: We use dvc-objects generic.copy over fs.get since it makes file # download atomic and avoids fsspec glob/regex path expansion. if fs.isdir(fs_path): @@ -69,15 +81,6 @@ def download( from_infos = [fs_path] to_infos = [to] - if isinstance(fs, DVCFileSystem): - lfs_prefetch( - fs, - [ - f"{fs.normpath(glob.escape(fs_path))}/**" - if fs.isdir(fs_path) - else glob.escape(fs_path) - ], - ) cb.set_size(len(from_infos)) jobs = jobs or fs.jobs generic.copy(fs, from_infos, localfs, to_infos, callback=cb, batch_size=jobs) diff --git a/dvc/fs/dvc.py b/dvc/fs/dvc.py index 19e4c04654..f7b9ba0d47 100644 --- a/dvc/fs/dvc.py +++ b/dvc/fs/dvc.py @@ -6,13 +6,15 @@ import threading from collections import deque from contextlib import ExitStack, suppress +from glob import has_magic from typing import TYPE_CHECKING, Any, Callable, Optional, Union -from fsspec.spec import AbstractFileSystem +from fsspec.spec import DEFAULT_CALLBACK, AbstractFileSystem from funcy import wrap_with from dvc.log import logger -from dvc_objects.fs.base import FileSystem +from dvc.utils.threadpool import ThreadPoolExecutor +from dvc_objects.fs.base import AnyFSPath, FileSystem from .data import DataFileSystem @@ -20,6 +22,8 @@ from dvc.repo import Repo from dvc.types import DictStrAny, StrPath + from .callbacks import Callback + logger = logger.getChild(__name__) RepoFactory = Union[Callable[..., "Repo"], type["Repo"]] @@ -474,9 +478,105 @@ def _info( # noqa: C901 info["name"] = path return info + def get( + self, + rpath, + lpath, + recursive=False, + callback=DEFAULT_CALLBACK, + maxdepth=None, + batch_size=None, + **kwargs, + ): + self._get( + rpath, + lpath, + recursive=recursive, + callback=callback, + maxdepth=maxdepth, + batch_size=batch_size, + **kwargs, + ) + + def _get( # noqa: C901 + self, + rpath, + lpath, + recursive=False, + callback=DEFAULT_CALLBACK, + maxdepth=None, + batch_size=None, + **kwargs, + ) -> list[Union[tuple[str, str], tuple[str, str, dict]]]: + if ( + isinstance(rpath, list) + or isinstance(lpath, list) + or has_magic(rpath) + or not self.exists(rpath) + or not recursive + ): + super().get( + rpath, + lpath, + recursive=recursive, + callback=callback, + maxdepth=maxdepth, + **kwargs, + ) + return [] + + if os.path.isdir(lpath) or lpath.endswith(os.path.sep): + lpath = self.join(lpath, os.path.basename(rpath)) + + if self.isfile(rpath): + os.makedirs(os.path.dirname(lpath), exist_ok=True) + with callback.branched(rpath, lpath) as child: + self.get_file(rpath, lpath, callback=child, **kwargs) + return [(rpath, lpath)] + + _files = [] + _dirs: list[str] = [] + for root, dirs, files in self.walk(rpath, maxdepth=maxdepth, detail=True): + if files: + callback.set_size((callback.size or 0) + len(files)) + + parts = self.relparts(root, rpath) + if parts in ((os.curdir,), ("",)): + parts = () + dest_root = os.path.join(lpath, *parts) + _dirs.extend(f"{dest_root}{os.path.sep}{d}" for d in dirs) + + key = self._get_key_from_relative(root) + _, dvc_fs, _ = self._get_subrepo_info(key) + + for name, info in files.items(): + src_path = f"{root}{self.sep}{name}" + dest_path = f"{dest_root}{os.path.sep}{name}" + _files.append((dvc_fs, src_path, dest_path, info)) + + os.makedirs(lpath, exist_ok=True) + for d in _dirs: + os.mkdir(d) + + def _get_file(arg): + dvc_fs, src, dest, info = arg + dvc_info = info.get("dvc_info") + if dvc_info and dvc_fs: + dvc_path = dvc_info["name"] + dvc_fs.get_file( + dvc_path, dest, callback=callback, info=dvc_info, **kwargs + ) + else: + self.get_file(src, dest, callback=callback, **kwargs) + return src, dest, info + + with ThreadPoolExecutor(max_workers=batch_size) as executor: + return list(executor.imap_unordered(_get_file, _files)) + def get_file(self, rpath, lpath, **kwargs): key = self._get_key_from_relative(rpath) fs_path = self._from_key(key) + try: return self.repo.fs.get_file(fs_path, lpath, **kwargs) except FileNotFoundError: @@ -553,6 +653,45 @@ def immutable(self): def getcwd(self): return self.fs.getcwd() + def _get( + self, + from_info: Union[AnyFSPath, list[AnyFSPath]], + to_info: Union[AnyFSPath, list[AnyFSPath]], + callback: "Callback" = DEFAULT_CALLBACK, + recursive: bool = False, + batch_size: Optional[int] = None, + **kwargs, + ) -> list[Union[tuple[str, str], tuple[str, str, dict]]]: + # FileSystem.get is non-recursive by default if arguments are lists + # otherwise, it's recursive. + recursive = not (isinstance(from_info, list) and isinstance(to_info, list)) + return self.fs._get( + from_info, + to_info, + callback=callback, + recursive=recursive, + batch_size=batch_size, + **kwargs, + ) + + def get( + self, + from_info: Union[AnyFSPath, list[AnyFSPath]], + to_info: Union[AnyFSPath, list[AnyFSPath]], + callback: "Callback" = DEFAULT_CALLBACK, + recursive: bool = False, + batch_size: Optional[int] = None, + **kwargs, + ) -> None: + self._get( + from_info, + to_info, + callback=callback, + batch_size=batch_size, + recursive=recursive, + **kwargs, + ) + @property def fsid(self) -> str: return self.fs.fsid diff --git a/tests/func/test_import.py b/tests/func/test_import.py index df78e7a699..47fdd66ccc 100644 --- a/tests/func/test_import.py +++ b/tests/func/test_import.py @@ -13,7 +13,7 @@ from dvc.testing.tmp_dir import make_subrepo from dvc.utils.fs import remove from dvc_data.hashfile import hash -from dvc_data.index.index import DataIndexDirError +from dvc_data.index.index import DataIndex, DataIndexDirError def test_import(tmp_dir, scm, dvc, erepo_dir): @@ -725,12 +725,41 @@ def test_import_invalid_configs(tmp_dir, scm, dvc, erepo_dir): ) -def test_import_no_hash(tmp_dir, scm, dvc, erepo_dir, mocker): +@pytest.mark.parametrize( + "files,expected_info_calls", + [ + ({"foo": "foo"}, {("foo",)}), + ( + { + "dir": { + "bar": "bar", + "subdir": {"lorem": "ipsum", "nested": {"lorem": "lorem"}}, + } + }, + # info calls should be made for only directories + {("dir",), ("dir", "subdir"), ("dir", "subdir", "nested")}, + ), + ], +) +def test_import_no_hash( + tmp_dir, scm, dvc, erepo_dir, mocker, files, expected_info_calls +): with erepo_dir.chdir(): - erepo_dir.dvc_gen("foo", "foo content", commit="create foo") - - spy = mocker.spy(hash, "file_md5") - stage = dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported") - assert spy.call_count == 1 - for call in spy.call_args_list: - assert stage.outs[0].fs_path != call.args[0] + erepo_dir.dvc_gen(files, commit="create foo") + + file_md5_spy = mocker.spy(hash, "file_md5") + index_info_spy = mocker.spy(DataIndex, "info") + name = next(iter(files)) + + dvc.imp(os.fspath(erepo_dir), name, "out") + + local_hashes = [ + call.args[0] + for call in file_md5_spy.call_args_list + if call.args[1].protocol == "local" + ] + # no files should be hashed, should use existing metadata + assert not local_hashes + assert { + call.args[1] for call in index_info_spy.call_args_list + } == expected_info_calls