Skip to content

Commit 589815f

Browse files
authored
add more detail to tbptt example (#755)
* add more detail to tbptt example * warn user about new arg in training_step
1 parent 76a1c67 commit 589815f

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,14 @@ def training_step(self, batch, batch_idx, optimizer_idx):
168168
# Truncated back-propagation through time
169169
def training_step(self, batch, batch_idx, hiddens):
170170
# hiddens are the hiddens from the previous truncated backprop step
171+
...
172+
out, hiddens = self.lstm(data, hiddens)
173+
...
174+
175+
return {
176+
"loss": ...,
177+
"hiddens": hiddens # remember to detach() this
178+
}
171179
172180
You can also return a -1 instead of a dict to stop the current loop. This is useful
173181
if you want to break out of the current training epoch early.

pytorch_lightning/trainer/trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,10 @@ def __init__(
448448
# backprop every 5 steps in a batch
449449
trainer = Trainer(truncated_bptt_steps=5)
450450
451+
Using this feature requires updating your LightningModule's `training_step()` to include
452+
a `hiddens` arg.
453+
454+
451455
resume_from_checkpoint (str): To resume training from a specific checkpoint pass in the path here.k
452456
Example::
453457

0 commit comments

Comments
 (0)