From d1023bbd888c217143e43a886bea8fade92767f6 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 27 Oct 2021 12:13:17 +0100 Subject: [PATCH 1/7] adding multiweight support for deeplabv3 prototype models --- .../prototype/models/segmentation/__init__.py | 1 + .../models/segmentation/deeplabv3.py | 169 ++++++++++++++++++ 2 files changed, 170 insertions(+) create mode 100644 torchvision/prototype/models/segmentation/deeplabv3.py diff --git a/torchvision/prototype/models/segmentation/__init__.py b/torchvision/prototype/models/segmentation/__init__.py index bb3b6697514..20273be2170 100644 --- a/torchvision/prototype/models/segmentation/__init__.py +++ b/torchvision/prototype/models/segmentation/__init__.py @@ -1,2 +1,3 @@ from .fcn import * from .lraspp import * +from .deeplabv3 import * diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py new file mode 100644 index 00000000000..e49247cd96e --- /dev/null +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -0,0 +1,169 @@ +import warnings +from functools import partial +from typing import Any, Callable, Optional, Type, Dict + +from torchvision.prototype.models.resnet import resnet50, resnet101 + +from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet +from ...transforms.presets import VocEval +from .._api import Weights, WeightEntry +from .._meta import _VOC_CATEGORIES +from ..mobilenetv3 import MobileNetV3LargeWeights +from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large +from ..resnet import ResNet50Weights, ResNet101Weights + + +__all__ = [ + "DeepLabV3", + "DeepLabV3ResNet50Weights", + "DeepLabV3ResNet101Weights", + "DeepLabV3MobileNetV3LargeWeights", + "deeplabv3_mobilenet_v3_large", + "deeplabv3_resnet50", + "deeplabv3_resnet101", +] + + +class DeepLabV3ResNet50Weights(Weights): + CocoWithVocLabels_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", + transforms=partial(VocEval, resize_size=520), + meta={ + "categories": _VOC_CATEGORIES, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50", + "mIoU": 66.4, + "acc": 92.4, + }, + ) + + +class DeepLabV3ResNet101Weights(Weights): + CocoWithVocLabels_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", + transforms=partial(VocEval, resize_size=520), + meta={ + "categories": _VOC_CATEGORIES, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101", + "mIoU": 67.4, + "acc": 92.4, + }, + ) + + +class DeepLabV3MobileNetV3LargeWeights(Weights): + CocoWithVocLabels_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", + transforms=partial(VocEval, resize_size=520), + meta={ + "categories": _VOC_CATEGORIES, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large", + "mIoU": 60.3, + "acc": 91.2, + }, + ) + + +def _deeplabv3( + weights_class: Type[Weights], + weights: Weights, + model_builder: Callable, + weights_backbone_class: Type[Weights], + weights_backbone: Weights, + backbone_model_builder: Callable, + progress: bool, + num_classes: int, + backbone_args: Dict[str, Any], + aux_loss: Optional[bool] = None, + **kwargs: Any, +) -> DeepLabV3: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = weights_class.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None + + weights = weights_class.verify(weights) + if "pretrained_backbone" in kwargs: + warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") + weights_backbone = weights_backbone_class.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None + weights_backbone = weights_backbone_class.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + aux_loss = True + num_classes = len(weights.meta["categories"]) + + backbone = backbone_model_builder(weights=weights_backbone, **backbone_args) + model = model_builder(backbone, num_classes, aux_loss) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + + return model + + +def deeplabv3_resnet50( + weights: Optional[DeepLabV3ResNet50Weights] = None, + weights_backbone: Optional[ResNet50Weights] = None, + progress: bool = True, + num_classes: int = 21, + aux_loss: Optional[bool] = None, + **kwargs: Any, +) -> DeepLabV3: + return _deeplabv3( + DeepLabV3ResNet50Weights, + weights, + _deeplabv3_resnet, + ResNet50Weights, + weights_backbone, + resnet50, + progress, + num_classes, + {"replace_stride_with_dilation": [False, True, True]}, + aux_loss, + kwargs=kwargs, + ) + + +def deeplabv3_resnet101( + weights: Optional[DeepLabV3ResNet50Weights] = None, + weights_backbone: Optional[ResNet50Weights] = None, + progress: bool = True, + num_classes: int = 21, + aux_loss: Optional[bool] = None, + **kwargs: Any, +) -> DeepLabV3: + return _deeplabv3( + DeepLabV3ResNet101Weights, + weights, + _deeplabv3_resnet, + ResNet101Weights, + weights_backbone, + resnet101, + progress, + num_classes, + {"replace_stride_with_dilation": [False, True, True]}, + aux_loss, + kwargs=kwargs, + ) + + +def deeplabv3_mobilenet_v3_large( + weights: Optional[DeepLabV3MobileNetV3LargeWeights] = None, + weights_backbone: Optional[MobileNetV3LargeWeights] = None, + progress: bool = True, + num_classes: int = 21, + aux_loss: Optional[bool] = None, + **kwargs: Any, +) -> DeepLabV3: + return _deeplabv3( + DeepLabV3MobileNetV3LargeWeights, + weights, + _deeplabv3_mobilenetv3, + MobileNetV3LargeWeights, + weights_backbone, + mobilenet_v3_large, + progress, + num_classes, + {"dilated": True}, + aux_loss, + kwargs=kwargs, + ) From e1f25d129fe20326e146e7a63bf4142db58f1d32 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 27 Oct 2021 12:24:43 +0100 Subject: [PATCH 2/7] adding default values for optional params --- .../prototype/models/segmentation/deeplabv3.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py index e49247cd96e..8a374fe9329 100644 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -8,7 +8,6 @@ from ...transforms.presets import VocEval from .._api import Weights, WeightEntry from .._meta import _VOC_CATEGORIES -from ..mobilenetv3 import MobileNetV3LargeWeights from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large from ..resnet import ResNet50Weights, ResNet101Weights @@ -65,14 +64,14 @@ class DeepLabV3MobileNetV3LargeWeights(Weights): def _deeplabv3( weights_class: Type[Weights], - weights: Weights, model_builder: Callable, weights_backbone_class: Type[Weights], - weights_backbone: Weights, backbone_model_builder: Callable, progress: bool, num_classes: int, backbone_args: Dict[str, Any], + weights: Optional[Weights] = None, + weights_backbone: Optional[Weights] = None, aux_loss: Optional[bool] = None, **kwargs: Any, ) -> DeepLabV3: @@ -110,14 +109,14 @@ def deeplabv3_resnet50( ) -> DeepLabV3: return _deeplabv3( DeepLabV3ResNet50Weights, - weights, _deeplabv3_resnet, ResNet50Weights, - weights_backbone, resnet50, progress, num_classes, {"replace_stride_with_dilation": [False, True, True]}, + weights, + weights_backbone, aux_loss, kwargs=kwargs, ) @@ -133,14 +132,14 @@ def deeplabv3_resnet101( ) -> DeepLabV3: return _deeplabv3( DeepLabV3ResNet101Weights, - weights, _deeplabv3_resnet, ResNet101Weights, - weights_backbone, resnet101, progress, num_classes, {"replace_stride_with_dilation": [False, True, True]}, + weights, + weights_backbone, aux_loss, kwargs=kwargs, ) @@ -156,14 +155,14 @@ def deeplabv3_mobilenet_v3_large( ) -> DeepLabV3: return _deeplabv3( DeepLabV3MobileNetV3LargeWeights, - weights, _deeplabv3_mobilenetv3, MobileNetV3LargeWeights, - weights_backbone, mobilenet_v3_large, progress, num_classes, {"dilated": True}, + weights, aux_loss, + weights_backbone, kwargs=kwargs, ) From 9e9905fda0a507dcdcc389d619f323df10f5ce15 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 27 Oct 2021 12:42:10 +0100 Subject: [PATCH 3/7] fixing bug --- torchvision/prototype/models/segmentation/deeplabv3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py index 8a374fe9329..9dc32f81f52 100644 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -162,7 +162,7 @@ def deeplabv3_mobilenet_v3_large( num_classes, {"dilated": True}, weights, - aux_loss, weights_backbone, + aux_loss, kwargs=kwargs, ) From 81e888dd70935cb659ebb7611c90b2fde2db340f Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 27 Oct 2021 14:41:06 +0100 Subject: [PATCH 4/7] addressing PR comment --- .../models/segmentation/deeplabv3.py | 127 ++++++++---------- 1 file changed, 58 insertions(+), 69 deletions(-) diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py index 9dc32f81f52..857e5e4c2de 100644 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -1,14 +1,13 @@ import warnings from functools import partial -from typing import Any, Callable, Optional, Type, Dict - -from torchvision.prototype.models.resnet import resnet50, resnet101 +from typing import Any, Optional from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet from ...transforms.presets import VocEval from .._api import Weights, WeightEntry from .._meta import _VOC_CATEGORIES from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large +from ..resnet import resnet50, resnet101 from ..resnet import ResNet50Weights, ResNet101Weights @@ -62,36 +61,31 @@ class DeepLabV3MobileNetV3LargeWeights(Weights): ) -def _deeplabv3( - weights_class: Type[Weights], - model_builder: Callable, - weights_backbone_class: Type[Weights], - backbone_model_builder: Callable, - progress: bool, - num_classes: int, - backbone_args: Dict[str, Any], - weights: Optional[Weights] = None, - weights_backbone: Optional[Weights] = None, +def deeplabv3_resnet50( + weights: Optional[DeepLabV3ResNet50Weights] = None, + weights_backbone: Optional[ResNet50Weights] = None, + progress: bool = True, + num_classes: int = 21, aux_loss: Optional[bool] = None, **kwargs: Any, ) -> DeepLabV3: if "pretrained" in kwargs: warnings.warn("The argument pretrained is deprecated, please use weights instead.") - weights = weights_class.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None + weights = DeepLabV3ResNet50Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None - weights = weights_class.verify(weights) + weights = DeepLabV3ResNet50Weights.verify(weights) if "pretrained_backbone" in kwargs: warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") - weights_backbone = weights_backbone_class.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None - weights_backbone = weights_backbone_class.verify(weights_backbone) + weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None + weights_backbone = ResNet50Weights.verify(weights_backbone) if weights is not None: weights_backbone = None aux_loss = True num_classes = len(weights.meta["categories"]) - backbone = backbone_model_builder(weights=weights_backbone, **backbone_args) - model = model_builder(backbone, num_classes, aux_loss) + backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) + model = _deeplabv3_resnet(backbone, num_classes, aux_loss) if weights is not None: model.load_state_dict(weights.state_dict(progress=progress)) @@ -99,50 +93,36 @@ def _deeplabv3( return model -def deeplabv3_resnet50( - weights: Optional[DeepLabV3ResNet50Weights] = None, - weights_backbone: Optional[ResNet50Weights] = None, +def deeplabv3_resnet101( + weights: Optional[DeepLabV3ResNet101Weights] = None, + weights_backbone: Optional[ResNet101Weights] = None, progress: bool = True, num_classes: int = 21, aux_loss: Optional[bool] = None, **kwargs: Any, ) -> DeepLabV3: - return _deeplabv3( - DeepLabV3ResNet50Weights, - _deeplabv3_resnet, - ResNet50Weights, - resnet50, - progress, - num_classes, - {"replace_stride_with_dilation": [False, True, True]}, - weights, - weights_backbone, - aux_loss, - kwargs=kwargs, - ) + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = DeepLabV3ResNet101Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None + weights = DeepLabV3ResNet101Weights.verify(weights) + if "pretrained_backbone" in kwargs: + warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") + weights_backbone = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None + weights_backbone = ResNet101Weights.verify(weights_backbone) -def deeplabv3_resnet101( - weights: Optional[DeepLabV3ResNet50Weights] = None, - weights_backbone: Optional[ResNet50Weights] = None, - progress: bool = True, - num_classes: int = 21, - aux_loss: Optional[bool] = None, - **kwargs: Any, -) -> DeepLabV3: - return _deeplabv3( - DeepLabV3ResNet101Weights, - _deeplabv3_resnet, - ResNet101Weights, - resnet101, - progress, - num_classes, - {"replace_stride_with_dilation": [False, True, True]}, - weights, - weights_backbone, - aux_loss, - kwargs=kwargs, - ) + if weights is not None: + weights_backbone = None + aux_loss = True + num_classes = len(weights.meta["categories"]) + + backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) + model = _deeplabv3_resnet(backbone, num_classes, aux_loss) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + + return model def deeplabv3_mobilenet_v3_large( @@ -153,16 +133,25 @@ def deeplabv3_mobilenet_v3_large( aux_loss: Optional[bool] = None, **kwargs: Any, ) -> DeepLabV3: - return _deeplabv3( - DeepLabV3MobileNetV3LargeWeights, - _deeplabv3_mobilenetv3, - MobileNetV3LargeWeights, - mobilenet_v3_large, - progress, - num_classes, - {"dilated": True}, - weights, - weights_backbone, - aux_loss, - kwargs=kwargs, - ) + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = DeepLabV3MobileNetV3LargeWeights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None + + weights = DeepLabV3MobileNetV3LargeWeights.verify(weights) + if "pretrained_backbone" in kwargs: + warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") + weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None + weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + aux_loss = True + num_classes = len(weights.meta["categories"]) + + backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) + model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + + return model From e826f572ac4affe0f1315c33ad3cfc03d875a158 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 27 Oct 2021 15:29:54 +0100 Subject: [PATCH 5/7] fixing seed in test_batched_nms_implementations --- test/test_ops.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 892496dffca..d7223e1e165 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -555,8 +555,10 @@ def test_nms_cuda_float16(self): keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres) assert_equal(keep32, keep16) - def test_batched_nms_implementations(self): + @pytest.mark.parametrize("seed", range(10)) + def test_batched_nms_implementations(self, seed): """Make sure that both implementations of batched_nms yield identical results""" + torch.random.manual_seed(seed) num_boxes = 1000 iou_threshold = 0.9 From cd68508e931d41091a68a557c75ef3a587a9531b Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Thu, 28 Oct 2021 09:59:44 +0100 Subject: [PATCH 6/7] change seeds in test_batched_nms_implementations --- test/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index b5cd7381539..77d5315711f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -557,7 +557,7 @@ def test_nms_cuda_float16(self): keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres) assert_equal(keep32, keep16) - @pytest.mark.parametrize("seed", range(10)) + @pytest.mark.parametrize("seed", range(20, 30)) def test_batched_nms_implementations(self, seed): """Make sure that both implementations of batched_nms yield identical results""" torch.random.manual_seed(seed) From dcf1ebaff0f7c2abed82683c92805fd57751b225 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Thu, 28 Oct 2021 10:54:44 +0100 Subject: [PATCH 7/7] change seeds in test_batched_nms_implementations --- test/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 77d5315711f..b5cd7381539 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -557,7 +557,7 @@ def test_nms_cuda_float16(self): keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres) assert_equal(keep32, keep16) - @pytest.mark.parametrize("seed", range(20, 30)) + @pytest.mark.parametrize("seed", range(10)) def test_batched_nms_implementations(self, seed): """Make sure that both implementations of batched_nms yield identical results""" torch.random.manual_seed(seed)