Skip to content

Commit 9ead641

Browse files
authored
remote: separate cloud remote and cloud cache classes (#4019)
* remote: move get_file_checksum() into tree * makes RemoteTree and RepoTree consistent with regard to checksum calculation * remote: save() now takes explicit tree parameter * when tree is remote.tree, save will be a move + link operation (same as default existing behavior) * when saving path from a different tree, save will be a copy operation * tests: update for moved remote/tree functions * remote: move get_checksum into tree * remote: separate cloud remote and cache classes * tests: update unit tests for remote/cache separation * remote: cloud cache should extend cloud remote * remote: move LocalRemote.get * tests: update func tests * dependency: update for moved get_checksum * remote: fix state lookup bug
1 parent 652f5ab commit 9ead641

25 files changed

+886
-849
lines changed

dvc/cache.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ def _make_remote_property(name):
2828
"""
2929

3030
def getter(self):
31-
from dvc.remote import Remote
31+
from dvc.remote import Cache as CloudCache
3232

3333
remote = self.config.get(name)
3434
if not remote:
3535
return None
3636

37-
return Remote(self.repo, name=remote)
37+
return CloudCache(self.repo, name=remote)
3838

3939
getter.__name__ = name
4040
return cached_property(getter)
@@ -50,7 +50,7 @@ class Cache:
5050
CACHE_DIR = "cache"
5151

5252
def __init__(self, repo):
53-
from dvc.remote import Remote
53+
from dvc.remote import Cache as CloudCache
5454

5555
self.repo = repo
5656
self.config = config = repo.config["cache"]
@@ -62,7 +62,7 @@ def __init__(self, repo):
6262
else:
6363
settings = {**config, "url": config["dir"]}
6464

65-
self.local = Remote(repo, **settings)
65+
self.local = CloudCache(repo, **settings)
6666

6767
s3 = _make_remote_property("s3")
6868
gs = _make_remote_property("gs")

dvc/dependency/repo.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ def _get_checksum(self, locked=True):
6464

6565
# We are polluting our repo cache with some dir listing here
6666
if tree.isdir(path):
67-
return self.repo.cache.local.get_dir_checksum(
68-
path, tree=tree
69-
)
67+
return self.repo.cache.local.get_checksum(path, tree)
7068
return tree.get_file_checksum(path)
7169

7270
def status(self):

dvc/external_repo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,9 @@ def download_update(result):
126126
raise PathMissingError(path, self.url)
127127
save_info = self.local_cache.save(
128128
path,
129+
self.repo_tree,
129130
None,
130-
tree=self.repo_tree,
131+
save_link=False,
131132
download_callback=download_update,
132133
)
133134
save_infos.append(save_info)

dvc/output/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def save(self):
267267

268268
def commit(self):
269269
if self.use_cache:
270-
self.cache.save(self.path_info, self.info)
270+
self.cache.save(self.path_info, self.cache.tree, self.info)
271271

272272
def dumpd(self):
273273
ret = copy(self.info)

dvc/remote/__init__.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
11
import posixpath
22
from urllib.parse import urlparse
33

4-
from dvc.remote.azure import AzureRemote
4+
from dvc.remote.azure import AzureCache, AzureRemote
55
from dvc.remote.gdrive import GDriveRemote
6-
from dvc.remote.gs import GSRemote
7-
from dvc.remote.hdfs import HDFSRemote
6+
from dvc.remote.gs import GSCache, GSRemote
7+
from dvc.remote.hdfs import HDFSCache, HDFSRemote
88
from dvc.remote.http import HTTPRemote
99
from dvc.remote.https import HTTPSRemote
10-
from dvc.remote.local import LocalRemote
10+
from dvc.remote.local import LocalCache, LocalRemote
1111
from dvc.remote.oss import OSSRemote
12-
from dvc.remote.s3 import S3Remote
13-
from dvc.remote.ssh import SSHRemote
12+
from dvc.remote.s3 import S3Cache, S3Remote
13+
from dvc.remote.ssh import SSHCache, SSHRemote
14+
15+
CACHES = [
16+
AzureCache,
17+
GSCache,
18+
HDFSCache,
19+
S3Cache,
20+
SSHCache,
21+
# LocalCache is the default
22+
]
1423

1524
REMOTES = [
1625
AzureRemote,
@@ -26,21 +35,30 @@
2635
]
2736

2837

29-
def _get(remote_conf):
30-
for remote in REMOTES:
38+
def _get(remote_conf, remotes, default):
39+
for remote in remotes:
3140
if remote.supported(remote_conf):
3241
return remote
33-
return LocalRemote
42+
return default
3443

3544

36-
def Remote(repo, **kwargs):
45+
def _get_conf(repo, **kwargs):
3746
name = kwargs.get("name")
3847
if name:
3948
remote_conf = repo.config["remote"][name.lower()]
4049
else:
4150
remote_conf = kwargs
42-
remote_conf = _resolve_remote_refs(repo.config, remote_conf)
43-
return _get(remote_conf)(repo, remote_conf)
51+
return _resolve_remote_refs(repo.config, remote_conf)
52+
53+
54+
def Remote(repo, **kwargs):
55+
remote_conf = _get_conf(repo, **kwargs)
56+
return _get(remote_conf, REMOTES, LocalRemote)(repo, remote_conf)
57+
58+
59+
def Cache(repo, **kwargs):
60+
remote_conf = _get_conf(repo, **kwargs)
61+
return _get(remote_conf, CACHES, LocalCache)(repo, remote_conf)
4462

4563

4664
def _resolve_remote_refs(config, remote_conf):

dvc/remote/azure.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from dvc.path_info import CloudURLInfo
99
from dvc.progress import Tqdm
10-
from dvc.remote.base import BaseRemote, BaseRemoteTree
10+
from dvc.remote.base import BaseRemote, BaseRemoteTree, CacheMixin
1111
from dvc.scheme import Schemes
1212

1313
logger = logging.getLogger(__name__)
@@ -108,6 +108,9 @@ def remove(self, path_info):
108108
logger.debug(f"Removing {path_info}")
109109
self.blob_service.delete_blob(path_info.bucket, path_info.path)
110110

111+
def get_file_checksum(self, path_info):
112+
return self.get_etag(path_info)
113+
111114
def _upload(
112115
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
113116
):
@@ -134,10 +137,11 @@ def _download(
134137
class AzureRemote(BaseRemote):
135138
scheme = Schemes.AZURE
136139
REQUIRES = {"azure-storage-blob": "azure.storage.blob"}
140+
TREE_CLS = AzureRemoteTree
137141
PARAM_CHECKSUM = "etag"
138142
COPY_POLL_SECONDS = 5
139143
LIST_OBJECT_PAGE_SIZE = 5000
140-
TREE_CLS = AzureRemoteTree
141144

142-
def get_file_checksum(self, path_info):
143-
return self.tree.get_etag(path_info)
145+
146+
class AzureCache(AzureRemote, CacheMixin):
147+
pass

0 commit comments

Comments
 (0)