Skip to content

Commit 99b58d3

Browse files
authored
refactor/test multistage load for params and outputs (#3949)
* refactor multistage load for params and outputs * tests: load params * tests: output loading from pipeline file * fix test * fix typo in name * split params load * rename params func s/inject_values/fill_values * fix tests * simplify loads_params and output.load_from_pipeline * address @pared's suggestions for tests
1 parent 5bbaa28 commit 99b58d3

File tree

8 files changed

+280
-136
lines changed

8 files changed

+280
-136
lines changed

dvc/dependency/__init__.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,26 +87,29 @@ def loads_from(stage, s_list, erepo=None):
8787
return ret
8888

8989

90-
def _parse_params(path_params):
91-
path, _, params_str = path_params.rpartition(":")
92-
params = params_str.split(",")
93-
return path, params
90+
def _merge_params(s_list):
91+
d = defaultdict(list)
92+
default_file = ParamsDependency.DEFAULT_PARAMS_FILE
93+
for key in s_list:
94+
if isinstance(key, str):
95+
d[default_file].append(key)
96+
continue
97+
if not isinstance(key, dict):
98+
msg = "Only list of str/dict is supported. Got: "
99+
msg += f"'{type(key).__name__}'."
100+
raise ValueError(msg)
101+
102+
for k, params in key.items():
103+
if not isinstance(params, list):
104+
msg = "Expected list of params for custom params file "
105+
msg += f"'{k}', got '{type(params).__name__}'."
106+
raise ValueError(msg)
107+
d[k].extend(params)
108+
return d
94109

95110

96111
def loads_params(stage, s_list):
97-
# Creates an object for each unique file that is referenced in the list
98-
params_by_path = defaultdict(list)
99-
for s in s_list:
100-
path, params = _parse_params(s)
101-
params_by_path[path].extend(params)
102-
103-
d_list = []
104-
for path, params in params_by_path.items():
105-
d_list.append(
106-
{
107-
BaseOutput.PARAM_PATH: path,
108-
ParamsDependency.PARAM_PARAMS: params,
109-
}
110-
)
111-
112-
return loadd_from(stage, d_list)
112+
d = _merge_params(s_list)
113+
return [
114+
ParamsDependency(stage, path, params) for path, params in d.items()
115+
]

dvc/dependency/param.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, stage, path, params):
4040
info=info,
4141
)
4242

43-
def _dyn_load(self, values=None):
43+
def fill_values(self, values=None):
4444
"""Load params values dynamically."""
4545
if not values:
4646
return

dvc/output/__init__.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from collections import defaultdict
12
from urllib.parse import urlparse
23

4+
from funcy import collecting, project
35
from voluptuous import And, Any, Coerce, Length, Lower, Required, SetTo
46

57
from dvc.output.base import BaseOutput
@@ -58,7 +60,9 @@
5860
SCHEMA[BaseOutput.PARAM_PERSIST] = bool
5961

6062

61-
def _get(stage, p, info, cache, metric, plot=False, persist=False):
63+
def _get(
64+
stage, p, info=None, cache=True, metric=False, plot=False, persist=False
65+
):
6266
parsed = urlparse(p)
6367

6468
if parsed.scheme == "remote":
@@ -135,3 +139,45 @@ def loads_from(
135139
)
136140
for s in s_list
137141
]
142+
143+
144+
def _split_dict(d, keys):
145+
return project(d, keys), project(d, d.keys() - keys)
146+
147+
148+
def _merge_data(s_list):
149+
d = defaultdict(dict)
150+
for key in s_list:
151+
if isinstance(key, str):
152+
d[key].update({})
153+
continue
154+
if not isinstance(key, dict):
155+
raise ValueError(f"'{type(key).__name__}' not supported.")
156+
157+
for k, flags in key.items():
158+
if not isinstance(flags, dict):
159+
raise ValueError(
160+
f"Expected dict for '{k}', got: '{type(flags).__name__}'"
161+
)
162+
d[k].update(flags)
163+
return d
164+
165+
166+
@collecting
167+
def load_from_pipeline(stage, s_list, typ="outs"):
168+
if typ not in (stage.PARAM_OUTS, stage.PARAM_METRICS, stage.PARAM_PLOTS):
169+
raise ValueError(f"'{typ}' key is not allowed for pipeline files.")
170+
171+
metric = typ == stage.PARAM_METRICS
172+
plot = typ == stage.PARAM_PLOTS
173+
174+
d = _merge_data(s_list)
175+
176+
for path, flags in d.items():
177+
plt_d = {}
178+
if plot:
179+
from dvc.schema import PLOT_PROPS
180+
181+
plt_d, flags = _split_dict(flags, keys=PLOT_PROPS.keys())
182+
extra = project(flags, ["cache", "persist"])
183+
yield _get(stage, path, {}, plot=plt_d or plot, metric=metric, **extra)

