Skip to content

Conversion of Hogwild SGD to estimator #1134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions src/Microsoft.ML.Data/Training/TrainerUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -384,26 +384,51 @@ public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn)
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}

private static void CheckArgColName(IHostEnvironment host, string defaultColName, string argValue)
{
if (argValue != defaultColName)
throw host.Except($"Don't supply a value for the {defaultColName} column in the arguments, as it will be ignored. Specify them in the loader, or constructor instead instead.");
}

/// <summary>
/// Check that the label, feature, weights, groupId column names are not supplied in the args of the constructor, through the advancedSettings parameter,
/// for cases when the public constructor is called.
/// The recommendation is to set the column names directly.
/// </summary>
public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithGroupId args)
{
Action<string, string> checkArgColName = (defaultColName, argValue) =>
{
if (argValue != defaultColName)
throw host.Except($"Don't supply a value for the {defaultColName} column in the arguments, as it will be ignored. Specify them in the loader, or constructor instead instead.");
};

// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
checkArgColName(DefaultColumnNames.Label, args.LabelColumn);
checkArgColName(DefaultColumnNames.Features, args.FeatureColumn);
checkArgColName(DefaultColumnNames.Weight, args.WeightColumn);
CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn);
CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn);
CheckArgColName(host, DefaultColumnNames.Weight, args.WeightColumn);

if (args.GroupIdColumn != null)
checkArgColName(DefaultColumnNames.GroupId, args.GroupIdColumn);
CheckArgColName(host, DefaultColumnNames.GroupId, args.GroupIdColumn);
}

/// <summary>
/// Check that the label, feature, and weights column names are not supplied in the args of the constructor, through the advancedSettings parameter,
/// for cases when the public constructor is called.
/// The recommendation is to set the column names directly.
/// </summary>
public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithWeight args)
{
// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn);
CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn);
CheckArgColName(host, DefaultColumnNames.Weight, args.WeightColumn);
}

/// <summary>
/// Check that the label and feature column names are not supplied in the args of the constructor, through the advancedSettings parameter,
/// for cases when the public constructor is called.
/// The recommendation is to set the column names directly.
/// </summary>
public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithLabel args)
{
// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn);
CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn);
}

/// <summary>
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Legacy/CSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10524,10 +10524,10 @@ public sealed partial class StochasticGradientDescentBinaryClassifier : Microsof
public ClassificationLossFunction LossFunction { get; set; } = new LogLossClassificationLossFunction();

/// <summary>
/// L2 regularizer constant
/// L2 Regularization constant
/// </summary>
[TlcModule.SweepableDiscreteParamAttribute("L2Const", new object[]{1E-07f, 5E-07f, 1E-06f, 5E-06f, 1E-05f})]
public float L2Const { get; set; } = 1E-06f;
public float L2Weight { get; set; } = 1E-06f;

/// <summary>
/// Degree of lock-free parallelism. Defaults to automatic depending on data sparseness. Determinism not guaranteed.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Calibration;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.StaticPipe.Runtime;

namespace Microsoft.ML.StaticPipe
{
using Arguments = StochasticGradientDescentClassificationTrainer.Arguments;

/// <summary>
/// Binary Classification trainer estimators.
/// </summary>
public static partial class BinaryClassificationTrainers
{
/// <summary>
/// Predict a target using a linear binary classification model trained with the <see cref="Microsoft.ML.Runtime.Learners.StochasticGradientDescentClassificationTrainer"/> trainer.
/// </summary>
/// <param name="ctx">The binary classificaiton context trainer object.</param>
/// <param name="label">The name of the label column.</param>
/// <param name="features">The name of the feature column.</param>
/// <param name="weights">The name for the example weight column.</param>
/// <param name="maxIterations">The maximum number of iterations; set to 1 to simulate online learning.</param>
/// <param name="initLearningRate">The initial learning rate used by SGD.</param>
/// <param name="l2Weight">The L2 regularization constant.</param>
/// <param name="loss">The loss function to use.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
/// <param name="onFit">A delegate that is called every time the
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}.Fit(DataView{TTupleInShape})"/> method is called on the
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}"/> instance created out of this. This delegate will receive
/// the linear model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to
/// be informed about what was learnt.</param>
/// <returns>The predicted output.</returns>
public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel) StochasticGradientDescentClassificationTrainer(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
Scalar<bool> label,
Vector<float> features,
Scalar<float> weights = null,
int maxIterations = Arguments.Defaults.MaxIterations,
double initLearningRate = Arguments.Defaults.InitLearningRate,
float l2Weight = Arguments.Defaults.L2Weight,
ISupportClassificationLossFactory loss = null,
Action<Arguments> advancedSettings = null,
Action<IPredictorWithFeatureWeights<float>> onFit = null)
{
var rec = new TrainerEstimatorReconciler.BinaryClassifier(
(env, labelName, featuresName, weightsName) =>
{
var trainer = new StochasticGradientDescentClassificationTrainer(env, featuresName, labelName, weightsName, maxIterations, initLearningRate, l2Weight, loss, advancedSettings);

if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
return trainer;

}, label, features, weights);

return rec.Output;
}
}
}
Loading