Skip to content

Commit ddb59f1

Browse files
committed
Saving model test
1 parent 49fd2e5 commit ddb59f1

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

keras_nlp/models/bert_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
# limitations under the License.
1414
"""Tests for Bert model."""
1515

16+
import os
17+
1618
import tensorflow as tf
19+
from tensorflow import keras
1720

1821
from keras_nlp.models import bert
1922

@@ -76,3 +79,34 @@ def test_valid_call_bert_base(self):
7679
"input_mask": tf.constant([1] * 512, shape=(1, 512)),
7780
}
7881
model(input_data)
82+
83+
def test_saving_model(self):
84+
model = bert.Bert(
85+
vocabulary_size=30522,
86+
num_layers=12,
87+
num_heads=12,
88+
hidden_dim=768,
89+
intermediate_dim=3072,
90+
max_sequence_length=12,
91+
name="encoder",
92+
)
93+
input_data = {
94+
"input_ids": tf.random.uniform(
95+
shape=(1, 12), dtype=tf.int64, maxval=30522
96+
),
97+
"segment_ids": tf.constant(
98+
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)
99+
),
100+
"input_mask": tf.constant(
101+
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)
102+
),
103+
}
104+
model_output = model(input_data)
105+
save_path = os.path.join(self.get_temp_dir(), "model")
106+
model.save(save_path)
107+
restored_model = keras.models.load_model(save_path)
108+
109+
restored_output = restored_model(input_data)
110+
self.assertAllClose(
111+
model_output["pooled_output"], restored_output["pooled_output"]
112+
)

0 commit comments

Comments
 (0)