Skip to content

cleanup mnist #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 15, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion mnist/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@

```bash
pip install -r requirements.txt
python data.py
python main.py
```
117 changes: 64 additions & 53 deletions mnist/main.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,64 @@
from __future__ import print_function
import os
import os, argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

cuda = torch.cuda.is_available()

def print_header(msg):
print('===>', msg)
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batchSize', type=int, default=64, metavar='input batch size')
parser.add_argument('--testBatchSize', type=int, default=1000, metavar='input batch size for testing')
parser.add_argument('--trainSize', type=int, default=1000, metavar='Train dataset size (max=60000). Default: 1000')
parser.add_argument('--nEpochs', type=int, default=2, metavar='number of epochs to train')
parser.add_argument('--lr', type=float, default=0.01, metavar='Learning Rate. Default=0.01')
parser.add_argument('--momentum', type=float, default=0.5, metavar='Default=0.5')
parser.add_argument('--seed', type=int, default=123, metavar='Random Seed to use. Default=123')
opt = parser.parse_args()
print(opt)

torch.manual_seed(opt.seed)
if cuda == True:
torch.cuda.manual_seed(opt.seed)

if not os.path.exists('data/processed/training.pt'):
import data

# Data
print_header('Loading data')
print('===> Loading data')
with open('data/processed/training.pt', 'rb') as f:
training_set = torch.load(f)
with open('data/processed/test.pt', 'rb') as f:
test_set = torch.load(f)

training_data = training_set[0].view(-1, 1, 28, 28).div(255)
training_data = training_data[:opt.trainSize]
training_labels = training_set[1]
test_data = test_set[0].view(-1, 1, 28, 28).div(255)
test_labels = test_set[1]

del training_set
del test_set

# Model
print_header('Building model')
print('===> Building model')
class Net(nn.Container):
def __init__(self):
super(Net, self).__init__(
conv1 = nn.Conv2d(1, 20, 5),
pool1 = nn.MaxPool2d(2, 2),
conv2 = nn.Conv2d(20, 50, 5),
pool2 = nn.MaxPool2d(2, 2),
fc1 = nn.Linear(800, 500),
fc2 = nn.Linear(500, 10),
relu = nn.ReLU(),
softmax = nn.LogSoftmax(),
)
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, 5)
self.pool1 = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(10, 20, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
self.relu = nn.ReLU()
self.softmax = nn.LogSoftmax()

def forward(self, x):
x = self.relu(self.pool1(self.conv1(x)))
x = self.relu(self.pool2(self.conv2(x)))
x = x.view(-1, 800)
x = x.view(-1, 320)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
return self.softmax(x)
Expand All @@ -56,60 +68,59 @@ def forward(self, x):
model.cuda()

criterion = nn.NLLLoss()

# Training settings
BATCH_SIZE = 150
TEST_BATCH_SIZE = 1000
NUM_EPOCHS = 2

optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum)

def train(epoch):
batch_data_t = torch.FloatTensor(BATCH_SIZE, 1, 28, 28)
batch_targets_t = torch.LongTensor(BATCH_SIZE)
# create buffers for mini-batch
batch_data = torch.FloatTensor(opt.batchSize, 1, 28, 28)
batch_targets = torch.LongTensor(opt.batchSize)
if cuda:
batch_data_t = batch_data_t.cuda()
batch_targets_t = batch_targets_t.cuda()
batch_data = Variable(batch_data_t, requires_grad=False)
batch_targets = Variable(batch_targets_t, requires_grad=False)
for i in range(0, training_data.size(0), BATCH_SIZE):
batch_data, batch_targets = batch_data.cuda(), batch_targets.cuda()

# create autograd Variables over these buffers
batch_data, batch_targets = Variable(batch_data), Variable(batch_targets)

for i in range(0, training_data.size(0)-opt.batchSize+1, opt.batchSize):
start, end = i, i+opt.batchSize
optimizer.zero_grad()
batch_data.data[:] = training_data[i:i+BATCH_SIZE]
batch_targets.data[:] = training_labels[i:i+BATCH_SIZE]
loss = criterion(model(batch_data), batch_targets)
batch_data.data[:] = training_data[start:end]
batch_targets.data[:] = training_labels[start:end]
output = model(batch_data)
loss = criterion(output, batch_targets)
loss.backward()
loss = loss.data[0]
optimizer.step()
print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}'.format(epoch,
i+BATCH_SIZE, training_data.size(0),
float(i+BATCH_SIZE)/training_data.size(0)*100, loss))
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}'
.format(epoch, end, opt.trainSize, float(end)/opt.trainSize*100, loss))

