Skip to content

Commit 072aa71

Browse files
authored
Put back CelebA download (#6147) (#6149)
* Revert "Indicate Celeba download parameter is deprecated and will be removed (#6059)" This reverts commit 49496c4. * Revert "Deactivate CelebA download (#6052)" This reverts commit ac56f52.
1 parent 1a288d1 commit 072aa71

File tree

1 file changed

+9
-31
lines changed

1 file changed

+9
-31
lines changed

torchvision/datasets/celeba.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import csv
22
import os
3-
import warnings
43
from collections import namedtuple
54
from typing import Any, Callable, List, Optional, Union, Tuple
65

76
import PIL
87
import torch
98

10-
from .utils import check_integrity, verify_str_arg
9+
from .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive
1110
from .vision import VisionDataset
1211

1312
CSV = namedtuple("CSV", ["header", "index", "data"])
@@ -36,17 +35,9 @@ class CelebA(VisionDataset):
3635
and returns a transformed version. E.g, ``transforms.PILToTensor``
3736
target_transform (callable, optional): A function/transform that takes in the
3837
target and transforms it.
39-
download (bool, optional): Deprecated.
40-
41-
.. warning::
42-
43-
Downloading CelebA is not supported anymore as of 0.13 and this
44-
parameter will be removed in 0.15. See
45-
`this issue <https://github.com/pytorch/vision/issues/5705>`__
46-
for more details.
47-
Please download the files from
48-
https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract
49-
them in ``root/celeba``.
38+
download (bool, optional): If true, downloads the dataset from the internet and
39+
puts it in root directory. If dataset is already downloaded, it is not
40+
downloaded again.
5041
"""
5142

5243
base_folder = "celeba"
@@ -73,7 +64,7 @@ def __init__(
7364
target_type: Union[List[str], str] = "attr",
7465
transform: Optional[Callable] = None,
7566
target_transform: Optional[Callable] = None,
76-
download: bool = None,
67+
download: bool = False,
7768
) -> None:
7869
super().__init__(root, transform=transform, target_transform=target_transform)
7970
self.split = split
@@ -85,15 +76,6 @@ def __init__(
8576
if not self.target_type and self.target_transform is not None:
8677
raise RuntimeError("target_transform is specified but target_type is empty")
8778

88-
if download is not None:
89-
warnings.warn(
90-
"Downloading CelebA is not supported anymore as of 0.13, and the "
91-
"download parameter will be removed in 0.15. See "
92-
"https://github.com/pytorch/vision/issues/5705 for more details. "
93-
"Please download the files from "
94-
"https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract them "
95-
"in ``root/celeba``."
96-
)
9779
if download:
9880
self.download()
9981

@@ -164,14 +146,10 @@ def download(self) -> None:
164146
print("Files already downloaded and verified")
165147
return
166148

167-
raise ValueError(
168-
"Downloading CelebA is not supported anymore as of 0.13, and the "
169-
"download parameter will be removed in 0.15. See "
170-
"https://github.com/pytorch/vision/issues/5705 for more details. "
171-
"Please download the files from "
172-
"https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract them "
173-
"in ``root/celeba``."
174-
)
149+
for (file_id, md5, filename) in self.file_list:
150+
download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
151+
152+
extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))
175153

176154
def __getitem__(self, index: int) -> Tuple[Any, Any]:
177155
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))

0 commit comments

Comments
 (0)