-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Adding multiweight support to FasterRCNN #4847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
144c6cf
b1215fb
2829f98
46f30c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,12 @@ | |
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups | ||
|
||
|
||
try: | ||
from torchvision.prototype import models as PM | ||
except ImportError: | ||
PM = None | ||
|
||
|
||
def get_dataset(name, image_set, transform, data_path): | ||
paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)} | ||
p, ds_fn, num_classes = paths[name] | ||
|
@@ -41,8 +47,15 @@ def get_dataset(name, image_set, transform, data_path): | |
return ds, num_classes | ||
|
||
|
||
def get_transform(train, data_augmentation): | ||
return presets.DetectionPresetTrain(data_augmentation) if train else presets.DetectionPresetEval() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: would it be clearer to have weights as explicit argument There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would like to avoid polluting massive parts of the references with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes that makes sense. |
||
def get_transform(train, args): | ||
if train: | ||
return presets.DetectionPresetTrain(args.data_augmentation) | ||
elif not args.weights: | ||
return presets.DetectionPresetEval() | ||
else: | ||
fn = PM.detection.__dict__[args.model] | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
weights = PM._api.get_weight(fn, args.weights) | ||
return weights.transforms() | ||
|
||
|
||
def get_args_parser(add_help=True): | ||
|
@@ -128,6 +141,9 @@ def get_args_parser(add_help=True): | |
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") | ||
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") | ||
|
||
# Prototype models only | ||
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") | ||
|
||
return parser | ||
|
||
|
||
|
@@ -143,10 +159,8 @@ def main(args): | |
# Data loading code | ||
print("Loading data") | ||
|
||
dataset, num_classes = get_dataset( | ||
args.dataset, "train", get_transform(True, args.data_augmentation), args.data_path | ||
) | ||
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args.data_augmentation), args.data_path) | ||
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path) | ||
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path) | ||
|
||
print("Creating data loaders") | ||
if args.distributed: | ||
|
@@ -175,9 +189,14 @@ def main(args): | |
if "rcnn" in args.model: | ||
if args.rpn_score_thresh is not None: | ||
kwargs["rpn_score_thresh"] = args.rpn_score_thresh | ||
model = torchvision.models.detection.__dict__[args.model]( | ||
num_classes=num_classes, pretrained=args.pretrained, **kwargs | ||
) | ||
if not args.weights: | ||
model = torchvision.models.detection.__dict__[args.model]( | ||
pretrained=args.pretrained, num_classes=num_classes, **kwargs | ||
) | ||
else: | ||
if PM is None: | ||
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") | ||
model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs) | ||
model.to(device) | ||
if args.distributed and args.sync_bn: | ||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,13 @@ def test_classification_model(model_fn, dev): | |
TM.test_classification_model(model_fn, dev) | ||
|
||
|
||
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.detection)) | ||
@pytest.mark.parametrize("dev", cpu_and_gpu()) | ||
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled") | ||
def test_detection_model(model_fn, dev): | ||
TM.test_detection_model(model_fn, dev) | ||
|
||
|
||
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization)) | ||
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled") | ||
def test_quantized_classification_model(model_fn): | ||
|
@@ -71,6 +78,7 @@ def test_video_model(model_fn, dev): | |
@pytest.mark.parametrize( | ||
"model_fn, module_name", | ||
get_models_with_module_names(models) | ||
+ get_models_with_module_names(models.detection) | ||
+ get_models_with_module_names(models.quantization) | ||
+ get_models_with_module_names(models.segmentation) | ||
+ get_models_with_module_names(models.video), | ||
|
@@ -82,6 +90,9 @@ def test_old_vs_new_factory(model_fn, module_name, dev): | |
"models": { | ||
"input_shape": (1, 3, 224, 224), | ||
}, | ||
"detection": { | ||
"input_shape": (3, 300, 300), | ||
}, | ||
"quantization": { | ||
"input_shape": (1, 3, 224, 224), | ||
}, | ||
|
@@ -95,7 +106,10 @@ def test_old_vs_new_factory(model_fn, module_name, dev): | |
model_name = model_fn.__name__ | ||
kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})} | ||
input_shape = kwargs.pop("input_shape") | ||
kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models | ||
x = torch.rand(input_shape).to(device=dev) | ||
if module_name == "detection": | ||
x = [x] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Detection models receive a list of images. |
||
|
||
# compare with new model builder parameterized in the old fashion way | ||
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -162,7 +162,7 @@ def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwa | |
if pretrained: | ||
model_url = model_urls[arch] | ||
if model_url is None: | ||
raise NotImplementedError(f"pretrained {arch} is not supported as of now") | ||
raise ValueError(f"No checkpoint is available for model type {arch}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just an unrelated correction, to align the exception type and value with the entire TorchVision so that we can capture it easier. |
||
else: | ||
state_dict = load_state_dict_from_url(model_url, progress=progress) | ||
model.load_state_dict(state_dict) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,34 +1,72 @@ | ||
import warnings | ||
from typing import Any, Optional | ||
from typing import Any, Optional, Union | ||
|
||
from ....models.detection.faster_rcnn import ( | ||
_validate_trainable_layers, | ||
_mobilenet_extractor, | ||
_resnet_fpn_extractor, | ||
_validate_trainable_layers, | ||
AnchorGenerator, | ||
FasterRCNN, | ||
misc_nn_ops, | ||
overwrite_eps, | ||
) | ||
from ...transforms.presets import CocoEval | ||
from .._api import Weights, WeightEntry | ||
from .._meta import _COCO_CATEGORIES | ||
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large | ||
from ..resnet import ResNet50Weights, resnet50 | ||
|
||
|
||
__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"] | ||
__all__ = [ | ||
"FasterRCNN", | ||
"FasterRCNNResNet50FPNWeights", | ||
"FasterRCNNMobileNetV3LargeFPNWeights", | ||
"FasterRCNNMobileNetV3Large320FPNWeights", | ||
"fasterrcnn_resnet50_fpn", | ||
"fasterrcnn_mobilenet_v3_large_fpn", | ||
"fasterrcnn_mobilenet_v3_large_320_fpn", | ||
] | ||
|
||
|
||
_common_meta = {"categories": _COCO_CATEGORIES} | ||
|
||
|
||
class FasterRCNNResNet50FPNWeights(Weights): | ||
Coco_RefV1 = WeightEntry( | ||
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", | ||
transforms=CocoEval, | ||
meta={ | ||
"categories": _COCO_CATEGORIES, | ||
**_common_meta, | ||
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", | ||
"map": 37.0, | ||
}, | ||
) | ||
|
||
|
||
class FasterRCNNMobileNetV3LargeFPNWeights(Weights): | ||
Coco_RefV1 = WeightEntry( | ||
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", | ||
transforms=CocoEval, | ||
meta={ | ||
**_common_meta, | ||
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn", | ||
"map": 32.8, | ||
}, | ||
) | ||
|
||
|
||
class FasterRCNNMobileNetV3Large320FPNWeights(Weights): | ||
Coco_RefV1 = WeightEntry( | ||
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", | ||
transforms=CocoEval, | ||
meta={ | ||
**_common_meta, | ||
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn", | ||
"map": 22.8, | ||
}, | ||
) | ||
|
||
|
||
def fasterrcnn_resnet50_fpn( | ||
weights: Optional[FasterRCNNResNet50FPNWeights] = None, | ||
weights_backbone: Optional[ResNet50Weights] = None, | ||
|
@@ -64,3 +102,109 @@ def fasterrcnn_resnet50_fpn( | |
overwrite_eps(model, 0.0) | ||
|
||
return model | ||
|
||
|
||
def _fasterrcnn_mobilenet_v3_large_fpn( | ||
weights: Optional[Union[FasterRCNNMobileNetV3LargeFPNWeights, FasterRCNNMobileNetV3Large320FPNWeights]] = None, | ||
weights_backbone: Optional[MobileNetV3LargeWeights] = None, | ||
progress: bool = True, | ||
num_classes: int = 91, | ||
trainable_backbone_layers: Optional[int] = None, | ||
**kwargs: Any, | ||
) -> FasterRCNN: | ||
if weights is not None: | ||
weights_backbone = None | ||
num_classes = len(weights.meta["categories"]) | ||
|
||
trainable_backbone_layers = _validate_trainable_layers( | ||
weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 3 | ||
) | ||
|
||
backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) | ||
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) | ||
anchor_sizes = ( | ||
( | ||
32, | ||
64, | ||
128, | ||
256, | ||
512, | ||
), | ||
) * 3 | ||
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) | ||
model = FasterRCNN( | ||
backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs | ||
) | ||
|
||
if weights is not None: | ||
model.load_state_dict(weights.state_dict(progress=progress)) | ||
|
||
return model | ||
|
||
|
||
def fasterrcnn_mobilenet_v3_large_fpn( | ||
weights: Optional[FasterRCNNMobileNetV3LargeFPNWeights] = None, | ||
weights_backbone: Optional[MobileNetV3LargeWeights] = None, | ||
progress: bool = True, | ||
num_classes: int = 91, | ||
trainable_backbone_layers: Optional[int] = None, | ||
**kwargs: Any, | ||
) -> FasterRCNN: | ||
if "pretrained" in kwargs: | ||
warnings.warn("The argument pretrained is deprecated, please use weights instead.") | ||
weights = FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None | ||
weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights) | ||
if "pretrained_backbone" in kwargs: | ||
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") | ||
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None | ||
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will eventually be refactored to reduce duplicate code but that's part of the clean up. |
||
|
||
defaults = { | ||
"rpn_score_thresh": 0.05, | ||
} | ||
|
||
kwargs = {**defaults, **kwargs} | ||
return _fasterrcnn_mobilenet_v3_large_fpn( | ||
weights, | ||
weights_backbone, | ||
progress, | ||
num_classes, | ||
trainable_backbone_layers, | ||
**kwargs, | ||
) | ||
|
||
|
||
def fasterrcnn_mobilenet_v3_large_320_fpn( | ||
weights: Optional[FasterRCNNMobileNetV3Large320FPNWeights] = None, | ||
weights_backbone: Optional[MobileNetV3LargeWeights] = None, | ||
progress: bool = True, | ||
num_classes: int = 91, | ||
trainable_backbone_layers: Optional[int] = None, | ||
**kwargs: Any, | ||
) -> FasterRCNN: | ||
if "pretrained" in kwargs: | ||
warnings.warn("The argument pretrained is deprecated, please use weights instead.") | ||
weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None | ||
weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights) | ||
if "pretrained_backbone" in kwargs: | ||
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") | ||
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None | ||
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) | ||
|
||
defaults = { | ||
"min_size": 320, | ||
"max_size": 640, | ||
"rpn_pre_nms_top_n_test": 150, | ||
"rpn_post_nms_top_n_test": 150, | ||
"rpn_score_thresh": 0.05, | ||
} | ||
|
||
kwargs = {**defaults, **kwargs} | ||
return _fasterrcnn_mobilenet_v3_large_fpn( | ||
weights, | ||
weights_backbone, | ||
progress, | ||
num_classes, | ||
trainable_backbone_layers, | ||
**kwargs, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Temporarily adding support of prototype on the detection reference script similar to all other areas, so we can test the results.