diff --git a/lib/Importer/ONNXModelLoader.cpp b/lib/Importer/ONNXModelLoader.cpp index b34dca79fa..c25f397030 100644 --- a/lib/Importer/ONNXModelLoader.cpp +++ b/lib/Importer/ONNXModelLoader.cpp @@ -1445,8 +1445,31 @@ Error ONNXModelLoader::loadSlice(const ONNX_NAMESPACE::NodeProto &op, axesC->getType()->getElementName().str().c_str()))); } - RETURN_ERR_IF_NOT(op.input_size() == 4, - opErrMsg(op, "Steps is not currently supported!")); + if (op.input_size() > 4) { + std::vector step; + Constant *stepC = getConstantByNameOrNull(op.input(4)); + + RETURN_ERR_IF_NOT(stepC, opErrMsg(op, "Step tensor is not Constant.")); + + if (stepC->getElementType() == ElemKind::Int64ITy) { + helperSetter(stepC, step); + } else if (stepC->getElementType() == ElemKind::Int32ITy) { + helperSetter(stepC, step); + } else { + RETURN_ERR_IF_NOT( + false, + opErrMsg( + op, + strFormat("Step Tensor has unsupported type '%s'", + stepC->getType()->getElementName().str().c_str()))); + } + + // Step is interpreted 1 as default. + for (size_t i = 0; i < step.size(); i++) { + RETURN_ERR_IF_NOT(step[i] == 1, + opErrMsg(op, "step!=1 is currently not supported")); + } + } } } else { // Attributes 'starts' and 'ends' are mandatory and must be consistent. diff --git a/tests/models/onnxModels/sliceWithStep.onnxtxt b/tests/models/onnxModels/sliceWithStep.onnxtxt new file mode 100644 index 0000000000..ac035250ce --- /dev/null +++ b/tests/models/onnxModels/sliceWithStep.onnxtxt @@ -0,0 +1,69 @@ +ir_version: 5 +producer_name: "test4glow" +opset_import { + version: 10 +} + +graph { + name: "test-model" + input { + name: "data" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 3 + } + dim { + dim_value: 3 + } + } + } + } + } + initializer { + dims: 4 + data_type: 7 + name: "starts" + raw_data: "\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000" + } + initializer { + dims: 4 + data_type: 7 + name: "ends" + raw_data: "\002\000\000\000\000\000\000\000\002\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000" + } + initializer { + dims: 4 + data_type: 7 + name: "axes" + raw_data: "\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\002\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000" + } + initializer { + dims: 4 + data_type: 7 + name: "step" + raw_data: "\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000" + } + node { + input: "data" + input: "starts" + input: "ends" + input: "axes" + input: "step" + output: "out" + name: "DynamicSlice" + op_type: "Slice" + domain: "" + } + output { + name: "out" + } +} \ No newline at end of file diff --git a/tests/models/onnxModels/sliceWithUnsupportedStep.onnxtxt b/tests/models/onnxModels/sliceWithUnsupportedStep.onnxtxt new file mode 100644 index 0000000000..ff9bb381c2 --- /dev/null +++ b/tests/models/onnxModels/sliceWithUnsupportedStep.onnxtxt @@ -0,0 +1,69 @@ +ir_version: 5 +producer_name: "test4glow" +opset_import { + version: 10 +} + +graph { + name: "test-model" + input { + name: "data" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 3 + } + dim { + dim_value: 3 + } + } + } + } + } + initializer { + dims: 4 + data_type: 7 + name: "starts" + raw_data: "\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000" + } + initializer { + dims: 4 + data_type: 7 + name: "ends" + raw_data: "\002\000\000\000\000\000\000\000\002\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000" + } + initializer { + dims: 4 + data_type: 7 + name: "axes" + raw_data: "\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\002\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000" + } + initializer { + dims: 4 + data_type: 7 + name: "step" + raw_data: "\002\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000" + } + node { + input: "data" + input: "starts" + input: "ends" + input: "axes" + input: "step" + output: "out" + name: "DynamicSlice" + op_type: "Slice" + domain: "" + } + output { + name: "out" + } +} \ No newline at end of file diff --git a/tests/unittests/OnnxExporterTest.cpp b/tests/unittests/OnnxExporterTest.cpp index 3f74a08ebf..7a5e022afe 100644 --- a/tests/unittests/OnnxExporterTest.cpp +++ b/tests/unittests/OnnxExporterTest.cpp @@ -453,7 +453,8 @@ TEST(exporter, onnxModels) { name.find("pow_scalar_broadcast.onnxtxt") != std::string::npos || name.find("simpleConvTransposeAutoPadSameUpper.onnxtxt") != std::string::npos || - name.find("sliceInvalidAxes.onnxtxt") != std::string::npos) { + name.find("sliceInvalidAxes.onnxtxt") != std::string::npos || + name.find("sliceWithUnsupportedStep.onnxtxt") != std::string::npos) { // Ignore invalid ONNX files and graphs without nodes. llvm::outs() << "Ignore invalid input files: " << name << "\n"; continue; diff --git a/tests/unittests/OnnxImporterTest.cpp b/tests/unittests/OnnxImporterTest.cpp index 101240be69..fef2c9c039 100644 --- a/tests/unittests/OnnxImporterTest.cpp +++ b/tests/unittests/OnnxImporterTest.cpp @@ -2750,6 +2750,19 @@ TEST_F(OnnxImporterTest, importSliceInvalidAxes) { {2, 1, 2, 2} /* output */, true); } +TEST_F(OnnxImporterTest, importSliceWithStep) { + importSliceTest("sliceWithStep.onnxtxt", "data", {2, 3, 3, 3} /* input */, + {0, 1, 1, 1} /* starts */, /* ends: {2, 2, 3, 3} */ + {2, 1, 2, 2} /* output */); +} + +TEST_F(OnnxImporterTest, importSliceWithUnsupportedStep) { + importSliceTest("sliceWithUnsupportedStep.onnxtxt", "data", + {2, 3, 3, 3} /* input */, + {0, 1, 1, 1} /* starts */, /* ends: {2, 2, 3, 3} */ + {2, 1, 2, 2} /* output */, true); +} + static void importCast(llvm::StringRef fileName, llvm::StringRef inputName, llvm::ArrayRef inputShape, ElemKind outputKind) { ExecutionEngine EE{};