Skip to content

Commit ab32439

Browse files
artidoroshauheen
authored andcommitted
Conversion of Parallel Stochastic Gradient Descent (SymSGD) to estimator (#1012)
1 parent b4a95aa commit ab32439

File tree

2 files changed

+79
-19
lines changed

2 files changed

+79
-19
lines changed

src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using Float = System.Single;
6-
75
using System;
86
using System.Collections.Generic;
97
using System.Runtime.InteropServices;
108
using System.Security;
9+
using Microsoft.ML.Core.Data;
1110
using Microsoft.ML.Runtime;
1211
using Microsoft.ML.Runtime.CommandLine;
1312
using Microsoft.ML.Runtime.Data;
@@ -26,20 +25,18 @@
2625
SymSgdClassificationTrainer.LoadNameValue,
2726
SymSgdClassificationTrainer.ShortName)]
2827

29-
[assembly: LoadableClass(typeof(void), typeof(SymSgdClassificationTrainer), null, typeof(SignatureEntryPointModule), "SymSGD")]
28+
[assembly: LoadableClass(typeof(void), typeof(SymSgdClassificationTrainer), null, typeof(SignatureEntryPointModule), SymSgdClassificationTrainer.LoadNameValue)]
3029

3130
namespace Microsoft.ML.Runtime.SymSgd
3231
{
33-
using TPredictor = IPredictorWithFeatureWeights<Float>;
32+
using TPredictor = IPredictorWithFeatureWeights<float>;
3433

3534
/// <include file='doc.xml' path='doc/members/member[@name="SymSGD"]/*' />
36-
public sealed class SymSgdClassificationTrainer :
37-
TrainerBase<TPredictor>,
38-
ITrainer<TPredictor>
35+
public sealed class SymSgdClassificationTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<TPredictor>, TPredictor>
3936
{
40-
public const string LoadNameValue = "SymbolicSGD";
41-
public const string UserNameValue = "Symbolic SGD (binary)";
42-
public const string ShortName = "SymSGD";
37+
internal const string LoadNameValue = "SymbolicSGD";
38+
internal const string UserNameValue = "Symbolic SGD (binary)";
39+
internal const string ShortName = "SymSGD";
4340

4441
public sealed class Arguments : LearnerInputBaseWithLabel
4542
{
@@ -78,7 +75,7 @@ public sealed class Arguments : LearnerInputBaseWithLabel
7875
public bool Shuffle = true;
7976

8077
[Argument(ArgumentType.AtMostOnce, HelpText = "Apply weight to the positive class, for imbalanced data", ShortName = "piw")]
81-
public Float PositiveInstanceWeight = 1;
78+
public float PositiveInstanceWeight = 1;
8279

8380
public void Check(IExceptionContext ectx)
8481
{
@@ -135,7 +132,7 @@ private RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedDa
135132
return examplesToFeedTrain;
136133
}
137134

138-
public override TPredictor Train(TrainContext context)
135+
protected override TPredictor TrainModelCore(TrainContext context)
139136
{
140137
Host.CheckValue(context, nameof(context));
141138
TPredictor pred;
@@ -155,25 +152,64 @@ public override TPredictor Train(TrainContext context)
155152

156153
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
157154

158-
public SymSgdClassificationTrainer(IHostEnvironment env, Arguments args)
159-
: base(env, LoadNameValue)
155+
/// <summary>
156+
/// Initializes a new instance of <see cref="SymSgdClassificationTrainer"/>
157+
/// </summary>
158+
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
159+
/// <param name="labelColumn">The name of the label column.</param>
160+
/// <param name="featureColumn">The name of the feature column.</param>
161+
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
162+
public SymSgdClassificationTrainer(IHostEnvironment env, string featureColumn, string labelColumn, Action<Arguments> advancedSettings = null)
163+
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn),
164+
TrainerUtils.MakeBoolScalarLabel(labelColumn))
165+
{
166+
_args = new Arguments();
167+
168+
// Apply the advanced args, if the user supplied any.
169+
_args.Check(Host);
170+
advancedSettings?.Invoke(_args);
171+
_args.FeatureColumn = featureColumn;
172+
_args.LabelColumn = labelColumn;
173+
174+
Info = new TrainerInfo();
175+
}
176+
177+
/// <summary>
178+
/// Initializes a new instance of <see cref="SymSgdClassificationTrainer"/>
179+
/// </summary>
180+
internal SymSgdClassificationTrainer(IHostEnvironment env, Arguments args)
181+
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
182+
TrainerUtils.MakeBoolScalarLabel(args.LabelColumn))
160183
{
161184
args.Check(Host);
162185
_args = args;
163186
Info = new TrainerInfo();
164187
}
165188

166-
private TPredictor CreatePredictor(VBuffer<Float> weights, Float bias)
189+
private TPredictor CreatePredictor(VBuffer<float> weights, float bias)
167190
{
168191
Host.CheckParam(weights.Length > 0, nameof(weights));
169192

170-
VBuffer<Float> maybeSparseWeights = default;
193+
VBuffer<float> maybeSparseWeights = default;
171194
VBufferUtils.CreateMaybeSparseCopy(ref weights, ref maybeSparseWeights,
172-
Conversions.Instance.GetIsDefaultPredicate<Float>(NumberType.Float));
195+
Conversions.Instance.GetIsDefaultPredicate<float>(NumberType.R4));
173196
var predictor = new LinearBinaryPredictor(Host, ref maybeSparseWeights, bias);
174197
return new ParameterMixingCalibratedPredictor(Host, predictor, new PlattCalibrator(Host, -1, 0));
175198
}
176199

