Skip to content

Add typing Annotations to detection/utils #4583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 12, 2021
50 changes: 49 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,55 @@ ignore_errors = True

ignore_errors=True

[mypy-torchvision.models.detection.*]
[mypy-torchvision.models.detection.anchor_utils]

ignore_errors = True

[mypy-torchvision.models.detection.backbone_utils]

ignore_errors = True

[mypy-torchvision.models.detection.image_list]

ignore_errors = True

[mypy-torchvision.models.detection.transform]

ignore_errors = True

[mypy-torchvision.models.detection.rpn]

ignore_errors = True

[mypy-torchvision.models.detection.roi_heads]

ignore_errors = True

[mypy-torchvision.models.detection.generalized_rcnn]

ignore_errors = True

[mypy-torchvision.models.detection.faster_rcnn]

ignore_errors = True

[mypy-torchvision.models.detection.mask_rcnn]

ignore_errors = True

[mypy-torchvision.models.detection.keypoint_rcnn]

ignore_errors = True

[mypy-torchvision.models.detection.retinanet]

ignore_errors = True

[mypy-torchvision.models.detection.ssd]

ignore_errors = True

[mypy-torchvision.models.detection.ssdlite]

ignore_errors = True

Expand Down
45 changes: 20 additions & 25 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Tuple

import torch
from torch import Tensor
from torch import Tensor, nn
from torchvision.ops.misc import FrozenBatchNorm2d


Expand All @@ -12,18 +12,16 @@ class BalancedPositiveNegativeSampler(object):
This class samples batches, ensuring that they contain a fixed proportion of positives
"""

def __init__(self, batch_size_per_image, positive_fraction):
# type: (int, float) -> None
def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
"""
Args:
batch_size_per_image (int): number of elements to be selected per image
positive_fraction (float): percentace of positive elements per batch
positive_fraction (float): percentage of positive elements per batch
"""
self.batch_size_per_image = batch_size_per_image
self.positive_fraction = positive_fraction

def __call__(self, matched_idxs):
# type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
"""
Args:
matched idxs: list of tensors containing -1, 0 or positive values.
Expand Down Expand Up @@ -73,8 +71,7 @@ def __call__(self, matched_idxs):


@torch.jit._script_if_tracing
def encode_boxes(reference_boxes, proposals, weights):
# type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
"""
Encode a set of proposals with respect to some
reference boxes
Expand Down Expand Up @@ -127,8 +124,9 @@ class BoxCoder(object):
the representation used for training the regressors.
"""

def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)):
# type: (Tuple[float, float, float, float], float) -> None
def __init__(
self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
) -> None:
"""
Args:
weights (4-element tuple)
Expand All @@ -137,15 +135,14 @@ def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)):
self.weights = weights
self.bbox_xform_clip = bbox_xform_clip

def encode(self, reference_boxes, proposals):
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
boxes_per_image = [len(b) for b in reference_boxes]
reference_boxes = torch.cat(reference_boxes, dim=0)
proposals = torch.cat(proposals, dim=0)
targets = self.encode_single(reference_boxes, proposals)
return targets.split(boxes_per_image, 0)

def encode_single(self, reference_boxes, proposals):
def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
"""
Encode a set of proposals with respect to some
reference boxes
Expand All @@ -161,8 +158,7 @@ def encode_single(self, reference_boxes, proposals):

return targets

def decode(self, rel_codes, boxes):
# type: (Tensor, List[Tensor]) -> Tensor
def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
assert isinstance(boxes, (list, tuple))
assert isinstance(rel_codes, torch.Tensor)
boxes_per_image = [b.size(0) for b in boxes]
Expand All @@ -177,7 +173,7 @@ def decode(self, rel_codes, boxes):
pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
return pred_boxes

def decode_single(self, rel_codes, boxes):
def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
"""
From a set of original boxes and encoded relative box offsets,
get the decoded boxes.
Expand Down Expand Up @@ -244,8 +240,7 @@ class Matcher(object):
"BETWEEN_THRESHOLDS": int,
}

def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
# type: (float, float, bool) -> None
def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
"""
Args:
high_threshold (float): quality values greater than or equal to
Expand All @@ -266,7 +261,7 @@ def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=Fals
self.low_threshold = low_threshold
self.allow_low_quality_matches = allow_low_quality_matches

def __call__(self, match_quality_matrix):
def __call__(self, match_quality_matrix: Tensor) -> Tensor:
"""
Args:
match_quality_matrix (Tensor[float]): an MxN tensor, containing the
Expand All @@ -290,7 +285,7 @@ def __call__(self, match_quality_matrix):
if self.allow_low_quality_matches:
all_matches = matches.clone()
else:
all_matches = None
all_matches = None # type: ignore[assignment]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should ignore this error, as we are directly assigning None to a Tensor object.

Copy link
Contributor

@khushi-411 khushi-411 Oct 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I was stuck at this point. Is there any other method of adding types here?
I read one issue (#40867). That's quite relatable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we have a better way as we are assigning all_matches on the fly to be either Tensor or None.
Where could we declare all_matches to be Optional[Tensor] ?

Previously too we had to use such workaround see https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/googlenet.py#L80

cc @pmeier

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If Optional[Tensor] is possible, than we should definitively go for that. Otherwise probably Tensor and ignore the error here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me have a quick try at that. I will get back in an hour.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't work well, see the below JIT error.


# Assign candidate matches with low quality to negative (unassigned) values
below_low_threshold = matched_vals < self.low_threshold
Expand All @@ -304,7 +299,7 @@ def __call__(self, match_quality_matrix):

return matches

def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
"""
Produce additional matches for predictions that have only low-quality matches.
Specifically, for each ground-truth find the set of predictions that have
Expand Down Expand Up @@ -335,10 +330,10 @@ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):


class SSDMatcher(Matcher):
def __init__(self, threshold):
def __init__(self, threshold: float) -> None:
super().__init__(threshold, threshold, allow_low_quality_matches=False)

def __call__(self, match_quality_matrix):
def __call__(self, match_quality_matrix: Tensor) -> Tensor:
matches = super().__call__(match_quality_matrix)

# For each gt, find the prediction with which it has the highest quality
Expand All @@ -350,7 +345,7 @@ def __call__(self, match_quality_matrix):
return matches


def overwrite_eps(model, eps):
def overwrite_eps(model: nn.Module, eps: float) -> None:
"""
This method overwrites the default eps values of all the
FrozenBatchNorm2d layers of the model with the provided value.
Expand All @@ -368,7 +363,7 @@ def overwrite_eps(model, eps):
module.eps = eps


def retrieve_out_channels(model, size):
def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
"""
This method retrieves the number of output channels of a specific model.

Expand Down