File tree 3 files changed +6
-3
lines changed
3 files changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -27,7 +27,8 @@ at::Tensor nms_kernel_impl(
27
27
28
28
at::Tensor areas_t = (x2_t - x1_t ) * (y2_t - y1_t );
29
29
30
- auto order_t = std::get<1 >(scores.sort (0 , /* descending=*/ true ));
30
+ auto order_t = std::get<1 >(
31
+ scores.sort (/* stable=*/ true , /* dim=*/ 0 , /* descending=*/ true ));
31
32
32
33
auto ndets = dets.size (0 );
33
34
at::Tensor suppressed_t = at::zeros ({ndets}, dets.options ().dtype (at::kByte ));
Original file line number Diff line number Diff line change @@ -109,7 +109,8 @@ at::Tensor nms_kernel(
109
109
return at::empty ({0 }, dets.options ().dtype (at::kLong ));
110
110
}
111
111
112
- auto order_t = std::get<1 >(scores.sort (0 , /* descending=*/ true ));
112
+ auto order_t = std::get<1 >(
113
+ scores.sort (/* stable=*/ true , /* dim=*/ 0 , /* descending=*/ true ));
113
114
auto dets_sorted = dets.index_select (0 , order_t ).contiguous ();
114
115
115
116
int dets_num = dets.size (0 );
Original file line number Diff line number Diff line change @@ -27,7 +27,8 @@ at::Tensor qnms_kernel_impl(
27
27
auto y1_t = dets.select (1 , 1 ).contiguous ();
28
28
auto x2_t = dets.select (1 , 2 ).contiguous ();
29
29
auto y2_t = dets.select (1 , 3 ).contiguous ();
30
- auto order_t = std::get<1 >(scores.sort (0 , /* descending=*/ true ));
30
+ auto order_t = std::get<1 >(
31
+ scores.sort (/* stable=*/ true , /* dim=*/ 0 , /* descending=*/ true ));
31
32
at::Tensor suppressed_t = at::zeros ({ndets}, dets.options ().dtype (at::kByte ));
32
33
at::Tensor keep_t = at::zeros ({ndets}, dets.options ().dtype (at::kLong ));
33
34
at::Tensor areas_t = at::zeros ({ndets}, dets.options ().dtype (at::kFloat ));
You can’t perform that action at this time.
0 commit comments