diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 7fa6dcf7666..c8158acfeb7 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -7,6 +7,7 @@ import numpy as np import torch import codecs +from .utils import download_url class MNIST(data.Dataset): @@ -120,12 +121,10 @@ def download(self): raise for url in self.urls: - print('Downloading ' + url) - data = urllib.request.urlopen(url) filename = url.rpartition('/')[2] file_path = os.path.join(self.root, self.raw_folder, filename) - with open(file_path, 'wb') as f: - f.write(data.read()) + download_url(url, root=os.path.join(self.root, self.raw_folder), + filename=filename, md5=None) with open(file_path.replace('.gz', ''), 'wb') as out_f, \ gzip.GzipFile(file_path) as zip_f: out_f.write(zip_f.read()) @@ -247,13 +246,10 @@ def download(self): else: raise - print('Downloading ' + self.url) - data = urllib.request.urlopen(self.url) filename = self.url.rpartition('/')[2] raw_folder = os.path.join(self.root, self.raw_folder) file_path = os.path.join(raw_folder, filename) - with open(file_path, 'wb') as f: - f.write(data.read()) + download_url(self.url, root=file_path, filename=filename, md5=None) print('Extracting zip archive') with zipfile.ZipFile(file_path) as zip_f: diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index fef84cbeea7..43e5896801a 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -15,7 +15,9 @@ def bar_update(count, block_size, total_size): return bar_update -def check_integrity(fpath, md5): +def check_integrity(fpath, md5=None): + if md5 is None: + return True if not os.path.isfile(fpath): return False md5o = hashlib.md5()