Skip to content

Implement foreach ... in loop in dvc.yaml #4734

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions dvc/parsing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from collections import defaultdict
from collections.abc import Mapping, Sequence
from copy import deepcopy
from itertools import starmap
from typing import TYPE_CHECKING
Expand All @@ -25,6 +26,10 @@
WDIR_KWD = "wdir"
DEFAULT_PARAMS_FILE = ParamsDependency.DEFAULT_PARAMS_FILE
PARAMS_KWD = "params"
FOREACH_KWD = "foreach"
IN_KWD = "in"

DEFAULT_SENTINEL = object()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use None instead of DEFAULT_SENTINEL? Seems like it's used only to check whether the key has been provided?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not trying to assume to much about the parametrized data.



class DataResolver:
Expand All @@ -50,6 +55,11 @@ def __init__(self, repo: "Repo", wdir: PathInfo, d: dict):

def _resolve_entry(self, name: str, definition):
context = Context.clone(self.global_ctx)
if FOREACH_KWD in definition:
assert IN_KWD in definition
return self._foreach(
context, name, definition[FOREACH_KWD], definition[IN_KWD]
)
return self._resolve_stage(context, name, definition)

def resolve(self):
Expand Down Expand Up @@ -114,3 +124,21 @@ def _resolve_wdir(self, context: Context, wdir: str = None) -> PathInfo:
return self.wdir
wdir = resolve(wdir, context)
return self.wdir / str(wdir)

def _foreach(self, context: Context, name: str, foreach_data, in_data):
def each_iter(value, key=DEFAULT_SENTINEL):
c = Context.clone(context)
c["item"] = value
if key is not DEFAULT_SENTINEL:
c["key"] = key
suffix = str(key if key is not DEFAULT_SENTINEL else value)
return self._resolve_stage(c, f"{name}-{suffix}", in_data)

iterable = resolve(foreach_data, context)
if isinstance(iterable, Sequence):
gen = (each_iter(v) for v in iterable)
elif isinstance(iterable, Mapping):
gen = (each_iter(v, k) for k, v in iterable.items())
else:
raise Exception(f"got type of {type(iterable)}")
return join(gen)
42 changes: 22 additions & 20 deletions dvc/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dvc import dependency, output
from dvc.output import CHECKSUMS_SCHEMA, BaseOutput
from dvc.parsing import USE_KWD, VARS_KWD
from dvc.parsing import FOREACH_KWD, IN_KWD, USE_KWD, VARS_KWD
from dvc.stage.params import StageParams

STAGES = "stages"
Expand Down Expand Up @@ -48,26 +48,28 @@

PARAM_PSTAGE_NON_DEFAULT_SCHEMA = {str: [str]}

SINGLE_PIPELINE_STAGE_SCHEMA = {
str: {
StageParams.PARAM_CMD: str,
Optional(StageParams.PARAM_WDIR): str,
Optional(StageParams.PARAM_DEPS): [str],
Optional(StageParams.PARAM_PARAMS): [
Any(str, PARAM_PSTAGE_NON_DEFAULT_SCHEMA)
],
Optional(StageParams.PARAM_FROZEN): bool,
Optional(StageParams.PARAM_META): object,
Optional(StageParams.PARAM_ALWAYS_CHANGED): bool,
Optional(StageParams.PARAM_OUTS): [
Any(str, OUT_PSTAGE_DETAILED_SCHEMA)
],
Optional(StageParams.PARAM_METRICS): [
Any(str, OUT_PSTAGE_DETAILED_SCHEMA)
],
Optional(StageParams.PARAM_PLOTS): [Any(str, PLOT_PSTAGE_SCHEMA)],
}
STAGE_DEFINITION = {
StageParams.PARAM_CMD: str,
Optional(StageParams.PARAM_WDIR): str,
Optional(StageParams.PARAM_DEPS): [str],
Optional(StageParams.PARAM_PARAMS): [
Any(str, PARAM_PSTAGE_NON_DEFAULT_SCHEMA)
],
Optional(StageParams.PARAM_FROZEN): bool,
Optional(StageParams.PARAM_META): object,
Optional(StageParams.PARAM_ALWAYS_CHANGED): bool,
Optional(StageParams.PARAM_OUTS): [Any(str, OUT_PSTAGE_DETAILED_SCHEMA)],
Optional(StageParams.PARAM_METRICS): [
Any(str, OUT_PSTAGE_DETAILED_SCHEMA)
],
Optional(StageParams.PARAM_PLOTS): [Any(str, PLOT_PSTAGE_SCHEMA)],
}

FOREACH_IN = {
Required(FOREACH_KWD): Any(dict, list, str),
Required(IN_KWD): STAGE_DEFINITION,
}
SINGLE_PIPELINE_STAGE_SCHEMA = {str: Any(STAGE_DEFINITION, FOREACH_IN)}
MULTI_STAGE_SCHEMA = {
STAGES: SINGLE_PIPELINE_STAGE_SCHEMA,
USE_KWD: str,
Expand Down
69 changes: 69 additions & 0 deletions tests/func/test_stage_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,72 @@ def test_with_templated_wdir(tmp_dir, dvc):
}
},
)


def test_simple_foreach_loop(tmp_dir, dvc):
iterable = ["foo", "bar", "baz"]
d = {
"stages": {
"build": {
"foreach": iterable,
"in": {"cmd": "python script.py ${item}"},
}
}
}

resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
assert resolver.resolve() == {
"stages": {
f"build-{item}": {"cmd": f"python script.py {item}"}
for item in iterable
}
}


def test_foreach_loop_dict(tmp_dir, dvc):
iterable = {"models": {"us": {"thresh": 10}, "gb": {"thresh": 15}}}
d = {
"stages": {
"build": {
"foreach": iterable["models"],
"in": {"cmd": "python script.py ${item.thresh}"},
}
}
}

resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
assert resolver.resolve() == {
"stages": {
f"build-{key}": {"cmd": f"python script.py {item['thresh']}"}
for key, item in iterable["models"].items()
}
}


def test_foreach_loop_templatized(tmp_dir, dvc):
params = {"models": {"us": {"thresh": 10}}}
vars_ = {"models": {"gb": {"thresh": 15}}}
dump_yaml(tmp_dir / DEFAULT_PARAMS_FILE, params)
d = {
"vars": vars_,
"stages": {
"build": {
"foreach": "${models}",
"in": {"cmd": "python script.py --thresh ${item.thresh}"},
}
},
}

resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
assert_stage_equal(
resolver.resolve(),
{
"stages": {
"build-gb": {"cmd": "python script.py --thresh 15"},
"build-us": {
"cmd": "python script.py --thresh 10",
"params": ["models.us.thresh"],
},
}
},
)