Skip to content

Improve download_test_db.py #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 3, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 103 additions & 47 deletions utils/download_test_db.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,120 @@
import pickle
import gzip
from __future__ import division
from __future__ import print_function

import argparse
import array
import urllib
import tarfile
import collections
import gzip
import os.path
import pickle
import sys
import tarfile
import urllib

try:
from urllib.error import URLError
except ImportError:
from urllib2 import URLError


Dataset = collections.namedtuple('Dataset', 'filename, url, handler')


def handle_mnist(filename):
print('Extracting {} ...'.format(filename))
with gzip.open(filename, 'rb') as file:
training_set, _, _ = pickle.load(file)
data, labels = training_set

images_file = open('mnist_images.bin', 'wb')
data.tofile(images_file)
images_file.close()

labels_file = open('mnist_labels.bin', 'wb')
L = array.array('B', labels)
L.tofile(labels_file)
labels_file.close()


# This script downloads and extracts the mnist and cifar-10 databases.
def untar(filename):
print('Extracting {} ...'.format(filename))
tar = tarfile.open(filename, "r:gz")
tar.extractall()
tar.close()

print("""Downloading test files. If the download fails try setting up a proxy:
#export http_proxy="http://fwdproxy:8080

""")
DATASETS = dict(
mnist=Dataset(
'mnist.pkl.gz',
'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz',
handle_mnist,
),
cifar10=Dataset(
'cifar-10.binary.tar.gz',
'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz',
untar,
),
ptb=Dataset(
'ptb.tgz',
'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz',
untar,
),
)
DATASET_NAMES = list(DATASETS.keys())

mnist_filename = "mnist.pkl.gz"
cifar10_filename = "cifar-10.binary.tar.gz"
ptb_filename = "ptb.tgz"

if os.path.exists(mnist_filename):
print("MNIST file found. Not downloading.")
else:
print("Downloading MNIST ... ")
urllib.urlretrieve ("http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz", mnist_filename)
def report_download_progress(chunk_number, chunk_size, file_size):
if file_size != -1:
percent = min(1, (chunk_number * chunk_size) / file_size)
bar = '#' * int(64 * percent)
sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100)))

if os.path.exists(cifar10_filename):
print("CIFAR file found. Not downloading.")
else:
print("Downloading CIFAR ... ")
urllib.urlretrieve ("http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz", cifar10_filename)

if os.path.exists(ptb_filename):
print("PTB file found. Not downloading.")
else:
print("Downloading PTB ... ")
urllib.urlretrieve ("http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz", ptb_filename)
def download_dataset(dataset):
if os.path.exists(dataset.filename):
print('{} already exists, skipping ...'.format(dataset.filename))
else:
print('Downloading {} from {} ...'.format(dataset.filename,
dataset.url))
try:
urllib.urlretrieve(
dataset.url,
dataset.filename,
reporthook=report_download_progress)
except URLError:
print('Error downloading {}!'.format(dataset.filename))
finally:
# Just a newline.
print()

def dumpToFile(dataset):
data, labels = dataset

imagesFile = open('mnist_images.bin', 'wb')
data.tofile(imagesFile)
imagesFile.close()
def parse():
parser = argparse.ArgumentParser(description='Download datasets for Glow')
parser.add_argument('-d', '--datasets', nargs='+', choices=DATASET_NAMES)
parser.add_argument('-a', '--all', action='store_true')
options = parser.parse_args()

labelsFile = open('mnist_labels.bin', 'wb')
L = array.array('B', labels)
L.tofile(labelsFile)
labelsFile.close()
if options.all:
datasets = DATASET_NAMES
elif options.datasets:
datasets = options.datasets
else:
parser.error('Must specify at least one dataset or --all.')

print("Extracting the mnist database.")
return datasets

with gzip.open(mnist_filename, 'rb') as f:
train_set, valid_set, test_set = pickle.load(f)
dumpToFile(train_set)

def main():
datasets = parse()
try:
for name in datasets:
dataset = DATASETS[name]
download_dataset(dataset)
dataset.handler(dataset.filename)
print('Done.')
except KeyboardInterrupt:
print('Interrupted')

print("Extracting the CIFAR-10 database.")
tar = tarfile.open(cifar10_filename, "r:gz")
tar.extractall()
tar.close()

print("Extracting the PTB database.")
tar = tarfile.open(ptb_filename, "r:gz")
tar.extractall('ptb')
tar.close()
if __name__ == '__main__':
main()