Skip to content

Commit 212e4ae

Browse files
gramalingamaskhade
andauthored
Simplify definition of context-free function body definitions (onnx#3717)
* Update function defs to use parser Signed-off-by: Ganesan Ramalingam <[email protected]> * Extend parser to support attribute references Signed-off-by: Ganesan Ramalingam <[email protected]> * Update function body for MVN Signed-off-by: Ganesan Ramalingam <[email protected]> * Allow untyped inputs and outputs in graphs Signed-off-by: Ganesan Ramalingam <[email protected]> * Update range function body Signed-off-by: Ganesan Ramalingam <[email protected]> * Delete old function bodu builder for range Signed-off-by: Ganesan Ramalingam <[email protected]> * Minor formatting for range function body Signed-off-by: Ganesan Ramalingam <[email protected]> * Add support for comments in parser Signed-off-by: Ganesan Ramalingam <[email protected]> * Add support for exponent notation in float literals Signed-off-by: Ganesan Ramalingam <[email protected]> * Address PR feedback (add check for parse status) Signed-off-by: Ganesan Ramalingam <[email protected]> * Remove attr ref Signed-off-by: Ganesan Ramalingam <[email protected]> * Add back attr type in ref attr Signed-off-by: Ganesan Ramalingam <[email protected]> Co-authored-by: Ashwini Khade <[email protected]>
1 parent 1cc5e12 commit 212e4ae

File tree

10 files changed

+201
-176
lines changed

10 files changed

+201
-176
lines changed

onnx/defs/generator/defs.cc

Lines changed: 18 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -715,97 +715,6 @@ inline int64_t compute_output_dim_for_range(
715715
return n;
716716
}
717717

718-
const std::vector<NodeProto> build_nodes_range_op() {
719-
// body for 'Loop node'
720-
GraphProto loop_sub_graph;
721-
loop_sub_graph.set_name("loop_body_attribute");
722-
723-
// 'Loop' node 'body' attribute's graph inputs
724-
// input 0 - number of iteration
725-
auto* input_value_info_proto_0 = loop_sub_graph.add_input();
726-
input_value_info_proto_0->set_name("i");
727-
// add an empty shape
728-
auto* input_0_type_proto_tensor =
729-
input_value_info_proto_0->mutable_type()->mutable_tensor_type();
730-
input_0_type_proto_tensor->mutable_shape()->Clear();
731-
// always INT64 type
732-
input_0_type_proto_tensor->set_elem_type(TensorProto_DataType_INT64);
733-
734-
// input 1 - condition
735-
auto* input_value_info_proto_1 = loop_sub_graph.add_input();
736-
input_value_info_proto_1->set_name("cond");
737-
// add an empty shape
738-
auto* input_1_type_proto_tensor =
739-
input_value_info_proto_1->mutable_type()->mutable_tensor_type();
740-
input_1_type_proto_tensor->mutable_shape()->Clear();
741-
// always BOOL type
742-
input_1_type_proto_tensor->set_elem_type(TensorProto_DataType_BOOL);
743-
744-
// input 2 - loop carried dependency
745-
auto* input_value_info_proto_2 = loop_sub_graph.add_input();
746-
input_value_info_proto_2->set_name("prev");
747-
748-
// 'Loop' node 'body' attribute's graph nodes
749-
auto* node_proto_0 = loop_sub_graph.add_node();
750-
node_proto_0->set_op_type("Identity");
751-
node_proto_0->add_input();
752-
node_proto_0->set_input(0, "cond");
753-
node_proto_0->add_output();
754-
node_proto_0->set_output(0, "cond_out");
755-
756-
auto* node_proto_1 = loop_sub_graph.add_node();
757-
node_proto_1->set_op_type("Add");
758-
node_proto_1->add_input();
759-
node_proto_1->set_input(0, "prev");
760-
node_proto_1->add_input();
761-
node_proto_1->set_input(1, "delta");
762-
node_proto_1->add_output();
763-
node_proto_1->set_output(0, "current");
764-
765-
auto* node_proto_2 = loop_sub_graph.add_node();
766-
node_proto_2->set_op_type("Identity");
767-
node_proto_2->add_input();
768-
node_proto_2->set_input(0, "prev");
769-
node_proto_2->add_output();
770-
node_proto_2->set_output(0, "range");
771-
772-
// 'Loop' node 'body' attribute's graph inputs
773-
auto* output_value_info_proto_0 = loop_sub_graph.add_output();
774-
output_value_info_proto_0->set_name("cond_out");
775-
776-
auto* output_value_info_proto_1 = loop_sub_graph.add_output();
777-
output_value_info_proto_1->set_name("current");
778-
779-
auto* output_value_info_proto_2 = loop_sub_graph.add_output();
780-
output_value_info_proto_2->set_name("range");
781-
782-
return FunctionBodyHelper::BuildNodes(
783-
{// nodes: {outputs, op, inputs, attributes}
784-
{{"sub_result"}, "Sub", {"limit", "start"}},
785-
{{"sub_result_casted"},
786-
"Cast",
787-
{"sub_result"},
788-
{{"to", static_cast<int64_t>(1)}}},
789-
{{"delta_casted"}, "Cast", {"delta"}, {{"to", static_cast<int64_t>(1)}}},
790-
{{"div_result"}, "Div", {"sub_result_casted", "delta_casted"}},
791-
{{"ceil_result"}, "Ceil", {"div_result"}},
792-
// we want max(0, ceil_cast_int) as negative values would evaluate to
793-
// bool true in next step
794-
{{"ceil_result_relu"}, "Relu", {"ceil_result"}},
795-
{{"ceil_result_relu_int"},
796-
"Cast",
797-
{"ceil_result_relu"},
798-
{{"to", static_cast<int64_t>(7)}}},
799-
{{"ceil_result_relu_bool"},
800-
"Cast",
801-
{"ceil_result_relu"},
802-
{{"to", static_cast<int64_t>(9)}}},
803-
{{"variadic_output", "output"},
804-
"Loop",
805-
{"ceil_result_relu_int", "ceil_result_relu_bool", "start"},
806-
{MakeAttribute("body", loop_sub_graph)}}});
807-
}
808-
809718
ONNX_OPERATOR_SET_SCHEMA(
810719
Range,
811720
11,
@@ -835,7 +744,24 @@ ONNX_OPERATOR_SET_SCHEMA(
835744
"tensor(int32)",
836745
"tensor(int64)"},
837746
"Constrain input types to common numeric type tensors.")
838-
.FunctionBody(build_nodes_range_op())
747+
.FunctionBody(R"ONNX(
748+
{
749+
sub_result = Sub (limit, start)
750+
sub_result_casted = Cast <to = 1> (sub_result)
751+
delta_casted = Cast <to = 1> (delta)
752+
div_result = Div (sub_result_casted, delta_casted)
753+
ceil_result = Ceil (div_result)
754+
ceil_result_relu = Relu (ceil_result)
755+
ceil_result_relu_int = Cast <to = 7> (ceil_result_relu)
756+
ceil_result_relu_bool = Cast <to = 9> (ceil_result_relu)
757+
variadic_output, output = Loop (ceil_result_relu_int, ceil_result_relu_bool, start)
758+
<body = loop_body_attribute (int64 i, bool cond, prev) => (cond_out, current, range) {
759+
cond_out = Identity (cond)
760+
current = Add (prev, delta)
761+
range = Identity (prev)
762+
}>
763+
}
764+
)ONNX")
839765
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
840766
// Type inference
841767
propagateElemTypeFromInputToOutput(ctx, 0, 0);

onnx/defs/logical/defs.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,13 @@ ONNX_OPERATOR_SET_SCHEMA(
284284
{"tensor(bool)"},
285285
"Constrains output to boolean tensor.")
286286
.TypeAndShapeInferenceFunction(InferenceFunction())
287-
.FunctionBody(FunctionBodyHelper::BuildNodes(
288-
{// nodes: {outputs, op, inputs, attributes}
289-
{{"O1"}, "Less", {"A", "B"}},
290-
{{"O2"}, "Equal", {"A", "B"}},
291-
{{"C"}, "Or", {"O1", "O2"}}})));
287+
.FunctionBody(R"ONNX(
288+
{
289+
O1 = Less (A, B)
290+
O2 = Equal (A, B)
291+
C = Or (O1, O2)
292+
}
293+
)ONNX"));
292294

