Description
Bug description
Almost of codes are from this simple example
I already submitted few issues for above simple GAN training codes with DDP; whatever you use every parameters in models, the training cannot run.
Based on #18740, I check the procedure of pytorch lightning with DDP and just simple DDP codes without any other platform. The key difference is pytorch lightning wrap the whole trainer containing gen and disc, which is not the case of simple torch ver.
pytorch-lightning ver.
DDP(trainer) -> trainer has two models; generator and discriminator
simple torch ver.
DDP(generator), DDP(discrimiantor)
I think wraping the DDP on a pl.trainer
makes all parameters calculated in one self.manual_backward
function, and this is also following with the observation of #18740 because each gen_loss.backward
and dis_loss.backward
doesn't use each other's parameters during the optimizing.
So I just make a simple custion DDP options to prevent this problem; although I'm not a expert of this kind of things, it looks fine now.
You can see the custom DDP option in the below section.
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
import os
import cv2
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
NUM_WORKERS = int(os.cpu_count() / 2)
class MNISTDataModule(pl.LightningDataModule):
def __init__(
self,
data_dir: str = PATH_DATASETS,
batch_size: int = BATCH_SIZE,
num_workers: int = NUM_WORKERS,
):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)
self.dims = (1, 28, 28)
self.num_classes = 10
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(
self.mnist_train,
batch_size=self.batch_size,
num_workers=self.num_workers,
)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)
class Generator(nn.Module):
def __init__(self, latent_dim, img_shape):
super().__init__()
self.img_shape = img_shape
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.01, inplace=True))
return layers
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh(),
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return img
class Discriminator(nn.Module):
def __init__(self, img_shape):
super().__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
class GAN(pl.LightningModule):
def __init__(
self,
channels,
width,
height,
latent_dim: int = 100,
lr: float = 0.0002,
b1: float = 0.5,
b2: float = 0.999,
batch_size: int = BATCH_SIZE,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
self.automatic_optimization = False
# networks
data_shape = (channels, width, height)
self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)
self.discriminator = Discriminator(img_shape=data_shape)
self.validation_z = torch.randn(8, self.hparams.latent_dim)
self.example_input_array = torch.zeros(2, self.hparams.latent_dim)
def forward(self, z):
return self.generator(z)
def adversarial_loss(self, y_hat, y):
return F.binary_cross_entropy(y_hat, y)
def training_step(self, batch):
imgs, _ = batch
optimizer_g, optimizer_d = self.optimizers()
# sample noise
z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
z = z.type_as(imgs)
# train generator
# generate images
self.toggle_optimizer(optimizer_g)
self.generated_imgs = self(z)
# log sampled images
sample_imgs = self.generated_imgs[:6]
grid = torchvision.utils.make_grid(sample_imgs)
# self.logger.experiment.add_image("train/generated_images", grid, self.current_epoch)
# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)
# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
self.log("g_loss", g_loss, prog_bar=True)
self.manual_backward(g_loss)
optimizer_g.step()
optimizer_g.zero_grad()
# ! this kind of training has no problem
# g_loss.backward()
# optimizer_g.optimizer.step()
# optimizer_g.optimizer.zero_grad()
self.untoggle_optimizer(optimizer_g)
# train discriminator
# Measure discriminator's ability to classify real from generated samples
self.toggle_optimizer(optimizer_d)
# how well can it label as real?
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)
real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
# how well can it label as fake?
fake = torch.zeros(imgs.size(0), 1)
fake = fake.type_as(imgs)
gen_value = torch.tensor(self.generated_imgs.detach().cpu().numpy()).cuda()
fake_loss = self.adversarial_loss(self.discriminator(gen_value), fake)
# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2
self.log("d_loss", d_loss, prog_bar=True)
self.manual_backward(d_loss)
optimizer_d.step()
optimizer_d.zero_grad()
# ! this kind of training has no problem
# d_loss.backward()
# optimizer_d.optimizer.step()
# optimizer_d.optimizer.zero_grad()
self.untoggle_optimizer(optimizer_d)
def validation_step(self, batch, batch_idx):
pass
def configure_optimizers(self):
opt_g = torch.optim.Adam(self.generator.parameters(), lr=self.hparams.lr)
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.hparams.lr)
return [opt_g, opt_d], []
@torch.no_grad()
def on_validation_epoch_end(self, save_size=128):
if self.current_epoch % 10 != 0:
return
self.generator.eval(), self.discriminator.eval()
os.makedirs("pl_test", exist_ok=True)
z = self.validation_z.type_as(self.generator.module.model[0].weight)
sample_imgs = self(z)
result = []
for idx in range(sample_imgs.shape[0]):
cur_img = sample_imgs[idx].detach().cpu().numpy().transpose(1, 2, 0) * 255
cur_img = np.clip(cur_img, 0, 255).astype(np.uint8)
cur_img = cv2.resize(cur_img, (save_size, save_size), interpolation=cv2.INTER_LINEAR)
result.append(cur_img)
result = np.array(result)
result = result.transpose(1, 0, 2).reshape(save_size, -1)
result = result.astype(np.uint8)
if self.trainer.is_global_zero:
cv2.imwrite(os.path.join("pl_test", f"epoch_{self.current_epoch}.png"), result)
if __name__ == "__main__":
from custom import TwoDDPStrategy
dm = MNISTDataModule()
model = GAN(*dm.dims)
trainer = pl.Trainer(
accelerator="auto",
devices=[2, 3],
strategy=TwoDDPStrategy(),
max_epochs=200,
)
trainer.fit(model, dm)
...
...
import torch
from contextlib import nullcontext
from pytorch_lightning import Trainer
from pytorch_lightning.strategies.ddp import DDPStrategy
from torch.nn.parallel import DistributedDataParallel
class TwoDDPStrategy(DDPStrategy):
"""
Ref pytorch_lightning.strategies.ddp.DDPStrategy
def configure_ddp(self) -> None:
log.debug(f"{self.__class__.__name__}: configuring DistributedDataParallel")
assert isinstance(self.model, pl.LightningModule)
self.model = self._setup_model(self.model) <--- maybe problem?
self._register_ddp_hooks()
....
def _setup_model(self, model: Module) -> DistributedDataParallel:
'Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.'
device_ids = self.determine_ddp_device_ids()
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
with ctx:
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
"""
def configure_ddp(self) -> None:
device_ids = self.determine_ddp_device_ids()
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
with ctx:
self.model.generator = DistributedDataParallel(
self.model.generator,
device_ids=device_ids,
find_unused_parameters=False,
)
self.model.discriminator = DistributedDataParallel(
self.model.discriminator,
device_ids=device_ids,
find_unused_parameters=False,
)
return self.model
Error messages and logs
# Error messages and logs here please
Environment
Current environment
aiohappyeyeballs 2.6.1
aiohttp 3.11.14
aiosignal 1.3.2
annotated-types 0.7.0
antlr4-python3-runtime 4.9.3
anykeystore 0.2
apex 0.9.10.dev0
attrs 25.3.0
certifi 2025.1.31
charset-normalizer 3.4.1
click 8.1.8
cryptacular 1.6.2
decorator 4.4.2
defusedxml 0.7.1
docker-pycreds 0.4.0
facenet-pytorch 2.6.0
filelock 3.18.0
flow-vis 0.1
frozenlist 1.5.0
fsspec 2025.3.0
gitdb 4.0.12
GitPython 3.1.44
greenlet 3.0.3
h5py 3.11.0
hupper 1.12.1
idna 3.10
Jinja2 3.1.6
lightning-utilities 0.14.1
MarkupSafe 3.0.2
mpmath 1.3.0
multidict 6.2.0
munkres 1.1.4
munkres 1.1.4
natsort 8.4.0
natsort 8.4.0
networkx 3.4.2
numpy 1.26.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-cusparselt-cu12 0.6.2
nvidia-nccl-cu12 2.19.3
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.1.105
oauthlib 3.2.2
omegaconf 2.3.0
opencv-python 4.11.0.86
packaging 24.2
PasteDeploy 3.1.0
pbkdf2 1.3
pillow 10.2.0
pip 25.0
plaster 1.1.2
plaster-pastedeploy 1.0.1
platformdirs 4.3.6
proglog 0.1.10
propcache 0.3.0
protobuf 5.29.3
psutil 7.0.0
pyav 11.4.1
pycocotools 2.0.8
pycocotools 2.0.8
pydantic 2.10.6
pydantic_core 2.27.2
pyramid 2.0.2
pyramid-mailer 0.15.1
python3-openid 3.2.0
pytorch-lightning 2.5.0.post0
PyYAML 6.0.2
repoze.sendmail 4.4.1
requests 2.32.3
requests-oauthlib 2.0.0
sentry-sdk 2.23.1
setproctitle 1.3.5
setuptools 75.8.0
six 1.17.0
slack_sdk 3.35.0
smmap 5.0.2
SQLAlchemy 2.0.30
sympy 1.13.1
tensorboardX 2.6.2.2
torch 2.2.2
torchaudio 2.2.2
torchmetrics 1.6.3
torchvision 0.17.2
tqdm 4.67.1
transaction 4.0
translationstring 1.4
triton 2.2.0
typing_extensions 4.12.2
urllib3 2.3.0
velruse 1.1.1
venusian 3.1.0
wandb 0.19.8
WebOb 1.8.7
wheel 0.45.1
WTForms 3.1.2
wtforms-recaptcha 0.3.2
yarl 1.18.3
zope.deprecation 5.0
zope.interface 6.4.post2
zope.sqlalchemy 3.1
More info
No response