From a6701d7d1734f12bb20e8fc5d32711801d2af24c Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Thu, 20 Apr 2023 03:57:48 +0300 Subject: [PATCH] fetch: use index fetch --- dvc/repo/__init__.py | 8 ++ dvc/repo/fetch.py | 170 +++++---------------------- dvc/repo/index.py | 13 ++ dvc/repo/pull.py | 6 - dvc/repo/worktree.py | 26 ---- pyproject.toml | 2 +- tests/func/test_data_cloud.py | 12 +- tests/func/test_import.py | 5 +- tests/func/test_import_url.py | 3 +- tests/func/test_virtual_directory.py | 2 +- 10 files changed, 62 insertions(+), 185 deletions(-) diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index e51fe3febb..5d36b049a1 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -367,6 +367,14 @@ def data_index(self) -> "DataIndex": return self._data_index + def drop_data_index(self) -> None: + try: + self.data_index.delete_node(("tree",)) + except KeyError: + pass + self.data_index.commit() + self._reset() + def __repr__(self): return f"{self.__class__.__name__}: '{self.root_dir}'" diff --git a/dvc/repo/fetch.py b/dvc/repo/fetch.py index 75acadb2ea..77d562bcbc 100644 --- a/dvc/repo/fetch.py +++ b/dvc/repo/fetch.py @@ -1,20 +1,9 @@ import logging -from contextlib import suppress -from typing import TYPE_CHECKING, Optional, Sequence -from dvc.config import NoRemoteError from dvc.exceptions import DownloadError -from dvc.fs import Schemes from . import locked -if TYPE_CHECKING: - from dvc.data_cloud import Remote - from dvc.repo import Repo - from dvc.types import TargetType - from dvc_data.hashfile.db import HashFileDB - from dvc_data.hashfile.transfer import TransferResult - logger = logging.getLogger(__name__) @@ -31,7 +20,6 @@ def fetch( # noqa: C901, PLR0913 all_commits=False, run_cache=False, revs=None, - odb: Optional["HashFileDB"] = None, ) -> int: """Download data items from a cloud and imported repositories @@ -45,18 +33,11 @@ def fetch( # noqa: C901, PLR0913 config.NoRemoteError: thrown when downloading only local files and no remote is configured """ - from dvc.repo.imports import save_imports - from dvc_data.hashfile.transfer import TransferResult + from dvc_data.index.fetch import fetch as ifetch if isinstance(targets, str): targets = [targets] - worktree_remote: Optional["Remote"] = None - with suppress(NoRemoteError): - _remote = self.cloud.get_remote(name=remote) - if _remote.worktree or _remote.fs.version_aware: - worktree_remote = _remote - failed_count = 0 transferred_count = 0 @@ -66,133 +47,38 @@ def fetch( # noqa: C901, PLR0913 except DownloadError as exc: failed_count += exc.amount - no_remote_msg: Optional[str] = None - result = TransferResult(set(), set()) - try: - if worktree_remote is not None: - transferred_count += _fetch_worktree( - self, - worktree_remote, - revs=revs, - all_branches=all_branches, - all_tags=all_tags, - all_commits=all_commits, - targets=targets, - jobs=jobs, - with_deps=with_deps, - recursive=recursive, - ) - else: - d, f = _fetch( - self, + def _indexes(): + for _ in self.brancher( + revs=revs, + all_branches=all_branches, + all_tags=all_tags, + all_commits=all_commits, + ): + yield self.index.targets_view( targets, - all_branches=all_branches, - all_tags=all_tags, - all_commits=all_commits, with_deps=with_deps, - force=True, - remote=remote, - jobs=jobs, recursive=recursive, - revs=revs, - odb=odb, - ) - result.transferred.update(d) - result.failed.update(f) - except NoRemoteError as exc: - no_remote_msg = str(exc) - - for rev in self.brancher( - revs=revs, - all_branches=all_branches, - all_tags=all_tags, - all_commits=all_commits, - ): - imported = save_imports( - self, - targets, - unpartial=not rev or rev == "workspace", - recursive=recursive, - ) - result.transferred.update(imported) - result.failed.difference_update(imported) - - failed_count += len(result.failed) + ).data["repo"] + saved_remote = self.config["core"].get("remote") + try: + if remote: + self.config["core"]["remote"] = remote + + fetch_transferred, fetch_failed = ifetch( + _indexes(), jobs=jobs + ) # pylint: disable=assignment-from-no-return + finally: + if remote: + self.config["core"]["remote"] = saved_remote + + if fetch_transferred: + # NOTE: dropping cached index to force reloading from newly saved cache + self.drop_data_index() + + transferred_count += fetch_transferred + failed_count += fetch_failed if failed_count: - if no_remote_msg: - logger.error(no_remote_msg) raise DownloadError(failed_count) - transferred_count += len(result.transferred) return transferred_count - - -def _fetch( - repo: "Repo", - targets: "TargetType", - remote: Optional[str] = None, - jobs: Optional[int] = None, - odb: Optional["HashFileDB"] = None, - **kwargs, -) -> "TransferResult": - from dvc_data.hashfile.transfer import TransferResult - - result = TransferResult(set(), set()) - used = repo.used_objs( - targets, - remote=remote, - jobs=jobs, - **kwargs, - ) - if odb: - all_ids = set() - for _odb, obj_ids in used.items(): - all_ids.update(obj_ids) - d, f = repo.cloud.pull( - all_ids, - jobs=jobs, - remote=remote, - odb=odb, - ) - result.transferred.update(d) - result.failed.update(f) - else: - for src_odb, obj_ids in sorted( - used.items(), - key=lambda item: item[0] is not None - and item[0].fs.protocol == Schemes.MEMORY, - ): - d, f = repo.cloud.pull( - obj_ids, - jobs=jobs, - remote=remote, - odb=src_odb, - ) - result.transferred.update(d) - result.failed.update(f) - return result - - -def _fetch_worktree( - repo: "Repo", - remote: "Remote", - revs: Optional[Sequence[str]] = None, - all_branches: bool = False, - all_tags: bool = False, - all_commits: bool = False, - targets: Optional["TargetType"] = None, - jobs: Optional[int] = None, - **kwargs, -) -> int: - from dvc.repo.worktree import fetch_worktree - - downloaded = 0 - for _ in repo.brancher( - revs=revs, - all_branches=all_branches, - all_tags=all_tags, - all_commits=all_commits, - ): - downloaded += fetch_worktree(repo, remote, targets=targets, jobs=jobs, **kwargs) - return downloaded diff --git a/dvc/repo/index.py b/dvc/repo/index.py index 5a1b12138d..fc55dc33a3 100644 --- a/dvc/repo/index.py +++ b/dvc/repo/index.py @@ -178,6 +178,19 @@ def _load_storage_from_out(storage_map, key, out): if out.stage.is_import: dep = out.stage.deps[0] + if not out.hash_info: + from fsspec.utils import tokenize + + # partial import + storage_map.add_cache( + FileStorage( + key, + out.cache.fs, + out.cache.fs.path.join( + out.cache.path, "fs", dep.fs.protocol, tokenize(dep.fs_path) + ), + ) + ) storage_map.add_remote(FileStorage(key, dep.fs, dep.fs_path)) diff --git a/dvc/repo/pull.py b/dvc/repo/pull.py index 300847581e..dbb75261be 100644 --- a/dvc/repo/pull.py +++ b/dvc/repo/pull.py @@ -1,12 +1,8 @@ import logging -from typing import TYPE_CHECKING, Optional from dvc.repo import locked from dvc.utils import glob_targets -if TYPE_CHECKING: - from dvc_objects.db import ObjectDB - logger = logging.getLogger(__name__) @@ -24,7 +20,6 @@ def pull( # noqa: PLR0913 all_commits=False, run_cache=False, glob=False, - odb: Optional["ObjectDB"] = None, allow_missing=False, ): if isinstance(targets, str): @@ -42,7 +37,6 @@ def pull( # noqa: PLR0913 with_deps=with_deps, recursive=recursive, run_cache=run_cache, - odb=odb, ) stats = self.checkout( targets=expanded_targets, diff --git a/dvc/repo/worktree.py b/dvc/repo/worktree.py index 45fc6dff30..3f86f2ee87 100644 --- a/dvc/repo/worktree.py +++ b/dvc/repo/worktree.py @@ -104,32 +104,6 @@ def _get_remote( return repo.cloud.get_remote(name, command) -def fetch_worktree( - repo: "Repo", - remote: "Remote", - targets: Optional["TargetType"] = None, - jobs: Optional[int] = None, - **kwargs: Any, -) -> int: - from dvc_data.index import save - - transferred = 0 - for remote_name, view in worktree_view_by_remotes( - repo.index, push=True, targets=targets, **kwargs - ): - remote_obj = _get_remote(repo, remote_name, remote, "fetch") - index = view.data["repo"] - total = len(index) - with Callback.as_tqdm_callback( - unit="file", - desc=f"Fetching from remote {remote_obj.name!r}", - disable=total == 0, - ) as cb: - cb.set_size(total) - transferred += save(index, callback=cb, jobs=jobs, storage="remote") - return transferred - - def push_worktree( repo: "Repo", remote: "Remote", diff --git a/pyproject.toml b/pyproject.toml index 3763dde1c0..a2370adc6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "configobj>=5.0.6", "distro>=1.3", "dpath<3,>=2.1.0", - "dvc-data>=1.0.3,<1.1.0", + "dvc-data>=1.1.0,<1.2.0", "dvc-http>=2.29.0", "dvc-render>=0.3.1,<1", "dvc-studio-client>=0.9.2,<1", diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index 125a47dbde..070828e6e4 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -159,8 +159,8 @@ def test_missing_cache(tmp_dir, dvc, local_remote, caplog): "Some of the cache files do not exist " "neither locally nor on remote. Missing cache files:\n" ) - foo = "name: bar, md5: 37b51d194a7513e45b56f6524f2d51f2\n" - bar = "name: foo, md5: acbd18db4cc2f85cedef654fccc4a4d8\n" + foo = "md5: 37b51d194a7513e45b56f6524f2d51f2\n" + bar = "md5: acbd18db4cc2f85cedef654fccc4a4d8\n" caplog.clear() dvc.push() @@ -198,7 +198,7 @@ def test_verify_hashes(tmp_dir, scm, dvc, mocker, tmp_path_factory, local_remote dvc.pull() # NOTE: 1 is for index.data_tree building - assert hash_spy.call_count == 1 + assert hash_spy.call_count == 2 # Removing cache will invalidate existing state entries dvc.cache.local.clear() @@ -206,7 +206,7 @@ def test_verify_hashes(tmp_dir, scm, dvc, mocker, tmp_path_factory, local_remote dvc.config["remote"]["upstream"]["verify"] = True dvc.pull() - assert hash_spy.call_count == 6 + assert hash_spy.call_count == 8 @flaky(max_runs=3, min_passes=1) @@ -268,7 +268,7 @@ def test_pull_partial_import(tmp_dir, dvc, local_workspace): stage = dvc.imp_url("remote://workspace/file", os.fspath(dst), no_download=True) result = dvc.pull("file") - assert result["fetched"] == 1 + assert result["fetched"] == 0 assert dst.exists() assert stage.outs[0].get_hash().value == "d10b4c3ff123b26dc068d43a8bef2d23" @@ -483,7 +483,7 @@ def test_pull_partial(tmp_dir, dvc, local_remote): clean(["foo"], dvc) stats = dvc.pull(os.path.join("foo", "bar")) - assert stats["fetched"] == 1 + assert stats["fetched"] == 3 assert (tmp_dir / "foo").read_text() == {"bar": {"baz": "baz"}} diff --git a/tests/func/test_import.py b/tests/func/test_import.py index f67b9ceaeb..6a81dff9f5 100644 --- a/tests/func/test_import.py +++ b/tests/func/test_import.py @@ -241,12 +241,13 @@ def test_pull_import_no_download(tmp_dir, scm, dvc, erepo_dir): dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported", no_download=True) dvc.pull(["foo_imported.dvc"]) - assert (tmp_dir / "foo_imported").exists + assert (tmp_dir / "foo_imported").exists() assert (tmp_dir / "foo_imported" / "bar").read_bytes() == b"bar" assert (tmp_dir / "foo_imported" / "baz").read_bytes() == b"baz contents" - stage = load_file(dvc, "foo_imported.dvc").stage + dvc.commit(force=True) + stage = load_file(dvc, "foo_imported.dvc").stage if os.name == "nt": expected_hash = "2e798234df5f782340ac3ce046f8dfae.dir" else: diff --git a/tests/func/test_import_url.py b/tests/func/test_import_url.py index cdace1c857..02eecbbfd1 100644 --- a/tests/func/test_import_url.py +++ b/tests/func/test_import_url.py @@ -232,8 +232,9 @@ def test_partial_import_pull(tmp_dir, scm, dvc, local_workspace): assert dst.exists() - stage = load_file(dvc, "file.dvc").stage + dvc.commit(force=True) + stage = load_file(dvc, "file.dvc").stage assert stage.outs[0].hash_info.value == "d10b4c3ff123b26dc068d43a8bef2d23" assert stage.outs[0].meta.size == 12 diff --git a/tests/func/test_virtual_directory.py b/tests/func/test_virtual_directory.py index 6aec736448..b609e5718b 100644 --- a/tests/func/test_virtual_directory.py +++ b/tests/func/test_virtual_directory.py @@ -182,7 +182,7 @@ def test_partial_checkout_and_update(M, tmp_dir, dvc, remote): assert dvc.pull("dir/subdir") == M.dict( added=[join("dir", "")], - fetched=1, + fetched=3, ) assert (tmp_dir / "dir").read_text() == {"subdir": {"lorem": "lorem"}}