Skip to content

Commit add12f9

Browse files
committed
Fix set not being auto-tracked
1 parent 856864d commit add12f9

File tree

3 files changed

+74
-32
lines changed

3 files changed

+74
-32
lines changed

dvc/parsing/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,7 @@ def _resolve_entry(self, name: str, definition):
6969
def resolve(self):
7070
stages = self.data.get(STAGES_KWD, {})
7171
data = join(starmap(self._resolve_entry, stages.items()))
72-
logger.trace( # pytype: disable=attribute-error
73-
"Resolved dvc.yaml:\n%s", dumps_yaml(data)
74-
)
72+
logger.trace("Resolved dvc.yaml:\n%s", dumps_yaml(data))
7573
return {STAGES_KWD: data}
7674

7775
def _resolve_stage(self, context: Context, name: str, definition) -> dict:
@@ -108,8 +106,9 @@ def _resolve_stage(self, context: Context, name: str, definition) -> dict:
108106
logger.trace( # pytype: disable=attribute-error
109107
"Context during resolution of stage %s:\n%s", name, context
110108
)
109+
111110
with context.track():
112-
stage_d = resolve(definition, context)
111+
stage_d = resolve(definition, context, unwrap=True)
113112

114113
params = stage_d.get(PARAMS_KWD, []) + context.tracked
115114

dvc/parsing/context.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,53 @@ def path(self):
4545
@dataclass
4646
class Value:
4747
value: Any
48-
meta: Optional[Meta] = field(compare=False, default=None, repr=False)
48+
meta: Meta = field(compare=False, repr=False)
4949

5050
def __repr__(self):
5151
return f"'{self}'"
5252

5353
def __str__(self) -> str:
5454
return str(self.value)
5555

56+
def get_sources(self):
57+
return {self.meta.source: self.meta.path()}
58+
59+
60+
class String:
61+
"""
62+
Wrapper around string, that can interpolate, and keep the
63+
original source of those interpolations.
64+
"""
65+
66+
def __init__(self, template, matches, context):
67+
68+
from .interpolate import _resolve_value
69+
70+
index, buf = 0, ""
71+
self.meta = defaultdict(set)
72+
for match in matches:
73+
start, end = match.span(0)
74+
val = _resolve_value(match, context)
75+
self._add_source(val)
76+
buf += template[index:start] + str(val)
77+
index = end
78+
value = buf + template[index:]
79+
self.value = value.replace(r"\${", "${")
80+
81+
def __repr__(self) -> str:
82+
return str(self.value)
83+
84+
def _add_source(self, val: Union[Value, "String"]):
85+
# string might have been built from multiple sources
86+
if isinstance(val, Value) and val.meta and val.meta.source:
87+
self.meta[val.meta.source].add(val.meta.path())
88+
if isinstance(val, String) and val.meta:
89+
for source, keys in self.meta.items():
90+
self.meta[source].update(keys)
91+
92+
def get_sources(self):
93+
return self.meta
94+
5695

5796
class Container:
5897
meta: Meta
@@ -66,7 +105,7 @@ def _convert(self, key, value):
66105
meta = Meta.update_path(self.meta, key)
67106
if value is None or isinstance(value, (int, float, str, bytes, bool)):
68107
return Value(value, meta=meta)
69-
elif isinstance(value, (CtxList, CtxDict, Value)):
108+
elif isinstance(value, (CtxList, CtxDict, Value, String)):
70109
return value
71110
elif isinstance(value, (list, dict)):
72111
container = CtxDict if isinstance(value, dict) else CtxList
@@ -108,6 +147,9 @@ def select(self, key: str):
108147
) from exc
109148
return d.select(rems[0]) if rems else d
110149

150+
def get_sources(self):
151+
return {}
152+
111153

112154
class CtxList(Container, MutableSequence):
113155
_key_transform = staticmethod(int)
@@ -120,6 +162,9 @@ def __init__(self, values: Sequence, meta: Meta = None):
120162
def insert(self, index: int, value):
121163
self.data.insert(index, self._convert(index, value))
122164

165+
def get_sources(self):
166+
return {self.meta.source: self.meta.path()}
167+
123168

124169
class CtxDict(Container, MutableMapping):
125170
def __init__(self, mapping: Mapping = None, meta: Meta = None, **kwargs):
@@ -158,10 +203,15 @@ def track(self):
158203
self._track = False
159204

160205
def _track_data(self, node):
161-
if isinstance(node, (Value, CtxList)):
162-
meta = node.meta
163-
if meta and meta.source and self._track:
164-
self._tracked_data[meta.source].add(meta.path())
206+
if not self._track:
207+
return
208+
209+
for source, keys in node.get_sources().items():
210+
if not source:
211+
continue
212+
params_file = self._tracked_data[source]
213+
keys = [keys] if isinstance(keys, str) else keys
214+
params_file.update(keys)
165215

166216
@property
167217
def tracked(self):

dvc/parsing/interpolate.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from funcy import rpartial
55

6-
from dvc.parsing.context import Context, Value
6+
from dvc.parsing.context import Context, String, Value
77

88
KEYCRE = re.compile(
99
r"""
@@ -16,6 +16,8 @@
1616
re.VERBOSE,
1717
)
1818

19+
UNWRAP_DEFAULT = False
20+
1921

2022
def _get_matches(template):
2123
return list(KEYCRE.finditer(template))
@@ -24,43 +26,34 @@ def _get_matches(template):
2426
def _resolve_value(match, context: Context):
2527
_, _, inner = match.groups()
2628
value = context.select(inner)
27-
if isinstance(value, Value):
28-
return value.value
2929
return value
3030

3131

32-
def _str_interpolate(template, matches, context):
33-
index, buf = 0, ""
34-
for match in matches:
35-
start, end = match.span(0)
36-
buf += template[index:start] + str(_resolve_value(match, context))
37-
index = end
38-
return buf + template[index:]
32+
def _unwrap(value):
33+
if isinstance(value, (Value, String)):
34+
return value.value
35+
return value
3936

4037

41-
def _resolve_str(src: str, context):
38+
def _resolve_str(src: str, context, unwrap=UNWRAP_DEFAULT):
4239
matches = _get_matches(src)
4340
if len(matches) == 1 and src == matches[0].group(0):
4441
# replace "${enabled}", if `enabled` is a boolean, with it's actual
4542
# value rather than it's string counterparts.
46-
return _resolve_value(matches[0], context)
47-
elif matches:
48-
# but not "${num} days"
49-
src = _str_interpolate(src, matches, context)
50-
51-
# regex already backtracks and avoids any `${` starting with
52-
# backslashes(`\`). We just need to replace those by `${`.
53-
return src.replace(r"\${", "${")
43+
value = _resolve_value(matches[0], context)
44+
else:
45+
value = String(src, matches, context)
46+
return _unwrap(value) if unwrap else value
5447

5548

56-
def resolve(src, context):
49+
def resolve(src, context, unwrap=UNWRAP_DEFAULT):
5750
Seq = (list, tuple, set)
5851

59-
apply_value = rpartial(resolve, context)
52+
apply_value = rpartial(resolve, context, unwrap=unwrap)
6053
if isinstance(src, Mapping):
6154
return {key: apply_value(value) for key, value in src.items()}
6255
elif isinstance(src, Seq):
6356
return type(src)(map(apply_value, src))
6457
elif isinstance(src, str):
65-
return _resolve_str(src, context)
58+
return _resolve_str(src, context, unwrap=unwrap)
6659
return src

0 commit comments

Comments
 (0)