Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 103 additions & 47 deletions utils/download_test_db.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,120 @@
import pickle
import gzip
from __future__ import division
from __future__ import print_function

import argparse
import array
import urllib
import tarfile
import collections
import gzip
import os.path
import pickle
import sys
import tarfile
import urllib

try:
from urllib.error import URLError
except ImportError:
from urllib2 import URLError


Dataset = collections.namedtuple('Dataset', 'filename, url, handler')


def handle_mnist(filename):
print('Extracting {} ...'.format(filename))
with gzip.open(filename, 'rb') as file:
training_set, _, _ = pickle.load(file)
data, labels = training_set

images_file = open('mnist_images.bin', 'wb')
data.tofile(images_file)
images_file.close()

labels_file = open('mnist_labels.bin', 'wb')
L = array.array('B', labels)
L.tofile(labels_file)
labels_file.close()


# This script downloads and extracts the mnist and cifar-10 databases.
def untar(filename):
print('Extracting {} ...'.format(filename))
tar = tarfile.open(filename, "r:gz")
tar.extractall()
tar.close()

print("""Downloading test files. If the download fails try setting up a proxy:
#export http_proxy="http://fwdproxy:8080

""")
DATASETS = dict(
mnist=Dataset(
'mnist.pkl.gz',
'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz',
handle_mnist,
),
cifar10=Dataset(
'cifar-10.binary.tar.gz',
'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz',
untar,
),
ptb=Dataset(
'ptb.tgz',
'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz',
untar,
),
)
DATASET_NAMES = list(DATASETS.keys())

mnist_filename = "mnist.pkl.gz"
cifar10_filename = "cifar-10.binary.tar.gz"
ptb_filename = "ptb.tgz"

if os.path.exists(mnist_filename):
print("MNIST file found. Not downloading.")
else:
print("Downloading MNIST ... ")
urllib.urlretrieve ("http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz", mnist_filename)
def report_download_progress(chunk_number, chunk_size, file_size):
if file_size != -1:
percent = min(1, (chunk_number * chunk_size) / file_size)
bar = '#' * int(64 * percent)
sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100)))

if os.path.exists(cifar10_filename):
print("CIFAR file found. Not downloading.")
else:
print("Downloading CIFAR ... ")
urllib.urlretrieve ("http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz", cifar10_filename)

if os.path.exists(ptb_filename):
print("PTB file found. Not downloading.")
else:
print("Downloading PTB ... ")
urllib.urlretrieve ("http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz", ptb_filename)
def download_dataset(dataset):
if os.path.exists(dataset.filename):
print('{} already exists, skipping ...'.format(dataset.filename))
else:
print('Downloading {} from {} ...'.format(dataset.filename,
dataset.url))
try:
urllib.urlretrieve(
dataset.url,
dataset.filename,
reporthook=report_download_progress)
except URLError:
print('Error downloading {}!'.format(dataset.filename))
finally:
# Just a newline.
print()

def dumpToFile(dataset):
data, labels = dataset

imagesFile = open('mnist_images.bin', 'wb')
data.tofile(imagesFile)
imagesFile.close()
def parse():
parser = argparse.ArgumentParser(description='Download datasets for Glow')
parser.add_argument('-d', '--datasets', nargs='+', choices=DATASET_NAMES)
parser.add_argument('-a', '--all', action='store_true')
options = parser.parse_args()

labelsFile = open('mnist_labels.bin', 'wb')
L = array.array('B', labels)
L.tofile(labelsFile)
labelsFile.close()
if options.all:
datasets = DATASET_NAMES
elif options.datasets:
datasets = options.datasets
else:
parser.error('Must specify at least one dataset or --all.')

print("Extracting the mnist database.")
return datasets

with gzip.open(mnist_filename, 'rb') as f:
train_set, valid_set, test_set = pickle.load(f)
dumpToFile(train_set)

def main():
datasets = parse()
try:
for name in datasets:
dataset = DATASETS[name]
download_dataset(dataset)
dataset.handler(dataset.filename)
print('Done.')
except KeyboardInterrupt:
print('Interrupted')

print("Extracting the CIFAR-10 database.")
tar = tarfile.open(cifar10_filename, "r:gz")
tar.extractall()
tar.close()

print("Extracting the PTB database.")
tar = tarfile.open(ptb_filename, "r:gz")
tar.extractall('ptb')
tar.close()
if __name__ == '__main__':
main()