Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Complex Loss Functions Inside TrainingLoop #720

@xanderdunn

Description

@xanderdunn

Thanks to @xihui-wu's talk earlier today, I learned about the TrainingLoop struct. I had essentially replicated this functionality in a messier way in my code, so I'm looking at it to see if I could replace my train loop with this cleaner implementation. I believe the only issue I might face is with respect to the loss function.

The current loss function takes as parameters only the model's output and the target: public typealias F = @differentiable(Output, @noDerivative Target) -> Tensor<Float> from here. This covers a huge majority of supervised training situations, but there are situations where we might want more complicated loss functions. For example, how might we mask the output and the target for each sample when we calculate the loss, as done in this paper:

at each time step the model tries to predict the full, uncorrupted input vectors xt; however, only the predictions on the masked values are considered in the Mean Squared Error loss.

Another situation that comes to mind is a loss that requires some third, external set of values. Perhaps this is an RL agent whose current loss is a function of recent past losses. Another example could also be a risk-adjusted metric where "risk" depends on some external value that is not static.

Is my reading of the code correct that these types of loss functions are not currently supported? If so, could the protocol be reasonably modified to optionally support such complex loss functions?

Many thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions