|
11 | 11 | def _register_custom_op():
|
12 | 12 | from torch.onnx.symbolic_helper import parse_args
|
13 | 13 | from torch.onnx.symbolic_opset11 import select, squeeze, unsqueeze
|
14 |
| - from torch.onnx.symbolic_opset9 import _cast_Long |
15 | 14 |
|
16 | 15 | @parse_args("v", "v", "f")
|
17 | 16 | def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
|
18 | 17 | boxes = unsqueeze(g, boxes, 0)
|
19 | 18 | scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
|
20 | 19 | max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long))
|
21 | 20 | iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float))
|
22 |
| - nms_out = g.op("NonMaxSuppression", boxes, scores, max_output_per_class, iou_threshold) |
| 21 | + nms_out = g.op( |
| 22 | + "NonMaxSuppression", |
| 23 | + g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.FLOAT), |
| 24 | + g.op("Cast", scores, to_i=torch.onnx.TensorProtoDataType.FLOAT), |
| 25 | + max_output_per_class, |
| 26 | + iou_threshold, |
| 27 | + ) |
23 | 28 | return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1)
|
24 | 29 |
|
25 | 30 | def _process_batch_indices_for_roi_align(g, rois):
|
26 |
| - return _cast_Long( |
27 |
| - g, squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1), False |
28 |
| - ) |
| 31 | + indices = squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1) |
| 32 | + return g.op("Cast", indices, to_i=torch.onnx.TensorProtoDataType.INT64) |
29 | 33 |
|
30 | 34 | def _process_rois_for_roi_align(g, rois):
|
31 | 35 | return select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
|
|
0 commit comments