def test(epoch):
test_loss = 0
batch_data_t = torch.FloatTensor(TEST_BATCH_SIZE, 1, 28, 28)
batch_targets_t = torch.LongTensor(TEST_BATCH_SIZE)
# create buffers for mini-batch
batch_data = torch.FloatTensor(opt.testBatchSize, 1, 28, 28)
batch_targets = torch.LongTensor(opt.testBatchSize)
if cuda:
batch_data_t = batch_data_t.cuda()
batch_targets_t = batch_targets_t.cuda()
batch_data = Variable(batch_data_t, volatile=True)
batch_targets = Variable(batch_targets_t, volatile=True)
batch_data, batch_targets = batch_data.cuda(), batch_targets.cuda()

# create autograd Variables over these buffers
batch_data = Variable(batch_data, volatile=True)
batch_targets = Variable(batch_targets, volatile=True)

test_loss = 0
correct = 0
for i in range(0, test_data.size(0), TEST_BATCH_SIZE):
print('Testing model: {}/{}'.format(i, test_data.size(0)), end='\r')
batch_data.data[:] = test_data[i:i+TEST_BATCH_SIZE]
batch_targets.data[:] = test_labels[i:i+TEST_BATCH_SIZE]

for i in range(0, test_data.size(0), opt.testBatchSize):
batch_data.data[:] = test_data[i:i+opt.testBatchSize]
batch_targets.data[:] = test_labels[i:i+opt.testBatchSize]
output = model(batch_data)
test_loss += criterion(output, batch_targets)
pred = output.data.max(1)[1]
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.long().eq(batch_targets.data.long()).cpu().sum()

test_loss = test_loss.data[0]
test_loss /= (test_data.size(0) / TEST_BATCH_SIZE) # criterion averages over batch size
print('TEST SET RESULTS:' + ' ' * 20)
print('Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
test_loss /= (test_data.size(0) / opt.testBatchSize) # criterion averages over batch size
print('\nTest Set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, test_data.size(0),
float(correct)/test_data.size(0)*100))

for epoch in range(1, NUM_EPOCHS+1):
for epoch in range(1, opt.nEpochs+1):
train(epoch)
test(epoch)
1 change: 1 addition & 0 deletions mnist/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
torch
six
tqdm
10 changes: 3 additions & 7 deletions word_language_model/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,15 @@
# test set.
###############################################################################

import argparse
import time
import math

import argparse, time, math
import torch
import torch.nn as nn
from torch.autograd import Variable

import data
import model

parser = argparse.ArgumentParser(description='PyTorch PTB Language Model')
parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model')

# Data parameters
parser.add_argument('-data' , type=str, default='./data/penn', help='Location of the data corpus' )
Expand All @@ -41,8 +38,7 @@

# Set the random seed manually for reproducibility.
torch.manual_seed(args.seed)
# If the GPU is enabled, do some plumbing.

# If the GPU is enabled, warn the user to use it
if torch.cuda.is_available() and not args.cuda:
print("WARNING: You have a CUDA device, so you should probably run with -cuda")

Expand Down
8 changes: 4 additions & 4 deletions word_language_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ class RNNModel(nn.Container):
and a decoder. Runs one RNN step at a time.
"""

def __init__(self, rnnType, ntoken, ninp, nhid, nlayers):
def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers):
super(RNNModel, self).__init__(
encoder = nn.sparse.Embedding(ntoken, ninp),
rnn = nn.RNNBase(rnnType, ninp, nhid, nlayers, bias=False),
rnn = nn.RNNBase(rnn_type, ninp, nhid, nlayers, bias=False),
decoder = nn.Linear(nhid, ntoken),
)

Expand All @@ -21,7 +21,7 @@ def __init__(self, rnnType, ntoken, ninp, nhid, nlayers):
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.uniform_(-initrange, initrange)

self.rnnType = rnnType
self.rnn_type = rnn_type
self.nhid = nhid
self.nlayers = nlayers

Expand All @@ -33,7 +33,7 @@ def forward(self, input, hidden):

def initHidden(self, bsz):
weight = next(self.parameters()).data
if self.rnnType == 'LSTM':
if self.rnn_type == 'LSTM':
return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()),
Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()))
else:
Expand Down