Skip to content

Commit 2c1ed4f

Browse files
Added python implementation for mish (#1139)
* Added py implementation for mish
1 parent 34132df commit 2c1ed4f

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

tensorflow_addons/activations/mish.py

+4
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,7 @@ def mish(x: types.TensorLike) -> tf.Tensor:
4242
@tf.RegisterGradient("Addons>Mish")
4343
def _mish_grad(op, grad):
4444
return _activation_so.ops.addons_mish_grad(grad, op.inputs[0])
45+
46+
47+
def _mish_py(x):
48+
return x * tf.math.tanh(tf.math.softplus(x))

tensorflow_addons/activations/mish_test.py

+23
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import tensorflow as tf
2020
from tensorflow_addons.activations import mish
21+
from tensorflow_addons.activations.mish import _mish_py
2122
from tensorflow_addons.utils import test_utils
2223

2324

@@ -42,6 +43,28 @@ def test_theoretical_gradients(self, dtype):
4243
theoretical, numerical = tf.test.compute_gradient(mish, [x])
4344
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)
4445

46+
@parameterized.named_parameters(("float32", np.float32), ("float64", np.float64))
47+
def test_same_as_py_func(self, dtype):
48+
np.random.seed(1234)
49+
for _ in range(20):
50+
self.verify_funcs_are_equivalent(dtype)
51+
52+
def verify_funcs_are_equivalent(self, dtype):
53+
x_np = np.random.uniform(-10, 10, size=(4, 4)).astype(dtype)
54+
x = tf.convert_to_tensor(x_np)
55+
56+
with tf.GradientTape(persistent=True) as t:
57+
t.watch(x)
58+
y_native = mish(x)
59+
y_py = _mish_py(x)
60+
61+
self.assertAllCloseAccordingToType(y_native, y_py, atol=1e-4)
62+
63+
grad_native = t.gradient(y_native, x)
64+
grad_py = t.gradient(y_py, x)
65+
66+
self.assertAllCloseAccordingToType(grad_native, grad_py, atol=1e-4)
67+
4568

4669
if __name__ == "__main__":
4770
tf.test.main()

0 commit comments

Comments
 (0)