Skip to content

Commit 563d9ca

Browse files
authored
improve error message for prototype datasets without options (#5224)
* improve error message for prototype datasets without options * remove early exit * add test
1 parent e32b19e commit 563d9ca

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

test/test_prototype_datasets_api.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,16 @@ def test_default_config(self, info):
126126
assert info.default_config == default_config
127127

128128
@pytest.mark.parametrize(
129-
("options", "expected_error_msg"),
129+
("valid_options", "options", "expected_error_msg"),
130130
[
131-
pytest.param(dict(unknown_option=None), "Unknown option 'unknown_option'", id="unknown_option"),
132-
pytest.param(dict(split="unknown_split"), "Invalid argument 'unknown_split'", id="invalid_argument"),
131+
(dict(), dict(any_option=None), "does not take any options"),
132+
(dict(split="train"), dict(unknown_option=None), "Unknown option 'unknown_option'"),
133+
(dict(split="train"), dict(split="invalid_argument"), "Invalid argument 'invalid_argument'"),
133134
],
134135
)
135-
def test_make_config_invalid_inputs(self, info, options, expected_error_msg):
136+
def test_make_config_invalid_inputs(self, info, valid_options, options, expected_error_msg):
137+
info = make_minimal_dataset_info(valid_options=valid_options)
138+
136139
with pytest.raises(ValueError, match=expected_error_msg):
137140
info.make_config(**options)
138141

torchvision/prototype/datasets/utils/_dataset.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ def read_categories_file(path: pathlib.Path) -> List[List[str]]:
7878
return [row for row in csv.reader(file)]
7979

8080
def make_config(self, **options: Any) -> DatasetConfig:
81+
if not self._valid_options and options:
82+
raise ValueError(
83+
f"Dataset {self.name} does not take any options, "
84+
f"but got {sequence_to_str(list(options), separate_last=' and')}."
85+
)
86+
8187
for name, arg in options.items():
8288
if name not in self._valid_options:
8389
raise ValueError(

0 commit comments

Comments
 (0)