Skip to content

Commit e2b1fb1

Browse files
dtrifiroDaniele Trifirò
and
Daniele Trifirò
authored
add toml support for ParamsDependency (#4258)
* add toml support for ParamsDependency * add toml support to dvc show Co-authored-by: Daniele Trifirò <[email protected]>
1 parent 39c0bdb commit e2b1fb1

File tree

6 files changed

+56
-5
lines changed

6 files changed

+56
-5
lines changed

dvc/dependency/param.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections import defaultdict
33

44
import dpath.util
5+
import toml
56
import yaml
67
from voluptuous import Any
78

@@ -21,6 +22,8 @@ class ParamsDependency(LocalDependency):
2122
PARAM_PARAMS = "params"
2223
PARAM_SCHEMA = {PARAM_PARAMS: Any(dict, list, None)}
2324
DEFAULT_PARAMS_FILE = "params.yaml"
25+
PARAMS_FILE_LOADERS = defaultdict(lambda: yaml.safe_load)
26+
PARAMS_FILE_LOADERS.update({".toml": toml.load})
2427

2528
def __init__(self, stage, path, params):
2629
info = {}
@@ -87,8 +90,10 @@ def read_params(self):
8790

8891
with self.repo.tree.open(self.path_info, "r") as fobj:
8992
try:
90-
config = yaml.safe_load(fobj)
91-
except yaml.YAMLError as exc:
93+
config = self.PARAMS_FILE_LOADERS[
94+
self.path_info.suffix.lower()
95+
](fobj)
96+
except (yaml.YAMLError, toml.TomlDecodeError) as exc:
9297
raise BadParamFileError(
9398
f"Unable to read parameters from '{self}'"
9499
) from exc

dvc/repo/params/show.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22

3+
import toml
34
import yaml
45

56
from dvc.dependency.param import ParamsDependency
@@ -34,8 +35,10 @@ def _read_params(repo, configs, rev):
3435

3536
with repo.tree.open(config, "r") as fobj:
3637
try:
37-
res[str(config)] = yaml.safe_load(fobj)
38-
except yaml.YAMLError:
38+
res[str(config)] = ParamsDependency.PARAMS_FILE_LOADERS[
39+
config.suffix.lower()
40+
](fobj)
41+
except (yaml.YAMLError, toml.TomlDecodeError):
3942
logger.debug(
4043
"failed to read '%s' on '%s'", config, rev, exc_info=True
4144
)

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ count=true
1717
[isort]
1818
include_trailing_comma=true
1919
known_first_party=dvc,tests
20-
known_third_party=PyInstaller,RangeHTTPServer,boto3,colorama,configobj,distro,dpath,flaky,flufl,funcy,git,grandalf,mock,moto,nanotime,networkx,packaging,pathspec,pygtrie,pylint,pytest,requests,ruamel,setuptools,shortuuid,shtab,tqdm,voluptuous,yaml,zc
20+
known_third_party=PyInstaller,RangeHTTPServer,boto3,colorama,configobj,distro,dpath,flaky,flufl,funcy,git,grandalf,mock,moto,nanotime,networkx,packaging,pathspec,pygtrie,pylint,pytest,requests,ruamel,setuptools,shortuuid,shtab,toml,tqdm,voluptuous,yaml,zc
2121
line_length=79
2222
force_grid_wrap=0
2323
use_parentheses=True

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def run(self):
6262
"appdirs>=1.4.3",
6363
"PyYAML>=5.1.2,<5.4", # Compatibility with awscli
6464
"ruamel.yaml>=0.16.1",
65+
"toml>=0.10.1",
6566
"funcy>=1.14",
6667
"pathspec>=0.6.0",
6768
"shortuuid>=0.5.0",

tests/func/params/test_show.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@ def test_show(tmp_dir, dvc):
1414
assert dvc.params.show() == {"": {"params.yaml": {"foo": "bar"}}}
1515

1616

17+
def test_show_toml(tmp_dir, dvc):
18+
tmp_dir.gen("params.toml", "[foo]\nbar = 42\nbaz = [1, 2]\n")
19+
dvc.run(
20+
cmd="echo params.toml", params=["params.toml:foo"], single_stage=True
21+
)
22+
assert dvc.params.show() == {
23+
"": {"params.toml": {"foo": {"bar": 42, "baz": [1, 2]}}}
24+
}
25+
26+
1727
def test_show_multiple(tmp_dir, dvc):
1828
tmp_dir.gen("params.yaml", "foo: bar\nbaz: qux\n")
1929
dvc.run(

tests/unit/dependency/test_params.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import toml
23
import yaml
34

45
from dvc.dependency import ParamsDependency, loadd_from, loads_params
@@ -99,6 +100,37 @@ def test_read_params_nested(tmp_dir, dvc):
99100
assert dep.read_params() == {"some.path.foo": ["val1", "val2"]}
100101

101102

103+
def test_read_params_default_loader(tmp_dir, dvc):
104+
parameters_file = "parameters.foo"
105+
tmp_dir.gen(
106+
parameters_file,
107+
yaml.dump({"some": {"path": {"foo": ["val1", "val2"]}}}),
108+
)
109+
dep = ParamsDependency(Stage(dvc), parameters_file, ["some.path.foo"])
110+
assert dep.read_params() == {"some.path.foo": ["val1", "val2"]}
111+
112+
113+
def test_read_params_wrong_suffix(tmp_dir, dvc):
114+
parameters_file = "parameters.toml"
115+
tmp_dir.gen(
116+
parameters_file,
117+
yaml.dump({"some": {"path": {"foo": ["val1", "val2"]}}}),
118+
)
119+
dep = ParamsDependency(Stage(dvc), parameters_file, ["some.path.foo"])
120+
with pytest.raises(BadParamFileError):
121+
dep.read_params()
122+
123+
124+
def test_read_params_toml(tmp_dir, dvc):
125+
parameters_file = "parameters.toml"
126+
tmp_dir.gen(
127+
parameters_file,
128+
toml.dumps({"some": {"path": {"foo": ["val1", "val2"]}}}),
129+
)
130+
dep = ParamsDependency(Stage(dvc), parameters_file, ["some.path.foo"])
131+
assert dep.read_params() == {"some.path.foo": ["val1", "val2"]}
132+
133+
102134
def test_save_info_missing_config(dvc):
103135
dep = ParamsDependency(Stage(dvc), None, ["foo"])
104136
with pytest.raises(MissingParamsError):

0 commit comments

Comments
 (0)