Skip to content

Commit 0d57af5

Browse files
Yinghai Lufacebook-github-bot
Yinghai Lu
authored andcommitted
Fix onnx exporter to remove unnecessary Reshape
Summary: We should not simply add the `Reshape` when we see a `Save`, which doesn't make sense. Reviewed By: tracelogfb Differential Revision: D18632551 fbshipit-source-id: 4f06a030d5453610d68710c61cae27e45be2f4cb
1 parent a210f97 commit 0d57af5

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

lib/Exporter/ONNXModelWriter.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,19 @@ void inputsToProto(const Node *node, ONNX_NAMESPACE::NodeProto *proto) {
197197
}
198198
}
199199

200+
/// Write the output of the provided node, add SaveNode if necessary
201+
void outputKindToProto(const Node *node, ONNX_NAMESPACE::NodeProto *proto) {
202+
for (const auto &use : node->getUsers()) {
203+
const auto *user = use.getUser();
204+
if (user->getKind() == Kinded::Kind::SaveNodeKind) {
205+
const SaveNode *SN = llvm::cast<SaveNode>(user);
206+
proto->add_output(SN->getPlaceholder()->getName());
207+
} else {
208+
outputsToProto(user, proto);
209+
}
210+
}
211+
}
212+
200213
/// Write the output of the provided type only of node outputs.
201214
bool outputKindToProto(Kinded::Kind kind, const Node *node,
202215
ONNX_NAMESPACE::NodeProto *proto) {
@@ -757,14 +770,8 @@ Error ONNXModelWriter::writeBatchedReduceMean(const BatchedReduceMeanNode *node,
757770
proto->set_op_type("ReduceMean");
758771
inputsToProto(node, proto);
759772

760-
// Use the output of reshape node.
761-
if (outputKindToProto(Kinded::Kind::ReshapeNodeKind, node, proto)) {
762-
// Add dictionary entries.
763-
addValueAttribute(proto, "keepdims", 1);
764-
} else {
765-
addValueAttribute(proto, "keepdims", 0);
766-
outputsToProto(node, proto);
767-
}
773+
addValueAttribute(proto, "keepdims", 0);
774+
outputKindToProto(node, proto);
768775

769776
return Error::success();
770777
}
@@ -781,14 +788,8 @@ Error ONNXModelWriter::writeBatchedReduceAdd(const BatchedReduceAddNode *node,
781788
proto->set_op_type("ReduceSum");
782789
inputsToProto(node, proto);
783790

784-
// Use the output of reshape node.
785-
if (outputKindToProto(Kinded::Kind::ReshapeNodeKind, node, proto)) {
786-
// Add dictionary entries.
787-
addValueAttribute(proto, "keepdims", 1);
788-
} else {
789-
addValueAttribute(proto, "keepdims", 0);
790-
outputsToProto(node, proto);
791-
}
791+
addValueAttribute(proto, "keepdims", 0);
792+
outputKindToProto(node, proto);
792793

793794
return Error::success();
794795
}

0 commit comments

Comments
 (0)