Skip to content

Similarity learning reference code #1101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 17, 2019
Merged

Conversation

dakshjotwani
Copy link
Contributor

@fmassa PR for embedding learning reference code as discussed in #1042. I decided not to create a VGGFace2 dataset for now, since that would require more thought and planning. For now I'm using FMNIST.

@codecov-io
Copy link

codecov-io commented Jul 8, 2019

Codecov Report

Merging #1101 into master will increase coverage by 0.38%.
The diff coverage is n/a.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1101      +/-   ##
==========================================
+ Coverage   64.57%   64.95%   +0.38%     
==========================================
  Files          68       68              
  Lines        5411     5413       +2     
  Branches      831      835       +4     
==========================================
+ Hits         3494     3516      +22     
+ Misses       1665     1641      -24     
- Partials      252      256       +4
Impacted Files Coverage Δ
torchvision/models/detection/roi_heads.py 55.93% <0%> (-0.97%) ⬇️
torchvision/ops/boxes.py 94.73% <0%> (ø) ⬆️
torchvision/transforms/transforms.py 81.53% <0%> (+0.98%) ⬆️
torchvision/datasets/fakedata.py 26.92% <0%> (+3.58%) ⬆️
torchvision/datasets/svhn.py 67.3% <0%> (+32.69%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 3483342...3cfc08b. Read the comment docs.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR!

I have made a few comments.

Also, is FashionMNIST a dataset that is generally used for embedding learning, or you just took it for an example? I really think that we should focus on some more realistic dataset, but we can change that in the future.

Also, it might be good adding a basic README explaining what the reference script is meant to do, so that users know what they are looking for. This is something that is missing for the other reference scripts, but I should add those in the future.

train_dataset = FashionMNIST(args.train_data, train=True, transform=transform, download=True)
test_dataset = FashionMNIST(args.test_data, train=False, transform=transform, download=True)

targets = train_dataset.targets.tolist()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment here would be helpful, as this is generally something that the user will need to change if they change the dataset

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added comments mentioning that any classification dataset should be fine here as long as targets is constructed as described. Is that sufficient?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment saying that it should be any classification dataset is misleading, because not all datasets have the .targets attribute, even if they are classification datasets. Maybe just check that the dataset has a targets attribute, and raise a nice error message if not?

@dakshjotwani
Copy link
Contributor Author

@fmassa I have made all the requested changes, other than the helper method change (I'm not sure which section you wanted to make a method).

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking very good, thanks!

I think some of the parts of this code (like the samplers) will be great to be moved to torchvision in the future, once we figure out where to put it. But for that, tests would be necessary.

I've a couple more comments, but I think this is almost ready to merge, thanks!

train_dataset = FashionMNIST(args.train_data, train=True, transform=transform, download=True)
test_dataset = FashionMNIST(args.test_data, train=False, transform=transform, download=True)

targets = train_dataset.targets.tolist()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment saying that it should be any classification dataset is misleading, because not all datasets have the .targets attribute, even if they are classification datasets. Maybe just check that the dataset has a targets attribute, and raise a nice error message if not?

self.k = k
self.groups = create_groups(groups, self.k)

def __iter__(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is ok as is because we are not yet adding samplers to the main library, but once we move it to the torchvision package, it would be good to have tests for it.

If you could write a basic test now checking the behavior (with dummy data), it would make it much easier for moving this to torchvision core later on.

Copy link
Contributor Author

@dakshjotwani dakshjotwani Jul 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa I have written the test cases for the sampler and refactored the accuracy section into another method.

I don't think we should assume that a targets attribute exists. Instead I feel like I can explain the targets data structure better, so that users can construct targets accordingly (or use the targets attribute if it exists). I have changed the comments slightly, explaining what is expected of the targets variable. Will this be okay instead?

@dakshjotwani
Copy link
Contributor Author

@fmassa I have made the changes. Instead of expecting a targets attribute from the dataset, I elaborated further on the semantics and requirements from the targets data structure, which users can build during or after they have initialized their dataset. Will that be ok?

@dakshjotwani dakshjotwani changed the title Embedding learning reference code Similarity learning reference code Jul 16, 2019
@dakshjotwani
Copy link
Contributor Author

Renamed embedding to similarity to be more consistent with existing literature. Both are used, but similarity is more common.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot!

@fmassa fmassa merged commit bbd363c into pytorch:master Jul 17, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants