Skip to content

Fix GAN training exmaple using DDP due to find_unused_parameters #20866

Open
@samsara-ku

Description

@samsara-ku

Bug description

Ref #20685 #20328

Almost of codes are from this simple example

I already submitted few issues for above simple GAN training codes with DDP; whatever you use every parameters in models, the training cannot run.

Based on #18740, I check the procedure of pytorch lightning with DDP and just simple DDP codes without any other platform. The key difference is pytorch lightning wrap the whole trainer containing gen and disc, which is not the case of simple torch ver.

pytorch-lightning ver.
DDP(trainer) -> trainer has two models; generator and discriminator

simple torch ver.
DDP(generator), DDP(discrimiantor)

I think wraping the DDP on a pl.trainer makes all parameters calculated in one self.manual_backward function, and this is also following with the observation of #18740 because each gen_loss.backward and dis_loss.backward doesn't use each other's parameters during the optimizing.

So I just make a simple custion DDP options to prevent this problem; although I'm not a expert of this kind of things, it looks fine now.

You can see the custom DDP option in the below section.

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

import os
import cv2

os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
NUM_WORKERS = int(os.cpu_count() / 2)


class MNISTDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str = PATH_DATASETS,
        batch_size: int = BATCH_SIZE,
        num_workers: int = NUM_WORKERS,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.mnist_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)


class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super().__init__()
        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.01, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh(),
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity


class GAN(pl.LightningModule):
    def __init__(
        self,
        channels,
        width,
        height,
        latent_dim: int = 100,
        lr: float = 0.0002,
        b1: float = 0.5,
        b2: float = 0.999,
        batch_size: int = BATCH_SIZE,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False

        # networks
        data_shape = (channels, width, height)
        self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)
        self.discriminator = Discriminator(img_shape=data_shape)

        self.validation_z = torch.randn(8, self.hparams.latent_dim)

        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch):
        imgs, _ = batch

        optimizer_g, optimizer_d = self.optimizers()

        # sample noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)

        # train generator
        # generate images
        self.toggle_optimizer(optimizer_g)
        self.generated_imgs = self(z)

        # log sampled images
        sample_imgs = self.generated_imgs[:6]
        grid = torchvision.utils.make_grid(sample_imgs)
        # self.logger.experiment.add_image("train/generated_images", grid, self.current_epoch)

        # ground truth result (ie: all fake)
        # put on GPU because we created this tensor inside training_loop
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)

        # adversarial loss is binary cross-entropy
        g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
        self.log("g_loss", g_loss, prog_bar=True)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()

        # ! this kind of training has no problem
        # g_loss.backward()
        # optimizer_g.optimizer.step()
        # optimizer_g.optimizer.zero_grad()

        self.untoggle_optimizer(optimizer_g)

        # train discriminator
        # Measure discriminator's ability to classify real from generated samples
        self.toggle_optimizer(optimizer_d)

        # how well can it label as real?
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)

        real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

        # how well can it label as fake?
        fake = torch.zeros(imgs.size(0), 1)
        fake = fake.type_as(imgs)

        gen_value = torch.tensor(self.generated_imgs.detach().cpu().numpy()).cuda()

        fake_loss = self.adversarial_loss(self.discriminator(gen_value), fake)

        # discriminator loss is the average of these
        d_loss = (real_loss + fake_loss) / 2
        self.log("d_loss", d_loss, prog_bar=True)
        self.manual_backward(d_loss)
        optimizer_d.step()
        optimizer_d.zero_grad()

        # ! this kind of training has no problem
        # d_loss.backward()
        # optimizer_d.optimizer.step()
        # optimizer_d.optimizer.zero_grad()

        self.untoggle_optimizer(optimizer_d)

    def validation_step(self, batch, batch_idx):
        pass

    def configure_optimizers(self):
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=self.hparams.lr)
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.hparams.lr)

        return [opt_g, opt_d], []

    @torch.no_grad()
    def on_validation_epoch_end(self, save_size=128):
        if self.current_epoch % 10 != 0:
            return

        self.generator.eval(), self.discriminator.eval()

        os.makedirs("pl_test", exist_ok=True)

        z = self.validation_z.type_as(self.generator.module.model[0].weight)

        sample_imgs = self(z)

        result = []

        for idx in range(sample_imgs.shape[0]):
            cur_img = sample_imgs[idx].detach().cpu().numpy().transpose(1, 2, 0) * 255
            cur_img = np.clip(cur_img, 0, 255).astype(np.uint8)
            cur_img = cv2.resize(cur_img, (save_size, save_size), interpolation=cv2.INTER_LINEAR)

            result.append(cur_img)

        result = np.array(result)
        result = result.transpose(1, 0, 2).reshape(save_size, -1)
        result = result.astype(np.uint8)

        if self.trainer.is_global_zero:
            cv2.imwrite(os.path.join("pl_test", f"epoch_{self.current_epoch}.png"), result)


