diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 8aae80e29da..64de5a7a1f4 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -30,7 +30,10 @@ def ram_cache(self) -> ModelCache: @abstractmethod def load_model_from_path( - self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None + self, + model_path: Path, + loader: Optional[Callable[[Path], AnyModel]] = None, + cache_key_extra: Optional[str] = None, ) -> LoadedModelWithoutConfig: """ Load the model file or directory located at the indicated Path. @@ -46,6 +49,8 @@ def load_model_from_path( Args: model_path: A pathlib.Path to a checkpoint-style models file loader: A Callable that expects a Path and returns a Dict[str, Tensor] + cache_key_extra: A string to append to the cache key. This is useful for + differentiating an instances of the same model with different parameters. Returns: A LoadedModel object. diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index dac45c70252..d12c6a6b528 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -76,9 +76,12 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo return loaded_model def load_model_from_path( - self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None + self, + model_path: Path, + loader: Optional[Callable[[Path], AnyModel]] = None, + cache_key_extra: Optional[str] = None, ) -> LoadedModelWithoutConfig: - cache_key = str(model_path) + cache_key = f"{model_path}:{cache_key_extra}" if cache_key_extra else str(model_path) try: return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache) except IndexError: diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 743b6208ead..51163a590a4 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -497,6 +497,7 @@ def load_local_model( self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None, + cache_key_extra: Optional[str] = None, ) -> LoadedModelWithoutConfig: """ Load the model file located at the indicated path @@ -509,18 +510,25 @@ def load_local_model( Args: path: A model Path loader: A Callable that expects a Path and returns a dict[str|int, Any] + cache_key_extra: A string to append to the cache key. This is useful for + differentiating an instances of the same model with different parameters. Returns: A LoadedModelWithoutConfig object. """ self._util.signal_progress(f"Loading model {model_path.name}") - return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader) + return self._services.model_manager.load.load_model_from_path( + model_path=model_path, + loader=loader, + cache_key_extra=cache_key_extra, + ) def load_remote_model( self, source: str | AnyHttpUrl, loader: Optional[Callable[[Path], AnyModel]] = None, + cache_key_extra: Optional[str] = None, ) -> LoadedModelWithoutConfig: """ Download, cache, and load the model file located at the indicated URL or repo_id. @@ -535,6 +543,8 @@ def load_remote_model( Args: source: A URL or huggingface repoid. loader: A Callable that expects a Path and returns a dict[str|int, Any] + cache_key_extra: A string to append to the cache key. This is useful for + differentiating an instances of the same model with different parameters. Returns: A LoadedModelWithoutConfig object. @@ -542,7 +552,11 @@ def load_remote_model( model_path = self._services.model_manager.install.download_and_cache_model(source=str(source)) self._util.signal_progress(f"Loading model {source}") - return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader) + return self._services.model_manager.load.load_model_from_path( + model_path=model_path, + loader=loader, + cache_key_extra=cache_key_extra, + ) def get_absolute_path(self, config_or_path: AnyModelConfig | Path | str) -> Path: """Gets the absolute path for a given model config or path.