Skip to content

Commit 57f1528

Browse files
committed
feat: expose loader argument in Country211 and EuroSAT.
1 parent fab1188 commit 57f1528

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

torchvision/datasets/country211.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pathlib import Path
2-
from typing import Callable, Optional, Union
2+
from typing import Any, Callable, Optional, Union
33

4-
from .folder import ImageFolder
4+
from .folder import default_loader, ImageFolder
55
from .utils import download_and_extract_archive, verify_str_arg
66

77

@@ -21,6 +21,7 @@ class Country211(ImageFolder):
2121
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
2222
download (bool, optional): If True, downloads the dataset from the internet and puts it into
2323
``root/country211/``. If dataset is already downloaded, it is not downloaded again.
24+
loader (callable, optional): A function to load an image given its path.
2425
"""
2526

2627
_URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
@@ -33,6 +34,7 @@ def __init__(
3334
transform: Optional[Callable] = None,
3435
target_transform: Optional[Callable] = None,
3536
download: bool = False,
37+
loader: Callable[[str], Any] = default_loader,
3638
) -> None:
3739
self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
3840

@@ -46,7 +48,12 @@ def __init__(
4648
if not self._check_exists():
4749
raise RuntimeError("Dataset not found. You can use download=True to download it")
4850

49-
super().__init__(str(self._base_folder / self._split), transform=transform, target_transform=target_transform)
51+
super().__init__(
52+
str(self._base_folder / self._split),
53+
transform=transform,
54+
target_transform=target_transform,
55+
loader=loader,
56+
)
5057
self.root = str(root)
5158

5259
def _check_exists(self) -> bool:

torchvision/datasets/eurosat.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
22
from pathlib import Path
3-
from typing import Callable, Optional, Union
3+
from typing import Any, Callable, Optional, Union
44

5-
from .folder import ImageFolder
5+
from .folder import default_loader, ImageFolder
66
from .utils import download_and_extract_archive
77

88

@@ -21,6 +21,7 @@ class EuroSAT(ImageFolder):
2121
download (bool, optional): If True, downloads the dataset from the internet and
2222
puts it in root directory. If dataset is already downloaded, it is not
2323
downloaded again. Default is False.
24+
loader (callable, optional): A function to load an image given its path.
2425
"""
2526

2627
def __init__(
@@ -29,6 +30,7 @@ def __init__(
2930
transform: Optional[Callable] = None,
3031
target_transform: Optional[Callable] = None,
3132
download: bool = False,
33+
loader: Callable[[str], Any] = default_loader,
3234
) -> None:
3335
self.root = os.path.expanduser(root)
3436
self._base_folder = os.path.join(self.root, "eurosat")
@@ -40,7 +42,12 @@ def __init__(
4042
if not self._check_exists():
4143
raise RuntimeError("Dataset not found. You can use download=True to download it")
4244

43-
super().__init__(self._data_folder, transform=transform, target_transform=target_transform)
45+
super().__init__(
46+
self._data_folder,
47+
transform=transform,
48+
target_transform=target_transform,
49+
loader=loader,
50+
)
4451
self.root = os.path.expanduser(root)
4552

4653
def __len__(self) -> int:

0 commit comments

Comments
 (0)