Skip to content

Commit d2486f6

Browse files
authored
Put back CelebA download (#6147)
* Revert "Indicate Celeba download parameter is deprecated and will be removed (#6059)" This reverts commit 49496c4. * Revert "Deactivate CelebA download (#6052)" This reverts commit ac56f52.
1 parent efc67ea commit d2486f6

File tree

3 files changed

+25
-48
lines changed

3 files changed

+25
-48
lines changed

torchvision/datasets/celeba.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import csv
22
import os
3-
import warnings
43
from collections import namedtuple
54
from typing import Any, Callable, List, Optional, Union, Tuple
65

76
import PIL
87
import torch
98

10-
from .utils import check_integrity, verify_str_arg
9+
from .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive
1110
from .vision import VisionDataset
1211

1312
CSV = namedtuple("CSV", ["header", "index", "data"])
@@ -36,17 +35,9 @@ class CelebA(VisionDataset):
3635
and returns a transformed version. E.g, ``transforms.PILToTensor``
3736
target_transform (callable, optional): A function/transform that takes in the
3837
target and transforms it.
39-
download (bool, optional): Deprecated.
40-
41-
.. warning::
42-
43-
Downloading CelebA is not supported anymore as of 0.13 and this
44-
parameter will be removed in 0.15. See
45-
`this issue <https://github.com/pytorch/vision/issues/5705>`__
46-
for more details.
47-
Please download the files from
48-
https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract
49-
them in ``root/celeba``.
38+
download (bool, optional): If true, downloads the dataset from the internet and
39+
puts it in root directory. If dataset is already downloaded, it is not
40+
downloaded again.
5041
"""
5142

5243
base_folder = "celeba"
@@ -73,7 +64,7 @@ def __init__(
7364
target_type: Union[List[str], str] = "attr",
7465
transform: Optional[Callable] = None,
7566
target_transform: Optional[Callable] = None,
76-
download: bool = None,
67+
download: bool = False,
7768
) -> None:
7869
super().__init__(root, transform=transform, target_transform=target_transform)
7970
self.split = split
@@ -85,15 +76,6 @@ def __init__(
8576
if not self.target_type and self.target_transform is not None:
8677
raise RuntimeError("target_transform is specified but target_type is empty")
8778

88-
if download is not None:
89-
warnings.warn(
90-
"Downloading CelebA is not supported anymore as of 0.13, and the "
91-
"download parameter will be removed in 0.15. See "
92-
"https://github.com/pytorch/vision/issues/5705 for more details. "
93-
"Please download the files from "
94-
"https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract them "
95-
"in ``root/celeba``."
96-
)
9779
if download:
9880
self.download()
9981

@@ -164,14 +146,10 @@ def download(self) -> None:
164146
print("Files already downloaded and verified")
165147
return
166148

167-
raise ValueError(
168-
"Downloading CelebA is not supported anymore as of 0.13, and the "
169-
"download parameter will be removed in 0.15. See "
170-
"https://github.com/pytorch/vision/issues/5705 for more details. "
171-
"Please download the files from "
172-
"https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract them "
173-
"in ``root/celeba``."
174-
)
149+
for (file_id, md5, filename) in self.file_list:
150+
download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
151+
152+
extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))
175153

176154
def __getitem__(self, index: int) -> Tuple[Any, Any]:
177155
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))

torchvision/prototype/datasets/_builtin/celeba.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212
from torchvision.prototype.datasets.utils import (
1313
Dataset,
14-
ManualDownloadResource,
14+
GDriveResource,
1515
OnlineResource,
1616
)
1717
from torchvision.prototype.datasets.utils._internal import (
@@ -85,34 +85,33 @@ def __init__(
8585
super().__init__(root, skip_integrity_check=skip_integrity_check)
8686

8787
def _resources(self) -> List[OnlineResource]:
88-
instructions = "Please download the file from https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html."
89-
splits = ManualDownloadResource(
90-
instructions=instructions,
88+
splits = GDriveResource(
89+
"0B7EVK8r0v71pY0NSMzRuSXJEVkk",
9190
sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7",
9291
file_name="list_eval_partition.txt",
9392
)
94-
images = ManualDownloadResource(
95-
instructions=instructions,
93+
images = GDriveResource(
94+
"0B7EVK8r0v71pZjFTYXZWM3FlRnM",
9695
sha256="46fb89443c578308acf364d7d379fe1b9efb793042c0af734b6112e4fd3a8c74",
9796
file_name="img_align_celeba.zip",
9897
)
99-
identities = ManualDownloadResource(
100-
instructions=instructions,
98+
identities = GDriveResource(
99+
"1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS",
101100
sha256="c6143857c3e2630ac2da9f782e9c1232e5e59be993a9d44e8a7916c78a6158c0",
102101
file_name="identity_CelebA.txt",
103102
)
104-
attributes = ManualDownloadResource(
105-
instructions=instructions,
103+
attributes = GDriveResource(
104+
"0B7EVK8r0v71pblRyaVFSWGxPY0U",
106105
sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0",
107106
file_name="list_attr_celeba.txt",
108107
)
109-
bounding_boxes = ManualDownloadResource(
110-
instructions=instructions,
108+
bounding_boxes = GDriveResource(
109+
"0B7EVK8r0v71pbThiMVRxWXZ4dU0",
111110
sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b",
112111
file_name="list_bbox_celeba.txt",
113112
)
114-
landmarks = ManualDownloadResource(
115-
instructions=instructions,
113+
landmarks = GDriveResource(
114+
"0B7EVK8r0v71pd0FJY3Blby1HUTQ",
116115
sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b",
117116
file_name="list_landmarks_align_celeba.txt",
118117
)

torchvision/prototype/datasets/utils/_resource.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,9 @@ def __init__(self, instructions: str, **kwargs: Any) -> None:
216216

217217
def _download(self, root: pathlib.Path) -> NoReturn:
218218
raise RuntimeError(
219-
f"The file {self.file_name} was not found, and cannot be downloaded automatically.\n\n"
220-
f"{self.instructions.strip()}\n\n"
221-
f"Once it is downloaded, please place the file in {root}."
219+
f"The file {self.file_name} cannot be downloaded automatically. "
220+
f"Please follow the instructions below and place it in {root}\n\n"
221+
f"{self.instructions}"
222222
)
223223

224224

0 commit comments

Comments
 (0)