Skip to content

Commit 710f210

Browse files
committed
remote: support legacy cache push/fetch
1 parent 86b475a commit 710f210

File tree

3 files changed

+160
-42
lines changed

3 files changed

+160
-42
lines changed

dvc/data_cloud.py

Lines changed: 112 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
"""Manages dvc remotes that user can use with push/pull/status commands."""
22

33
import logging
4-
from typing import TYPE_CHECKING, Iterable, Optional
4+
from typing import TYPE_CHECKING, Iterable, Optional, Set, Tuple
55

66
from dvc.config import NoRemoteError, RemoteConfigError
77
from dvc.utils.objects import cached_property
88
from dvc_data.hashfile.db import get_index
9+
from dvc_data.hashfile.transfer import TransferResult
910

1011
if TYPE_CHECKING:
1112
from dvc.fs import FileSystem
1213
from dvc_data.hashfile.db import HashFileDB
1314
from dvc_data.hashfile.hash_info import HashInfo
1415
from dvc_data.hashfile.status import CompareStatusResult
15-
from dvc_data.hashfile.transfer import TransferResult
1616

1717
logger = logging.getLogger(__name__)
1818

@@ -50,6 +50,21 @@ def legacy_odb(self) -> "HashFileDB":
5050
return get_odb(self.fs, path, hash_name="md5-dos2unix", **self.config)
5151

5252

53+
def _split_legacy_hash_infos(
54+
hash_infos: Iterable["HashInfo"],
55+
) -> Tuple[Set["HashInfo"], Set["HashInfo"]]:
56+
from dvc.cachemgr import LEGACY_HASH_NAMES
57+
58+
legacy = set()
59+
default = set()
60+
for hi in hash_infos:
61+
if hi.name in LEGACY_HASH_NAMES:
62+
legacy.add(hi)
63+
else:
64+
default.add(hi)
65+
return legacy, default
66+
67+
5368
class DataCloud:
5469
"""Class that manages dvc remotes.
5570
@@ -167,14 +182,40 @@ def push(
167182
By default remote from core.remote config option is used.
168183
odb: optional ODB to push to. Overrides remote.
169184
"""
170-
odb = odb or self.get_remote_odb(remote, "push")
185+
if odb is not None:
186+
return self._push(objs, jobs=jobs, odb=odb)
187+
legacy_objs, default_objs = _split_legacy_hash_infos(objs)
188+
result = TransferResult(set(), set())
189+
if legacy_objs:
190+
odb = self.get_remote_odb(remote, "push", hash_name="md5-dos2unix")
191+
t, f = self._push(legacy_objs, jobs=jobs, odb=odb)
192+
result.transferred.update(t)
193+
result.failed.update(f)
194+
if default_objs:
195+
odb = self.get_remote_odb(remote, "push")
196+
t, f = self._push(default_objs, jobs=jobs, odb=odb)
197+
result.transferred.update(t)
198+
result.failed.update(f)
199+
return result
200+
201+
def _push(
202+
self,
203+
objs: Iterable["HashInfo"],
204+
*,
205+
jobs: Optional[int] = None,
206+
odb: "HashFileDB",
207+
) -> "TransferResult":
208+
if odb.hash_name == "md5-dos2unix":
209+
cache = self.repo.cache.legacy
210+
else:
211+
cache = self.repo.cache.local
171212
return self.transfer(
172-
self.repo.cache.local,
213+
cache,
173214
odb,
174215
objs,
175216
jobs=jobs,
176217
dest_index=get_index(odb),
177-
cache_odb=self.repo.cache.local,
218+
cache_odb=cache,
178219
validate_status=self._log_missing,
179220
)
180221

