diff --git a/examples/community/README.md b/examples/community/README.md
index bf121b5b7050..f467ee38de3b 100755
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -70,6 +70,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
 | Stable Diffusion XL IPEX Pipeline | Accelerate Stable Diffusion XL inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion XL on IPEX](#stable-diffusion-xl-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
 | Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) |
 |   FRESCO V2V Pipeline                                                                                                    | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962)                                                                                                                                                                                                                                                                                                                                                                                                                                      | [FRESCO V2V Pipeline](#fresco)      | - |              [Yifan Zhou](https://github.com/SingleZombie) |
+| AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
 
 To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
 
@@ -4099,6 +4100,117 @@ output_frames[0].save(output_video_path, save_all=True,
                  append_images=output_frames[1:], duration=100, loop=0)
 ```
 
+### AnimateDiff on IPEX
+
+This diffusion pipeline aims to accelerate the inference of AnimateDiff on Intel Xeon CPUs with BF16/FP32 precision using [IPEX](https://github.com/intel/intel-extension-for-pytorch).
+
+To use this pipeline, you need to:
+1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch)
+
+**Note:** For each PyTorch release, there is a corresponding release of IPEX. Here is the mapping relationship. It is recommended to install Pytorch/IPEX2.3 to get the best performance.
+
+|PyTorch Version|IPEX Version|
+|--|--|
+|[v2.3.\*](https://github.com/pytorch/pytorch/tree/v2.3.0 "v2.3.0")|[v2.3.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v2.3.0+cpu)|
+|[v1.13.\*](https://github.com/pytorch/pytorch/tree/v1.13.0 "v1.13.0")|[v1.13.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v1.13.100+cpu)|
+
+You can simply use pip to install IPEX with the latest version.
+```python
+python -m pip install intel_extension_for_pytorch
+```
+**Note:** To install a specific version, run with the following command:
+```
+python -m pip install intel_extension_for_pytorch==<version_name> -f https://developer.intel.com/ipex-whl-stable-cpu
+```
+2. After pipeline initialization, `prepare_for_ipex()` should be called to enable IPEX accelaration. Supported inference datatypes are Float32 and BFloat16.
+
+```python
+pipe = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
+# For Float32
+pipe.prepare_for_ipex(torch.float32, prompt="A girl smiling")
+# For BFloat16
+pipe.prepare_for_ipex(torch.bfloat16, prompt="A girl smiling")
+```
+
+Then you can use the ipex pipeline in a similar way to the default animatediff pipeline.
+```python
+# For Float32
+output = pipe(prompt="A girl smiling", guidance_scale=1.0, num_inference_steps=step)
+# For BFloat16
+with torch.cpu.amp.autocast(enabled = True, dtype = torch.bfloat16):
+    output = pipe(prompt="A girl smiling", guidance_scale=1.0, num_inference_steps=step)
+```
+
+The following code compares the performance of the original animatediff pipeline with the ipex-optimized pipeline.
+By using this optimized pipeline, we can get about 1.5-2.2 times performance boost with BFloat16 on the fifth generation of Intel Xeon CPUs, code-named Emerald Rapids.
+
+```python
+import torch
+from diffusers import MotionAdapter, AnimateDiffPipeline, EulerDiscreteScheduler
+from safetensors.torch import load_file
+from pipeline_animatediff_ipex import AnimateDiffPipelineIpex
+import time
+
+device = "cpu"
+dtype = torch.float32
+
+prompt = "A girl smiling"
+step = 8  # Options: [1,2,4,8]
+repo = "ByteDance/AnimateDiff-Lightning"
+ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
+base = "emilianJR/epiCRealism"  # Choose to your favorite base model.
+
+adapter = MotionAdapter().to(device, dtype)
+adapter.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
+
+# Helper function for time evaluation
+def elapsed_time(pipeline, nb_pass=3, num_inference_steps=1):
+    # warmup
+    for _ in range(2):
+        output = pipeline(prompt = prompt, guidance_scale=1.0, num_inference_steps = num_inference_steps)
+    #time evaluation
+    start = time.time()
+    for _ in range(nb_pass):
+        pipeline(prompt = prompt, guidance_scale=1.0, num_inference_steps = num_inference_steps)
+    end = time.time()
+    return (end - start) / nb_pass
+
+##############     bf16 inference performance    ###############
+
+# 1. IPEX Pipeline initialization
+pipe = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
+pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
+pipe.prepare_for_ipex(torch.bfloat16, prompt = prompt)
+
+# 2. Original Pipeline initialization
+pipe2 = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
+pipe2.scheduler = EulerDiscreteScheduler.from_config(pipe2.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
+
+# 3. Compare performance between Original Pipeline and IPEX Pipeline
+with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
+    latency = elapsed_time(pipe, num_inference_steps=step)
+    print("Latency of AnimateDiffPipelineIpex--bf16", latency, "s for total", step, "steps")
+    latency = elapsed_time(pipe2, num_inference_steps=step)
+    print("Latency of AnimateDiffPipeline--bf16", latency, "s for total", step, "steps")
+
+##############     fp32 inference performance    ###############
+
+# 1. IPEX Pipeline initialization
+pipe3 = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
+pipe3.scheduler = EulerDiscreteScheduler.from_config(pipe3.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
+pipe3.prepare_for_ipex(torch.float32, prompt = prompt)
+
+# 2. Original Pipeline initialization
+pipe4 = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
+pipe4.scheduler = EulerDiscreteScheduler.from_config(pipe4.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
+
+# 3. Compare performance between Original Pipeline and IPEX Pipeline
+latency = elapsed_time(pipe3, num_inference_steps=step)
+print("Latency of AnimateDiffPipelineIpex--fp32", latency, "s for total", step, "steps")
+latency = elapsed_time(pipe4, num_inference_steps=step)
+print("Latency of AnimateDiffPipeline--fp32",latency, "s for total", step, "steps")
+```
+
 # Perturbed-Attention Guidance
 
 [Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
diff --git a/examples/community/pipeline_animatediff_ipex.py b/examples/community/pipeline_animatediff_ipex.py
new file mode 100644
index 000000000000..dc65e76bc43b
--- /dev/null
+++ b/examples/community/pipeline_animatediff_ipex.py
@@ -0,0 +1,1002 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import intel_extension_for_pytorch as ipex
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.image_processor import PipelineImageInput
+from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.models.unets.unet_motion_model import MotionAdapter
+from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput
+from diffusers.pipelines.free_init_utils import FreeInitMixin
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.schedulers import (
+    DDIMScheduler,
+    DPMSolverMultistepScheduler,
+    EulerAncestralDiscreteScheduler,
+    EulerDiscreteScheduler,
+    LMSDiscreteScheduler,
+    PNDMScheduler,
+)
+from diffusers.utils import (
+    USE_PEFT_BACKEND,
+    logging,
+    replace_example_docstring,
+    scale_lora_layers,
+    unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+    Examples:
+        ```py
+        >>> import torch
+        >>> from diffusers import MotionAdapter, AnimateDiffPipelineIpex, EulerDiscreteScheduler
+        >>> from diffusers.utils import export_to_gif
+        >>> from safetensors.torch import load_file
+
+        >>> device = "cpu"
+        >>> dtype = torch.float32
+
+        >>> # ByteDance/AnimateDiff-Lightning, a distilled version of AnimateDiff SD1.5 v2,
+        >>> # a lightning-fast text-to-video generation model which can generate videos
+        >>> # more than ten times faster than the original AnimateDiff.
+        >>> step = 8  # Options: [1,2,4,8]
+        >>> repo = "ByteDance/AnimateDiff-Lightning"
+        >>> ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
+        >>> base = "emilianJR/epiCRealism"  # Choose to your favorite base model.
+
+        >>> adapter = MotionAdapter().to(device, dtype)
+        >>> adapter.load_state_dict(load_file(hf_hub_download(repo ,ckpt), device=device))
+
+        >>> pipe = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
+        >>> pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
+
+        >>> # For Float32
+        >>> pipe.prepare_for_ipex(torch.float32, prompt = "A girl smiling")
+        >>> # For BFloat16
+        >>> pipe.prepare_for_ipex(torch.bfloat16, prompt = "A girl smiling")
+
+        >>> # For Float32
+        >>> output = pipe(prompt = "A girl smiling", guidance_scale=1.0, num_inference_steps = step)
+        >>> # For BFloat16
+        >>> with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
+        >>>     output = pipe(prompt = "A girl smiling", guidance_scale=1.0, num_inference_steps = step)
+
+        >>> frames = output.frames[0]
+        >>> export_to_gif(frames, "animation.gif")
+        ```
+"""
+
+
+class AnimateDiffPipelineIpex(
+    DiffusionPipeline,
+    StableDiffusionMixin,
+    TextualInversionLoaderMixin,
+    IPAdapterMixin,
+    LoraLoaderMixin,
+    FreeInitMixin,
+):
+    r"""
+    Pipeline for text-to-video generation.
+
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+    implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+    The pipeline also inherits the following loading methods:
+        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+        - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+        - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+        text_encoder ([`CLIPTextModel`]):
+            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+        tokenizer (`CLIPTokenizer`):
+            A [`~transformers.CLIPTokenizer`] to tokenize text.
+        unet ([`UNet2DConditionModel`]):
+            A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
+        motion_adapter ([`MotionAdapter`]):
+            A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+    """
+
+    model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+    _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
+    _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+    def __init__(
+        self,
+        vae: AutoencoderKL,
+        text_encoder: CLIPTextModel,
+        tokenizer: CLIPTokenizer,
+        unet: Union[UNet2DConditionModel, UNetMotionModel],
+        motion_adapter: MotionAdapter,
+        scheduler: Union[
+            DDIMScheduler,
+            PNDMScheduler,
+            LMSDiscreteScheduler,
+            EulerDiscreteScheduler,
+            EulerAncestralDiscreteScheduler,
+            DPMSolverMultistepScheduler,
+        ],
+        feature_extractor: CLIPImageProcessor = None,
+        image_encoder: CLIPVisionModelWithProjection = None,
+    ):
+        super().__init__()
+        if isinstance(unet, UNet2DConditionModel):
+            unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
+
+        self.register_modules(
+            vae=vae,
+            text_encoder=text_encoder,
+            tokenizer=tokenizer,
+            unet=unet,
+            motion_adapter=motion_adapter,
+            scheduler=scheduler,
+            feature_extractor=feature_extractor,
+            image_encoder=image_encoder,
+        )
+        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+        self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
+    def encode_prompt(
+        self,
+        prompt,
+        device,
+        num_images_per_prompt,
+        do_classifier_free_guidance,
+        negative_prompt=None,
+        prompt_embeds: Optional[torch.Tensor] = None,
+        negative_prompt_embeds: Optional[torch.Tensor] = None,
+        lora_scale: Optional[float] = None,
+        clip_skip: Optional[int] = None,
+    ):
+        r"""
+        Encodes the prompt into text encoder hidden states.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                prompt to be encoded
+            device: (`torch.device`):
+                torch device
+            num_images_per_prompt (`int`):
+                number of images that should be generated per prompt
+            do_classifier_free_guidance (`bool`):
+                whether to use classifier free guidance or not
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            prompt_embeds (`torch.Tensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.Tensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            lora_scale (`float`, *optional*):
+                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+            clip_skip (`int`, *optional*):
+                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+                the output of the pre-final layer will be used for computing the prompt embeddings.
+        """
+        # set lora scale so that monkey patched LoRA
+        # function of text encoder can correctly access it
+        if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+            self._lora_scale = lora_scale
+
+            # dynamically adjust the LoRA scale
+            if not USE_PEFT_BACKEND:
+                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+            else:
+                scale_lora_layers(self.text_encoder, lora_scale)
+
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        if prompt_embeds is None:
+            # textual inversion: process multi-vector tokens if necessary
+            if isinstance(self, TextualInversionLoaderMixin):
+                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+            text_inputs = self.tokenizer(
+                prompt,
+                padding="max_length",
+                max_length=self.tokenizer.model_max_length,
+                truncation=True,
+                return_tensors="pt",
+            )
+            text_input_ids = text_inputs.input_ids
+            untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+                text_input_ids, untruncated_ids
+            ):
+                removed_text = self.tokenizer.batch_decode(
+                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+                )
+                logger.warning(
+                    "The following part of your input was truncated because CLIP can only handle sequences up to"
+                    f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+                )
+
+            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+                attention_mask = text_inputs.attention_mask.to(device)
+            else:
+                attention_mask = None
+
+            if clip_skip is None:
+                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+                prompt_embeds = prompt_embeds[0]
+            else:
+                prompt_embeds = self.text_encoder(
+                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+                )
+                # Access the `hidden_states` first, that contains a tuple of
+                # all the hidden states from the encoder layers. Then index into
+                # the tuple to access the hidden states from the desired layer.
+                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+                # We also need to apply the final LayerNorm here to not mess with the
+                # representations. The `last_hidden_states` that we typically use for
+                # obtaining the final prompt representations passes through the LayerNorm
+                # layer.
+                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+        if self.text_encoder is not None:
+            prompt_embeds_dtype = self.text_encoder.dtype
+        elif self.unet is not None:
+            prompt_embeds_dtype = self.unet.dtype
+        else:
+            prompt_embeds_dtype = prompt_embeds.dtype
+
+        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+        bs_embed, seq_len, _ = prompt_embeds.shape
+        # duplicate text embeddings for each generation per prompt, using mps friendly method
+        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+        # get unconditional embeddings for classifier free guidance
+        if do_classifier_free_guidance and negative_prompt_embeds is None:
+            uncond_tokens: List[str]
+            if negative_prompt is None:
+                uncond_tokens = [""] * batch_size
+            elif prompt is not None and type(prompt) is not type(negative_prompt):
+                raise TypeError(
+                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                    f" {type(prompt)}."
+                )
+            elif isinstance(negative_prompt, str):
+                uncond_tokens = [negative_prompt]
+            elif batch_size != len(negative_prompt):
+                raise ValueError(
+                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                    " the batch size of `prompt`."
+                )
+            else:
+                uncond_tokens = negative_prompt
+
+            # textual inversion: process multi-vector tokens if necessary
+            if isinstance(self, TextualInversionLoaderMixin):
+                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+            max_length = prompt_embeds.shape[1]
+            uncond_input = self.tokenizer(
+                uncond_tokens,
+                padding="max_length",
+                max_length=max_length,
+                truncation=True,
+                return_tensors="pt",
+            )
+
+            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+                attention_mask = uncond_input.attention_mask.to(device)
+            else:
+                attention_mask = None
+
+            negative_prompt_embeds = self.text_encoder(
+                uncond_input.input_ids.to(device),
+                attention_mask=attention_mask,
+            )
+            negative_prompt_embeds = negative_prompt_embeds[0]
+
+        if do_classifier_free_guidance:
+            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+            seq_len = negative_prompt_embeds.shape[1]
+
+            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+        if self.text_encoder is not None:
+            if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
+                # Retrieve the original scale by scaling back the LoRA layers
+                unscale_lora_layers(self.text_encoder, lora_scale)
+
+        return prompt_embeds, negative_prompt_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+        dtype = next(self.image_encoder.parameters()).dtype
+
+        if not isinstance(image, torch.Tensor):
+            image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+        image = image.to(device=device, dtype=dtype)
+        if output_hidden_states:
+            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+            uncond_image_enc_hidden_states = self.image_encoder(
+                torch.zeros_like(image), output_hidden_states=True
+            ).hidden_states[-2]
+            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+                num_images_per_prompt, dim=0
+            )
+            return image_enc_hidden_states, uncond_image_enc_hidden_states
+        else:
+            image_embeds = self.image_encoder(image).image_embeds
+            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+            uncond_image_embeds = torch.zeros_like(image_embeds)
+
+            return image_embeds, uncond_image_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+    def prepare_ip_adapter_image_embeds(
+        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+    ):
+        if ip_adapter_image_embeds is None:
+            if not isinstance(ip_adapter_image, list):
+                ip_adapter_image = [ip_adapter_image]
+
+            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+                raise ValueError(
+                    f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+                )
+
+            image_embeds = []
+            for single_ip_adapter_image, image_proj_layer in zip(
+                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+            ):
+                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+                single_image_embeds, single_negative_image_embeds = self.encode_image(
+                    single_ip_adapter_image, device, 1, output_hidden_state
+                )
+                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+                single_negative_image_embeds = torch.stack(
+                    [single_negative_image_embeds] * num_images_per_prompt, dim=0
+                )
+
+                if do_classifier_free_guidance:
+                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+                    single_image_embeds = single_image_embeds.to(device)
+
+                image_embeds.append(single_image_embeds)
+        else:
+            repeat_dims = [1]
+            image_embeds = []
+            for single_image_embeds in ip_adapter_image_embeds:
+                if do_classifier_free_guidance:
+                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+                    single_image_embeds = single_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+                    )
+                    single_negative_image_embeds = single_negative_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+                    )
+                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+                else:
+                    single_image_embeds = single_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+                    )
+                image_embeds.append(single_image_embeds)
+
+        return image_embeds
+
+    # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
+    def decode_latents(self, latents):
+        latents = 1 / self.vae.config.scaling_factor * latents
+
+        batch_size, channels, num_frames, height, width = latents.shape
+        latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
+
+        image = self.vae.decode(latents).sample
+        video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
+        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+        video = video.float()
+        return video
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+    def prepare_extra_step_kwargs(self, generator, eta):
+        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+        # and should be between [0, 1]
+
+        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        extra_step_kwargs = {}
+        if accepts_eta:
+            extra_step_kwargs["eta"] = eta
+
+        # check if the scheduler accepts generator
+        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        if accepts_generator:
+            extra_step_kwargs["generator"] = generator
+        return extra_step_kwargs
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+    def check_inputs(
+        self,
+        prompt,
+        height,
+        width,
+        negative_prompt=None,
+        prompt_embeds=None,
+        negative_prompt_embeds=None,
+        ip_adapter_image=None,
+        ip_adapter_image_embeds=None,
+        callback_on_step_end_tensor_inputs=None,
+    ):
+        if height % 8 != 0 or width % 8 != 0:
+            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+        if callback_on_step_end_tensor_inputs is not None and not all(
+            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+        ):
+            raise ValueError(
+                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+            )
+
+        if prompt is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt is None and prompt_embeds is None:
+            raise ValueError(
+                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+            )
+        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+        if negative_prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if prompt_embeds is not None and negative_prompt_embeds is not None:
+            if prompt_embeds.shape != negative_prompt_embeds.shape:
+                raise ValueError(
+                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+                    f" {negative_prompt_embeds.shape}."
+                )
+
+        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+            raise ValueError(
+                "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+            )
+
+        if ip_adapter_image_embeds is not None:
+            if not isinstance(ip_adapter_image_embeds, list):
+                raise ValueError(
+                    f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+                )
+            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+                raise ValueError(
+                    f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+                )
+
+    # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
+    def prepare_latents(
+        self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+    ):
+        shape = (
+            batch_size,
+            num_channels_latents,
+            num_frames,
+            height // self.vae_scale_factor,
+            width // self.vae_scale_factor,
+        )
+        if isinstance(generator, list) and len(generator) != batch_size:
+            raise ValueError(
+                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+            )
+
+        if latents is None:
+            latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
+        else:
+            latents = latents.to(device)
+
+        # scale the initial noise by the standard deviation required by the scheduler
+        latents = latents * self.scheduler.init_noise_sigma
+        return latents
+
+    @property
+    def guidance_scale(self):
+        return self._guidance_scale
+
+    @property
+    def clip_skip(self):
+        return self._clip_skip
+
+    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+    # corresponds to doing no classifier free guidance.
+    @property
+    def do_classifier_free_guidance(self):
+        return self._guidance_scale > 1
+
+    @property
+    def cross_attention_kwargs(self):
+        return self._cross_attention_kwargs
+
+    @property
+    def num_timesteps(self):
+        return self._num_timesteps
+
+    @torch.no_grad()
+    @replace_example_docstring(EXAMPLE_DOC_STRING)
+    def __call__(
+        self,
+        prompt: Union[str, List[str]] = None,
+        num_frames: Optional[int] = 16,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        num_inference_steps: int = 50,
+        guidance_scale: float = 7.5,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        num_videos_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.Tensor] = None,
+        prompt_embeds: Optional[torch.Tensor] = None,
+        negative_prompt_embeds: Optional[torch.Tensor] = None,
+        ip_adapter_image: Optional[PipelineImageInput] = None,
+        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        clip_skip: Optional[int] = None,
+        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+    ):
+        r"""
+        The call function to the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+                The height in pixels of the generated video.
+            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+                The width in pixels of the generated video.
+            num_frames (`int`, *optional*, defaults to 16):
+                The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
+                amounts to 2 seconds of video.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
+                expense of slower inference.
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                A higher guidance scale value encourages the model to generate images closely linked to the text
+                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+                generation deterministic.
+            latents (`torch.Tensor`, *optional*):
+                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
+                `(batch_size, num_channel, num_frames, height, width)`.
+            prompt_embeds (`torch.Tensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+                provided, text embeddings are generated from the `prompt` input argument.
+            negative_prompt_embeds (`torch.Tensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+            ip_adapter_image: (`PipelineImageInput`, *optional*):
+                Optional image input to work with IP Adapters.
+            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+                contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+                provided, embeddings are computed from the `ip_adapter_image` input argument.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
+                of a plain tuple.
+            cross_attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+            clip_skip (`int`, *optional*):
+                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+                the output of the pre-final layer will be used for computing the prompt embeddings.
+            callback_on_step_end (`Callable`, *optional*):
+                A function that calls at the end of each denoising steps during the inference. The function is called
+                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+                `callback_on_step_end_tensor_inputs`.
+            callback_on_step_end_tensor_inputs (`List`, *optional*):
+                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+                `._callback_tensor_inputs` attribute of your pipeline class.
+
+        Examples:
+
+        Returns:
+            [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
+                If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
+                returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
+        """
+
+        # 0. Default height and width to unet
+        height = height or self.unet.config.sample_size * self.vae_scale_factor
+        width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+        num_videos_per_prompt = 1
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(
+            prompt,
+            height,
+            width,
+            negative_prompt,
+            prompt_embeds,
+            negative_prompt_embeds,
+            ip_adapter_image,
+            ip_adapter_image_embeds,
+            callback_on_step_end_tensor_inputs,
+        )
+
+        self._guidance_scale = guidance_scale
+        self._clip_skip = clip_skip
+        self._cross_attention_kwargs = cross_attention_kwargs
+
+        # 2. Define call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+
+        # 3. Encode input prompt
+        text_encoder_lora_scale = (
+            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+        )
+        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+            prompt,
+            device,
+            num_videos_per_prompt,
+            self.do_classifier_free_guidance,
+            negative_prompt,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            lora_scale=text_encoder_lora_scale,
+            clip_skip=self.clip_skip,
+        )
+        # For classifier free guidance, we need to do two forward passes.
+        # Here we concatenate the unconditional and text embeddings into a single batch
+        # to avoid doing two forward passes
+        if self.do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+            image_embeds = self.prepare_ip_adapter_image_embeds(
+                ip_adapter_image,
+                ip_adapter_image_embeds,
+                device,
+                batch_size * num_videos_per_prompt,
+                self.do_classifier_free_guidance,
+            )
+
+        # 4. Prepare timesteps
+        self.scheduler.set_timesteps(num_inference_steps, device=device)
+        timesteps = self.scheduler.timesteps
+
+        # 5. Prepare latent variables
+        num_channels_latents = self.unet.config.in_channels
+        latents = self.prepare_latents(
+            batch_size * num_videos_per_prompt,
+            num_channels_latents,
+            num_frames,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+        )
+
+        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 7. Add image embeds for IP-Adapter
+        added_cond_kwargs = (
+            {"image_embeds": image_embeds}
+            if ip_adapter_image is not None or ip_adapter_image_embeds is not None
+            else None
+        )
+
+        num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
+        for free_init_iter in range(num_free_init_iters):
+            if self.free_init_enabled:
+                latents, timesteps = self._apply_free_init(
+                    latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
+                )
+
+            self._num_timesteps = len(timesteps)
+            num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+            # 8. Denoising loop
+            with self.progress_bar(total=self._num_timesteps) as progress_bar:
+                for i, t in enumerate(timesteps):
+                    # expand the latents if we are doing classifier free guidance
+                    latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+                    latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                    # predict the noise residual
+                    noise_pred = self.unet(
+                        latent_model_input,
+                        t,
+                        encoder_hidden_states=prompt_embeds,
+                        # cross_attention_kwargs=cross_attention_kwargs,
+                        # added_cond_kwargs=added_cond_kwargs,
+                        # ).sample
+                    )["sample"]
+
+                    # perform guidance
+                    if self.do_classifier_free_guidance:
+                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                    # compute the previous noisy sample x_t -> x_t-1
+                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+                    if callback_on_step_end is not None:
+                        callback_kwargs = {}
+                        for k in callback_on_step_end_tensor_inputs:
+                            callback_kwargs[k] = locals()[k]
+                        callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                        latents = callback_outputs.pop("latents", latents)
+                        prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                        negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+                    # call the callback, if provided
+                    if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                        progress_bar.update()
+
+        # 9. Post processing
+        if output_type == "latent":
+            video = latents
+        else:
+            video_tensor = self.decode_latents(latents)
+            video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
+
+        # 10. Offload all models
+        self.maybe_free_model_hooks()
+
+        if not return_dict:
+            return (video,)
+
+        return AnimateDiffPipelineOutput(frames=video)
+
+    @torch.no_grad()
+    def prepare_for_ipex(
+        self,
+        dtype=torch.float32,
+        prompt: Union[str, List[str]] = None,
+        num_frames: Optional[int] = 16,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        num_inference_steps: int = 50,
+        guidance_scale: float = 7.5,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        num_videos_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.Tensor] = None,
+        prompt_embeds: Optional[torch.Tensor] = None,
+        negative_prompt_embeds: Optional[torch.Tensor] = None,
+        ip_adapter_image: Optional[PipelineImageInput] = None,
+        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        clip_skip: Optional[int] = None,
+        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+    ):
+        # 0. Default height and width to unet
+        height = height or self.unet.config.sample_size * self.vae_scale_factor
+        width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+        num_videos_per_prompt = 1
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(
+            prompt,
+            height,
+            width,
+            negative_prompt,
+            prompt_embeds,
+            negative_prompt_embeds,
+            ip_adapter_image,
+            ip_adapter_image_embeds,
+            callback_on_step_end_tensor_inputs,
+        )
+
+        self._guidance_scale = guidance_scale
+        self._clip_skip = clip_skip
+        self._cross_attention_kwargs = cross_attention_kwargs
+
+        # 2. Define call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+
+        # 3. Encode input prompt
+        text_encoder_lora_scale = (
+            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+        )
+        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+            prompt,
+            device,
+            num_videos_per_prompt,
+            self.do_classifier_free_guidance,
+            negative_prompt,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            lora_scale=text_encoder_lora_scale,
+            clip_skip=self.clip_skip,
+        )
+        # For classifier free guidance, we need to do two forward passes.
+        # Here we concatenate the unconditional and text embeddings into a single batch
+        # to avoid doing two forward passes
+        if self.do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+        # 4. Prepare timesteps
+        self.scheduler.set_timesteps(num_inference_steps, device=device)
+        timesteps = self.scheduler.timesteps
+
+        # 5. Prepare latent variables
+        num_channels_latents = self.unet.config.in_channels
+        latents = self.prepare_latents(
+            batch_size * num_videos_per_prompt,
+            num_channels_latents,
+            num_frames,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+        )
+
+        num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
+        for free_init_iter in range(num_free_init_iters):
+            if self.free_init_enabled:
+                latents, timesteps = self._apply_free_init(
+                    latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
+                )
+
+        self._num_timesteps = len(timesteps)
+
+        dummy = timesteps[0]
+        latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+        latent_model_input = self.scheduler.scale_model_input(latent_model_input, dummy)
+
+        self.unet = self.unet.to(memory_format=torch.channels_last)
+        self.vae.decoder = self.vae.decoder.to(memory_format=torch.channels_last)
+        self.text_encoder = self.text_encoder.to(memory_format=torch.channels_last)
+
+        unet_input_example = {
+            "sample": latent_model_input,
+            "timestep": dummy,
+            "encoder_hidden_states": prompt_embeds,
+        }
+
+        fake_latents = 1 / self.vae.config.scaling_factor * latents
+        batch_size, channels, num_frames, height, width = fake_latents.shape
+        fake_latents = fake_latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
+        vae_decoder_input_example = fake_latents
+
+        # optimize with ipex
+        if dtype == torch.bfloat16:
+            self.unet = ipex.optimize(self.unet.eval(), dtype=torch.bfloat16, inplace=True)
+            self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.bfloat16, inplace=True)
+            self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
+        elif dtype == torch.float32:
+            self.unet = ipex.optimize(
+                self.unet.eval(),
+                dtype=torch.float32,
+                inplace=True,
+                # sample_input=unet_input_example,
+                level="O1",
+                weights_prepack=True,
+                auto_kernel_selection=False,
+            )
+            self.vae.decoder = ipex.optimize(
+                self.vae.decoder.eval(),
+                dtype=torch.float32,
+                inplace=True,
+                level="O1",
+                weights_prepack=True,
+                auto_kernel_selection=False,
+            )
+            self.text_encoder = ipex.optimize(
+                self.text_encoder.eval(),
+                dtype=torch.float32,
+                inplace=True,
+                level="O1",
+                weights_prepack=True,
+                auto_kernel_selection=False,
+            )
+        else:
+            raise ValueError(" The value of 'dtype' should be 'torch.bfloat16' or 'torch.float32' !")
+
+        # trace unet model to get better performance on IPEX
+        with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad():
+            unet_trace_model = torch.jit.trace(
+                self.unet, example_kwarg_inputs=unet_input_example, check_trace=False, strict=False
+            )
+            unet_trace_model = torch.jit.freeze(unet_trace_model)
+            self.unet.forward = unet_trace_model.forward
+
+        # trace vae.decoder model to get better performance on IPEX
+        with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad():
+            vae_decoder_trace_model = torch.jit.trace(
+                self.vae.decoder, vae_decoder_input_example, check_trace=False, strict=False
+            )
+            vae_decoder_trace_model = torch.jit.freeze(vae_decoder_trace_model)
+            self.vae.decoder.forward = vae_decoder_trace_model.forward