-
Notifications
You must be signed in to change notification settings - Fork 543
Optim-wip: Consolidate CenterCrop & improve ToRGB #573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optim-wip: Consolidate CenterCrop & improve ToRGB #573
Conversation
a8301ac
to
7aec698
Compare
Thank you, @ProGamerGov! I'll take a look! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for working on this PR, @ProGamerGov! I left couple comments.
@NarineK I've made all the suggested changes, and it should be ready for merging now if there are no other issues! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for addressing the comments, @ProGamerGov! Posting couple more comments primarily related to type hints.
captum/optim/_utils/typing.py
Outdated
TransformValList = Union[Sequence[int], Sequence[float], Tensor] | ||
TransformVal = Union[int, float, Tensor] | ||
TransformSize = Union[List[int], Tuple[int], int] | ||
SquashFuncType = Callable[[Tensor], Tensor] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does [Tensor] mean a list of tensors ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's just how lambda functions are type hinted: Callable[[<in type>], <out type>]
.
" torchvision.transforms.RandomRotation(degrees=(-5,5)),\r\n", | ||
" optimviz.transform.RandomSpatialJitter(8),\r\n", | ||
" optimviz.transform.CenterCrop(16),\r\n", | ||
" optimviz.transform.CenterCrop((224,224)),\r\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question: Would this be equivalent to passing: optimviz.transform.CenterCrop(16, pixels_from_edges=True) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For removing the padding from torch.nn.ReflectionPad2d(16)
, the equivalent to optimviz.transform.CenterCrop(224, False)
, is optimviz.transform.CenterCrop(32, True)
(16 * 2 = 32).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, @ProGamerGov was the original line optimviz.transform.CenterCrop(16)
incorrect ? because it looks like the original cropped 8 from each side ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NarineK Yeah, it was! I mistakenly put 16 as I forgot about the division by 2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can change it so that 16 is the correct user input if you think it would be better.
Thank you for working on this PR, @ProGamerGov ! Merging ... |
I merged the two center crop algorithms into a single function and class.
ToRGB can now accept user matrices, and it can also be disabled in NaturalImage.