@@ -230,10 +230,13 @@ def inference(x, is_train, sequence_length, reuse=None):
230
230
rnn_init = tf .random_uniform_initializer (- init_scale , init_scale )
231
231
with tf .variable_scope ("model" , reuse = reuse ):
232
232
network = EmbeddingInputlayer (x , vocab_size , hidden_size , rnn_init , name = 'embedding' )
233
- network = RNNLayer (network , cell_fn = tf .contrib .rnn .BasicLSTMCell , \
234
- cell_init_args = {'forget_bias' : 0.0 , 'state_is_tuple' : True }, \
235
- n_hidden = hidden_size , initializer = rnn_init , n_steps = sequence_length , return_last = False ,
236
- return_seq_2d = True , name = 'lstm1' )
233
+ network = RNNLayer (
234
+ network , cell_fn = tf .contrib .rnn .BasicLSTMCell , cell_init_args = {
235
+ 'forget_bias' : 0.0 ,
236
+ 'state_is_tuple' : True
237
+ }, n_hidden = hidden_size , initializer = rnn_init , n_steps = sequence_length , return_last = False ,
238
+ return_seq_2d = True , name = 'lstm1'
239
+ )
237
240
lstm1 = network
238
241
network = DenseLayer (network , vocab_size , W_init = rnn_init , b_init = rnn_init , act = tf .identity , name = 'output' )
239
242
return network , lstm1
@@ -297,14 +300,21 @@ def loss_fn(outputs, targets, batch_size, sequence_length):
297
300
## reset all states at the begining of every epoch
298
301
state1 = tl .layers .initialize_rnn_state (lstm1 .initial_state )
299
302
for step , (x , y ) in enumerate (tl .iterate .ptb_iterator (train_data , batch_size , sequence_length )):
300
- _cost , state1 , _ = sess .run ([cost , lstm1 .final_state , train_op ], \
301
- feed_dict = {input_data : x , targets : y , lstm1 .initial_state : state1 })
303
+ _cost , state1 , _ = sess .run (
304
+ [cost , lstm1 .final_state , train_op ], feed_dict = {
305
+ input_data : x ,
306
+ targets : y ,
307
+ lstm1 .initial_state : state1
308
+ }
309
+ )
302
310
costs += _cost
303
311
iters += sequence_length
304
312
305
313
if step % (epoch_size // 10 ) == 1 :
306
- print ("%.3f perplexity: %.3f speed: %.0f wps" % \
307
- (step * 1.0 / epoch_size , np .exp (costs / iters ), iters * batch_size / (time .time () - start_time )))
314
+ print (
315
+ "%.3f perplexity: %.3f speed: %.0f wps" %
316
+ (step * 1.0 / epoch_size , np .exp (costs / iters ), iters * batch_size / (time .time () - start_time ))
317
+ )
308
318
train_perplexity = np .exp (costs / iters )
309
319
# print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity))
310
320
print ("Epoch: %d/%d Train Perplexity: %.3f" % (i + 1 , max_max_epoch , train_perplexity ))
@@ -319,14 +329,22 @@ def loss_fn(outputs, targets, batch_size, sequence_length):
319
329
# feed the seed to initialize the state for generation.
320
330
for ids in outs_id [:- 1 ]:
321
331
a_id = np .asarray (ids ).reshape (1 , 1 )
322
- state1 = sess .run ([lstm1_test .final_state ], \
323
- feed_dict = {input_data_test : a_id , lstm1_test .initial_state : state1 })
332
+ state1 = sess .run (
333
+ [lstm1_test .final_state ], feed_dict = {
334
+ input_data_test : a_id ,
335
+ lstm1_test .initial_state : state1
336
+ }
337
+ )
324
338
# feed the last word in seed, and start to generate sentence.
325
339
a_id = outs_id [- 1 ]
326
340
for _ in range (print_length ):
327
341
a_id = np .asarray (a_id ).reshape (1 , 1 )
328
- out , state1 = sess .run ([y_soft , lstm1_test .final_state ], \
329
- feed_dict = {input_data_test : a_id , lstm1_test .initial_state : state1 })
342
+ out , state1 = sess .run (
343
+ [y_soft , lstm1_test .final_state ], feed_dict = {
344
+ input_data_test : a_id ,
345
+ lstm1_test .initial_state : state1
346
+ }
347
+ )
330
348
## Without sampling
331
349
# a_id = np.argmax(out[0])
332
350
## Sample from all words, if vocab_size is large,
0 commit comments