Skip to content

Fixing inconsistency in usage of LossFunction #2856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public static void Example()
// Define the trainer options.
var options = new AveragedPerceptronTrainer.Options()
{
LossFunction = new SmoothedHingeLoss.Options(),
LossFunction = new SmoothedHingeLoss(),
LearningRate = 0.1f,
DoLazyUpdates = false,
RecencyGain = 0.1f,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public static void Example()
var options = new SdcaMultiClassTrainer.Options
{
// Add custom loss
LossFunction = new HingeLoss.Options(),
LossFunction = new HingeLoss(),
// Make the convergence tolerance tighter.
ConvergenceTolerance = 0.05f,
// Increase the maximum number of passes over training data.
Expand Down
12 changes: 7 additions & 5 deletions src/Microsoft.ML.Data/Dirty/ILoss.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public interface ILossFunction<in TOutput, in TLabel>
Double Loss(TOutput output, TLabel label);
}

public interface IScalarOutputLoss : ILossFunction<float, float>
public interface IScalarLoss : ILossFunction<float, float>
{
/// <summary>
/// Derivative of the loss function with respect to output
Expand All @@ -25,20 +25,22 @@ public interface IScalarOutputLoss : ILossFunction<float, float>
}

[TlcModule.ComponentKind("RegressionLossFunction")]
public interface ISupportRegressionLossFactory : IComponentFactory<IRegressionLoss>
[BestFriend]
internal interface ISupportRegressionLossFactory : IComponentFactory<IRegressionLoss>
{
}

public interface IRegressionLoss : IScalarOutputLoss
public interface IRegressionLoss : IScalarLoss
{
}

[TlcModule.ComponentKind("ClassificationLossFunction")]
public interface ISupportClassificationLossFactory : IComponentFactory<IClassificationLoss>
[BestFriend]
internal interface ISupportClassificationLossFactory : IComponentFactory<IClassificationLoss>
{
}

public interface IClassificationLoss : IScalarOutputLoss
public interface IClassificationLoss : IScalarLoss
{
}

Expand Down
47 changes: 29 additions & 18 deletions src/Microsoft.ML.Data/Utils/LossFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace Microsoft.ML
/// The loss function may know the close-form solution to the optimal dual update
/// Ref: Sec(6.2) of http://jmlr.org/papers/volume14/shalev-shwartz13a/shalev-shwartz13a.pdf
/// </summary>
public interface ISupportSdcaLoss : IScalarOutputLoss
public interface ISupportSdcaLoss : IScalarLoss
{
//This method helps the optimizer pre-compute the invariants that will be used later in DualUpdate.
//scaledFeaturesNormSquared = instanceWeight * (|x|^2 + 1) / (lambda * n), where
Expand All @@ -69,7 +69,7 @@ public interface ISupportSdcaLoss : IScalarOutputLoss
/// </summary>
/// <param name="label">The label of the example.</param>
/// <param name="dual">The dual variable of the example.</param>
Double DualLoss(float label, Double dual);
Double DualLoss(float label, float dual);
}

public interface ISupportSdcaClassificationLoss : ISupportSdcaLoss, IClassificationLoss
Expand All @@ -81,19 +81,22 @@ public interface ISupportSdcaRegressionLoss : ISupportSdcaLoss, IRegressionLoss
}

[TlcModule.ComponentKind("SDCAClassificationLossFunction")]
public interface ISupportSdcaClassificationLossFactory : IComponentFactory<ISupportSdcaClassificationLoss>
[BestFriend]
internal interface ISupportSdcaClassificationLossFactory : IComponentFactory<ISupportSdcaClassificationLoss>
{
}

[TlcModule.ComponentKind("SDCARegressionLossFunction")]
public interface ISupportSdcaRegressionLossFactory : IComponentFactory<ISupportSdcaRegressionLoss>
[BestFriend]
internal interface ISupportSdcaRegressionLossFactory : IComponentFactory<ISupportSdcaRegressionLoss>
{
new ISupportSdcaRegressionLoss CreateComponent(IHostEnvironment env);
}

