|
6 | 6 | import json
|
7 | 7 | import subprocess
|
8 | 8 | import sys
|
| 9 | +import tempfile |
9 | 10 | import unittest
|
10 | 11 | from multiprocessing.connection import Listener
|
| 12 | +from pathlib import Path |
11 | 13 |
|
12 | 14 | import torch
|
13 | 15 | from executorch.backends.qualcomm.tests.utils import (
|
@@ -1102,6 +1104,19 @@ def test_qnn_backend_shared_buffer(self):
|
1102 | 1104 | expected_partitions=1,
|
1103 | 1105 | )
|
1104 | 1106 |
|
| 1107 | + def test_qnn_backend_online_prepare(self): |
| 1108 | + backend_options = generate_htp_compiler_spec(use_fp16=True) |
| 1109 | + TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( |
| 1110 | + soc_model=self.arch_table[TestQNN.model], |
| 1111 | + backend_options=backend_options, |
| 1112 | + debug=False, |
| 1113 | + saver=False, |
| 1114 | + online_prepare=True, |
| 1115 | + ) |
| 1116 | + module = SimpleModel() # noqa: F405 |
| 1117 | + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) |
| 1118 | + self.lower_module_and_test_output(module, sample_input) |
| 1119 | + |
1105 | 1120 |
|
1106 | 1121 | class TestQNNQuantizedUtils(TestQNN):
|
1107 | 1122 | # TODO: refactor to support different backends
|
@@ -1223,6 +1238,20 @@ def test_qnn_backend_shared_buffer(self):
|
1223 | 1238 | expected_partitions=1,
|
1224 | 1239 | )
|
1225 | 1240 |
|
| 1241 | + def test_qnn_backend_online_prepare(self): |
| 1242 | + backend_options = generate_htp_compiler_spec(use_fp16=False) |
| 1243 | + TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( |
| 1244 | + soc_model=self.arch_table[TestQNN.model], |
| 1245 | + backend_options=backend_options, |
| 1246 | + debug=False, |
| 1247 | + saver=False, |
| 1248 | + online_prepare=True, |
| 1249 | + ) |
| 1250 | + module = SimpleModel() # noqa: F405 |
| 1251 | + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) |
| 1252 | + module = self.get_qdq_module(module, sample_input) |
| 1253 | + self.lower_module_and_test_output(module, sample_input) |
| 1254 | + |
1226 | 1255 |
|
1227 | 1256 | class TestExampleOssScript(TestQNN):
|
1228 | 1257 | def required_envs(self, conditions=None) -> bool:
|
@@ -1640,6 +1669,29 @@ def test_ptq_mobilebert(self):
|
1640 | 1669 | for k, v in cpu.items():
|
1641 | 1670 | self.assertLessEqual(abs(v[0] - htp[k][0]), 5)
|
1642 | 1671 |
|
| 1672 | + def test_export_example(self): |
| 1673 | + if not self.required_envs([self.model_name]): |
| 1674 | + self.skipTest("missing required envs") |
| 1675 | + |
| 1676 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 1677 | + cmds = [ |
| 1678 | + "python", |
| 1679 | + "qualcomm/scripts/export_example.py", |
| 1680 | + "--model_name", |
| 1681 | + self.model_name, |
| 1682 | + "--output_folder", |
| 1683 | + "{}/".format(tmp_dir), |
| 1684 | + "--generate_etrecord", |
| 1685 | + ] |
| 1686 | + |
| 1687 | + p = subprocess.Popen( |
| 1688 | + cmds, stdout=subprocess.DEVNULL, cwd=f"{self.executorch_root}/examples" |
| 1689 | + ) |
| 1690 | + p.communicate() |
| 1691 | + self.assertTrue( |
| 1692 | + Path("{0}/{1}.pte".format(tmp_dir, self.model_name)).exists() |
| 1693 | + ) |
| 1694 | + |
1643 | 1695 |
|
1644 | 1696 | def setup_environment():
|
1645 | 1697 | parser = setup_common_args_and_variables()
|
@@ -1669,6 +1721,12 @@ def setup_environment():
|
1669 | 1721 | default="",
|
1670 | 1722 | type=str,
|
1671 | 1723 | )
|
| 1724 | + parser.add_argument( |
| 1725 | + "-n", |
| 1726 | + "--model_name", |
| 1727 | + help="Input the model to export", |
| 1728 | + type=str, |
| 1729 | + ) |
1672 | 1730 | parser.add_argument(
|
1673 | 1731 | "-o",
|
1674 | 1732 | "--online_prepare",
|
@@ -1697,6 +1755,7 @@ def setup_environment():
|
1697 | 1755 | TestQNN.artifact_dir = args.artifact_dir
|
1698 | 1756 | TestQNN.image_dataset = args.image_dataset
|
1699 | 1757 | TestQNN.pretrained_weight = args.pretrained_weight
|
| 1758 | + TestQNN.model_name = args.model_name |
1700 | 1759 | TestQNN.online_prepare = args.online_prepare
|
1701 | 1760 | TestQNN.enable_profile = args.enable_profile
|
1702 | 1761 | TestQNN.error_only = args.error_only
|
|
0 commit comments