-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Description
Describe the bug
The function signature of load_model_dict_into_meta
changed in #10604, and device
is no longer an accepted argument. However, IP-Adapter loading still passes device
, as we can see below:
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) |
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) |
diffusers/src/diffusers/loaders/transformer_sd3.py
Lines 78 to 80 in e3bc4aa
load_model_dict_into_meta( | |
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype | |
) |
load_model_dict_into_meta(image_proj, updated_state_dict, device=self.device, dtype=self.dtype) |
diffusers/src/diffusers/loaders/unet.py
Line 756 in e3bc4aa
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) |
diffusers/src/diffusers/loaders/unet.py
Line 849 in e3bc4aa
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) |
Now that #10604 is merged, should we follow a similar approach to FromOriginalModelMixin
, as below? That is, now pass device_map
as {"": param_device}
?
diffusers/src/diffusers/loaders/single_file_model.py
Lines 369 to 387 in e3bc4aa
device_map = None | |
if is_accelerate_available(): | |
param_device = torch.device(device) if device else torch.device("cpu") | |
empty_state_dict = model.state_dict() | |
unexpected_keys = [ | |
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict | |
] | |
device_map = {"": param_device} | |
load_model_dict_into_meta( | |
model, | |
diffusers_format_checkpoint, | |
dtype=torch_dtype, | |
device_map=device_map, | |
hf_quantizer=hf_quantizer, | |
keep_in_fp32_modules=keep_in_fp32_modules, | |
unexpected_keys=unexpected_keys, | |
) | |
else: | |
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) |
Happy to contribute with a PR updating IP-Adapter loading :)
Reproduction
Loading any IP-Adapter with low_cpu_mem_usage=True
, which is the default value when torch >= 1.9.0, for example:
import torch
from diffusers import FluxPipeline
from diffusers.utils import load_image
pipe: FluxPipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
)
pipe.load_ip_adapter(
"XLabs-AI/flux-ip-adapter",
weight_name="ip_adapter.safetensors",
image_encoder_pretrained_model_name_or_path="openai/clip-vit-large-patch14"
)
pipe.set_ip_adapter_scale(0.6)
pipe.enable_sequential_cpu_offload()
ip_adapter_image = load_image("https://huggingface.co/guiyrt/sample-images/resolve/main/astronaut.jpg")
image = pipe(
width=1024,
height=1024,
prompt="A vintage picture of an astronaut in a starry sky",
generator=torch.manual_seed(42),
ip_adapter_image=ip_adapter_image,
).images[0]
image.save('result.jpg')
Logs
Traceback (most recent call last):
File "/home/guiyrt/diffusers/run.py", line 10, in <module>
pipe.load_ip_adapter(
File "/home/guiyrt/anaconda3/envs/diffusers/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
return fn(*args, **kwargs)
File "/home/guiyrt/diffusers/src/diffusers/loaders/ip_adapter.py", line 553, in load_ip_adapter
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
File "/home/guiyrt/diffusers/src/diffusers/loaders/transformer_flux.py", line 168, in _load_ip_adapter_weights
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
File "/home/guiyrt/diffusers/src/diffusers/loaders/transformer_flux.py", line 156, in _convert_ip_adapter_attn_to_diffusers
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
TypeError: load_model_dict_into_meta() got an unexpected keyword argument 'device'
System Info
- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.10.16
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): 0.10.2 (cpu)
- Jax version: 0.5.0
- JaxLib version: 0.5.0
- Huggingface_hub version: 0.28.1
- Transformers version: 4.48.3
- Accelerate version: 1.3.0
- PEFT version: 0.14.0
- Bitsandbytes version: not installed
- Safetensors version: 0.5.2
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 4070 Ti SUPER, 16376 MiB
- Using GPU in script?:
- Using distributed or parallel set-up in script?: