diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 9eb849bbe34..7392a33f6b2 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -200,6 +200,13 @@ def make_dataset( extensions: Optional[Tuple[str, ...]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, ) -> List[Tuple[str, int]]: + if class_to_idx is None: + # prevent potential bug since make_dataset() would use the class_to_idx logic of the + # find_classes() function, instead of using that of the find_classes() method, which + # is potentially overridden and thus could have a different logic. + raise ValueError( + "The class_to_idx parameter cannot be None." + ) return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) def find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: