Skip to content

Commit b0ae5e2

Browse files
authored
Support parametrization from params or vars section in dvc.yaml (#4751)
* Implement importing from params * Change varname * Add tests * fix pylint issues
1 parent e480550 commit b0ae5e2

File tree

7 files changed

+636
-67
lines changed

7 files changed

+636
-67
lines changed

dvc/dvcfile.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from dvc.exceptions import DvcException
1010
from dvc.parsing import DataResolver
11+
from dvc.path_info import PathInfo
1112
from dvc.stage import serialize
1213
from dvc.stage.exceptions import (
1314
StageFileBadNameError,
@@ -231,7 +232,9 @@ def stages(self):
231232

232233
if self.repo.config["feature"]["parametrization"]:
233234
with log_durations(logger.debug, "resolving values"):
234-
resolver = DataResolver(data)
235+
resolver = DataResolver(
236+
self.repo, PathInfo(self.path).parent, data
237+
)
235238
data = resolver.resolve()
236239

237240
lockfile_data = self._lockfile.load()

dvc/parsing/__init__.py

Lines changed: 89 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,105 @@
11
import logging
2+
import os
3+
from copy import deepcopy
24
from itertools import starmap
5+
from typing import TYPE_CHECKING
36

4-
from funcy import join
7+
from funcy import first, join
8+
9+
from dvc.dependency.param import ParamsDependency
10+
from dvc.path_info import PathInfo
11+
from dvc.utils.serialize import dumps_yaml
512

613
from .context import Context
714
from .interpolate import resolve
815

16+
if TYPE_CHECKING:
17+
from dvc.repo import Repo
18+
919
logger = logging.getLogger(__name__)
1020

11-
STAGES = "stages"
21+
STAGES_KWD = "stages"
22+
USE_KWD = "use"
23+
VARS_KWD = "vars"
24+
WDIR_KWD = "wdir"
25+
DEFAULT_PARAMS_FILE = ParamsDependency.DEFAULT_PARAMS_FILE
26+
PARAMS_KWD = "params"
1227

1328

1429
class DataResolver:
15-
def __init__(self, d):
16-
self.context = Context()
17-
self.data = d
30+
def __init__(self, repo: "Repo", wdir: PathInfo, d: dict):
31+
to_import: PathInfo = wdir / d.get(USE_KWD, DEFAULT_PARAMS_FILE)
32+
vars_ = d.get(VARS_KWD, {})
33+
if os.path.exists(to_import):
34+
self.global_ctx_source = to_import
35+
self.global_ctx = Context.load_from(repo.tree, str(to_import))
36+
else:
37+
self.global_ctx = Context()
38+
self.global_ctx_source = None
39+
logger.debug(
40+
"%s does not exist, it won't be used in parametrization",
41+
to_import,
42+
)
1843

19-
def _resolve_entry(self, name, definition):
20-
stage_d = resolve(definition, self.context)
21-
logger.trace("Resolved stage data for '%s': %s", name, stage_d)
22-
return {name: stage_d}
44+
self.global_ctx.merge_update(vars_)
45+
self.data: dict = d
46+
self.wdir = wdir
47+
self.repo = repo
48+
49+
def _resolve_entry(self, name: str, definition):
50+
context = Context.clone(self.global_ctx)
51+
return self._resolve_stage(context, name, definition)
2352

2453
def resolve(self):
25-
stages = self.data.get(STAGES, {})
54+
stages = self.data.get(STAGES_KWD, {})
2655
data = join(starmap(self._resolve_entry, stages.items()))
27-
return {**self.data, STAGES: data}
56+
logger.trace("Resolved dvc.yaml:\n%s", dumps_yaml(data))
57+
return {STAGES_KWD: data}
58+
59+
def _resolve_stage(self, context: Context, name: str, definition) -> dict:
60+
definition = deepcopy(definition)
61+
wdir = self._resolve_wdir(context, definition.get(WDIR_KWD))
62+
if self.wdir != wdir:
63+
logger.debug(
64+
"Stage %s has different wdir than dvc.yaml file", name
65+
)
66+
67+
contexts = []
68+
params_yaml_file = wdir / DEFAULT_PARAMS_FILE
69+
if self.global_ctx_source != params_yaml_file:
70+
if os.path.exists(params_yaml_file):
71+
contexts.append(
72+
Context.load_from(self.repo.tree, str(params_yaml_file))
73+
)
74+
else:
75+
logger.debug(
76+
"%s does not exist for stage %s", params_yaml_file, name
77+
)
78+
79+
params_file = definition.get(PARAMS_KWD, [])
80+
for item in params_file:
81+
if item and isinstance(item, dict):
82+
contexts.append(
83+
Context.load_from(self.repo.tree, str(wdir / first(item)))
84+
)
85+
86+
context.merge_update(*contexts)
87+
88+
logger.trace( # pytype: disable=attribute-error
89+
"Context during resolution of stage %s:\n%s", name, context
90+
)
91+
92+
with context.track():
93+
stage_d = resolve(definition, context)
94+
95+
params = stage_d.get(PARAMS_KWD, []) + context.tracked
96+
97+
if params:
98+
stage_d[PARAMS_KWD] = params
99+
return {name: stage_d}
100+
101+
def _resolve_wdir(self, context: Context, wdir: str = None) -> PathInfo:
102+
if not wdir:
103+
return self.wdir
104+
wdir = resolve(wdir, context)
105+
return self.wdir / str(wdir)

dvc/parsing/context.py

Lines changed: 193 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,208 @@
1-
from collections.abc import Collection, Mapping, Sequence
1+
import os
2+
from collections import defaultdict
3+
from collections.abc import Mapping, MutableMapping, MutableSequence
4+
from contextlib import contextmanager
5+
from copy import deepcopy
6+
from dataclasses import dataclass, field, replace
7+
from typing import Any, List, Optional, Sequence, Union
28

3-
# for testing purpose
4-
# FIXME: after implementing of reading of "params".
5-
TEST_DATA = {
6-
"__test__": {
7-
"dict": {"one": 1, "two": 2, "three": "three", "four": "4"},
8-
"list": [1, 2, 3, 4, 3.14],
9-
"set": {1, 2, 3},
10-
"tuple": (1, 2),
11-
"bool": True,
12-
"none": None,
13-
"float": 3.14,
14-
"nomnom": 1000,
15-
}
16-
}
9+
from funcy import identity
1710

11+
from dvc.utils.serialize import LOADERS
1812

19-
class Context:
20-
def __init__(self, data=None):
21-
self.data = data or TEST_DATA
2213

23-
def select(self, key):
24-
return _get_value(self.data, key)
14+
def _merge(into, update, overwrite):
15+
for key, val in update.items():
16+
if isinstance(into.get(key), Mapping) and isinstance(val, Mapping):
17+
_merge(into[key], val, overwrite)
18+
else:
19+
if key in into and not overwrite:
20+
raise ValueError(
21+
f"Cannot overwrite as key {key} already exists in {into}"
22+
)
23+
into[key] = val
2524

2625

27-
def _get_item(data, idx):
28-
if isinstance(data, Sequence):
29-
idx = int(idx)
26+
@dataclass
27+
class Meta:
28+
source: Optional[str] = None
29+
dpaths: List[str] = field(default_factory=list)
3030

31-
if isinstance(data, (Mapping, Sequence)):
32-
return data[idx]
31+
@staticmethod
32+
def update_path(meta: "Meta", path: Union[str, int]):
33+
dpaths = meta.dpaths[:] + [str(path)]
34+
return replace(meta, dpaths=dpaths)
3335

34-
raise ValueError(
35-
f"Cannot get item '{idx}' from data of type '{type(data).__name__}'"
36+
def __str__(self):
37+
string = self.source or "<local>:"
38+
string += self.path()
39+
return string
40+
41+
def path(self):
42+
return ".".join(self.dpaths)
43+
44+
45+
def _default_meta():
46+
return Meta(source=None)
47+
48+
49+
@dataclass
50+
class Value:
51+
value: Any
52+
meta: Meta = field(
53+
compare=False, default_factory=_default_meta, repr=False
3654
)
3755

56+
def __repr__(self):
57+
return repr(self.value)
3858

39-
def _get_value(data, key):
40-
obj_and_attrs = key.strip().split(".")
41-
value = data
42-
for attr in obj_and_attrs:
43-
if attr == "":
44-
raise ValueError("Syntax error!")
59+
def get_sources(self):
60+
return {self.meta.source: self.meta.path()}
4561

46-
try:
47-
value = _get_item(value, attr)
48-
except KeyError:
62+
63+
class Container:
64+
meta: Meta
65+
data: Union[list, dict]
66+
_key_transform = staticmethod(identity)
67+
68+
def __init__(self, meta) -> None:
69+
self.meta = meta or Meta(source=None)
70+
71+
def _convert(self, key, value):
72+
meta = Meta.update_path(self.meta, key)
73+
if value is None or isinstance(value, (int, float, str, bytes, bool)):
74+
return Value(value, meta=meta)
75+
elif isinstance(value, (CtxList, CtxDict, Value)):
76+
return value
77+
elif isinstance(value, (list, dict)):
78+
container = CtxDict if isinstance(value, dict) else CtxList
79+
return container(value, meta=meta)
80+
else:
4981
msg = (
50-
f"Could not find '{attr}' "
51-
"while substituting "
52-
f"'{key}'.\n"
53-
f"Interpolating with: {data}"
82+
"Unsupported value of type "
83+
f"'{type(value).__name__}' in '{meta}'"
5484
)
55-
raise ValueError(msg)
85+
raise TypeError(msg)
86+
87+
def __repr__(self):
88+
return repr(self.data)
89+
90+
def __getitem__(self, key):
91+
return self.data[key]
92+
93+
def __setitem__(self, key, value):
94+
self.data[key] = self._convert(key, value)
95+
96+
def __delitem__(self, key):
97+
del self.data[key]
98+
99+
def __len__(self):
100+
return len(self.data)
101+
102+
def __iter__(self):
103+
return iter(self.data)
104+
105+
def __eq__(self, o):
106+
return o.data == self.data
107+
108+
def select(self, key: str):
109+
index, *rems = key.split(sep=".", maxsplit=1)
110+
index = index.strip()
111+
index = self._key_transform(index)
112+
try:
113+
d = self.data[index]
114+
except LookupError as exc:
115+
raise ValueError(
116+
f"Could not find '{index}' in {self.data}"
117+
) from exc
118+
return d.select(rems[0]) if rems else d
119+
120+
def get_sources(self):
121+
return {}
122+
123+
124+
class CtxList(Container, MutableSequence):
125+
_key_transform = staticmethod(int)
126+
127+
def __init__(self, values: Sequence, meta: Meta = None):
128+
super().__init__(meta=meta)
129+
self.data: list = []
130+
self.extend(values)
131+
132+
def insert(self, index: int, value):
133+
self.data.insert(index, self._convert(index, value))
134+
135+
def get_sources(self):
136+
return {self.meta.source: self.meta.path()}
137+
138+
139+
class CtxDict(Container, MutableMapping):
140+
def __init__(self, mapping: Mapping = None, meta: Meta = None, **kwargs):
141+
super().__init__(meta=meta)
142+
143+
self.data: dict = {}
144+
if mapping:
145+
self.update(mapping)
146+
self.update(kwargs)
147+
148+
def __setitem__(self, key, value):
149+
if not isinstance(key, str):
150+
# limitation for the interpolation
151+
# ignore other kinds of keys
152+
return
153+
return super().__setitem__(key, value)
154+
155+
def merge_update(self, *args, overwrite=False):
156+
for d in args:
157+
_merge(self.data, d, overwrite=overwrite)
158+
159+
160+
class Context(CtxDict):
161+
def __init__(self, *args, **kwargs):
162+
"""
163+
Top level mutable dict, with some helpers to create context and track
164+
"""
165+
super().__init__(*args, **kwargs)
166+
self._track = False
167+
self._tracked_data = defaultdict(set)
168+
169+
@contextmanager
170+
def track(self):
171+
self._track = True
172+
yield
173+
self._track = False
174+
175+
def _track_data(self, node):
176+
if not self._track:
177+
return
178+
179+
for source, keys in node.get_sources().items():
180+
if not source:
181+
continue
182+
params_file = self._tracked_data[source]
183+
keys = [keys] if isinstance(keys, str) else keys
184+
params_file.update(keys)
185+
186+
@property
187+
def tracked(self):
188+
return [
189+
{file: list(keys)} for file, keys in self._tracked_data.items()
190+
]
191+
192+
def select(self, key: str):
193+
node = super().select(key)
194+
self._track_data(node)
195+
return node
196+
197+
@classmethod
198+
def load_from(cls, tree, file: str) -> "Context":
199+
_, ext = os.path.splitext(file)
200+
loader = LOADERS[ext]
201+
202+
meta = Meta(source=file)
203+
return cls(loader(file, tree=tree), meta=meta)
56204

57-
if not isinstance(value, str) and isinstance(value, Collection):
58-
raise ValueError(
59-
f"Cannot interpolate value of type '{type(value).__name__}'"
60-
)
61-
return value
205+
@classmethod
206+
def clone(cls, ctx: "Context") -> "Context":
207+
"""Clones given context."""
208+
return cls(deepcopy(ctx.data))

0 commit comments

Comments
 (0)