Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,13 +430,14 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)

# 4. Prepare timesteps

# sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)

# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
effective_batch_size = batch_size * num_images_per_prompt
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
Expand All @@ -450,21 +451,15 @@ def __call__(

# 6. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
dt = 1.0 / num_inference_steps
dt = (
torch.tensor([dt] * effective_batch_size)
.to(self.device)
.view([effective_batch_size, *([1] * len(latents.shape[1:]))])
)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(range(num_inference_steps, 0, -1)):
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

# aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
t = t / num_inference_steps
timestep = (
torch.tensor([t]).expand(latent_model_input.shape[0]).to(latents.device, dtype=latents.dtype)
)
timestep = torch.tensor([t / 1000]).expand(latent_model_input.shape[0])
timestep = timestep.to(latents.device, dtype=latents.dtype)

# predict noise model_output
noise_pred = self.transformer(
Expand All @@ -480,7 +475,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = (latents - dt * noise_pred).to(latents.dtype)
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
25 changes: 16 additions & 9 deletions src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -158,7 +158,12 @@ def scale_noise(
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps

def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).

Expand All @@ -168,17 +173,19 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps

timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
)
if sigmas is None:
self.num_inference_steps = num_inference_steps
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
)

sigmas = timesteps / self.config.num_train_timesteps
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
sigmas = timesteps / self.config.num_train_timesteps
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)

sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps

self.timesteps = timesteps.to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])

Expand Down