-
Notifications
You must be signed in to change notification settings - Fork 614
add cutout image op #1338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add cutout image op #1338
Conversation
if tf.equal(tf.rank(image), 3): | ||
mask = tf.expand_dims(mask, -1) | ||
mask = tf.tile(mask, [1, 1, tf.shape(image)[-1]]) | ||
elif tf.equal(tf.rank(image), 4): | ||
mask = tf.expand_dims(mask, 0) | ||
mask = tf.expand_dims(mask, -1) | ||
mask = tf.tile(mask, [tf.shape(image)[0], 1, 1, tf.shape(image)[-1]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can from_4D_image
and to_4D_image
be used to handle this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed. But I wonder about the performance of these ops.
for channel in [0, 1, 3, 4]: | ||
with self.subTest(channel=channel): | ||
test_image = tf.image.decode_image( | ||
test_image_file, channels=channel, dtype=tf.uint8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean dtype=dtype
here and in some other places
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tested dtype
in test_different_dtypes
and it will increase test time if I use it on all test cases.
Thank you for this pull request! I believe we're going to change the API of the reference implementation to make it more user friendly. Here is what I propose: def cutout(
images: TensorLike,
mask_size: TensorLike,
offset: TensorLike= (0, 0),
constant_values: Number = 0,
data_format="channels_last"
) -> tf.Tensor:
"""
Args:
image: A 4D tensor, (N, H, W, C).
mask_size: a tuple (or tensor) with two values. Height and width of the mask (correspond to 2x pad_size). A single scalar means a square mask (like in keras conv and pad layers). Can be a tensor of shape (N, 2) to have different mask sizes in the same batch.
offset: the offset relative to the center of the image. Default is (0, 0) which means the mask will be in the middle of the image. Can be a tensor of shape (N, 2) to have different offsets in the same batch.
constant_values: the values used to fill the mask.
data_format: The data format
""" The rational behind using only 4D tensors is that when the users passes 3D tensors, there is a non-negligable probability that user mistake It's totally fine to not implement everything. For example, you can do if data_format == "channels_first":
raise NotImplementedError("Channels first is not yet available for cutout. Contributions welcome!")
if tf.rank(mask_size) != 0:
raise NotImplementedError("Having non-square masks is not supported yet, contributions welcome.") random_cutout should follow a similar pattern. I believe that it should not change too much your code, it's already really close to the end result. For the tests, please avoid adding images to the git repo, also we moving away from |
@fsx950223 feel free to ping me when you're done with the changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the pull request. We're nearly there, a few changes here and there need to be made.
images: A tensor of shape | ||
(num_images, num_rows, num_columns, num_channels) | ||
(NHWC), (num_images, num_channels, num_rows, num_columns)(NCHW), (num_rows, num_columns, num_channels) (HWC), or | ||
(num_rows, num_columns) (HW). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll focus here only on 4D tensors for easier implementation and to avoid users shooting themselves in the foot.
mask_size: Specifies how big the zero mask that will be generated is that | ||
is applied to the images. The mask will be of size | ||
(2*pad_size[0] x 2*pad_size[1]). | ||
constant_values: What pixel value to fill in the images in the area that has | ||
the cutout mask applied to it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should specify what shapes are authorized, as it's not clear for someone who didn't read the code.
|
||
def _norm_params(images, mask_size, data_format): | ||
mask_size = tf.convert_to_tensor(mask_size) | ||
if tf.equal(tf.rank(mask_size), 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For readability, I believe you can use:
if tf.equal(tf.rank(mask_size), 0): | |
if tf.rank(mask_size) == 0: |
mask_4d, [tf.shape(images)[0], tf.shape(images)[1], 1, 1] | ||
) | ||
images = tf.where( | ||
tf.equal(mask, 0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, you can replace all the tf.equal by ==
shape=[], minval=0, maxval=image_height, dtype=tf.int32, seed=seed | ||
) | ||
cutout_center_width = tf.random.uniform( | ||
shape=[], minval=0, maxval=image_width, dtype=tf.int32, seed=seed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, if a batch of images is passed and random cutout is applied, the mask will be at the same place for all the images in the batch, right? Is this something we want? Maybe users would expect (as an augmentation strategy) to have the mask being at different places for all the images in the batch.
If the original implementation uses the same mask for all the images in the batch, let's go with this implementation, but the function shouldn't be public. I believe a public random_cutout
shoud respect the principle of least astonishment.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly! In a batch the cutout operation has to be random for each image.
I will change them later. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you again for your work, the implementation looks great! Some comments on the docs and the tests. I'm sorry to always ask more, but I feel this feature is going to be used by many for data augmentation, we need to make it perfect :)
mask_size: Specifies how big the zero mask that will be generated is that | ||
is applied to the images. The mask will be of size | ||
(2 * mask_height x 2 * mask_width). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mask_size
should be the shape of the mask. Here the shape of the true mask is (2 * mask_height x 2 * mask_width)
which might confuse users. Could you change that to (mask_height x mask_width)
? You can throw an error if the width and height cannot be divided by 2.
mask_size: Specifies how big the zero mask that will be generated is that | ||
is applied to the images. The mask will be of size | ||
(2 * mask_height x 2 * mask_width). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
mask_size: Specifies how big the zero mask that will be generated is that | ||
is applied to the images. The mask will be of size | ||
(2 * mask_height x 2 * mask_width). | ||
offset: A tuple of (height, width) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you specify all the shapes are are possible? A tuple is possible but a 2D tensor too (batch_size, 2)
"""Apply cutout (https://arxiv.org/abs/1708.04552) to images. | ||
|
||
This operation applies a (2 * mask_height x 2 * mask_width) mask of zeros to | ||
a random location within `img`. The pixel values filled in will be of the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a random location within `img`. The pixel values filled in will be of the | |
a location within `img` specified by the offset. The pixel values filled in will be of the |
@pytest.mark.usefixtures("maybe_run_functions_eagerly") | ||
def test_with_tf_function(): | ||
test_image = tf.ones([1, 40, 40, 1], dtype=tf.uint8) | ||
result_image = tf.function(random_cutout)(test_image, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great way to test the graph mode :)
Here we need to make sure that the for loop with the TensorArray works well. To do that, we need to make sure the loop is possible even if the size of the batch is not known when drawing the graph. Could you use input_signature
and set the shape of images
to [None, 40, 40, 1]
?
np.testing.assert_allclose(tf.shape(result_image), tf.shape(expect_image)) | ||
|
||
|
||
if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add one more test to ensure that with random cutout, the masks are at different places in the images of the batch? I'll let you decide what's the easier way of doing that.
@parameterized.named_parameters( | ||
("float16", np.float16), ("float32", np.float32), ("uint8", np.uint8) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@parameterized.named_parameters( | |
("float16", np.float16), ("float32", np.float32), ("uint8", np.uint8) | |
) | |
@pytest.mark.parameterize("dtype", [np.float16, np.float32, np.uint8]) |
np.testing.assert_allclose(result_image, expect_image) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.testing.assert_allclose(result_image, expect_image) | |
np.testing.assert_allclose(result_image, expect_image) | |
assert result_image.dtype == dtype | |
cc @gabrieldemarmiesse . |
with tf.control_dependencies( | ||
[ | ||
tf.assert_equal( | ||
tf.reduce_any(mask_size % 2 != 0), | ||
False, | ||
"mask_size should be divisible by 2", | ||
) | ||
] | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
control_dependencies is quite tf 1.x style, I recommend to activate the check only in eager mode. It will also be easier to debug for users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this great pull request!
Thanks for your review |
* add cutout op * export module * remove test_utils * use tf.rank * remove decorator * add tf function test * fix cutout channels test * add norm param * change batch random strategy * fix flake8 * add more checks * add missing comment * add seed * remove control dependencies
Related #1333