Skip to content

Commit 4c93238

Browse files
committed
fixing inference to use volatile variables
1 parent cf9c232 commit 4c93238

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

word_language_model/generate.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
###############################################################################
22
# Language Modeling on Penn Tree Bank
33
#
4-
# With the default parameters, this should achieve ~116 perplexity on the
5-
# test set.
4+
# This file generates new sentences sampled from the language model
5+
#
66
###############################################################################
77

88
import argparse
@@ -38,27 +38,26 @@
3838
with open(args.checkpoint, 'rb') as f:
3939
model = torch.load(f)
4040

41-
# Waiting on https://github.com/pytorch/pytorch/issues/188
42-
# if args.cuda:
43-
# model.cuda()
44-
# else:
45-
# model.cpu()
41+
if args.cuda:
42+
model.cuda()
43+
else:
44+
model.cpu()
4645

4746
corpus = data.Corpus(args.data)
4847
ntokens = corpus.dic.ntokens()
4948

5049
hidden = model.initHidden(1)
5150

52-
input = torch.LongTensor(1,1).fill_(math.floor(torch.rand(1)[0] * ntokens))
51+
input = torch.LongTensor(1,1).fill_(int(math.floor(torch.rand(1)[0] * ntokens)))
5352
if args.cuda:
5453
input = input.cuda()
5554

5655
temperature = max(args.temperature, 1e-3)
5756
with open(args.outf, 'w') as outf:
5857
for i in range(args.nwords):
5958

60-
output, hidden = model(Variable(input, requires_grad=False), hidden)
61-
gen = torch.multinomial(output[0].data.cpu().div(temperature).exp(), 1)[0][0] # FIXME: no multinomial on GPU?
59+
output, hidden = model(Variable(input, volatile=True), hidden)
60+
gen = torch.multinomial(output[0].data.div(temperature).exp().cpu(), 1)[0][0] # FIXME: multinomial is only for CPU
6261
input.fill_(gen)
6362
word = corpus.dic.idx2word[gen]
6463
outf.write(word)

0 commit comments

Comments
 (0)