33
33
class TestLlama (unittest .TestCase ):
34
34
"""
35
35
Test class of Llama models. Type of Llama model depends on command line parameters:
36
- --llama_inputs <path to .pt file> <path to json file>
37
- Example: --llama_inputs stories110M/stories110M.pt stories110M/params.json
36
+ --llama_inputs <path to .pt file> <path to json file> <name of model variant>
37
+ Example: --llama_inputs stories110M/stories110M.pt stories110M/params.json stories110m
38
+ For more examples and info see examples/models/llama/README.md.
38
39
"""
39
40
40
41
def prepare_model (self ):
41
42
42
43
checkpoint = None
43
44
params_file = None
45
+ usage = "To run use --llama_inputs <.pt/.pth> <.json> <name>"
46
+
44
47
if conftest .is_option_enabled ("llama_inputs" ):
45
48
param_list = conftest .get_option ("llama_inputs" )
46
- assert (
47
- isinstance (param_list , list ) and len (param_list ) == 2
48
- ), "invalid number of inputs for --llama_inputs"
49
+
50
+ if not isinstance (param_list , list ) or len (param_list ) != 3 :
51
+ raise RuntimeError (
52
+ f"Invalid number of inputs for --llama_inputs. { usage } "
53
+ )
54
+ if not all (isinstance (param , str ) for param in param_list ):
55
+ raise RuntimeError (
56
+ f"All --llama_inputs are expected to be strings. { usage } "
57
+ )
58
+
49
59
checkpoint = param_list [0 ]
50
60
params_file = param_list [1 ]
51
- assert isinstance (checkpoint , str ) and isinstance (
52
- params_file , str
53
- ), "invalid input for --llama_inputs"
61
+ model_name = param_list [2 ]
54
62
else :
55
63
logger .warning (
56
- "Skipping Llama test because of lack of input. To run use --llama_inputs <.pt> <.json> "
64
+ "Skipping Llama tests because of missing --llama_inputs. {usage} "
57
65
)
58
66
return None , None , None
59
67
@@ -71,7 +79,7 @@ def prepare_model(self):
71
79
"-p" ,
72
80
params_file ,
73
81
"--model" ,
74
- "stories110m" ,
82
+ model_name ,
75
83
]
76
84
parser = build_args_parser ()
77
85
args = parser .parse_args (args )
@@ -122,6 +130,7 @@ def test_llama_tosa_BI(self):
122
130
.quantize ()
123
131
.export ()
124
132
.to_edge_transform_and_lower ()
133
+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
125
134
.to_executorch ()
126
135
.run_method_and_compare_outputs (
127
136
inputs = llama_inputs ,
0 commit comments