18
18
import numpy as np
19
19
import tensorflow as tf
20
20
from tensorflow_addons .activations import mish
21
+ from tensorflow_addons .activations .mish import _mish_py
21
22
from tensorflow_addons .utils import test_utils
22
23
23
24
@@ -42,6 +43,28 @@ def test_theoretical_gradients(self, dtype):
42
43
theoretical , numerical = tf .test .compute_gradient (mish , [x ])
43
44
self .assertAllCloseAccordingToType (theoretical , numerical , atol = 1e-4 )
44
45
46
+ @parameterized .named_parameters (("float32" , np .float32 ), ("float64" , np .float64 ))
47
+ def test_same_as_py_func (self , dtype ):
48
+ np .random .seed (1234 )
49
+ for _ in range (20 ):
50
+ self .verify_funcs_are_equivalent (dtype )
51
+
52
+ def verify_funcs_are_equivalent (self , dtype ):
53
+ x_np = np .random .uniform (- 10 , 10 , size = (4 , 4 )).astype (dtype )
54
+ x = tf .convert_to_tensor (x_np )
55
+
56
+ with tf .GradientTape (persistent = True ) as t :
57
+ t .watch (x )
58
+ y_native = mish (x )
59
+ y_py = _mish_py (x )
60
+
61
+ self .assertAllCloseAccordingToType (y_native , y_py , atol = 1e-4 )
62
+
63
+ grad_native = t .gradient (y_native , x )
64
+ grad_py = t .gradient (y_py , x )
65
+
66
+ self .assertAllCloseAccordingToType (grad_native , grad_py , atol = 1e-4 )
67
+
45
68
46
69
if __name__ == "__main__" :
47
70
tf .test .main ()
0 commit comments