Skip to content

Commit ddd3111

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] allow single extension as str in make_dataset (#5229)
Summary: * allow single extension as str in make_dataset * remove test class * remove regex * revert collection to tuple * cleanup Reviewed By: jdsgomes, prabhat00155 Differential Revision: D33739394 fbshipit-source-id: 460a1576a18d9ee61657302c10b3e51156cd66fd
1 parent e0a4b0e commit ddd3111

File tree

2 files changed

+33
-6
lines changed

2 files changed

+33
-6
lines changed

test/test_datasets_utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import contextlib
22
import gzip
33
import os
4+
import pathlib
5+
import re
46
import tarfile
57
import zipfile
68

79
import pytest
810
import torchvision.datasets.utils as utils
911
from torch._utils_internal import get_file_path_2
12+
from torchvision.datasets.folder import make_dataset
1013
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
1114

12-
1315
TEST_FILE = get_file_path_2(
1416
os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
1517
)
@@ -214,5 +216,29 @@ def test_verify_str_arg(self):
214216
pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
215217

216218

219+
@pytest.mark.parametrize(
220+
("kwargs", "expected_error_msg"),
221+
[
222+
(dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"),
223+
(dict(extensions=".png"), re.escape("classes b, c. Supported extensions are: .png")),
224+
(dict(extensions=(".png", ".jpeg")), re.escape("classes c. Supported extensions are: .png, .jpeg")),
225+
],
226+
)
227+
def test_make_dataset_no_valid_files(tmpdir, kwargs, expected_error_msg):
228+
tmpdir = pathlib.Path(tmpdir)
229+
230+
(tmpdir / "a").mkdir()
231+
(tmpdir / "a" / "a.png").touch()
232+
233+
(tmpdir / "b").mkdir()
234+
(tmpdir / "b" / "b.jpeg").touch()
235+
236+
(tmpdir / "c").mkdir()
237+
(tmpdir / "c" / "c.unknown").touch()
238+
239+
with pytest.raises(FileNotFoundError, match=expected_error_msg):
240+
make_dataset(str(tmpdir), **kwargs)
241+
242+
217243
if __name__ == "__main__":
218244
pytest.main([__file__])

torchvision/datasets/folder.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import os
22
import os.path
33
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
4+
from typing import Union
45

56
from PIL import Image
67

78
from .vision import VisionDataset
89

910

10-
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
11+
def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
1112
"""Checks if a file is an allowed extension.
1213
1314
Args:
@@ -17,7 +18,7 @@ def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bo
1718
Returns:
1819
bool: True if the filename ends with one of given extensions
1920
"""
20-
return filename.lower().endswith(extensions)
21+
return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
2122

2223

2324
def is_image_file(filename: str) -> bool:
@@ -48,7 +49,7 @@ def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
4849
def make_dataset(
4950
directory: str,
5051
class_to_idx: Optional[Dict[str, int]] = None,
51-
extensions: Optional[Tuple[str, ...]] = None,
52+
extensions: Optional[Union[str, Tuple[str, ...]]] = None,
5253
is_valid_file: Optional[Callable[[str], bool]] = None,
5354
) -> List[Tuple[str, int]]:
5455
"""Generates a list of samples of a form (path_to_sample, class).
@@ -73,7 +74,7 @@ def make_dataset(
7374
if extensions is not None:
7475

7576
def is_valid_file(x: str) -> bool:
76-
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
77+
return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
7778

7879
is_valid_file = cast(Callable[[str], bool], is_valid_file)
7980

@@ -98,7 +99,7 @@ def is_valid_file(x: str) -> bool:
9899
if empty_classes:
99100
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
100101
if extensions is not None:
101-
msg += f"Supported extensions are: {', '.join(extensions)}"
102+
msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
102103
raise FileNotFoundError(msg)
103104

104105
return instances

0 commit comments

Comments
 (0)