diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 97a0e234f33d..261d8e93dcef 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -430,13 +430,14 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # 4. Prepare timesteps + + # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels - effective_batch_size = batch_size * num_images_per_prompt latents = self.prepare_latents( batch_size * num_images_per_prompt, latent_channels, @@ -450,21 +451,15 @@ def __call__( # 6. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - dt = 1.0 / num_inference_steps - dt = ( - torch.tensor([dt] * effective_batch_size) - .to(self.device) - .view([effective_batch_size, *([1] * len(latents.shape[1:]))]) - ) with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(range(num_inference_steps, 0, -1)): + for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - t = t / num_inference_steps - timestep = ( - torch.tensor([t]).expand(latent_model_input.shape[0]).to(latents.device, dtype=latents.dtype) - ) + timestep = torch.tensor([t / 1000]).expand(latent_model_input.shape[0]) + timestep = timestep.to(latents.device, dtype=latents.dtype) # predict noise model_output noise_pred = self.transformer( @@ -480,7 +475,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = (latents - dt * noise_pred).to(latents.dtype) + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 83ce63981abd..779e691f0c27 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -158,7 +158,12 @@ def scale_noise( def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -168,17 +173,19 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - self.num_inference_steps = num_inference_steps - timesteps = np.linspace( - self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps - ) + if sigmas is None: + self.num_inference_steps = num_inference_steps + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) - sigmas = timesteps / self.config.num_train_timesteps - sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + sigmas = timesteps / self.config.num_train_timesteps + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) timesteps = sigmas * self.config.num_train_timesteps + self.timesteps = timesteps.to(device=device) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])