Skip to content

Commit c45017b

Browse files
prabhat00155facebook-github-bot
authored andcommitted
[fbsync] Change to stable sort in nms implementations (#4767)
Summary: * change to stable sort in nms implementations Reviewed By: NicolasHug Differential Revision: D32694315 fbshipit-source-id: e2ff4d0ed84ca7a4ef2982f4d9bb3192a88dc9b0
1 parent 28c597b commit c45017b

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

torchvision/csrc/ops/cpu/nms_kernel.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ at::Tensor nms_kernel_impl(
2727

2828
at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);
2929

30-
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
30+
auto order_t = std::get<1>(
31+
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
3132

3233
auto ndets = dets.size(0);
3334
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));

torchvision/csrc/ops/cuda/nms_kernel.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ at::Tensor nms_kernel(
109109
return at::empty({0}, dets.options().dtype(at::kLong));
110110
}
111111

112-
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
112+
auto order_t = std::get<1>(
113+
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
113114
auto dets_sorted = dets.index_select(0, order_t).contiguous();
114115

115116
int dets_num = dets.size(0);

torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ at::Tensor qnms_kernel_impl(
2727
auto y1_t = dets.select(1, 1).contiguous();
2828
auto x2_t = dets.select(1, 2).contiguous();
2929
auto y2_t = dets.select(1, 3).contiguous();
30-
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
30+
auto order_t = std::get<1>(
31+
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
3132
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
3233
at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
3334
at::Tensor areas_t = at::zeros({ndets}, dets.options().dtype(at::kFloat));

0 commit comments

Comments
 (0)