293295
ONNX_OPERATOR_SET_SCHEMA(
294296
GreaterOrEqual,

onnx/defs/math/defs.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,13 +1117,12 @@ ONNX_OPERATOR_SET_SCHEMA(
11171117
{"tensor(float16)", "tensor(float)", "tensor(double)"},
11181118
"Constrain input and output types to float tensors.")
11191119
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
1120-
.FunctionBody(FunctionBodyHelper::BuildNodes({
1121-
// nodes: {outputs, op, inputs, attributes}
1122-
{{"HS_X"},
1123-
"HardSigmoid",
1124-
{"X"},
1125-
{MakeAttribute("alpha", 1.0f/6.0f), MakeAttribute("beta", 0.5f)}},
1126-
{{"Y"}, "Mul", {"X", "HS_X"}}})));
1120+
.FunctionBody(R"ONNX(
1121+
{
1122+
HS_X = HardSigmoid<alpha = 0.16666667163372, beta = 0.5>(X)
1123+
Y = Mul (X, HS_X)
1124+
}
1125+
)ONNX"));
11271126

11281127
// Generate opschema for element-wise ops. Leaves type constraint "T"
11291128
// unspecified.

onnx/defs/nn/defs.cc

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2466,24 +2466,21 @@ ONNX_OPERATOR_SET_SCHEMA(
24662466
"tensor(double)",
24672467
"tensor(bfloat16)"},
24682468
"Constrain input and output types to all numeric tensors.")
2469-
.FunctionBody(FunctionBodyHelper::BuildNodes(
2470-
{// nodes: {outputs, op, inputs, attributes}
2471-
FunctionBodyHelper::Const<float>("Exponent", 2.0f),
2472-
FunctionBodyHelper::Const<float>("Epsilon", float(1e-9)),
2473-
{{"X_RM"},
2474-
"ReduceMean",
2475-
{"X"},
2476-
{MakeRefAttribute("axes", AttributeProto::INTS)}},
2477-
{{"EX_squared"}, "Pow", {"X_RM", "Exponent"}},
2478-
{{"X_squared"}, "Pow", {"X", "Exponent"}},
2479-
{{"E_Xsquared"},
2480-
"ReduceMean",
2481-
{"X_squared"},
2482-
{MakeRefAttribute("axes", AttributeProto::INTS)}},
2483-
{{"Variance"}, "Sub", {"E_Xsquared", "EX_squared"}},
2484-
{{"STD"}, "Sqrt", {"Variance"}},
2485-
{{"X_variance"}, "Sub", {"X", "X_RM"}},
2486-
{{"Processed_STD"}, "Add", {"STD", "Epsilon"}},
2487-
{{"Y"}, "Div", {"X_variance", "Processed_STD"}}})));
2469+
.FunctionBody(R"ONNX(
2470+
{
2471+
Exponent = Constant <value = float {2.0}>()
2472+
Epsilon = Constant <value = float {1e-9}>()
2473+
X_RM = ReduceMean <axes : ints = @axes> (X)
2474+
EX_squared = Pow (X_RM, Exponent)
2475+
X_squared = Pow (X, Exponent)
2476+
E_Xsquared = ReduceMean <axes : ints = @axes> (X_squared)
2477+
Variance = Sub (E_Xsquared, EX_squared)
2478+
STD = Sqrt (Variance)
2479+
X_variance = Sub (X, X_RM)
2480+
Processed_STD = Add (STD, Epsilon)
2481+
Y = Div (X_variance, Processed_STD)
2482+
}
2483+
)ONNX"
2484+
));
24882485

24892486
} // namespace ONNX_NAMESPACE

onnx/defs/parser.cc

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ Status OnnxParser::Parse(TypeProto& typeProto) {
9292
}
9393

9494
Status OnnxParser::Parse(ValueInfoProto& valueinfo) {
95-
PARSE(*valueinfo.mutable_type());
95+
if (NextIsType())
96+
PARSE(*valueinfo.mutable_type());
9697
std::string name;
9798
CHECK_PARSER_STATUS(ParseIdentifier(name));
9899
valueinfo.set_name(name);
@@ -239,19 +240,27 @@ Status OnnxParser::Parse(TensorProto& tensorProto, const TypeProto& tensorTypePr
239240
return Status::OK();
240241
}
241242

243+
bool OnnxParser::NextIsType() {
244+
std::string id("");
245+
(void)PeekIdentifier(id);
246+
return (PrimitiveTypeNameMap::IsTypeName(id));
247+
}
248+
242249
Status OnnxParser::ParseSingleAttributeValue(AttributeProto& attr) {
243250
// Parse a single-value
244251
auto next = NextChar();
245252
if (isalpha(next) || next == '_') {
246-
std::string id("");
247-
(void)PeekIdentifier(id);
248-
if (PrimitiveTypeNameMap::IsTypeName(id)) {
253+
if (NextIsType()) {
249254
attr.set_type(AttributeProto_AttributeType_TENSOR);
250255
Parse(*attr.mutable_t());
251256
} else {
252257
attr.set_type(AttributeProto_AttributeType_GRAPH);
253258
Parse(*attr.mutable_g());
254259
}
260+
} else if (Matches('@')) {
261+
std::string name;
262+
CHECK_PARSER_STATUS(ParseIdentifier(name));
263+
attr.set_ref_attr_name(name);
255264
} else {
256265
Literal literal;
257266
PARSE_TOKEN(literal);
@@ -279,6 +288,15 @@ Status OnnxParser::Parse(AttributeProto& attr) {
279288
std::string name;
280289
CHECK_PARSER_STATUS(ParseIdentifier(name));
281290
attr.set_name(name);
291+
if (Matches(':')) {
292+
CHECK_PARSER_STATUS(ParseIdentifier(name));
293+
int attrtype = AttributeTypeNameMap::Lookup(name);
294+
if (attrtype != 0) {
295+
attr.set_type(static_cast<AttributeProto_AttributeType>(attrtype));
296+
} else {
297+
return ParseError("Unexpected attribute type.");
298+
}
299+
}
282300
MATCH('=');
283301
if (NextChar() == '[') {
284302
// Parse a list of values. For now, empty list is not allowed, as we need to

0 commit comments

Comments
 (0)