@@ -194,14 +235,41 @@ def pull(
194235
By default remote from core.remote config option is used.
195236
odb: optional ODB to pull from. Overrides remote.
196237
"""
197-
odb = odb or self.get_remote_odb(remote, "pull")
238+
if odb is not None:
239+
return self._pull(objs, jobs=jobs, odb=odb)
240+
legacy_objs, default_objs = _split_legacy_hash_infos(objs)
241+
result = TransferResult(set(), set())
242+
if legacy_objs:
243+
odb = self.get_remote_odb(remote, "pull", hash_name="md5-dos2unix")
244+
assert odb.hash_name == "md5-dos2unix"
245+
t, f = self._pull(legacy_objs, jobs=jobs, odb=odb)
246+
result.transferred.update(t)
247+
result.failed.update(f)
248+
if default_objs:
249+
odb = self.get_remote_odb(remote, "pull")
250+
t, f = self._pull(default_objs, jobs=jobs, odb=odb)
251+
result.transferred.update(t)
252+
result.failed.update(f)
253+
return result
254+
255+
def _pull(
256+
self,
257+
objs: Iterable["HashInfo"],
258+
*,
259+
jobs: Optional[int] = None,
260+
odb: "HashFileDB",
261+
) -> "TransferResult":
262+
if odb.hash_name == "md5-dos2unix":
263+
cache = self.repo.cache.legacy
264+
else:
265+
cache = self.repo.cache.local
198266
return self.transfer(
199267
odb,
200-
self.repo.cache.local,
268+
cache,
201269
objs,
202270
jobs=jobs,
203271
src_index=get_index(odb),
204-
cache_odb=self.repo.cache.local,
272+
cache_odb=cache,
205273
verify=odb.verify,
206274
validate_status=self._log_missing,
207275
)
@@ -223,17 +291,49 @@ def status(
223291
is used.
224292
odb: optional ODB to check status from. Overrides remote.
225293
"""
294+
from dvc_data.hashfile.status import CompareStatusResult
295+
296+
if odb is not None:
297+
return self._status(objs, jobs=jobs, odb=odb)
298+
result = CompareStatusResult(set(), set(), set(), set())
299+
legacy_objs, default_objs = _split_legacy_hash_infos(objs)
300+
if legacy_objs:
301+
odb = self.get_remote_odb(remote, "status", hash_name="md5-dos2unix")
302+
assert odb.hash_name == "md5-dos2unix"
303+
o, m, n, d = self._status(legacy_objs, jobs=jobs, odb=odb)
304+
result.ok.update(o)
305+
result.missing.update(m)
306+
result.new.update(n)
307+
result.deleted.update(d)
308+
if default_objs:
309+
odb = self.get_remote_odb(remote, "status")
310+
o, m, n, d = self._status(default_objs, jobs=jobs, odb=odb)
311+
result.ok.update(o)
312+
result.missing.update(m)
313+
result.new.update(n)
314+
result.deleted.update(d)
315+
return result
316+
317+
def _status(
318+
self,
319+
objs: Iterable["HashInfo"],
320+
*,
321+
jobs: Optional[int] = None,
322+
odb: "HashFileDB",
323+
):
226324
from dvc_data.hashfile.status import compare_status
227325

228-
if not odb:
229-
odb = self.get_remote_odb(remote, "status")
326+
if odb.hash_name == "md5-dos2unix":
327+
cache = self.repo.cache.legacy
328+
else:
329+
cache = self.repo.cache.local
230330
return compare_status(
231-
self.repo.cache.local,
331+
cache,
232332
odb,
233333
objs,
234334
jobs=jobs,
235335
dest_index=get_index(odb),
236-
cache_odb=self.repo.cache.local,
336+
cache_odb=cache,
237337
)
238338

239339
def get_url_for(self, remote, checksum):

dvc/repo/imports.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import os
3+
from functools import partial
34
from tempfile import TemporaryDirectory
45
from typing import TYPE_CHECKING, List, Set, Tuple, Union
56

@@ -17,17 +18,18 @@
1718

1819
def unfetched_view(
1920
index: "Index", targets: "TargetType", unpartial: bool = False, **kwargs
20-
) -> Tuple["IndexView", List["Dependency"]]:
21+
) -> Tuple["IndexView", "IndexView", List["Dependency"]]:
2122
"""Return index view of imports which have not been fetched.
2223
2324
Returns:
24-
Tuple in the form (view, changed_deps) where changed_imports is a list
25-
of import dependencies that cannot be fetched due to changed data
26-
source.
25+
Tuple in the form (legacy_view, view, changed_deps) where changed_imports is a
26+
list of import dependencies that cannot be fetched due to changed data source.
2727
"""
28+
from dvc.cachemgr import LEGACY_HASH_NAMES
29+
2830
changed_deps: List["Dependency"] = []
2931

