From d1eebb4b8584caf65172cdaa80d00e582495b021 Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Wed, 3 Oct 2018 00:00:11 -0700 Subject: [PATCH 1/5] howeildsgd conversion to estimator and pigstension with related tests --- .../Standard/LinearClassificationStatics.cs | 63 ++++++++++ .../Standard/LinearClassificationTrainer.cs | 110 ++++++++++++++---- .../Training.cs | 38 ++++++ .../TrainerEstimators/HogwildSGDTests.cs | 23 ++++ 4 files changed, 210 insertions(+), 24 deletions(-) create mode 100644 src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatics.cs create mode 100644 test/Microsoft.ML.Tests/TrainerEstimators/HogwildSGDTests.cs diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatics.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatics.cs new file mode 100644 index 0000000000..a37374911b --- /dev/null +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatics.cs @@ -0,0 +1,63 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Internal.Calibration; +using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.StaticPipe.Runtime; + +namespace Microsoft.ML.StaticPipe +{ + using Arguments = StochasticGradientDescentClassificationTrainer.Arguments; + + /// + /// Binary Classification trainer estimators. + /// + public static partial class BinaryClassificationTrainers + { + /// + /// Predict a target using a linear binary classification model trained with the trainer. + /// + /// The binary classificaiton context trainer object. + /// The name of the label column. + /// The name of the feature column. + /// The name for the example weight column. + /// The maximum number of iterations; set to 1 to simulate online learning. + /// The initial Learning Rate used by SGD. + /// The L2 regularizer constant. + /// A delegate to apply all the advanced arguments to the algorithm. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to + /// be informed about what was learnt. + /// The predicted output. + public static (Scalar score, Scalar probability, Scalar predictedLabel) StochasticGradientDescentClassificationTrainer(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + Scalar label, + Vector features, + Scalar weights = null, + int maxIterations = Arguments.Defaults.MaxIterations, + double initLearningRate = Arguments.Defaults.InitLearningRate, + float l2Weight = Arguments.Defaults.L2Weight, + Action advancedSettings = null, + Action> onFit = null) + { + var rec = new TrainerEstimatorReconciler.BinaryClassifier( + (env, labelName, featuresName, weightsName) => + { + var trainer = new StochasticGradientDescentClassificationTrainer(env, featuresName, labelName, weightsName, maxIterations, initLearningRate, l2Weight, advancedSettings); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + + }, label, features, weights); + + return rec.Output; + } + } +} diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 29e2e1e89f..028936f14b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -43,9 +43,12 @@ namespace Microsoft.ML.Runtime.Learners using Stopwatch = System.Diagnostics.Stopwatch; using TScalarPredictor = IPredictorWithFeatureWeights; - public abstract class LinearTrainerBase : TrainerBase - where TPredictor : IPredictor + public abstract class LinearTrainerBase : TrainerEstimatorBase + where TTransformer : ISingleFeaturePredictionTransformer + where TModel : IPredictor { + private const string RegisterName = nameof(LinearTrainerBase); + protected bool NeedShuffle; private static readonly TrainerInfo _info = new TrainerInfo(); @@ -56,15 +59,17 @@ public abstract class LinearTrainerBase : TrainerBase /// protected abstract bool ShuffleData { get; } - private protected LinearTrainerBase(IHostEnvironment env, string name) - : base(env, name) + private protected LinearTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn, + string weightColumn = null) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), + labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { } - public override TPredictor Train(TrainContext context) + protected override TModel TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); - TPredictor pred; + TModel pred; using (var ch = Host.Start("Training")) { var preparedData = PrepareDataFromTrainingExamples(ch, context.TrainingSet, out int weightSetCount); @@ -79,7 +84,7 @@ public override TPredictor Train(TrainContext context) return pred; } - protected abstract TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearPredictor predictor, int weightSetCount); + protected abstract TModel TrainCore(IChannel ch, RoleMappedData data, LinearPredictor predictor, int weightSetCount); /// /// This method ensures that the data meets the requirements of this trainer and its @@ -1496,11 +1501,11 @@ protected override BinaryPredictionTransformer MakeTransformer } public sealed class StochasticGradientDescentClassificationTrainer : - LinearTrainerBase + LinearTrainerBase, TScalarPredictor> { - public const string LoadNameValue = "BinarySGD"; - public const string UserNameValue = "Hogwild SGD (binary)"; - public const string ShortName = "HogwildSGD"; + internal const string LoadNameValue = "BinarySGD"; + internal const string UserNameValue = "Hogwild SGD (binary)"; + internal const string ShortName = "HogwildSGD"; public sealed class Arguments : LearnerInputBaseWithWeight { @@ -1510,7 +1515,7 @@ public sealed class Arguments : LearnerInputBaseWithWeight [Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularizer constant", ShortName = "l2", SortOrder = 50)] [TGUI(Label = "L2 Regularizer Constant", SuggestedSweeps = "1e-7,5e-7,1e-6,5e-6,1e-5")] [TlcModule.SweepableDiscreteParam("L2Const", new object[] { 1e-7f, 5e-7f, 1e-6f, 5e-6f, 1e-5f })] - public float L2Const = (float)1e-6; + public float L2Weight = Defaults.L2Weight; [Argument(ArgumentType.AtMostOnce, HelpText = "Degree of lock-free parallelism. Defaults to automatic depending on data sparseness. Determinism not guaranteed.", ShortName = "nt,t,threads", SortOrder = 50)] [TGUI(Label = "Number of threads", SuggestedSweeps = "1,2,4")] @@ -1519,16 +1524,16 @@ public sealed class Arguments : LearnerInputBaseWithWeight [Argument(ArgumentType.AtMostOnce, HelpText = "Exponential moving averaged improvement tolerance for convergence", ShortName = "tol")] [TGUI(SuggestedSweeps = "1e-2,1e-3,1e-4,1e-5")] [TlcModule.SweepableDiscreteParam("ConvergenceTolerance", new object[] { 1e-2f, 1e-3f, 1e-4f, 1e-5f })] - public Double ConvergenceTolerance = 1e-4; + public double ConvergenceTolerance = 1e-4; [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of iterations; set to 1 to simulate online learning.", ShortName = "iter")] [TGUI(Label = "Max number of iterations", SuggestedSweeps = "1,5,10,20")] [TlcModule.SweepableDiscreteParam("MaxIterations", new object[] { 1, 5, 10, 20 })] - public int MaxIterations = 20; + public int MaxIterations = Defaults.MaxIterations; [Argument(ArgumentType.AtMostOnce, HelpText = "Initial learning rate (only used by SGD)", ShortName = "ilr,lr")] [TGUI(Label = "Initial Learning Rate (for SGD)")] - public Double InitLearningRate = 0.01; + public double InitLearningRate = Defaults.InitLearningRate; [Argument(ArgumentType.AtMostOnce, HelpText = "Shuffle data every epoch?", ShortName = "shuf")] [TlcModule.SweepableDiscreteParam("Shuffle", null, isBool: true)] @@ -1549,17 +1554,17 @@ public sealed class Arguments : LearnerInputBaseWithWeight internal void Check(IHostEnvironment env) { Contracts.CheckValue(env, nameof(env)); - env.CheckUserArg(L2Const >= 0, nameof(L2Const), "Must be non-negative."); + env.CheckUserArg(L2Weight >= 0, nameof(L2Weight), "Must be non-negative."); env.CheckUserArg(InitLearningRate > 0, nameof(InitLearningRate), "Must be positive."); env.CheckUserArg(MaxIterations > 0, nameof(MaxIterations), "Must be positive."); env.CheckUserArg(PositiveInstanceWeight > 0, nameof(PositiveInstanceWeight), "Must be positive"); - if (InitLearningRate * L2Const >= 1) + if (InitLearningRate * L2Weight >= 1) { using (var ch = env.Start("Argument Adjustment")) { ch.Warning("{0} {1} set too high; reducing to {1}", nameof(InitLearningRate), - InitLearningRate, InitLearningRate = (float)0.5 / L2Const); + InitLearningRate, InitLearningRate = (float)0.5 / L2Weight); ch.Done(); } } @@ -1567,6 +1572,12 @@ internal void Check(IHostEnvironment env) if (ConvergenceTolerance <= 0) ConvergenceTolerance = float.Epsilon; } + internal static class Defaults + { + internal const float L2Weight = 1e-6f; + internal const int MaxIterations = 20; + internal const double InitLearningRate = 0.01; + } } private readonly IClassificationLoss _loss; @@ -1578,8 +1589,46 @@ internal void Check(IHostEnvironment env) public override TrainerInfo Info { get; } - public StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Arguments args) - : base(env, LoadNameValue) + /// + /// Initializes a new instance of + /// + /// The environment to use. + /// The name of the feature column. + /// The name of the label column. + /// The name for the example weight column. + /// The maximum number of iterations; set to 1 to simulate online learning. + /// The initial Learning Rate used by SGD. + /// The L2 regularizer constant. + /// A delegate to apply all the advanced arguments to the algorithm. + public StochasticGradientDescentClassificationTrainer(IHostEnvironment env, string featureColumn, string labelColumn, string weightColumn = null, + int maxIterations = Arguments.Defaults.MaxIterations, + double initLearningRate = Arguments.Defaults.InitLearningRate, + float l2Weight = Arguments.Defaults.L2Weight, + Action advancedSettings = null) + : base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), weightColumn) + { + _args = new Arguments(); + advancedSettings?.Invoke(_args); + + // Apply the advanced args, if the user supplied any. + _args.FeatureColumn = featureColumn; + _args.LabelColumn = labelColumn; + _args.WeightColumn = weightColumn; + _args.MaxIterations = maxIterations; + _args.InitLearningRate = initLearningRate; + _args.L2Weight = l2Weight; + _args.Check(env); + + _loss = _args.LossFunction.CreateComponent(env); + Info = new TrainerInfo(calibration: !(_loss is LogLoss), supportIncrementalTrain: true); + NeedShuffle = _args.Shuffle; + } + + /// + /// Initializes a new instance of + /// + internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Arguments args) + : base(env, args.FeatureColumn, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn), args.WeightColumn) { args.Check(env); _loss = args.LossFunction.CreateComponent(env); @@ -1588,6 +1637,19 @@ public StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Argu _args = args; } + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), + new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; + } + + protected override BinaryPredictionTransformer MakeTransformer(TScalarPredictor model, ISchema trainSchema) + => new BinaryPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + //For complexity analysis, we assume that // - The number of features is N // - Average number of non-zero per instance is k @@ -1613,7 +1675,7 @@ protected override TScalarPredictor TrainCore(IChannel ch, RoleMappedData data, int checkFrequency = _args.CheckFrequency ?? numThreads; if (checkFrequency <= 0) checkFrequency = int.MaxValue; - var l2Const = _args.L2Const; + var l2Weight = _args.L2Weight; var lossFunc = _loss; var pOptions = new ParallelOptions { MaxDegreeOfParallelism = numThreads }; var positiveInstanceWeight = _args.PositiveInstanceWeight; @@ -1660,7 +1722,7 @@ protected override TScalarPredictor TrainCore(IChannel ch, RoleMappedData data, } } - var newLoss = lossSum.Sum / count + l2Const * VectorUtils.NormSquared(weights) * 0.5; + var newLoss = lossSum.Sum / count + l2Weight * VectorUtils.NormSquared(weights) * 0.5; improvement = improvement == 0 ? loss - newLoss : 0.5 * (loss - newLoss + improvement); loss = newLoss; @@ -1698,9 +1760,9 @@ protected override TScalarPredictor TrainCore(IChannel ch, RoleMappedData data, if (label > 0) derivative *= positiveInstanceWeight; - Double rate = ilr / (1 + ilr * l2Const * (t++)); + Double rate = ilr / (1 + ilr * l2Weight * (t++)); Double step = -derivative * rate; - weightScaling *= 1 - rate * l2Const; + weightScaling *= 1 - rate * l2Weight; VectorUtils.AddMult(ref features, weights.Values, (float)(step / weightScaling)); bias += (float)step; } diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index 92bfe1e2c1..89e99fea72 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -749,5 +749,43 @@ public void FastTreeRanking() Assert.InRange(metrics.Ndcg[1], 36.5, 37); Assert.InRange(metrics.Ndcg[2], 36.5, 37); } + + [Fact] + public void HogwildSGDBinaryClassification() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); + var dataSource = new MultiFileSource(dataPath); + var ctx = new BinaryClassificationContext(env); + + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadBool(0), features: c.LoadFloat(1, 9))); + + IPredictorWithFeatureWeights pred = null; + + var est = reader.MakeNewEstimator() + .Append(r => (r.label, preds: ctx.Trainers.StochasticGradientDescentClassificationTrainer(r.label, r.features, + l2Weight: 0, + onFit: (p) => { pred = p; }))); + + var pipe = reader.Append(est); + + Assert.Null(pred); + var model = pipe.Fit(dataSource); + Assert.NotNull(pred); + + // 9 input features, so we ought to have 9 weights. + VBuffer weights = new VBuffer(); + pred.GetFeatureWeights(ref weights); + Assert.Equal(9, weights.Length); + + var data = model.Read(dataSource); + + var metrics = ctx.Evaluate(data, r => r.label, r => r.preds); + // Run a sanity check against a few of the metrics. + Assert.InRange(metrics.Accuracy, 0, 1); + Assert.InRange(metrics.Auc, 0, 1); + Assert.InRange(metrics.Auprc, 0, 1); + } } } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/HogwildSGDTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/HogwildSGDTests.cs new file mode 100644 index 0000000000..802cb205e0 --- /dev/null +++ b/test/Microsoft.ML.Tests/TrainerEstimators/HogwildSGDTests.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Learners; +using Xunit; + +namespace Microsoft.ML.Tests.TrainerEstimators +{ + public partial class TrainerEstimators + { + [Fact] + public void TestEstimatorHogwildSGD() + { + (IEstimator pipe, IDataView dataView) = GetBinaryClassificationPipeline(); + pipe.Append(new StochasticGradientDescentClassificationTrainer(Env, "Features", "Label")); + TestEstimatorCore(pipe, dataView); + Done(); + } + } +} From bd18b1d0101477fd5906b74768fa1bbe162a76d1 Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Wed, 3 Oct 2018 00:07:44 -0700 Subject: [PATCH 2/5] add check --- .../Standard/LinearClassificationTrainer.cs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 028936f14b..d9953de07c 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -1607,6 +1607,9 @@ public StochasticGradientDescentClassificationTrainer(IHostEnvironment env, stri Action advancedSettings = null) : base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), weightColumn) { + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + _args = new Arguments(); advancedSettings?.Invoke(_args); From 7ffaac94138bb1ba49f048d4b6967f800232bff1 Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Wed, 3 Oct 2018 12:59:19 -0700 Subject: [PATCH 3/5] fixed errors --- src/Microsoft.ML.Legacy/CSharpApi.cs | 4 ++-- .../Standard/LinearClassificationStatics.cs | 2 +- .../Standard/LinearClassificationTrainer.cs | 4 ++-- test/BaselineOutput/Common/EntryPoints/core_manifest.json | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/Microsoft.ML.Legacy/CSharpApi.cs b/src/Microsoft.ML.Legacy/CSharpApi.cs index 8b8790ddd7..a1aa6bcc75 100644 --- a/src/Microsoft.ML.Legacy/CSharpApi.cs +++ b/src/Microsoft.ML.Legacy/CSharpApi.cs @@ -10524,10 +10524,10 @@ public sealed partial class StochasticGradientDescentBinaryClassifier : Microsof public ClassificationLossFunction LossFunction { get; set; } = new LogLossClassificationLossFunction(); /// - /// L2 regularizer constant + /// L2 Regularization constant /// [TlcModule.SweepableDiscreteParamAttribute("L2Const", new object[]{1E-07f, 5E-07f, 1E-06f, 5E-06f, 1E-05f})] - public float L2Const { get; set; } = 1E-06f; + public float L2Weight { get; set; } = 1E-06f; /// /// Degree of lock-free parallelism. Defaults to automatic depending on data sparseness. Determinism not guaranteed. diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatics.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatics.cs index a37374911b..5d0a8db238 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatics.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatics.cs @@ -28,7 +28,7 @@ public static partial class BinaryClassificationTrainers /// The name for the example weight column. /// The maximum number of iterations; set to 1 to simulate online learning. /// The initial Learning Rate used by SGD. - /// The L2 regularizer constant. + /// The L2 regularization constant. /// A delegate to apply all the advanced arguments to the algorithm. /// A delegate that is called every time the /// method is called on the diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index d9953de07c..81545e17cf 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -1512,8 +1512,8 @@ public sealed class Arguments : LearnerInputBaseWithWeight [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] public ISupportClassificationLossFactory LossFunction = new LogLossFactory(); - [Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularizer constant", ShortName = "l2", SortOrder = 50)] - [TGUI(Label = "L2 Regularizer Constant", SuggestedSweeps = "1e-7,5e-7,1e-6,5e-6,1e-5")] + [Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization constant", ShortName = "l2", SortOrder = 50)] + [TGUI(Label = "L2 Regularization Constant", SuggestedSweeps = "1e-7,5e-7,1e-6,5e-6,1e-5")] [TlcModule.SweepableDiscreteParam("L2Const", new object[] { 1e-7f, 5e-7f, 1e-6f, 5e-6f, 1e-5f })] public float L2Weight = Defaults.L2Weight; diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 4dae0b59cf..e6005b54d6 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -16518,9 +16518,9 @@ } }, { - "Name": "L2Const", + "Name": "L2Weight", "Type": "Float", - "Desc": "L2 regularizer constant", + "Desc": "L2 Regularization constant", "Aliases": [ "l2" ], From b2e07d7a0f849343bb926f0ba2a4953f097b9560 Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Thu, 4 Oct 2018 14:13:29 -0700 Subject: [PATCH 4/5] fixed review comments --- .../Training/TrainerUtils.cs | 45 ++++++++++++++----- .../Standard/LinearClassificationStatics.cs | 4 +- .../Standard/LinearClassificationTrainer.cs | 35 +++++++++++++++ .../TrainerEstimators/HogwildSGDTests.cs | 23 ---------- .../TrainerEstimators/TrainerEstimators.cs | 12 +++++ 5 files changed, 85 insertions(+), 34 deletions(-) delete mode 100644 test/Microsoft.ML.Tests/TrainerEstimators/HogwildSGDTests.cs diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs index 7eadddbba2..dff5748635 100644 --- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs +++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs @@ -384,6 +384,12 @@ public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn) return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } + private static void CheckArgColName(IHostEnvironment host, string defaultColName, string argValue) + { + if (argValue != defaultColName) + throw host.Except($"Don't supply a value for the {defaultColName} column in the arguments, as it will be ignored. Specify them in the loader, or constructor instead instead."); + } + /// /// Check that the label, feature, weights, groupId column names are not supplied in the args of the constructor, through the advancedSettings parameter, /// for cases when the public constructor is called. @@ -391,19 +397,38 @@ public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn) /// public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithGroupId args) { - Action checkArgColName = (defaultColName, argValue) => - { - if (argValue != defaultColName) - throw host.Except($"Don't supply a value for the {defaultColName} column in the arguments, as it will be ignored. Specify them in the loader, or constructor instead instead."); - }; - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - checkArgColName(DefaultColumnNames.Label, args.LabelColumn); - checkArgColName(DefaultColumnNames.Features, args.FeatureColumn); - checkArgColName(DefaultColumnNames.Weight, args.WeightColumn); + CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn); + CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn); + CheckArgColName(host, DefaultColumnNames.Weight, args.WeightColumn); if (args.GroupIdColumn != null) - checkArgColName(DefaultColumnNames.GroupId, args.GroupIdColumn); + CheckArgColName(host, DefaultColumnNames.GroupId, args.GroupIdColumn); + } + + /// + /// Check that the label, feature, and weights column names are not supplied in the args of the constructor, through the advancedSettings parameter, + /// for cases when the public constructor is called. + /// The recommendation is to set the column names directly. + /// + public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithWeight args) + { + // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly + CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn); + CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn); + CheckArgColName(host, DefaultColumnNames.Weight, args.WeightColumn); + } + + /// + /// Check that the label and feature column names are not supplied in the args of the constructor, through the advancedSettings parameter, + /// for cases when the public constructor is called. + /// The recommendation is to set the column names directly. + /// + public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithLabel args) + { + // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly + CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn); + CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn); } /// diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatics.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatics.cs index 5d0a8db238..5e703b18f6 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatics.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatics.cs @@ -29,6 +29,7 @@ public static partial class BinaryClassificationTrainers /// The maximum number of iterations; set to 1 to simulate online learning. /// The initial Learning Rate used by SGD. /// The L2 regularization constant. + /// The loss function to use. /// A delegate to apply all the advanced arguments to the algorithm. /// A delegate that is called every time the /// method is called on the @@ -43,13 +44,14 @@ public static (Scalar score, Scalar probability, Scalar pred int maxIterations = Arguments.Defaults.MaxIterations, double initLearningRate = Arguments.Defaults.InitLearningRate, float l2Weight = Arguments.Defaults.L2Weight, + ISupportClassificationLossFactory loss = null, Action advancedSettings = null, Action> onFit = null) { var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => { - var trainer = new StochasticGradientDescentClassificationTrainer(env, featuresName, labelName, weightsName, maxIterations, initLearningRate, l2Weight, advancedSettings); + var trainer = new StochasticGradientDescentClassificationTrainer(env, featuresName, labelName, weightsName, maxIterations, initLearningRate, l2Weight, loss, advancedSettings); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 81545e17cf..0208589fc2 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -1599,11 +1599,13 @@ internal static class Defaults /// The maximum number of iterations; set to 1 to simulate online learning. /// The initial Learning Rate used by SGD. /// The L2 regularizer constant. + /// The loss function to use. /// A delegate to apply all the advanced arguments to the algorithm. public StochasticGradientDescentClassificationTrainer(IHostEnvironment env, string featureColumn, string labelColumn, string weightColumn = null, int maxIterations = Arguments.Defaults.MaxIterations, double initLearningRate = Arguments.Defaults.InitLearningRate, float l2Weight = Arguments.Defaults.L2Weight, + ISupportClassificationLossFactory loss = null, Action advancedSettings = null) : base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), weightColumn) { @@ -1613,6 +1615,12 @@ public StochasticGradientDescentClassificationTrainer(IHostEnvironment env, stri _args = new Arguments(); advancedSettings?.Invoke(_args); + // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly + TrainerUtils.CheckArgsHaveDefaultColNames(Host, _args); + + if (advancedSettings != null) + CheckArgsAndAdvancedSettingMismatch(maxIterations, initLearningRate, l2Weight, loss, new Arguments(), _args); + // Apply the advanced args, if the user supplied any. _args.FeatureColumn = featureColumn; _args.LabelColumn = labelColumn; @@ -1620,6 +1628,8 @@ public StochasticGradientDescentClassificationTrainer(IHostEnvironment env, stri _args.MaxIterations = maxIterations; _args.InitLearningRate = initLearningRate; _args.L2Weight = l2Weight; + if (loss != null) + _args.LossFunction = loss; _args.Check(env); _loss = _args.LossFunction.CreateComponent(env); @@ -1640,6 +1650,31 @@ internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Ar _args = args; } + /// + /// If, after applying the advancedSettings delegate, the args are different that the default value + /// and are also different than the value supplied directly to the xtension method, warn the user + /// about which value is being used. + /// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune. + /// This list should follow the one in the constructor, and the extension methods on the . + /// + internal void CheckArgsAndAdvancedSettingMismatch(int maxIterations, + double initLearningRate, + float l2Weight, + ISupportClassificationLossFactory loss, + Arguments snapshot, + Arguments currentArgs) + { + using (var ch = Host.Start("Comparing advanced settings with the directly provided values.")) + { + // Check that the user didn't supply different parameters in the args, from what it specified directly. + TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, maxIterations, snapshot.MaxIterations, currentArgs.MaxIterations, nameof(maxIterations)); + TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, initLearningRate, snapshot.InitLearningRate, currentArgs.InitLearningRate, nameof(initLearningRate)); + TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, l2Weight, snapshot.L2Weight, currentArgs.L2Weight, nameof(l2Weight)); + TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, loss, snapshot.LossFunction, currentArgs.LossFunction, nameof(loss)); + ch.Done(); + } + } + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { return new[] diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/HogwildSGDTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/HogwildSGDTests.cs deleted file mode 100644 index 802cb205e0..0000000000 --- a/test/Microsoft.ML.Tests/TrainerEstimators/HogwildSGDTests.cs +++ /dev/null @@ -1,23 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Learners; -using Xunit; - -namespace Microsoft.ML.Tests.TrainerEstimators -{ - public partial class TrainerEstimators - { - [Fact] - public void TestEstimatorHogwildSGD() - { - (IEstimator pipe, IDataView dataView) = GetBinaryClassificationPipeline(); - pipe.Append(new StochasticGradientDescentClassificationTrainer(Env, "Features", "Label")); - TestEstimatorCore(pipe, dataView); - Done(); - } - } -} diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs index 49b74c7fa2..c37238e24f 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs @@ -77,6 +77,18 @@ public void KMeansEstimator() Done(); } + /// + /// HogwildSGD TrainerEstimator test + /// + [Fact] + public void TestEstimatorHogwildSGD() + { + (IEstimator pipe, IDataView dataView) = GetBinaryClassificationPipeline(); + pipe.Append(new StochasticGradientDescentClassificationTrainer(Env, "Features", "Label")); + TestEstimatorCore(pipe, dataView); + Done(); + } + private (IEstimator, IDataView) GetBinaryClassificationPipeline() { var data = new TextLoader(Env, From 2d84a9f92a9f45d2438592f65a8b45b7b2fe230c Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Fri, 5 Oct 2018 10:35:08 -0700 Subject: [PATCH 5/5] corrected comment --- .../Standard/LinearClassificationStatic.cs | 2 +- .../Standard/LinearClassificationTrainer.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatic.cs index 5e703b18f6..59880c99a4 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatic.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatic.cs @@ -27,7 +27,7 @@ public static partial class BinaryClassificationTrainers /// The name of the feature column. /// The name for the example weight column. /// The maximum number of iterations; set to 1 to simulate online learning. - /// The initial Learning Rate used by SGD. + /// The initial learning rate used by SGD. /// The L2 regularization constant. /// The loss function to use. /// A delegate to apply all the advanced arguments to the algorithm. diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 0208589fc2..7bcb8ef04c 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -1597,7 +1597,7 @@ internal static class Defaults /// The name of the label column. /// The name for the example weight column. /// The maximum number of iterations; set to 1 to simulate online learning. - /// The initial Learning Rate used by SGD. + /// The initial learning rate used by SGD. /// The L2 regularizer constant. /// The loss function to use. /// A delegate to apply all the advanced arguments to the algorithm.