From e952e09d0a97be78247c2c0609a03a2df1abbfa5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 1 Feb 2021 14:20:53 +0000 Subject: [PATCH 1/2] Fixing mypy errors. --- torchvision/datasets/semeion.py | 2 -- torchvision/datasets/stl10.py | 2 +- torchvision/datasets/usps.py | 4 ++-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/torchvision/datasets/semeion.py b/torchvision/datasets/semeion.py index 411a5faed19..dad530ffa15 100644 --- a/torchvision/datasets/semeion.py +++ b/torchvision/datasets/semeion.py @@ -41,8 +41,6 @@ def __init__( raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') - self.data = [] - self.labels = [] fp = os.path.join(self.root, self.filename) data = np.loadtxt(fp) # convert value to 8 bit unsigned integer diff --git a/torchvision/datasets/stl10.py b/torchvision/datasets/stl10.py index 1ef861fe563..f4c517f39a2 100644 --- a/torchvision/datasets/stl10.py +++ b/torchvision/datasets/stl10.py @@ -67,7 +67,7 @@ def __init__( 'You can use download=True to download it') # now load the picked numpy arrays - self.labels: np.ndarray + self.labels: Optional[np.ndarray] if self.split == 'train': self.data, self.labels = self.__loadfile( self.train_list[0][0], self.train_list[1][0]) diff --git a/torchvision/datasets/usps.py b/torchvision/datasets/usps.py index e2938dd671b..c315b8d3111 100644 --- a/torchvision/datasets/usps.py +++ b/torchvision/datasets/usps.py @@ -57,8 +57,8 @@ def __init__( import bz2 with bz2.open(full_path) as fp: raw_data = [line.decode().split() for line in fp.readlines()] - imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data] - imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16)) + tmp_list = [[x.split(':')[-1] for x in data[1:]] for data in raw_data] + imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16)) imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8) targets = [int(d[0]) - 1 for d in raw_data] From 5196845e1194a8c3170da567e29cc58bcc4c49ff Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 1 Feb 2021 14:29:00 +0000 Subject: [PATCH 2/2] Fixing typing issue. --- torchvision/datasets/stl10.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/datasets/stl10.py b/torchvision/datasets/stl10.py index f4c517f39a2..35c71bd2cf1 100644 --- a/torchvision/datasets/stl10.py +++ b/torchvision/datasets/stl10.py @@ -182,4 +182,6 @@ def __load_folds(self, folds: Optional[int]) -> None: with open(path_to_folds, 'r') as f: str_idx = f.read().splitlines()[folds] list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ') - self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx] + self.data = self.data[list_idx, :, :, :] + if self.labels is not None: + self.labels = self.labels[list_idx]