From 3dceb84a89d3a4cbbf37db518aca325fb6989cae Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 10 Jul 2024 06:25:11 +0200 Subject: [PATCH 01/13] start debugging the problem, --- src/diffusers/models/attention_processor.py | 23 ++++++++++++++++++- .../models/transformers/transformer_sd3.py | 9 +++++++- .../test_pipeline_stable_diffusion_3.py | 1 + 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ef25d24f9f1a..68b2548c363f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -649,7 +649,6 @@ def fuse_projections(self, fuse=True): if self.use_bias: concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) self.to_qkv.bias.copy_(concatenated_bias) - else: concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) in_features = concatenated_weights.shape[1] @@ -661,6 +660,27 @@ def fuse_projections(self, fuse=True): concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) self.to_kv.bias.copy_(concatenated_bias) + + print(f"{hasattr(self, 'add_q_proj')=}, {hasattr(self, 'add_k_proj')=}, {hasattr(self, 'add_v_proj')=}") + if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): + concatenated_weights = torch.cat([self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_added_qkv = nn.Linear(in_features, out_features, bias=True, device=device, dtype=dtype) + self.to_added_qkv.weight.copy_(concatenated_weights) + concatenated_bias = torch.cat([self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]) + self.to_added_qkv.bias.copy_(concatenated_bias) + # elif hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): + # concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data]) + # in_features = concatenated_weights.shape[1] + # out_features = concatenated_weights.shape[0] + + # self.to_added_kv = nn.Linear(in_features, out_features, bias=True, device=device, dtype=dtype) + # self.to_added_kv.weight.copy_(concatenated_weights) + # concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data]) + # self.to_added_kv.bias.copy_(concatenated_bias) + self.fused_projections = fuse @@ -1093,6 +1113,7 @@ def __call__( query, key, value = torch.split(qkv, split_size, dim=-1) # `context` projections. + print(f"{hasattr(attn, 'to_added_qkv')=}") encoder_qkv = attn.to_added_qkv(encoder_hidden_states) split_size = encoder_qkv.shape[-1] // 3 ( diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 1b9126b3b849..cffe308bf21d 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...models.attention import JointTransformerBlock -from ...models.attention_processor import Attention, AttentionProcessor +from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers @@ -211,6 +211,7 @@ def fuse_qkv_projections(self): """ + print("I am here.") self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): @@ -221,8 +222,14 @@ def fuse_qkv_projections(self): for module in self.modules(): if isinstance(module, Attention): + print(module.__class__.__name__) module.fuse_projections(fuse=True) + self.set_attn_processor(FusedJointAttnProcessor2_0()) + for key, value in self.attn_processors.items(): + print(key, value) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index 93e740145477..b56a4da7a935 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -192,6 +192,7 @@ def test_fused_qkv_projections(self): original_image_slice = image[0, -3:, -3:, -1] pipe.transformer.fuse_qkv_projections() + assert pipe.transformer is None inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images image_slice_fused = image[0, -3:, -3:, -1] From a352a67c57b160ac9553a5cf97a6526c0942d87d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 10 Jul 2024 11:21:59 +0200 Subject: [PATCH 02/13] start --- src/diffusers/models/transformers/transformer_sd3.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index cffe308bf21d..19c2e2fc9591 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -220,10 +220,14 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - for module in self.modules(): - if isinstance(module, Attention): - print(module.__class__.__name__) - module.fuse_projections(fuse=True) + def fuse_recursively(module): + for submodule in module.children(): + if isinstance(submodule, Attention): + submodule.fuse_projections(fuse=True) + # Recursively call this function on the submodule to handle nesting + fuse_recursively(submodule) + + fuse_recursively(self) self.set_attn_processor(FusedJointAttnProcessor2_0()) for key, value in self.attn_processors.items(): From 77ab54529d2f1dc34dcc647350de5025fd0a6329 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 10 Jul 2024 14:45:10 +0200 Subject: [PATCH 03/13] fix --- src/diffusers/models/attention_processor.py | 1 - src/diffusers/models/transformers/transformer_sd3.py | 4 ---- .../stable_diffusion_3/test_pipeline_stable_diffusion_3.py | 1 - 3 files changed, 6 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 68b2548c363f..c5fb937e1c7c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -445,7 +445,6 @@ def forward( # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty - attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) quiet_attn_parameters = {"ip_adapter_masks"} unused_kwargs = [ diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 19c2e2fc9591..82cd4970385d 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -211,7 +211,6 @@ def fuse_qkv_projections(self): """ - print("I am here.") self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): @@ -230,9 +229,6 @@ def fuse_recursively(module): fuse_recursively(self) self.set_attn_processor(FusedJointAttnProcessor2_0()) - for key, value in self.attn_processors.items(): - print(key, value) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index b56a4da7a935..93e740145477 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -192,7 +192,6 @@ def test_fused_qkv_projections(self): original_image_slice = image[0, -3:, -3:, -1] pipe.transformer.fuse_qkv_projections() - assert pipe.transformer is None inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images image_slice_fused = image[0, -3:, -3:, -1] From 2ee335f87bc164d2ee3108c118100f2497c362df Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 10 Jul 2024 19:11:15 +0200 Subject: [PATCH 04/13] fix --- src/diffusers/models/attention_processor.py | 13 ++++++++----- .../models/transformers/transformer_sd3.py | 4 ++-- src/diffusers/models/unets/unet_2d_condition.py | 14 +++++++++++--- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c5fb937e1c7c..6d785912cd0b 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -659,22 +659,25 @@ def fuse_projections(self, fuse=True): concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) self.to_kv.bias.copy_(concatenated_bias) - print(f"{hasattr(self, 'add_q_proj')=}, {hasattr(self, 'add_k_proj')=}, {hasattr(self, 'add_v_proj')=}") if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): - concatenated_weights = torch.cat([self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]) + concatenated_weights = torch.cat( + [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] + ) in_features = concatenated_weights.shape[1] out_features = concatenated_weights.shape[0] - + self.to_added_qkv = nn.Linear(in_features, out_features, bias=True, device=device, dtype=dtype) self.to_added_qkv.weight.copy_(concatenated_weights) - concatenated_bias = torch.cat([self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]) + concatenated_bias = torch.cat( + [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] + ) self.to_added_qkv.bias.copy_(concatenated_bias) # elif hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): # concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data]) # in_features = concatenated_weights.shape[1] # out_features = concatenated_weights.shape[0] - + # self.to_added_kv = nn.Linear(in_features, out_features, bias=True, device=device, dtype=dtype) # self.to_added_kv.weight.copy_(concatenated_weights) # concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data]) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 82cd4970385d..b7d0fac3efc3 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -199,7 +199,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) @@ -225,7 +225,7 @@ def fuse_recursively(module): submodule.fuse_projections(fuse=True) # Recursively call this function on the submodule to handle nesting fuse_recursively(submodule) - + fuse_recursively(self) self.set_attn_processor(FusedJointAttnProcessor2_0()) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 2b9122799bf3..9a53531b77ce 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -30,6 +30,7 @@ AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, + FusedAttnProcessor2_0, ) from ..embeddings import ( GaussianFourierProjection, @@ -886,9 +887,16 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) + def fuse_recursively(module): + for submodule in module.children(): + if isinstance(submodule, Attention): + submodule.fuse_projections(fuse=True) + # Recursively call this function on the submodule to handle nesting + fuse_recursively(submodule) + + fuse_recursively(self) + + self.set_attn_processor(FusedAttnProcessor2_0()) def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. From ddcc102f02e2041c0ebd9e1073be97dd5ef82b54 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 10 Jul 2024 19:19:19 +0200 Subject: [PATCH 05/13] fix imports. --- .../models/autoencoders/autoencoder_kl.py | 14 +++++++++++--- src/diffusers/models/controlnet_sd3.py | 17 ++++++++++++----- src/diffusers/models/controlnet_xs.py | 14 +++++++++++--- .../transformers/hunyuan_transformer_2d.py | 15 +++++++++++---- src/diffusers/models/unets/unet_3d_condition.py | 14 +++++++++++--- src/diffusers/models/unets/unet_i2vgen_xl.py | 14 +++++++++++--- src/diffusers/models/unets/unet_motion_model.py | 14 +++++++++++--- 7 files changed, 78 insertions(+), 24 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index abc187bf848d..a5b521195b20 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -26,6 +26,7 @@ AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, + FusedAttnProcessor2_0, ) from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin @@ -480,9 +481,16 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) + def fuse_recursively(module): + for submodule in module.children(): + if isinstance(submodule, Attention): + submodule.fuse_projections(fuse=True) + # Recursively call this function on the submodule to handle nesting + fuse_recursively(submodule) + + fuse_recursively(self) + + self.set_attn_processor(FusedAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py index 25eb6384c68c..c3d946a49c1c 100644 --- a/src/diffusers/models/controlnet_sd3.py +++ b/src/diffusers/models/controlnet_sd3.py @@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import FromOriginalModelMixin, PeftAdapterMixin from ..models.attention import JointTransformerBlock -from ..models.attention_processor import Attention, AttentionProcessor +from ..models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 from ..models.modeling_outputs import Transformer2DModelOutput from ..models.modeling_utils import ModelMixin from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers @@ -196,7 +196,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + # Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) @@ -216,9 +216,16 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) + def fuse_recursively(module): + for submodule in module.children(): + if isinstance(submodule, Attention): + submodule.fuse_projections(fuse=True) + # Recursively call this function on the submodule to handle nesting + fuse_recursively(submodule) + + fuse_recursively(self) + + self.set_attn_processor(FusedJointAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 354acfebe0a2..d5f00cf22296 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -29,6 +29,7 @@ AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, + FusedAttnProcessor2_0, ) from .controlnet import ControlNetConditioningEmbedding from .embeddings import TimestepEmbedding, Timesteps @@ -997,9 +998,16 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) + def fuse_recursively(module): + for submodule in module.children(): + if isinstance(submodule, Attention): + submodule.fuse_projections(fuse=True) + # Recursively call this function on the submodule to handle nesting + fuse_recursively(submodule) + + fuse_recursively(self) + + self.set_attn_processor(FusedAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index 8313ffd87a50..0da5795a5c9f 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -21,7 +21,7 @@ from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward -from ..attention_processor import Attention, AttentionProcessor, HunyuanAttnProcessor2_0 +from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0, HunyuanAttnProcessor2_0 from ..embeddings import ( HunyuanCombinedTimestepTextSizeStyleEmbedding, PatchEmbed, @@ -346,9 +346,16 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) + def fuse_recursively(module): + for submodule in module.children(): + if isinstance(submodule, Attention): + submodule.fuse_projections(fuse=True) + # Recursively call this function on the submodule to handle nesting + fuse_recursively(submodule) + + fuse_recursively(self) + + self.set_attn_processor(FusedAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index 40b3b92427ce..45216d4aae2f 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -31,6 +31,7 @@ AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, + FusedAttnProcessor2_0, ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin @@ -528,9 +529,16 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) + def fuse_recursively(module): + for submodule in module.children(): + if isinstance(submodule, Attention): + submodule.fuse_projections(fuse=True) + # Recursively call this function on the submodule to handle nesting + fuse_recursively(submodule) + + fuse_recursively(self) + + self.set_attn_processor(FusedAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index b650f0e21af0..8fb833374d2d 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -29,6 +29,7 @@ AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, + FusedAttnProcessor2_0, ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin @@ -494,9 +495,16 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) + def fuse_recursively(module): + for submodule in module.children(): + if isinstance(submodule, Attention): + submodule.fuse_projections(fuse=True) + # Recursively call this function on the submodule to handle nesting + fuse_recursively(submodule) + + fuse_recursively(self) + + self.set_attn_processor(FusedAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index e2657e56901f..e869144aca5b 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -29,6 +29,7 @@ AttnAddedKVProcessor, AttnProcessor, AttnProcessor2_0, + FusedAttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, ) @@ -925,9 +926,16 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) + def fuse_recursively(module): + for submodule in module.children(): + if isinstance(submodule, Attention): + submodule.fuse_projections(fuse=True) + # Recursively call this function on the submodule to handle nesting + fuse_recursively(submodule) + + fuse_recursively(self) + + self.set_attn_processor(FusedAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): From 8f4617721b73f98d9953fd4b8268e2d709101083 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 10 Jul 2024 19:28:10 +0200 Subject: [PATCH 06/13] handle hunyuan --- src/diffusers/models/attention_processor.py | 103 ++++++++++++++++++ .../transformers/hunyuan_transformer_2d.py | 6 +- 2 files changed, 106 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 6d785912cd0b..8c50429d4bf8 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1621,6 +1621,109 @@ def __call__( return hidden_states +class FusedHunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused + projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on + query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + if encoder_hidden_states is None: + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + else: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + query = attn.to_q(hidden_states) + + kv = attn.to_kv(encoder_hidden_states) + split_size = kv.shape[-1] // 2 + key, value = torch.split(kv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class LuminaAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index 0da5795a5c9f..14143b552507 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -21,7 +21,7 @@ from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward -from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0, HunyuanAttnProcessor2_0 +from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0 from ..embeddings import ( HunyuanCombinedTimestepTextSizeStyleEmbedding, PatchEmbed, @@ -326,7 +326,7 @@ def __init__( self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanAttnProcessor2_0 def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) @@ -355,7 +355,7 @@ def fuse_recursively(module): fuse_recursively(self) - self.set_attn_processor(FusedAttnProcessor2_0()) + self.set_attn_processor(FusedHunyuanAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): From ee39007636c3a2eb143be124d648ef05e5503196 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 10 Jul 2024 19:33:04 +0200 Subject: [PATCH 07/13] remove residuals. --- src/diffusers/models/attention_processor.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 8c50429d4bf8..945aaea88a2d 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -445,6 +445,7 @@ def forward( # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) quiet_attn_parameters = {"ip_adapter_masks"} unused_kwargs = [ @@ -648,6 +649,7 @@ def fuse_projections(self, fuse=True): if self.use_bias: concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) self.to_qkv.bias.copy_(concatenated_bias) + else: concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) in_features = concatenated_weights.shape[1] @@ -659,7 +661,7 @@ def fuse_projections(self, fuse=True): concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) self.to_kv.bias.copy_(concatenated_bias) - print(f"{hasattr(self, 'add_q_proj')=}, {hasattr(self, 'add_k_proj')=}, {hasattr(self, 'add_v_proj')=}") + # handle added projections for SD3 and others. if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): concatenated_weights = torch.cat( [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] @@ -673,15 +675,6 @@ def fuse_projections(self, fuse=True): [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] ) self.to_added_qkv.bias.copy_(concatenated_bias) - # elif hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): - # concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data]) - # in_features = concatenated_weights.shape[1] - # out_features = concatenated_weights.shape[0] - - # self.to_added_kv = nn.Linear(in_features, out_features, bias=True, device=device, dtype=dtype) - # self.to_added_kv.weight.copy_(concatenated_weights) - # concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data]) - # self.to_added_kv.bias.copy_(concatenated_bias) self.fused_projections = fuse @@ -1115,7 +1108,6 @@ def __call__( query, key, value = torch.split(qkv, split_size, dim=-1) # `context` projections. - print(f"{hasattr(attn, 'to_added_qkv')=}") encoder_qkv = attn.to_added_qkv(encoder_hidden_states) split_size = encoder_qkv.shape[-1] // 3 ( From be3adcf7c64802c77d79840cff787f16551174de Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 11 Jul 2024 10:28:13 +0200 Subject: [PATCH 08/13] add a check for making sure there's appropriate procs. --- tests/pipelines/hunyuan_dit/test_hunyuan_dit.py | 12 ++++++++++++ .../test_pipeline_stable_diffusion_3.py | 11 +++++++++++ tests/pipelines/test_pipelines_common.py | 16 ++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py index ad5f5f3ef2ca..bfd7d79dd903 100644 --- a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py +++ b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py @@ -42,6 +42,11 @@ enable_full_determinism() +def check_qkv_fusion_matches_attn_procs_length(model, original_attn_processors): + current_attn_processors = model.attn_processors + return len(current_attn_processors) == len(original_attn_processors) + + class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = HunyuanDiTPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} @@ -261,6 +266,13 @@ def test_fused_qkv_projections(self): original_image_slice = image[0, -3:, -3:, -1] pipe.transformer.fuse_qkv_projections() + # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added + # to the pipeline level. + pipe.transformer.fuse_qkv_projections() + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + inputs = self.get_dummy_inputs(device) inputs["return_dict"] = False image_fused = pipe(**inputs)[0] diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index 93e740145477..73e7454ffb22 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -16,6 +16,11 @@ from ..test_pipelines_common import PipelineTesterMixin +def check_qkv_fusion_matches_attn_procs_length(model, original_attn_processors): + current_attn_processors = model.attn_processors + return len(current_attn_processors) == len(original_attn_processors) + + class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): pipeline_class = StableDiffusion3Pipeline params = frozenset( @@ -191,7 +196,13 @@ def test_fused_qkv_projections(self): image = pipe(**inputs).images original_image_slice = image[0, -3:, -3:, -1] + # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added + # to the pipeline level. pipe.transformer.fuse_qkv_projections() + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images image_slice_fused = image[0, -3:, -3:, -1] diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 8f2419db92e3..f9e2435b4e64 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -13,6 +13,7 @@ import numpy as np import PIL.Image import torch +import torch.nn as nn from huggingface_hub import ModelCard, delete_repo from huggingface_hub.utils import is_jinja_available from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -67,6 +68,11 @@ def check_same_shape(tensor_list): return all(shape == shapes[0] for shape in shapes[1:]) +def check_qkv_fusion_matches_attn_procs_length(model, original_attn_processors): + current_attn_processors = model.attn_processors + return len(current_attn_processors) == len(original_attn_processors) + + class SDFunctionTesterMixin: """ This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. @@ -196,6 +202,16 @@ def test_fused_qkv_projections(self): original_image_slice = image[0, -3:, -3:, -1] pipe.fuse_qkv_projections() + for _, component in pipe.components.items(): + if ( + isinstance(component, nn.Module) + and hasattr(component, "original_attn_processors") + and component.original_attn_processors is not None + ): + assert check_qkv_fusion_matches_attn_procs_length( + component, component.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + inputs = self.get_dummy_inputs(device) inputs["return_dict"] = False image_fused = pipe(**inputs)[0] From eb94d4f4eb798aef15aab82ec7bba8b1a7039d6b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 18 Jul 2024 17:04:59 +0530 Subject: [PATCH 09/13] add more rigor to the tests. --- src/diffusers/utils/testing_utils.py | 11 +++++++++++ tests/pipelines/hunyuan_dit/test_hunyuan_dit.py | 10 +++++----- .../test_pipeline_stable_diffusion_3.py | 10 +++++----- tests/pipelines/test_pipelines_common.py | 17 +++++++++++------ 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index be3e9983c80f..81df76924b97 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -830,6 +830,17 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): test_case.fail(f'{results["error"]}') +def check_qkv_fusion_matches_attn_procs_length(model, original_attn_processors): + current_attn_processors = model.attn_processors + return len(current_attn_processors) == len(original_attn_processors) + + +def check_qkv_fusion_processors_exist(model): + current_attn_processors = model.attn_processors + proc_names = list(current_attn_processors.keys()) + return all("Fused" in p for p in proc_names) + + class CaptureLogger: """ Args: diff --git a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py index bfd7d79dd903..f3e67e290d25 100644 --- a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py +++ b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py @@ -28,6 +28,8 @@ HunyuanDiTPipeline, ) from diffusers.utils.testing_utils import ( + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, enable_full_determinism, numpy_cosine_similarity_distance, require_torch_gpu, @@ -42,11 +44,6 @@ enable_full_determinism() -def check_qkv_fusion_matches_attn_procs_length(model, original_attn_processors): - current_attn_processors = model.attn_processors - return len(current_attn_processors) == len(original_attn_processors) - - class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = HunyuanDiTPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} @@ -269,6 +266,9 @@ def test_fused_qkv_projections(self): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist( + pipe.transformer + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." assert check_qkv_fusion_matches_attn_procs_length( pipe.transformer, pipe.transformer.original_attn_processors ), "Something wrong with the attention processors concerning the fused QKV projections." diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index 73e7454ffb22..f9384da1891a 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -7,6 +7,8 @@ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline from diffusers.utils.testing_utils import ( + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, numpy_cosine_similarity_distance, require_torch_gpu, slow, @@ -16,11 +18,6 @@ from ..test_pipelines_common import PipelineTesterMixin -def check_qkv_fusion_matches_attn_procs_length(model, original_attn_processors): - current_attn_processors = model.attn_processors - return len(current_attn_processors) == len(original_attn_processors) - - class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): pipeline_class = StableDiffusion3Pipeline params = frozenset( @@ -199,6 +196,9 @@ def test_fused_qkv_projections(self): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist( + pipe.transformer + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." assert check_qkv_fusion_matches_attn_procs_length( pipe.transformer, pipe.transformer.original_attn_processors ), "Something wrong with the attention processors concerning the fused QKV projections." diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index f9e2435b4e64..1c9fd6a7203d 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -41,7 +41,14 @@ from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import logging from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available -from diffusers.utils.testing_utils import CaptureLogger, require_torch, skip_mps, torch_device +from diffusers.utils.testing_utils import ( + CaptureLogger, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, + require_torch, + skip_mps, + torch_device, +) from ..models.autoencoders.test_models_vae import ( get_asym_autoencoder_kl_config, @@ -68,11 +75,6 @@ def check_same_shape(tensor_list): return all(shape == shapes[0] for shape in shapes[1:]) -def check_qkv_fusion_matches_attn_procs_length(model, original_attn_processors): - current_attn_processors = model.attn_processors - return len(current_attn_processors) == len(original_attn_processors) - - class SDFunctionTesterMixin: """ This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. @@ -208,6 +210,9 @@ def test_fused_qkv_projections(self): and hasattr(component, "original_attn_processors") and component.original_attn_processors is not None ): + assert check_qkv_fusion_processors_exist( + component + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." assert check_qkv_fusion_matches_attn_procs_length( component, component.original_attn_processors ), "Something wrong with the attention processors concerning the fused QKV projections." From ce67fe8bfd4137f64995abd249affc0f4e65a31f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 18 Jul 2024 17:17:07 +0530 Subject: [PATCH 10/13] fix test --- src/diffusers/utils/testing_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 81df76924b97..f99108d56fcf 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -837,8 +837,8 @@ def check_qkv_fusion_matches_attn_procs_length(model, original_attn_processors): def check_qkv_fusion_processors_exist(model): current_attn_processors = model.attn_processors - proc_names = list(current_attn_processors.keys()) - return all("Fused" in p for p in proc_names) + proc_names = [v.__class__.__name__ for _, v in current_attn_processors.items()] + return all(p.startswith("Fused") for p in proc_names) class CaptureLogger: From d563b9eac2e31689900d798f7cc77e1cdde267bf Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 19 Jul 2024 17:23:47 +0530 Subject: [PATCH 11/13] remove redundant check --- src/diffusers/models/autoencoders/autoencoder_kl.py | 11 +++-------- src/diffusers/models/controlnet_xs.py | 11 +++-------- .../models/transformers/hunyuan_transformer_2d.py | 11 +++-------- src/diffusers/models/transformers/transformer_sd3.py | 11 +++-------- src/diffusers/models/unets/unet_2d_condition.py | 11 +++-------- src/diffusers/models/unets/unet_3d_condition.py | 11 +++-------- src/diffusers/models/unets/unet_i2vgen_xl.py | 11 +++-------- src/diffusers/models/unets/unet_motion_model.py | 11 +++-------- 8 files changed, 24 insertions(+), 64 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 87ff8a5a6e9b..80ffe17a4ebd 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -483,14 +483,9 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - def fuse_recursively(module): - for submodule in module.children(): - if isinstance(submodule, Attention): - submodule.fuse_projections(fuse=True) - # Recursively call this function on the submodule to handle nesting - fuse_recursively(submodule) - - fuse_recursively(self) + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) self.set_attn_processor(FusedAttnProcessor2_0()) diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index d5f00cf22296..0fa21755f09c 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -998,14 +998,9 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - def fuse_recursively(module): - for submodule in module.children(): - if isinstance(submodule, Attention): - submodule.fuse_projections(fuse=True) - # Recursively call this function on the submodule to handle nesting - fuse_recursively(submodule) - - fuse_recursively(self) + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) self.set_attn_processor(FusedAttnProcessor2_0()) diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index 9f91074668c6..7f3dab220aaa 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -337,14 +337,9 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - def fuse_recursively(module): - for submodule in module.children(): - if isinstance(submodule, Attention): - submodule.fuse_projections(fuse=True) - # Recursively call this function on the submodule to handle nesting - fuse_recursively(submodule) - - fuse_recursively(self) + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) self.set_attn_processor(FusedHunyuanAttnProcessor2_0()) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 873bac1f187d..a02c7a471f3a 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -231,14 +231,9 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - def fuse_recursively(module): - for submodule in module.children(): - if isinstance(submodule, Attention): - submodule.fuse_projections(fuse=True) - # Recursively call this function on the submodule to handle nesting - fuse_recursively(submodule) - - fuse_recursively(self) + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) self.set_attn_processor(FusedJointAttnProcessor2_0()) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 9a53531b77ce..611ac6087e4a 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -887,14 +887,9 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - def fuse_recursively(module): - for submodule in module.children(): - if isinstance(submodule, Attention): - submodule.fuse_projections(fuse=True) - # Recursively call this function on the submodule to handle nesting - fuse_recursively(submodule) - - fuse_recursively(self) + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) self.set_attn_processor(FusedAttnProcessor2_0()) diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index 45216d4aae2f..3081fdc4700c 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -529,14 +529,9 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - def fuse_recursively(module): - for submodule in module.children(): - if isinstance(submodule, Attention): - submodule.fuse_projections(fuse=True) - # Recursively call this function on the submodule to handle nesting - fuse_recursively(submodule) - - fuse_recursively(self) + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) self.set_attn_processor(FusedAttnProcessor2_0()) diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index 8fb833374d2d..6ab3a577b892 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -495,14 +495,9 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - def fuse_recursively(module): - for submodule in module.children(): - if isinstance(submodule, Attention): - submodule.fuse_projections(fuse=True) - # Recursively call this function on the submodule to handle nesting - fuse_recursively(submodule) - - fuse_recursively(self) + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) self.set_attn_processor(FusedAttnProcessor2_0()) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 71b510517d0c..196f947d599b 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -926,14 +926,9 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - def fuse_recursively(module): - for submodule in module.children(): - if isinstance(submodule, Attention): - submodule.fuse_projections(fuse=True) - # Recursively call this function on the submodule to handle nesting - fuse_recursively(submodule) - - fuse_recursively(self) + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) self.set_attn_processor(FusedAttnProcessor2_0()) From 8214c8865325c75a528afd46f80c5c33afc087b6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 19 Jul 2024 17:28:49 +0530 Subject: [PATCH 12/13] fix-copies --- src/diffusers/models/controlnet_sd3.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py index c3d946a49c1c..305401164b2f 100644 --- a/src/diffusers/models/controlnet_sd3.py +++ b/src/diffusers/models/controlnet_sd3.py @@ -216,14 +216,9 @@ def fuse_qkv_projections(self): self.original_attn_processors = self.attn_processors - def fuse_recursively(module): - for submodule in module.children(): - if isinstance(submodule, Attention): - submodule.fuse_projections(fuse=True) - # Recursively call this function on the submodule to handle nesting - fuse_recursively(submodule) - - fuse_recursively(self) + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) self.set_attn_processor(FusedJointAttnProcessor2_0()) From 81caa930a254f8ad87013f837b9f50a7e23cda7e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 20 Jul 2024 10:02:36 +0530 Subject: [PATCH 13/13] move check_qkv_fusion_matches_attn_procs_length and check_qkv_fusion_processors_exist. --- src/diffusers/utils/testing_utils.py | 11 ----------- tests/pipelines/hunyuan_dit/test_hunyuan_dit.py | 9 ++++++--- .../test_pipeline_stable_diffusion_3.py | 8 +++++--- tests/pipelines/test_pipelines_common.py | 13 +++++++++++-- 4 files changed, 22 insertions(+), 19 deletions(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index f99108d56fcf..be3e9983c80f 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -830,17 +830,6 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): test_case.fail(f'{results["error"]}') -def check_qkv_fusion_matches_attn_procs_length(model, original_attn_processors): - current_attn_processors = model.attn_processors - return len(current_attn_processors) == len(original_attn_processors) - - -def check_qkv_fusion_processors_exist(model): - current_attn_processors = model.attn_processors - proc_names = [v.__class__.__name__ for _, v in current_attn_processors.items()] - return all(p.startswith("Fused") for p in proc_names) - - class CaptureLogger: """ Args: diff --git a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py index f3e67e290d25..653cb41e4bc4 100644 --- a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py +++ b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py @@ -28,8 +28,6 @@ HunyuanDiTPipeline, ) from diffusers.utils.testing_utils import ( - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, enable_full_determinism, numpy_cosine_similarity_distance, require_torch_gpu, @@ -38,7 +36,12 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, + to_np, +) enable_full_determinism() diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index f9384da1891a..75a7d88ea4f2 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -7,15 +7,17 @@ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline from diffusers.utils.testing_utils import ( - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device, ) -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, +) class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 77d98264dfbf..06fcc1c90b71 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -43,8 +43,6 @@ from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available from diffusers.utils.testing_utils import ( CaptureLogger, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, require_torch, skip_mps, torch_device, @@ -75,6 +73,17 @@ def check_same_shape(tensor_list): return all(shape == shapes[0] for shape in shapes[1:]) +def check_qkv_fusion_matches_attn_procs_length(model, original_attn_processors): + current_attn_processors = model.attn_processors + return len(current_attn_processors) == len(original_attn_processors) + + +def check_qkv_fusion_processors_exist(model): + current_attn_processors = model.attn_processors + proc_names = [v.__class__.__name__ for _, v in current_attn_processors.items()] + return all(p.startswith("Fused") for p in proc_names) + + class SDFunctionTesterMixin: """ This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.