Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save import hash info to state #10531

Merged
merged 4 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dvc/dependency/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def update(self, rev=None):
self.fs_path = self.fs.version_path(self.fs_path, self.meta.version_id)

def download(self, to, jobs=None):
fs_download(self.fs, self.fs_path, to.fs_path, jobs=jobs)
return fs_download(self.fs, self.fs_path, to.fs_path, jobs=jobs)

def save(self):
super().save()
Expand Down
44 changes: 20 additions & 24 deletions dvc/dependency/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import voluptuous as vol

from dvc.prompt import confirm
from dvc.utils import as_posix

from .base import Dependency
Expand All @@ -12,6 +11,7 @@
from dvc.fs import DVCFileSystem
from dvc.output import Output
from dvc.stage import Stage
from dvc_data.hashfile.hash_info import HashInfo


class RepoDependency(Dependency):
Expand Down Expand Up @@ -94,29 +94,25 @@ def dumpd(self, **kwargs) -> dict[str, Union[str, dict[str, str]]]:
}

def download(self, to: "Output", jobs: Optional[int] = None):
from dvc_data.hashfile.build import build
from dvc_data.hashfile.checkout import CheckoutError, checkout

try:
repo = self._make_fs(locked=True).repo

_, _, obj = build(
repo.cache.local,
self.fs_path,
repo.dvcfs,
repo.cache.local.fs.PARAM_CHECKSUM,
)
checkout(
to.fs_path,
to.fs,
obj,
self.repo.cache.local,
ignore=None,
state=self.repo.state,
prompt=confirm,
)
except (CheckoutError, FileNotFoundError):
super().download(to=to, jobs=jobs)
from dvc.fs import LocalFileSystem

files = super().download(to=to, jobs=jobs)
if not isinstance(to.fs, LocalFileSystem):
return files

hashes: list[tuple[str, HashInfo, dict[str, Any]]] = []
for src_path, dest_path in files:
try:
hash_info = self.fs.info(src_path)["dvc_info"]["entry"].hash_info
skshetry marked this conversation as resolved.
Show resolved Hide resolved
dest_info = to.fs.info(dest_path)
except (OSError, KeyError, AttributeError):
skshetry marked this conversation as resolved.
Show resolved Hide resolved
# If no hash info found, just keep going and output will be hashed later
continue
if hash_info:
hashes.append((dest_path, hash_info, dest_info))
cache = to.cache if to.use_cache else to.local_cache
cache.state.save_many(hashes, to.fs)
return files

def update(self, rev: Optional[str] = None):
if rev:
Expand Down
6 changes: 3 additions & 3 deletions dvc/fs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

def download(
fs: "FileSystem", fs_path: str, to: str, jobs: Optional[int] = None
) -> int:
) -> list[tuple[str, str]]:
from dvc.scm import lfs_prefetch

from .callbacks import TqdmCallback
Expand All @@ -61,7 +61,7 @@ def download(
]
if not from_infos:
localfs.makedirs(to, exist_ok=True)
return 0
return []
to_infos = [
localfs.join(to, *fs.relparts(info, fs_path)) for info in from_infos
]
Expand All @@ -81,7 +81,7 @@ def download(
cb.set_size(len(from_infos))
jobs = jobs or fs.jobs
generic.copy(fs, from_infos, localfs, to_infos, callback=cb, batch_size=jobs)
return len(to_infos)
return list(zip(from_infos, to_infos))


def parse_external_url(url, fs_config=None, config=None):
Expand Down
2 changes: 1 addition & 1 deletion dvc/repo/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def download(

out = resolve_output(path, out, force=force)
fs = self.repo.dvcfs
count = fs_download(fs, path, os.path.abspath(out), jobs=jobs)
count = len(fs_download(fs, path, os.path.abspath(out), jobs=jobs))
return count, out

@staticmethod
Expand Down
16 changes: 7 additions & 9 deletions tests/func/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

from dvc.cachemgr import CacheManager
from dvc.config import NoRemoteError
from dvc.dependency import base
from dvc.dvcfile import load_file
from dvc.fs import system
from dvc.scm import Git
from dvc.stage.exceptions import StagePathNotFoundError
from dvc.testing.tmp_dir import make_subrepo
from dvc.utils.fs import remove
from dvc_data.hashfile import hash
from dvc_data.index.index import DataIndexDirError


Expand Down Expand Up @@ -725,14 +725,12 @@ def test_import_invalid_configs(tmp_dir, scm, dvc, erepo_dir):
)


def test_reimport(tmp_dir, scm, dvc, erepo_dir, mocker):
def test_import_no_hash(tmp_dir, scm, dvc, erepo_dir, mocker):
with erepo_dir.chdir():
erepo_dir.dvc_gen("foo", "foo content", commit="create foo")

spy = mocker.spy(base, "fs_download")
dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported")
assert spy.called

spy.reset_mock()
dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported", force=True)
assert not spy.called
spy = mocker.spy(hash, "file_md5")
stage = dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported")
assert spy.call_count == 1
for call in spy.call_args_list:
assert stage.outs[0].fs_path != call.args[0]
Loading