200+
protected override BinaryPredictionTransformer<TPredictor> MakeTransformer(TPredictor model, ISchema trainSchema)
201+
=> new BinaryPredictionTransformer<TPredictor>(Host, model, trainSchema, FeatureColumn.Name);
202+
203+
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
204+
{
205+
return new[]
206+
{
207+
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),
208+
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))),
209+
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))
210+
};
211+
}
212+
177213
[TlcModule.EntryPoint(Name = "Trainers.SymSgdBinaryClassifier",
178214
Desc = "Train a symbolic SGD.",
179215
UserName = SymSgdClassificationTrainer.UserNameValue,
@@ -496,9 +532,9 @@ public void LoadAsMuchAsPossible()
496532
continue;
497533
}
498534

499-
// We assume that cursor.Features.values are represented by Float and cursor.Features.indices are represented by int
535+
// We assume that cursor.Features.values are represented by float and cursor.Features.indices are represented by int
500536
// We conservatively assume that an instance is sparse and therefore, it has an array of Floats and ints for values and indices
501-
int perNonZeroInBytes = sizeof(Float) + sizeof(int);
537+
int perNonZeroInBytes = sizeof(float) + sizeof(int);
502538
if (featureCount > _trainer.AcceleratedMemoryBudgetBytes / perNonZeroInBytes)
503539
{
504540
// Hopefully this never happens. But the memorySize must >= perNonZeroInBytes * length(the longest instance).
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Core.Data;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Runtime.RunTests;
8+
using Microsoft.ML.Runtime.SymSgd;
9+
using Xunit;
10+
11+
namespace Microsoft.ML.Tests.TrainerEstimators
12+
{
13+
public partial class TrainerEstimators
14+
{
15+
[Fact]
16+
public void TestEstimatorSymSgdClassificationTrainer()
17+
{
18+
(var pipe, var dataView) = GetBinaryClassificationPipeline();
19+
pipe.Append(new SymSgdClassificationTrainer(Env, "Features", "Label"));
20+
TestEstimatorCore(pipe, dataView);
21+
Done();
22+
}
23+
}
24+
}

0 commit comments

Comments
 (0)