Skip to content

Adding multiweight support to FasterRCNN #4847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups


try:
from torchvision.prototype import models as PM
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporarily adding support of prototype on the detection reference script similar to all other areas, so we can test the results.

except ImportError:
PM = None


def get_dataset(name, image_set, transform, data_path):
paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)}
p, ds_fn, num_classes = paths[name]
Expand All @@ -41,8 +47,15 @@ def get_dataset(name, image_set, transform, data_path):
return ds, num_classes


def get_transform(train, data_augmentation):
return presets.DetectionPresetTrain(data_augmentation) if train else presets.DetectionPresetEval()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would it be clearer to have weights as explicit argument

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to avoid polluting massive parts of the references with weights references while pretrained is supported concurrently. Are you OK if we do this as part of the cleanup operations recorded at #4652?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that makes sense.

def get_transform(train, args):
if train:
return presets.DetectionPresetTrain(args.data_augmentation)
elif not args.weights:
return presets.DetectionPresetEval()
else:
fn = PM.detection.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
return weights.transforms()


def get_args_parser(add_help=True):
Expand Down Expand Up @@ -128,6 +141,9 @@ def get_args_parser(add_help=True):
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")

# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

return parser


Expand All @@ -143,10 +159,8 @@ def main(args):
# Data loading code
print("Loading data")

dataset, num_classes = get_dataset(
args.dataset, "train", get_transform(True, args.data_augmentation), args.data_path
)
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args.data_augmentation), args.data_path)
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path)
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path)

print("Creating data loaders")
if args.distributed:
Expand Down Expand Up @@ -175,9 +189,14 @@ def main(args):
if "rcnn" in args.model:
if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
model = torchvision.models.detection.__dict__[args.model](
num_classes=num_classes, pretrained=args.pretrained, **kwargs
)
if not args.weights:
model = torchvision.models.detection.__dict__[args.model](
pretrained=args.pretrained, num_classes=num_classes, **kwargs
)
else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand Down
14 changes: 14 additions & 0 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def test_classification_model(model_fn, dev):
TM.test_classification_model(model_fn, dev)


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.detection))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
def test_detection_model(model_fn, dev):
TM.test_detection_model(model_fn, dev)


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization))
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
def test_quantized_classification_model(model_fn):
Expand All @@ -71,6 +78,7 @@ def test_video_model(model_fn, dev):
@pytest.mark.parametrize(
"model_fn, module_name",
get_models_with_module_names(models)
+ get_models_with_module_names(models.detection)
+ get_models_with_module_names(models.quantization)
+ get_models_with_module_names(models.segmentation)
+ get_models_with_module_names(models.video),
Expand All @@ -82,6 +90,9 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
"models": {
"input_shape": (1, 3, 224, 224),
},
"detection": {
"input_shape": (3, 300, 300),
},
"quantization": {
"input_shape": (1, 3, 224, 224),
},
Expand All @@ -95,7 +106,10 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
model_name = model_fn.__name__
kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})}
input_shape = kwargs.pop("input_shape")
kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models
x = torch.rand(input_shape).to(device=dev)
if module_name == "detection":
x = [x]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Detection models receive a list of images.


# compare with new model builder parameterized in the old fashion way
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/shufflenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwa
if pretrained:
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError(f"pretrained {arch} is not supported as of now")
raise ValueError(f"No checkpoint is available for model type {arch}")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an unrelated correction, to align the exception type and value with the entire TorchVision so that we can capture it easier.

else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
Expand Down
152 changes: 148 additions & 4 deletions torchvision/prototype/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,72 @@
import warnings
from typing import Any, Optional
from typing import Any, Optional, Union

from ....models.detection.faster_rcnn import (
_validate_trainable_layers,
_mobilenet_extractor,
_resnet_fpn_extractor,
_validate_trainable_layers,
AnchorGenerator,
FasterRCNN,
misc_nn_ops,
overwrite_eps,
)
from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
from ..resnet import ResNet50Weights, resnet50


__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"]
__all__ = [
"FasterRCNN",
"FasterRCNNResNet50FPNWeights",
"FasterRCNNMobileNetV3LargeFPNWeights",
"FasterRCNNMobileNetV3Large320FPNWeights",
"fasterrcnn_resnet50_fpn",
"fasterrcnn_mobilenet_v3_large_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn",
]


_common_meta = {"categories": _COCO_CATEGORIES}


class FasterRCNNResNet50FPNWeights(Weights):
Coco_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval,
meta={
"categories": _COCO_CATEGORIES,
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0,
},
)


class FasterRCNNMobileNetV3LargeFPNWeights(Weights):
Coco_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
transforms=CocoEval,
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"map": 32.8,
},
)


class FasterRCNNMobileNetV3Large320FPNWeights(Weights):
Coco_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
transforms=CocoEval,
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"map": 22.8,
},
)


def fasterrcnn_resnet50_fpn(
weights: Optional[FasterRCNNResNet50FPNWeights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
Expand Down Expand Up @@ -64,3 +102,109 @@ def fasterrcnn_resnet50_fpn(
overwrite_eps(model, 0.0)

return model


def _fasterrcnn_mobilenet_v3_large_fpn(
weights: Optional[Union[FasterRCNNMobileNetV3LargeFPNWeights, FasterRCNNMobileNetV3Large320FPNWeights]] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True,
num_classes: int = 91,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])

trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 3
)

backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
anchor_sizes = (
(
32,
64,
128,
256,
512,
),
) * 3
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
model = FasterRCNN(
backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))

return model


def fasterrcnn_mobilenet_v3_large_fpn(
weights: Optional[FasterRCNNMobileNetV3LargeFPNWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True,
num_classes: int = 91,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will eventually be refactored to reduce duplicate code but that's part of the clean up.


defaults = {
"rpn_score_thresh": 0.05,
}

kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(
weights,
weights_backbone,
progress,
num_classes,
trainable_backbone_layers,
**kwargs,
)


def fasterrcnn_mobilenet_v3_large_320_fpn(
weights: Optional[FasterRCNNMobileNetV3Large320FPNWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True,
num_classes: int = 91,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)

defaults = {
"min_size": 320,
"max_size": 640,
"rpn_pre_nms_top_n_test": 150,
"rpn_post_nms_top_n_test": 150,
"rpn_score_thresh": 0.05,
}

kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(
weights,
weights_backbone,
progress,
num_classes,
trainable_backbone_layers,
**kwargs,
)