diff --git a/src/diffusers/pipelines/onnx_utils.py b/src/diffusers/pipelines/onnx_utils.py index 0e12340f6895..74e9f0b97800 100644 --- a/src/diffusers/pipelines/onnx_utils.py +++ b/src/diffusers/pipelines/onnx_utils.py @@ -75,6 +75,11 @@ def load_model(path: Union[str, Path], provider=None, sess_options=None, provide logger.info("No onnxruntime provider specified, using CPUExecutionProvider") provider = "CPUExecutionProvider" + if provider_options is None: + provider_options = [] + elif not isinstance(provider_options, list): + provider_options = [provider_options] + return ort.InferenceSession( path, providers=[provider], sess_options=sess_options, provider_options=provider_options ) @@ -174,7 +179,10 @@ def _from_pretrained( # load model from local directory if os.path.isdir(model_id): model = OnnxRuntimeModel.load_model( - Path(model_id, model_file_name).as_posix(), provider=provider, sess_options=sess_options + Path(model_id, model_file_name).as_posix(), + provider=provider, + sess_options=sess_options, + provider_options=kwargs.pop("provider_options"), ) kwargs["model_save_dir"] = Path(model_id) # load model from hub @@ -190,7 +198,12 @@ def _from_pretrained( ) kwargs["model_save_dir"] = Path(model_cache_path).parent kwargs["latest_model_name"] = Path(model_cache_path).name - model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options) + model = OnnxRuntimeModel.load_model( + model_cache_path, + provider=provider, + sess_options=sess_options, + provider_options=kwargs.pop("provider_options"), + ) return cls(model=model, **kwargs) @classmethod