@@ -69,30 +69,28 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We
69
69
}
70
70
}
71
71
72
- bool IsInputSupported (const NodeArg& input , const std::string& parent_name, const logging::Logger& logger) {
73
- const auto & input_name = input .Name ();
74
- const auto * shape_proto = input .Shape ();
72
+ bool IsTensorShapeSupported (const NodeArg& node_arg , const std::string& parent_name, const logging::Logger& logger) {
73
+ const auto & node_arg_name = node_arg .Name ();
74
+ const auto * shape_proto = node_arg .Shape ();
75
75
// Optional tensors can be indicated by an empty name, just ignore it.
76
- if (input_name .empty ()) {
76
+ if (node_arg_name .empty ()) {
77
77
return true ;
78
78
}
79
- // We do not support input with no shape.
79
+ // We do not support input/output with no shape.
80
80
if (!shape_proto) {
81
- LOGS (logger, VERBOSE) << " Input [" << input_name << " ] of [" << parent_name
82
- << " ] has not shape" ;
81
+ LOGS (logger, VERBOSE) << " Node arg [" << node_arg_name << " ] of [" << parent_name << " ] has not shape" ;
83
82
return false ;
84
83
}
85
84
86
85
for (const auto & dim : shape_proto->dim ()) {
87
86
// WebNN doesn't support dynamic shape - use sessionOptions.freeDimensionOverrides to fix the shape.
88
87
if (!dim.has_dim_value ()) {
89
88
LOGS (logger, VERBOSE) << " Dynamic shape is not supported, "
90
- << " use sessionOptions.FreeDimensionOverrides to set a fixed shape for input: "
91
- << input_name;
89
+ << " use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name;
92
90
return false ;
93
91
}
94
92
if (dim.dim_value () == 0 ) {
95
- LOGS (logger, VERBOSE) << " The shape of [" << input_name << " ] has 0 dimension which is not supported by WebNN" ;
93
+ LOGS (logger, VERBOSE) << " The shape of [" << node_arg_name << " ] has 0 dimension which is not supported by WebNN" ;
96
94
return false ;
97
95
}
98
96
}
@@ -106,13 +104,6 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
106
104
const emscripten::val& wnn_limits,
107
105
const logging::Logger& logger) {
108
106
std::vector<std::vector<size_t >> supported_node_groups;
109
-
110
- for (const auto * input : graph_viewer.GetInputs ()) {
111
- if (!IsInputSupported (*input, " graph" , logger)) {
112
- return supported_node_groups;
113
- }
114
- }
115
-
116
107
std::vector<size_t > supported_node_group;
117
108
const auto & node_indices = graph_viewer.GetNodesInTopologicalOrder ();
118
109
0 commit comments