dvc/repo/run.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22

3-
from funcy import concat, first
3+
from funcy import concat, first, lfilter
44

55
from dvc.exceptions import InvalidArgumentError
66
from dvc.stage.exceptions import (
@@ -20,6 +20,19 @@ def is_valid_name(name: str):
2020
return not INVALID_STAGENAME_CHARS & set(name)
2121

2222

23+
def parse_params(path_params):
24+
ret = []
25+
for path_param in path_params:
26+
path, _, params_str = path_param.rpartition(":")
27+
# remove empty strings from params, on condition such as `-p "file1:"`
28+
params = lfilter(bool, params_str.split(","))
29+
if not path:
30+
ret.extend(params)
31+
else:
32+
ret.append({path: params})
33+
return ret
34+
35+
2336
def _get_file_path(kwargs):
2437
from dvc.dvcfile import DVC_FILE_SUFFIX, DVC_FILE
2538

@@ -72,7 +85,10 @@ def run(self, fname=None, no_exec=False, single_stage=False, **kwargs):
7285
if not is_valid_name(stage_name):
7386
raise InvalidStageName
7487

75-
stage = create_stage(stage_cls, repo=self, path=path, **kwargs)
88+
params = parse_params(kwargs.pop("params", []))
89+
stage = create_stage(
90+
stage_cls, repo=self, path=path, params=params, **kwargs
91+
)
7692
if stage is None:
7793
return None
7894

dvc/stage/loader.py

Lines changed: 16 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import logging
22
import os
3-
from collections import defaultdict
43
from collections.abc import Mapping
54
from copy import deepcopy
65
from itertools import chain
76

8-
from funcy import first
7+
from funcy import lcat, project
98

109
from dvc import dependency, output
1110

1211
from ..dependency import ParamsDependency
12+
from . import fill_stage_dependencies
1313
from .exceptions import StageNameUnspecified, StageNotFound
1414

1515
logger = logging.getLogger(__name__)
@@ -32,6 +32,7 @@ def __init__(self, dvcfile, stages_data, lockfile_data=None):
3232

3333
@staticmethod
3434
def fill_from_lock(stage, lock_data):
35+
"""Fill values for params, checksums for outs and deps from lock."""
3536
from .params import StageParams
3637

3738
items = chain(
@@ -46,8 +47,8 @@ def fill_from_lock(stage, lock_data):
4647
for key, item in items:
4748
if isinstance(item, ParamsDependency):
4849
# load the params with values inside lock dynamically
49-
params = lock_data.get("params", {}).get(item.def_path, {})
50-
item._dyn_load(params)
50+
lock_params = lock_data.get(stage.PARAM_PARAMS, {})
51+
item.fill_values(lock_params.get(item.def_path, {}))
5152
continue
5253

5354
item.checksum = (
@@ -56,104 +57,6 @@ def fill_from_lock(stage, lock_data):
5657
.get(item.checksum_type)
5758
)
5859

59-
@classmethod
60-
def _load_params(cls, stage, pipeline_params):
61-
"""
62-
File in pipeline file is expected to be in following format:
63-
```
64-
params:
65-
- lr
66-
- train.epochs
67-
- params2.yaml: # notice the filename
68-
- process.threshold
69-
- process.bow
70-
```
71-
72-
and, in lockfile, we keep it as following format:
73-
```
74-
params:
75-
params.yaml:
76-
lr: 0.0041
77-
train.epochs: 100
78-
params2.yaml:
79-
process.threshold: 0.98
80-
process.bow:
81-
- 15000
82-
- 123
83-
```
84-
In the list of `params` inside pipeline file, if any of the item is
85-
dict-like, the key will be treated as separate params file and it's
86-
values to be part of that params file, else, the item is considered
87-
as part of the `params.yaml` which is a default file.
88-
89-
(From example above: `lr` is considered to be part of `params.yaml`
90-
whereas `process.bow` to be part of `params2.yaml`.)
91-
92-
We only load the keys here, lockfile bears the values which are used
93-
to compare between the actual params from the file in the workspace.
94-
"""
95-
res = defaultdict(list)
96-
for key in pipeline_params:
97-
if isinstance(key, str):
98-
path = DEFAULT_PARAMS_FILE
99-
res[path].append(key)
100-
elif isinstance(key, dict):
101-
path = first(key)
102-
res[path].extend(key[path])
103-
104-
stage.deps.extend(
105-
dependency.loadd_from(
106-
stage,
107-
[
108-
{"path": key, "params": params}
109-
for key, params in res.items()
110-
],
111-
)
112-
)
113-
114-
@classmethod
115-
def _load_outs(cls, stage, data, typ=None):
116-
from dvc.output.base import BaseOutput
117-
118-
d = []
119-
for key in data:
120-
if isinstance(key, str):
121-
entry = {BaseOutput.PARAM_PATH: key}
122-
if typ:
123-
entry[typ] = True
124-
d.append(entry)
125-
continue
126-
127-
assert isinstance(key, dict)
128-
assert len(key) == 1
129-
130-
path = first(key)
131-
extra = key[path]
132-
133-
if not typ:
134-
d.append({BaseOutput.PARAM_PATH: path, **extra})
135-
continue
136-
137-
entry = {BaseOutput.PARAM_PATH: path}
138-
139-
persist = extra.pop(BaseOutput.PARAM_PERSIST, False)
140-
if persist:
141-
entry[BaseOutput.PARAM_PERSIST] = persist
142-
143-
cache = extra.pop(BaseOutput.PARAM_CACHE, True)
144-
if not cache:
145-
entry[BaseOutput.PARAM_CACHE] = cache
146-
147-
entry[typ] = extra or True
148-
149-
d.append(entry)
150-
151-
stage.outs.extend(output.loadd_from(stage, d))
152-
153-
@classmethod
154-
def _load_deps(cls, stage, data):
155-
stage.deps.extend(dependency.loads_from(stage, data))
156-
15760
@classmethod
15861
def load_stage(cls, dvcfile, name, stage_data, lock_data):
15962
from . import PipelineStage, Stage, loads_from
@@ -163,13 +66,18 @@ def load_stage(cls, dvcfile, name, stage_data, lock_data):
16366
)
16467
stage = loads_from(PipelineStage, dvcfile.repo, path, wdir, stage_data)
16568
stage.name = name
166-
stage.deps, stage.outs = [], []
16769

168-
cls._load_outs(stage, stage_data.get("outs", []))
169-
cls._load_outs(stage, stage_data.get("metrics", []), "metric")
170-
cls._load_outs(stage, stage_data.get("plots", []), "plot")
171-
cls._load_deps(stage, stage_data.get("deps", []))
172-
cls._load_params(stage, stage_data.get("params", []))
70+
deps = project(stage_data, [stage.PARAM_DEPS, stage.PARAM_PARAMS])
71+
fill_stage_dependencies(stage, **deps)
72+
73+
outs = project(
74+
stage_data,
75+
[stage.PARAM_OUTS, stage.PARAM_METRICS, stage.PARAM_PLOTS],
76+
)
77+
stage.outs = lcat(
78+
output.load_from_pipeline(stage, data, typ=key)
79+
for key, data in outs.items()
80+
)
17381

17482
if lock_data:
17583
stage.cmd_changed = lock_data.get(

tests/unit/dependency/test_params.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,40 @@
1515

1616
def test_loads_params(dvc):
1717
stage = Stage(dvc)
18-
deps = loads_params(stage, ["foo", "bar,baz", "a_file:qux"])
19-
assert len(deps) == 2
18+
deps = loads_params(
19+
stage,
20+
[
21+
"foo",
22+
"bar",
23+
{"a_file": ["baz", "bat"]},
24+
{"b_file": ["cat"]},
25+
{},
26+
{"a_file": ["foobar"]},
27+
],
28+
)
29+
assert len(deps) == 3
2030

2131
assert isinstance(deps[0], ParamsDependency)
2232
assert deps[0].def_path == ParamsDependency.DEFAULT_PARAMS_FILE
23-
assert deps[0].params == ["foo", "bar", "baz"]
33+
assert deps[0].params == ["foo", "bar"]
2434
assert deps[0].info == {}
2535

2636
assert isinstance(deps[1], ParamsDependency)
2737
assert deps[1].def_path == "a_file"
28-
assert deps[1].params == ["qux"]
38+
assert deps[1].params == ["baz", "bat", "foobar"]
2939
assert deps[1].info == {}
3040

41+
assert isinstance(deps[2], ParamsDependency)
42+
assert deps[2].def_path == "b_file"
43+
assert deps[2].params == ["cat"]
44+
assert deps[2].info == {}
45+
46+
47+
@pytest.mark.parametrize("params", [[3], [{"b_file": "cat"}]])
48+
def test_params_error(dvc, params):
49+
with pytest.raises(ValueError):
50+
loads_params(Stage(dvc), params)
51+
3152

3253
def test_loadd_from(dvc):
3354
stage = Stage(dvc)

0 commit comments

Comments
 (0)