Skip to content

add PAG support for Stable Diffusion 3 #8861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Aug 6, 2024
Merged

Conversation

sunovivid
Copy link
Contributor

What does this PR do?

This is a PR for adding PAG support for SD3!

SD3 differs slightly from SD and SDXL because it employs Rectified Flow and uses the MMDiT backbone, not UNet.

For the joint attention in MMDiT, we can apply perturbed self-attention by masking attention between image patches, following the principles of PAG.

It works quite well. Here is an example with PAG + SD3.

Examples

"Pirate ship trapped in a cosmic maelstrom nebula"
resized_pag

From left to right, the PAG scale increases to 0.5, 1.0, 3.0, 5.0, and 7.0.
From top to bottom, the CFG scale increases to 1.0, 3.0, 5.0, and 7.0.

How to Use

import torch
from diffusers import AutoPipelineForText2Image

pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16, enable_pag=True, pag_applied_layers=["13"])
pipe.to("cuda")

prompt = "a photo of a cat holding a sign that says hello world"

image = pipe(
    prompt=prompt,
    negative_prompt="",
    num_inference_steps=28,
    height=1024,
    width=1024,
    guidance_scale=0,
    pag_scale=5.0,
).images[0]

image.save(f"result.png")

Ablations on Layer

For those interested in knowing which layer to perturb, I am providing the results of single-layer perturbation for all the MMDiT attention layers in SD3.

ablations from left to right, guidance scale: 3.0, 5.0, 7.0

from top to bottom, pag_applied_layer: 0,1,2,3
sd3_layer_a photo of a ca_seed_0_slice_1
from top to bottom, pag_applied_layer: 4,5,6,7
sd3_layer_a photo of a ca_seed_0_slice_2
from top to bottom, pag_applied_layer: 8,9,10,11
sd3_layer_a photo of a ca_seed_0_slice_3
from top to bottom, pag_applied_layer: 12,13,14,15
sd3_layer_a photo of a ca_seed_0_slice_4
from top to bottom, pag_applied_layer: 16,17,18,19
sd3_layer_a photo of a ca_seed_0_slice_5
from top to bottom, pag_applied_layer: 20,21,22,23
sd3_layer_a photo of a ca_seed_0_slice_6

Before submitting

Who can review?

@yiyixuxu @asomoza

@sayakpaul @a-r-r-o-w also could be interested

Thank you for your time!

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some minor comments but other than those this looks quite sleek!

@sayakpaul sayakpaul requested a review from yiyixuxu July 14, 2024 02:48
@asomoza
Copy link
Member

asomoza commented Jul 15, 2024

I gave this a test, it does fix a little of the generations but I find that the improvement is not as big and noticeable as with SDXL. I still think it's a good addition and probably will help with the controlnets.

W/O PAG PAG PAG
20240714191606_2986130010 20240714192948_2986130010_6_11 20240714193013_2986130010_2_9

The pag results are with different layers, but if I use more than 2 layers the generation looks more fake and cartoonish. Also I had to lower the scale to 2.5 to make it less cartoonish.

With both PAG results we can see that the "little people" have better details, but overall I like the style of the original, so probably I will just inpaint and make better the first generation.

If I use a more realistic photo, the fake and cartoon effect it's even more noticeable, contrary to SDXL that made them more realistic and well defined.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a ton for working on this!

@yiyixuxu yiyixuxu requested a review from DN6 July 16, 2024 16:48
@asomoza
Copy link
Member

asomoza commented Jul 19, 2024

I was playing again with this because I needed a cool generation and I gave it a revisit. I found a cool application that I'd like to share.

What I found are two things, I was playing with a pag scale that was too high, just using something like 0.5 or 0.7 gives good results that clean a little the image, for example:

original 0.5 0.7
20240719091424_4114929549 20240719092845_4114929549 20240719092932_4114929549

The original has a lot of details which always makes me impressed on how good the SD3 VAE is, but it has some defects like the shape of the tires, the weirdness of the front bumper among other things

Just applying a low pag scale with specific layers, cleans the image a lot and makes the generation better, the only problem is that it makes the truck cleaner which is wrong, but maybe that can be fixed with some other layers or just use this as base for a better generation.

The second thing I found is that we can change the "style" of the generation with just applying pag, for example, without doing anything in the prompt I can change this generation to a more concept art style which I think its really cool.

original with pag
20240719091424_4114929549 20240719093713_4114929549

