@@ -197,6 +197,19 @@ void inputsToProto(const Node *node, ONNX_NAMESPACE::NodeProto *proto) {
197
197
}
198
198
}
199
199
200
+ // / Write the output of the provided node, add SaveNode if necessary
201
+ void outputKindToProto (const Node *node, ONNX_NAMESPACE::NodeProto *proto) {
202
+ for (const auto &use : node->getUsers ()) {
203
+ const auto *user = use.getUser ();
204
+ if (user->getKind () == Kinded::Kind::SaveNodeKind) {
205
+ const SaveNode *SN = llvm::cast<SaveNode>(user);
206
+ proto->add_output (SN->getPlaceholder ()->getName ());
207
+ } else {
208
+ outputsToProto (user, proto);
209
+ }
210
+ }
211
+ }
212
+
200
213
// / Write the output of the provided type only of node outputs.
201
214
bool outputKindToProto (Kinded::Kind kind, const Node *node,
202
215
ONNX_NAMESPACE::NodeProto *proto) {
@@ -757,14 +770,8 @@ Error ONNXModelWriter::writeBatchedReduceMean(const BatchedReduceMeanNode *node,
757
770
proto->set_op_type (" ReduceMean" );
758
771
inputsToProto (node, proto);
759
772
760
- // Use the output of reshape node.
761
- if (outputKindToProto (Kinded::Kind::ReshapeNodeKind, node, proto)) {
762
- // Add dictionary entries.
763
- addValueAttribute (proto, " keepdims" , 1 );
764
- } else {
765
- addValueAttribute (proto, " keepdims" , 0 );
766
- outputsToProto (node, proto);
767
- }
773
+ addValueAttribute (proto, " keepdims" , 0 );
774
+ outputKindToProto (node, proto);
768
775
769
776
return Error::success ();
770
777
}
@@ -781,14 +788,8 @@ Error ONNXModelWriter::writeBatchedReduceAdd(const BatchedReduceAddNode *node,
781
788
proto->set_op_type (" ReduceSum" );
782
789
inputsToProto (node, proto);
783
790
784
- // Use the output of reshape node.
785
- if (outputKindToProto (Kinded::Kind::ReshapeNodeKind, node, proto)) {
786
- // Add dictionary entries.
787
- addValueAttribute (proto, " keepdims" , 1 );
788
- } else {
789
- addValueAttribute (proto, " keepdims" , 0 );
790
- outputsToProto (node, proto);
791
- }
791
+ addValueAttribute (proto, " keepdims" , 0 );
792
+ outputKindToProto (node, proto);
792
793
793
794
return Error::success ();
794
795
}
0 commit comments