Skip to content

Commit a00f113

Browse files
committed
Fix foreach to use any kind of iterables
1 parent c44039e commit a00f113

File tree

5 files changed

+121
-138
lines changed

5 files changed

+121
-138
lines changed

dvc/parsing/__init__.py

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import os
3+
from collections.abc import Mapping, Sequence
34
from copy import deepcopy
45
from itertools import starmap
56
from typing import TYPE_CHECKING
@@ -11,7 +12,7 @@
1112
from dvc.path_info import PathInfo
1213
from dvc.utils.serialize import dumps_yaml
1314

14-
from .context import Context, CtxDict, CtxList
15+
from .context import Context
1516
from .interpolate import resolve
1617

1718
if TYPE_CHECKING:
@@ -29,21 +30,26 @@
2930
DEFAULT_PARAMS_FILE = ParamsDependency.DEFAULT_PARAMS_FILE
3031
PARAMS_KWD = "params"
3132

33+
DEFAULT_SENTINEL = object()
34+
3235

3336
class DataResolver:
34-
def __init__(self, repo: "Repo", yaml_wdir: PathInfo, d):
37+
def __init__(self, repo: "Repo", yaml_wdir: PathInfo, d: dict):
3538
to_import: PathInfo = yaml_wdir / d.get(USE_KWD, DEFAULT_PARAMS_FILE)
3639
vars_ = d.get(VARS_KWD, {})
3740
if os.path.exists(to_import):
38-
self.global_ctx = Context.load_from(
39-
repo.tree, str(to_import), vars_
40-
)
4141
self.global_ctx_source = to_import
42+
self.global_ctx = Context.load_from(repo.tree, str(to_import))
4243
else:
44+
self.global_ctx = Context()
4345
self.global_ctx_source = None
44-
self.global_ctx = Context.create(vars_)
46+
logger.debug(
47+
"%s does not exist, it won't be used in parametrization",
48+
to_import,
49+
)
4550

46-
self.data = d
51+
self.global_ctx.merge_update(vars_)
52+
self.data: dict = d
4753
self._yaml_wdir = yaml_wdir
4854
self.repo = repo
4955

@@ -63,23 +69,34 @@ def _resolve_entry(self, name: str, definition):
6369
def resolve(self):
6470
stages = self.data.get(STAGES_KWD, {})
6571
data = join(starmap(self._resolve_entry, stages.items()))
66-
logger.trace("Resolved dvc.yaml:\n%s", dumps_yaml(data))
67-
return {**self.data, STAGES_KWD: data}
72+
logger.trace( # pytype: disable=attribute-error
73+
"Resolved dvc.yaml:\n%s", dumps_yaml(data)
74+
)
75+
return {STAGES_KWD: data}
6876

69-
def _resolve_stage(self, context: Context, name, definition):
77+
def _resolve_stage(self, context: Context, name: str, definition) -> dict:
7078
definition = deepcopy(definition)
7179
self._set_context_from(context, definition.pop(SET_KWD, {}))
80+
7281
wdir = self._resolve_wdir(context, definition.get(WDIR_KWD))
73-
params_file = definition.get(PARAMS_KWD, [])
74-
contexts = []
82+
if self._yaml_wdir != wdir:
83+
logger.debug(
84+
"Stage %s has different wdir than dvc.yaml file", name
85+
)
7586

87+
contexts = []
7688
params_yaml_file = wdir / DEFAULT_PARAMS_FILE
77-
if (self.global_ctx_source != params_yaml_file) and os.path.exists(
78-
params_yaml_file
79-
):
80-
contexts.append(
81-
Context.load_from(self.repo.tree, str(params_yaml_file))
82-
)
89+
if self.global_ctx_source != params_yaml_file:
90+
if os.path.exists(params_yaml_file):
91+
contexts.append(
92+
Context.load_from(self.repo.tree, str(params_yaml_file))
93+
)
94+
else:
95+
logger.debug(
96+
"%s does not exist for stage %s", params_yaml_file, name
97+
)
98+
99+
params_file = definition.get(PARAMS_KWD, [])
83100
for item in params_file:
84101
if item and isinstance(item, dict):
85102
contexts.append(
@@ -88,38 +105,37 @@ def _resolve_stage(self, context: Context, name, definition):
88105

89106
context.merge_update(*contexts)
90107

91-
stage_d = resolve(definition, context)
108+
logger.trace( # pytype: disable=attribute-error
109+
"Context during resolution of stage %s:\n%s", name, context
110+
)
111+
with context.track():
112+
stage_d = resolve(definition, context)
113+
92114
params = stage_d.get(PARAMS_KWD, []) + context.tracked
93115

94116
if params:
95117
stage_d[PARAMS_KWD] = params
96118
return {name: stage_d}
97119

98120
def _foreach(self, context: Context, name, foreach_data, in_data):
99-
assert isinstance(foreach_data, str)
100-
iterables = resolve(foreach_data, context)
101-
102-
def each_iter(value):
121+
def each_iter(value, key=DEFAULT_SENTINEL):
103122
c = Context.clone(context)
104-
if isinstance(value, tuple):
105-
key, val = value
106-
else:
107-
key, val = None, value
108-
c["item"] = val
109-
if key is not None:
123+
c["item"] = value
124+
if key is not DEFAULT_SENTINEL:
110125
c["key"] = key
111-
suff = key or value
112-
return self._resolve_stage(c, f"{name}-{suff}", in_data)
113-
114-
if isinstance(iterables, (CtxList, list, tuple)):
115-
gen = map(each_iter, iterables)
116-
elif isinstance(iterables, (CtxDict, dict)):
117-
gen = map(each_iter, iterables.items())
126+
suffix = str(key if key is not DEFAULT_SENTINEL else value)
127+
return self._resolve_stage(c, f"{name}-{suffix}", in_data)
128+
129+
iterable = resolve(foreach_data, context)
130+
if isinstance(iterable, Sequence):
131+
gen = (each_iter(v) for v in iterable)
132+
elif isinstance(iterable, Mapping):
133+
gen = (each_iter(v, k) for k, v in iterable.items())
118134
else:
119-
raise Exception(f"got type of {type(iterables)}")
135+
raise Exception(f"got type of {type(iterable)}")
120136
return join(gen)
121137

122-
def _resolve_wdir(self, context, wdir):
138+
def _resolve_wdir(self, context: Context, wdir: str = None) -> PathInfo:
123139
if not wdir:
124140
return self._yaml_wdir
125141
wdir = resolve(wdir, context)

0 commit comments

Comments
 (0)