And if I change the prompt to "concept art" instead of "photo":

original with pag
20240719100523_1765814182 20240719100413_4114929549

It seems to help to keep the composition consistent even if you change the style of the prompt.

@sunovivid
Copy link
Contributor Author

@yiyixuxu @DN6 Thank you for your hard workings with PAG!
I want to merge this PR, but it seems to be stalled. What steps remain for this PR? I believe I have revised the code following all the comments. Any suggestions?

@a-r-r-o-w
Copy link
Member

@sunovivid Just for notification - we merged #8936 now. It should be easier to now directly apply PAG without the extra SD3PAG class. You'll have to pass PAG related SD3 attention processors as a parameter when calling self.set_pag_applied_layers(..., pag_attn_processors=(...)) (you can look at how it's done in Hunyuan for an example)

Thanks for the amazing work here! Although a bit unrealistic, it might be nice to get this ready by tomorrow since a Diffusers release will be happening soon and it'd be nice to ship a PAG variant for SD3.

@sunovivid
Copy link
Contributor Author

sunovivid commented Aug 5, 2024

@a-r-r-o-w Ok! I want to get a chance. I will try it now. Thank you for your awesome refactoring!

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sunovivid Thank you for the kind words! This is really close to merging, just a couple of things:

  • Could you run make fix-copies?
  • Could you add a test for SD3? It mostly should be a copy of the SD3 tests with two additional tests similar to this and this

@sunovivid
Copy link
Contributor Author

@a-r-r-o-w I fixed the code following all your comments and added tests! Thank you for your comments and guidance.
I copied only fast tests of SD3 test following @sayakpaul's comments for now. WDYT?

@a-r-r-o-w
Copy link
Member

Also, another test seems to be failing:
image

This is probably because num_layers=1 but we are trying to find the second block using blocks.1. You can remove blocks.1 from that specific failing test to fix I think

@sunovivid
Copy link
Contributor Author

I see! How about set num_layers=2 as in the HunyuanDiTPAG test? I think it is the more robust way (we can test regex-based layer matching)
image

@a-r-r-o-w
Copy link
Member

That's sounds good as well. I tried it and there isn't a significant impact on speed (a few extra seconds but that's okay and should be good to test with atleast 2 layers)

@sunovivid
Copy link
Contributor Author

Thanks! I set num_layes=2.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sunovivid
Copy link
Contributor Author

OK. There is another error. Sorry for bothering you. I'll check it out.

@sunovivid
Copy link
Contributor Author

@a-r-r-o-w Thank you for your patience! The failure in the first test case was due to using joint attention in SD3, which uses the name attn without distinguishing between attn1 and attn2. I have fixed this error.

The second test case failure was harder to pinpoint. It occurred because, as this is a randomly initialized model, the random weights caused it to turn into nan at a specific point (specifically, when hidden_state_ptb passed through attn.to_out). Thus, the assertion assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3 failed as it tried to compare nan with 0.001. I believe this is due to the randomness, and when actually sampling results, it works correctly.

The part where `nan` occurs:

image
image

Sample grid using CFG and PAG (from left to right, PAG=[1.0, 1.5, 3.0, 5.0], from top to bottom, CFG=[0.5, 0.75, 1.0, 3.0]):

a photo of a cat holding a sign that says hello world
gpu0_prompt_a_photo_resized

19th century Scottish wizard with a mysterious smile and a piercing gaze, enigmatic, photorealistic, incredibly detailed, sharpness, detail, cinematic lighting
gpu1_prompt_19th_century_resized

Sampling code I used (a bit dirty bur for reproducibility):
from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image, make_image_grid
import multiprocessing

gpu_list = [0, 1]

def sample(gpu_id, prompts):
    pipe = AutoPipelineForText2Image.from_pretrained(
        "stabilityai/stable-diffusion-3-medium-diffusers",
        torch_dtype=torch.float16,
        enable_pag=True, pag_applied_layers=["blocks.9"]
    )
    pipe.to(f"cuda:{gpu_id}")

    seeds = range(1, 6)
    pag_scale_list = [0.5, 0.75, 1.0, 3.0,]
    cfg_scale_list = [1.0, 1.5, 3.0, 5.0,]

    for prompt in prompts:
        for seed in seeds:
            images = []
            for cfg_scale in cfg_scale_list:
                for pag_scale in pag_scale_list:
                    generator = torch.Generator(device="cpu").manual_seed(seed)
                    print(f'GPU {gpu_id} - Prompt: {prompt[:15]} - Seed: {seed} - CFG: {cfg_scale} - PAG: {pag_scale}')
                    image = pipe(
                        prompt=prompt,
                        negative_prompt="",
                        num_inference_steps=28,
                        height=1024,
                        width=1024,
                        guidance_scale=cfg_scale,
                        pag_scale=pag_scale,
                        generator=generator,
                    ).images[0]
                    images.append(image)

            grid_image = make_image_grid(images, rows=len(cfg_scale_list), cols=len(pag_scale_list))
            grid_image.save(f"gpu{gpu_id}_prompt:{prompt[:15]}_seed:{seed}.png")

if __name__ == "__main__":
    prompts = [
        "a photo of a cat holding a sign that says hello world",
        "a photo of an astronaut riding a horse on mars",
        "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.",
        "A close-up photo of a person. The subject is a woman. She wore a blue coat with a gray dress underneath. She has blue eyes and blond hair, and wears a pair of earrings. Behind are blurred city buildings and streets.",
        "two white cats playing on top of the orange sofa, very comfortable Best quality",
        "an illustration of a stylish swordsman",
        "19th century Scottish wizard with a mysterious smile and a piercing gaze, enigmatic, photorealistic, incredibly detailed, sharpness, detail, cinematic lighting",
        "3D animation of a small, round, fluffy creature with big, expressive eyes explores a vibrant, enchanted forest. The creature, a whimsical blend of a rabbit and a squirrel, has soft blue fur and a bushy, striped tail. It hops along a sparkling stream, its eyes wide with wonder. Flowers that glow and change colors, trees with leaves in shades of purple and silver, and small floating lights that resemble fireflies. The creature stops to interact playfully with a group of tiny, fairy-like beings dancing around a mushroom ring.",
        "Pirate ship trapped in a cosmic maelstrom nebula",
        "Oppenheimer sits on the beach on a chair, watching a nuclear exposition with a huge mushroom cloud, 120mm",
    ]

    # Distribute prompts among GPUs
    num_gpus = len(gpu_list)
    prompts_per_gpu = len(prompts) // num_gpus
    distributed_prompts = [prompts[i * prompts_per_gpu:(i + 1) * prompts_per_gpu] for i in range(num_gpus)]

    # Handle any remaining prompts
    remaining_prompts = prompts[num_gpus * prompts_per_gpu:]
    for i in range(len(remaining_prompts)):
        distributed_prompts[i % num_gpus].append(remaining_prompts[i])

    processes = []
    for gpu_id, gpu_prompts in zip(gpu_list, distributed_prompts):
        p = multiprocessing.Process(target=sample, args=(gpu_id, gpu_prompts))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

Therefore, I have removed the assert statement in the test. I think it is sufficient if out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] generates samples without errors, and the results of pag_scale=3.0 and pag_scale=0.0 are not always guaranteed to be different. If you want to include it, you can test it using nan_to_num like the following, but I don't think it's necessary:
python assert np.max(np.abs(np.nan_to_num(out, nan=0) - np.nan_to_num(out_pag_enabled, nan=0)).flatten()) > 1e-3

A new test case failure has been added, which I am not aware of. It seems to be related to the internal model. Do you know why this failure case was added?

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ============================
FAILED tests/schedulers/test_schedulers.py::SchedulerBaseTests::test_default_arguments_not_in_config - FileNotFoundError: [Errno 2] No such file or directory: '/github/home/.cache/huggingface/hub/models--hf-internal-testing--tiny-stable-diffusion-pipe/snapshots/3ee6c9f225f088ad5d35b624b6514b091e6a4849/unet/diffusion_pytorch_model.bin'
===== 1 failed, 1445 passed, 406 skipped, 82 warnings in 60.15s (0:01:00) ======

@yiyixuxu yiyixuxu merged commit 926daa3 into huggingface:main Aug 6, 2024
14 of 15 checks passed
@sayakpaul
Copy link
Member

Thanks for your awesome contributions, @sunovivid!

@yiyixuxu yiyixuxu added the PAG label Sep 4, 2024
sayakpaul added a commit that referenced this pull request Dec 23, 2024
add pag sd3


---------

Co-authored-by: HyoungwonCho <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: crepejung00 <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Aryan <[email protected]>
Co-authored-by: Aryan <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants