Skip to content

[Community Pipelines] Accelerate inference of AnimateDiff by IPEX on CPU #8643

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Stable Diffusion XL IPEX Pipeline | Accelerate Stable Diffusion XL inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion XL on IPEX](#stable-diffusion-xl-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) |
| FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) |
| AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |

To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.

Expand Down Expand Up @@ -4099,6 +4100,117 @@ output_frames[0].save(output_video_path, save_all=True,
append_images=output_frames[1:], duration=100, loop=0)
```

### AnimateDiff on IPEX

This diffusion pipeline aims to accelerate the inference of AnimateDiff on Intel Xeon CPUs with BF16/FP32 precision using [IPEX](https://github.com/intel/intel-extension-for-pytorch).

To use this pipeline, you need to:
1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch)

**Note:** For each PyTorch release, there is a corresponding release of IPEX. Here is the mapping relationship. It is recommended to install Pytorch/IPEX2.3 to get the best performance.

|PyTorch Version|IPEX Version|
|--|--|
|[v2.3.\*](https://github.com/pytorch/pytorch/tree/v2.3.0 "v2.3.0")|[v2.3.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v2.3.0+cpu)|
|[v1.13.\*](https://github.com/pytorch/pytorch/tree/v1.13.0 "v1.13.0")|[v1.13.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v1.13.100+cpu)|

You can simply use pip to install IPEX with the latest version.
```python
python -m pip install intel_extension_for_pytorch
```
**Note:** To install a specific version, run with the following command:
```
python -m pip install intel_extension_for_pytorch==<version_name> -f https://developer.intel.com/ipex-whl-stable-cpu
```
2. After pipeline initialization, `prepare_for_ipex()` should be called to enable IPEX accelaration. Supported inference datatypes are Float32 and BFloat16.

```python
pipe = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
# For Float32
pipe.prepare_for_ipex(torch.float32, prompt="A girl smiling")
# For BFloat16
pipe.prepare_for_ipex(torch.bfloat16, prompt="A girl smiling")
```

Then you can use the ipex pipeline in a similar way to the default animatediff pipeline.
```python
# For Float32
output = pipe(prompt="A girl smiling", guidance_scale=1.0, num_inference_steps=step)
# For BFloat16
with torch.cpu.amp.autocast(enabled = True, dtype = torch.bfloat16):
output = pipe(prompt="A girl smiling", guidance_scale=1.0, num_inference_steps=step)
```

The following code compares the performance of the original animatediff pipeline with the ipex-optimized pipeline.
By using this optimized pipeline, we can get about 1.5-2.2 times performance boost with BFloat16 on the fifth generation of Intel Xeon CPUs, code-named Emerald Rapids.

```python
import torch
from diffusers import MotionAdapter, AnimateDiffPipeline, EulerDiscreteScheduler
from safetensors.torch import load_file
from pipeline_animatediff_ipex import AnimateDiffPipelineIpex
import time

device = "cpu"
dtype = torch.float32

prompt = "A girl smiling"
step = 8 # Options: [1,2,4,8]
repo = "ByteDance/AnimateDiff-Lightning"
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
base = "emilianJR/epiCRealism" # Choose to your favorite base model.

adapter = MotionAdapter().to(device, dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would actually prefer if from_pretrained could be used consistently but this is no problem either

adapter.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))

# Helper function for time evaluation
def elapsed_time(pipeline, nb_pass=3, num_inference_steps=1):
# warmup
for _ in range(2):
output = pipeline(prompt = prompt, guidance_scale=1.0, num_inference_steps = num_inference_steps)
#time evaluation
start = time.time()
for _ in range(nb_pass):
pipeline(prompt = prompt, guidance_scale=1.0, num_inference_steps = num_inference_steps)
end = time.time()
return (end - start) / nb_pass

############## bf16 inference performance ###############

# 1. IPEX Pipeline initialization
pipe = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
pipe.prepare_for_ipex(torch.bfloat16, prompt = prompt)

# 2. Original Pipeline initialization
pipe2 = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
pipe2.scheduler = EulerDiscreteScheduler.from_config(pipe2.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")

# 3. Compare performance between Original Pipeline and IPEX Pipeline
with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
latency = elapsed_time(pipe, num_inference_steps=step)
print("Latency of AnimateDiffPipelineIpex--bf16", latency, "s for total", step, "steps")
latency = elapsed_time(pipe2, num_inference_steps=step)
print("Latency of AnimateDiffPipeline--bf16", latency, "s for total", step, "steps")

############## fp32 inference performance ###############

# 1. IPEX Pipeline initialization
pipe3 = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
pipe3.scheduler = EulerDiscreteScheduler.from_config(pipe3.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
pipe3.prepare_for_ipex(torch.float32, prompt = prompt)

# 2. Original Pipeline initialization
pipe4 = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
pipe4.scheduler = EulerDiscreteScheduler.from_config(pipe4.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")

# 3. Compare performance between Original Pipeline and IPEX Pipeline
latency = elapsed_time(pipe3, num_inference_steps=step)
print("Latency of AnimateDiffPipelineIpex--fp32", latency, "s for total", step, "steps")
latency = elapsed_time(pipe4, num_inference_steps=step)
print("Latency of AnimateDiffPipeline--fp32",latency, "s for total", step, "steps")
```

# Perturbed-Attention Guidance

[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
Expand Down
Loading
Loading