-
Notifications
You must be signed in to change notification settings - Fork 335
Support for tf.data.Dataset
added in plot_image_gallery
, removed hard dependency on rows
and cols
params
#1837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
removed hard dependency on having rows and cols params, added support for processing tf.data.Dataset internally Signed-off-by: Suvaditya Mukherjee <[email protected]>
Signed-off-by: Suvaditya Mukherjee <[email protected]>
Signed-off-by: Suvaditya Mukherjee <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
Let's also update examples/visualization/plot_image_gallery.py
to show the usage of both paths for this API
Signed-off-by: Suvaditya Mukherjee <[email protected]>
Hi @ianstenbit, Have addressed the comments. Do take a look! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates!
A few more comments for you
updated error messages, added example with nparrays, corrected conditions Signed-off-by: Suvaditya Mukherjee <[email protected]>
@ianstenbit Please take a look when you can, thanks! |
return inputs["image"] | ||
|
||
# Calculate appropriate number of rows and columns | ||
if rows is None or cols is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If one of them is set and the other isn't, this behaves a bit weirdly.
If we set rows=2
and cols=None
, and our batch size is 9
, we'll end up with a 3x3, which seems counterintuitive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So in this case, we can do something like divide the not-None
argument such that the batch will create something which at least fixes the user's args.
if rows is None and cols is not None:
rows = ceil(batch_size / cols)
elif rows is not None and cols is None:
cols = ceil(batch_size / rows)
This should work in theory.
For your example, it would mean we do rows=2
, columns=None
, batch_size=9
and get rows=2
, cols=5
The final grid would have 5 images in the first row and 4 images in the second row. Is that fine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah this sounds good to me
) | ||
|
||
elif rows is not None and cols is not None: | ||
if isinstance(images, tf.data.Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If rows and cols are both specified, why do we need to compute batch size? I'm not sure what the purpose of this branch is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is done to take out the required number of images from the dataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. It seems like it would be a lot easier to just use tfds.as_numpy(images.take(rows*cols)
(Probably have to optionally unbatch first)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, I think the logic would be much simpler if we separated the logic for converting a tf.data.Dataset into a numpy array from the logic for computing batch size / rows / cols. As-is there's some unnecessary indirection so the resulting code is more complex than it needs to be.
One simple approach would be to convert all possible input types to an unbatched iterator at the top of the method, and record the batch size if there was one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tfds
is a special case. Most folks may also be invoking and creating tf.data.Dataset
instances from a simple image_dataset_from_directory()
call or from_tensor_slices()
too. Using tfds.as_numpy
might not be the best choice for that reason.
scale: how large to scale the images in the gallery | ||
rows: (Optional) number of rows in the gallery to show. Required if inputs are unbatched. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After chatting with the team a bit, and thinking about this code some more, I think we should keep rows/cols required and get rid of the logic to try to infer how many images should be shown.
It seems reasonable to ask users for a number of rows/cols, and it makes the code nice and overt on the caller's side.
It also dramatically simplifies how this can be implemented, because we don't need to try to infer batch sizes for different data types, and to add support for tf.data.Datasets, we can just use tfds.as_numpy(dataset.take(rows*cols))
(after optionally unbatching the dataset 😄
Sorry for the churn on this, and thanks for the awesome work!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer an automatic rows and cols, mostly to remove many useless lines of code from our examples.
This can be implemented in just a few lines, even with a fancy aspect ratio adaptation:
N = len(images)
sqrtN = int(math.ceil(math.sqrt(N)))
aspect = sum([im.shape[1]/im.shape[0] for im in images]) / N # mean aspect ratio of images
fig = plt.figure(figsize=(15,15/aspect), frameon=False)
(I've been using this code in all my samples from the past couple of years)
And it's fine to require either both rows and cols or neither and not handle the fancy cases in between.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay this sgtm then -- @suvadityamuk I think this PR is pretty close then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See Martin's comment about rows/cols specification.
I think the code here would be a lot cleaner if we structured it like this:
def visualize(....):
image_batch = _extract_image_batch(images, rows, cols) # Throws if input is invalid
rows = rows or ...
cols = cols or ...
# actually do the plotting
Basically the current logic of intermixing the conditionals of input data type and whether rows/cols are specified gets a bit hairy and we should break out the logic to extract an image batch.
Using some guard clauses would make the code a lot easier to understand 🙂
scale: how large to scale the images in the gallery | ||
rows: (Optional) number of rows in the gallery to show. Required if inputs are unbatched. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay this sgtm then -- @suvadityamuk I think this PR is pretty close then
"Passed `tf.data.Dataset` does not appear to be batched. Please batch using the `.batch().`" | ||
) | ||
|
||
images = images.map(unpack_images) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can drop this since sample["image"] is already what we want
images = images[:batch_size, ...] | ||
else: | ||
raise ValueError( | ||
"plot_image_gallery() expects `tf.data.Dataset` to be batched if rows or cols are not specified." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This message still seems incorrect -- this branch will get executed even if one of rows
or cols
is set -- and in this case I think we can probably just infer the other. Maybe the and
on L120 was meant to be or
?
"plot_image_gallery() expects `tf.data.Dataset` to be batched if rows or cols are not specified." | ||
) | ||
|
||
rows = int(math.ceil(batch_size**0.5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this overrides rows / cols even if they were specified. Let's use rows = rows or ...
instead
addressed most comments, have rewritten the code from scratch to find a new and cleaner method to work with the images Signed-off-by: Suvaditya Mukherjee <[email protected]>
I have performed a ton of refactoring around the code to try and reduce the complexity of the code 😅. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful -- I love the new structure.
Thank you for the PR, and for sticking with it despite the churn on my end 😄
Aha, no worries at all, happy to help! 😅 |
Signed-off-by: Suvaditya Mukherjee <[email protected]>
Just committed a linted fix! |
Looks like the linter is still unhappy @suvadityamuk |
Signed-off-by: Suvaditya Mukherjee <[email protected]>
Aha, the example was not linted, it seems. This new commit should do the trick, hopefully! 🤞 |
…n file Signed-off-by: Suvaditya Mukherjee <[email protected]>
Seems like it needed that one more final "push", haha. Please check it out now, have run |
…hard dependency on `rows` and `cols` params (keras-team#1837) * feat: add support for tf.data.dataset removed hard dependency on having rows and cols params, added support for processing tf.data.Dataset internally Signed-off-by: Suvaditya Mukherjee <[email protected]> * chore: removed unused numpy import Signed-off-by: Suvaditya Mukherjee <[email protected]> * chore: added `tf.data.Dataset` to docstring Signed-off-by: Suvaditya Mukherjee <[email protected]> * chore: addressed comments and modified example file Signed-off-by: Suvaditya Mukherjee <[email protected]> * chore: addressed newer comments updated error messages, added example with nparrays, corrected conditions Signed-off-by: Suvaditya Mukherjee <[email protected]> * fix: refactored code addressed most comments, have rewritten the code from scratch to find a new and cleaner method to work with the images Signed-off-by: Suvaditya Mukherjee <[email protected]> * chore: fix with isort Signed-off-by: Suvaditya Mukherjee <[email protected]> * chore: fix linting on example Signed-off-by: Suvaditya Mukherjee <[email protected]> * chore: linting issues with black and isort solved for example and main file Signed-off-by: Suvaditya Mukherjee <[email protected]> --------- Signed-off-by: Suvaditya Mukherjee <[email protected]>
What does this PR do?
This PR is meant to add support for using
tf.data.Dataset
instances directly withkeras_cv.visualization.plot_image_gallery
instead of the user having to manually convert the dataset and then perform processing. It still supports NumPy arrays or Tensors as it used to.Also, handled a possible future deprecation of removing
plt.tight_layout()
by making use ofplt.subplot(...,layout='tight')
as per this warning in MatplotlibBefore submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
cc: @ianstenbit (as discussed previously) @jbischof @tanzhenyu