Skip to content

Support for tf.data.Dataset added in plot_image_gallery, removed hard dependency on rows and cols params #1837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 26, 2023

Conversation

suvadityamuk
Copy link
Contributor

@suvadityamuk suvadityamuk commented May 30, 2023

What does this PR do?

This PR is meant to add support for using tf.data.Dataset instances directly with keras_cv.visualization.plot_image_gallery instead of the user having to manually convert the dataset and then perform processing. It still supports NumPy arrays or Tensors as it used to.

Also, handled a possible future deprecation of removing plt.tight_layout() by making use of plt.subplot(...,layout='tight') as per this warning in Matplotlib

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case.
  • Did you write any new necessary tests?
  • If this adds a new model, can you run a few training steps on TPU in Colab to ensure that no XLA incompatible OP are used?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

cc: @ianstenbit (as discussed previously) @jbischof @tanzhenyu

removed hard dependency on having rows and cols params, added support for processing tf.data.Dataset internally

Signed-off-by: Suvaditya Mukherjee <[email protected]>
Signed-off-by: Suvaditya Mukherjee <[email protected]>
Copy link
Contributor

@ianstenbit ianstenbit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

Let's also update examples/visualization/plot_image_gallery.py to show the usage of both paths for this API

@suvadityamuk suvadityamuk requested a review from ianstenbit May 30, 2023 19:32
@suvadityamuk
Copy link
Contributor Author

suvadityamuk commented May 30, 2023

Hi @ianstenbit,

Have addressed the comments. Do take a look!

Copy link
Contributor

@ianstenbit ianstenbit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates!

A few more comments for you

updated error messages, added example with nparrays, corrected conditions

Signed-off-by: Suvaditya Mukherjee <[email protected]>
@suvadityamuk
Copy link
Contributor Author

@ianstenbit Please take a look when you can, thanks!

return inputs["image"]

# Calculate appropriate number of rows and columns
if rows is None or cols is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If one of them is set and the other isn't, this behaves a bit weirdly.

If we set rows=2 and cols=None, and our batch size is 9, we'll end up with a 3x3, which seems counterintuitive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in this case, we can do something like divide the not-None argument such that the batch will create something which at least fixes the user's args.

if rows is None and cols is not None:
    rows = ceil(batch_size / cols)
elif rows is not None and cols is None:
    cols = ceil(batch_size / rows)

This should work in theory.
For your example, it would mean we do rows=2, columns=None, batch_size=9 and get rows=2, cols=5

The final grid would have 5 images in the first row and 4 images in the second row. Is that fine?

Copy link
Contributor

@ianstenbit ianstenbit Jun 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this sounds good to me

)

