diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 618c74e8706..116d58a84f9 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -674,6 +674,12 @@ def _validate_args(args): "If you need this feature, please file an issue." ) + if args.xnnpack: + if args.dtype_override not in ["fp32", "fp16"]: + raise ValueError( + f"XNNPACK supports either fp32 or fp16 dtypes only for now. Given {args.dtype_override}." + ) + def _export_llama(args) -> LLMEdgeManager: # noqa: C901 _validate_args(args)