Skip to content

Commit 31d5d9f

Browse files
NicolasHugmalfetpmeier
authored andcommitted
[fbsync] [ONNX] Fix dtype for NonMaxSuppression (#7056)
Summary: Reviewed By: vmoens Differential Revision: D44416637 fbshipit-source-id: c06807ab7d9d71f2272761bf366411268ed5b462 Co-authored-by: Nikita Shulga <[email protected]> Co-authored-by: Philip Meier <[email protected]>
1 parent fd2d42a commit 31d5d9f

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

torchvision/ops/_register_onnx_ops.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,25 @@
1111
def _register_custom_op():
1212
from torch.onnx.symbolic_helper import parse_args
1313
from torch.onnx.symbolic_opset11 import select, squeeze, unsqueeze
14-
from torch.onnx.symbolic_opset9 import _cast_Long
1514

1615
@parse_args("v", "v", "f")
1716
def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
1817
boxes = unsqueeze(g, boxes, 0)
1918
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
2019
max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long))
2120
iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float))
22-
nms_out = g.op("NonMaxSuppression", boxes, scores, max_output_per_class, iou_threshold)
21+
nms_out = g.op(
22+
"NonMaxSuppression",
23+
g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.FLOAT),
24+
g.op("Cast", scores, to_i=torch.onnx.TensorProtoDataType.FLOAT),
25+
max_output_per_class,
26+
iou_threshold,
27+
)
2328
return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1)
2429

2530
def _process_batch_indices_for_roi_align(g, rois):
26-
return _cast_Long(
27-
g, squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1), False
28-
)
31+
indices = squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1)
32+
return g.op("Cast", indices, to_i=torch.onnx.TensorProtoDataType.INT64)
2933

3034
def _process_rois_for_roi_align(g, rois):
3135
return select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))

0 commit comments

Comments
 (0)