Skip to content

Add README and generate.py script #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 31, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions word_language_model/README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,34 @@
# Language Modeling example
# Word-level language modeling RNN

This example showcases training a language model.
By default, the data used is the Penn TreeBank dataset
This example trains a multi-layer RNN (Elman, GRU, or LSTM) on a language modeling task.
By default, the training script uses the PTB dataset, provided.
The trained model can then be used by the generate script to generate new text.

```bash
pip install -r requirements.txt
python main.py
python main.py -cuda # Train an LSTM on ptb with cuda (cuDNN). Should reach perplexity of 116
python generate.py # Generate samples from the trained LSTM model.
```

The model uses the `nn.RNN` module (and its sister modules `nn.GRU` and `nn.LSTM`) which will automatically use the cuDNN backend if run on CUDA with cuDNN installed.

The `main.py` script accepts the following arguments:

```bash
optional arguments:
-h, --help show this help message and exit
-data DATA Location of the data corpus
-model MODEL Type of recurrent net. RNN_TANH, RNN_RELU, LSTM, or
GRU.
-emsize EMSIZE Size of word embeddings
-nhid NHID Number of hidden units per layer.
-nlayers NLAYERS Number of layers.
-lr LR Initial learning rate.
-clip CLIP Gradient clipping.
-maxepoch MAXEPOCH Upper epoch limit.
-batchsize BATCHSIZE Batch size.
-bptt BPTT Sequence length.
-seed SEED Random seed.
-cuda Use CUDA.
-reportint REPORTINT Report interval.
-save SAVE Path to save the final model.
```
72 changes: 72 additions & 0 deletions word_language_model/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
###############################################################################
# Language Modeling on Penn Tree Bank
#
# With the default parameters, this should achieve ~116 perplexity on the
# test set.
###############################################################################

import argparse
import time
import math

import torch
import torch.nn as nn
from torch.autograd import Variable

import data

parser = argparse.ArgumentParser(description='PyTorch PTB Language Model')

# Model parameters.
parser.add_argument('-data' , type=str, default='./data/penn', help='Location of the data corpus' )
parser.add_argument('-checkpoint', type=str, default='./model.pt' , help='Checkpoint file path' )
parser.add_argument('-outf' , type=str, default='generated.out', help='Output file for generated text.' )
parser.add_argument('-nwords' , type=int, default='1000' , help='Number of words of text to generate' )
parser.add_argument('-seed' , type=int, default=1111 , help='Random seed.' )
parser.add_argument('-cuda' , action='store_true' , help='Use CUDA.' )
parser.add_argument('-temperature', type=float, default=1.0 , help='Temperature. Higher will increase diversity')
parser.add_argument('-reportinterval', type=int, default=100 , help='Reporting interval' )
args = parser.parse_args()

# Set the random seed manually for reproducibility.
torch.manual_seed(args.seed)
# If the GPU is enabled, do some plumbing.

if torch.cuda.is_available() and not args.cuda:
print("WARNING: You have a CUDA device, so you should probably run with -cuda")

with open(args.checkpoint, 'rb') as f:
model = torch.load(f)

# Waiting on https://github.com/pytorch/pytorch/issues/188
# if args.cuda:
# model.cuda()
# else:
# model.cpu()

corpus = data.Corpus(args.data)
ntokens = corpus.dic.ntokens()

hidden = model.initHidden(1)

input = torch.LongTensor(1,1).fill_(math.floor(torch.rand(1)[0] * ntokens))
if args.cuda:
input = input.cuda()

temperature = max(args.temperature, 1e-3)
with open(args.outf, 'w') as outf:
for i in range(args.nwords):

output, hidden = model(Variable(input, requires_grad=False), hidden)
gen = torch.multinomial(output[0].data.cpu().div(temperature).exp(), 1)[0][0] # FIXME: no multinomial on GPU?
input.fill_(gen)
word = corpus.dic.idx2word[gen]
outf.write(word)

if i % 20 == 19:
outf.write("\n")
else:
outf.write(" ")

if i % args.reportinterval == 0:
print('| Generated {}/{} words'.format(i, args.nwords))
41 changes: 4 additions & 37 deletions word_language_model/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.autograd import Variable

import data
import model

parser = argparse.ArgumentParser(description='PyTorch PTB Language Model')

Expand Down Expand Up @@ -70,47 +71,13 @@ def batchify(data, bsz):
# MAKE MODEL
###############################################################################

class RNNModel(nn.Container):
"""A container module with an encoder, an RNN (one of several flavors),
and a decoder. Runs one RNN step at a time.
"""

def __init__(self, rnnType, ntoken, ninp, nhid, nlayers):
super(RNNModel, self).__init__(
encoder = nn.sparse.Embedding(ntoken, ninp),
rnn = nn.RNNBase(rnnType, ninp, nhid, nlayers, bias=False),
decoder = nn.Linear(nhid, ntoken),
)

# FIXME: add stdv named argument to reset_parameters
# (and/or to the constructors)
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.uniform_(-initrange, initrange)

def forward(self, input, hidden):
emb = self.encoder(input)
output, hidden = self.rnn(emb, hidden)
decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

ntokens = corpus.dic.ntokens()
model = RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers)
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers)
if args.cuda:
model.cuda()

criterion = nn.CrossEntropyLoss()

def initHidden(model, bsz):
weight = next(model.parameters()).data
if args.model == 'LSTM':
return (Variable(weight.new(args.nlayers, bsz, args.nhid).zero_()),
Variable(weight.new(args.nlayers, bsz, args.nhid).zero_()))
else:
return Variable(weight.new(args.nlayers, bsz, args.nhid).zero_())


########################################
# TRAINING
########################################
Expand All @@ -123,7 +90,7 @@ def initHidden(model, bsz):
# Perform the forward pass only.
def evaluate(model, data, criterion, bsz):
loss = 0
hidden = initHidden(model, bsz)
hidden = model.initHidden(bsz)
# Loop over validation data.
for i in range(0, data.size(0) - 1, bptt):
seq_len = min(bptt, data.size(0) - 1 - i)
Expand Down Expand Up @@ -158,7 +125,7 @@ def repackageHidden(h):
total_loss = 0
epoch_start_time = time.time()
# Start with an initial hidden state.
hidden = initHidden(model, bsz)
hidden = model.initHidden(bsz)

loss = 0
i = 0
Expand Down
40 changes: 40 additions & 0 deletions word_language_model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
import torch.nn as nn
from torch.autograd import Variable

class RNNModel(nn.Container):
"""A container module with an encoder, an RNN (one of several flavors),
and a decoder. Runs one RNN step at a time.
"""

def __init__(self, rnnType, ntoken, ninp, nhid, nlayers):
super(RNNModel, self).__init__(
encoder = nn.sparse.Embedding(ntoken, ninp),
rnn = nn.RNNBase(rnnType, ninp, nhid, nlayers, bias=False),
decoder = nn.Linear(nhid, ntoken),
)

# FIXME: add stdv named argument to reset_parameters
# (and/or to the constructors)
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.uniform_(-initrange, initrange)

self.rnnType = rnnType
self.nhid = nhid
self.nlayers = nlayers

def forward(self, input, hidden):
emb = self.encoder(input)
output, hidden = self.rnn(emb, hidden)
decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

def initHidden(self, bsz):
weight = next(self.parameters()).data
if self.rnnType == 'LSTM':
return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()),
Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()))
else:
return Variable(weight.new(args.nlayers, bsz, args.nhid).zero_())