Skip to content

self.manual_backward() makes all gradients gone #20685

Open
@samsara-ku

Description

@samsara-ku

Bug description

If you try to train GAN with lightning module using multi-GPU, face some erros like this:

[rank0]: RuntimeError: It looks like your LightningModule has parameters that were not used in producing the loss returned by training_step. If this is intentional, you must enable the detection of unused parameters in DDP, either by setting the string value `strategy='ddp_find_unused_parameters_true'` or by setting the flag in the strategy with `strategy=DDPStrategy(find_unused_parameters=True)`.

For inspection, I tried to train GAN training codes with this kind of snippets in the main.py:

import os

os.environ["TORCH_CPP_LOG_LEVEL"] = "INFO"
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"

and then you can face this kind of errors:

[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.conv.weight_orig did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.conv.bias did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.3.norm.bias did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.3.norm.weight did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.3.conv.weight_orig did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.3.conv.bias did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.2.norm.bias did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.2.norm.weight did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.2.conv.weight_orig did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.2.conv.bias did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.1.norm.bias did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.1.norm.weight did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.1.conv.weight_orig did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.1.conv.bias did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.0.conv.weight_orig did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.0.conv.bias did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.conv.weight_orig did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.conv.bias did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.3.norm.bias did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.3.norm.weight did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.3.conv.weight_orig did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.3.conv.bias did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.2.norm.bias did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.2.norm.weight did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.2.conv.weight_orig did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.2.conv.bias did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.1.norm.bias did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.1.norm.weight did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.1.conv.weight_orig did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.1.conv.bias did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.0.conv.weight_orig did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.0.conv.bias did not get gradient in backwards pass.

I think this problem comes from the some wrong codes in the lightning module; if you successfully run one of the manual_backward() codes, then your the other manual_backward() codes cannot run right way since the first call might remove all the gradients of other module.

In this, someone suggests a way to call only one manual_backward() codes, but I think it would be a little different from normal training strategy of GAN:

self.manual_backward(d_loss + g_loss) --> I think this would be problem, but I cannot find no other way to solve unused parameter issue

self.toggle_optimizer(optimizer_d)
optimizer_d.step()
optimizer_d.zero_grad()
self.untoggle_optimizer(optimizer_d)
self.toggle_optimizer(optimizer_g)
optimizer_g.step()
optimizer_g.zero_grad()
self.untoggle_optimizer(optimizer_g)

Is there anyone to solve this problem?

What version are you seeing the problem on?

v2.5

How to reproduce the bug

class G_base(pl.LightningModule):
    def __init__(self, cfg):
        super().__init__()

        self.cfg = cfg
        self.gen_scales = cfg.generator.scales
        self.dis_scales = cfg.discriminator.scales

        self.generator = Generator(cfg)
        self.discriminator = Discriminator(**cfg.discriminator)

        self.automatic_optimization = False

    def training_step(self, batch):
        src_img, drv_img = batch["src"], batch["drv"]

        opt_G, opt_D = self.optimizers()

        output = self.generator(src_img, drv_img)

        ### Gen ###
        self.toggle_optimizer(opt_G)
        gan_g_loss = 0

        pyramid_generated_gen = {"prediction_1": output["gen_img"]}

        disc_map_generated_gen = self.discriminator(pyramid_generated_gen)

        for scale in self.dis_scales:
            key = "prediction_map_%s" % scale
            value = (1 - disc_map_generated_gen[key]) ** 2
            gan_g_loss += value.mean()

        opt_G.zero_grad()
        self.manual_backward(gan_g_loss)
        opt_G.step()
        self.untoggle_optimizer(opt_G)

        self.log_dict({"gan_g_loss": gan_g_loss}, prog_bar=True)

        ### Dis ###
        self.toggle_optimizer(opt_D)
        gan_d_loss = 0

        pyramid_real = {"prediction_1": drv_img}
        pyramid_generated = {"prediction_1": output["gen_img"].detach()}

        disc_map_real = self.discriminator(pyramid_real)
        disc_map_generated = self.discriminator(pyramid_generated)

        for scale in self.dis_scales:
            key = "prediction_map_%s" % scale
            value = (1 - disc_map_real[key]) ** 2 + disc_map_generated[key] ** 2
            gan_d_loss += value.mean()

        opt_D.zero_grad()
        self.manual_backward(gan_d_loss)
        opt_D.step()
        self.untoggle_optimizer(opt_D)

        self.log_dict({"gan_d_loss": gan_d_loss}, prog_bar=True)

Error messages and logs

# Error messages and logs here please

Environment

Current environment
aiohappyeyeballs         2.6.1
aiohttp                  3.11.14
aiosignal                1.3.2
annotated-types          0.7.0
antlr4-python3-runtime   4.9.3
anykeystore              0.2
apex                     0.9.10.dev0
attrs                    25.3.0
certifi                  2025.1.31
charset-normalizer       3.4.1
click                    8.1.8
cryptacular              1.6.2
decorator                4.4.2
defusedxml               0.7.1
docker-pycreds           0.4.0
facenet-pytorch          2.6.0
filelock                 3.18.0
flow-vis                 0.1
frozenlist               1.5.0
fsspec                   2025.3.0
gitdb                    4.0.12
GitPython                3.1.44
greenlet                 3.0.3
h5py                     3.11.0
hupper                   1.12.1
idna                     3.10
Jinja2                   3.1.6
lightning-utilities      0.14.1
MarkupSafe               3.0.2
mpmath                   1.3.0
multidict                6.2.0
munkres                  1.1.4
munkres                  1.1.4
natsort                  8.4.0
natsort                  8.4.0
networkx                 3.4.2
numpy                    1.26.4
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu12     12.1.0.106
nvidia-cusparselt-cu12   0.6.2
nvidia-nccl-cu12         2.19.3
nvidia-nvjitlink-cu12    12.4.127
nvidia-nvtx-cu12         12.1.105
oauthlib                 3.2.2
omegaconf                2.3.0
opencv-python            4.11.0.86
packaging                24.2
PasteDeploy              3.1.0
pbkdf2                   1.3
pillow                   10.2.0
pip                      25.0
plaster                  1.1.2
plaster-pastedeploy      1.0.1
platformdirs             4.3.6
proglog                  0.1.10
propcache                0.3.0
protobuf                 5.29.3
psutil                   7.0.0
pyav                     11.4.1
pycocotools              2.0.8
pycocotools              2.0.8
pydantic                 2.10.6
pydantic_core            2.27.2
pyramid                  2.0.2
pyramid-mailer           0.15.1
python3-openid           3.2.0
pytorch-lightning        2.5.0.post0
PyYAML                   6.0.2
repoze.sendmail          4.4.1
requests                 2.32.3
requests-oauthlib        2.0.0
sentry-sdk               2.23.1
setproctitle             1.3.5
setuptools               75.8.0
six                      1.17.0
slack_sdk                3.35.0
smmap                    5.0.2
SQLAlchemy               2.0.30
sympy                    1.13.1
tensorboardX             2.6.2.2
torch                    2.2.2
torchaudio               2.2.2
torchmetrics             1.6.3
torchvision              0.17.2
tqdm                     4.67.1
transaction              4.0
translationstring        1.4
triton                   2.2.0
typing_extensions        4.12.2
urllib3                  2.3.0
velruse                  1.1.1
venusian                 3.1.0
wandb                    0.19.8
WebOb                    1.8.7
wheel                    0.45.1
WTForms                  3.1.2
wtforms-recaptcha        0.3.2
yarl                     1.18.3
zope.deprecation         5.0
zope.interface           6.4.post2
zope.sqlalchemy          3.1

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions