Skip to content

Commit d9c6f80

Browse files
authored
Arm backend: Add model name to -llama_inputs (#10775)
This way other Llama variants than stories110m can be run.
1 parent 5e8295e commit d9c6f80

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

backends/arm/test/models/test_llama.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,27 +33,35 @@
3333
class TestLlama(unittest.TestCase):
3434
"""
3535
Test class of Llama models. Type of Llama model depends on command line parameters:
36-
--llama_inputs <path to .pt file> <path to json file>
37-
Example: --llama_inputs stories110M/stories110M.pt stories110M/params.json
36+
--llama_inputs <path to .pt file> <path to json file> <name of model variant>
37+
Example: --llama_inputs stories110M/stories110M.pt stories110M/params.json stories110m
38+
For more examples and info see examples/models/llama/README.md.
3839
"""
3940

4041
def prepare_model(self):
4142

4243
checkpoint = None
4344
params_file = None
45+
usage = "To run use --llama_inputs <.pt/.pth> <.json> <name>"
46+
4447
if conftest.is_option_enabled("llama_inputs"):
4548
param_list = conftest.get_option("llama_inputs")
46-
assert (
47-
isinstance(param_list, list) and len(param_list) == 2
48-
), "invalid number of inputs for --llama_inputs"
49+
50+
if not isinstance(param_list, list) or len(param_list) != 3:
51+
raise RuntimeError(
52+
f"Invalid number of inputs for --llama_inputs. {usage}"
53+
)
54+
if not all(isinstance(param, str) for param in param_list):
55+
raise RuntimeError(
56+
f"All --llama_inputs are expected to be strings. {usage}"
57+
)
58+
4959
checkpoint = param_list[0]
5060
params_file = param_list[1]
51-
assert isinstance(checkpoint, str) and isinstance(
52-
params_file, str
53-
), "invalid input for --llama_inputs"
61+
model_name = param_list[2]
5462
else:
5563
logger.warning(
56-
"Skipping Llama test because of lack of input. To run use --llama_inputs <.pt> <.json>"
64+
"Skipping Llama tests because of missing --llama_inputs. {usage}"
5765
)
5866
return None, None, None
5967

@@ -71,7 +79,7 @@ def prepare_model(self):
7179
"-p",
7280
params_file,
7381
"--model",
74-
"stories110m",
82+
model_name,
7583
]
7684
parser = build_args_parser()
7785
args = parser.parse_args(args)
@@ -122,6 +130,7 @@ def test_llama_tosa_BI(self):
122130
.quantize()
123131
.export()
124132
.to_edge_transform_and_lower()
133+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
125134
.to_executorch()
126135
.run_method_and_compare_outputs(
127136
inputs=llama_inputs,

0 commit comments

Comments
 (0)