diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs index fb1dfacf50..fad5444016 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs @@ -23,7 +23,7 @@ public static void Example() // Define the trainer options. var options = new AveragedPerceptronTrainer.Options() { - LossFunction = new SmoothedHingeLoss.Options(), + LossFunction = new SmoothedHingeLoss(), LearningRate = 0.1f, DoLazyUpdates = false, RecencyGain = 0.1f, diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscentWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscentWithOptions.cs index fe54dc2728..c9ab5c91c5 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscentWithOptions.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscentWithOptions.cs @@ -29,7 +29,7 @@ public static void Example() var options = new SdcaMultiClassTrainer.Options { // Add custom loss - LossFunction = new HingeLoss.Options(), + LossFunction = new HingeLoss(), // Make the convergence tolerance tighter. ConvergenceTolerance = 0.05f, // Increase the maximum number of passes over training data. diff --git a/src/Microsoft.ML.Data/Dirty/ILoss.cs b/src/Microsoft.ML.Data/Dirty/ILoss.cs index 6e35c17bea..694cd34471 100644 --- a/src/Microsoft.ML.Data/Dirty/ILoss.cs +++ b/src/Microsoft.ML.Data/Dirty/ILoss.cs @@ -16,7 +16,7 @@ public interface ILossFunction Double Loss(TOutput output, TLabel label); } - public interface IScalarOutputLoss : ILossFunction + public interface IScalarLoss : ILossFunction { /// /// Derivative of the loss function with respect to output @@ -25,20 +25,22 @@ public interface IScalarOutputLoss : ILossFunction } [TlcModule.ComponentKind("RegressionLossFunction")] - public interface ISupportRegressionLossFactory : IComponentFactory + [BestFriend] + internal interface ISupportRegressionLossFactory : IComponentFactory { } - public interface IRegressionLoss : IScalarOutputLoss + public interface IRegressionLoss : IScalarLoss { } [TlcModule.ComponentKind("ClassificationLossFunction")] - public interface ISupportClassificationLossFactory : IComponentFactory + [BestFriend] + internal interface ISupportClassificationLossFactory : IComponentFactory { } - public interface IClassificationLoss : IScalarOutputLoss + public interface IClassificationLoss : IScalarLoss { } diff --git a/src/Microsoft.ML.Data/Utils/LossFunctions.cs b/src/Microsoft.ML.Data/Utils/LossFunctions.cs index a48b05c1be..7a91a521af 100644 --- a/src/Microsoft.ML.Data/Utils/LossFunctions.cs +++ b/src/Microsoft.ML.Data/Utils/LossFunctions.cs @@ -43,7 +43,7 @@ namespace Microsoft.ML /// The loss function may know the close-form solution to the optimal dual update /// Ref: Sec(6.2) of http://jmlr.org/papers/volume14/shalev-shwartz13a/shalev-shwartz13a.pdf /// - public interface ISupportSdcaLoss : IScalarOutputLoss + public interface ISupportSdcaLoss : IScalarLoss { //This method helps the optimizer pre-compute the invariants that will be used later in DualUpdate. //scaledFeaturesNormSquared = instanceWeight * (|x|^2 + 1) / (lambda * n), where @@ -69,7 +69,7 @@ public interface ISupportSdcaLoss : IScalarOutputLoss /// /// The label of the example. /// The dual variable of the example. - Double DualLoss(float label, Double dual); + Double DualLoss(float label, float dual); } public interface ISupportSdcaClassificationLoss : ISupportSdcaLoss, IClassificationLoss @@ -81,19 +81,22 @@ public interface ISupportSdcaRegressionLoss : ISupportSdcaLoss, IRegressionLoss } [TlcModule.ComponentKind("SDCAClassificationLossFunction")] - public interface ISupportSdcaClassificationLossFactory : IComponentFactory + [BestFriend] + internal interface ISupportSdcaClassificationLossFactory : IComponentFactory { } [TlcModule.ComponentKind("SDCARegressionLossFunction")] - public interface ISupportSdcaRegressionLossFactory : IComponentFactory + [BestFriend] + internal interface ISupportSdcaRegressionLossFactory : IComponentFactory { new ISupportSdcaRegressionLoss CreateComponent(IHostEnvironment env); } [TlcModule.Component(Name = "LogLoss", FriendlyName = "Log loss", Aliases = new[] { "Logistic", "CrossEntropy" }, Desc = "Log loss.")] - public sealed class LogLossFactory : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory + [BestFriend] + internal sealed class LogLossFactory : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory { public ISupportSdcaClassificationLoss CreateComponent(IHostEnvironment env) => new LogLoss(); @@ -136,7 +139,7 @@ public float DualUpdate(float output, float label, float dual, float invariant, return maxNumThreads >= 2 && Math.Abs(fullUpdate) > Threshold ? fullUpdate / maxNumThreads : fullUpdate; } - public Double DualLoss(float label, Double dual) + public Double DualLoss(float label, float dual) { // Normalize the dual with label. if (label <= 0) @@ -161,7 +164,8 @@ private static Double Log(Double x) public sealed class HingeLoss : ISupportSdcaClassificationLoss { [TlcModule.Component(Name = "HingeLoss", FriendlyName = "Hinge loss", Alias = "Hinge", Desc = "Hinge loss.")] - public sealed class Options : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory + [BestFriend] + internal sealed class Options : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "Margin value", ShortName = "marg")] public float Margin = Defaults.Margin; @@ -175,7 +179,7 @@ public sealed class Options : ISupportSdcaClassificationLossFactory, ISupportCla private const float Threshold = 0.5f; private readonly float _margin; - internal HingeLoss(Options options) + private HingeLoss(Options options) { _margin = options.Margin; } @@ -216,7 +220,7 @@ public float DualUpdate(float output, float label, float alpha, float invariant, return maxNumThreads >= 2 && Math.Abs(fullUpdate) > Threshold ? fullUpdate / maxNumThreads : fullUpdate; } - public Double DualLoss(float label, Double dual) + public Double DualLoss(float label, float dual) { if (label <= 0) dual = -dual; @@ -233,7 +237,7 @@ public sealed class SmoothedHingeLoss : ISupportSdcaClassificationLoss { [TlcModule.Component(Name = "SmoothedHingeLoss", FriendlyName = "Smoothed Hinge Loss", Alias = "SmoothedHinge", Desc = "Smoothed Hinge loss.")] - public sealed class Options : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory + internal sealed class Options : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "Smoothing constant", ShortName = "smooth")] public float SmoothingConst = Defaults.SmoothingConst; @@ -313,7 +317,7 @@ public float DualUpdate(float output, float label, float alpha, float invariant, return maxNumThreads >= 2 && Math.Abs(fullUpdate) > Threshold ? fullUpdate / maxNumThreads : fullUpdate; } - public Double DualLoss(float label, Double dual) + public Double DualLoss(float label, float dual) { if (label <= 0) dual = -dual; @@ -332,7 +336,7 @@ public Double DualLoss(float label, Double dual) public sealed class ExpLoss : IClassificationLoss { [TlcModule.Component(Name = "ExpLoss", FriendlyName = "Exponential Loss", Desc = "Exponential loss.")] - public sealed class Options : ISupportClassificationLossFactory + internal sealed class Options : ISupportClassificationLossFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "Beta (dilation)", ShortName = "beta")] public float Beta = 1; @@ -344,11 +348,16 @@ public sealed class Options : ISupportClassificationLossFactory private readonly float _beta; - public ExpLoss(Options options) + internal ExpLoss(Options options) { _beta = options.Beta; } + public ExpLoss(float beta = 1) + { + _beta = beta; + } + public Double Loss(float output, float label) { float truth = label > 0 ? 1 : -1; @@ -364,7 +373,8 @@ public float Derivative(float output, float label) } [TlcModule.Component(Name = "SquaredLoss", FriendlyName = "Squared Loss", Alias = "L2", Desc = "Squared loss.")] - public sealed class SquaredLossFactory : ISupportSdcaRegressionLossFactory, ISupportRegressionLossFactory + [BestFriend] + internal sealed class SquaredLossFactory : ISupportSdcaRegressionLossFactory, ISupportRegressionLossFactory { public ISupportSdcaRegressionLoss CreateComponent(IHostEnvironment env) => new SquaredLoss(); @@ -398,14 +408,15 @@ public float DualUpdate(float output, float label, float dual, float invariant, return maxNumThreads >= 2 ? fullUpdate / maxNumThreads : fullUpdate; } - public Double DualLoss(float label, Double dual) + public Double DualLoss(float label, float dual) { return -dual * (dual / 4 - label); } } [TlcModule.Component(Name = "PoissonLoss", FriendlyName = "Poisson Loss", Desc = "Poisson loss.")] - public sealed class PoissonLossFactory : ISupportRegressionLossFactory + [BestFriend] + internal sealed class PoissonLossFactory : ISupportRegressionLossFactory { public IRegressionLoss CreateComponent(IHostEnvironment env) => new PoissonLoss(); } @@ -437,7 +448,7 @@ public float Derivative(float output, float label) public sealed class TweedieLoss : IRegressionLoss { [TlcModule.Component(Name = "TweedieLoss", FriendlyName = "Tweedie Loss", Alias = "tweedie", Desc = "Tweedie loss.")] - public sealed class Options : ISupportRegressionLossFactory + internal sealed class Options : ISupportRegressionLossFactory { [Argument(ArgumentType.LastOccurenceWins, HelpText = "Index parameter for the Tweedie distribution, in the range [1, 2]. 1 is Poisson loss, 2 is gamma loss, " + @@ -453,7 +464,7 @@ public sealed class Options : ISupportRegressionLossFactory private readonly Double _index1; // 1 minus the index parameter. private readonly Double _index2; // 2 minus the index parameter. - public TweedieLoss(Options options) + private TweedieLoss(Options options) { Contracts.CheckUserArg(1 <= options.Index && options.Index <= 2, nameof(options.Index), "Must be in the range [1, 2]"); _index = options.Index; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs index 6e291134e9..ed92b3b930 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs @@ -111,7 +111,7 @@ internal class AveragedDefault : OnlineLinearOptions.OnlineDefault public const float L2RegularizerWeight = 0; } - internal abstract IComponentFactory LossFunctionFactory { get; } + internal abstract IComponentFactory LossFunctionFactory { get; } } public abstract class AveragedLinearTrainer : OnlineLinearTrainer @@ -119,7 +119,7 @@ public abstract class AveragedLinearTrainer : OnlineLinear where TModel : class { private protected readonly AveragedLinearOptions AveragedLinearTrainerOptions; - private protected IScalarOutputLoss LossFunction; + private protected IScalarLoss LossFunction; private protected abstract class AveragedTrainStateBase : TrainStateBase { @@ -141,7 +141,7 @@ private protected abstract class AveragedTrainStateBase : TrainStateBase protected readonly bool Averaged; private readonly long _resetWeightsAfterXExamples; private readonly AveragedLinearOptions _args; - private readonly IScalarOutputLoss _loss; + private readonly IScalarLoss _loss; private protected AveragedTrainStateBase(IChannel ch, int numFeatures, LinearModelParameters predictor, AveragedLinearTrainer parent) : base(ch, numFeatures, predictor, parent) diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index 3fe119d98a..2df8c6289a 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -60,8 +60,13 @@ public sealed class Options : AveragedLinearOptions /// /// A custom loss. /// - [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] - public ISupportClassificationLossFactory LossFunction = new HingeLoss.Options(); + [Argument(ArgumentType.Multiple, Name = "LossFunction", HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] + internal ISupportClassificationLossFactory ClassificationLossFunctionFactory = new HingeLoss.Options(); + + /// + /// A custom loss. + /// + public IClassificationLoss LossFunction { get; set; } /// /// The calibrator for producing probabilities. Default is exponential (aka Platt) calibration. @@ -75,7 +80,7 @@ public sealed class Options : AveragedLinearOptions [Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] internal int MaxCalibrationExamples = 1000000; - internal override IComponentFactory LossFunctionFactory => LossFunction; + internal override IComponentFactory LossFunctionFactory => ClassificationLossFunctionFactory; } private sealed class TrainState : AveragedTrainStateBase @@ -112,7 +117,7 @@ internal AveragedPerceptronTrainer(IHostEnvironment env, Options options) : base(options, env, UserNameValue, TrainerUtils.MakeBoolScalarLabel(options.LabelColumnName)) { _args = options; - LossFunction = _args.LossFunction.CreateComponent(env); + LossFunction = _args.LossFunction ?? _args.LossFunctionFactory.CreateComponent(env); } /// @@ -143,23 +148,11 @@ internal AveragedPerceptronTrainer(IHostEnvironment env, DecreaseLearningRate = decreaseLearningRate, L2RegularizerWeight = l2RegularizerWeight, NumberOfIterations = numIterations, - LossFunction = new TrivialFactory(lossFunction ?? new HingeLoss()) + LossFunction = lossFunction ?? new HingeLoss() }) { } - private sealed class TrivialFactory : ISupportClassificationLossFactory - { - private IClassificationLoss _loss; - - public TrivialFactory(IClassificationLoss loss) - { - _loss = loss; - } - - IClassificationLoss IComponentFactory.CreateComponent(IHostEnvironment env) => _loss; - } - private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification; private protected override bool NeedCalibration => true; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs index 771fe8b149..0476d7e411 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs @@ -35,9 +35,13 @@ public sealed class OnlineGradientDescentTrainer : AveragedLinearTrainer LossFunctionFactory => RegressionLossFunctionFactory; /// /// Set defaults that vary from the base type. @@ -48,8 +52,6 @@ public Options() DecreaseLearningRate = OgdDefaultArgs.DecreaseLearningRate; } - internal override IComponentFactory LossFunctionFactory => LossFunction; - [BestFriend] internal class OgdDefaultArgs : AveragedDefault { @@ -113,27 +115,15 @@ internal OnlineGradientDescentTrainer(IHostEnvironment env, NumberOfIterations = numIterations, LabelColumnName = labelColumn, FeatureColumnName = featureColumn, - LossFunction = new TrivialFactory(lossFunction ?? new SquaredLoss()) + LossFunction = lossFunction ?? new SquaredLoss() }) { } - private sealed class TrivialFactory : ISupportRegressionLossFactory - { - private IRegressionLoss _loss; - - public TrivialFactory(IRegressionLoss loss) - { - _loss = loss; - } - - IRegressionLoss IComponentFactory.CreateComponent(IHostEnvironment env) => _loss; - } - internal OnlineGradientDescentTrainer(IHostEnvironment env, Options options) : base(options, env, UserNameValue, TrainerUtils.MakeR4ScalarColumn(options.LabelColumnName)) { - LossFunction = options.LossFunction.CreateComponent(env); + LossFunction = options.LossFunction ?? options.LossFunctionFactory.CreateComponent(env); } private protected override PredictionKind PredictionKind => PredictionKind.Regression; diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs index a7b3fbfccb..faa340f5b6 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs @@ -1649,8 +1649,16 @@ public sealed class Options : BinaryOptionsBase /// /// If unspecified, will be used. /// - [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] - public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory(); + [Argument(ArgumentType.Multiple, Name = "LossFunction", HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] + internal ISupportSdcaClassificationLossFactory LossFunctionFactory = new LogLossFactory(); + + /// + /// The custom loss. + /// + /// + /// If unspecified, will be used. + /// + public ISupportSdcaClassificationLoss LossFunction { get; set; } } internal SdcaNonCalibratedBinaryTrainer(IHostEnvironment env, @@ -1666,7 +1674,7 @@ internal SdcaNonCalibratedBinaryTrainer(IHostEnvironment env, } internal SdcaNonCalibratedBinaryTrainer(IHostEnvironment env, Options options) - : base(env, options, options.LossFunction.CreateComponent(env)) + : base(env, options, options.LossFunction ?? options.LossFunctionFactory.CreateComponent(env)) { } diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index cbb61a2c0b..3bea3852cf 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -46,8 +46,16 @@ public sealed class Options : OptionsBase /// /// If unspecified, will be used. /// - [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] - public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory(); + [Argument(ArgumentType.Multiple, Name = "LossFunction", HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] + internal ISupportSdcaClassificationLossFactory LossFunctionFactory = new LogLossFactory(); + + /// + /// The custom loss. + /// + /// + /// If unspecified, will be used. + /// + public ISupportSdcaClassificationLoss LossFunction { get; set; } } private readonly ISupportSdcaClassificationLoss _loss; @@ -78,7 +86,7 @@ internal SdcaMultiClassTrainer(IHostEnvironment env, { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - _loss = loss ?? SdcaTrainerOptions.LossFunction.CreateComponent(env); + _loss = loss ?? SdcaTrainerOptions.LossFunction ?? SdcaTrainerOptions.LossFunctionFactory.CreateComponent(env); Loss = _loss; } @@ -89,7 +97,7 @@ internal SdcaMultiClassTrainer(IHostEnvironment env, Options options, Host.CheckValue(labelColumn, nameof(labelColumn)); Host.CheckValue(featureColumn, nameof(featureColumn)); - _loss = options.LossFunction.CreateComponent(env); + _loss = options.LossFunction ?? options.LossFunctionFactory.CreateComponent(env); Loss = _loss; } diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs index bc373fa3d0..53ed341007 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs @@ -43,8 +43,16 @@ public sealed class Options : OptionsBase /// /// Defaults to /// - [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] - public ISupportSdcaRegressionLossFactory LossFunction = new SquaredLossFactory(); + [Argument(ArgumentType.Multiple, Name = "LossFunction", HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] + internal ISupportSdcaRegressionLossFactory LossFunctionFactory = new SquaredLossFactory(); + + /// + /// A custom loss. + /// + /// + /// Defaults to + /// + public ISupportSdcaRegressionLoss LossFunction { get; set; } /// /// Create the object. @@ -87,7 +95,7 @@ internal SdcaRegressionTrainer(IHostEnvironment env, { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - _loss = loss ?? SdcaTrainerOptions.LossFunction.CreateComponent(env); + _loss = loss ?? SdcaTrainerOptions.LossFunction ?? SdcaTrainerOptions.LossFunctionFactory.CreateComponent(env); Loss = _loss; } @@ -97,7 +105,7 @@ internal SdcaRegressionTrainer(IHostEnvironment env, Options options, string fea Host.CheckValue(labelColumn, nameof(labelColumn)); Host.CheckValue(featureColumn, nameof(featureColumn)); - _loss = options.LossFunction.CreateComponent(env); + _loss = options.LossFunction ?? options.LossFunctionFactory.CreateComponent(env); Loss = _loss; } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestLoss.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestLoss.cs index 4daa225d92..f3f8df1305 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestLoss.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestLoss.cs @@ -28,7 +28,7 @@ public class TestLoss /// step, given label and output /// Whether the loss function is differentiable /// w.r.t. the output in the vicinity of the output value - private void TestHelper(IScalarOutputLoss lossFunc, double label, double output, double expectedLoss, double expectedUpdate, bool differentiable = true) + private void TestHelper(IScalarLoss lossFunc, double label, double output, double expectedLoss, double expectedUpdate, bool differentiable = true) { Double loss = lossFunc.Loss((float)output, (float)label); float derivative = lossFunc.Derivative((float)output, (float)label);