Skip to content

Commit e92a36a

Browse files
Moved some functions out of run_all_in_graph_and_eager_mode. (#1639)
1 parent 84657f8 commit e92a36a

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

tensorflow_addons/seq2seq/tests/attention_wrapper_test.py

+25-23
Original file line numberDiff line numberDiff line change
@@ -975,34 +975,36 @@ def testLuongMonotonicScaled(self):
975975
create_attention_kwargs=create_attention_kwargs,
976976
)
977977

978-
def test_attention_state_with_keras_rnn(self):
979-
# See https://github.com/tensorflow/addons/issues/1095.
980-
cell = tf.keras.layers.LSTMCell(8)
981978

982-
mechanism = wrapper.LuongAttention(units=8, memory=tf.ones((2, 4, 8)))
979+
def test_attention_state_with_keras_rnn():
980+
# See https://github.com/tensorflow/addons/issues/1095.
981+
cell = tf.keras.layers.LSTMCell(8)
983982

984-
cell = wrapper.AttentionWrapper(cell=cell, attention_mechanism=mechanism)
983+
mechanism = wrapper.LuongAttention(units=8, memory=tf.ones((2, 4, 8)))
985984

986-
layer = tf.keras.layers.RNN(cell)
987-
_ = layer(inputs=tf.ones((2, 4, 8)))
985+
cell = wrapper.AttentionWrapper(cell=cell, attention_mechanism=mechanism)
988986

989-
# Make sure the explicit initial_state also works.
990-
initial_state = cell.get_initial_state(batch_size=2, dtype=tf.float32)
991-
_ = layer(inputs=tf.ones((2, 4, 8)), initial_state=initial_state)
987+
layer = tf.keras.layers.RNN(cell)
988+
_ = layer(inputs=tf.ones((2, 4, 8)))
992989

993-
def test_attention_state_with_variable_length_input(self):
994-
cell = tf.keras.layers.LSTMCell(3)
995-
mechanism = wrapper.LuongAttention(units=3)
996-
cell = wrapper.AttentionWrapper(cell, mechanism)
990+
# Make sure the explicit initial_state also works.
991+
initial_state = cell.get_initial_state(batch_size=2, dtype=tf.float32)
992+
_ = layer(inputs=tf.ones((2, 4, 8)), initial_state=initial_state)
997993

998-
var_len = tf.random.uniform(shape=(), minval=2, maxval=10, dtype=tf.int32)
999-
lengths = tf.random.uniform(
1000-
shape=(var_len,), minval=1, maxval=var_len + 1, dtype=tf.int32
1001-
)
1002-
data = tf.ones(shape=(var_len, var_len, 3))
1003-
mask = tf.sequence_mask(lengths, maxlen=var_len)
1004994

1005-
mechanism.setup_memory(data)
1006-
layer = tf.keras.layers.RNN(cell)
995+
def test_attention_state_with_variable_length_input():
996+
cell = tf.keras.layers.LSTMCell(3)
997+
mechanism = wrapper.LuongAttention(units=3)
998+
cell = wrapper.AttentionWrapper(cell, mechanism)
999+
1000+
var_len = tf.random.uniform(shape=(), minval=2, maxval=10, dtype=tf.int32)
1001+
lengths = tf.random.uniform(
1002+
shape=(var_len,), minval=1, maxval=var_len + 1, dtype=tf.int32
1003+
)
1004+
data = tf.ones(shape=(var_len, var_len, 3))
1005+
mask = tf.sequence_mask(lengths, maxlen=var_len)
1006+
1007+
mechanism.setup_memory(data)
1008+
layer = tf.keras.layers.RNN(cell)
10071009

1008-
_ = layer(data, mask=mask)
1010+
_ = layer(data, mask=mask)

0 commit comments

Comments
 (0)