Skip to content

A Possible Bug in EDMEulerScheduler #7406

@KeepNoob

Description

@KeepNoob

Describe the bug

I was trying to train a model in DDPM pipeline by using EDMEulerScheduler . I noticed that in EDMEulerScheduler the noise_scheduler.timesteps is not integer but is float. However, during training we had to sample the timesteps as an argument to passed to the Unet model, and here is the problem, when EDMEulerScheduler call add_noise function, an error occurred in step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]. Then I found that this problem is caused by indices = (schedule_timesteps == timestep).nonzero() in self.index_for_timestep. Since the schedule_timesteps is floating value and the timestep is integer, the len of variable indices is 0.

noise_scheduler.timesteps
tensor([ 1.0955e+00,  1.0941e+00,  1.0928e+00,  1.0914e+00,  1.0900e+00,
         1.0887e+00,  1.0873e+00,  1.0859e+00,  1.0845e+00,  1.0832e+00,
         1.0818e+00,  1.0804e+00,  1.0790e+00,  1.0777e+00,  1.0763e+00,
         1.0749e+00,  1.0735e+00,  1.0721e+00,  1.0707e+00,  1.0694e+00,
         1.0680e+00,  1.0666e+00,  1.0652e+00,  1.0638e+00,  1.0624e+00,
         1.0610e+00,  1.0596e+00,  1.0582e+00,  1.0568e+00,  1.0554e+00,
         1.0540e+00,  1.0526e+00,  1.0512e+00,  1.0498e+00,  1.0484e+00,
         1.0470e+00,  1.0456e+00,  1.0442e+00,  1.0428e+00,  1.0414e+00,
         1.0400e+00,  1.0386e+00,  1.0372e+00,  1.0357e+00,  1.0343e+00,
         1.0329e+00,  1.0315e+00,  1.0301e+00,  1.0287e+00,  1.0272e+00,
         1.0258e+00,  1.0244e+00,  1.0230e+00,  1.0216e+00,  1.0201e+00,
         1.0187e+00,  1.0173e+00,  1.0158e+00,  1.0144e+00,  1.0130e+00,
         1.0116e+00,  1.0101e+00,  1.0087e+00,  1.0072e+00,  1.0058e+00,
         1.0044e+00,  1.0029e+00,  1.0015e+00,  1.0000e+00,  9.9860e-01,
         9.9716e-01,  9.9571e-01,  9.9426e-01,  9.9282e-01,  9.9137e-01,
         9.8992e-01,  9.8846e-01,  9.8701e-01,  9.8556e-01,  9.8410e-01,
         9.8264e-01,  9.8119e-01,  9.7973e-01,  9.7827e-01,  9.7681e-01,
         9.7534e-01,  9.7388e-01,  9.7241e-01,  9.7095e-01,  9.6948e-01,
         9.6801e-01,  9.6654e-01,  9.6507e-01,  9.6360e-01,  9.6212e-01,
         9.6065e-01,  9.5917e-01,  9.5769e-01,  9.5622e-01,  9.5474e-01,
         9.5325e-01,  9.5177e-01,  9.5029e-01,  9.4880e-01,  9.4732e-01,
         9.4583e-01,  9.4434e-01,  9.4285e-01,  9.4136e-01,  9.3987e-01,
         9.3837e-01,  9.3688e-01,  9.3538e-01,  9.3388e-01,  9.3238e-01,
         9.3088e-01,  9.2938e-01,  9.2788e-01,  9.2637e-01,  9.2487e-01,
         9.2336e-01,  9.2185e-01,  9.2034e-01,  9.1883e-01,  9.1732e-01,
...
        -1.4107e+00, -1.4164e+00, -1.4221e+00, -1.4279e+00, -1.4337e+00,
        -1.4395e+00, -1.4453e+00, -1.4512e+00, -1.4570e+00, -1.4629e+00,
        -1.4688e+00, -1.4748e+00, -1.4807e+00, -1.4867e+00, -1.4926e+00,
        -1.4987e+00, -1.5047e+00, -1.5107e+00, -1.5168e+00, -1.5229e+00,
        -1.5290e+00, -1.5351e+00, -1.5413e+00, -1.5475e+00, -1.5537e+00])
timesteps = torch.randint( 0, 1000, (batch_size,), device=clean_images.device, dtype=torch.int64)
timesteps
tensor([577, 793, 397, 291, 522, 668, 928, 616, 760, 528, 178,  58, 608, 633,
         26, 505, 294, 233,  95, 830, 534,  44, 892, 833, 770, 595, 256, 515,
        942, 450, 456, 892])

Reproduction

from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from diffusers import EDMEulerScheduler
import torch 

preprocess = transforms.Compose(
    [
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)
def transform(examples):
    images = [preprocess(image.convert("L")) for image in examples["image"]]
    return {"images": images}
dataset = load_dataset("mnist", split="train")
dataset.set_transform(transform)
train_dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, pin_memory=True, num_workers=0)

noise_scheduler = EDMEulerScheduler(num_train_timesteps=1000)

for step, batch in enumerate(train_dataloader):
    clean_images = batch["images"]
    # Sample noise to add to the images
    noise = torch.randn(clean_images.shape, device=clean_images.device)
    bs = clean_images.shape[0]

    # Sample a random timestep for each image
    timesteps = torch.randint(
        0, 1000, (bs,), device=clean_images.device,
        dtype=torch.int64
        )

    # Add noise to the clean images according to the noise magnitude at each timestep
    # (this is the forward diffusion process)
    noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

Logs

IndexError                                Traceback (most recent call last)
Input In [6], in <cell line: 22>()
     29 timesteps = torch.randint(
     30     0, 1000, (bs,), device=clean_images.device,
     31     dtype=torch.int64
     32     )
     34 # Add noise to the clean images according to the noise magnitude at each timestep
     35 # (this is the forward diffusion process)
---> 36 noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

File c:\Users\User\miniconda3\envs\Pytorch\lib\site-packages\diffusers\schedulers\scheduling_edm_euler.py:369, in EDMEulerScheduler.add_noise(self, original_samples, noise, timesteps)
    367 # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
    368 if self.begin_index is None:
--> 369     step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
    370 else:
    371     step_indices = [self.begin_index] * timesteps.shape[0]

File c:\Users\User\miniconda3\envs\Pytorch\lib\site-packages\diffusers\schedulers\scheduling_edm_euler.py:369, in <listcomp>(.0)
    367 # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
    368 if self.begin_index is None:
--> 369     step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
    370 else:
    371     step_indices = [self.begin_index] * timesteps.shape[0]

File c:\Users\User\miniconda3\envs\Pytorch\lib\site-packages\diffusers\schedulers\scheduling_edm_euler.py:239, in EDMEulerScheduler.index_for_timestep(self, timestep, schedule_timesteps)
    233 # The sigma index that is taken for the **very** first `step`
    234 # is always the second index (or the last index if there is only 1)
    235 # This way we can ensure we don't accidentally skip a sigma in
    236 # case we start in the middle of the denoising schedule (e.g. for image-to-image)
    237 pos = 1 if len(indices) > 1 else 0
--> 239 return indices[pos].item()

IndexError: index 0 is out of bounds for dimension 0 with size 0

System Info

As I cannot run the command successfully, I try to provide the info as possible.
diffusers 0.27.1 pyhd8ed1ab_0 conda-forge
OS Windows 11
Python 3.10.6
pytorch 2.2.1 py3.10_cuda11.8_cudnn8_0 pytorch
huggingface-hub 0.20.1 pypi_0 pypi

Who can help?

@yiyixuxu

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions