Skip to content

dvc.yaml: introduce set keyword #4757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 62 additions & 7 deletions dvc/parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,21 @@
from collections.abc import Mapping, Sequence
from copy import deepcopy
from itertools import starmap
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union

from funcy import first, join

from dvc.dependency.param import ParamsDependency
from dvc.path_info import PathInfo
from dvc.utils.serialize import dumps_yaml

from .context import Context
from .interpolate import resolve
from .interpolate import (
_get_matches,
_is_exact_string,
_is_interpolated_string,
_resolve_str,
resolve,
)

if TYPE_CHECKING:
from dvc.repo import Repo
Expand All @@ -28,8 +33,10 @@
PARAMS_KWD = "params"
FOREACH_KWD = "foreach"
IN_KWD = "in"
SET_KWD = "set"

DEFAULT_SENTINEL = object()
SeqOrMap = Union[Sequence, Mapping]


class DataResolver:
Expand All @@ -56,6 +63,7 @@ def __init__(self, repo: "Repo", wdir: PathInfo, d: dict):
def _resolve_entry(self, name: str, definition):
context = Context.clone(self.global_ctx)
if FOREACH_KWD in definition:
self.set_context_from(context, definition.get(SET_KWD, {}))
assert IN_KWD in definition
return self._foreach(
context, name, definition[FOREACH_KWD], definition[IN_KWD]
Expand All @@ -65,11 +73,12 @@ def _resolve_entry(self, name: str, definition):
def resolve(self):
stages = self.data.get(STAGES_KWD, {})
data = join(starmap(self._resolve_entry, stages.items()))
logger.trace("Resolved dvc.yaml:\n%s", dumps_yaml(data))
logger.trace("Resolved dvc.yaml:\n%s", data)
return {STAGES_KWD: data}

def _resolve_stage(self, context: Context, name: str, definition) -> dict:
definition = deepcopy(definition)
self.set_context_from(context, definition.pop(SET_KWD, {}))
wdir = self._resolve_wdir(context, definition.get(WDIR_KWD))
if self.wdir != wdir:
logger.debug(
Expand Down Expand Up @@ -135,10 +144,56 @@ def each_iter(value, key=DEFAULT_SENTINEL):
return self._resolve_stage(c, f"{name}-{suffix}", in_data)

iterable = resolve(foreach_data, context)

assert isinstance(iterable, (Sequence, Mapping)) and not isinstance(
iterable, str
), f"got type of {type(iterable)}"
if isinstance(iterable, Sequence):
gen = (each_iter(v) for v in iterable)
elif isinstance(iterable, Mapping):
gen = (each_iter(v, k) for k, v in iterable.items())
else:
raise Exception(f"got type of {type(iterable)}")
gen = (each_iter(v, k) for k, v in iterable.items())
return join(gen)

@classmethod
def set_context_from(cls, context: Context, to_set):
for key, value in to_set.items():
if key in context:
raise ValueError(f"Cannot set '{key}', key already exists")
if isinstance(value, str):
cls._check_joined_with_interpolation(key, value)
value = _resolve_str(value, context, unwrap=False)
elif isinstance(value, (Sequence, Mapping)):
cls._check_nested_collection(key, value)
cls._check_interpolation_collection(key, value)
context[key] = value

@staticmethod
def _check_nested_collection(key: str, value: SeqOrMap):
values = value.values() if isinstance(value, Mapping) else value
has_nested = any(
not isinstance(item, str) and isinstance(item, (Mapping, Sequence))
for item in values
)
if has_nested:
raise ValueError(f"Cannot set '{key}', has nested dict/list")

@staticmethod
def _check_interpolation_collection(key: str, value: SeqOrMap):
values = value.values() if isinstance(value, Mapping) else value
interpolated = any(_is_interpolated_string(item) for item in values)
if interpolated:
raise ValueError(
f"Cannot set '{key}', "
"having interpolation inside "
f"'{type(value).__name__}' is not supported."
)

@staticmethod
def _check_joined_with_interpolation(key: str, value: str):
matches = _get_matches(value)
if matches and not _is_exact_string(value, matches):
raise ValueError(
f"Cannot set '{key}', "
"joining string with interpolated string"
"is not supported"
)
8 changes: 7 additions & 1 deletion dvc/parsing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class Value:
def __repr__(self):
return repr(self.value)

def __str__(self) -> str:
return str(self.value)

def get_sources(self):
return {self.meta.source: self.meta.path()}

Expand Down Expand Up @@ -103,7 +106,10 @@ def __iter__(self):
return iter(self.data)

def __eq__(self, o):
return o.data == self.data
container = type(self)
if isinstance(o, container):
return o.data == self.data
return container(o) == self

def select(self, key: str):
index, *rems = key.split(sep=".", maxsplit=1)
Expand Down
22 changes: 16 additions & 6 deletions dvc/parsing/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,27 @@
re.VERBOSE,
)

UNWRAP_DEFAULT = True

def _get_matches(template):

def _get_matches(template: str):
return list(KEYCRE.finditer(template))


def _is_interpolated_string(val):
return bool(_get_matches(val)) if isinstance(val, str) else False


def _unwrap(value):
if isinstance(value, Value):
return value.value
return value


def _resolve_value(match, context: Context):
def _resolve_value(match, context: Context, unwrap=UNWRAP_DEFAULT):
_, _, inner = match.groups()
value = context.select(inner)
return _unwrap(value)
return _unwrap(value) if unwrap else value


def _str_interpolate(template, matches, context):
Expand All @@ -42,12 +48,16 @@ def _str_interpolate(template, matches, context):
return buf + template[index:]


def _resolve_str(src: str, context):
def _is_exact_string(src: str, matches):
return len(matches) == 1 and src == matches[0].group(0)


def _resolve_str(src: str, context, unwrap=UNWRAP_DEFAULT):
matches = _get_matches(src)
if len(matches) == 1 and src == matches[0].group(0):
if _is_exact_string(src, matches):
# replace "${enabled}", if `enabled` is a boolean, with it's actual
# value rather than it's string counterparts.
return _resolve_value(matches[0], context)
return _resolve_value(matches[0], context, unwrap=unwrap)

# but not "${num} days"
src = _str_interpolate(src, matches, context)
Expand Down
4 changes: 3 additions & 1 deletion dvc/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dvc import dependency, output
from dvc.output import CHECKSUMS_SCHEMA, BaseOutput
from dvc.parsing import FOREACH_KWD, IN_KWD, USE_KWD, VARS_KWD
from dvc.parsing import FOREACH_KWD, IN_KWD, SET_KWD, USE_KWD, VARS_KWD
from dvc.stage.params import StageParams

STAGES = "stages"
Expand Down Expand Up @@ -50,6 +50,7 @@

STAGE_DEFINITION = {
StageParams.PARAM_CMD: str,
Optional(SET_KWD): dict,
Optional(StageParams.PARAM_WDIR): str,
Optional(StageParams.PARAM_DEPS): [str],
Optional(StageParams.PARAM_PARAMS): [
Expand All @@ -66,6 +67,7 @@
}

FOREACH_IN = {
Optional(SET_KWD): dict,
Required(FOREACH_KWD): Any(dict, list, str),
Required(IN_KWD): STAGE_DEFINITION,
}
Expand Down
98 changes: 98 additions & 0 deletions tests/func/test_stage_resolver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from copy import deepcopy
from math import pi

import pytest

Expand Down Expand Up @@ -302,3 +303,100 @@ def test_foreach_loop_templatized(tmp_dir, dvc):
}
},
)


