diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py
index d4f5b4658542..b5c44e868d38 100644
--- a/src/diffusers/models/transformers/lumina_nextdit2d.py
+++ b/src/diffusers/models/transformers/lumina_nextdit2d.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Union
import torch
import torch.nn as nn
@@ -20,7 +20,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import LuminaFeedForward
-from ..attention_processor import Attention, LuminaAttnProcessor2_0
+from ..attention_processor import Attention, AttentionProcessor, LuminaAttnProcessor2_0
from ..embeddings import (
LuminaCombinedTimestepCaptionEmbedding,
LuminaPatchEmbed,
@@ -115,6 +115,16 @@ def __init__(
self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
def forward(
self,
hidden_states: torch.Tensor,
@@ -282,10 +292,113 @@ def __init__(
bias=True,
out_dim=patch_size * patch_size * self.out_channels,
)
- # self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels)
assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ self.set_attn_processor(LuminaAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
def forward(
self,
hidden_states: torch.Tensor,
@@ -297,7 +410,8 @@ def forward(
return_dict=True,
) -> torch.Tensor:
"""
- Forward pass of LuminaNextDiT.
+ The `LuminaNextDiT2DModel` of forward method. Check the details on [Lumina
+ paper](https://arxiv.org/abs/2406.18583).
Parameters:
hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W).
@@ -338,3 +452,45 @@ def forward(
return (output,)
return Transformer2DModelOutput(sample=output)
+
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
+ """
+ Sets the attention processor to use [feed forward
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
+
+ Parameters:
+ chunk_size (`int`, *optional*):
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
+ over each tensor of dim=`dim`.
+ dim (`int`, *optional*, defaults to `0`):
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
+ or dim=1 (sequence length).
+ """
+ if dim not in [0, 1]:
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
+
+ # By default chunk size is 1
+ chunk_size = chunk_size or 1
+
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, chunk_size, dim)
+
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
+ def disable_forward_chunking(self):
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, None, 0)
diff --git a/tests/pipelines/lumina/test_lumina_nextdit.py b/tests/pipelines/lumina/test_lumina_nextdit.py
index a53758ce2808..489c62fcd94d 100644
--- a/tests/pipelines/lumina/test_lumina_nextdit.py
+++ b/tests/pipelines/lumina/test_lumina_nextdit.py
@@ -13,10 +13,10 @@
torch_device,
)
-from ..test_pipelines_common import PipelineTesterMixin
+from ..test_pipelines_common import PipelineTesterMixin, to_np
-class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+class LuminaText2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = LuminaText2ImgPipeline
params = frozenset(
[
@@ -119,6 +119,60 @@ def test_lumina_prompt_embeds(self):
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4
+ def test_feed_forward_chunking(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_no_chunking = image[0, -3:, -3:, -1]
+
+ pipe.transformer.enable_forward_chunking(chunk_size=1, dim=0)
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_chunking = image[0, -3:, -3:, -1]
+
+ max_diff = np.abs(to_np(image_slice_no_chunking) - to_np(image_slice_chunking)).max()
+ self.assertLess(max_diff, 1e-4)
+
+ def test_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["return_dict"] = False
+ image = pipe(**inputs)[0]
+ original_image_slice = image[0, -3:, -3:, -1]
+
+ pipe.transformer.fuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ inputs["return_dict"] = False
+ image_fused = pipe(**inputs)[0]
+ image_slice_fused = image_fused[0, -3:, -3:, -1]
+
+ pipe.transformer.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ inputs["return_dict"] = False
+ image_disabled = pipe(**inputs)[0]
+ image_slice_disabled = image_disabled[0, -3:, -3:, -1]
+
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
+
@slow
@require_torch_gpu