Skip to content

Commit 71a4c7c

Browse files
committed
Improve download_test_db.py
1 parent 8813224 commit 71a4c7c

File tree

1 file changed

+103
-47
lines changed

1 file changed

+103
-47
lines changed

utils/download_test_db.py

Lines changed: 103 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,120 @@
1-
import pickle
2-
import gzip
1+
from __future__ import division
2+
from __future__ import print_function
3+
4+
import argparse
35
import array
4-
import urllib
5-
import tarfile
6+
import collections
7+
import gzip
68
import os.path
9+
import pickle
10+
import sys
11+
import tarfile
12+
import urllib
13+
14+
try:
15+
from urllib.error import URLError
16+
except ImportError:
17+
from urllib2 import URLError
18+
19+
20+
Dataset = collections.namedtuple('Dataset', 'filename, url, handler')
21+
22+
23+
def handle_mnist(filename):
24+
print('Extracting {} ...'.format(filename))
25+
with gzip.open(filename, 'rb') as file:
26+
training_set, _, _ = pickle.load(file)
27+
data, labels = training_set
28+
29+
images_file = open('mnist_images.bin', 'wb')
30+
data.tofile(images_file)
31+
images_file.close()
32+
33+
labels_file = open('mnist_labels.bin', 'wb')
34+
L = array.array('B', labels)
35+
L.tofile(labels_file)
36+
labels_file.close()
37+
738

8-
# This script downloads and extracts the mnist and cifar-10 databases.
39+
def untar(filename):
40+
print('Extracting {} ...'.format(filename))
41+
tar = tarfile.open(filename, "r:gz")
42+
tar.extractall()
43+
tar.close()
944

10-
print("""Downloading test files. If the download fails try setting up a proxy:
11-
#export http_proxy="http://fwdproxy:8080
1245

13-
""")
46+
DATASETS = dict(
47+
mnist=Dataset(
48+
'mnist.pkl.gz',
49+
'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz',
50+
handle_mnist,
51+
),
52+
cifar10=Dataset(
53+
'cifar-10.binary.tar.gz',
54+
'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz',
55+
untar,
56+
),
57+
ptb=Dataset(
58+
'ptb.tgz',
59+
'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz',
60+
untar,
61+
),
62+
)
63+
DATASET_NAMES = list(DATASETS.keys())
1464

15-
mnist_filename = "mnist.pkl.gz"
16-
cifar10_filename = "cifar-10.binary.tar.gz"
17-
ptb_filename = "ptb.tgz"
1865

19-
if os.path.exists(mnist_filename):
20-
print("MNIST file found. Not downloading.")
21-
else:
22-
print("Downloading MNIST ... ")
23-
urllib.urlretrieve ("http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz", mnist_filename)
66+
def report_download_progress(chunk_number, chunk_size, file_size):
67+
if file_size != -1:
68+
percent = min(1, (chunk_number * chunk_size) / file_size)
69+
bar = '#' * int(64 * percent)
70+
sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100)))
2471

25-
if os.path.exists(cifar10_filename):
26-
print("CIFAR file found. Not downloading.")
27-
else:
28-
print("Downloading CIFAR ... ")
29-
urllib.urlretrieve ("http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz", cifar10_filename)
3072

31-
if os.path.exists(ptb_filename):
32-
print("PTB file found. Not downloading.")
33-
else:
34-
print("Downloading PTB ... ")
35-
urllib.urlretrieve ("http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz", ptb_filename)
73+
def download_dataset(dataset):
74+
if os.path.exists(dataset.filename):
75+
print('{} already exists, skipping ...'.format(dataset.filename))
76+
else:
77+
print('Downloading {} from {} ...'.format(dataset.filename,
78+
dataset.url))
79+
try:
80+
urllib.urlretrieve(
81+
dataset.url,
82+
dataset.filename,
83+
reporthook=report_download_progress)
84+
except URLError:
85+
print('Error downloading {}!'.format(dataset.filename))
86+
finally:
87+
# Just a newline.
88+
print()
3689

37-
def dumpToFile(dataset):
38-
data, labels = dataset
3990

40-
imagesFile = open('mnist_images.bin', 'wb')
41-
data.tofile(imagesFile)
42-
imagesFile.close()
91+
def parse():
92+
parser = argparse.ArgumentParser(description='Download datasets for Glow')
93+
parser.add_argument('-d', '--datasets', nargs='+', choices=DATASET_NAMES)
94+
parser.add_argument('-a', '--all', action='store_true')
95+
options = parser.parse_args()
4396

44-
labelsFile = open('mnist_labels.bin', 'wb')
45-
L = array.array('B', labels)
46-
L.tofile(labelsFile)
47-
labelsFile.close()
97+
if options.all:
98+
datasets = DATASET_NAMES
99+
elif options.datasets:
100+
datasets = options.datasets
101+
else:
102+
parser.error('Must specify at least one dataset or --all.')
48103

49-
print("Extracting the mnist database.")
104+
return datasets
50105

51-
with gzip.open(mnist_filename, 'rb') as f:
52-
train_set, valid_set, test_set = pickle.load(f)
53-
dumpToFile(train_set)
54106

107+
def main():
108+
datasets = parse()
109+
try:
110+
for name in datasets:
111+
dataset = DATASETS[name]
112+
download_dataset(dataset)
113+
dataset.handler(dataset.filename)
114+
print('Done.')
115+
except KeyboardInterrupt:
116+
print('Interrupted')
55117

56-
print("Extracting the CIFAR-10 database.")
57-
tar = tarfile.open(cifar10_filename, "r:gz")
58-
tar.extractall()
59-
tar.close()
60118

61-
print("Extracting the PTB database.")
62-
tar = tarfile.open(ptb_filename, "r:gz")
63-
tar.extractall('ptb')
64-
tar.close()
119+
if __name__ == '__main__':
120+
main()

0 commit comments

Comments
 (0)