Skip to content

Commit 4382b80

Browse files
authored
dvcfile: support remote per output (#6486)
Related to #2095
1 parent 95ba68d commit 4382b80

File tree

8 files changed

+121
-20
lines changed

8 files changed

+121
-20
lines changed

dvc/output.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def loadd_from(stage, d_list):
8181
desc = d.pop(Output.PARAM_DESC, False)
8282
isexec = d.pop(Output.PARAM_ISEXEC, False)
8383
live = d.pop(Output.PARAM_LIVE, False)
84+
remote = d.pop(Output.PARAM_REMOTE, None)
8485
ret.append(
8586
_get(
8687
stage,
@@ -94,6 +95,7 @@ def loadd_from(stage, d_list):
9495
desc=desc,
9596
isexec=isexec,
9697
live=live,
98+
remote=remote,
9799
)
98100
)
99101
return ret
@@ -109,6 +111,7 @@ def loads_from(
109111
checkpoint=False,
110112
isexec=False,
111113
live=False,
114+
remote=None,
112115
):
113116
return [
114117
_get(
@@ -122,6 +125,7 @@ def loads_from(
122125
checkpoint=checkpoint,
123126
isexec=isexec,
124127
live=live,
128+
remote=remote,
125129
)
126130
for s in s_list
127131
]
@@ -162,7 +166,6 @@ def load_from_pipeline(stage, data, typ="outs"):
162166
metric = typ == stage.PARAM_METRICS
163167
plot = typ == stage.PARAM_PLOTS
164168
live = typ == stage.PARAM_LIVE
165-
166169
if live:
167170
# `live` is single object
168171
data = [data]
@@ -185,6 +188,7 @@ def load_from_pipeline(stage, data, typ="outs"):
185188
Output.PARAM_CACHE,
186189
Output.PARAM_PERSIST,
187190
Output.PARAM_CHECKPOINT,
191+
Output.PARAM_REMOTE,
188192
],
189193
)
190194

@@ -255,6 +259,7 @@ class Output:
255259
PARAM_LIVE = "live"
256260
PARAM_LIVE_SUMMARY = "summary"
257261
PARAM_LIVE_HTML = "html"
262+
PARAM_REMOTE = "remote"
258263

259264
METRIC_SCHEMA = Any(
260265
None,
@@ -283,6 +288,7 @@ def __init__(
283288
live=False,
284289
desc=None,
285290
isexec=False,
291+
remote=None,
286292
):
287293
self.repo = stage.repo if stage else None
288294

@@ -326,7 +332,7 @@ def __init__(
326332
self.obj = None
327333
self.isexec = False if self.IS_DEPENDENCY else isexec
328334

329-
self.def_remote = None
335+
self.remote = remote
330336

331337
def _parse_path(self, fs, path_info):
332338
if fs.scheme != "local":
@@ -843,7 +849,8 @@ def get_dir_cache(self, **kwargs):
843849
try:
844850
objects.check(self.odb, obj)
845851
except FileNotFoundError:
846-
self.repo.cloud.pull([obj.hash_info], **kwargs)
852+
remote = self.repo.cloud.get_remote_odb(self.remote)
853+
self.repo.cloud.pull([obj.hash_info], odb=remote, **kwargs)
847854

848855
if self.obj:
849856
return self.obj
@@ -855,9 +862,9 @@ def get_dir_cache(self, **kwargs):
855862

856863
return self.obj
857864

858-
def collect_used_dir_cache(
865+
def _collect_used_dir_cache(
859866
self, remote=None, force=False, jobs=None, filter_info=None
860-
) -> Dict[Optional["ObjectDB"], Set["HashInfo"]]:
867+
) -> Optional["Tree"]:
861868
"""Fetch dir cache and return used object IDs for this out."""
862869

863870
try:
@@ -878,13 +885,13 @@ def collect_used_dir_cache(
878885
"unable to fully collect used cache"
879886
" without cache for directory '{}'".format(self)
880887
)
881-
return {}
888+
return None
882889

883890
obj = self.get_obj()
884891
if filter_info and filter_info != self.path_info:
885892
prefix = filter_info.relative_to(self.path_info).parts
886893
obj = obj.filter(prefix)
887-
return {None: set(self._named_obj_ids(obj))}
894+
return obj
888895

889896
def get_used_objs(
890897
self, **kwargs
@@ -917,22 +924,31 @@ def get_used_objs(
917924
return {}
918925

919926
if self.is_dir_checksum:
920-
return self.collect_used_dir_cache(**kwargs)
927+
obj = self._collect_used_dir_cache(**kwargs)
928+
else:
929+
obj = self.get_obj(filter_info=kwargs.get("filter_info"))
930+
if not obj:
931+
obj = self.odb.get(self.hash_info)
921932

922-
obj = self.get_obj(filter_info=kwargs.get("filter_info"))
923933
if not obj:
924-
obj = self.odb.get(self.hash_info)
934+
return {}
935+
936+
if self.remote:
937+
remote = self.repo.cloud.get_remote_odb(name=self.remote)
938+
else:
939+
remote = None
925940

926-
return {None: set(self._named_obj_ids(obj))}
941+
return {remote: self._named_obj_ids(obj)}
927942

928943
def _named_obj_ids(self, obj):
929944
name = str(self)
930945
obj.hash_info.obj_name = name
931-
yield obj.hash_info
946+
oids = {obj.hash_info}
932947
if isinstance(obj, Tree):
933948
for key, entry_obj in obj:
934949
entry_obj.hash_info.obj_name = self.fs.sep.join([name, *key])
935-
yield entry_obj.hash_info
950+
oids.add(entry_obj.hash_info)
951+
return oids
936952

937953
def get_used_external(
938954
self, **kwargs
@@ -1033,4 +1049,5 @@ def is_plot(self) -> bool:
10331049
Output.PARAM_CACHE: bool,
10341050
Output.PARAM_METRIC: Output.METRIC_SCHEMA,
10351051
Output.PARAM_DESC: str,
1052+
Output.PARAM_REMOTE: str,
10361053
}

dvc/schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
Output.PARAM_PERSIST: bool,
5050
Output.PARAM_CHECKPOINT: bool,
5151
Output.PARAM_DESC: str,
52+
Output.PARAM_REMOTE: str,
5253
}
5354
}
5455

dvc/stage/serialize.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
PARAM_PERSIST = Output.PARAM_PERSIST
2929
PARAM_CHECKPOINT = Output.PARAM_CHECKPOINT
3030
PARAM_DESC = Output.PARAM_DESC
31+
PARAM_REMOTE = Output.PARAM_REMOTE
3132

3233
DEFAULT_PARAMS_FILE = ParamsDependency.DEFAULT_PARAMS_FILE
3334

@@ -52,6 +53,8 @@ def _get_flags(out):
5253
yield from out.plot.items()
5354
if out.live and isinstance(out.live, dict):
5455
yield from out.live.items()
56+
if out.remote:
57+
yield PARAM_REMOTE, out.remote
5558

5659

5760
def _serialize_out(out):

tests/func/test_data_cloud.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,3 +609,47 @@ def test_pull_no_00_prefix(tmp_dir, dvc, remote, monkeypatch):
609609
stats = dvc.pull()
610610
assert stats["fetched"] == 2
611611
assert set(stats["added"]) == {"foo", "bar"}
612+
613+
614+
def test_output_remote(tmp_dir, dvc, make_remote):
615+
from dvc.utils.serialize import modify_yaml
616+
617+
make_remote("default", default=True)
618+
make_remote("for_foo", default=False)
619+
make_remote("for_data", default=False)
620+
621+
tmp_dir.dvc_gen("foo", "foo")
622+
tmp_dir.dvc_gen("bar", "bar")
623+
tmp_dir.dvc_gen("data", {"one": "one", "two": "two"})
624+
625+
with modify_yaml("foo.dvc") as d:
626+
d["outs"][0]["remote"] = "for_foo"
627+
628+
with modify_yaml("data.dvc") as d:
629+
d["outs"][0]["remote"] = "for_data"
630+
631+
dvc.push()
632+
633+
default = dvc.cloud.get_remote_odb("default")
634+
for_foo = dvc.cloud.get_remote_odb("for_foo")
635+
for_data = dvc.cloud.get_remote_odb("for_data")
636+
637+
assert set(default.all()) == {"37b51d194a7513e45b56f6524f2d51f2"}
638+
assert set(for_foo.all()) == {"acbd18db4cc2f85cedef654fccc4a4d8"}
639+
assert set(for_data.all()) == {
640+
"f97c5d29941bfb1b2fdab0874906ab82",
641+
"6b18131dc289fd37006705affe961ef8.dir",
642+
"b8a9f715dbb64fd5c56e7783c6820a61",
643+
}
644+
645+
clean(["foo", "bar", "data"], dvc)
646+
647+
dvc.pull()
648+
649+
assert set(dvc.odb.local.all()) == {
650+
"37b51d194a7513e45b56f6524f2d51f2",
651+
"acbd18db4cc2f85cedef654fccc4a4d8",
652+
"f97c5d29941bfb1b2fdab0874906ab82",
653+
"6b18131dc289fd37006705affe961ef8.dir",
654+
"b8a9f715dbb64fd5c56e7783c6820a61",
655+
}

tests/func/test_remote.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def unreliable_upload(self, fobj, to_info, **kwargs):
200200
dvc.push()
201201
remove(dvc.odb.local.cache_dir)
202202

203-
baz.collect_used_dir_cache()
203+
baz._collect_used_dir_cache()
204204
with patch.object(LocalFileSystem, "upload", side_effect=Exception):
205205
with pytest.raises(DownloadError) as download_error_info:
206206
dvc.pull()

tests/remotes/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717
from .hdfs import HDFS, hadoop, hdfs, hdfs_server, real_hdfs # noqa: F401
1818
from .http import HTTP, http, http_server # noqa: F401
19-
from .local import Local, local_cloud, local_remote # noqa: F401
19+
from .local import Local, local_cloud, local_remote, make_local # noqa: F401
2020
from .oss import ( # noqa: F401
2121
OSS,
2222
TEST_OSS_REPO_BUCKET,
@@ -100,6 +100,24 @@ def docker_services(
100100
return Services(executor)
101101

102102

103+
@pytest.fixture
104+
def make_cloud(request):
105+
def _make_cloud(typ):
106+
return request.getfixturevalue(f"make_{typ}")()
107+
108+
return _make_cloud
109+
110+
111+
@pytest.fixture
112+
def make_remote(tmp_dir, dvc, make_cloud):
113+
def _make_remote(name, typ="local", **kwargs):
114+
cloud = make_cloud(typ)
115+
tmp_dir.add_remote(name=name, config=cloud.config, **kwargs)
116+
return cloud
117+
118+
return _make_remote
119+
120+
103121
@pytest.fixture
104122
def remote(tmp_dir, dvc, request):
105123
cloud = request.param

tests/remotes/local.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,19 @@ def get_url():
1717

1818

1919
@pytest.fixture
20-
def local_cloud(make_tmp_dir):
21-
ret = make_tmp_dir("local-cloud")
22-
ret.url = str(ret)
23-
ret.config = {"url": ret.url}
24-
return ret
20+
def make_local(make_tmp_dir):
21+
def _make_local():
22+
ret = make_tmp_dir("local-cloud")
23+
ret.url = str(ret)
24+
ret.config = {"url": ret.url}
25+
return ret
26+
27+
return _make_local
28+
29+
30+
@pytest.fixture
31+
def local_cloud(make_local):
32+
return make_local()
2533

2634

2735
@pytest.fixture

tests/unit/output/test_load.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,16 @@ def test_load_remote_files_from_pipeline(dvc):
7979
assert not out.hash_info
8080

8181

82+
def test_load_remote(dvc):
83+
stage = Stage(dvc)
84+
(foo, bar) = output.load_from_pipeline(
85+
stage,
86+
["foo", {"bar": {"remote": "myremote"}}],
87+
)
88+
assert foo.remote is None
89+
assert bar.remote == "myremote"
90+
91+
8292
@pytest.mark.parametrize("typ", [None, "", "illegal"])
8393
def test_load_from_pipeline_error_on_typ(dvc, typ):
8494
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)