Skip to content

Boxes with negative scores in NMS input? #3198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
masahi opened this issue Dec 21, 2020 · 12 comments
Closed

Boxes with negative scores in NMS input? #3198

masahi opened this issue Dec 21, 2020 · 12 comments

Comments

@masahi
Copy link
Contributor

masahi commented Dec 21, 2020

Hi, I found that the use of NMS in RegionProposalNetwork can take on boxes with negative scores as inputs. I found this when running MaskRCNN in v0.8 release.

keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)

In other use of NMS in ROIHeads, scores are thresholded to keep only boxes with positive scores:

inds = torch.where(scores > self.score_thresh)[0]

I'm wondering if that lack of score thresholding in RPN is intentional or not... In TVM, we expects NMS input with negative scores to be invalid. Since NMS in PyTorch doesn't have a score threshold parameter, we didn't realize that there could be boxes with negative scores.

I proposed to fix TVM's NMS conversion in apache/tvm#7137, but since it would have a big performance implication and I heard that negative boxes don't matter in the final output anyway, I'm now inclined not to fix this in TVM side.

cc @fmassa @t-vi

@datumbox
Copy link
Contributor

I believe that the original intention of the following snippet was to speed up the execution of the NMS op rather than filtering explicitly any non positive value:

# remove low scoring boxes
inds = torch.where(scores > self.score_thresh)[0]
boxes, scores, labels = boxes[inds], scores[inds], labels[inds]

Could you provide a snippet that reproduces negative scores so that we can investigate?

@masahi
Copy link
Contributor Author

masahi commented Dec 22, 2020

Sure this is the script https://gist.github.com/masahi/ea002c85e7d665d40eeb5c6422490e63

I dumped the inputs to batched_nms and I can see many negative boxes. My dumps are available at https://github.com/masahi/torchscript-to-tvm/tree/master/nms-perf-issue

I think thresholding scores at 0 in roi_head.py make sense because the scores in ROIHeads are outputs from softmax that are supposed to be probability-like values. But in the other use of batched_nms, the one I'm getting many negative boxes from, the scores are "objectness". Objectness seems to be outputs from convolution layers whose values can be anything. Actually, when I run my script, there are about 150000 boxes generated from objectness network, and only about 3000 of them have positive scores.

So the two use of NMS have slightly different notion of scores. I don't know if NMS with negative scores is common in networks that use RPN.

@datumbox
Copy link
Contributor

Thanks a lot for providing additional info.

NMS does not require the scores to be positive. The kernel uses the scores to rank the boxes from highest to lowest. Thus they can be virtually any well-behaved score and the method will still work:

auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));

As you said, in the case of ROIHeads the scores are positive because they are probabilities:

pred_scores = F.softmax(class_logits, -1)

Thus in this case there is a natural range and a lower threshold value that can be assumed. I think this is the reason why applying the heuristic thresholding is easier.

On the other hand, the objectness score are logits so it takes values in R (a negative logit is when p <0.5). You can see that this is the case here:

objectness_loss = F.binary_cross_entropy_with_logits(
objectness[sampled_inds], labels[sampled_inds]
)

So in principle you could introduce such a threshold in the RPN method too, but you would have to turn logits into probabilities to ensure their scale is more intuitive for the user and they can tune easier the threshold. To reiterate, we will be doing this not because NMS does not handle negative scores, but because we might want to apply the same heuristic as in ROIHeads and clip boxes before NMS. Perhaps if we had a few benchmarks that indicate massive speed up gains with minimum loss in accuracy, you could build a strong case.

I'll like the opinion of @fmassa who is way more into Object Detection than me to see if that's something we want to bring in TorchVision. He is currently on holidays so, he is likely to respond once he is back.

@masahi
Copy link
Contributor Author

masahi commented Dec 22, 2020

ok I understand that negative boxes are considered valid inputs to NMS. But then there are some deep questions I have like the ones below. I don't expect all of them to be answered, since some of them might not have a clear answer.

  • What do boxes with negative objectness mean?
  • What makes one box positive while others negative?
  • Can a totally reasonable box have a negative objectness? Some people in the TVM community are claiming that negative boxes don't matter in the final detection results in practice.

If being negative doesn't matter much in terms of objectness, would it be better to rescale them in the range [0, 1), for more intuitive understanding? If yes, I think it makes sense to introduce a user-chosen threshold parameter to cut low-objectness boxes. There is already box_score_thresh parameter for NMS use in roi_head.py, so I think having a similar tunable parameter for NMS with objectness in rpn.py would be reasonable.

For another data point, MXNet's implementation of rpn applies sigmoid on the raw objectness score, so they don't have negative boxes: https://github.com/dmlc/gluon-cv/blob/f9a8a284b8222794bc842453e2bebe5746516048/gluoncv/model_zoo/rcnn/rpn/rpn.py#L238

@zhiqwang
Copy link
Contributor

I'm not sure the negative box will cause this problem as described in #1705 ?

@masahi
Copy link
Contributor Author

masahi commented Dec 22, 2020

@zhiqwang No. If negative boxes are considered valid, we cannot reduce the number of boxes.

@datumbox
Copy link
Contributor

@masahi I'm going to give a lengthier answer before jumping to your questions because it will give you a better understanding on how the implementation works:

In both RPN and RoIHeads we have parts of the network that do classification and estimate class logits instead of class probabilities (one for whether there is anything inside the anchor and the other about the category label). You can convert logits to probabilities and vice-versa easily but estimating logits is cheaper because you don't have to exponate, normalize etc.

You can confirm that we work with class logits by looking at the loss functions used in the fasterrcnn components (binary_cross_entropy_with_logits and fastrcnn_loss). In both the cases of RPN and RoiHeads those "scores" are fitted against some labels (not the same!). So in both cases, we estimate classification logits.

