Skip to content

Commit 66e8ca8

Browse files
Added a py function for hardshrink (#1128)
* Added a py function. * Added a simple test.
1 parent 3955638 commit 66e8ca8

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

tensorflow_addons/activations/hardshrink.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,20 @@ def _hardshrink_grad(op, grad):
4848
return _activation_so.ops.addons_hardshrink_grad(
4949
grad, op.inputs[0], op.get_attr("lower"), op.get_attr("upper")
5050
)
51+
52+
53+
def _hardshrink_py(
54+
x: types.TensorLike, lower: Number = -0.5, upper: Number = 0.5
55+
) -> tf.Tensor:
56+
if lower > upper:
57+
raise ValueError(
58+
"The value of lower is {} and should"
59+
" not be higher than the value "
60+
"variable upper, which is {} .".format(lower, upper)
61+
)
62+
x = tf.convert_to_tensor(x)
63+
mask_lower = x < lower
64+
mask_upper = upper < x
65+
mask = tf.logical_or(mask_lower, mask_upper)
66+
mask = tf.cast(mask, x.dtype)
67+
return x * mask

tensorflow_addons/activations/hardshrink_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import tensorflow as tf
2020
from tensorflow_addons.activations import hardshrink
2121
from tensorflow_addons.utils import test_utils
22+
from tensorflow_addons.activations.hardshrink import _hardshrink_py
2223

2324

2425
@test_utils.run_all_in_graph_and_eager_modes
@@ -53,6 +54,30 @@ def test_theoretical_gradients(self, dtype):
5354
theoretical, numerical = tf.test.compute_gradient(hardshrink, [x])
5455
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)
5556

57+
@parameterized.named_parameters(("float32", np.float32), ("float64", np.float64))
58+
def test_same_as_py_func(self, dtype):
59+
np.random.seed(1234)
60+
for _ in range(20):
61+
self.verify_funcs_are_equivalent(dtype)
62+
63+
def verify_funcs_are_equivalent(self, dtype):
64+
x_np = np.random.uniform(-10, 10, size=(4, 4)).astype(dtype)
65+
x = tf.convert_to_tensor(x_np)
66+
lower = np.random.uniform(-10, 10)
67+
upper = lower + np.random.uniform(0, 10)
68+
69+
with tf.GradientTape(persistent=True) as t:
70+
t.watch(x)
71+
y_native = hardshrink(x, lower, upper)
72+
y_py = _hardshrink_py(x, lower, upper)
73+
74+
self.assertAllCloseAccordingToType(y_native, y_py, atol=1e-4)
75+
76+
grad_native = t.gradient(y_native, x)
77+
grad_py = t.gradient(y_py, x)
78+
79+
self.assertAllCloseAccordingToType(grad_native, grad_py, atol=1e-4)
80+
5681

5782
if __name__ == "__main__":
5883
tf.test.main()

0 commit comments

Comments
 (0)