From 6032a0804830d3168467aa4893f80a2ab9e418f2 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 14 Apr 2023 23:43:39 +0800 Subject: [PATCH 01/13] add stable_diffusion_ipex community pipeline --- examples/community/README.md | 74 +- examples/community/stable_diffusion_ipex.py | 829 ++++++++++++++++++++ 2 files changed, 902 insertions(+), 1 deletion(-) create mode 100644 examples/community/stable_diffusion_ipex.py diff --git a/examples/community/README.md b/examples/community/README.md index 11da90764579..8ed5d639da4e 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -31,7 +31,7 @@ MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | | DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - |[Aengus (Duc-Anh)](https://github.com/aengusng8) | | CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) | - +| Stable Diffusion IPEX Pipeline | Accelerate Stable Diffusion inference pipeline with BF16/FP32 precision on CPU by [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion on IPEX](#stable-diffusion-on-ipex) | - | [Yingjie Han](https://github.com/yingjie-han/) | To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. @@ -1130,3 +1130,75 @@ Init Image Output Image ![img2img_clip_guidance](https://huggingface.co/datasets/njindal/images/resolve/main/clip_guided_img2img.jpg) + + +### Stable Diffusion on IPEX + +This diffusion pipeline can significantly accelarate the inference of Stable-Diffusion on Intel CPUs with BF16/FP32 precision by IPEX. +You need: +1.Install IPEX +```python +python -m pip install intel_extension_for_pytorch +``` +2.After pipeline initialization, prepare_for_ipex() should be called to enable IPEX accelaration. +```python +pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_ipex") +pipe.prepare_for_ipex(prompt,infer_type='bf16') +``` + +Other usage of "stable_diffusion_ipex" pipeline is same as the default stable diffusion pipeline. +Following code compares the performance of original stable diffusion pipeline with ipex pipeline. + +```python +import torch +import intel_extension_for_pytorch as ipex +from diffusers import StableDiffusionPipeline +import time + +prompt = "sailing ship in storm by Rembrandt" +model_id = "runwayml/stable-diffusion-v1-5" +#Help function for time evaluation +def elapsed_time(pipeline, nb_pass=3, num_inference_steps=20): + # warmup + for _ in range(2): + images = pipeline(prompt, num_inference_steps=num_inference_steps).images + #time evaluation + start = time.time() + for _ in range(nb_pass): + pipeline(prompt, num_inference_steps=num_inference_steps) + end = time.time() + return (end - start) / nb_pass + +############## bf16 inference performance ############### + +#1.IPEX Pipeline initialization +pipe = DiffusionPipeline.from_pretrained(model_id, custom_pipeline="stable_diffusion_ipex") +pipe.prepare_for_ipex(prompt,infer_type='bf16') + +#2.Original Pipeline initialization +pipe2 = StableDiffusionPipeline.from_pretrained(model_id) + +#3.Compare performance between Original Pipeline and IPEX Pipeline +with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16): + latency = elapsed_time(pipe) + print("Latency of StableDiffusionIPEXPipeline--bf16", latency) + latency = elapsed_time(pipe2) + print("Latency of StableDiffusionPipeline--bf16",latency) + +############## fp32 inference performance ############### + +#1.IPEX Pipeline initialization +pipe3 = DiffusionPipeline.from_pretrained(model_id, custom_pipeline="stable_diffusion_ipex") +pipe3.prepare_for_ipex(prompt,infer_type='fp32') + +#2.Original Pipeline initialization +pipe4 = StableDiffusionPipeline.from_pretrained(model_id) + +#3.Compare performance between Original Pipeline and IPEX Pipeline +with torch.no_grad(): + latency = elapsed_time(pipe3) + print("Latency of StableDiffusionIPEXPipeline--fp32", latency) + latency = elapsed_time(pipe4) + print("Latency of StableDiffusionPipeline--fp32",latency) + +``` \ No newline at end of file diff --git a/examples/community/stable_diffusion_ipex.py b/examples/community/stable_diffusion_ipex.py new file mode 100644 index 000000000000..d4ee86e29606 --- /dev/null +++ b/examples/community/stable_diffusion_ipex.py @@ -0,0 +1,829 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from packaging import version +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + +import intel_extension_for_pytorch as ipex + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline + + >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +class StableDiffusionIPEXPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion on IPEX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + + + def get_input_example(self,prompt,height = None,width = None,guidance_scale=7.5,num_images_per_prompt =1): + + prompt_embeds = None + negative_prompt_embeds = None + negative_prompt = None + callback_steps = 1 + generator = None + latents = None + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + + device = "cpu" + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + + # 5. Prepare latent variables + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + self.unet.in_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + dummy = torch.ones(1, dtype=torch.int32) + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, dummy) + + unet_input_example = (latent_model_input, dummy, prompt_embeds) + vae_decoder_input_example = latents + + return unet_input_example,vae_decoder_input_example + + + + def prepare_for_ipex(self,promt,infer_type = 'bf16',height = None,width = None,guidance_scale=7.5): + self.unet = self.unet.to(memory_format=torch.channels_last) + self.vae.decoder = self.vae.decoder.to(memory_format=torch.channels_last) + self.text_encoder = self.text_encoder.to(memory_format=torch.channels_last) + if self.safety_checker != None: + self.safety_checker = self.safety_checker.to(memory_format=torch.channels_last) + + unet_input_example,vae_decoder_input_example = self.get_input_example(promt,height,width,guidance_scale) + + # optimize with ipex + if infer_type == 'bf16': + self.unet = ipex.optimize(self.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=unet_input_example) + self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.bfloat16, inplace=True) + self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.bfloat16, inplace=True) + if self.safety_checker != None: + self.safety_checker = ipex.optimize(self.safety_checker.eval(), dtype=torch.bfloat16, inplace=True) + elif infer_type == 'fp32': + self.unet = ipex.optimize(self.unet.eval(), dtype=torch.float32, inplace=True, sample_input=unet_input_example, level="O1", weights_prepack=True, auto_kernel_selection=False) + self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.float32, inplace=True, level="O1", weights_prepack=True, auto_kernel_selection=False) + self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.float32, inplace=True, level="O1", weights_prepack=True, auto_kernel_selection=False) + if self.safety_checker != None: + self.safety_checker = ipex.optimize(self.safety_checker.eval(), dtype=torch.float32, inplace=True, level="O1", weights_prepack=True, auto_kernel_selection=False) + else: + raise ValueError( + f" The value of infer_type should be 'bf16' or 'fp32' !" + ) + + # trace unet model to get better performance on IPEX + with torch.cpu.amp.autocast(enabled=infer_type=='bf16'), torch.no_grad(): + unet_trace_model = torch.jit.trace(self.unet, unet_input_example, check_trace=False, strict=False) + unet_trace_model = torch.jit.freeze(unet_trace_model) + self.unet.forward = unet_trace_model.forward + + # trace vae.decoder model to get better performance on IPEX + with torch.cpu.amp.autocast(enabled=infer_type=='bf16'), torch.no_grad(): + ave_decoder_trace_model = torch.jit.trace(self.vae.decoder, vae_decoder_input_example, check_trace=False, strict=False) + ave_decoder_trace_model = torch.jit.freeze(ave_decoder_trace_model) + self.vae.decoder.forward = ave_decoder_trace_model.forward + + + + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds + )['sample'] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + else: + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + From ce14a43eb7d88020e30965f7fcb0050c3082df93 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 15 Apr 2023 00:16:33 +0800 Subject: [PATCH 02/13] Update readme.md --- examples/community/README.md | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/community/README.md b/examples/community/README.md index 8ed5d639da4e..ffd49a83ada0 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -31,7 +31,7 @@ MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | | DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - |[Aengus (Duc-Anh)](https://github.com/aengusng8) | | CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) | -| Stable Diffusion IPEX Pipeline | Accelerate Stable Diffusion inference pipeline with BF16/FP32 precision on CPU by [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion on IPEX](#stable-diffusion-on-ipex) | - | [Yingjie Han](https://github.com/yingjie-han/) | +| Stable Diffusion IPEX Pipeline | Accelerate Stable Diffusion inference pipeline with BF16/FP32 precision on Intel CPUs by [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion on IPEX](#stable-diffusion-on-ipex) | - | [Yingjie Han](https://github.com/yingjie-han/) | To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. @@ -1134,19 +1134,21 @@ Output Image ### Stable Diffusion on IPEX -This diffusion pipeline can significantly accelarate the inference of Stable-Diffusion on Intel CPUs with BF16/FP32 precision by IPEX. -You need: -1.Install IPEX +This diffusion pipeline can accelarate the inference of Stable-Diffusion on Intel CPUs with BF16/FP32 precision by [IPEX](https://github.com/intel/intel-extension-for-pytorch). + +To use this pipeline, You need to: +1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch) ```python python -m pip install intel_extension_for_pytorch ``` -2.After pipeline initialization, prepare_for_ipex() should be called to enable IPEX accelaration. +2. After pipeline initialization, prepare_for_ipex() should be called to enable IPEX accelaration. ```python pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_ipex") pipe.prepare_for_ipex(prompt,infer_type='bf16') ``` -Other usage of "stable_diffusion_ipex" pipeline is same as the default stable diffusion pipeline. +Other usage of this ipex pipeline is same as the default stable diffusion pipeline. + Following code compares the performance of original stable diffusion pipeline with ipex pipeline. ```python From 2eab9c4daee8cfbfb0ac7f1314d6565eaf67c7f6 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 18 Apr 2023 23:20:25 +0800 Subject: [PATCH 03/13] reformat --- examples/community/stable_diffusion_ipex.py | 87 ++++++++++++--------- 1 file changed, 51 insertions(+), 36 deletions(-) diff --git a/examples/community/stable_diffusion_ipex.py b/examples/community/stable_diffusion_ipex.py index d4ee86e29606..85aba627e7ce 100644 --- a/examples/community/stable_diffusion_ipex.py +++ b/examples/community/stable_diffusion_ipex.py @@ -160,7 +160,6 @@ def __init__( new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) - self.register_modules( vae=vae, text_encoder=text_encoder, @@ -173,9 +172,7 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - - - def get_input_example(self,prompt,height = None,width = None,guidance_scale=7.5,num_images_per_prompt =1): + def get_input_example(self, prompt, height=None, width=None, guidance_scale=7.5, num_images_per_prompt=1): prompt_embeds = None negative_prompt_embeds = None @@ -205,7 +202,6 @@ def get_input_example(self,prompt,height = None,width = None,guidance_scale=7.5, # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, @@ -217,7 +213,6 @@ def get_input_example(self,prompt,height = None,width = None,guidance_scale=7.5, negative_prompt_embeds=negative_prompt_embeds, ) - # 5. Prepare latent variables latents = self.prepare_latents( batch_size * num_images_per_prompt, @@ -236,52 +231,78 @@ def get_input_example(self,prompt,height = None,width = None,guidance_scale=7.5, unet_input_example = (latent_model_input, dummy, prompt_embeds) vae_decoder_input_example = latents - return unet_input_example,vae_decoder_input_example - - + return unet_input_example, vae_decoder_input_example - def prepare_for_ipex(self,promt,infer_type = 'bf16',height = None,width = None,guidance_scale=7.5): + def prepare_for_ipex(self, promt, infer_type="bf16", height=None, width=None, guidance_scale=7.5): self.unet = self.unet.to(memory_format=torch.channels_last) self.vae.decoder = self.vae.decoder.to(memory_format=torch.channels_last) self.text_encoder = self.text_encoder.to(memory_format=torch.channels_last) if self.safety_checker != None: self.safety_checker = self.safety_checker.to(memory_format=torch.channels_last) - unet_input_example,vae_decoder_input_example = self.get_input_example(promt,height,width,guidance_scale) + unet_input_example, vae_decoder_input_example = self.get_input_example(promt, height, width, guidance_scale) # optimize with ipex - if infer_type == 'bf16': - self.unet = ipex.optimize(self.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=unet_input_example) + if infer_type == "bf16": + self.unet = ipex.optimize( + self.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=unet_input_example + ) self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.bfloat16, inplace=True) self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.bfloat16, inplace=True) if self.safety_checker != None: self.safety_checker = ipex.optimize(self.safety_checker.eval(), dtype=torch.bfloat16, inplace=True) - elif infer_type == 'fp32': - self.unet = ipex.optimize(self.unet.eval(), dtype=torch.float32, inplace=True, sample_input=unet_input_example, level="O1", weights_prepack=True, auto_kernel_selection=False) - self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.float32, inplace=True, level="O1", weights_prepack=True, auto_kernel_selection=False) - self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.float32, inplace=True, level="O1", weights_prepack=True, auto_kernel_selection=False) + elif infer_type == "fp32": + self.unet = ipex.optimize( + self.unet.eval(), + dtype=torch.float32, + inplace=True, + sample_input=unet_input_example, + level="O1", + weights_prepack=True, + auto_kernel_selection=False, + ) + self.vae.decoder = ipex.optimize( + self.vae.decoder.eval(), + dtype=torch.float32, + inplace=True, + level="O1", + weights_prepack=True, + auto_kernel_selection=False, + ) + self.text_encoder = ipex.optimize( + self.text_encoder.eval(), + dtype=torch.float32, + inplace=True, + level="O1", + weights_prepack=True, + auto_kernel_selection=False, + ) if self.safety_checker != None: - self.safety_checker = ipex.optimize(self.safety_checker.eval(), dtype=torch.float32, inplace=True, level="O1", weights_prepack=True, auto_kernel_selection=False) + self.safety_checker = ipex.optimize( + self.safety_checker.eval(), + dtype=torch.float32, + inplace=True, + level="O1", + weights_prepack=True, + auto_kernel_selection=False, + ) else: - raise ValueError( - f" The value of infer_type should be 'bf16' or 'fp32' !" - ) + raise ValueError(f" The value of infer_type should be 'bf16' or 'fp32' !") # trace unet model to get better performance on IPEX - with torch.cpu.amp.autocast(enabled=infer_type=='bf16'), torch.no_grad(): + with torch.cpu.amp.autocast(enabled=infer_type == "bf16"), torch.no_grad(): unet_trace_model = torch.jit.trace(self.unet, unet_input_example, check_trace=False, strict=False) unet_trace_model = torch.jit.freeze(unet_trace_model) self.unet.forward = unet_trace_model.forward # trace vae.decoder model to get better performance on IPEX - with torch.cpu.amp.autocast(enabled=infer_type=='bf16'), torch.no_grad(): - ave_decoder_trace_model = torch.jit.trace(self.vae.decoder, vae_decoder_input_example, check_trace=False, strict=False) + with torch.cpu.amp.autocast(enabled=infer_type == "bf16"), torch.no_grad(): + ave_decoder_trace_model = torch.jit.trace( + self.vae.decoder, vae_decoder_input_example, check_trace=False, strict=False + ) ave_decoder_trace_model = torch.jit.freeze(ave_decoder_trace_model) self.vae.decoder.forward = ave_decoder_trace_model.forward - - - def enable_vae_slicing(self): r""" Enable sliced VAE decoding. @@ -459,7 +480,6 @@ def _encode_prompt( ) prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape @@ -767,7 +787,7 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - + # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -777,11 +797,7 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds - )['sample'] + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds)["sample"] # perform guidance if do_classifier_free_guidance: @@ -811,7 +827,7 @@ def __call__( # 10. Convert to PIL image = self.numpy_to_pil(image) else: - + # 8. Post-processing image = self.decode_latents(latents) @@ -826,4 +842,3 @@ def __call__( return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - From 3c63137a7d9c7c4caf9f8ab89d9f384a5ae743f5 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 18 Apr 2023 23:47:31 +0800 Subject: [PATCH 04/13] reformat --- examples/community/stable_diffusion_ipex.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/examples/community/stable_diffusion_ipex.py b/examples/community/stable_diffusion_ipex.py index 85aba627e7ce..e93e1aeb8aa0 100644 --- a/examples/community/stable_diffusion_ipex.py +++ b/examples/community/stable_diffusion_ipex.py @@ -15,12 +15,16 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union +import intel_extension_for_pytorch as ipex import torch from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( deprecate, @@ -30,11 +34,6 @@ randn_tensor, replace_example_docstring, ) -from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker - -import intel_extension_for_pytorch as ipex logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -173,7 +172,6 @@ def __init__( self.register_to_config(requires_safety_checker=requires_safety_checker) def get_input_example(self, prompt, height=None, width=None, guidance_scale=7.5, num_images_per_prompt=1): - prompt_embeds = None negative_prompt_embeds = None negative_prompt = None @@ -237,7 +235,7 @@ def prepare_for_ipex(self, promt, infer_type="bf16", height=None, width=None, gu self.unet = self.unet.to(memory_format=torch.channels_last) self.vae.decoder = self.vae.decoder.to(memory_format=torch.channels_last) self.text_encoder = self.text_encoder.to(memory_format=torch.channels_last) - if self.safety_checker != None: + if self.safety_checker is not None: self.safety_checker = self.safety_checker.to(memory_format=torch.channels_last) unet_input_example, vae_decoder_input_example = self.get_input_example(promt, height, width, guidance_scale) @@ -249,7 +247,7 @@ def prepare_for_ipex(self, promt, infer_type="bf16", height=None, width=None, gu ) self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.bfloat16, inplace=True) self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.bfloat16, inplace=True) - if self.safety_checker != None: + if self.safety_checker is not None: self.safety_checker = ipex.optimize(self.safety_checker.eval(), dtype=torch.bfloat16, inplace=True) elif infer_type == "fp32": self.unet = ipex.optimize( @@ -277,7 +275,7 @@ def prepare_for_ipex(self, promt, infer_type="bf16", height=None, width=None, gu weights_prepack=True, auto_kernel_selection=False, ) - if self.safety_checker != None: + if self.safety_checker is not None: self.safety_checker = ipex.optimize( self.safety_checker.eval(), dtype=torch.float32, @@ -287,7 +285,7 @@ def prepare_for_ipex(self, promt, infer_type="bf16", height=None, width=None, gu auto_kernel_selection=False, ) else: - raise ValueError(f" The value of infer_type should be 'bf16' or 'fp32' !") + raise ValueError(" The value of infer_type should be 'bf16' or 'fp32' !") # trace unet model to get better performance on IPEX with torch.cpu.amp.autocast(enabled=infer_type == "bf16"), torch.no_grad(): @@ -817,7 +815,6 @@ def __call__( image = latents has_nsfw_concept = None elif output_type == "pil": - # 8. Post-processing image = self.decode_latents(latents) @@ -827,7 +824,6 @@ def __call__( # 10. Convert to PIL image = self.numpy_to_pil(image) else: - # 8. Post-processing image = self.decode_latents(latents) From d047fed357cdc2600cd131feedfcfec42b135b8b Mon Sep 17 00:00:00 2001 From: yingjieh Date: Thu, 18 May 2023 09:32:12 +0800 Subject: [PATCH 05/13] Update examples/community/README.md Co-authored-by: Pedro Cuenca --- examples/community/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/README.md b/examples/community/README.md index ffd49a83ada0..e0b1508b00d0 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -1141,7 +1141,7 @@ To use this pipeline, You need to: ```python python -m pip install intel_extension_for_pytorch ``` -2. After pipeline initialization, prepare_for_ipex() should be called to enable IPEX accelaration. +2. After pipeline initialization, `prepare_for_ipex()` should be called to enable IPEX accelaration. ```python pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_ipex") pipe.prepare_for_ipex(prompt,infer_type='bf16') From 4059ae00c0703832255ead0fb8dc4e059eaac875 Mon Sep 17 00:00:00 2001 From: yingjieh Date: Thu, 18 May 2023 09:33:44 +0800 Subject: [PATCH 06/13] Update examples/community/README.md Co-authored-by: Pedro Cuenca --- examples/community/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/README.md b/examples/community/README.md index e0b1508b00d0..97a2b8145b03 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -1149,7 +1149,7 @@ pipe.prepare_for_ipex(prompt,infer_type='bf16') Other usage of this ipex pipeline is same as the default stable diffusion pipeline. -Following code compares the performance of original stable diffusion pipeline with ipex pipeline. +The following code compares the performance of the original stable diffusion pipeline with the ipex-optimized pipeline. ```python import torch From d36aec2ead659c6d679bf16977446f7cadfedcee Mon Sep 17 00:00:00 2001 From: yingjieh Date: Thu, 18 May 2023 09:34:04 +0800 Subject: [PATCH 07/13] Update examples/community/README.md Co-authored-by: Pedro Cuenca --- examples/community/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/README.md b/examples/community/README.md index 97a2b8145b03..092fb4272979 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -1196,7 +1196,7 @@ pipe3.prepare_for_ipex(prompt,infer_type='fp32') #2.Original Pipeline initialization pipe4 = StableDiffusionPipeline.from_pretrained(model_id) -#3.Compare performance between Original Pipeline and IPEX Pipeline +# 3. Compare performance between Original Pipeline and IPEX Pipeline with torch.no_grad(): latency = elapsed_time(pipe3) print("Latency of StableDiffusionIPEXPipeline--fp32", latency) From 5fa265de86373e150ab75c7772c651527c6026e0 Mon Sep 17 00:00:00 2001 From: yingjieh Date: Thu, 18 May 2023 09:34:23 +0800 Subject: [PATCH 08/13] Update examples/community/README.md Co-authored-by: Pedro Cuenca --- examples/community/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/README.md b/examples/community/README.md index 092fb4272979..10961d15b8b0 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -1159,7 +1159,7 @@ import time prompt = "sailing ship in storm by Rembrandt" model_id = "runwayml/stable-diffusion-v1-5" -#Help function for time evaluation +# Helper function for time evaluation def elapsed_time(pipeline, nb_pass=3, num_inference_steps=20): # warmup for _ in range(2): From bd410a7a849aeae04db526b4198cd191d8f15d9a Mon Sep 17 00:00:00 2001 From: yingjieh Date: Thu, 18 May 2023 09:36:43 +0800 Subject: [PATCH 09/13] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- examples/community/README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/community/README.md b/examples/community/README.md index 10961d15b8b0..369efbaeccf7 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -1173,14 +1173,14 @@ def elapsed_time(pipeline, nb_pass=3, num_inference_steps=20): ############## bf16 inference performance ############### -#1.IPEX Pipeline initialization +# 1. IPEX Pipeline initialization pipe = DiffusionPipeline.from_pretrained(model_id, custom_pipeline="stable_diffusion_ipex") pipe.prepare_for_ipex(prompt,infer_type='bf16') -#2.Original Pipeline initialization +# 2. Original Pipeline initialization pipe2 = StableDiffusionPipeline.from_pretrained(model_id) -#3.Compare performance between Original Pipeline and IPEX Pipeline +# 3. Compare performance between Original Pipeline and IPEX Pipeline with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16): latency = elapsed_time(pipe) print("Latency of StableDiffusionIPEXPipeline--bf16", latency) @@ -1189,11 +1189,11 @@ with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16) ############## fp32 inference performance ############### -#1.IPEX Pipeline initialization +# 1. IPEX Pipeline initialization pipe3 = DiffusionPipeline.from_pretrained(model_id, custom_pipeline="stable_diffusion_ipex") pipe3.prepare_for_ipex(prompt,infer_type='fp32') -#2.Original Pipeline initialization +# 2. Original Pipeline initialization pipe4 = StableDiffusionPipeline.from_pretrained(model_id) # 3. Compare performance between Original Pipeline and IPEX Pipeline From ed81aa2bdbb699be7695521becd178036695c1a9 Mon Sep 17 00:00:00 2001 From: yingjieh Date: Fri, 19 May 2023 15:20:09 +0800 Subject: [PATCH 10/13] Update README.md --- examples/community/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/community/README.md b/examples/community/README.md index eb0adc913deb..535cb938f3d1 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -1139,6 +1139,7 @@ This diffusion pipeline can accelarate the inference of Stable-Diffusion on Inte To use this pipeline, You need to: 1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch) **Note:** For each PyTorch release, there is a corresponding release of the IPEX. Here are the mapping relationship.It is recommanded to install Pytorch/IPEX2.0 to get the best performance. + |PyTorch Version|IPEX Version| |--|--| |[v2.0.\*](https://github.com/pytorch/pytorch/tree/v2.0.1 "v2.0.1")|[v2.0.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v2.0.100+cpu)| @@ -1224,4 +1225,4 @@ print("Latency of StableDiffusionIPEXPipeline--fp32", latency) latency = elapsed_time(pipe4) print("Latency of StableDiffusionPipeline--fp32",latency) -``` \ No newline at end of file +``` From ddbe4e15243ee74289869d4ea44f586d83939557 Mon Sep 17 00:00:00 2001 From: yingjieh Date: Fri, 19 May 2023 15:33:47 +0800 Subject: [PATCH 11/13] Update README.md --- examples/community/README.md | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/community/README.md b/examples/community/README.md index 535cb938f3d1..625f9b3c0d90 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -1134,10 +1134,11 @@ Output Image ### Stable Diffusion on IPEX -This diffusion pipeline can accelarate the inference of Stable-Diffusion on Intel CPUs with BF16/FP32 precision by [IPEX](https://github.com/intel/intel-extension-for-pytorch). +This diffusion pipeline aims to accelarate the inference of Stable-Diffusion on Intel Xeon CPUs with BF16/FP32 precision by [IPEX](https://github.com/intel/intel-extension-for-pytorch). To use this pipeline, You need to: 1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch) + **Note:** For each PyTorch release, there is a corresponding release of the IPEX. Here are the mapping relationship.It is recommanded to install Pytorch/IPEX2.0 to get the best performance. |PyTorch Version|IPEX Version| @@ -1145,15 +1146,17 @@ To use this pipeline, You need to: |[v2.0.\*](https://github.com/pytorch/pytorch/tree/v2.0.1 "v2.0.1")|[v2.0.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v2.0.100+cpu)| |[v1.13.\*](https://github.com/pytorch/pytorch/tree/v1.13.0 "v1.13.0")|[v1.13.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v1.13.100+cpu)| +You can use normal pip command to install IPEX with the latest version. ```python python -m pip install intel_extension_for_pytorch ``` **Note:** To install a specific version, run with the following command: ``` -python -m pip install == -f https://developer.intel.com/ipex-whl-stable-cpu +python -m pip install intel_extension_for_pytorch== -f https://developer.intel.com/ipex-whl-stable-cpu ``` 2. After pipeline initialization, `prepare_for_ipex()` should be called to enable IPEX accelaration. Supported inference datatypes are Float32 and BFloat16. + **Note:** The setting of generated image height/width for `prepare_for_ipex()` should be same as the setting of pipeline inference. ```python pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_ipex") @@ -1163,13 +1166,13 @@ pipe.prepare_for_ipex(prompt, dtype=torch.float32, height=512, width=512) #value pipe.prepare_for_ipex(prompt, dtype=torch.bfloat16, height=512, width=512) #value of image height/width should be consistent with the pipeline inference ``` -Then you can use the ipex pipeline in similar way as the default stable diffusion pipeline. +Then you can use the ipex pipeline in a similar way to the default stable diffusion pipeline. ```python # For Float32 -image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()' +image = pipe(prompt, num_inference_steps=20, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()' # For BFloat16 with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16): - image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()' + image = pipe(prompt, num_inference_steps=20, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()' ``` The following code compares the performance of the original stable diffusion pipeline with the ipex-optimized pipeline. From e7d54a1c41e4cb158b4cd4ebe60dca627e239885 Mon Sep 17 00:00:00 2001 From: yingjieh Date: Tue, 23 May 2023 15:41:00 +0800 Subject: [PATCH 12/13] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- examples/community/README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/community/README.md b/examples/community/README.md index 625f9b3c0d90..f63551106504 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -31,7 +31,7 @@ MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | | DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - |[Aengus (Duc-Anh)](https://github.com/aengusng8) | | CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) | -| Stable Diffusion IPEX Pipeline | Accelerate Stable Diffusion inference pipeline with BF16/FP32 precision on Intel CPUs by [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion on IPEX](#stable-diffusion-on-ipex) | - | [Yingjie Han](https://github.com/yingjie-han/) | +| Stable Diffusion IPEX Pipeline | Accelerate Stable Diffusion inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion on IPEX](#stable-diffusion-on-ipex) | - | [Yingjie Han](https://github.com/yingjie-han/) | To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. @@ -1134,19 +1134,19 @@ Output Image ### Stable Diffusion on IPEX -This diffusion pipeline aims to accelarate the inference of Stable-Diffusion on Intel Xeon CPUs with BF16/FP32 precision by [IPEX](https://github.com/intel/intel-extension-for-pytorch). +This diffusion pipeline aims to accelarate the inference of Stable-Diffusion on Intel Xeon CPUs with BF16/FP32 precision using [IPEX](https://github.com/intel/intel-extension-for-pytorch). -To use this pipeline, You need to: +To use this pipeline, you need to: 1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch) -**Note:** For each PyTorch release, there is a corresponding release of the IPEX. Here are the mapping relationship.It is recommanded to install Pytorch/IPEX2.0 to get the best performance. +**Note:** For each PyTorch release, there is a corresponding release of the IPEX. Here is the mapping relationship. It is recommended to install Pytorch/IPEX2.0 to get the best performance. |PyTorch Version|IPEX Version| |--|--| |[v2.0.\*](https://github.com/pytorch/pytorch/tree/v2.0.1 "v2.0.1")|[v2.0.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v2.0.100+cpu)| |[v1.13.\*](https://github.com/pytorch/pytorch/tree/v1.13.0 "v1.13.0")|[v1.13.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v1.13.100+cpu)| -You can use normal pip command to install IPEX with the latest version. +You can simply use pip to install IPEX with the latest version. ```python python -m pip install intel_extension_for_pytorch ``` From 5091e6f830d710819b88fe9fa05642b94e3efb2b Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 23 May 2023 10:02:30 +0200 Subject: [PATCH 13/13] style --- examples/community/stable_diffusion_ipex.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/community/stable_diffusion_ipex.py b/examples/community/stable_diffusion_ipex.py index 3d10361e0323..9abe16d56f10 100644 --- a/examples/community/stable_diffusion_ipex.py +++ b/examples/community/stable_diffusion_ipex.py @@ -48,13 +48,13 @@ >>> # For Float32 >>> pipe.prepare_for_ipex(prompt, dtype=torch.float32, height=512, width=512) #value of image height/width should be consistent with the pipeline inference - >>> # For BFloat16 + >>> # For BFloat16 >>> pipe.prepare_for_ipex(prompt, dtype=torch.bfloat16, height=512, width=512) #value of image height/width should be consistent with the pipeline inference >>> prompt = "a photo of an astronaut riding a horse on mars" >>> # For Float32 >>> image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()' - >>> # For BFloat16 + >>> # For BFloat16 >>> with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16): >>> image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()' ```