You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Apr 23, 2025. It is now read-only.
Thanks to @xihui-wu's talk earlier today, I learned about the TrainingLoop struct. I had essentially replicated this functionality in a messier way in my code, so I'm looking at it to see if I could replace my train loop with this cleaner implementation. I believe the only issue I might face is with respect to the loss function.
The current loss function takes as parameters only the model's output and the target: public typealias F = @differentiable(Output, @noDerivative Target) -> Tensor<Float> from here. This covers a huge majority of supervised training situations, but there are situations where we might want more complicated loss functions. For example, how might we mask the output and the target for each sample when we calculate the loss, as done in this paper:
at each time step the model tries to predict the full, uncorrupted input vectors xt; however, only the predictions on the masked values are considered in the Mean Squared Error loss.
Another situation that comes to mind is a loss that requires some third, external set of values. Perhaps this is an RL agent whose current loss is a function of recent past losses. Another example could also be a risk-adjusted metric where "risk" depends on some external value that is not static.
Is my reading of the code correct that these types of loss functions are not currently supported? If so, could the protocol be reasonably modified to optionally support such complex loss functions?