diff --git a/tensorflow_addons/image/BUILD b/tensorflow_addons/image/BUILD index a1a09d66bf..4c9a7408bb 100644 --- a/tensorflow_addons/image/BUILD +++ b/tensorflow_addons/image/BUILD @@ -1,25 +1,28 @@ licenses(["notice"]) # Apache 2.0 -package(default_visibility = ["//visibility:public"]) +package(default_visibility=["//visibility:public"]) py_library( - name = "image", - srcs = ([ - "__init__.py", - "dense_image_warp.py", - "distance_transform.py", - "distort_image_ops.py", - "filters.py", - "transform_ops.py", - "translate_ops.py", - "utils.py", - "sparse_image_warp.py", - "interpolate_spline.py", - "connected_components.py", - "resampler_ops.py", - "compose_ops.py", - ]), + + name="image", + srcs=( + [ + "__init__.py", + "dense_image_warp.py", + "distance_transform.py", + "distort_image_ops.py", + "filters.py", + "transform_ops.py", + "translate_ops.py", + "utils.py", + "sparse_image_warp.py", + "interpolate_spline.py", + "connected_components.py", + "resampler_ops.py", + "solarize_ops.py", + ]), data = [ + ":sparse_image_warp_test_data", "//tensorflow_addons/custom_ops/image:_distort_image_ops.so", "//tensorflow_addons/custom_ops/image:_image_ops.so", @@ -29,141 +32,104 @@ py_library( ) filegroup( - name = "sparse_image_warp_test_data", - srcs = glob(["test_data/*.png"]), + name="sparse_image_warp_test_data", srcs=glob(["test_data/*.png"]), ) py_test( - name = "dense_image_warp_test", - size = "small", - srcs = [ - "dense_image_warp_test.py", - ], - main = "dense_image_warp_test.py", - deps = [ - ":image", - ], + name="dense_image_warp_test", + size="small", + srcs=["dense_image_warp_test.py",], + main="dense_image_warp_test.py", + deps=[":image",], ) py_test( - name = "distance_transform_ops_test", - size = "small", - srcs = [ - "distance_transform_test.py", - ], - main = "distance_transform_test.py", - deps = [ - ":image", - ], + name="distance_transform_ops_test", + size="small", + srcs=["distance_transform_test.py",], + main="distance_transform_test.py", + deps=[":image",], ) py_test( - name = "distort_image_ops_test", - size = "small", - srcs = [ - "distort_image_ops_test.py", - ], - main = "distort_image_ops_test.py", - deps = [ - ":image", - ], + name="distort_image_ops_test", + size="small", + srcs=["distort_image_ops_test.py",], + main="distort_image_ops_test.py", + deps=[":image",], ) py_test( - name = "filters_test", - size = "medium", - srcs = [ - "filters_test.py", - ], - flaky = True, - main = "filters_test.py", - deps = [ - ":image", - ], + name="filters_test", + size="medium", + srcs=["filters_test.py",], + flaky=True, + main="filters_test.py", + deps=[":image",], ) py_test( - name = "transform_ops_test", - size = "medium", - srcs = [ - "transform_ops_test.py", - ], - main = "transform_ops_test.py", - deps = [ - ":image", - ], + name="transform_ops_test", + size="medium", + srcs=["transform_ops_test.py",], + main="transform_ops_test.py", + deps=[":image",], ) py_test( - name = "translate_ops_test", - size = "medium", - srcs = [ - "translate_ops_test.py", - ], - main = "translate_ops_test.py", - deps = [ - ":image", - ], + name="translate_ops_test", + size="medium", + srcs=["translate_ops_test.py",], + main="translate_ops_test.py", + deps=[":image",], ) py_test( - name = "utils_test", - size = "small", - srcs = [ - "utils_test.py", - ], - main = "utils_test.py", - deps = [ - ":image", - ], + name="utils_test", + size="small", + srcs=["utils_test.py",], + main="utils_test.py", + deps=[":image",], ) py_test( - name = "sparse_image_warp_test", - size = "medium", - srcs = [ - "sparse_image_warp_test.py", - ], - main = "sparse_image_warp_test.py", - deps = [ - ":image", - ], + name="sparse_image_warp_test", + size="medium", + srcs=["sparse_image_warp_test.py",], + main="sparse_image_warp_test.py", + deps=[":image",], ) py_test( - name = "interpolate_spline_test", - size = "medium", - srcs = [ - "interpolate_spline_test.py", - ], - main = "interpolate_spline_test.py", - deps = [ - ":image", - ], + name="interpolate_spline_test", + size="medium", + srcs=["interpolate_spline_test.py",], + main="interpolate_spline_test.py", + deps=[":image",], ) py_test( - name = "connected_components_test", - size = "medium", - srcs = [ - "connected_components_test.py", - ], - main = "connected_components_test.py", - deps = [ - ":image", - ], + name="connected_components_test", + size="medium", + srcs=["connected_components_test.py",], + main="connected_components_test.py", + deps=[":image",], ) py_test( - name = "resampler_ops_test", - size = "medium", - srcs = [ - "resampler_ops_test.py", - ], - main = "resampler_ops_test.py", - deps = [ - ":image", - ], + name="resampler_ops_test", + size="medium", + srcs=["resampler_ops_test.py",], + main="resampler_ops_test.py", + deps=[":image",], +) + +py_test( + name="solarize_ops_test", + size="medium", + srcs=["solarize_ops_test.py",], + main="solarize_ops_test.py", + deps=[":image",], ) py_test( diff --git a/tensorflow_addons/image/__init__.py b/tensorflow_addons/image/__init__.py index fbd5cda029..ebd0994d24 100644 --- a/tensorflow_addons/image/__init__.py +++ b/tensorflow_addons/image/__init__.py @@ -28,4 +28,6 @@ from tensorflow_addons.image.transform_ops import rotate from tensorflow_addons.image.transform_ops import transform from tensorflow_addons.image.translate_ops import translate +from tensorflow_addons.image.solarize_ops import solarize +from tensorflow_addons.image.solarize_ops import solarize_add from tensorflow_addons.image.compose_ops import blend diff --git a/tensorflow_addons/image/solarize_ops.py b/tensorflow_addons/image/solarize_ops.py new file mode 100644 index 0000000000..6b390b5940 --- /dev/null +++ b/tensorflow_addons/image/solarize_ops.py @@ -0,0 +1,30 @@ +""" This module is used to invert all pixel values above a threshold + which simply means segmentation. """ + +import tensorflow as tf + + +def solarize(image, threshold=128): + """Method to solarize the image + image: input image + threshold: threshold value to solarize the image + """ + # For each pixel in the image, select the pixel + # if the value is less than the threshold. + # Otherwise, subtract 255 from the pixel. + return tf.where(image < threshold, image, 255 - image) + + +def solarize_add(image, addition=0, threshold=128): + """Method to add solarize to the image + image: input image + addition: addition amount to add in image + threshold: threshold value to solarize the image + """ + # For each pixel in the image less than threshold + # we add 'addition' amount to it and then clip the + # pixel value to be between 0 and 255. The value + # of 'addition' is between -128 and 128. + added_image = tf.cast(image, tf.int64) + addition + added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8) + return tf.where(image < threshold, added_image, image) diff --git a/tensorflow_addons/image/solarize_ops_test.py b/tensorflow_addons/image/solarize_ops_test.py new file mode 100644 index 0000000000..050c55111d --- /dev/null +++ b/tensorflow_addons/image/solarize_ops_test.py @@ -0,0 +1,30 @@ +"""Test of solarize_ops""" +import sys +import pytest +import tensorflow as tf +from tensorflow_addons.image import solarize_ops +from tensorflow_addons.utils import test_utils +from absl.testing import parameterized + + +@test_utils.run_all_in_graph_and_eager_modes +class SolarizeOPSTest(tf.test.TestCase, parameterized.TestCase): + """SolarizeOPSTest class to test the solarize images""" + + def test_solarize(self): + if tf.executing_eagerly(): + image2 = tf.constant( + [ + [255, 255, 255, 255], + [255, 255, 255, 255], + [255, 255, 255, 255], + [255, 255, 255, 255], + ], + dtype=tf.uint8, + ) + threshold = 10 + solarize_img = solarize_ops.solarize(image2, threshold) + self.assertAllEqual(tf.shape(solarize_img), tf.shape(image2)) + +if __name__ == "__main__": + sys.exit(pytest.main([__file__]))