-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add typing Annotations to detection/utils #4583
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
50ab495
dab9402
5f7d0d9
d017161
8e83805
f6bb442
c4c2884
f157190
f1df465
c8fa413
f78dd49
4df6ccf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
from typing import List, Tuple | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch import Tensor, nn | ||
from torchvision.ops.misc import FrozenBatchNorm2d | ||
|
||
|
||
|
@@ -12,18 +12,16 @@ class BalancedPositiveNegativeSampler(object): | |
This class samples batches, ensuring that they contain a fixed proportion of positives | ||
""" | ||
|
||
def __init__(self, batch_size_per_image, positive_fraction): | ||
# type: (int, float) -> None | ||
def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None: | ||
oke-aditya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Args: | ||
batch_size_per_image (int): number of elements to be selected per image | ||
positive_fraction (float): percentace of positive elements per batch | ||
positive_fraction (float): percentage of positive elements per batch | ||
oke-aditya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
self.batch_size_per_image = batch_size_per_image | ||
self.positive_fraction = positive_fraction | ||
|
||
def __call__(self, matched_idxs): | ||
# type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] | ||
def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: | ||
""" | ||
Args: | ||
matched idxs: list of tensors containing -1, 0 or positive values. | ||
|
@@ -73,8 +71,7 @@ def __call__(self, matched_idxs): | |
|
||
|
||
@torch.jit._script_if_tracing | ||
def encode_boxes(reference_boxes, proposals, weights): | ||
# type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor | ||
def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor: | ||
""" | ||
Encode a set of proposals with respect to some | ||
reference boxes | ||
|
@@ -127,8 +124,9 @@ class BoxCoder(object): | |
the representation used for training the regressors. | ||
""" | ||
|
||
def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)): | ||
# type: (Tuple[float, float, float, float], float) -> None | ||
def __init__( | ||
self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16) | ||
) -> None: | ||
""" | ||
Args: | ||
weights (4-element tuple) | ||
|
@@ -137,15 +135,14 @@ def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)): | |
self.weights = weights | ||
self.bbox_xform_clip = bbox_xform_clip | ||
|
||
def encode(self, reference_boxes, proposals): | ||
# type: (List[Tensor], List[Tensor]) -> List[Tensor] | ||
def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]: | ||
boxes_per_image = [len(b) for b in reference_boxes] | ||
reference_boxes = torch.cat(reference_boxes, dim=0) | ||
proposals = torch.cat(proposals, dim=0) | ||
targets = self.encode_single(reference_boxes, proposals) | ||
return targets.split(boxes_per_image, 0) | ||
|
||
def encode_single(self, reference_boxes, proposals): | ||
def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: | ||
""" | ||
Encode a set of proposals with respect to some | ||
reference boxes | ||
|
@@ -161,8 +158,7 @@ def encode_single(self, reference_boxes, proposals): | |
|
||
return targets | ||
|
||
def decode(self, rel_codes, boxes): | ||
# type: (Tensor, List[Tensor]) -> Tensor | ||
def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor: | ||
assert isinstance(boxes, (list, tuple)) | ||
assert isinstance(rel_codes, torch.Tensor) | ||
boxes_per_image = [b.size(0) for b in boxes] | ||
|
@@ -177,7 +173,7 @@ def decode(self, rel_codes, boxes): | |
pred_boxes = pred_boxes.reshape(box_sum, -1, 4) | ||
return pred_boxes | ||
|
||
def decode_single(self, rel_codes, boxes): | ||
def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: | ||
""" | ||
From a set of original boxes and encoded relative box offsets, | ||
get the decoded boxes. | ||
|
@@ -244,8 +240,7 @@ class Matcher(object): | |
"BETWEEN_THRESHOLDS": int, | ||
} | ||
|
||
def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False): | ||
# type: (float, float, bool) -> None | ||
def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None: | ||
""" | ||
Args: | ||
high_threshold (float): quality values greater than or equal to | ||
|
@@ -266,7 +261,7 @@ def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=Fals | |
self.low_threshold = low_threshold | ||
self.allow_low_quality_matches = allow_low_quality_matches | ||
|
||
def __call__(self, match_quality_matrix): | ||
def __call__(self, match_quality_matrix: Tensor) -> Tensor: | ||
""" | ||
Args: | ||
match_quality_matrix (Tensor[float]): an MxN tensor, containing the | ||
|
@@ -290,7 +285,7 @@ def __call__(self, match_quality_matrix): | |
if self.allow_low_quality_matches: | ||
all_matches = matches.clone() | ||
else: | ||
all_matches = None | ||
all_matches = None # type: ignore[assignment] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should ignore this error, as we are directly assigning None to a Tensor object. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I was stuck at this point. Is there any other method of adding types here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if we have a better way as we are assigning all_matches on the fly to be either Tensor or None. Previously too we had to use such workaround see https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/googlenet.py#L80 cc @pmeier There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me have a quick try at that. I will get back in an hour. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Didn't work well, see the below JIT error. |
||
|
||
# Assign candidate matches with low quality to negative (unassigned) values | ||
below_low_threshold = matched_vals < self.low_threshold | ||
|
@@ -304,7 +299,7 @@ def __call__(self, match_quality_matrix): | |
|
||
return matches | ||
|
||
def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix): | ||
def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None: | ||
""" | ||
Produce additional matches for predictions that have only low-quality matches. | ||
Specifically, for each ground-truth find the set of predictions that have | ||
|
@@ -335,10 +330,10 @@ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix): | |
|
||
|
||
class SSDMatcher(Matcher): | ||
def __init__(self, threshold): | ||
def __init__(self, threshold: float) -> None: | ||
oke-aditya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
super().__init__(threshold, threshold, allow_low_quality_matches=False) | ||
oke-aditya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __call__(self, match_quality_matrix): | ||
def __call__(self, match_quality_matrix: Tensor) -> Tensor: | ||
matches = super().__call__(match_quality_matrix) | ||
|
||
# For each gt, find the prediction with which it has the highest quality | ||
|
@@ -350,7 +345,7 @@ def __call__(self, match_quality_matrix): | |
return matches | ||
|
||
|
||
def overwrite_eps(model, eps): | ||
def overwrite_eps(model: nn.Module, eps: float) -> None: | ||
oke-aditya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
This method overwrites the default eps values of all the | ||
FrozenBatchNorm2d layers of the model with the provided value. | ||
|
@@ -368,7 +363,7 @@ def overwrite_eps(model, eps): | |
module.eps = eps | ||
|
||
|
||
def retrieve_out_channels(model, size): | ||
def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]: | ||
""" | ||
This method retrieves the number of output channels of a specific model. | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.