Skip to content

Commit 6697257

Browse files
committed
conversion of symsgdclassificationtrainer, added new estimator test, still need to pass the weight column
1 parent 5e8fcdb commit 6697257

File tree

2 files changed

+71
-16
lines changed

2 files changed

+71
-16
lines changed

src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Collections.Generic;
77
using System.Runtime.InteropServices;
88
using System.Security;
9+
using Microsoft.ML.Core.Data;
910
using Microsoft.ML.Runtime;
1011
using Microsoft.ML.Runtime.CommandLine;
1112
using Microsoft.ML.Runtime.Data;
@@ -131,7 +132,7 @@ private RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedDa
131132
return examplesToFeedTrain;
132133
}
133134

134-
public override TPredictor Train(TrainContext context)
135+
protected override TPredictor TrainModelCore(TrainContext context)
135136
{
136137
Host.CheckValue(context, nameof(context));
137138
TPredictor pred;
@@ -161,29 +162,33 @@ public override TPredictor Train(TrainContext context)
161162
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
162163
public SymSgdClassificationTrainer(IHostEnvironment env, string featureColumn, string labelColumn,
163164
string weightColumn = null, Action<Arguments> advancedSettings = null)
164-
: this(env, ArgsInit(featureColumn, labelColumn, weightColumn, advancedSettings))
165+
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn),
166+
TrainerUtils.MakeR4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
165167
{
166-
}
168+
var args = new Arguments();
167169

168-
public SymSgdClassificationTrainer(IHostEnvironment env, Arguments args)
169-
: base(env, LoadNameValue)
170-
{
170+
//apply the advanced args, if the user supplied any
171171
args.Check(Host);
172+
advancedSettings?.Invoke(args);
173+
args.FeatureColumn = featureColumn;
174+
args.LabelColumn = labelColumn;
175+
// TODO:
176+
//args.WeightColumn = weightColumn;
177+
172178
_args = args;
173179
Info = new TrainerInfo();
174180
}
175181

176-
private static Arguments ArgsInit(string featureColumn, string labelColumn,
177-
string weightColumn, Action<Arguments> advancedSettings)
182+
/// <summary>
183+
/// Initializes a new instance of <see cref="SymSgdClassificationTrainer"/>
184+
/// </summary>
185+
internal SymSgdClassificationTrainer(IHostEnvironment env, Arguments args)
186+
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
187+
TrainerUtils.MakeR4ScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(null))
178188
{
179-
var args = new Arguments();
180-
181-
//apply the advanced args, if the user supplied any
182-
advancedSettings?.Invoke(args);
183-
args.FeatureColumn = featureColumn;
184-
args.LabelColumn = labelColumn;
185-
args.WeightColumn = weightColumn;
186-
return args;
189+
args.Check(Host);
190+
_args = args;
191+
Info = new TrainerInfo();
187192
}
188193

189194
private TPredictor CreatePredictor(VBuffer<float> weights, float bias)
@@ -197,6 +202,19 @@ private TPredictor CreatePredictor(VBuffer<float> weights, float bias)
197202
return new ParameterMixingCalibratedPredictor(Host, predictor, new PlattCalibrator(Host, -1, 0));
198203
}
199204

205+
protected override BinaryPredictionTransformer<TPredictor> MakeTransformer(TPredictor model, ISchema trainSchema)
206+
=> new BinaryPredictionTransformer<TPredictor>(Host, model, trainSchema, FeatureColumn.Name);
207+
208+
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
209+
{
210+
return new[]
211+
{
212+
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),
213+
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))),
214+
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))
215+
};
216+
}
217+
200218
[TlcModule.EntryPoint(Name = "Trainers.SymSgdBinaryClassifier",
201219
Desc = "Train a symbolic SGD.",
202220
UserName = SymSgdClassificationTrainer.UserNameValue,
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime.Data;
6+
using Microsoft.ML.Runtime.RunTests;
7+
using Microsoft.ML.Runtime.SymSgd;
8+
using Xunit;
9+
10+
namespace Microsoft.ML.Tests.TrainerEstimators
11+
{
12+
public partial class TrainerEstimators
13+
{
14+
private IDataView GetBreastCancerDataview()
15+
{
16+
return new TextLoader(Env,
17+
new TextLoader.Arguments()
18+
{
19+
HasHeader = true,
20+
Column = new[]
21+
{
22+
new TextLoader.Column("Label", DataKind.R4, 0),
23+
new TextLoader.Column("Features", DataKind.R4, new [] { new TextLoader.Range(1, 9) } )
24+
}
25+
}).Read(new MultiFileSource(GetDataPath(TestDatasets.breastCancer.trainFilename)));
26+
}
27+
28+
[Fact]
29+
public void TestEstimatorSymSgdClassificationTrainer()
30+
{
31+
var dataView = GetBreastCancerDataview();
32+
var pipe = new SymSgdClassificationTrainer(Env, "Features", "Label");
33+
TestEstimatorCore(pipe, dataView);
34+
Done();
35+
}
36+
}
37+
}

0 commit comments

Comments
 (0)