[TlcModule.Component(Name = "LogLoss", FriendlyName = "Log loss", Aliases = new[] { "Logistic", "CrossEntropy" },
Desc = "Log loss.")]
public sealed class LogLossFactory : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory
[BestFriend]
internal sealed class LogLossFactory : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory
{
public ISupportSdcaClassificationLoss CreateComponent(IHostEnvironment env) => new LogLoss();

Expand Down Expand Up @@ -136,7 +139,7 @@ public float DualUpdate(float output, float label, float dual, float invariant,
return maxNumThreads >= 2 && Math.Abs(fullUpdate) > Threshold ? fullUpdate / maxNumThreads : fullUpdate;
}

public Double DualLoss(float label, Double dual)
public Double DualLoss(float label, float dual)
{
// Normalize the dual with label.
if (label <= 0)
Expand All @@ -161,7 +164,8 @@ private static Double Log(Double x)
public sealed class HingeLoss : ISupportSdcaClassificationLoss
{
[TlcModule.Component(Name = "HingeLoss", FriendlyName = "Hinge loss", Alias = "Hinge", Desc = "Hinge loss.")]
public sealed class Options : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory
[BestFriend]
internal sealed class Options : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Margin value", ShortName = "marg")]
public float Margin = Defaults.Margin;
Expand All @@ -175,7 +179,7 @@ public sealed class Options : ISupportSdcaClassificationLossFactory, ISupportCla
private const float Threshold = 0.5f;
private readonly float _margin;

internal HingeLoss(Options options)
private HingeLoss(Options options)
{
_margin = options.Margin;
}
Expand Down Expand Up @@ -216,7 +220,7 @@ public float DualUpdate(float output, float label, float alpha, float invariant,
return maxNumThreads >= 2 && Math.Abs(fullUpdate) > Threshold ? fullUpdate / maxNumThreads : fullUpdate;
}

public Double DualLoss(float label, Double dual)
public Double DualLoss(float label, float dual)
{
if (label <= 0)
dual = -dual;
Expand All @@ -233,7 +237,7 @@ public sealed class SmoothedHingeLoss : ISupportSdcaClassificationLoss
{
[TlcModule.Component(Name = "SmoothedHingeLoss", FriendlyName = "Smoothed Hinge Loss", Alias = "SmoothedHinge",
Desc = "Smoothed Hinge loss.")]
public sealed class Options : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory
internal sealed class Options : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Smoothing constant", ShortName = "smooth")]
public float SmoothingConst = Defaults.SmoothingConst;
Expand Down Expand Up @@ -313,7 +317,7 @@ public float DualUpdate(float output, float label, float alpha, float invariant,
return maxNumThreads >= 2 && Math.Abs(fullUpdate) > Threshold ? fullUpdate / maxNumThreads : fullUpdate;
}

public Double DualLoss(float label, Double dual)
public Double DualLoss(float label, float dual)
{
if (label <= 0)
dual = -dual;
Expand All @@ -332,7 +336,7 @@ public Double DualLoss(float label, Double dual)
public sealed class ExpLoss : IClassificationLoss
{
[TlcModule.Component(Name = "ExpLoss", FriendlyName = "Exponential Loss", Desc = "Exponential loss.")]
public sealed class Options : ISupportClassificationLossFactory
internal sealed class Options : ISupportClassificationLossFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Beta (dilation)", ShortName = "beta")]
public float Beta = 1;
Expand All @@ -344,11 +348,16 @@ public sealed class Options : ISupportClassificationLossFactory

private readonly float _beta;

public ExpLoss(Options options)
internal ExpLoss(Options options)
{
_beta = options.Beta;
}

public ExpLoss(float beta = 1)
{
_beta = beta;
}

public Double Loss(float output, float label)
{
float truth = label > 0 ? 1 : -1;
Expand All @@ -364,7 +373,8 @@ public float Derivative(float output, float label)
}

[TlcModule.Component(Name = "SquaredLoss", FriendlyName = "Squared Loss", Alias = "L2", Desc = "Squared loss.")]
public sealed class SquaredLossFactory : ISupportSdcaRegressionLossFactory, ISupportRegressionLossFactory
[BestFriend]
internal sealed class SquaredLossFactory : ISupportSdcaRegressionLossFactory, ISupportRegressionLossFactory
{
public ISupportSdcaRegressionLoss CreateComponent(IHostEnvironment env) => new SquaredLoss();

Expand Down Expand Up @@ -398,14 +408,15 @@ public float DualUpdate(float output, float label, float dual, float invariant,
return maxNumThreads >= 2 ? fullUpdate / maxNumThreads : fullUpdate;
}

public Double DualLoss(float label, Double dual)
public Double DualLoss(float label, float dual)
{
return -dual * (dual / 4 - label);
}
}

[TlcModule.Component(Name = "PoissonLoss", FriendlyName = "Poisson Loss", Desc = "Poisson loss.")]
public sealed class PoissonLossFactory : ISupportRegressionLossFactory
[BestFriend]
internal sealed class PoissonLossFactory : ISupportRegressionLossFactory
{
public IRegressionLoss CreateComponent(IHostEnvironment env) => new PoissonLoss();
}
Expand Down Expand Up @@ -437,7 +448,7 @@ public float Derivative(float output, float label)
public sealed class TweedieLoss : IRegressionLoss
{
[TlcModule.Component(Name = "TweedieLoss", FriendlyName = "Tweedie Loss", Alias = "tweedie", Desc = "Tweedie loss.")]
public sealed class Options : ISupportRegressionLossFactory
internal sealed class Options : ISupportRegressionLossFactory
{
[Argument(ArgumentType.LastOccurenceWins, HelpText =
"Index parameter for the Tweedie distribution, in the range [1, 2]. 1 is Poisson loss, 2 is gamma loss, " +
Expand All @@ -453,7 +464,7 @@ public sealed class Options : ISupportRegressionLossFactory
private readonly Double _index1; // 1 minus the index parameter.
private readonly Double _index2; // 2 minus the index parameter.

public TweedieLoss(Options options)
private TweedieLoss(Options options)
{
Contracts.CheckUserArg(1 <= options.Index && options.Index <= 2, nameof(options.Index), "Must be in the range [1, 2]");
_index = options.Index;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,15 @@ internal class AveragedDefault : OnlineLinearOptions.OnlineDefault
public const float L2RegularizerWeight = 0;
}

internal abstract IComponentFactory<IScalarOutputLoss> LossFunctionFactory { get; }
internal abstract IComponentFactory<IScalarLoss> LossFunctionFactory { get; }
}

public abstract class AveragedLinearTrainer<TTransformer, TModel> : OnlineLinearTrainer<TTransformer, TModel>
where TTransformer : ISingleFeaturePredictionTransformer<TModel>
where TModel : class
{
private protected readonly AveragedLinearOptions AveragedLinearTrainerOptions;
private protected IScalarOutputLoss LossFunction;
private protected IScalarLoss LossFunction;

private protected abstract class AveragedTrainStateBase : TrainStateBase
{
Expand All @@ -141,7 +141,7 @@ private protected abstract class AveragedTrainStateBase : TrainStateBase
protected readonly bool Averaged;
private readonly long _resetWeightsAfterXExamples;
private readonly AveragedLinearOptions _args;
private readonly IScalarOutputLoss _loss;
private readonly IScalarLoss _loss;

private protected AveragedTrainStateBase(IChannel ch, int numFeatures, LinearModelParameters predictor, AveragedLinearTrainer<TTransformer, TModel> parent)
: base(ch, numFeatures, predictor, parent)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,13 @@ public sealed class Options : AveragedLinearOptions
/// <summary>
/// A custom <a href="tmpurl_loss">loss</a>.
/// </summary>
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
public ISupportClassificationLossFactory LossFunction = new HingeLoss.Options();
[Argument(ArgumentType.Multiple, Name = "LossFunction", HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
internal ISupportClassificationLossFactory ClassificationLossFunctionFactory = new HingeLoss.Options();

/// <summary>
/// A custom <a href="tmpurl_loss">loss</a>.
Copy link
Contributor

@TomFinley TomFinley Mar 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tmpurl_loss [](start = 34, length = 11)

Were you intending to update this? #Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I noted elsewhere you noted what the replacement value would be, which struck me as positive. Will we do that here? (That said, that is merely documentation, which while important, I would not prioritize since the important thing is to get the shape of the API right, which you seem to have done here.)


In reply to: 263108655 [](ancestors = 263108655)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shahab is working on this, issue #2356


In reply to: 263109039 [](ancestors = 263109039,263108655)

/// </summary>
public IClassificationLoss LossFunction { get; set; }

/// <summary>
/// The <a href="tmpurl_calib">calibrator</a> for producing probabilities. Default is exponential (aka Platt) calibration.
Expand All @@ -75,7 +80,7 @@ public sealed class Options : AveragedLinearOptions
[Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
internal int MaxCalibrationExamples = 1000000;

internal override IComponentFactory<IScalarOutputLoss> LossFunctionFactory => LossFunction;
internal override IComponentFactory<IScalarLoss> LossFunctionFactory => ClassificationLossFunctionFactory;
}

private sealed class TrainState : AveragedTrainStateBase
Expand Down Expand Up @@ -112,7 +117,7 @@ internal AveragedPerceptronTrainer(IHostEnvironment env, Options options)
: base(options, env, UserNameValue, TrainerUtils.MakeBoolScalarLabel(options.LabelColumnName))
{
_args = options;
LossFunction = _args.LossFunction.CreateComponent(env);
LossFunction = _args.LossFunction ?? _args.LossFunctionFactory.CreateComponent(env);
}

/// <summary>
Expand Down Expand Up @@ -143,23 +148,11 @@ internal AveragedPerceptronTrainer(IHostEnvironment env,
DecreaseLearningRate = decreaseLearningRate,
L2RegularizerWeight = l2RegularizerWeight,
NumberOfIterations = numIterations,
LossFunction = new TrivialFactory(lossFunction ?? new HingeLoss())
LossFunction = lossFunction ?? new HingeLoss()
})
{
}

private sealed class TrivialFactory : ISupportClassificationLossFactory
{
private IClassificationLoss _loss;

public TrivialFactory(IClassificationLoss loss)
{
_loss = loss;
}

IClassificationLoss IComponentFactory<IClassificationLoss>.CreateComponent(IHostEnvironment env) => _loss;
}

private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

private protected override bool NeedCalibration => true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,13 @@ public sealed class OnlineGradientDescentTrainer : AveragedLinearTrainer<Regress

public sealed class Options : AveragedLinearOptions
{
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
[Argument(ArgumentType.Multiple, Name = "LossFunction", HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
[TGUI(Label = "Loss Function")]
public ISupportRegressionLossFactory LossFunction = new SquaredLossFactory();
internal ISupportRegressionLossFactory RegressionLossFunctionFactory = new SquaredLossFactory();

public IRegressionLoss LossFunction { get; set; }

internal override IComponentFactory<IScalarLoss> LossFunctionFactory => RegressionLossFunctionFactory;

/// <summary>
/// Set defaults that vary from the base type.
Expand All @@ -48,8 +52,6 @@ public Options()
DecreaseLearningRate = OgdDefaultArgs.DecreaseLearningRate;
}

internal override IComponentFactory<IScalarOutputLoss> LossFunctionFactory => LossFunction;

[BestFriend]
internal class OgdDefaultArgs : AveragedDefault
{
Expand Down Expand Up @@ -113,27 +115,15 @@ internal OnlineGradientDescentTrainer(IHostEnvironment env,
NumberOfIterations = numIterations,
LabelColumnName = labelColumn,
FeatureColumnName = featureColumn,
LossFunction = new TrivialFactory(lossFunction ?? new SquaredLoss())
LossFunction = lossFunction ?? new SquaredLoss()
})
{
}

private sealed class TrivialFactory : ISupportRegressionLossFactory
{
private IRegressionLoss _loss;

public TrivialFactory(IRegressionLoss loss)
{
_loss = loss;
}

IRegressionLoss IComponentFactory<IRegressionLoss>.CreateComponent(IHostEnvironment env) => _loss;
}

internal OnlineGradientDescentTrainer(IHostEnvironment env, Options options)
: base(options, env, UserNameValue, TrainerUtils.MakeR4ScalarColumn(options.LabelColumnName))
{
LossFunction = options.LossFunction.CreateComponent(env);
LossFunction = options.LossFunction ?? options.LossFunctionFactory.CreateComponent(env);
}

private protected override PredictionKind PredictionKind => PredictionKind.Regression;
Expand Down
14 changes: 11 additions & 3 deletions src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1649,8 +1649,16 @@ public sealed class Options : BinaryOptionsBase
/// <value>
/// If unspecified, <see cref="LogLoss"/> will be used.
/// </value>
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory();
[Argument(ArgumentType.Multiple, Name = "LossFunction", HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
internal ISupportSdcaClassificationLossFactory LossFunctionFactory = new LogLossFactory();

/// <summary>
/// The custom <a href="tmpurl_loss">loss</a>.
Copy link
Member

@wschin wschin Mar 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tmpurl_loss [](start = 36, length = 11)

What dose it link to? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shahab is working on this, issue #2356


In reply to: 262720628 [](ancestors = 262720628)

/// </summary>
/// <value>
/// If unspecified, <see cref="LogLoss"/> will be used.
/// </value>
public ISupportSdcaClassificationLoss LossFunction { get; set; }
}

internal SdcaNonCalibratedBinaryTrainer(IHostEnvironment env,
Expand All @@ -1666,7 +1674,7 @@ internal SdcaNonCalibratedBinaryTrainer(IHostEnvironment env,
}

internal SdcaNonCalibratedBinaryTrainer(IHostEnvironment env, Options options)
: base(env, options, options.LossFunction.CreateComponent(env))
: base(env, options, options.LossFunction ?? options.LossFunctionFactory.CreateComponent(env))
Copy link
Member

@wschin wschin Mar 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only constructor of Options assigns new LogLoss() to LossFunction. Is it still possible to execute CreateComponent? #ByDesign

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not anymore, LossFunction is null by default


In reply to: 262723959 [](ancestors = 262723959)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, user will see a trainer with null loss, but that trainer is minimizing LogLoss. That's why I implemented a pattern mentioned by Tom.


In reply to: 262724703 [](ancestors = 262724703,262723959)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, as per documentation line 1659:
If unspecified, will be used.


In reply to: 262726971 [](ancestors = 262726971,262724703,262723959)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might prefer that the value be present (the idea of null acting like a particular specific instance is a bit suboptimal), but I don't insist on it I guess...


In reply to: 262728446 [](ancestors = 262728446,262726971,262724703,262723959)

{
}

Expand Down
Loading