-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Closed
Closed
Copy link
Description
In #11806, we introduced automatic extraction of exclude_modules
and passing it to the following function to prepare LoraConfig
kwargs:
diffusers/src/diffusers/utils/peft_utils.py
Line 153 in 425a715
def get_peft_kwargs( |
The problem with it is that if the target_modules
includes a module name where a substring from exclude_modules
is present, the module from target_modules
becomes invalid for peft
. So, for example, if target_modules
has single_transformer_blocks.0.proj_out
and exclude_modules
has proj_out
, we will see a problem. See below
Expand
oading adapter weights from state_dict led to unexpected keys found in the model: single_transformer_blocks.0.proj_out.lora_A.default_0.weight, single_transformer_blocks.0.proj_out.lora_B.default_0.weight, single_transformer_blocks.1.proj_out.lora_A.default_0.weight, single_transformer_blocks.1.proj_out.lora_B.default_0.weight, single_transformer_blocks.2.proj_out.lora_A.default_0.weight, single_transformer_blocks.2.proj_out.lora_B.default_0.weight, single_transformer_blocks.3.proj_out.lora_A.default_0.weight, single_transformer_blocks.3.proj_out.lora_B.default_0.weight, single_transformer_blocks.4.proj_out.lora_A.default_0.weight, single_transformer_blocks.4.proj_out.lora_B.default_0.weight, single_transformer_blocks.5.proj_out.lora_A.default_0.weight, single_transformer_blocks.5.proj_out.lora_B.default_0.weight, single_transformer_blocks.6.proj_out.lora_A.default_0.weight, single_transformer_blocks.6.proj_out.lora_B.default_0.weight, single_transformer_blocks.7.proj_out.lora_A.default_0.weight, single_transformer_blocks.7.proj_out.lora_B.default_0.weight, single_transformer_blocks.8.proj_out.lora_A.default_0.weight, single_transformer_blocks.8.proj_out.lora_B.default_0.weight, single_transformer_blocks.9.proj_out.lora_A.default_0.weight, single_transformer_blocks.9.proj_out.lora_B.default_0.weight, single_transformer_blocks.10.proj_out.lora_A.default_0.weight, single_transformer_blocks.10.proj_out.lora_B.default_0.weight, single_transformer_blocks.11.proj_out.lora_A.default_0.weight, single_transformer_blocks.11.proj_out.lora_B.default_0.weight, single_transformer_blocks.12.proj_out.lora_A.default_0.weight, single_transformer_blocks.12.proj_out.lora_B.default_0.weight, single_transformer_blocks.13.proj_out.lora_A.default_0.weight, single_transformer_blocks.13.proj_out.lora_B.default_0.weight, single_transformer_blocks.14.proj_out.lora_A.default_0.weight, single_transformer_blocks.14.proj_out.lora_B.default_0.weight, single_transformer_blocks.15.proj_out.lora_A.default_0.weight, single_transformer_blocks.15.proj_out.lora_B.default_0.weight, single_transformer_blocks.16.proj_out.lora_A.default_0.weight, single_transformer_blocks.16.proj_out.lora_B.default_0.weight, single_transformer_blocks.17.proj_out.lora_A.default_0.weight, single_transformer_blocks.17.proj_out.lora_B.default_0.weight, single_transformer_blocks.18.proj_out.lora_A.default_0.weight, single_transformer_blocks.18.proj_out.lora_B.default_0.weight, single_transformer_blocks.19.proj_out.lora_A.default_0.weight, single_transformer_blocks.19.proj_out.lora_B.default_0.weight, single_transformer_blocks.20.proj_out.lora_A.default_0.weight, single_transformer_blocks.20.proj_out.lora_B.default_0.weight, single_transformer_blocks.21.proj_out.lora_A.default_0.weight, single_transformer_blocks.21.proj_out.lora_B.default_0.weight, single_transformer_blocks.22.proj_out.lora_A.default_0.weight, single_transformer_blocks.22.proj_out.lora_B.default_0.weight, single_transformer_blocks.23.proj_out.lora_A.default_0.weight, single_transformer_blocks.23.proj_out.lora_B.default_0.weight, single_transformer_blocks.24.proj_out.lora_A.default_0.weight, single_transformer_blocks.24.proj_out.lora_B.default_0.weight, single_transformer_blocks.25.proj_out.lora_A.default_0.weight, single_transformer_blocks.25.proj_out.lora_B.default_0.weight, single_transformer_blocks.26.proj_out.lora_A.default_0.weight, single_transformer_blocks.26.proj_out.lora_B.default_0.weight, single_transformer_blocks.27.proj_out.lora_A.default_0.weight, single_transformer_blocks.27.proj_out.lora_B.default_0.weight, single_transformer_blocks.28.proj_out.lora_A.default_0.weight, single_transformer_blocks.28.proj_out.lora_B.default_0.weight, single_transformer_blocks.29.proj_out.lora_A.default_0.weight, single_transformer_blocks.29.proj_out.lora_B.default_0.weight, single_transformer_blocks.30.proj_out.lora_A.default_0.weight, single_transformer_blocks.30.proj_out.lora_B.default_0.weight, single_transformer_blocks.31.proj_out.lora_A.default_0.weight, single_transformer_blocks.31.proj_out.lora_B.default_0.weight, single_transformer_blocks.32.proj_out.lora_A.default_0.weight, single_transformer_blocks.32.proj_out.lora_B.default_0.weight, single_transformer_blocks.33.proj_out.lora_A.default_0.weight, single_transformer_blocks.33.proj_out.lora_B.default_0.weight, single_transformer_blocks.34.proj_out.lora_A.default_0.weight, single_transformer_blocks.34.proj_out.lora_B.default_0.weight, single_transformer_blocks.35.proj_out.lora_A.default_0.weight, single_transformer_blocks.35.proj_out.lora_B.default_0.weight, single_transformer_blocks.36.proj_out.lora_A.default_0.weight, single_transformer_blocks.36.proj_out.lora_B.default_0.weight, single_transformer_blocks.37.proj_out.lora_A.default_0.weight, single_transformer_blocks.37.proj_out.lora_B.default_0.weight.
Reproducer:
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to("cuda")
pipe_kwargs = {
"prompt": "{trigger_word} A cat holding a sign that says hello world",
"height": 1024,
"width": 1024,
"guidance_scale": 3.5,
"num_inference_steps": 28,
"max_sequence_length": 512,
}
pipeline.load_lora_weights("glif/l0w-r3z")
image = pipeline(**pipe_kwargs).images[0]
def test_lora_exclude_modules_wanvace(self): |
Cc: @BenjaminBossan
Metadata
Metadata
Assignees
Labels
No labels