Skip to content

[Alpha-VLLM Team] Feat: added fused qkv and chunked ffn #8815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 160 additions & 4 deletions src/diffusers/models/transformers/lumina_nextdit2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import torch
import torch.nn as nn

from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import LuminaFeedForward
from ..attention_processor import Attention, LuminaAttnProcessor2_0
from ..attention_processor import Attention, AttentionProcessor, LuminaAttnProcessor2_0
from ..embeddings import (
LuminaCombinedTimestepCaptionEmbedding,
LuminaPatchEmbed,
Expand Down Expand Up @@ -115,6 +115,16 @@ def __init__(

self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)

# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0

# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -282,10 +292,113 @@ def __init__(
bias=True,
out_dim=patch_size * patch_size * self.out_channels,
)
# self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels)

assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"

@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}

def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()

for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

return processors

for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)

return processors

def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(LuminaAttnProcessor2_0())

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.

Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.

If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.

"""
count = len(self.attn_processors.keys())

if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)

def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.

<Tip warning={true}>

This API is 🧪 experimental.

</Tip>
"""
self.original_attn_processors = None

for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

self.original_attn_processors = self.attn_processors

for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.

<Tip warning={true}>

This API is 🧪 experimental.

</Tip>

"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -297,7 +410,8 @@ def forward(
return_dict=True,
) -> torch.Tensor:
"""
Forward pass of LuminaNextDiT.
The `LuminaNextDiT2DModel` of forward method. Check the details on [Lumina
paper](https://arxiv.org/abs/2406.18583).

Parameters:
hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W).
Expand Down Expand Up @@ -338,3 +452,45 @@ def forward(
return (output,)

return Transformer2DModelOutput(sample=output)

# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).

Parameters:
chunk_size (`int`, *optional*):
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
over each tensor of dim=`dim`.
dim (`int`, *optional*, defaults to `0`):
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
or dim=1 (sequence length).
"""
if dim not in [0, 1]:
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")

# By default chunk size is 1
chunk_size = chunk_size or 1

def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)

for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)

# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
def disable_forward_chunking(self):
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)

for module in self.children():
fn_recursive_feed_forward(module, None, 0)
58 changes: 56 additions & 2 deletions tests/pipelines/lumina/test_lumina_nextdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
torch_device,
)

from ..test_pipelines_common import PipelineTesterMixin
from ..test_pipelines_common import PipelineTesterMixin, to_np


class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin):
class LuminaText2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = LuminaText2ImgPipeline
params = frozenset(
[
Expand Down Expand Up @@ -119,6 +119,60 @@ def test_lumina_prompt_embeds(self):
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4

def test_feed_forward_chunking(self):
device = "cpu"

components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_no_chunking = image[0, -3:, -3:, -1]

pipe.transformer.enable_forward_chunking(chunk_size=1, dim=0)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_chunking = image[0, -3:, -3:, -1]

max_diff = np.abs(to_np(image_slice_no_chunking) - to_np(image_slice_chunking)).max()
self.assertLess(max_diff, 1e-4)

def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
image = pipe(**inputs)[0]
original_image_slice = image[0, -3:, -3:, -1]

pipe.transformer.fuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
image_fused = pipe(**inputs)[0]
image_slice_fused = image_fused[0, -3:, -3:, -1]

pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]

assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."


@slow
@require_torch_gpu
Expand Down
Loading