Skip to content

Commit f2e76af

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] style: Added typing to datasets/lfw (#6844)
Summary: Co-authored-by: Philip Meier <[email protected]> Reviewed By: datumbox Differential Revision: D40851024 fbshipit-source-id: 0f4de9f2868aa3453d0670d0dc4d9fbaf9f504d4
1 parent 74c55d6 commit f2e76af

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

torchvision/datasets/lfw.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Any, Callable, List, Optional, Tuple
2+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
33

44
from PIL import Image
55

@@ -38,7 +38,7 @@ def __init__(
3838
transform: Optional[Callable] = None,
3939
target_transform: Optional[Callable] = None,
4040
download: bool = False,
41-
):
41+
) -> None:
4242
super().__init__(os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform)
4343

4444
self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys())
@@ -62,7 +62,7 @@ def _loader(self, path: str) -> Image.Image:
6262
img = Image.open(f)
6363
return img.convert("RGB")
6464

65-
def _check_integrity(self):
65+
def _check_integrity(self) -> bool:
6666
st1 = check_integrity(os.path.join(self.root, self.filename), self.md5)
6767
st2 = check_integrity(os.path.join(self.root, self.labels_file), self.checksums[self.labels_file])
6868
if not st1 or not st2:
@@ -71,7 +71,7 @@ def _check_integrity(self):
7171
return check_integrity(os.path.join(self.root, self.names), self.checksums[self.names])
7272
return True
7373

74-
def download(self):
74+
def download(self) -> None:
7575
if self._check_integrity():
7676
print("Files already downloaded and verified")
7777
return
@@ -81,13 +81,13 @@ def download(self):
8181
if self.view == "people":
8282
download_url(f"{self.download_url_prefix}{self.names}", self.root)
8383

84-
def _get_path(self, identity, no):
84+
def _get_path(self, identity: str, no: Union[int, str]) -> str:
8585
return os.path.join(self.images_dir, identity, f"{identity}_{int(no):04d}.jpg")
8686

8787
def extra_repr(self) -> str:
8888
return f"Alignment: {self.image_set}\nSplit: {self.split}"
8989

90-
def __len__(self):
90+
def __len__(self) -> int:
9191
return len(self.data)
9292

9393

@@ -119,13 +119,13 @@ def __init__(
119119
transform: Optional[Callable] = None,
120120
target_transform: Optional[Callable] = None,
121121
download: bool = False,
122-
):
122+
) -> None:
123123
super().__init__(root, split, image_set, "people", transform, target_transform, download)
124124

125125
self.class_to_idx = self._get_classes()
126126
self.data, self.targets = self._get_people()
127127

128-
def _get_people(self):
128+
def _get_people(self) -> Tuple[List[str], List[int]]:
129129
data, targets = [], []
130130
with open(os.path.join(self.root, self.labels_file)) as f:
131131
lines = f.readlines()
@@ -143,7 +143,7 @@ def _get_people(self):
143143

144144
return data, targets
145145

146-
def _get_classes(self):
146+
def _get_classes(self) -> Dict[str, int]:
147147
with open(os.path.join(self.root, self.names)) as f:
148148
lines = f.readlines()
149149
names = [line.strip().split()[0] for line in lines]
@@ -201,12 +201,12 @@ def __init__(
201201
transform: Optional[Callable] = None,
202202
target_transform: Optional[Callable] = None,
203203
download: bool = False,
204-
):
204+
) -> None:
205205
super().__init__(root, split, image_set, "pairs", transform, target_transform, download)
206206

207207
self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir)
208208

209-
def _get_pairs(self, images_dir):
209+
def _get_pairs(self, images_dir: str) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]], List[int]]:
210210
pair_names, data, targets = [], [], []
211211
with open(os.path.join(self.root, self.labels_file)) as f:
212212
lines = f.readlines()

0 commit comments

Comments
 (0)