diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index fbfa93c443..2aea830820 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -312,31 +312,18 @@ def _unpack_args(self, tree=None): def _update_params(self, params: dict): """Update experiment params files with the specified values.""" + from benedict import benedict + from dvc.utils.serialize import MODIFIERS logger.debug("Using experiment params '%s'", params) - # recursive dict update - def _update(dict_, other): - for key, value in other.items(): - if isinstance(dict_, list) and key.isdigit(): - key = int(key) - if isinstance(value, Mapping): - if isinstance(dict_, list): - fallback_value = dict_[key] - else: - fallback_value = dict_.get(key, {}) - dict_[key] = _update(fallback_value, value) - else: - dict_[key] = value - return dict_ - for params_fname in params: path = PathInfo(self.exp_dvc.root_dir) / params_fname suffix = path.suffix.lower() modify_data = MODIFIERS[suffix] with modify_data(path, tree=self.exp_dvc.tree) as data: - _update(data, params[params_fname]) + benedict(data).merge(params[params_fname], overwrite=True) # Force params file changes to be staged in git # Otherwise in certain situations the changes to params file may be diff --git a/dvc/repo/experiments/run.py b/dvc/repo/experiments/run.py index 112ec8aebf..92a6f80c8b 100644 --- a/dvc/repo/experiments/run.py +++ b/dvc/repo/experiments/run.py @@ -12,7 +12,6 @@ def _parse_params(path_params: Iterable): from ruamel.yaml import YAMLError from dvc.dependency.param import ParamsDependency - from dvc.utils.flatten import unflatten from dvc.utils.serialize import loads_yaml ret = {} @@ -31,7 +30,7 @@ def _parse_params(path_params: Iterable): ) if not path: path = ParamsDependency.DEFAULT_PARAMS_FILE - ret[path] = unflatten(params) + ret[path] = params return ret diff --git a/setup.py b/setup.py index aac5a8bf97..d698315acd 100644 --- a/setup.py +++ b/setup.py @@ -82,6 +82,7 @@ def run(self): "shtab>=1.3.2,<2", "rich>=3.0.5", "dictdiffer>=0.8.1", + "python-benedict>=0.21.1", ] diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 49578a09d6..e8ba0d6f71 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -66,27 +66,45 @@ def test_update_with_pull(tmp_dir, scm, dvc, exp_stage, mocker): @pytest.mark.parametrize( - "change, expected", + "changes, expected", [ - ["foo.1.baz=3", "foo: [bar: 1, baz: 3]"], - ["foo.0=bar", "foo: [bar, baz: 2]"], - ["foo.1=- baz\n- goo", "foo: [bar: 1, [baz, goo]]"], + [["foo=baz"], "{foo: baz, goo: {bag: 3}, lorem: false}"], + [["foo=baz,goo=bar"], "{foo: baz, goo: bar, lorem: false}"], + [ + ["goo.bag=4"], + "{foo: [bar: 1, baz: 2], goo: {bag: 4}, lorem: false}", + ], + [["foo[0]=bar"], "{foo: [bar, baz: 2], goo: {bag: 3}, lorem: false}"], + [ + ["foo[1].baz=3"], + "{foo: [bar: 1, baz: 3], goo: {bag: 3}, lorem: false}", + ], + [ + ["foo[1]=- baz\n- goo"], + "{foo: [bar: 1, [baz, goo]], goo: {bag: 3}, lorem: false}", + ], + [ + ["lorem.ipsum=3"], + "{foo: [bar: 1, baz: 2], goo: {bag: 3}, lorem: {ipsum: 3}}", + ], ], ) -def test_modify_list_param(tmp_dir, scm, dvc, mocker, change, expected): +def test_modify_params(tmp_dir, scm, dvc, mocker, changes, expected): tmp_dir.gen("copy.py", COPY_SCRIPT) - tmp_dir.gen("params.yaml", "foo: [bar: 1, baz: 2]") + tmp_dir.gen( + "params.yaml", "{foo: [bar: 1, baz: 2], goo: {bag: 3}, lorem: false}" + ) stage = dvc.run( cmd="python copy.py params.yaml metrics.yaml", metrics_no_cache=["metrics.yaml"], - params=["foo"], + params=["foo", "goo", "lorem"], name="copy-file", ) scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml", "metrics.yaml"]) scm.commit("init") new_mock = mocker.spy(dvc.experiments, "new") - dvc.experiments.run(stage.addressing, params=[change]) + dvc.experiments.run(stage.addressing, params=changes) new_mock.assert_called_once() assert (