Skip to content

Commit 188f8eb

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
Fixing mypy errors (#3335)
Summary: * Fixing mypy errors. * Fixing typing issue. Reviewed By: datumbox Differential Revision: D26226616 fbshipit-source-id: 5ebd98d70d9ae6edc2b8e960c6cea38279fd60c1
1 parent eb2e156 commit 188f8eb

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

torchvision/datasets/semeion.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ def __init__(
4141
raise RuntimeError('Dataset not found or corrupted.' +
4242
' You can use download=True to download it')
4343

44-
self.data = []
45-
self.labels = []
4644
fp = os.path.join(self.root, self.filename)
4745
data = np.loadtxt(fp)
4846
# convert value to 8 bit unsigned integer

torchvision/datasets/stl10.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
'You can use download=True to download it')
6868

6969
# now load the picked numpy arrays
70-
self.labels: np.ndarray
70+
self.labels: Optional[np.ndarray]
7171
if self.split == 'train':
7272
self.data, self.labels = self.__loadfile(
7373
self.train_list[0][0], self.train_list[1][0])
@@ -182,4 +182,6 @@ def __load_folds(self, folds: Optional[int]) -> None:
182182
with open(path_to_folds, 'r') as f:
183183
str_idx = f.read().splitlines()[folds]
184184
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ')
185-
self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx]
185+
self.data = self.data[list_idx, :, :, :]
186+
if self.labels is not None:
187+
self.labels = self.labels[list_idx]

torchvision/datasets/usps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def __init__(
5757
import bz2
5858
with bz2.open(full_path) as fp:
5959
raw_data = [line.decode().split() for line in fp.readlines()]
60-
imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
61-
imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16))
60+
tmp_list = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
61+
imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
6262
imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8)
6363
targets = [int(d[0]) - 1 for d in raw_data]
6464

0 commit comments

Comments
 (0)