So why one technique ends up with negative scores? Well on RoIHeads, we do the following just before calling NMS and we turn the logits into probabilities:

pred_scores = F.softmax(class_logits, -1)

On the other hand, in the case of RPN we use directly the logits (which on the code they are called objectness) without converting them to probabilities (you can confirm that by tracing back how these scores are estimated). In both cases, either we use probabilities or logits, the order of the boxes inside the NMS will not change because as we discussed earlier the two scores are only used for ordering and both should produce the same order (at least that's what the maths say, in practice you might get slight differences due to rounding errors etc).


What do boxes with negative objectness mean?

The "objectness" are the logit scores estimated by the model for whether there is any object detected to a specific anchor. Logits can be converted to probabilities as said earlier. The higher the probability of an event, the higher its logit (as p -> 1 then logit -> +oo and as p -> 0 then logit -> -oo). Logit < 0.0 means probability < 0.5.

What makes one box positive while others negative?

The model believes it's more likely an object to be in the anchor. If the probability is > 0.5 it's positive, else negative. Note that this is just a transformation. We don't explicitly remove scores that are positive or negative. In other words the sign of the score does not matter currently.

Can a totally reasonable box have a negative objectness? Some people in the TVM community are claiming that negative boxes don't matter in the final detection results in practice.

I believe that yes, it's totally possible a reasonable box to have negative objectness. The scale of the score is not taken into account in our implementation and thus it just means that the specific box happened to have a slightly lower score. There is nothing significant about it. If you believe that negative boxes don't matter, then you might do so because you are willing to put a threshold to any box that has probability less than 0.5. But that's something you decide based on the domain, problem, data, cost matrix etc, it's not a general rule.

If being negative doesn't matter much in terms of objectness, would it be better to rescale them in the range [0, 1), for more intuitive understanding?

Arguably probabilities are easier to interpret than logits but they do come with an overhead because they require additional calculations. Perhaps that extra overhead is justified if people intend to apply thresholds to those probabilities to reduce the number of boxes before the O(n*logn) sorting part of NMS. But that's something that needs to be weighted to ensure that it's relatively straightforward to tune and does improve the overall speed without hurting the accuracy.

To check the above, I drafted PR #3205 to investigate the tradeoff between speed/accuracy if we switch to probabilities and apply thresholds. Note that this is just an investigation, it's definitely not guaranteed that the change will land on master cause it's unclear if there are any benefits.

@masahi
Copy link
Contributor Author

masahi commented Dec 23, 2020

Wow that's a great answer! Thank you very much, it makes a lot of sense now :)

Indeed, since I'm only interested in testing / model compiling, I haven't paid attention to loss functions. But it is a great insight that I can look at the losses to understand what objectness network is trying to calculate.

Logit < 0.0 means probability < 0.5

👍 Great now I understand where the negatives come from and that they do matter.

I'm now definitely convinced that TVM's conversion rule for PyTorch NMS is definitely wrong and needs fixing. I'll try convince others who don't want to remove the "cheat" we are doing because it makes things slower :)

A bit off topic, but I got the impression that you deeply care about the performance of NMS. If that's the case, I suggest reworking your batched_nms implementation https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py#L80-L88. I think the way it is implemented is more of a hack to get things going fast, but this implementation means most of O(N ** 2) IOU tests you do in nms_kernel are pointless and also
O(N ** 2) memory requirement for the bit mask can become significant for no good reason (large fraction of the entries are 0 by construction).

Feel free to close this issue, since my original questions are answered. I think it also makes sense to keep this open until #3205 is merged or closed.

@datumbox
Copy link
Contributor

@masahi Awesome! I'll close the ticket but please feel free to open it again if there are more concerns.

Concerning the NMS indeed it's one of the most costly parts of inference so we try to optimize it as much as possible. My understanding is that the current implementation tries to make use of PyTorch specific features to make it very fast, nevertheless if you think there is a better option please send a PR and we will help you benchmark it.

@fmassa
Copy link
Member

fmassa commented Jan 5, 2021

Hi,

Thanks for the very comprehensive discussion @datumbox, I totally agree with your answers.

@masahi about your comment

reworking your batched_nms implementation

yes, you are right. The current implementation is tailored mostly for GPUs, in order to be able to re-use the current kernels we have and avoid for loops in Python (which would be costly).
It would be preferable to have an optimized C++ / CUDA implementation that could do this more efficiently. For CPU that would most probably mean just a simple for loop, but for CUDA things get a bit more complicated.

@masahi
Copy link
Contributor Author

masahi commented Jan 5, 2021

@fmassa @datumbox

About batched_nms performance, one thing I found is that all usage of batched_nms are followed by post NMS topk, like this:

keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
# keep only topk scoring predictions
keep = keep[:self.post_nms_top_n()]

So if we pass this topk parameter to NMS, we can make NMS faster by breaking out of the sequential loop on CPU below as soon as num_to_keep reaches this topk value.

int num_to_keep = 0;
for (int i = 0; i < dets_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long* p = mask_host + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}

@fmassa
Copy link
Member

fmassa commented Jan 6, 2021

Yes, that is true, and it's the approach that some kernels follow (like the one from Caffe2 and ONNX).
I'm not necessarily super excited about this change though, as this gives precedence for other micro-optimizations that could be applied to NMS (like the one discussed in #2552)

But I would like to see some performance numbers on representative use-cases to motivate the potential gains by restricting num_to_keep inside NMS -- I wouldn't expect this to make a very significant different runtime-wise on an end-to-end model inference (but I might be wrong)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants