@@ -975,34 +975,36 @@ def testLuongMonotonicScaled(self):
975
975
create_attention_kwargs = create_attention_kwargs ,
976
976
)
977
977
978
- def test_attention_state_with_keras_rnn (self ):
979
- # See https://github.com/tensorflow/addons/issues/1095.
980
- cell = tf .keras .layers .LSTMCell (8 )
981
978
982
- mechanism = wrapper .LuongAttention (units = 8 , memory = tf .ones ((2 , 4 , 8 )))
979
+ def test_attention_state_with_keras_rnn ():
980
+ # See https://github.com/tensorflow/addons/issues/1095.
981
+ cell = tf .keras .layers .LSTMCell (8 )
983
982
984
- cell = wrapper .AttentionWrapper ( cell = cell , attention_mechanism = mechanism )
983
+ mechanism = wrapper .LuongAttention ( units = 8 , memory = tf . ones (( 2 , 4 , 8 )) )
985
984
986
- layer = tf .keras .layers .RNN (cell )
987
- _ = layer (inputs = tf .ones ((2 , 4 , 8 )))
985
+ cell = wrapper .AttentionWrapper (cell = cell , attention_mechanism = mechanism )
988
986
989
- # Make sure the explicit initial_state also works.
990
- initial_state = cell .get_initial_state (batch_size = 2 , dtype = tf .float32 )
991
- _ = layer (inputs = tf .ones ((2 , 4 , 8 )), initial_state = initial_state )
987
+ layer = tf .keras .layers .RNN (cell )
988
+ _ = layer (inputs = tf .ones ((2 , 4 , 8 )))
992
989
993
- def test_attention_state_with_variable_length_input (self ):
994
- cell = tf .keras .layers .LSTMCell (3 )
995
- mechanism = wrapper .LuongAttention (units = 3 )
996
- cell = wrapper .AttentionWrapper (cell , mechanism )
990
+ # Make sure the explicit initial_state also works.
991
+ initial_state = cell .get_initial_state (batch_size = 2 , dtype = tf .float32 )
992
+ _ = layer (inputs = tf .ones ((2 , 4 , 8 )), initial_state = initial_state )
997
993
998
- var_len = tf .random .uniform (shape = (), minval = 2 , maxval = 10 , dtype = tf .int32 )
999
- lengths = tf .random .uniform (
1000
- shape = (var_len ,), minval = 1 , maxval = var_len + 1 , dtype = tf .int32
1001
- )
1002
- data = tf .ones (shape = (var_len , var_len , 3 ))
1003
- mask = tf .sequence_mask (lengths , maxlen = var_len )
1004
994
1005
- mechanism .setup_memory (data )
1006
- layer = tf .keras .layers .RNN (cell )
995
+ def test_attention_state_with_variable_length_input ():
996
+ cell = tf .keras .layers .LSTMCell (3 )
997
+ mechanism = wrapper .LuongAttention (units = 3 )
998
+ cell = wrapper .AttentionWrapper (cell , mechanism )
999
+
1000
+ var_len = tf .random .uniform (shape = (), minval = 2 , maxval = 10 , dtype = tf .int32 )
1001
+ lengths = tf .random .uniform (
1002
+ shape = (var_len ,), minval = 1 , maxval = var_len + 1 , dtype = tf .int32
1003
+ )
1004
+ data = tf .ones (shape = (var_len , var_len , 3 ))
1005
+ mask = tf .sequence_mask (lengths , maxlen = var_len )
1006
+
1007
+ mechanism .setup_memory (data )
1008
+ layer = tf .keras .layers .RNN (cell )
1007
1009
1008
- _ = layer (data , mask = mask )
1010
+ _ = layer (data , mask = mask )
0 commit comments