Skip to content

Commit 121bb55

Browse files
committed
Merge branch 'main' of github.com:pytorch/vision into stanford_cars
2 parents b593760 + 57a77c4 commit 121bb55

File tree

4 files changed

+80
-0
lines changed

4 files changed

+80
-0
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
4141
Country211
4242
DTD
4343
EMNIST
44+
EuroSAT
4445
FakeData
4546
FashionMNIST
4647
FER2013

test/test_datasets.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2169,6 +2169,27 @@ def inject_fake_data(self, tmpdir, config):
21692169
return num_sequences * (num_examples_per_sequence - 1)
21702170

21712171

2172+
class EuroSATTestCase(datasets_utils.ImageDatasetTestCase):
2173+
DATASET_CLASS = datasets.EuroSAT
2174+
FEATURE_TYPES = (PIL.Image.Image, int)
2175+
2176+
def inject_fake_data(self, tmpdir, config):
2177+
data_folder = os.path.join(tmpdir, "eurosat", "2750")
2178+
os.makedirs(data_folder)
2179+
2180+
num_examples_per_class = 3
2181+
classes = ("AnnualCrop", "Forest")
2182+
for cls in classes:
2183+
datasets_utils.create_image_folder(
2184+
root=data_folder,
2185+
name=cls,
2186+
file_name_fn=lambda idx: f"{cls}_{idx}.jpg",
2187+
num_examples=num_examples_per_class,
2188+
)
2189+
2190+
return len(classes) * num_examples_per_class
2191+
2192+
21722193
class Food101TestCase(datasets_utils.ImageDatasetTestCase):
21732194
DATASET_CLASS = datasets.Food101
21742195
FEATURE_TYPES = (PIL.Image.Image, int)

torchvision/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .coco import CocoCaptions, CocoDetection
88
from .country211 import Country211
99
from .dtd import DTD
10+
from .eurosat import EuroSAT
1011
from .fakedata import FakeData
1112
from .fer2013 import FER2013
1213
from .fgvc_aircraft import FGVCAircraft
@@ -100,4 +101,5 @@
100101
"OxfordIIITPet",
101102
"Country211",
102103
"FGVCAircraft",
104+
"EuroSAT",
103105
)

torchvision/datasets/eurosat.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
from typing import Any
3+
4+
from .folder import ImageFolder
5+
from .utils import download_and_extract_archive
6+
7+
8+
class EuroSAT(ImageFolder):
9+
"""RGB version of the `EuroSAT <https://github.com/phelber/eurosat>`_ Dataset.
10+
11+
Args:
12+
root (string): Root directory of dataset where ``root/eurosat`` exists.
13+
download (bool, optional): If True, downloads the dataset from the internet and
14+
puts it in root directory. If dataset is already downloaded, it is not
15+
downloaded again. Default is False.
16+
transform (callable, optional): A function/transform that takes in an PIL image
17+
and returns a transformed version. E.g, ``transforms.RandomCrop``
18+
target_transform (callable, optional): A function/transform that takes in the
19+
target and transforms it.
20+
"""
21+
22+
url = "https://madm.dfki.de/files/sentinel/EuroSAT.zip"
23+
md5 = "c8fa014336c82ac7804f0398fcb19387"
24+
25+
def __init__(
26+
self,
27+
root: str,
28+
download: bool = False,
29+
**kwargs: Any,
30+
) -> None:
31+
self.root = os.path.expanduser(root)
32+
self._base_folder = os.path.join(self.root, "eurosat")
33+
self._data_folder = os.path.join(self._base_folder, "2750")
34+
35+
if download:
36+
self.download()
37+
38+
if not self._check_exists():
39+
raise RuntimeError("Dataset not found. You can use download=True to download it")
40+
41+
super().__init__(self._data_folder, **kwargs)
42+
self.root = os.path.expanduser(root)
43+
44+
def __len__(self) -> int:
45+
return len(self.samples)
46+
47+
def _check_exists(self) -> bool:
48+
return os.path.exists(self._data_folder)
49+
50+
def download(self) -> None:
51+
52+
if self._check_exists():
53+
return
54+
55+
os.makedirs(self._base_folder, exist_ok=True)
56+
download_and_extract_archive(self.url, download_root=self._base_folder, md5=self.md5)

0 commit comments

Comments
 (0)