Skip to content

Commit d201fad

Browse files
authored
tests: serialize pipeline file (#3985)
serialize: make params ordered
1 parent 5be12b5 commit d201fad

9 files changed

+259
-51
lines changed

dvc/dependency/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def loadd_from(stage, d_list):
8080

8181

8282
def loads_from(stage, s_list, erepo=None):
83+
assert isinstance(s_list, list)
8384
ret = []
8485
for s in s_list:
8586
info = {RepoDependency.PARAM_REPO: erepo} if erepo else {}

dvc/dvcfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from voluptuous import MultipleInvalid
66

77
import dvc.prompt as prompt
8-
from dvc import serialize
98
from dvc.exceptions import DvcException
9+
from dvc.stage import serialize
1010
from dvc.stage.exceptions import (
1111
StageFileAlreadyExistsError,
1212
StageFileBadNameError,

dvc/stage/cache.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
from voluptuous import Invalid
88

99
from dvc.schema import COMPILED_LOCK_FILE_STAGE_SCHEMA
10-
from dvc.serialize import to_single_stage_lockfile
11-
from dvc.stage.loader import StageLoader
1210
from dvc.utils import dict_sha256, relpath
1311
from dvc.utils.fs import makedirs
1412
from dvc.utils.yaml import dump_yaml
1513

14+
from .loader import StageLoader
15+
from .serialize import to_single_stage_lockfile
16+
1617
logger = logging.getLogger(__name__)
1718

1819

@@ -78,7 +79,7 @@ def _load(self, stage):
7879
return None
7980

8081
def _create_stage(self, cache, wdir=None):
81-
from dvc.stage import create_stage, PipelineStage
82+
from . import create_stage, PipelineStage
8283

8384
stage = create_stage(
8485
PipelineStage,

dvc/serialize.py renamed to dvc/stage/serialize.py

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
from operator import attrgetter
44
from typing import TYPE_CHECKING, List
55

6-
from funcy import lsplit, rpartial
6+
from funcy import post_processing
77

88
from dvc.dependency import ParamsDependency
99
from dvc.output import BaseOutput
10-
from dvc.stage.params import StageParams
11-
from dvc.stage.utils import resolve_wdir
1210
from dvc.utils.collections import apply_diff
1311
from dvc.utils.yaml import parse_yaml_for_update
1412

13+
from .params import StageParams
14+
from .utils import resolve_wdir, split_params_deps
15+
1516
if TYPE_CHECKING:
1617
from dvc.stage import PipelineStage, Stage
1718

@@ -32,23 +33,34 @@
3233
sort_by_path = partial(sorted, key=attrgetter("def_path"))
3334

3435

35-
def _get_out(out):
36-
res = OrderedDict()
36+
@post_processing(OrderedDict)
37+
def _get_flags(out):
3738
if not out.use_cache:
38-
res[PARAM_CACHE] = False
39+
yield PARAM_CACHE, False
3940
if out.persist:
40-
res[PARAM_PERSIST] = True
41+
yield PARAM_PERSIST, True
4142
if out.plot and isinstance(out.plot, dict):
42-
res.update(out.plot)
43-
return out.def_path if not res else {out.def_path: res}
43+
# notice `out.plot` is not sorted
44+
# `out.plot` is in the same order as is in the file when read
45+
# and, should be dumped as-is without any sorting
46+
yield from out.plot.items()
4447

4548

46-
def _get_outs(outs):
47-
return [_get_out(out) for out in sort_by_path(outs)]
49+
def _serialize_out(out):
50+
flags = _get_flags(out)
51+
return out.def_path if not flags else {out.def_path: flags}
4852

4953

50-
def get_params_deps(stage: "PipelineStage"):
51-
return lsplit(rpartial(isinstance, ParamsDependency), stage.deps)
54+
def _serialize_outs(outputs: List[BaseOutput]):
55+
outs, metrics, plots = [], [], []
56+
for out in sort_by_path(outputs):
57+
bucket = outs
58+
if out.plot:
59+
bucket = plots
60+
elif out.metric:
61+
bucket = metrics
62+
bucket.append(_serialize_out(out))
63+
return outs, metrics, plots
5264

5365

5466
def _serialize_params(params: List[ParamsDependency]):
@@ -70,16 +82,15 @@ def _serialize_params(params: List[ParamsDependency]):
7082
dump = param_dep.dumpd()
7183
path, params = dump[PARAM_PATH], dump[PARAM_PARAMS]
7284
if isinstance(params, dict):
73-
k = list(params.keys())
85+
k = sorted(params.keys())
7486
if not k:
7587
continue
76-
key_vals[path] = OrderedDict(
77-
[(key, params[key]) for key in sorted(k)]
78-
)
88+
key_vals[path] = OrderedDict([(key, params[key]) for key in k])
7989
else:
8090
assert isinstance(params, list)
81-
k = params
82-
key_vals = OrderedDict()
91+
# no params values available here, entry will be skipped for lock
92+
k = sorted(params)
93+
8394
# params from default file is always kept at the start of the `params:`
8495
if path == DEFAULT_PARAMS_FILE:
8596
keys = k + keys
@@ -93,28 +104,20 @@ def _serialize_params(params: List[ParamsDependency]):
93104

94105

95106
def to_pipeline_file(stage: "PipelineStage"):
96-
params, deps = get_params_deps(stage)
97-
serialized_params, _ = _serialize_params(params)
107+
wdir = resolve_wdir(stage.wdir, stage.path)
108+
params, deps = split_params_deps(stage)
109+
deps = sorted([d.def_path for d in deps])
110+
params, _ = _serialize_params(params)
98111

112+
outs, metrics, plots = _serialize_outs(stage.outs)
99113
res = [
100114
(stage.PARAM_CMD, stage.cmd),
101-
(stage.PARAM_WDIR, resolve_wdir(stage.wdir, stage.path)),
102-
(stage.PARAM_DEPS, sorted([d.def_path for d in deps])),
103-
(stage.PARAM_PARAMS, serialized_params),
104-
(
105-
PARAM_OUTS,
106-
_get_outs(
107-
[out for out in stage.outs if not (out.metric or out.plot)]
108-
),
109-
),
110-
(
111-
stage.PARAM_METRICS,
112-
_get_outs([out for out in stage.outs if out.metric]),
113-
),
114-
(
115-
stage.PARAM_PLOTS,
116-
_get_outs([out for out in stage.outs if out.plot]),
117-
),
115+
(stage.PARAM_WDIR, wdir),
116+
(stage.PARAM_DEPS, deps),
117+
(stage.PARAM_PARAMS, params),
118+
(stage.PARAM_OUTS, outs),
119+
(stage.PARAM_METRICS, metrics),
120+
(stage.PARAM_PLOTS, plots),
118121
(stage.PARAM_FROZEN, stage.frozen),
119122
(stage.PARAM_ALWAYS_CHANGED, stage.always_changed),
120123
]
@@ -127,7 +130,7 @@ def to_single_stage_lockfile(stage: "Stage") -> dict:
127130
assert stage.cmd
128131

129132
res = OrderedDict([("cmd", stage.cmd)])
130-
params, deps = get_params_deps(stage)
133+
params, deps = split_params_deps(stage)
131134
deps, outs = [
132135
[
133136
OrderedDict(

dvc/stage/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
import pathlib
33
from itertools import product
44

5+
from funcy import lsplit, rpartial
6+
57
from dvc import dependency, output
68
from dvc.utils.fs import path_isin
79

10+
from ..dependency import ParamsDependency
811
from ..remote import LocalRemote, S3Remote
912
from ..utils import dict_md5, format_link, relpath
1013
from .exceptions import (
@@ -196,3 +199,7 @@ def get_dump(stage):
196199
}.items()
197200
if value
198201
}
202+
203+
204+
def split_params_deps(stage):
205+
return lsplit(rpartial(isinstance, ParamsDependency), stage.deps)

tests/func/test_lockfile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import yaml
55

66
from dvc.dvcfile import PIPELINE_LOCK
7-
from dvc.serialize import get_params_deps
7+
from dvc.stage.utils import split_params_deps
88
from dvc.utils.fs import remove
99
from dvc.utils.yaml import parse_yaml_for_update
1010
from tests.func.test_run_multistage import supported_params
@@ -137,7 +137,7 @@ def test_params_dump(tmp_dir, dvc, run_head):
137137
assert not dvc.reproduce(stage.addressing)
138138

139139
# let's change the order of params and dump them in pipeline file
140-
params, _ = get_params_deps(stage)
140+
params, _ = split_params_deps(stage)
141141
for param in params:
142142
param.params.reverse()
143143

tests/func/test_repro_multistage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def test_cyclic_graph_error(tmp_dir, dvc, run_copy):
509509
def test_repro_multiple_params(tmp_dir, dvc):
510510
from tests.func.test_run_multistage import supported_params
511511

512-
from dvc.serialize import get_params_deps
512+
from dvc.stage.utils import split_params_deps
513513

514514
with (tmp_dir / "params2.yaml").open("w+") as f:
515515
yaml.dump(supported_params, f)
@@ -529,7 +529,7 @@ def test_repro_multiple_params(tmp_dir, dvc):
529529
cmd="cat params2.yaml params.yaml > bar",
530530
)
531531

532-
params, deps = get_params_deps(stage)
532+
params, deps = split_params_deps(stage)
533533
assert len(params) == 2
534534
assert len(deps) == 1
535535
assert len(stage.outs) == 1

tests/unit/stage/test_loader_pipeline_file.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import pytest
66

77
from dvc.dvcfile import PIPELINE_FILE, Dvcfile
8-
from dvc.serialize import get_params_deps
98
from dvc.stage import PipelineStage, create_stage
109
from dvc.stage.loader import StageLoader
10+
from dvc.stage.serialize import split_params_deps
1111

1212

1313
@pytest.fixture
@@ -61,7 +61,7 @@ def test_fill_from_lock_params(dvc, lock_data):
6161
"ipsum": "ipsum"
6262
},
6363
}
64-
params_deps = get_params_deps(stage)[0]
64+
params_deps = split_params_deps(stage)[0]
6565
assert set(params_deps[0].params) == {"lorem", "lorem.ipsum"}
6666
assert set(params_deps[1].params) == {"ipsum", "foobar"}
6767
assert not params_deps[0].info
@@ -81,7 +81,7 @@ def test_fill_from_lock_missing_params_section(dvc, lock_data):
8181
outs=["bar"],
8282
params=["lorem", "lorem.ipsum", {"myparams.yaml": ["ipsum"]}],
8383
)
84-
params_deps = get_params_deps(stage)[0]
84+
params_deps = split_params_deps(stage)[0]
8585
StageLoader.fill_from_lock(stage, lock_data)
8686
assert not params_deps[0].info and not params_deps[1].info
8787

@@ -180,7 +180,7 @@ def test_load_stage_with_params(dvc, stage_data, lock_data):
180180
dvcfile = Dvcfile(dvc, PIPELINE_FILE)
181181
stage = StageLoader.load_stage(dvcfile, "stage-1", stage_data, lock_data)
182182

183-
params, deps = get_params_deps(stage)
183+
params, deps = split_params_deps(stage)
184184
assert deps[0].def_path == "foo" and stage.outs[0].def_path == "bar"
185185
assert params[0].def_path == "params.yaml"
186186
assert params[0].info == {"lorem": "ipsum"}

0 commit comments

Comments
 (0)