30-
def need_fetch(stage: "Stage") -> bool:
32+
def need_fetch(stage: "Stage", legacy: bool = False) -> bool:
3133
if not stage.is_import or (stage.is_partial_import and not unpartial):
3234
return False
3335

@@ -40,10 +42,19 @@ def need_fetch(stage: "Stage") -> bool:
4042
changed_deps.append(dep)
4143
return False
4244

43-
return True
45+
if out.hash_name in LEGACY_HASH_NAMES and legacy:
46+
return True
47+
if out.hash_name not in LEGACY_HASH_NAMES and not legacy:
48+
return True
49+
return False
4450

51+
legacy_unfetched = index.targets_view(
52+
targets,
53+
stage_filter=partial(need_fetch, legacy=True),
54+
**kwargs,
55+
)
4556
unfetched = index.targets_view(targets, stage_filter=need_fetch, **kwargs)
46-
return unfetched, changed_deps
57+
return legacy_unfetched, unfetched, changed_deps
4758

4859

4960
def partial_view(index: "Index", targets: "TargetType", **kwargs) -> "IndexView":
@@ -94,33 +105,36 @@ def save_imports(
94105

95106
downloaded: Set["HashInfo"] = set()
96107

97-
unfetched, changed = unfetched_view(
108+
legacy_unfetched, unfetched, changed = unfetched_view(
98109
repo.index, targets, unpartial=unpartial, **kwargs
99110
)
100111
for dep in changed:
101112
logger.warning(str(DataSourceChanged(f"{dep.stage} ({dep})")))
102113

103-
data_view = unfetched.data["repo"]
104-
if len(data_view):
105-
cache = repo.cache.local
106-
if not cache.fs.exists(cache.path):
107-
os.makedirs(cache.path)
108-
with TemporaryDirectory(dir=cache.path) as tmpdir:
109-
with Callback.as_tqdm_callback(
110-
desc="Downloading imports from source",
111-
unit="files",
112-
) as cb:
113-
checkout(data_view, tmpdir, cache.fs, callback=cb, storage="remote")
114-
md5(data_view)
115-
save(data_view, odb=cache, hardlink=True)
116-
117-
downloaded.update(
118-
entry.hash_info
119-
for _, entry in data_view.iteritems()
120-
if entry.meta is not None
121-
and not entry.meta.isdir
122-
and entry.hash_info is not None
123-
)
114+
for view, cache in [
115+
(legacy_unfetched, repo.cache.legacy),
116+
(unfetched, repo.cache.local),
117+
]:
118+
data_view = view.data["repo"]
119+
if len(data_view):
120+
if not cache.fs.exists(cache.path):
121+
os.makedirs(cache.path)
122+
with TemporaryDirectory(dir=cache.path) as tmpdir:
123+
with Callback.as_tqdm_callback(
124+
desc="Downloading imports from source",
125+
unit="files",
126+
) as cb:
127+
checkout(data_view, tmpdir, cache.fs, callback=cb, storage="remote")
128+
md5(data_view, name=cache.hash_name)
129+
save(data_view, odb=cache, hardlink=True)
130+
131+
downloaded.update(
132+
entry.hash_info
133+
for _, entry in data_view.iteritems()
134+
if entry.meta is not None
135+
and not entry.meta.isdir
136+
and entry.hash_info is not None
137+
)
124138

125139
if unpartial:
126140
unpartial_imports(partial_view(repo.index, targets, **kwargs))

dvc/repo/index.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def _load_data_from_outs(index, prefix, outs):
150150

151151

152152
def _load_storage_from_out(storage_map, key, out):
153+
from dvc.cachemgr import LEGACY_HASH_NAMES
153154
from dvc.config import NoRemoteError
154155
from dvc_data.index import FileStorage, ObjectStorage
155156

@@ -168,7 +169,10 @@ def _load_storage_from_out(storage_map, key, out):
168169
)
169170
)
170171
else:
171-
storage_map.add_remote(ObjectStorage(key, remote.odb, index=remote.index))
172+
odb = (
173+
remote.legacy_odb if out.hash_name in LEGACY_HASH_NAMES else remote.odb
174+
)
175+
storage_map.add_remote(ObjectStorage(key, odb, index=remote.index))
172176
except NoRemoteError:
173177
pass
174178

0 commit comments

Comments
 (0)