-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Update Reference scripts to support the prototype models #4837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit 55ddb93 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some clarifications below:
else: | ||
fn = PM.segmentation.__dict__[args.model] | ||
weights = PM._api.get_weight(fn, args.weights) | ||
return weights.transforms() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are in train mode, we always initialize the SegmentationPresetTrain. For validation if the weights are not defined (aka not a prototype model) then use the old preprocessing method for evaluation. Else use the one attached to the weights.
model = torchvision.models.segmentation.__dict__[args.model]( | ||
num_classes=num_classes, aux_loss=args.aux_loss, pretrained=args.pretrained | ||
) | ||
if not args.weights: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the weights are not defined, we use the standard way. Else it's a prototype run which means we will use the prototype model mechanism.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving to unblock
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) | ||
else: | ||
fn = PM.video.__dict__[args.model] | ||
weights = PM._api.get_weight(fn, args.weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: using a private API here. We probably don't want to advertise private APIs in the references
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -149,7 +155,12 @@ def main(args): | |||
print("Loading validation data") | |||
cache_path = _get_cache_path(valdir) | |||
|
|||
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) | |||
if not args.weights: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: pretrained
and weights
are overlapping and can be confusing. This ideally should be cleaned up in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's exactly the plan. --pretrained
will go away and --weights
is going to be the right parameter. Right now we support both temporarily so that we can switch between the two completely different APIs. The --weights
acts as a feature switch here.
if not args.weights: | ||
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) | ||
else: | ||
fn = PM.video.__dict__[args.model] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we providing some sort of registration API to get the models without having to resort to __dict__
manipulations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeap, that's the plan. There will be a proper registration mechanism, possibly something similar to what was discussed here. There are still pending discussions with other domains, so I didn't want to adopt something before those discussions take place.
* Adding prototype preprocessing on segmentation references. * Adding prototype preprocessing on video references.
Fixes #4671
This PR adds a similar mechanism as in
classification
forsegmentation
andvideo
. The target is to enable us to test the new model weights API (+ presets) and confirm it returns the same results as the old one. The co-existence of--pretrained
and--weights
is temporary and allows us to test that all models we introduce produce the expected results.The approach is not perfect as it exposes the
prototype
stuff in the example reference scripts but the alternative would be to duplicate the reference scripts or keep a separate branch with their modifications which makes the work cumbersome. These will be cleaned up prior to adopting the new API, see #4652 and #4679.cc @datumbox @bjuncek