Skip to content

Commit fcf2e35

Browse files
authored
tree: get rid of explicit tree arguments (#4310)
* cache: don't make save() compute the hash This is incorrect and leads to weird bugs like #4305 * tree: get rid of explicit tree arguments
1 parent dfdb24d commit fcf2e35

File tree

5 files changed

+46
-36
lines changed

5 files changed

+46
-36
lines changed

dvc/cache/base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,6 @@ def save(self, path_info, tree, hash_info, save_link=True, **kwargs):
276276
self.tree.scheme,
277277
)
278278

279-
if not hash_info:
280-
hash_info = self.tree.save_info(path_info, tree=tree, **kwargs)
281279
hash_ = hash_info[self.tree.PARAM_CHECKSUM]
282280
return self._save(path_info, tree, hash_, save_link, **kwargs)
283281

dvc/external_repo.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,26 @@ def use_cache(self, cache):
9090
"""Use the specified cache in place of default tmpdir cache for
9191
download operations.
9292
"""
93-
if hasattr(self, "cache"):
94-
save_cache = self.cache.local
95-
self.cache.local = cache
93+
has_cache = hasattr(self, "cache")
94+
95+
if has_cache:
96+
save_cache = self.cache.local # pylint: disable=E0203
97+
self.cache.local = cache # pylint: disable=E0203
98+
else:
99+
from collections import namedtuple
100+
101+
mock_cache = namedtuple("MockCache", ["local"])(local=cache)
102+
self.cache = mock_cache # pylint: disable=W0201
103+
96104
self._local_cache = cache
97105

98106
yield
99107

100-
if hasattr(self, "cache"):
108+
if has_cache:
101109
self.cache.local = save_cache
110+
else:
111+
del self.cache
112+
102113
self._local_cache = None
103114

104115
@cached_property
@@ -130,14 +141,17 @@ def download_update(result):
130141
for path in paths:
131142
if not self.repo_tree.exists(path):
132143
raise PathMissingError(path, self.url)
133-
save_info = self.local_cache.save(
144+
hash_info = self.repo_tree.save_info(
145+
path, download_callback=download_update
146+
)
147+
self.local_cache.save(
134148
path,
135149
self.repo_tree,
136-
None,
150+
hash_info,
137151
save_link=False,
138152
download_callback=download_update,
139153
)
140-
save_infos.append(save_info)
154+
save_infos.append(hash_info)
141155

142156
return sum(download_results), failed, save_infos
143157

dvc/repo/tree.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ class RepoTree(BaseTree): # pylint:disable=abstract-method
245245
Any kwargs will be passed to `DvcTree()`.
246246
"""
247247

248+
scheme = "local"
249+
PARAM_CHECKSUM = "md5"
250+
248251
def __init__(self, repo, **kwargs):
249252
super().__init__(repo, {"url": repo.root_dir})
250253
if hasattr(repo, "dvc_dir"):

dvc/tree/base.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -236,22 +236,17 @@ def is_dir_hash(cls, hash_):
236236
return False
237237
return hash_.endswith(cls.CHECKSUM_DIR_SUFFIX)
238238

239-
def get_hash(self, path_info, tree=None, **kwargs):
239+
def get_hash(self, path_info, **kwargs):
240240
assert path_info and (
241241
isinstance(path_info, str) or path_info.scheme == self.scheme
242242
)
243243

244-
if not tree:
245-
tree = self
246-
247-
if not tree.exists(path_info):
244+
if not self.exists(path_info):
248245
return None
249246

250-
if tree == self:
251-
# pylint: disable=assignment-from-none
252-
hash_ = self.state.get(path_info)
253-
else:
254-
hash_ = None
247+
# pylint: disable=assignment-from-none
248+
hash_ = self.state.get(path_info)
249+
255250
# If we have dir hash in state db, but dir cache file is lost,
256251
# then we need to recollect the dir via .get_dir_hash() call below,
257252
# see https://github.com/iterative/dvc/issues/2219 for context
@@ -267,10 +262,10 @@ def get_hash(self, path_info, tree=None, **kwargs):
267262
if hash_:
268263
return hash_
269264

270-
if tree.isdir(path_info):
271-
hash_ = self.get_dir_hash(path_info, tree, **kwargs)
265+
if self.isdir(path_info):
266+
hash_ = self.get_dir_hash(path_info, **kwargs)
272267
else:
273-
hash_ = tree.get_file_hash(path_info)
268+
hash_ = self.get_file_hash(path_info)
274269

275270
if hash_ and self.exists(path_info):
276271
self.state.save(path_info, hash_)
@@ -280,11 +275,11 @@ def get_hash(self, path_info, tree=None, **kwargs):
280275
def get_file_hash(self, path_info):
281276
raise NotImplementedError
282277

283-
def get_dir_hash(self, path_info, tree, **kwargs):
278+
def get_dir_hash(self, path_info, **kwargs):
284279
if not self.cache:
285280
raise RemoteCacheRequiredError(path_info)
286281

287-
dir_info = self._collect_dir(path_info, tree, **kwargs)
282+
dir_info = self._collect_dir(path_info, **kwargs)
288283
return self._save_dir_info(dir_info, path_info)
289284

290285
def hash_to_path_info(self, hash_):
@@ -298,29 +293,26 @@ def path_to_hash(self, path):
298293

299294
return "".join(parts)
300295

301-
def save_info(self, path_info, tree=None, **kwargs):
302-
return {
303-
self.PARAM_CHECKSUM: self.get_hash(path_info, tree=tree, **kwargs)
304-
}
296+
def save_info(self, path_info, **kwargs):
297+
return {self.PARAM_CHECKSUM: self.get_hash(path_info, **kwargs)}
305298

306-
@staticmethod
307-
def _calculate_hashes(file_infos, tree):
299+
def _calculate_hashes(self, file_infos):
308300
file_infos = list(file_infos)
309301
with Tqdm(
310302
total=len(file_infos),
311303
unit="md5",
312304
desc="Computing file/dir hashes (only done once)",
313305
) as pbar:
314-
worker = pbar.wrap_fn(tree.get_file_hash)
315-
with ThreadPoolExecutor(max_workers=tree.hash_jobs) as executor:
306+
worker = pbar.wrap_fn(self.get_file_hash)
307+
with ThreadPoolExecutor(max_workers=self.hash_jobs) as executor:
316308
tasks = executor.map(worker, file_infos)
317309
hashes = dict(zip(file_infos, tasks))
318310
return hashes
319311

320-
def _collect_dir(self, path_info, tree, **kwargs):
312+
def _collect_dir(self, path_info, **kwargs):
321313
file_infos = set()
322314

323-
for fname in tree.walk_files(path_info, **kwargs):
315+
for fname in self.walk_files(path_info, **kwargs):
324316
if DvcIgnore.DVCIGNORE_FILE == fname.name:
325317
raise DvcIgnoreInCollectedDirError(fname.parent)
326318

@@ -329,7 +321,7 @@ def _collect_dir(self, path_info, tree, **kwargs):
329321
hashes = {fi: self.state.get(fi) for fi in file_infos}
330322
not_in_state = {fi for fi, hash_ in hashes.items() if hash_ is None}
331323

332-
new_hashes = self._calculate_hashes(not_in_state, tree)
324+
new_hashes = self._calculate_hashes(not_in_state)
333325
hashes.update(new_hashes)
334326

335327
result = [

tests/func/test_tree.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,10 @@ def test_repotree_cache_save(tmp_dir, dvc, scm, erepo_dir, local_cloud):
217217
with erepo_dir.dvc.state:
218218
cache = dvc.cache.local
219219
with cache.tree.state:
220-
cache.save(PathInfo(erepo_dir / "dir"), tree, None)
220+
path_info = PathInfo(erepo_dir / "dir")
221+
hash_info = cache.tree.save_info(path_info)
222+
cache.save(path_info, tree, hash_info)
223+
221224
for hash_ in expected:
222225
assert os.path.exists(cache.tree.hash_to_path_info(hash_))
223226

0 commit comments

Comments
 (0)