-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Description
Describe the bug
I was trying to train a model in DDPM pipeline by using EDMEulerScheduler . I noticed that in EDMEulerScheduler the noise_scheduler.timesteps is not integer but is float. However, during training we had to sample the timesteps as an argument to passed to the Unet model, and here is the problem, when EDMEulerScheduler call add_noise
function, an error occurred in step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
. Then I found that this problem is caused by indices = (schedule_timesteps == timestep).nonzero()
in self.index_for_timestep
. Since the schedule_timesteps is floating value and the timestep is integer, the len of variable indices is 0.
noise_scheduler.timesteps
tensor([ 1.0955e+00, 1.0941e+00, 1.0928e+00, 1.0914e+00, 1.0900e+00,
1.0887e+00, 1.0873e+00, 1.0859e+00, 1.0845e+00, 1.0832e+00,
1.0818e+00, 1.0804e+00, 1.0790e+00, 1.0777e+00, 1.0763e+00,
1.0749e+00, 1.0735e+00, 1.0721e+00, 1.0707e+00, 1.0694e+00,
1.0680e+00, 1.0666e+00, 1.0652e+00, 1.0638e+00, 1.0624e+00,
1.0610e+00, 1.0596e+00, 1.0582e+00, 1.0568e+00, 1.0554e+00,
1.0540e+00, 1.0526e+00, 1.0512e+00, 1.0498e+00, 1.0484e+00,
1.0470e+00, 1.0456e+00, 1.0442e+00, 1.0428e+00, 1.0414e+00,
1.0400e+00, 1.0386e+00, 1.0372e+00, 1.0357e+00, 1.0343e+00,
1.0329e+00, 1.0315e+00, 1.0301e+00, 1.0287e+00, 1.0272e+00,
1.0258e+00, 1.0244e+00, 1.0230e+00, 1.0216e+00, 1.0201e+00,
1.0187e+00, 1.0173e+00, 1.0158e+00, 1.0144e+00, 1.0130e+00,
1.0116e+00, 1.0101e+00, 1.0087e+00, 1.0072e+00, 1.0058e+00,
1.0044e+00, 1.0029e+00, 1.0015e+00, 1.0000e+00, 9.9860e-01,
9.9716e-01, 9.9571e-01, 9.9426e-01, 9.9282e-01, 9.9137e-01,
9.8992e-01, 9.8846e-01, 9.8701e-01, 9.8556e-01, 9.8410e-01,
9.8264e-01, 9.8119e-01, 9.7973e-01, 9.7827e-01, 9.7681e-01,
9.7534e-01, 9.7388e-01, 9.7241e-01, 9.7095e-01, 9.6948e-01,
9.6801e-01, 9.6654e-01, 9.6507e-01, 9.6360e-01, 9.6212e-01,
9.6065e-01, 9.5917e-01, 9.5769e-01, 9.5622e-01, 9.5474e-01,
9.5325e-01, 9.5177e-01, 9.5029e-01, 9.4880e-01, 9.4732e-01,
9.4583e-01, 9.4434e-01, 9.4285e-01, 9.4136e-01, 9.3987e-01,
9.3837e-01, 9.3688e-01, 9.3538e-01, 9.3388e-01, 9.3238e-01,
9.3088e-01, 9.2938e-01, 9.2788e-01, 9.2637e-01, 9.2487e-01,
9.2336e-01, 9.2185e-01, 9.2034e-01, 9.1883e-01, 9.1732e-01,
...
-1.4107e+00, -1.4164e+00, -1.4221e+00, -1.4279e+00, -1.4337e+00,
-1.4395e+00, -1.4453e+00, -1.4512e+00, -1.4570e+00, -1.4629e+00,
-1.4688e+00, -1.4748e+00, -1.4807e+00, -1.4867e+00, -1.4926e+00,
-1.4987e+00, -1.5047e+00, -1.5107e+00, -1.5168e+00, -1.5229e+00,
-1.5290e+00, -1.5351e+00, -1.5413e+00, -1.5475e+00, -1.5537e+00])
timesteps = torch.randint( 0, 1000, (batch_size,), device=clean_images.device, dtype=torch.int64)
timesteps
tensor([577, 793, 397, 291, 522, 668, 928, 616, 760, 528, 178, 58, 608, 633,
26, 505, 294, 233, 95, 830, 534, 44, 892, 833, 770, 595, 256, 515,
942, 450, 456, 892])
Reproduction
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from diffusers import EDMEulerScheduler
import torch
preprocess = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def transform(examples):
images = [preprocess(image.convert("L")) for image in examples["image"]]
return {"images": images}
dataset = load_dataset("mnist", split="train")
dataset.set_transform(transform)
train_dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, pin_memory=True, num_workers=0)
noise_scheduler = EDMEulerScheduler(num_train_timesteps=1000)
for step, batch in enumerate(train_dataloader):
clean_images = batch["images"]
# Sample noise to add to the images
noise = torch.randn(clean_images.shape, device=clean_images.device)
bs = clean_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, 1000, (bs,), device=clean_images.device,
dtype=torch.int64
)
# Add noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
Logs
IndexError Traceback (most recent call last)
Input In [6], in <cell line: 22>()
29 timesteps = torch.randint(
30 0, 1000, (bs,), device=clean_images.device,
31 dtype=torch.int64
32 )
34 # Add noise to the clean images according to the noise magnitude at each timestep
35 # (this is the forward diffusion process)
---> 36 noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
File c:\Users\User\miniconda3\envs\Pytorch\lib\site-packages\diffusers\schedulers\scheduling_edm_euler.py:369, in EDMEulerScheduler.add_noise(self, original_samples, noise, timesteps)
367 # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
368 if self.begin_index is None:
--> 369 step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
370 else:
371 step_indices = [self.begin_index] * timesteps.shape[0]
File c:\Users\User\miniconda3\envs\Pytorch\lib\site-packages\diffusers\schedulers\scheduling_edm_euler.py:369, in <listcomp>(.0)
367 # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
368 if self.begin_index is None:
--> 369 step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
370 else:
371 step_indices = [self.begin_index] * timesteps.shape[0]
File c:\Users\User\miniconda3\envs\Pytorch\lib\site-packages\diffusers\schedulers\scheduling_edm_euler.py:239, in EDMEulerScheduler.index_for_timestep(self, timestep, schedule_timesteps)
233 # The sigma index that is taken for the **very** first `step`
234 # is always the second index (or the last index if there is only 1)
235 # This way we can ensure we don't accidentally skip a sigma in
236 # case we start in the middle of the denoising schedule (e.g. for image-to-image)
237 pos = 1 if len(indices) > 1 else 0
--> 239 return indices[pos].item()
IndexError: index 0 is out of bounds for dimension 0 with size 0
System Info
As I cannot run the command successfully, I try to provide the info as possible.
diffusers 0.27.1 pyhd8ed1ab_0 conda-forge
OS Windows 11
Python 3.10.6
pytorch 2.2.1 py3.10_cuda11.8_cudnn8_0 pytorch
huggingface-hub 0.20.1 pypi_0 pypi