1
1
import os
2
- from typing import Any , Callable , List , Optional , Tuple
2
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
3
3
4
4
from PIL import Image
5
5
@@ -38,7 +38,7 @@ def __init__(
38
38
transform : Optional [Callable ] = None ,
39
39
target_transform : Optional [Callable ] = None ,
40
40
download : bool = False ,
41
- ):
41
+ ) -> None :
42
42
super ().__init__ (os .path .join (root , self .base_folder ), transform = transform , target_transform = target_transform )
43
43
44
44
self .image_set = verify_str_arg (image_set .lower (), "image_set" , self .file_dict .keys ())
@@ -62,7 +62,7 @@ def _loader(self, path: str) -> Image.Image:
62
62
img = Image .open (f )
63
63
return img .convert ("RGB" )
64
64
65
- def _check_integrity (self ):
65
+ def _check_integrity (self ) -> bool :
66
66
st1 = check_integrity (os .path .join (self .root , self .filename ), self .md5 )
67
67
st2 = check_integrity (os .path .join (self .root , self .labels_file ), self .checksums [self .labels_file ])
68
68
if not st1 or not st2 :
@@ -71,7 +71,7 @@ def _check_integrity(self):
71
71
return check_integrity (os .path .join (self .root , self .names ), self .checksums [self .names ])
72
72
return True
73
73
74
- def download (self ):
74
+ def download (self ) -> None :
75
75
if self ._check_integrity ():
76
76
print ("Files already downloaded and verified" )
77
77
return
@@ -81,13 +81,13 @@ def download(self):
81
81
if self .view == "people" :
82
82
download_url (f"{ self .download_url_prefix } { self .names } " , self .root )
83
83
84
- def _get_path (self , identity , no ) :
84
+ def _get_path (self , identity : str , no : Union [ int , str ]) -> str :
85
85
return os .path .join (self .images_dir , identity , f"{ identity } _{ int (no ):04d} .jpg" )
86
86
87
87
def extra_repr (self ) -> str :
88
88
return f"Alignment: { self .image_set } \n Split: { self .split } "
89
89
90
- def __len__ (self ):
90
+ def __len__ (self ) -> int :
91
91
return len (self .data )
92
92
93
93
@@ -119,13 +119,13 @@ def __init__(
119
119
transform : Optional [Callable ] = None ,
120
120
target_transform : Optional [Callable ] = None ,
121
121
download : bool = False ,
122
- ):
122
+ ) -> None :
123
123
super ().__init__ (root , split , image_set , "people" , transform , target_transform , download )
124
124
125
125
self .class_to_idx = self ._get_classes ()
126
126
self .data , self .targets = self ._get_people ()
127
127
128
- def _get_people (self ):
128
+ def _get_people (self ) -> Tuple [ List [ str ], List [ int ]] :
129
129
data , targets = [], []
130
130
with open (os .path .join (self .root , self .labels_file )) as f :
131
131
lines = f .readlines ()
@@ -143,7 +143,7 @@ def _get_people(self):
143
143
144
144
return data , targets
145
145
146
- def _get_classes (self ):
146
+ def _get_classes (self ) -> Dict [ str , int ] :
147
147
with open (os .path .join (self .root , self .names )) as f :
148
148
lines = f .readlines ()
149
149
names = [line .strip ().split ()[0 ] for line in lines ]
@@ -201,12 +201,12 @@ def __init__(
201
201
transform : Optional [Callable ] = None ,
202
202
target_transform : Optional [Callable ] = None ,
203
203
download : bool = False ,
204
- ):
204
+ ) -> None :
205
205
super ().__init__ (root , split , image_set , "pairs" , transform , target_transform , download )
206
206
207
207
self .pair_names , self .data , self .targets = self ._get_pairs (self .images_dir )
208
208
209
- def _get_pairs (self , images_dir ) :
209
+ def _get_pairs (self , images_dir : str ) -> Tuple [ List [ Tuple [ str , str ]], List [ Tuple [ str , str ]], List [ int ]] :
210
210
pair_names , data , targets = [], [], []
211
211
with open (os .path .join (self .root , self .labels_file )) as f :
212
212
lines = f .readlines ()
0 commit comments