diff --git a/backends/xnnpack/runtime/XNNExecutor.cpp b/backends/xnnpack/runtime/XNNExecutor.cpp index 0311489d9df..1680b635993 100644 --- a/backends/xnnpack/runtime/XNNExecutor.cpp +++ b/backends/xnnpack/runtime/XNNExecutor.cpp @@ -86,6 +86,11 @@ __ET_NODISCARD Error XNNExecutor::prepare_args(EValue** args) { // Reshape runtime inputs if (i < input_ids_.size()) { size_t num_dims = tensor->dim(); + ET_CHECK_OR_RETURN_ERROR( + is_contiguous_dim_order(tensor->dim_order().data(), tensor->dim()), + Internal, + "Expecting default dim_order but got a non default dim_order tensor for external input %u", + i); size_t dims[XNN_MAX_TENSOR_DIMS]; ET_CHECK_OR_RETURN_ERROR( num_dims <= XNN_MAX_TENSOR_DIMS, diff --git a/backends/xnnpack/xnnpack_preprocess.py b/backends/xnnpack/xnnpack_preprocess.py index 5d4c05a5350..b7ee440c289 100644 --- a/backends/xnnpack/xnnpack_preprocess.py +++ b/backends/xnnpack/xnnpack_preprocess.py @@ -78,6 +78,22 @@ def generate_node_to_external_map( return node_to_external_map +def assert_default_dim_order(edge_graph_module: torch.fx.GraphModule) -> None: + for node in edge_graph_module.graph.nodes: + if node.op != "placeholder": + continue + + # We expect the default dim order for all tensor-like inputs i.e. inputs, buffers, and params + t = node.meta.get("val", None) + if t is not None and getattr(t, "dim_order", None) is not None: + default_dim_order = tuple(range(t.dim())) + if t.dim_order() != default_dim_order: + raise RuntimeError( + f"XNNPACK backend only supports contiguous memory format for inputs." + f"Expecting dim_order: {default_dim_order}, but got {node.meta['val'].dim_order()} for a placeholder node {node}." + ) + + @final class XnnpackBackend(BackendDetails): @staticmethod @@ -126,6 +142,9 @@ def preprocess( node_to_external_map = generate_node_to_external_map(ep, graph_module) + # Make sure all inputs are contiguous_format or NCHW or default dim order + assert_default_dim_order(graph_module) + # TODO retrace the graph module to lift the new params may have # been added to the graph in passes