if __name__ == "__main__":
    from custom import TwoDDPStrategy

    dm = MNISTDataModule()
    model = GAN(*dm.dims)
    trainer = pl.Trainer(
        accelerator="auto",
        devices=[2, 3],
        strategy=TwoDDPStrategy(),
        max_epochs=200,
    )
    trainer.fit(model, dm)

...
...

import torch

from contextlib import nullcontext
from pytorch_lightning import Trainer
from pytorch_lightning.strategies.ddp import DDPStrategy
from torch.nn.parallel import DistributedDataParallel


class TwoDDPStrategy(DDPStrategy):
    """
    Ref pytorch_lightning.strategies.ddp.DDPStrategy

        def configure_ddp(self) -> None:
            log.debug(f"{self.__class__.__name__}: configuring DistributedDataParallel")
            assert isinstance(self.model, pl.LightningModule)
            self.model = self._setup_model(self.model) <--- maybe problem?
            self._register_ddp_hooks()

        ....


        def _setup_model(self, model: Module) -> DistributedDataParallel:
            'Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.'
            device_ids = self.determine_ddp_device_ids()
            log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
            # https://pytorch.org/docs/stable/notes/cuda.html#id5
            ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
            with ctx:
                return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
    """

    def configure_ddp(self) -> None:
        device_ids = self.determine_ddp_device_ids()

        ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()

        with ctx:
            self.model.generator = DistributedDataParallel(
                self.model.generator,
                device_ids=device_ids,
                find_unused_parameters=False,
            )
            self.model.discriminator = DistributedDataParallel(
                self.model.discriminator,
                device_ids=device_ids,
                find_unused_parameters=False,
            )

        return self.model

Error messages and logs

# Error messages and logs here please

Environment

Current environment
aiohappyeyeballs         2.6.1
aiohttp                  3.11.14
aiosignal                1.3.2
annotated-types          0.7.0
antlr4-python3-runtime   4.9.3
anykeystore              0.2
apex                     0.9.10.dev0
attrs                    25.3.0
certifi                  2025.1.31
charset-normalizer       3.4.1
click                    8.1.8
cryptacular              1.6.2
decorator                4.4.2
defusedxml               0.7.1
docker-pycreds           0.4.0
facenet-pytorch          2.6.0
filelock                 3.18.0
flow-vis                 0.1
frozenlist               1.5.0
fsspec                   2025.3.0
gitdb                    4.0.12
GitPython                3.1.44
greenlet                 3.0.3
h5py                     3.11.0
hupper                   1.12.1
idna                     3.10
Jinja2                   3.1.6
lightning-utilities      0.14.1
MarkupSafe               3.0.2
mpmath                   1.3.0
multidict                6.2.0
munkres                  1.1.4
munkres                  1.1.4
natsort                  8.4.0
natsort                  8.4.0
networkx                 3.4.2
numpy                    1.26.4
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu12     12.1.0.106
nvidia-cusparselt-cu12   0.6.2
nvidia-nccl-cu12         2.19.3
nvidia-nvjitlink-cu12    12.4.127
nvidia-nvtx-cu12         12.1.105
oauthlib                 3.2.2
omegaconf                2.3.0
opencv-python            4.11.0.86
packaging                24.2
PasteDeploy              3.1.0
pbkdf2                   1.3
pillow                   10.2.0
pip                      25.0
plaster                  1.1.2
plaster-pastedeploy      1.0.1
platformdirs             4.3.6
proglog                  0.1.10
propcache                0.3.0
protobuf                 5.29.3
psutil                   7.0.0
pyav                     11.4.1
pycocotools              2.0.8
pycocotools              2.0.8
pydantic                 2.10.6
pydantic_core            2.27.2
pyramid                  2.0.2
pyramid-mailer           0.15.1
python3-openid           3.2.0
pytorch-lightning        2.5.0.post0
PyYAML                   6.0.2
repoze.sendmail          4.4.1
requests                 2.32.3
requests-oauthlib        2.0.0
sentry-sdk               2.23.1
setproctitle             1.3.5
setuptools               75.8.0
six                      1.17.0
slack_sdk                3.35.0
smmap                    5.0.2
SQLAlchemy               2.0.30
sympy                    1.13.1
tensorboardX             2.6.2.2
torch                    2.2.2
torchaudio               2.2.2
torchmetrics             1.6.3
torchvision              0.17.2
tqdm                     4.67.1
transaction              4.0
translationstring        1.4
triton                   2.2.0
typing_extensions        4.12.2
urllib3                  2.3.0
velruse                  1.1.1
venusian                 3.1.0
wandb                    0.19.8
WebOb                    1.8.7
wheel                    0.45.1
WTForms                  3.1.2
wtforms-recaptcha        0.3.2
yarl                     1.18.3
zope.deprecation         5.0
zope.interface           6.4.post2
zope.sqlalchemy          3.1

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions