1
1
from pathlib import Path
2
- from typing import Callable , Optional , Union
2
+ from typing import Any , Callable , Optional , Union
3
3
4
- from .folder import ImageFolder
4
+ from .folder import default_loader , ImageFolder
5
5
from .utils import download_and_extract_archive , verify_str_arg
6
6
7
7
@@ -21,6 +21,7 @@ class Country211(ImageFolder):
21
21
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
22
22
download (bool, optional): If True, downloads the dataset from the internet and puts it into
23
23
``root/country211/``. If dataset is already downloaded, it is not downloaded again.
24
+ loader (callable, optional): A function to load an image given its path.
24
25
"""
25
26
26
27
_URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
@@ -33,6 +34,7 @@ def __init__(
33
34
transform : Optional [Callable ] = None ,
34
35
target_transform : Optional [Callable ] = None ,
35
36
download : bool = False ,
37
+ loader : Callable [[str ], Any ] = default_loader ,
36
38
) -> None :
37
39
self ._split = verify_str_arg (split , "split" , ("train" , "valid" , "test" ))
38
40
@@ -46,7 +48,12 @@ def __init__(
46
48
if not self ._check_exists ():
47
49
raise RuntimeError ("Dataset not found. You can use download=True to download it" )
48
50
49
- super ().__init__ (str (self ._base_folder / self ._split ), transform = transform , target_transform = target_transform )
51
+ super ().__init__ (
52
+ str (self ._base_folder / self ._split ),
53
+ transform = transform ,
54
+ target_transform = target_transform ,
55
+ loader = loader ,
56
+ )
50
57
self .root = str (root )
51
58
52
59
def _check_exists (self ) -> bool :
0 commit comments