Skip to content

Commit f4ec25e

Browse files
Lewuathefacebook-github-bot
authored andcommitted
Support step attribute in slice from Onnx (#5454)
Summary: Support step attribute in slice node when loading ONNX format. Although it succeeds to load ONNX format having slice node with step attribute, it does not support non-1 step values. Since it should be rare to use non-1 step values for major machine learning models, we may want to prioritize loading major models such as YOLO, LSTM. See: #3987 Pull Request resolved: #5454 Test Plan: Added two test cases passing slice node with the default steps (all having 1s) and slice node with non-1 step attributes. Reviewed By: jackm321 Differential Revision: D27598418 Pulled By: jfix71 fbshipit-source-id: f4e51d69772f5f633c6173bdc5aea60f89b51dac
1 parent 682bb08 commit f4ec25e

File tree

5 files changed

+178
-3
lines changed

5 files changed

+178
-3
lines changed

lib/Importer/ONNXModelLoader.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,8 +1445,31 @@ Error ONNXModelLoader::loadSlice(const ONNX_NAMESPACE::NodeProto &op,
14451445
axesC->getType()->getElementName().str().c_str())));
14461446
}
14471447

1448-
RETURN_ERR_IF_NOT(op.input_size() == 4,
1449-
opErrMsg(op, "Steps is not currently supported!"));
1448+
if (op.input_size() > 4) {
1449+
std::vector<ssize_t> step;
1450+
Constant *stepC = getConstantByNameOrNull(op.input(4));
1451+
1452+
RETURN_ERR_IF_NOT(stepC, opErrMsg(op, "Step tensor is not Constant."));
1453+
1454+
if (stepC->getElementType() == ElemKind::Int64ITy) {
1455+
helperSetter<int64_t>(stepC, step);
1456+
} else if (stepC->getElementType() == ElemKind::Int32ITy) {
1457+
helperSetter<int32_t>(stepC, step);
1458+
} else {
1459+
RETURN_ERR_IF_NOT(
1460+
false,
1461+
opErrMsg(
1462+
op,
1463+
strFormat("Step Tensor has unsupported type '%s'",
1464+
stepC->getType()->getElementName().str().c_str())));
1465+
}
1466+
1467+
// Step is interpreted 1 as default.
1468+
for (size_t i = 0; i < step.size(); i++) {
1469+
RETURN_ERR_IF_NOT(step[i] == 1,
1470+
opErrMsg(op, "step!=1 is currently not supported"));
1471+
}
1472+
}
14501473
}
14511474
} else {
14521475
// Attributes 'starts' and 'ends' are mandatory and must be consistent.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
ir_version: 5
2+
producer_name: "test4glow"
3+
opset_import {
4+
version: 10
5+
}
6+
7+
graph {
8+
name: "test-model"
9+
input {
10+
name: "data"
11+
type {
12+
tensor_type {
13+
elem_type: 1
14+
shape {
15+
dim {
16+
dim_value: 2
17+
}
18+
dim {
19+
dim_value: 3
20+
}
21+
dim {
22+
dim_value: 3
23+
}
24+
dim {
25+
dim_value: 3
26+
}
27+
}
28+
}
29+
}
30+
}
31+
initializer {
32+
dims: 4
33+
data_type: 7
34+
name: "starts"
35+
raw_data: "\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000"
36+
}
37+
initializer {
38+
dims: 4
39+
data_type: 7
40+
name: "ends"
41+
raw_data: "\002\000\000\000\000\000\000\000\002\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000"
42+
}
43+
initializer {
44+
dims: 4
45+
data_type: 7
46+
name: "axes"
47+
raw_data: "\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\002\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000"
48+
}
49+
initializer {
50+
dims: 4
51+
data_type: 7
52+
name: "step"
53+
raw_data: "\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000"
54+
}
55+
node {
56+
input: "data"
57+
input: "starts"
58+
input: "ends"
59+
input: "axes"
60+
input: "step"
61+
output: "out"
62+
name: "DynamicSlice"
63+
op_type: "Slice"
64+
domain: ""
65+
}
66+
output {
67+
name: "out"
68+
}
69+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
ir_version: 5
2+
producer_name: "test4glow"
3+
opset_import {
4+
version: 10
5+
}
6+
7+
graph {
8+
name: "test-model"
9+
input {
10+
name: "data"
11+
type {
12+
tensor_type {
13+
elem_type: 1
14+
shape {
15+
dim {
16+
dim_value: 2
17+
}
18+
dim {
19+
dim_value: 3
20+
}
21+
dim {
22+
dim_value: 3
23+
}
24+
dim {
25+
dim_value: 3
26+
}
27+
}
28+
}
29+
}
30+
}
31+
initializer {
32+
dims: 4
33+
data_type: 7
34+
name: "starts"
35+
raw_data: "\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000"
36+
}
37+
initializer {
38+
dims: 4
39+
data_type: 7
40+
name: "ends"
41+
raw_data: "\002\000\000\000\000\000\000\000\002\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000"
42+
}
43+
initializer {
44+
dims: 4
45+
data_type: 7
46+
name: "axes"
47+
raw_data: "\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\002\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000"
48+
}
49+
initializer {
50+
dims: 4
51+
data_type: 7
52+
name: "step"
53+
raw_data: "\002\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000"
54+
}
55+
node {
56+
input: "data"
57+
input: "starts"
58+
input: "ends"
59+
input: "axes"
60+
input: "step"
61+
output: "out"
62+
name: "DynamicSlice"
63+
op_type: "Slice"
64+
domain: ""
65+
}
66+
output {
67+
name: "out"
68+
}
69+
}

tests/unittests/OnnxExporterTest.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,8 @@ TEST(exporter, onnxModels) {
453453
name.find("pow_scalar_broadcast.onnxtxt") != std::string::npos ||
454454
name.find("simpleConvTransposeAutoPadSameUpper.onnxtxt") !=
455455
std::string::npos ||
456-
name.find("sliceInvalidAxes.onnxtxt") != std::string::npos) {
456+
name.find("sliceInvalidAxes.onnxtxt") != std::string::npos ||
457+
name.find("sliceWithUnsupportedStep.onnxtxt") != std::string::npos) {
457458
// Ignore invalid ONNX files and graphs without nodes.
458459
llvm::outs() << "Ignore invalid input files: " << name << "\n";
459460
continue;

tests/unittests/OnnxImporterTest.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2750,6 +2750,19 @@ TEST_F(OnnxImporterTest, importSliceInvalidAxes) {
27502750
{2, 1, 2, 2} /* output */, true);
27512751
}
27522752

2753+
TEST_F(OnnxImporterTest, importSliceWithStep) {
2754+
importSliceTest("sliceWithStep.onnxtxt", "data", {2, 3, 3, 3} /* input */,
2755+
{0, 1, 1, 1} /* starts */, /* ends: {2, 2, 3, 3} */
2756+
{2, 1, 2, 2} /* output */);
2757+
}
2758+
2759+
TEST_F(OnnxImporterTest, importSliceWithUnsupportedStep) {
2760+
importSliceTest("sliceWithUnsupportedStep.onnxtxt", "data",
2761+
{2, 3, 3, 3} /* input */,
2762+
{0, 1, 1, 1} /* starts */, /* ends: {2, 2, 3, 3} */
2763+
{2, 1, 2, 2} /* output */, true);
2764+
}
2765+
27532766
static void importCast(llvm::StringRef fileName, llvm::StringRef inputName,
27542767
llvm::ArrayRef<dim_t> inputShape, ElemKind outputKind) {
27552768
ExecutionEngine EE{};

0 commit comments

Comments
 (0)