diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index c720b379551f..84c31350f0ce 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -799,6 +799,7 @@ def __init__( ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, + qk_norm: Optional[str] = None, ): super().__init__() self.dim = dim @@ -867,6 +868,7 @@ def __init__( cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, out_bias=attention_out_bias, + qk_norm=qk_norm, ) # 2. Cross-Attn @@ -897,6 +899,7 @@ def __init__( bias=attention_bias, upcast_attention=upcast_attention, out_bias=attention_out_bias, + qk_norm=qk_norm, ) # is self-attn if encoder_hidden_states is none else: if norm_type == "ada_norm_single": # For Latte diff --git a/src/diffusers/models/transformers/dual_transformer_2d.py b/src/diffusers/models/transformers/dual_transformer_2d.py index 24eed2168229..4917f52adce2 100644 --- a/src/diffusers/models/transformers/dual_transformer_2d.py +++ b/src/diffusers/models/transformers/dual_transformer_2d.py @@ -60,6 +60,7 @@ def __init__( num_vector_embeds: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, + qk_norm: Optional[str] = None, ): super().__init__() self.transformers = nn.ModuleList( @@ -77,6 +78,7 @@ def __init__( num_vector_embeds=num_vector_embeds, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, + qk_norm=qk_norm, ) for _ in range(2) ] diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index 67fe9a33109b..0b76a89ee426 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -96,6 +96,7 @@ def __init__( caption_channels: int = None, interpolation_scale: float = None, use_additional_conditions: Optional[bool] = None, + qk_norm: Optional[str] = None, ): super().__init__() @@ -199,6 +200,7 @@ def _init_continuous_input(self, norm_type): norm_elementwise_affine=self.config.norm_elementwise_affine, norm_eps=self.config.norm_eps, attention_type=self.config.attention_type, + qk_norm=self.config.qk_norm, ) for _ in range(self.config.num_layers) ] diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index 94a9245e567c..d4e8fa0b68e2 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -66,6 +66,7 @@ def get_down_block( attention_head_dim: Optional[int] = None, downsample_type: Optional[str] = None, dropout: float = 0.0, + qk_norm: Optional[str] = None, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: @@ -122,6 +123,7 @@ def get_down_block( attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, downsample_type=downsample_type, + qk_norm=qk_norm, ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: @@ -146,6 +148,7 @@ def get_down_block( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, + qk_norm=qk_norm, ) elif down_block_type == "SimpleCrossAttnDownBlock2D": if cross_attention_dim is None: @@ -167,6 +170,7 @@ def get_down_block( output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, + qk_norm=qk_norm, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( @@ -193,6 +197,7 @@ def get_down_block( resnet_act_fn=resnet_act_fn, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, + qk_norm=qk_norm, ) elif down_block_type == "DownEncoderBlock2D": return DownEncoderBlock2D( @@ -220,6 +225,7 @@ def get_down_block( downsample_padding=downsample_padding, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, + qk_norm=qk_norm, ) elif down_block_type == "KDownBlock2D": return KDownBlock2D( @@ -245,6 +251,7 @@ def get_down_block( cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, add_self_attention=True if not add_downsample else False, + qk_norm=qk_norm, ) raise ValueError(f"{down_block_type} does not exist.") @@ -270,6 +277,7 @@ def get_mid_block( cross_attention_norm: Optional[str] = None, attention_head_dim: Optional[int] = 1, dropout: float = 0.0, + qk_norm: Optional[str] = None, ): if mid_block_type == "UNetMidBlock2DCrossAttn": return UNetMidBlock2DCrossAttn( @@ -288,6 +296,7 @@ def get_mid_block( use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, + qk_norm=qk_norm, ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": return UNetMidBlock2DSimpleCrossAttn( @@ -304,6 +313,7 @@ def get_mid_block( skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, + qk_norm=qk_norm, ) elif mid_block_type == "UNetMidBlock2D": return UNetMidBlock2D( @@ -351,6 +361,7 @@ def get_up_block( attention_head_dim: Optional[int] = None, upsample_type: Optional[str] = None, dropout: float = 0.0, + qk_norm: Optional[str] = None, ) -> nn.Module: # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: @@ -416,6 +427,7 @@ def get_up_block( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, + qk_norm=qk_norm, ) elif up_block_type == "SimpleCrossAttnUpBlock2D": if cross_attention_dim is None: @@ -439,6 +451,7 @@ def get_up_block( output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, + qk_norm=qk_norm, ) elif up_block_type == "AttnUpBlock2D": if add_upsample is False: @@ -460,6 +473,7 @@ def get_up_block( attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, upsample_type=upsample_type, + qk_norm=qk_norm, ) elif up_block_type == "SkipUpBlock2D": return SkipUpBlock2D( @@ -489,6 +503,7 @@ def get_up_block( resnet_act_fn=resnet_act_fn, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, + qk_norm=qk_norm, ) elif up_block_type == "UpDecoderBlock2D": return UpDecoderBlock2D( @@ -518,6 +533,7 @@ def get_up_block( attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, temb_channels=temb_channels, + qk_norm=qk_norm, ) elif up_block_type == "KUpBlock2D": return KUpBlock2D( @@ -544,6 +560,7 @@ def get_up_block( resnet_act_fn=resnet_act_fn, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, + qk_norm=qk_norm, ) raise ValueError(f"{up_block_type} does not exist.") @@ -632,6 +649,7 @@ def __init__( add_attention: bool = True, attention_head_dim: int = 1, output_scale_factor: float = 1.0, + qk_norm: Optional[str] = None, ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) @@ -693,6 +711,7 @@ def __init__( bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, + qk_norm=qk_norm, ) ) else: @@ -770,6 +789,7 @@ def __init__( use_linear_projection: bool = False, upcast_attention: bool = False, attention_type: str = "default", + qk_norm: Optional[str] = None, ): super().__init__() @@ -818,6 +838,7 @@ def __init__( use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, + qk_norm=qk_norm, ) ) else: @@ -829,6 +850,7 @@ def __init__( num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, + qk_norm=qk_norm, ) ) resnets.append( @@ -908,6 +930,7 @@ def __init__( skip_time_act: bool = False, only_cross_attention: bool = False, cross_attention_norm: Optional[str] = None, + qk_norm: Optional[str] = None, ): super().__init__() @@ -954,6 +977,7 @@ def __init__( only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, + qk_norm=qk_norm, ) ) resnets.append( @@ -1032,6 +1056,7 @@ def __init__( output_scale_factor: float = 1.0, downsample_padding: int = 1, downsample_type: str = "conv", + qk_norm: Optional[str] = None, ): super().__init__() resnets = [] @@ -1072,6 +1097,7 @@ def __init__( bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, + qk_norm=qk_norm, ) ) @@ -1168,6 +1194,7 @@ def __init__( only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", + qk_norm: Optional[str] = None, ): super().__init__() resnets = [] @@ -1207,6 +1234,7 @@ def __init__( only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, attention_type=attention_type, + qk_norm=qk_norm, ) ) else: @@ -1218,6 +1246,7 @@ def __init__( num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, + qk_norm=qk_norm, ) ) self.attentions = nn.ModuleList(attentions) @@ -1464,6 +1493,7 @@ def __init__( output_scale_factor: float = 1.0, add_downsample: bool = True, downsample_padding: int = 1, + qk_norm: Optional[str] = None, ): super().__init__() resnets = [] @@ -1518,6 +1548,7 @@ def __init__( bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, + qk_norm=qk_norm, ) ) @@ -1566,6 +1597,7 @@ def __init__( attention_head_dim: int = 1, output_scale_factor: float = np.sqrt(2.0), add_downsample: bool = True, + qk_norm: Optional[str] = None, ): super().__init__() self.attentions = nn.ModuleList([]) @@ -1606,6 +1638,7 @@ def __init__( bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, + qk_norm=qk_norm, ) ) @@ -1863,6 +1896,7 @@ def __init__( skip_time_act: bool = False, only_cross_attention: bool = False, cross_attention_norm: Optional[str] = None, + qk_norm: Optional[str] = None, ): super().__init__() @@ -1909,6 +1943,7 @@ def __init__( only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, + qk_norm=qk_norm, ) ) self.attentions = nn.ModuleList(attentions) @@ -2079,6 +2114,7 @@ def __init__( add_self_attention: bool = False, resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", + qk_norm: Optional[str] = None, ): super().__init__() resnets = [] @@ -2116,6 +2152,7 @@ def __init__( add_self_attention=add_self_attention, cross_attention_norm="layer_norm", group_size=resnet_group_size, + qk_norm=qk_norm, ) ) @@ -2200,6 +2237,7 @@ def __init__( attention_head_dim: int = 1, output_scale_factor: float = 1.0, upsample_type: str = "conv", + qk_norm: Optional[str] = None, ): super().__init__() resnets = [] @@ -2243,6 +2281,7 @@ def __init__( bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, + qk_norm=qk_norm, ) ) @@ -2336,6 +2375,7 @@ def __init__( only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", + qk_norm: Optional[str] = None, ): super().__init__() resnets = [] @@ -2378,6 +2418,7 @@ def __init__( only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, attention_type=attention_type, + qk_norm=qk_norm, ) ) else: @@ -2389,6 +2430,7 @@ def __init__( num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, + qk_norm=qk_norm, ) ) self.attentions = nn.ModuleList(attentions) @@ -2662,6 +2704,7 @@ def __init__( output_scale_factor: float = 1.0, add_upsample: bool = True, temb_channels: Optional[int] = None, + qk_norm: Optional[str] = None, ): super().__init__() resnets = [] @@ -2719,6 +2762,7 @@ def __init__( bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, + qk_norm=qk_norm, ) ) @@ -2761,6 +2805,7 @@ def __init__( attention_head_dim: int = 1, output_scale_factor: float = np.sqrt(2.0), add_upsample: bool = True, + qk_norm: Optional[str] = None, ): super().__init__() self.attentions = nn.ModuleList([]) @@ -2804,6 +2849,7 @@ def __init__( bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, + qk_norm=qk_norm, ) ) @@ -3110,6 +3156,7 @@ def __init__( skip_time_act: bool = False, only_cross_attention: bool = False, cross_attention_norm: Optional[str] = None, + qk_norm: Optional[str] = None, ): super().__init__() resnets = [] @@ -3157,6 +3204,7 @@ def __init__( only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, + qk_norm=qk_norm, ) ) self.attentions = nn.ModuleList(attentions) @@ -3341,6 +3389,7 @@ def __init__( cross_attention_dim: int = 768, add_upsample: bool = True, upcast_attention: bool = False, + qk_norm: Optional[str] = None, ): super().__init__() resnets = [] @@ -3397,6 +3446,7 @@ def __init__( add_self_attention=add_self_attention, cross_attention_norm="layer_norm", upcast_attention=upcast_attention, + qk_norm=qk_norm, ) ) @@ -3497,6 +3547,7 @@ def __init__( add_self_attention: bool = False, cross_attention_norm: Optional[str] = None, group_size: int = 32, + qk_norm: Optional[str] = None, ): super().__init__() self.add_self_attention = add_self_attention @@ -3512,6 +3563,7 @@ def __init__( bias=attention_bias, cross_attention_dim=None, cross_attention_norm=None, + qk_norm=qk_norm, ) # 2. Cross-Attn @@ -3525,6 +3577,7 @@ def __init__( bias=attention_bias, upcast_attention=upcast_attention, cross_attention_norm=cross_attention_norm, + qk_norm=qk_norm, ) def _to_3d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor: diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 736deb28c376..65a666b13f15 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -162,6 +162,7 @@ class conditioning with `class_embed_type` equal to `None`. `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` otherwise. + qk_norm: Normalization to apply to attention Q/K. """ _supports_gradient_checkpointing = True @@ -225,6 +226,7 @@ def __init__( mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads: int = 64, + qk_norm: Optional[str] = None, ): super().__init__() @@ -380,6 +382,7 @@ def __init__( cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, dropout=dropout, + qk_norm=qk_norm, ) self.down_blocks.append(down_block) @@ -405,6 +408,7 @@ def __init__( cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[-1], dropout=dropout, + qk_norm=qk_norm, ) # count how many layers upsample the images @@ -463,6 +467,7 @@ def __init__( cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, dropout=dropout, + qk_norm=qk_norm, ) self.up_blocks.append(up_block) diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py index f6fa82aeb713..3824ac18b561 100644 --- a/tests/models/unets/test_models_unet_2d.py +++ b/tests/models/unets/test_models_unet_2d.py @@ -152,7 +152,6 @@ def test_gradient_checkpointing_is_applied(self): expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels ) - class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet2DModel main_input_name = "sample" diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 123dff16f8b0..410ff4f1cbd9 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -27,11 +27,13 @@ from diffusers import UNet2DConditionModel from diffusers.models.attention_processor import ( + Attention, CustomDiffusionAttnProcessor, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, ) from diffusers.models.embeddings import ImageProjection, IPAdapterFaceIDImageProjection, IPAdapterPlusImageProjection +from diffusers.models.normalization import RMSNorm from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( @@ -1143,6 +1145,94 @@ def test_save_attn_procs_raise_warning(self): warning_message = str(warning.warnings[0].message) assert "Using the `save_attn_procs()` method has been deprecated" in warning_message + def test_qk_norm_argument_unet2dcondition(self): + config_minimal = {} + model = UNet2DConditionModel.from_config(config_minimal) + qk_norm_modules = [n for n, _ in model.named_modules() + if n.endswith('.norm_q') or n.endswith('.norm_k')] + assert len(qk_norm_modules) == 0 + + config_minimal_qknorm = {"qk_norm": "rms_norm"} + model_qknorm = UNet2DConditionModel.from_config(config_minimal_qknorm) + qk_norm_modules = [n for n, _ in model_qknorm.named_modules() + if n.endswith('.norm_q') or n.endswith('.norm_k')] + assert len(qk_norm_modules) > 0 + + def test_qk_norm_argument_blocks(self): + from diffusers.models.unets.unet_2d_blocks import get_down_block, get_mid_block, get_up_block + + # test if block can be instantiated with qk_norm argument + # TODO: is there a canonical list of supported block types? + down_block_types = [ + "DownBlock2D", + "ResnetDownsampleBlock2D", + "AttnDownBlock2D", + "CrossAttnDownBlock2D", + "SimpleCrossAttnDownBlock2D", + "SkipDownBlock2D", + "AttnSkipDownBlock2D", + "DownEncoderBlock2D", + "AttnDownEncoderBlock2D", + "KDownBlock2D", + "KCrossAttnDownBlock2D", + ] + mid_block_types = ["UNetMidBlock2DCrossAttn", "UNetMidBlock2DSimpleCrossAttn", "UNetMidBlock2D"] + up_block_types = [ + block_type.replace("Down", "Up") + for block_type in down_block_types + if (block_type != "DownEncoderBlock2D" and block_type != "AttnDownEncoderBlock2D") + ] + + for block_getter, block_types, block_type_arg, extra_kwargs in [ + ( + get_down_block, + down_block_types, + "down_block_type", + {"num_layers": 2, "out_channels": 64, "add_downsample": False, "attention_head_dim": 8}, + ), + (get_mid_block, mid_block_types, "mid_block_type", {}), + ( + get_up_block, + up_block_types, + "up_block_type", + { + "num_layers": 2, + "out_channels": 64, + "prev_output_channel": 64, + "add_upsample": False, + "attention_head_dim": 8, + }, + ), + ]: + for block_type in block_types: + block_type_kwarg = {block_type_arg: block_type} + block = block_getter( + **block_type_kwarg, + **extra_kwargs, + in_channels=32, + temb_channels=32, + resnet_groups=32, + resnet_eps=1e-5, + resnet_act_fn="silu", + cross_attention_dim=1024, + num_attention_heads=8, + qk_norm="rms_norm", # <--- new argument + ) + if "Attn" in block_type: + # discover attentions + attentions = [module for module in block.modules() if isinstance(module, Attention)] + for attn in attentions: + k_norm = getattr(attn, "norm_k", None) + assert k_norm is not None + assert isinstance(k_norm, RMSNorm) + q_norm = getattr(attn, "norm_q", None) + assert q_norm is not None + assert isinstance(q_norm, RMSNorm) + else: + # make sure i didn't miss anything + attentions = [module for module in block.modules() if isinstance(module, Attention)] + assert len(attentions) == 0 + class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = UNet2DConditionModel