Skip to content

Update Reference scripts to support the prototype models #4837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
from torch import nn


try:
from torchvision.prototype import models as PM
except ImportError:
PM = None


def get_dataset(dir_path, name, image_set, transform):
def sbd(*args, **kwargs):
return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs)
Expand All @@ -26,11 +32,15 @@ def sbd(*args, **kwargs):
return ds, num_classes


def get_transform(train):
base_size = 520
crop_size = 480

return presets.SegmentationPresetTrain(base_size, crop_size) if train else presets.SegmentationPresetEval(base_size)
def get_transform(train, args):
if train:
return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
elif not args.weights:
return presets.SegmentationPresetEval(base_size=520)
else:
fn = PM.segmentation.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
return weights.transforms()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are in train mode, we always initialize the SegmentationPresetTrain. For validation if the weights are not defined (aka not a prototype model) then use the old preprocessing method for evaluation. Else use the one attached to the weights.



def criterion(inputs, target):
Expand Down Expand Up @@ -90,8 +100,8 @@ def main(args):

device = torch.device(args.device)

dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(train=True))
dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(train=False))
dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args))
dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args))

if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
Expand All @@ -113,9 +123,18 @@ def main(args):
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)

model = torchvision.models.segmentation.__dict__[args.model](
num_classes=num_classes, aux_loss=args.aux_loss, pretrained=args.pretrained
)
if not args.weights:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the weights are not defined, we use the standard way. Else it's a prototype run which means we will use the prototype model mechanism.

model = torchvision.models.segmentation.__dict__[args.model](
pretrained=args.pretrained,
num_classes=num_classes,
aux_loss=args.aux_loss,
)
else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.segmentation.__dict__[args.model](
weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss
)
model.to(device)
if args.distributed:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand Down Expand Up @@ -247,6 +266,9 @@ def get_args_parser(add_help=True):
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")

# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

return parser


Expand Down
23 changes: 21 additions & 2 deletions references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
amp = None


try:
from torchvision.prototype import models as PM
except ImportError:
PM = None


def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
Expand Down Expand Up @@ -149,7 +155,12 @@ def main(args):
print("Loading validation data")
cache_path = _get_cache_path(valdir)

transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
if not args.weights:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: pretrained and weights are overlapping and can be confusing. This ideally should be cleaned up in the future

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's exactly the plan. --pretrained will go away and --weights is going to be the right parameter. Right now we support both temporarily so that we can switch between the two completely different APIs. The --weights acts as a feature switch here.

transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
else:
fn = PM.video.__dict__[args.model]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we providing some sort of registration API to get the models without having to resort to __dict__ manipulations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeap, that's the plan. There will be a proper registration mechanism, possibly something similar to what was discussed here. There are still pending discussions with other domains, so I didn't want to adopt something before those discussions take place.

weights = PM._api.get_weight(fn, args.weights)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: using a private API here. We probably don't want to advertise private APIs in the references

Copy link
Contributor Author

@datumbox datumbox Nov 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is temporary and behind a feature switch (--weights). It's already recorded to clean up at #4652

Edit: I've edited the PR description to add some context. Also see #4734 for related discussion.

transform_test = weights.transforms()

if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_test from {cache_path}")
Expand Down Expand Up @@ -200,7 +211,12 @@ def main(args):
)

print("Creating model")
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
if not args.weights:
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.video.__dict__[args.model](weights=args.weights)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand Down Expand Up @@ -363,6 +379,9 @@ def parse_args():
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")

# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

args = parser.parse_args()

return args
Expand Down