Skip to content

Commit e9a3dd3

Browse files
shiyi9801ankitm3k
authored andcommitted
Reland "[WebNN] Fallback the node when its output doesn't have shape info" (microsoft#22685)
The previous PR was reverted because it causes the whole model to fallback when there is output shape info missing. This PR fixes the issue by removing redundant fallbacks.
1 parent f71ccd5 commit e9a3dd3

File tree

5 files changed

+25
-21
lines changed

5 files changed

+25
-21
lines changed

onnxruntime/core/providers/webnn/builders/helper.cc

+8-17
Original file line numberDiff line numberDiff line change
@@ -69,30 +69,28 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We
6969
}
7070
}
7171

72-
bool IsInputSupported(const NodeArg& input, const std::string& parent_name, const logging::Logger& logger) {
73-
const auto& input_name = input.Name();
74-
const auto* shape_proto = input.Shape();
72+
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) {
73+
const auto& node_arg_name = node_arg.Name();
74+
const auto* shape_proto = node_arg.Shape();
7575
// Optional tensors can be indicated by an empty name, just ignore it.
76-
if (input_name.empty()) {
76+
if (node_arg_name.empty()) {
7777
return true;
7878
}
79-
// We do not support input with no shape.
79+
// We do not support input/output with no shape.
8080
if (!shape_proto) {
81-
LOGS(logger, VERBOSE) << "Input [" << input_name << "] of [" << parent_name
82-
<< "] has not shape";
81+
LOGS(logger, VERBOSE) << "Node arg [" << node_arg_name << "] of [" << parent_name << "] has not shape";
8382
return false;
8483
}
8584

8685
for (const auto& dim : shape_proto->dim()) {
8786
// WebNN doesn't support dynamic shape - use sessionOptions.freeDimensionOverrides to fix the shape.
8887
if (!dim.has_dim_value()) {
8988
LOGS(logger, VERBOSE) << "Dynamic shape is not supported, "
90-
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape for input: "
91-
<< input_name;
89+
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name;
9290
return false;
9391
}
9492
if (dim.dim_value() == 0) {
95-
LOGS(logger, VERBOSE) << "The shape of [" << input_name << "] has 0 dimension which is not supported by WebNN";
93+
LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN";
9694
return false;
9795
}
9896
}
@@ -106,13 +104,6 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
106104
const emscripten::val& wnn_limits,
107105
const logging::Logger& logger) {
108106
std::vector<std::vector<size_t>> supported_node_groups;
109-
110-
for (const auto* input : graph_viewer.GetInputs()) {
111-
if (!IsInputSupported(*input, "graph", logger)) {
112-
return supported_node_groups;
113-
}
114-
}
115-
116107
std::vector<size_t> supported_node_group;
117108
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
118109

onnxruntime/core/providers/webnn/builders/helper.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s
180180
return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; });
181181
}
182182

183-
bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
183+
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
184184

185185
// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
186186
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,

onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc

+14-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
3232
if (!HasSupportedInputs(node, wnn_limits, logger))
3333
return false;
3434

35-
if (!HasSupportedOutputsImpl(node, wnn_limits, logger))
35+
if (!HasSupportedOutputs(node, wnn_limits, logger))
3636
return false;
3737

3838
if (!HasSupportedOpSet(node, logger))
@@ -45,7 +45,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val&
4545
const logging::Logger& logger) const {
4646
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
4747
for (const auto* input : node.InputDefs()) {
48-
if (!IsInputSupported(*input, node_name, logger)) {
48+
if (!IsTensorShapeSupported(*input, node_name, logger)) {
4949
return false;
5050
}
5151
}
@@ -66,6 +66,18 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
6666
return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger);
6767
}
6868

69+
bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits,
70+
const logging::Logger& logger) const {
71+
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
72+
for (const auto* output : node.OutputDefs()) {
73+
if (!IsTensorShapeSupported(*output, node_name, logger)) {
74+
return false;
75+
}
76+
}
77+
78+
return HasSupportedOutputsImpl(node, wnn_limits, logger);
79+
}
80+
6981
bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node,
7082
const emscripten::val& wnn_limits,
7183
const logging::Logger& logger) const {

onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class BaseOpBuilder : public IOpBuilder {
5454
private:
5555
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
5656
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
57+
bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
5758
};
5859

5960
} // namespace webnn

onnxruntime/core/providers/webnn/builders/model_builder.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
227227
if (!shape.empty()) {
228228
dims.reserve(shape.size());
229229
for (const auto& dim : shape) {
230-
// dim_param free dimensions should have already been excluded by IsInputSupported().
230+
// dim_param free dimensions should have already been excluded by IsTensorShapeSupported().
231231
assert(dim.has_dim_value());
232232
dims.push_back(SafeInt<int32_t>(dim.dim_value()));
233233
}

0 commit comments

Comments
 (0)