elif rows is not None and cols is not None:
if isinstance(images, tf.data.Dataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If rows and cols are both specified, why do we need to compute batch size? I'm not sure what the purpose of this branch is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done to take out the required number of images from the dataset

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It seems like it would be a lot easier to just use tfds.as_numpy(images.take(rows*cols)

(Probably have to optionally unbatch first)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, I think the logic would be much simpler if we separated the logic for converting a tf.data.Dataset into a numpy array from the logic for computing batch size / rows / cols. As-is there's some unnecessary indirection so the resulting code is more complex than it needs to be.

One simple approach would be to convert all possible input types to an unbatched iterator at the top of the method, and record the batch size if there was one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tfds is a special case. Most folks may also be invoking and creating tf.data.Dataset instances from a simple image_dataset_from_directory() call or from_tensor_slices() too. Using tfds.as_numpy might not be the best choice for that reason.

scale: how large to scale the images in the gallery
rows: (Optional) number of rows in the gallery to show. Required if inputs are unbatched.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After chatting with the team a bit, and thinking about this code some more, I think we should keep rows/cols required and get rid of the logic to try to infer how many images should be shown.

It seems reasonable to ask users for a number of rows/cols, and it makes the code nice and overt on the caller's side.

It also dramatically simplifies how this can be implemented, because we don't need to try to infer batch sizes for different data types, and to add support for tf.data.Datasets, we can just use tfds.as_numpy(dataset.take(rows*cols)) (after optionally unbatching the dataset 😄

Sorry for the churn on this, and thanks for the awesome work!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer an automatic rows and cols, mostly to remove many useless lines of code from our examples.
This can be implemented in just a few lines, even with a fancy aspect ratio adaptation:

N = len(images)
sqrtN = int(math.ceil(math.sqrt(N)))
aspect = sum([im.shape[1]/im.shape[0] for im in images]) / N # mean aspect ratio of images
fig = plt.figure(figsize=(15,15/aspect), frameon=False)

(I've been using this code in all my samples from the past couple of years)
And it's fine to require either both rows and cols or neither and not handle the fancy cases in between.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay this sgtm then -- @suvadityamuk I think this PR is pretty close then

Copy link
Contributor

@ianstenbit ianstenbit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See Martin's comment about rows/cols specification.

I think the code here would be a lot cleaner if we structured it like this:

def visualize(....):
    image_batch = _extract_image_batch(images, rows, cols) # Throws if input is invalid
    rows = rows or ...
    cols = cols or ...
    # actually do the plotting

Basically the current logic of intermixing the conditionals of input data type and whether rows/cols are specified gets a bit hairy and we should break out the logic to extract an image batch.

Using some guard clauses would make the code a lot easier to understand 🙂

scale: how large to scale the images in the gallery
rows: (Optional) number of rows in the gallery to show. Required if inputs are unbatched.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay this sgtm then -- @suvadityamuk I think this PR is pretty close then

"Passed `tf.data.Dataset` does not appear to be batched. Please batch using the `.batch().`"
)

images = images.map(unpack_images)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can drop this since sample["image"] is already what we want

images = images[:batch_size, ...]
else:
raise ValueError(
"plot_image_gallery() expects `tf.data.Dataset` to be batched if rows or cols are not specified."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This message still seems incorrect -- this branch will get executed even if one of rows or cols is set -- and in this case I think we can probably just infer the other. Maybe the and on L120 was meant to be or?

"plot_image_gallery() expects `tf.data.Dataset` to be batched if rows or cols are not specified."
)

rows = int(math.ceil(batch_size**0.5))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this overrides rows / cols even if they were specified. Let's use rows = rows or ... instead

addressed most comments, have rewritten the code from scratch to find a new and cleaner method to work with the images

Signed-off-by: Suvaditya Mukherjee <[email protected]>
@suvadityamuk
Copy link
Contributor Author

I have performed a ton of refactoring around the code to try and reduce the complexity of the code 😅.
Let me know if this is a better method to work with. This code handles more or less every case I could try, except for passing np.array in an unbatched format. Most other possibilities seem to be accounted for.

Copy link
Contributor

@ianstenbit ianstenbit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful -- I love the new structure.

Thank you for the PR, and for sticking with it despite the churn on my end 😄

@suvadityamuk
Copy link
Contributor Author

Aha, no worries at all, happy to help! 😅

Signed-off-by: Suvaditya Mukherjee <[email protected]>
@suvadityamuk
Copy link
Contributor Author

Just committed a linted fix!

@ianstenbit
Copy link
Contributor

Details

Looks like the linter is still unhappy @suvadityamuk

Signed-off-by: Suvaditya Mukherjee <[email protected]>
@suvadityamuk
Copy link
Contributor Author

Aha, the example was not linted, it seems. This new commit should do the trick, hopefully! 🤞

@suvadityamuk
Copy link
Contributor Author

suvadityamuk commented Jun 26, 2023

Seems like it needed that one more final "push", haha. Please check it out now, have run bash shell/format.sh, bash shell/lint.sh and fixed all the problems

@ianstenbit ianstenbit merged commit 1e08996 into keras-team:master Jun 26, 2023
ghost pushed a commit to y-vectorfield/keras-cv that referenced this pull request Nov 16, 2023
…hard dependency on `rows` and `cols` params (keras-team#1837)

* feat: add support for tf.data.dataset

removed hard dependency on having rows and cols params, added support for processing tf.data.Dataset internally

Signed-off-by: Suvaditya Mukherjee <[email protected]>

* chore: removed unused numpy import

Signed-off-by: Suvaditya Mukherjee <[email protected]>

* chore: added `tf.data.Dataset` to docstring

Signed-off-by: Suvaditya Mukherjee <[email protected]>

* chore: addressed comments and modified example file

Signed-off-by: Suvaditya Mukherjee <[email protected]>

* chore: addressed newer comments

updated error messages, added example with nparrays, corrected conditions

Signed-off-by: Suvaditya Mukherjee <[email protected]>

* fix: refactored code

addressed most comments, have rewritten the code from scratch to find a new and cleaner method to work with the images

Signed-off-by: Suvaditya Mukherjee <[email protected]>

* chore: fix with isort

Signed-off-by: Suvaditya Mukherjee <[email protected]>

* chore: fix linting on example

Signed-off-by: Suvaditya Mukherjee <[email protected]>

* chore: linting issues with black and isort solved for example and main file

Signed-off-by: Suvaditya Mukherjee <[email protected]>

---------

Signed-off-by: Suvaditya Mukherjee <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants