|
13 | 13 | # limitations under the License.
|
14 | 14 | """Tests for Bert model."""
|
15 | 15 |
|
| 16 | +import os |
| 17 | + |
16 | 18 | import tensorflow as tf
|
| 19 | +from tensorflow import keras |
17 | 20 |
|
18 | 21 | from keras_nlp.models import bert
|
19 | 22 |
|
@@ -76,3 +79,34 @@ def test_valid_call_bert_base(self):
|
76 | 79 | "input_mask": tf.constant([1] * 512, shape=(1, 512)),
|
77 | 80 | }
|
78 | 81 | model(input_data)
|
| 82 | + |
| 83 | + def test_saving_model(self): |
| 84 | + model = bert.Bert( |
| 85 | + vocabulary_size=30522, |
| 86 | + num_layers=12, |
| 87 | + num_heads=12, |
| 88 | + hidden_dim=768, |
| 89 | + intermediate_dim=3072, |
| 90 | + max_sequence_length=12, |
| 91 | + name="encoder", |
| 92 | + ) |
| 93 | + input_data = { |
| 94 | + "input_ids": tf.random.uniform( |
| 95 | + shape=(1, 12), dtype=tf.int64, maxval=30522 |
| 96 | + ), |
| 97 | + "segment_ids": tf.constant( |
| 98 | + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) |
| 99 | + ), |
| 100 | + "input_mask": tf.constant( |
| 101 | + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) |
| 102 | + ), |
| 103 | + } |
| 104 | + model_output = model(input_data) |
| 105 | + save_path = os.path.join(self.get_temp_dir(), "model") |
| 106 | + model.save(save_path) |
| 107 | + restored_model = keras.models.load_model(save_path) |
| 108 | + |
| 109 | + restored_output = restored_model(input_data) |
| 110 | + self.assertAllClose( |
| 111 | + model_output["pooled_output"], restored_output["pooled_output"] |
| 112 | + ) |
0 commit comments