diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py index d4f5b4658542..b5c44e868d38 100644 --- a/src/diffusers/models/transformers/lumina_nextdit2d.py +++ b/src/diffusers/models/transformers/lumina_nextdit2d.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union import torch import torch.nn as nn @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ..attention import LuminaFeedForward -from ..attention_processor import Attention, LuminaAttnProcessor2_0 +from ..attention_processor import Attention, AttentionProcessor, LuminaAttnProcessor2_0 from ..embeddings import ( LuminaCombinedTimestepCaptionEmbedding, LuminaPatchEmbed, @@ -115,6 +115,16 @@ def __init__( self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + def forward( self, hidden_states: torch.Tensor, @@ -282,10 +292,113 @@ def __init__( bias=True, out_dim=patch_size * patch_size * self.out_channels, ) - # self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels) assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4" + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(LuminaAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + def forward( self, hidden_states: torch.Tensor, @@ -297,7 +410,8 @@ def forward( return_dict=True, ) -> torch.Tensor: """ - Forward pass of LuminaNextDiT. + The `LuminaNextDiT2DModel` of forward method. Check the details on [Lumina + paper](https://arxiv.org/abs/2406.18583). Parameters: hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W). @@ -338,3 +452,45 @@ def forward( return (output,) return Transformer2DModelOutput(sample=output) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) diff --git a/tests/pipelines/lumina/test_lumina_nextdit.py b/tests/pipelines/lumina/test_lumina_nextdit.py index a53758ce2808..489c62fcd94d 100644 --- a/tests/pipelines/lumina/test_lumina_nextdit.py +++ b/tests/pipelines/lumina/test_lumina_nextdit.py @@ -13,10 +13,10 @@ torch_device, ) -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import PipelineTesterMixin, to_np -class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin): +class LuminaText2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): pipeline_class = LuminaText2ImgPipeline params = frozenset( [ @@ -119,6 +119,60 @@ def test_lumina_prompt_embeds(self): max_diff = np.abs(output_with_prompt - output_with_embeds).max() assert max_diff < 1e-4 + def test_feed_forward_chunking(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_no_chunking = image[0, -3:, -3:, -1] + + pipe.transformer.enable_forward_chunking(chunk_size=1, dim=0) + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_chunking = image[0, -3:, -3:, -1] + + max_diff = np.abs(to_np(image_slice_no_chunking) - to_np(image_slice_chunking)).max() + self.assertLess(max_diff, 1e-4) + + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["return_dict"] = False + image = pipe(**inputs)[0] + original_image_slice = image[0, -3:, -3:, -1] + + pipe.transformer.fuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + inputs["return_dict"] = False + image_fused = pipe(**inputs)[0] + image_slice_fused = image_fused[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + inputs["return_dict"] = False + image_disabled = pipe(**inputs)[0] + image_slice_disabled = image_disabled[0, -3:, -3:, -1] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." + @slow @require_torch_gpu