-
Notifications
You must be signed in to change notification settings - Fork 680
Labels
backend testerThis bug was found by the backend test suite.This bug was found by the backend test suite.module: qnnIssues related to Qualcomm's QNN delegate and code under backends/qualcomm/Issues related to Qualcomm's QNN delegate and code under backends/qualcomm/partner: qualcommFor backend delegation, kernels, demo, etc. from the 3rd-party partner, QualcommFor backend delegation, kernels, demo, etc. from the 3rd-party partner, Qualcomm
Description
🐛 Describe the bug
The conformer model from torchaudio does not successfully lower on the QNN backend. It fails with an error - "aten_where_self" generated: could not create op. This same error message shows up in a few different models, so it would be nice to resolve it.
Output excerpt:
INFO:executorch.backends.qualcomm.qnn_preprocess:Visiting: aten_permute_copy_default_49, aten.permute_copy.default
[ERROR] [Qnn ExecuTorch]: graph_prepare.cc:219::ERROR:could not create op: q::Select.exe.tcm
[ERROR] [Qnn ExecuTorch]: graph_prepare.cc:221::ERROR:Op creation failure, op id=0x32b00000001b (q::Select.exe.tcm) total_inputs=3
[ERROR] [Qnn ExecuTorch]: graph_prepare.cc:207: Input 0: id=[0x2c43f00000002] op=[*[email protected]*6.] output0=[14ConcreteTensorIN5Tdefs14QuantUint8_TCMEE]
[ERROR] [Qnn ExecuTorch]: graph_prepare.cc:207: Input 1: id=[0x4fd9500000019] op=[[email protected].] output0=[14ConcreteTensorIN5Tdefs14PlainFloat_TCMEE]
[ERROR] [Qnn ExecuTorch]: graph_prepare.cc:207: Input 2: id=[0x4fd960000001a] op=[[email protected].] output0=[14ConcreteTensorIN5Tdefs14QuantUint8_TCMEE]
[ERROR] [Qnn ExecuTorch]: graph_prepare.cc:1573::ERROR:Op 0x32b00000001b preparation failed with err:-1
[ERROR] [Qnn ExecuTorch]: <E> "aten_where_self" generated: could not create op
[ERROR] [Qnn ExecuTorch]: <E> RouterX86 graph prepare failed 12
[ERROR] [Qnn ExecuTorch]: <E> Failed to finalize graph (id: 1) with err 1002
[ERROR] [Qnn ExecuTorch]: Failed to finalize Qnn Graph with error: 1002
[ERROR] [Qnn ExecuTorch]: Fail to compile QNN graph
This can be reproduced with the following test case command or standalone script.
python -m executorch.backends.test.suite.runner models --flow qnn --filter "test_conformer_qnn_float32$"
Standalone repro:
from typing import Tuple
import executorch
import torch
import torchaudio
from executorch.backends.qualcomm.utils.utils import (
generate_qnn_executorch_compiler_spec,
generate_htp_compiler_spec,
QcomChipset,
to_edge_transform_and_lower_to_qnn,
)
class PatchedConformer(torch.nn.Module):
"""
A lightly modified version of the top-level Conformer module, such that it can be exported.
Instead of taking lengths and computing the padding mask, it takes the padding mask directly.
See https://github.com/pytorch/audio/blob/main/src/torchaudio/models/conformer.py#L215
"""
def __init__(self, conformer):
super().__init__()
self.conformer = conformer
def forward(
self, input: torch.Tensor, encoder_padding_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
x = input.transpose(0, 1)
for layer in self.conformer.conformer_layers:
x = layer(x, encoder_padding_mask)
return x.transpose(0, 1)
inner_model = torchaudio.models.Conformer(
input_dim=80,
num_heads=4,
ffn_dim=128,
num_layers=4,
depthwise_conv_kernel_size=31,
)
model = PatchedConformer(inner_model).eval()
lengths = torch.randint(1, 400, (10,))
encoder_padding_mask = torchaudio.models.conformer._lengths_to_padding_mask(
lengths
)
inputs = (
torch.rand(10, int(lengths.max()), 80),
encoder_padding_mask,
)
ep = torch.export.export(model, inputs)
backend_options = generate_htp_compiler_spec(
use_fp16=True,
)
compile_spec = generate_qnn_executorch_compiler_spec(
soc_model=QcomChipset.SM8650,
backend_options=backend_options,
)
model = to_edge_transform_and_lower_to_qnn(
model,
inputs,
compile_spec
).to_executorch()
Note that running the backend test case requires executorch's python bindings to be built with the QNN backend. An example build command is below, Note that it will still need the library paths to be set up properly as described in the ET QNN docs.
CMAKE_ARGS="-DEXECUTORCH_BUILD_QNN=ON -DQNN_SDK_ROOT=$QNN_SDK_ROOT" ./install_executorch.sh --editable
Versions
Commit fbda3a9, x86-64 simulator, WSL
Metadata
Metadata
Assignees
Labels
backend testerThis bug was found by the backend test suite.This bug was found by the backend test suite.module: qnnIssues related to Qualcomm's QNN delegate and code under backends/qualcomm/Issues related to Qualcomm's QNN delegate and code under backends/qualcomm/partner: qualcommFor backend delegation, kernels, demo, etc. from the 3rd-party partner, QualcommFor backend delegation, kernels, demo, etc. from the 3rd-party partner, Qualcomm