-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Similarity learning reference code #1101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #1101 +/- ##
==========================================
+ Coverage 64.57% 64.95% +0.38%
==========================================
Files 68 68
Lines 5411 5413 +2
Branches 831 835 +4
==========================================
+ Hits 3494 3516 +22
+ Misses 1665 1641 -24
- Partials 252 256 +4
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR!
I have made a few comments.
Also, is FashionMNIST a dataset that is generally used for embedding learning, or you just took it for an example? I really think that we should focus on some more realistic dataset, but we can change that in the future.
Also, it might be good adding a basic README explaining what the reference script is meant to do, so that users know what they are looking for. This is something that is missing for the other reference scripts, but I should add those in the future.
references/embedding/train.py
Outdated
train_dataset = FashionMNIST(args.train_data, train=True, transform=transform, download=True) | ||
test_dataset = FashionMNIST(args.test_data, train=False, transform=transform, download=True) | ||
|
||
targets = train_dataset.targets.tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comment here would be helpful, as this is generally something that the user will need to change if they change the dataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added comments mentioning that any classification dataset should be fine here as long as targets
is constructed as described. Is that sufficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment saying that it should be any classification dataset is misleading, because not all datasets have the .targets
attribute, even if they are classification datasets. Maybe just check that the dataset has a targets
attribute, and raise a nice error message if not?
@fmassa I have made all the requested changes, other than the helper method change (I'm not sure which section you wanted to make a method). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking very good, thanks!
I think some of the parts of this code (like the samplers) will be great to be moved to torchvision in the future, once we figure out where to put it. But for that, tests would be necessary.
I've a couple more comments, but I think this is almost ready to merge, thanks!
references/embedding/train.py
Outdated
train_dataset = FashionMNIST(args.train_data, train=True, transform=transform, download=True) | ||
test_dataset = FashionMNIST(args.test_data, train=False, transform=transform, download=True) | ||
|
||
targets = train_dataset.targets.tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment saying that it should be any classification dataset is misleading, because not all datasets have the .targets
attribute, even if they are classification datasets. Maybe just check that the dataset has a targets
attribute, and raise a nice error message if not?
references/embedding/sampler.py
Outdated
self.k = k | ||
self.groups = create_groups(groups, self.k) | ||
|
||
def __iter__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is ok as is because we are not yet adding samplers to the main library, but once we move it to the torchvision
package, it would be good to have tests for it.
If you could write a basic test now checking the behavior (with dummy data), it would make it much easier for moving this to torchvision core later on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fmassa I have written the test cases for the sampler and refactored the accuracy section into another method.
I don't think we should assume that a targets attribute exists. Instead I feel like I can explain the targets data structure better, so that users can construct targets accordingly (or use the targets attribute if it exists). I have changed the comments slightly, explaining what is expected of the targets variable. Will this be okay instead?
@fmassa I have made the changes. Instead of expecting a targets attribute from the dataset, I elaborated further on the semantics and requirements from the targets data structure, which users can build during or after they have initialized their dataset. Will that be ok? |
Renamed embedding to similarity to be more consistent with existing literature. Both are used, but similarity is more common. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks a lot!
@fmassa PR for embedding learning reference code as discussed in #1042. I decided not to create a VGGFace2 dataset for now, since that would require more thought and planning. For now I'm using FMNIST.