@pytest.mark.parametrize(
"value", ["value", "To set or not to set", 3, pi, True, False, None]
)
def test_set(tmp_dir, dvc, value):
d = {
"stages": {
"build": {
"set": {"item": value},
"cmd": "python script.py --thresh ${item}",
"always_changed": "${item}",
}
}
}
resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
assert resolver.resolve() == {
"stages": {
"build": {
"cmd": f"python script.py --thresh {value}",
"always_changed": value,
}
}
}


@pytest.mark.parametrize(
"coll", [["foo", "bar", "baz"], {"foo": "foo", "bar": "bar"}]
)
def test_coll(tmp_dir, dvc, coll):
d = {
"stages": {
"build": {
"set": {"item": coll, "thresh": 10},
"cmd": "python script.py --thresh ${thresh}",
"outs": "${item}",
}
}
}
resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
assert resolver.resolve() == {
"stages": {
"build": {"cmd": "python script.py --thresh 10", "outs": coll}
}
}


def test_set_with_foreach(tmp_dir, dvc):
items = ["foo", "bar", "baz"]
d = {
"stages": {
"build": {
"set": {"items": items},
"foreach": "${items}",
"in": {"cmd": "command --value ${item}"},
}
}
}
resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
assert resolver.resolve() == {
"stages": {
f"build-{item}": {"cmd": f"command --value {item}"}
for item in items
}
}


def test_set_with_foreach_and_on_stage_definition(tmp_dir, dvc):
iterable = {"models": {"us": {"thresh": 10}, "gb": {"thresh": 15}}}
dump_json(tmp_dir / "params.json", iterable)

d = {
"use": "params.json",
"stages": {
"build": {
"set": {"data": "${models}"},
"foreach": "${data}",
"in": {
"set": {"thresh": "${item.thresh}"},
"cmd": "command --value ${thresh}",
},
}
},
}
resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
assert resolver.resolve() == {
"stages": {
"build-us": {
"cmd": "command --value 10",
"params": [{"params.json": ["models.us.thresh"]}],
},
"build-gb": {
"cmd": "command --value 15",
"params": [{"params.json": ["models.gb.thresh"]}],
},
}
}
Loading