Skip to content

Adds EuroSAT to the list of supported datasets #5114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
Country211
DTD
EMNIST
EuroSAT
FakeData
FashionMNIST
FER2013
Expand Down
21 changes: 21 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2169,6 +2169,27 @@ def inject_fake_data(self, tmpdir, config):
return num_sequences * (num_examples_per_sequence - 1)


class EuroSATTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.EuroSAT
FEATURE_TYPES = (PIL.Image.Image, int)

def inject_fake_data(self, tmpdir, config):
data_folder = os.path.join(tmpdir, "eurosat", "2750")
os.makedirs(data_folder)

num_examples_per_class = 3
classes = ("AnnualCrop", "Forest")
for cls in classes:
datasets_utils.create_image_folder(
root=data_folder,
name=cls,
file_name_fn=lambda idx: f"{cls}_{idx}.jpg",
num_examples=num_examples_per_class,
)

return len(classes) * num_examples_per_class


class Food101TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Food101
FEATURE_TYPES = (PIL.Image.Image, int)
Expand Down
2 changes: 2 additions & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .coco import CocoCaptions, CocoDetection
from .country211 import Country211
from .dtd import DTD
from .eurosat import EuroSAT
from .fakedata import FakeData
from .fer2013 import FER2013
from .fgvc_aircraft import FGVCAircraft
Expand Down Expand Up @@ -98,4 +99,5 @@
"OxfordIIITPet",
"Country211",
"FGVCAircraft",
"EuroSAT",
)
56 changes: 56 additions & 0 deletions torchvision/datasets/eurosat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
from typing import Any

from .folder import ImageFolder
from .utils import download_and_extract_archive


class EuroSAT(ImageFolder):
"""RGB version of the `EuroSAT <https://github.com/phelber/eurosat>`_ Dataset.

Args:
root (string): Root directory of dataset where ``root/eurosat`` exists.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""

url = "https://madm.dfki.de/files/sentinel/EuroSAT.zip"
md5 = "c8fa014336c82ac7804f0398fcb19387"

def __init__(
self,
root: str,
download: bool = False,
**kwargs: Any,
) -> None:
self.root = os.path.expanduser(root)
self._base_folder = os.path.join(self.root, "eurosat")
self._data_folder = os.path.join(self._base_folder, "2750")

if download:
self.download()

if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")

super().__init__(self._data_folder, **kwargs)
self.root = os.path.expanduser(root)

def __len__(self) -> int:
return len(self.samples)

def _check_exists(self) -> bool:
return os.path.exists(self._data_folder)

def download(self) -> None:

if self._check_exists():
return

os.makedirs(self._base_folder, exist_ok=True)
download_and_extract_archive(self.url, download_root=self._base_folder, md5=self.md5)