Open
Description
Bug description
I expected manual_backward and .backward to perform backward propagation in the same way, but when I use self.manual_backward it results in a number of unused parameters. If I use .backward then the problem doesn't occur.
What version are you seeing the problem on?
v2.0
How to reproduce the bug
def training_step(self, batch, batch_idx):
x, y, _, y_mel = batch
y = y.unsqueeze(1)
# x.size() & y_mel.size() = [batch_size, n_mels=80, n_frames=32]
# y.size() = [batch_size, segment_size=8192]
optim_g, optim_d = self.optimizers()
scheduler_g, scheduler_d = self.lr_schedulers()
# generate waveform
if self.config.model.istft_layer:
mag, phase = self(x)
generated_wav = self.inverse_spectral_transform(
mag * torch.exp(phase * 1j)
).unsqueeze(-2)
else:
generated_wav = self(x)
# create mel
generated_mel_spec = dynamic_range_compression_torch(
self.spectral_transform(generated_wav).squeeze(1)[:, :, 1:]
)
# train discriminators
optim_d.zero_grad()
# MPD
y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, generated_wav.detach())
if self.use_gradient_penalty:
gp_f = self.compute_gradient_penalty(y.data, generated_wav.detach().data, self.mpd)
else:
gp_f = None
loss_disc_f, _, _ = self.discriminator_loss(y_df_hat_r, y_df_hat_g, gp=gp_f)
self.log("training/disc/mpd_loss", loss_disc_f, prog_bar=False)
# MSD
y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, generated_wav.detach())
loss_disc_s, _, _ = self.discriminator_loss(y_ds_hat_r, y_ds_hat_g, gp=gp_s)
self.log("training/disc/msd_loss", loss_disc_s, prog_bar=False)
# calculate loss
disc_loss_total = loss_disc_s + loss_disc_f
# manual optimization because Pytorch Lightning 2.0+ doesn't handle automatic optimization for multiple optimizers
# this works
disc_loss_total.backward()
# this does not
# self.manual_backward(disc_loss_total
optim_d.step()
scheduler_d.step()
# log discriminator loss
self.log("training/disc/d_loss_total", disc_loss_total, prog_bar=False)
# train generator
optim_g.zero_grad()
# calculate loss
_, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, generated_wav)
_, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, generated_wav)
loss_fm_f = self.feature_loss(fmap_f_r, fmap_f_g)
loss_fm_s = self.feature_loss(fmap_s_r, fmap_s_g)
loss_gen_f, _ = self.generator_loss(
y_df_hat_g, gp=self.use_gradient_penalty
)
loss_gen_s, _ = self.generator_loss(
y_ds_hat_g, gp=self.use_gradient_penalty
)
self.log("training/gen/loss_fmap_f", loss_fm_f, prog_bar=False)
self.log("training/gen/loss_fmap_s", loss_fm_s, prog_bar=False)
self.log("training/gen/loss_gen_f", loss_gen_f, prog_bar=False)
self.log("training/gen/loss_gen_s", loss_gen_s, prog_bar=False)
loss_mel = F.l1_loss(y_mel, generated_mel_spec) * 45
gen_loss_total = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
# manual optimization because Pytorch Lightning 2.0+ doesn't handle automatic optimization for multiple optimizers
gen_loss_total.backward()
optim_g.step()
scheduler_g.step()
# log generator loss
self.log("training/gen/gen_loss_total", gen_loss_total, prog_bar=True)
self.log("training/gen/mel_spec_error", loss_mel / 45, prog_bar=False)
I caught this by adding an on_after_backward
method. When I use self.manual_backward(disc_loss_total)
or self.manual_backward(gen_loss_total)
then I get a bunch of parameters with p.grad == None
but when I use disc_loss_total.backward()
everything works fine.
Error messages and logs
# Error messages and logs here please
Environment
Current environment
- CUDA:
- GPU:
- Tesla V100-SXM2-16GB
- available: True
- version: 11.7
- GPU:
- Lightning:
- lightning: 2.0.4
- lightning-cloud: 0.5.39
- lightning-utilities: 0.9.0
- pytorch-lightning: 2.0.9.post0
- torch: 2.0.1+cu117
- torchaudio: 2.0.2+cu117
- torchmetrics: 1.2.0
- Packages:
- absl-py: 2.0.0
- aiohttp: 3.8.5
- aiosignal: 1.3.1
- aniso8601: 9.0.1
- annotated-types: 0.5.0
- anyio: 3.7.1
- anytree: 2.9.0
- arrow: 1.3.0
- async-timeout: 4.0.3
- attrs: 23.1.0
- audioread: 3.0.1
- beautifulsoup4: 4.12.2
- bidict: 0.22.1
- black: 22.12.0
- blessed: 1.20.0
- cachetools: 5.3.1
- certifi: 2023.7.22
- cffi: 1.16.0
- cfgv: 3.4.0
- charset-normalizer: 3.3.0
- click: 8.1.7
- clipdetect: 0.1.3
- cmake: 3.27.6
- colorama: 0.4.6
- coloredlogs: 14.0
- contourpy: 1.1.1
- croniter: 1.3.15
- cycler: 0.12.0
- cython: 3.0.3
- dateutils: 0.6.12
- decorator: 5.1.1
- deepdiff: 6.6.0
- distlib: 0.3.7
- dnspython: 2.3.0
- editdistance: 0.6.2
- einops: 0.5.0
- et-xmlfile: 1.1.0
- eventlet: 0.33.3
- everyvoice: 0.1.20231005
- exceptiongroup: 1.1.3
- fastapi: 0.103.2
- filelock: 3.12.4
- flake8: 6.1.0
- flask: 2.2.5
- flask-cors: 4.0.0
- flask-restful: 0.3.10
- flask-socketio: 5.3.6
- flask-talisman: 1.1.0
- fonttools: 4.43.0
- frozenlist: 1.4.0
- fsspec: 2023.9.2
- g2p: 1.1.20230822
- gitlint-core: 0.19.1
- google-auth: 2.23.2
- google-auth-oauthlib: 1.0.0
- greenlet: 3.0.0
- grpcio: 1.59.0
- h11: 0.14.0
- humanfriendly: 10.0
- identify: 2.5.30
- idna: 3.4
- importlib-metadata: 6.8.0
- iniconfig: 2.0.0
- inquirer: 3.1.3
- isort: 5.12.0
- itsdangerous: 2.1.2
- jinja2: 3.1.2
- joblib: 1.3.2
- jsonschema: 4.19.1
- jsonschema-specifications: 2023.7.1
- kiwisolver: 1.4.5
- librosa: 0.9.2
- lightning: 2.0.4
- lightning-cloud: 0.5.39
- lightning-utilities: 0.9.0
- lit: 17.0.2
- llvmlite: 0.41.0
- loguru: 0.6.0
- markdown: 3.4.4
- markdown-it-py: 3.0.0
- markupsafe: 2.1.3
- matplotlib: 3.6.0
- mccabe: 0.7.0
- mdurl: 0.1.2
- merge-args: 0.1.5
- mpmath: 1.3.0
- multidict: 6.0.4
- munkres: 1.1.4
- mypy: 1.5.1
- mypy-extensions: 1.0.0
- networkx: 2.8.4
- nltk: 3.7
- nodeenv: 1.8.0
- numba: 0.58.0
- numpy: 1.25.2
- oauthlib: 3.2.2
- openpyxl: 3.1.2
- ordered-set: 4.1.0
- packaging: 23.2
- pandas: 1.4.4
- panphon: 0.20.0
- pathspec: 0.11.2
- pillow: 10.0.1
- pip: 23.2.1
- platformdirs: 3.11.0
- pluggy: 1.3.0
- pooch: 1.7.0
- pre-commit: 3.4.0
- prompt-toolkit: 3.0.39
- protobuf: 4.24.4
- psutil: 5.9.5
- pyasn1: 0.5.0
- pyasn1-modules: 0.3.0
- pycodestyle: 2.11.0
- pycountry: 22.3.5
- pycparser: 2.21
- pydantic: 2.4.2
- pydantic-core: 2.10.1
- pyflakes: 3.1.0
- pygments: 2.16.1
- pyjwt: 2.8.0
- pympi-ling: 1.70.2
- pyparsing: 3.1.1
- pysdtw: 0.0.5
- pytest: 7.4.2
- python-dateutil: 2.8.2
- python-editor: 1.0.4
- python-engineio: 4.7.1
- python-multipart: 0.0.6
- python-socketio: 5.9.0
- pytorch-lightning: 2.0.9.post0
- pytz: 2023.3.post1
- pyworld: 0.3.4
- pyyaml: 6.0.1
- questionary: 1.10.0
- readchar: 4.0.5
- referencing: 0.30.2
- regex: 2023.10.3
- requests: 2.31.0
- requests-oauthlib: 1.3.1
- resampy: 0.4.2
- rich: 13.6.0
- rpds-py: 0.10.4
- rsa: 4.9
- scikit-learn: 1.3.1
- scipy: 1.11.3
- setuptools: 59.5.0
- sh: 2.0.6
- shellingham: 1.5.3
- simple-term-menu: 1.5.2
- simple-websocket: 1.0.0
- six: 1.16.0
- sniffio: 1.3.0
- soundfile: 0.12.1
- soupsieve: 2.5
- starlette: 0.27.0
- starsessions: 1.3.0
- sympy: 1.12
- tabulate: 0.8.10
- tensorboard: 2.14.1
- tensorboard-data-server: 0.7.1
- text-unidecode: 1.3
- threadpoolctl: 3.2.0
- tomli: 2.0.1
- torch: 2.0.1+cu117
- torchaudio: 2.0.2+cu117
- torchmetrics: 1.2.0
- tqdm: 4.66.1
- traitlets: 5.11.2
- triton: 2.0.0
- typer: 0.9.0
- types-python-dateutil: 2.8.19.14
- types-pyyaml: 6.0.12.12
- types-requests: 2.31.0.8
- types-setuptools: 68.2.0.0
- types-tabulate: 0.8.11
- typing-extensions: 4.8.0
- unicodecsv: 0.14.1
- urllib3: 2.0.6
- uvicorn: 0.23.2
- virtualenv: 20.24.5
- wcwidth: 0.2.8
- websocket-client: 1.6.3
- websockets: 11.0.3
- werkzeug: 2.2.3
- wheel: 0.41.2
- wsproto: 1.2.0
- yarl: 1.9.2
- zipp: 3.17.0
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.9.18
- release: 4.15.0-204-generic
- version: Demos #215-Ubuntu SMP Fri Jan 20 18:24:59 UTC 2023
More info
No response