Skip to content

manual_backward and .backward() have different behaviour. #18740

Open
@roedoejet

Description

@roedoejet

Bug description

I expected manual_backward and .backward to perform backward propagation in the same way, but when I use self.manual_backward it results in a number of unused parameters. If I use .backward then the problem doesn't occur.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

def training_step(self, batch, batch_idx):
        x, y, _, y_mel = batch
        y = y.unsqueeze(1)
        # x.size() & y_mel.size() = [batch_size, n_mels=80, n_frames=32]
        # y.size() = [batch_size, segment_size=8192]
        optim_g, optim_d = self.optimizers()
        scheduler_g, scheduler_d = self.lr_schedulers()
        # generate waveform
        if self.config.model.istft_layer:
            mag, phase = self(x)
            generated_wav = self.inverse_spectral_transform(
                mag * torch.exp(phase * 1j)
            ).unsqueeze(-2)
        else:
            generated_wav = self(x)
       
        # create mel
        generated_mel_spec = dynamic_range_compression_torch(
            self.spectral_transform(generated_wav).squeeze(1)[:, :, 1:]
        )
        # train discriminators
        optim_d.zero_grad()
        # MPD
        y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, generated_wav.detach())
        if self.use_gradient_penalty:
            gp_f = self.compute_gradient_penalty(y.data, generated_wav.detach().data, self.mpd)
        else:
            gp_f = None
        loss_disc_f, _, _ = self.discriminator_loss(y_df_hat_r, y_df_hat_g, gp=gp_f)
        self.log("training/disc/mpd_loss", loss_disc_f, prog_bar=False)
        # MSD
        y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, generated_wav.detach())
        loss_disc_s, _, _ = self.discriminator_loss(y_ds_hat_r, y_ds_hat_g, gp=gp_s)
        self.log("training/disc/msd_loss", loss_disc_s, prog_bar=False)
        # calculate loss
        disc_loss_total = loss_disc_s + loss_disc_f
        # manual optimization because Pytorch Lightning 2.0+ doesn't handle automatic optimization for multiple optimizers
        # this works
        disc_loss_total.backward()
        # this does not
        # self.manual_backward(disc_loss_total
        optim_d.step()
        scheduler_d.step()
        # log discriminator loss
        self.log("training/disc/d_loss_total", disc_loss_total, prog_bar=False)
            
        # train generator
        optim_g.zero_grad()
        # calculate loss
        _, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, generated_wav)
        _, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, generated_wav)
        loss_fm_f = self.feature_loss(fmap_f_r, fmap_f_g)
        loss_fm_s = self.feature_loss(fmap_s_r, fmap_s_g)
        loss_gen_f, _ = self.generator_loss(
            y_df_hat_g, gp=self.use_gradient_penalty
        )
        loss_gen_s, _ = self.generator_loss(
            y_ds_hat_g, gp=self.use_gradient_penalty
        )
        self.log("training/gen/loss_fmap_f", loss_fm_f, prog_bar=False)
        self.log("training/gen/loss_fmap_s", loss_fm_s, prog_bar=False)
        self.log("training/gen/loss_gen_f", loss_gen_f, prog_bar=False)
        self.log("training/gen/loss_gen_s", loss_gen_s, prog_bar=False)
        loss_mel = F.l1_loss(y_mel, generated_mel_spec) * 45
        gen_loss_total = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
        # manual optimization because Pytorch Lightning 2.0+ doesn't handle automatic optimization for multiple optimizers
        gen_loss_total.backward()
        optim_g.step()
        scheduler_g.step()
        # log generator loss
        self.log("training/gen/gen_loss_total", gen_loss_total, prog_bar=True)
        self.log("training/gen/mel_spec_error", loss_mel / 45, prog_bar=False)

I caught this by adding an on_after_backward method. When I use self.manual_backward(disc_loss_total) or self.manual_backward(gen_loss_total) then I get a bunch of parameters with p.grad == None but when I use disc_loss_total.backward() everything works fine.

Error messages and logs

# Error messages and logs here please

Environment

