-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Description
With the keep_checkpoint
option we can specify how many checkpoints should be kept. However, the checkpoints are just saved sequentially and never ordered. That means that if your best performing model is early on, it might get removed anyway.
OpenNMT-py/onmt/models/model_saver.py
Lines 79 to 83 in 0734288
if self.keep_checkpoint > 0: | |
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen: | |
todel = self.checkpoint_queue.popleft() | |
self._rm_checkpoint(todel) | |
self.checkpoint_queue.append(chkpt_name) |
As an alternative approach, I would suggest that if validation is done before each save step, that validation loss is also passed to the save
method. self.checkpoint_queue
could then contain tuples of (loss, chkpt_name)
and after each append that queue gets sorted on loss
. That way, only the worst performing models are removed.
Things to consider: ModelSaver
should then know whether the metric is higher=better or lower=better, and a fallback needs to be in-place when no loss is passed.
Activity
francoishernandez commentedon Nov 24, 2020
Hey Bram,
Yes, there is a pending PR #1859 and issue #1856 about this, but the first propositions did not convince me, and I did not take the time yet to have a deeper look.
I like the idea of having a queue that gets updated each time. Maybe some people would like to still keep the N last checkpoints (chronologically), thought, but that may be handled with a flag, like
keep_checkpoints_order
withchoices=[metricA, metricB, "chronological"]
for instance.Feel free to open a PR for such an implementation.
BramVanroy commentedon Nov 24, 2020
@francoishernandez Ah, my bad - should've looked in the PRs.
In your proposal, you use metricA and metricB. In OpenNMT, can validation ever be done with more than one metric? From
build_loss_compute
it seems that there is only ever one metric during validation.I would simply (optionally) pass the validation loss to the model saver here:
OpenNMT-py/onmt/trainer.py
Lines 276 to 279 in bc95e03
If no loss is passed (i.e. if no validation is done), the behaviour defaults to "chronological" and otherwise the validation metric is used. If that seems like something that interests you I can implement this next week.
francoishernandez commentedon Nov 24, 2020
This was just a way of keeping it generic for any further extension. A simple boolean flag can indeed do the trick at first.