1
1
import csv
2
2
import os
3
- import warnings
4
3
from collections import namedtuple
5
4
from typing import Any , Callable , List , Optional , Union , Tuple
6
5
7
6
import PIL
8
7
import torch
9
8
10
- from .utils import check_integrity , verify_str_arg
9
+ from .utils import download_file_from_google_drive , check_integrity , verify_str_arg , extract_archive
11
10
from .vision import VisionDataset
12
11
13
12
CSV = namedtuple ("CSV" , ["header" , "index" , "data" ])
@@ -36,17 +35,9 @@ class CelebA(VisionDataset):
36
35
and returns a transformed version. E.g, ``transforms.PILToTensor``
37
36
target_transform (callable, optional): A function/transform that takes in the
38
37
target and transforms it.
39
- download (bool, optional): Deprecated.
40
-
41
- .. warning::
42
-
43
- Downloading CelebA is not supported anymore as of 0.13 and this
44
- parameter will be removed in 0.15. See
45
- `this issue <https://github.com/pytorch/vision/issues/5705>`__
46
- for more details.
47
- Please download the files from
48
- https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract
49
- them in ``root/celeba``.
38
+ download (bool, optional): If true, downloads the dataset from the internet and
39
+ puts it in root directory. If dataset is already downloaded, it is not
40
+ downloaded again.
50
41
"""
51
42
52
43
base_folder = "celeba"
@@ -73,7 +64,7 @@ def __init__(
73
64
target_type : Union [List [str ], str ] = "attr" ,
74
65
transform : Optional [Callable ] = None ,
75
66
target_transform : Optional [Callable ] = None ,
76
- download : bool = None ,
67
+ download : bool = False ,
77
68
) -> None :
78
69
super ().__init__ (root , transform = transform , target_transform = target_transform )
79
70
self .split = split
@@ -85,15 +76,6 @@ def __init__(
85
76
if not self .target_type and self .target_transform is not None :
86
77
raise RuntimeError ("target_transform is specified but target_type is empty" )
87
78
88
- if download is not None :
89
- warnings .warn (
90
- "Downloading CelebA is not supported anymore as of 0.13, and the "
91
- "download parameter will be removed in 0.15. See "
92
- "https://github.com/pytorch/vision/issues/5705 for more details. "
93
- "Please download the files from "
94
- "https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract them "
95
- "in ``root/celeba``."
96
- )
97
79
if download :
98
80
self .download ()
99
81
@@ -164,14 +146,10 @@ def download(self) -> None:
164
146
print ("Files already downloaded and verified" )
165
147
return
166
148
167
- raise ValueError (
168
- "Downloading CelebA is not supported anymore as of 0.13, and the "
169
- "download parameter will be removed in 0.15. See "
170
- "https://github.com/pytorch/vision/issues/5705 for more details. "
171
- "Please download the files from "
172
- "https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract them "
173
- "in ``root/celeba``."
174
- )
149
+ for (file_id , md5 , filename ) in self .file_list :
150
+ download_file_from_google_drive (file_id , os .path .join (self .root , self .base_folder ), filename , md5 )
151
+
152
+ extract_archive (os .path .join (self .root , self .base_folder , "img_align_celeba.zip" ))
175
153
176
154
def __getitem__ (self , index : int ) -> Tuple [Any , Any ]:
177
155
X = PIL .Image .open (os .path .join (self .root , self .base_folder , "img_align_celeba" , self .filename [index ]))
0 commit comments