Current environment
  • CUDA:
    • GPU:
      • Tesla V100-SXM2-16GB
    • available: True
    • version: 11.7
  • Lightning:
    • lightning: 2.0.4
    • lightning-cloud: 0.5.39
    • lightning-utilities: 0.9.0
    • pytorch-lightning: 2.0.9.post0
    • torch: 2.0.1+cu117
    • torchaudio: 2.0.2+cu117
    • torchmetrics: 1.2.0
  • Packages:
    • absl-py: 2.0.0
    • aiohttp: 3.8.5
    • aiosignal: 1.3.1
    • aniso8601: 9.0.1
    • annotated-types: 0.5.0
    • anyio: 3.7.1
    • anytree: 2.9.0
    • arrow: 1.3.0
    • async-timeout: 4.0.3
    • attrs: 23.1.0
    • audioread: 3.0.1
    • beautifulsoup4: 4.12.2
    • bidict: 0.22.1
    • black: 22.12.0
    • blessed: 1.20.0
    • cachetools: 5.3.1
    • certifi: 2023.7.22
    • cffi: 1.16.0
    • cfgv: 3.4.0
    • charset-normalizer: 3.3.0
    • click: 8.1.7
    • clipdetect: 0.1.3
    • cmake: 3.27.6
    • colorama: 0.4.6
    • coloredlogs: 14.0
    • contourpy: 1.1.1
    • croniter: 1.3.15
    • cycler: 0.12.0
    • cython: 3.0.3
    • dateutils: 0.6.12
    • decorator: 5.1.1
    • deepdiff: 6.6.0
    • distlib: 0.3.7
    • dnspython: 2.3.0
    • editdistance: 0.6.2
    • einops: 0.5.0
    • et-xmlfile: 1.1.0
    • eventlet: 0.33.3
    • everyvoice: 0.1.20231005
    • exceptiongroup: 1.1.3
    • fastapi: 0.103.2
    • filelock: 3.12.4
    • flake8: 6.1.0
    • flask: 2.2.5
    • flask-cors: 4.0.0
    • flask-restful: 0.3.10
    • flask-socketio: 5.3.6
    • flask-talisman: 1.1.0
    • fonttools: 4.43.0
    • frozenlist: 1.4.0
    • fsspec: 2023.9.2
    • g2p: 1.1.20230822
    • gitlint-core: 0.19.1
    • google-auth: 2.23.2
    • google-auth-oauthlib: 1.0.0
    • greenlet: 3.0.0
    • grpcio: 1.59.0
    • h11: 0.14.0
    • humanfriendly: 10.0
    • identify: 2.5.30
    • idna: 3.4
    • importlib-metadata: 6.8.0
    • iniconfig: 2.0.0
    • inquirer: 3.1.3
    • isort: 5.12.0
    • itsdangerous: 2.1.2
    • jinja2: 3.1.2
    • joblib: 1.3.2
    • jsonschema: 4.19.1
    • jsonschema-specifications: 2023.7.1
    • kiwisolver: 1.4.5
    • librosa: 0.9.2
    • lightning: 2.0.4
    • lightning-cloud: 0.5.39
    • lightning-utilities: 0.9.0
    • lit: 17.0.2
    • llvmlite: 0.41.0
    • loguru: 0.6.0
    • markdown: 3.4.4
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.3
    • matplotlib: 3.6.0
    • mccabe: 0.7.0
    • mdurl: 0.1.2
    • merge-args: 0.1.5
    • mpmath: 1.3.0
    • multidict: 6.0.4
    • munkres: 1.1.4
    • mypy: 1.5.1
    • mypy-extensions: 1.0.0
    • networkx: 2.8.4
    • nltk: 3.7
    • nodeenv: 1.8.0
    • numba: 0.58.0
    • numpy: 1.25.2
    • oauthlib: 3.2.2
    • openpyxl: 3.1.2
    • ordered-set: 4.1.0
    • packaging: 23.2
    • pandas: 1.4.4
    • panphon: 0.20.0
    • pathspec: 0.11.2
    • pillow: 10.0.1
    • pip: 23.2.1
    • platformdirs: 3.11.0
    • pluggy: 1.3.0
    • pooch: 1.7.0
    • pre-commit: 3.4.0
    • prompt-toolkit: 3.0.39
    • protobuf: 4.24.4
    • psutil: 5.9.5
    • pyasn1: 0.5.0
    • pyasn1-modules: 0.3.0
    • pycodestyle: 2.11.0
    • pycountry: 22.3.5
    • pycparser: 2.21
    • pydantic: 2.4.2
    • pydantic-core: 2.10.1
    • pyflakes: 3.1.0
    • pygments: 2.16.1
    • pyjwt: 2.8.0
    • pympi-ling: 1.70.2
    • pyparsing: 3.1.1
    • pysdtw: 0.0.5
    • pytest: 7.4.2
    • python-dateutil: 2.8.2
    • python-editor: 1.0.4
    • python-engineio: 4.7.1
    • python-multipart: 0.0.6
    • python-socketio: 5.9.0
    • pytorch-lightning: 2.0.9.post0
    • pytz: 2023.3.post1
    • pyworld: 0.3.4
    • pyyaml: 6.0.1
    • questionary: 1.10.0
    • readchar: 4.0.5
    • referencing: 0.30.2
    • regex: 2023.10.3
    • requests: 2.31.0
    • requests-oauthlib: 1.3.1
    • resampy: 0.4.2
    • rich: 13.6.0
    • rpds-py: 0.10.4
    • rsa: 4.9
    • scikit-learn: 1.3.1
    • scipy: 1.11.3
    • setuptools: 59.5.0
    • sh: 2.0.6
    • shellingham: 1.5.3
    • simple-term-menu: 1.5.2
    • simple-websocket: 1.0.0
    • six: 1.16.0
    • sniffio: 1.3.0
    • soundfile: 0.12.1
    • soupsieve: 2.5
    • starlette: 0.27.0
    • starsessions: 1.3.0
    • sympy: 1.12
    • tabulate: 0.8.10
    • tensorboard: 2.14.1
    • tensorboard-data-server: 0.7.1
    • text-unidecode: 1.3
    • threadpoolctl: 3.2.0
    • tomli: 2.0.1
    • torch: 2.0.1+cu117
    • torchaudio: 2.0.2+cu117
    • torchmetrics: 1.2.0
    • tqdm: 4.66.1
    • traitlets: 5.11.2
    • triton: 2.0.0
    • typer: 0.9.0
    • types-python-dateutil: 2.8.19.14
    • types-pyyaml: 6.0.12.12
    • types-requests: 2.31.0.8
    • types-setuptools: 68.2.0.0
    • types-tabulate: 0.8.11
    • typing-extensions: 4.8.0
    • unicodecsv: 0.14.1
    • urllib3: 2.0.6
    • uvicorn: 0.23.2
    • virtualenv: 20.24.5
    • wcwidth: 0.2.8
    • websocket-client: 1.6.3
    • websockets: 11.0.3
    • werkzeug: 2.2.3
    • wheel: 0.41.2
    • wsproto: 1.2.0
    • yarl: 1.9.2
    • zipp: 3.17.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.9.18
    • release: 4.15.0-204-generic
    • version: Demos #215-Ubuntu SMP Fri Jan 20 18:24:59 UTC 2023

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingrepro neededThe issue is missing a reproducible examplever: 2.0.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions