Skip to content

Commit 7d6f15f

Browse files
committed
Adding tests, and re-establishing the role-mapped data before calibration
1 parent adcb4d2 commit 7d6f15f

File tree

5 files changed

+144
-54
lines changed

5 files changed

+144
-54
lines changed

src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515

1616
namespace Microsoft.ML.Runtime.Learners
1717
{
18-
// using TScalarTrainer = ITrainer<IPredictorProducing<Float>>;
19-
using TScalarTrainer = ITrainerEstimator<IPredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>;
20-
// using TScalarTrainer = IEstimator<IPredictionTransformer<IPredictorProducing<float>>>;
18+
using TScalarTrainer = ITrainerEstimator<IPredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>;
2119

2220
public abstract class MetaMulticlassTrainer<TTransformer, TModel> : ITrainerEstimator<TTransformer, TModel>, ITrainer<TModel>
2321
where TTransformer : IPredictionTransformer<TModel>
@@ -81,8 +79,7 @@ internal MetaMulticlassTrainer(IHostEnvironment env, ArgumentsBase args, string
8179
// Create the first trainer so errors in the args surface early.
8280
_trainer = singleEstimator?? CreateTrainer();
8381

84-
if (calibrator != null)
85-
Calibrator = calibrator;
82+
Calibrator = calibrator?? null;
8683

8784
if (args.Calibrator != null)
8885
Calibrator = args.Calibrator.CreateComponent(Host);
@@ -105,27 +102,26 @@ private TScalarTrainer CreateTrainer()
105102
new LinearSvm(Host, new LinearSvm.Arguments());
106103
}
107104

108-
protected IDataView MapLabelsCore<T>(ColumnType type, RefPredicate<T> equalsTarget, RoleMappedData data, string dstName)
105+
protected IDataView MapLabelsCore<T>(ColumnType type, RefPredicate<T> equalsTarget, RoleMappedData data)
109106
{
110107
Host.AssertValue(type);
111108
Host.Assert(type.RawType == typeof(T));
112109
Host.AssertValue(equalsTarget);
113110
Host.AssertValue(data);
114111
Host.AssertValue(data.Schema.Label);
115-
Host.AssertNonWhiteSpace(dstName);
116112

117113
var lab = data.Schema.Label;
118114

119115
RefPredicate<T> isMissing;
120116
if (!Args.ImputeMissingLabelsAsNegative && Conversions.Instance.TryGetIsNAPredicate(type, out isMissing))
121117
{
122118
return LambdaColumnMapper.Create(Host, "Label mapper", data.Data,
123-
lab.Name, dstName, type, NumberType.Float,
119+
lab.Name, lab.Name, type, NumberType.Float,
124120
(ref T src, ref float dst) =>
125121
dst = equalsTarget(ref src) ? 1 : (isMissing(ref src) ? float.NaN : default(float)));
126122
}
127123
return LambdaColumnMapper.Create(Host, "Label mapper", data.Data,
128-
lab.Name, dstName, type, NumberType.Float,
124+
lab.Name, lab.Name, type, NumberType.Float,
129125
(ref T src, ref float dst) =>
130126
dst = equalsTarget(ref src) ? 1 : default(float));
131127
}

src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
namespace Microsoft.ML.Runtime.Learners
3737
{
3838
using TScalarPredictor = IPredictorProducing<Float>;
39-
// using TScalarTrainer = IEstimator<IPredictionTransformer<IPredictor>>;
4039
using TScalarTrainer = ITrainerEstimator<IPredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>;
4140
using CR = RoleMappedSchema.ColumnRole;
4241

@@ -113,25 +112,24 @@ protected override OvaPredictor TrainCore(IChannel ch, RoleMappedData data, int
113112

114113
private IPredictionTransformer<TScalarPredictor> TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls)
115114
{
116-
string dstName;
117-
var view = MapLabels(data, cls, out dstName);
115+
var view = MapLabels(data, cls);
118116

119117
string trainerLabel = data.Schema.Label.Name;
120118

121-
//copy the newly created column in the what the learner knows as the label column
122-
var trainerData = new CopyColumnsTransform(Host, (dstName, trainerLabel)).Transform(view);
123-
124119
// REVIEW: In principle we could support validation sets and the like via the train context, but
125120
// this is currently unsupported.
126-
var transformer = trainer.Fit(trainerData);
121+
var transformer = trainer.Fit(view);
127122

128123
if (_args.UseProbabilities)
129124
{
130125
var calibratedModel = transformer.Model as TScalarPredictor;
131126

127+
// the validations in the calibrator check for the feature column, in the RoleMappedData
128+
var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumn);
129+
132130
if (calibratedModel == null)
133131
// calibratedModel = CalibratorUtils.TrainCalibratorIfNeeded(Host, ch, calibrator, _args.MaxCalibrationExamples, trainer, transformer.Model, data) as TScalarPredictor;
134-
calibratedModel = CalibratorUtils.TrainCalibrator(Host, ch, Calibrator, Args.MaxCalibrationExamples, transformer.Model, data) as TScalarPredictor;
132+
calibratedModel = CalibratorUtils.TrainCalibrator(Host, ch, Calibrator, Args.MaxCalibrationExamples, transformer.Model, trainedData) as TScalarPredictor;
135133

136134
Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface");
137135
return new BinaryPredictionTransformer<TScalarPredictor>(Host, calibratedModel, data.Data.Schema, transformer.FeatureColumn);
@@ -140,30 +138,30 @@ private IPredictionTransformer<TScalarPredictor> TrainOne(IChannel ch, TScalarTr
140138
return new BinaryPredictionTransformer<TScalarPredictor>(Host, transformer.Model, data.Data.Schema, transformer.FeatureColumn);
141139
}
142140

143-
private IDataView MapLabels(RoleMappedData data, int cls, out string dstName)
141+
private IDataView MapLabels(RoleMappedData data, int cls)
144142
{
145143
var lab = data.Schema.Label;
146144
Host.Assert(!data.Schema.Schema.IsHidden(lab.Index));
147145
Host.Assert(lab.Type.KeyCount > 0 || lab.Type == NumberType.R4 || lab.Type == NumberType.R8);
148146

149147
// Get the destination label column name.
150-
dstName = data.Schema.Schema.GetTempColumnName();
148+
//dstName = data.Schema.Schema.GetTempColumnName();
151149

152150
if (lab.Type.KeyCount > 0)
153151
{
154152
// Key values are 1-based.
155153
uint key = (uint)(cls + 1);
156-
return MapLabelsCore(NumberType.U4, (ref uint val) => key == val, data, dstName);
154+
return MapLabelsCore(NumberType.U4, (ref uint val) => key == val, data);
157155
}
158156
if (lab.Type == NumberType.R4)
159157
{
160158
Float key = cls;
161-
return MapLabelsCore(NumberType.R4, (ref float val) => key == val, data, dstName);
159+
return MapLabelsCore(NumberType.R4, (ref float val) => key == val, data);
162160
}
163161
if (lab.Type == NumberType.R8)
164162
{
165163
Double key = cls;
166-
return MapLabelsCore(NumberType.R8, (ref double val) => key == val, data, dstName);
164+
return MapLabelsCore(NumberType.R8, (ref double val) => key == val, data);
167165
}
168166

169167
throw Host.ExceptNotSupp($"Label column type is not supported by OVA: {lab.Type}");

src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ namespace Microsoft.ML.Runtime.Learners
2929
{
3030

3131
using TDistPredictor = IDistPredictorProducing<float, float>;
32-
// using TScalarTrainer = IEstimator<IPredictionTransformer<IPredictorProducing<float>>>;
3332
using TScalarTrainer = ITrainerEstimator<IPredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>;
3433
using CR = RoleMappedSchema.ColumnRole;
3534
using TTransformer = MulticlassPredictionTransformer<PkpdPredictor>;
@@ -131,52 +130,50 @@ protected override PkpdPredictor TrainCore(IChannel ch, RoleMappedData data, int
131130

132131
private IPredictionTransformer<TDistPredictor> TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls1, int cls2)
133132
{
134-
string dstName;
135-
var view = MapLabels(data, cls1, cls2, out dstName);
136-
137133
// this should not be necessary when the legacy constructor doesn't exist, and the label colum is not an optional parameter on the
138134
// MetaMulticlassTrainer constructor.
139135
string trainerLabel = data.Schema.Label.Name;
140136

141-
//copy the newly created column in the what the learner knows as the label column
142-
var trainerData = new CopyColumnsTransform(Host, (dstName, trainerLabel)).Transform(view);
143-
137+
var view = MapLabels(data, cls1, cls2);
144138
var transformer = trainer.Fit(view);
145139

140+
// the validations in the calibrator check for the feature column, in the RoleMappedData
141+
var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumn);
142+
146143
var calibratedModel = transformer.Model as TDistPredictor;
147144
if (calibratedModel == null)
148-
calibratedModel = CalibratorUtils.TrainCalibrator(Host, ch, Calibrator, Args.MaxCalibrationExamples, transformer.Model, data) as TDistPredictor;
145+
calibratedModel = CalibratorUtils.TrainCalibrator(Host, ch, Calibrator, Args.MaxCalibrationExamples, transformer.Model, trainedData) as TDistPredictor;
149146

150147
return new BinaryPredictionTransformer<TDistPredictor>(Host, calibratedModel, data.Data.Schema, transformer.FeatureColumn);
151148
}
152149

153-
private IDataView MapLabels(RoleMappedData data, int cls1, int cls2, out string dstName)
150+
private IDataView MapLabels(RoleMappedData data, int cls1, int cls2)
154151
{
155152
var lab = data.Schema.Label;
156153
Host.Assert(!data.Schema.Schema.IsHidden(lab.Index));
157154
Host.Assert(lab.Type.KeyCount > 0 || lab.Type == NumberType.R4 || lab.Type == NumberType.R8);
158155

159156
// Get the destination label column name.
160-
dstName = data.Schema.Schema.GetTempColumnName();
157+
//dstName = data.Schema.Schema.GetTempColumnName();
161158

162159
if (lab.Type.KeyCount > 0)
163160
{
164161
// Key values are 1-based.
165162
uint key1 = (uint)(cls1 + 1);
166163
uint key2 = (uint)(cls2 + 1);
167-
return MapLabelsCore(NumberType.U4, (ref uint val) => val == key1 || val == key2, data, dstName);
164+
return MapLabelsCore(NumberType.U4, (ref uint val) => val == key1 || val == key2, data);
168165
}
169166
if (lab.Type == NumberType.R4)
170167
{
171168
float key1 = cls1;
172169
float key2 = cls2;
173-
return MapLabelsCore(NumberType.R4, (ref float val) => val == key1 || val == key2, data, dstName);
170+
return MapLabelsCore(NumberType.R4, (ref float val) => val == key1 || val == key2, data);
174171
}
175172
if (lab.Type == NumberType.R8)
176173
{
177174
double key1 = cls1;
178175
double key2 = cls2;
179-
return MapLabelsCore(NumberType.R8, (ref double val) => val == key1 || val == key2, data, dstName);
176+
return MapLabelsCore(NumberType.R8, (ref double val) => val == key1 || val == key2, data);
180177
}
181178

182179
throw Host.ExceptNotSupp($"Label column type is not supported by PKPD: {lab.Type}");

src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using Float = System.Single;
6-
75
using System;
86
using Microsoft.ML.Runtime;
97
using Microsoft.ML.Runtime.CommandLine;
@@ -49,7 +47,7 @@ public sealed class Arguments : OnlineLinearArguments
4947
[Argument(ArgumentType.AtMostOnce, HelpText = "Regularizer constant", ShortName = "lambda", SortOrder = 50)]
5048
[TGUI(SuggestedSweeps = "0.00001-0.1;log;inc:10")]
5149
[TlcModule.SweepableFloatParamAttribute("Lambda", 0.00001f, 0.1f, 10, isLogScale:true)]
52-
public Float Lambda = (Float)0.001;
50+
public float Lambda = (float)0.001;
5351

5452
[Argument(ArgumentType.AtMostOnce, HelpText = "Batch size", ShortName = "batch", SortOrder = 190)]
5553
[TGUI(Label = "Batch Size")]
@@ -79,9 +77,9 @@ public sealed class Arguments : OnlineLinearArguments
7977
// weightsUpdate/weightsUpdateScale/biasUpdate are similar to weights/weightsScale/bias, in that
8078
// all elements of weightsUpdate are considered to be multiplied by weightsUpdateScale, and the
8179
// bias update term is not considered to be multiplied by the scale.
82-
private VBuffer<Float> _weightsUpdate;
83-
private Float _weightsUpdateScale;
84-
private Float _biasUpdate;
80+
private VBuffer<float> _weightsUpdate;
81+
private float _weightsUpdateScale;
82+
private float _biasUpdate;
8583

8684
protected override bool NeedCalibration => true;
8785

@@ -114,7 +112,7 @@ protected override void CheckLabel(RoleMappedData data)
114112
/// <summary>
115113
/// Return the raw margin from the decision hyperplane
116114
/// </summary>
117-
protected override Float Margin(ref VBuffer<Float> feat)
115+
protected override float Margin(ref VBuffer<float> feat)
118116
{
119117
return Bias + VectorUtils.DotProduct(ref feat, ref Weights) * WeightsScale;
120118
}
@@ -134,7 +132,7 @@ protected override void InitCore(IChannel ch, int numFeatures, LinearPredictor p
134132
if (predictor == null)
135133
VBufferUtils.Densify(ref Weights);
136134

137-
_weightsUpdate = VBufferUtils.CreateEmpty<Float>(numFeatures);
135+
_weightsUpdate = VBufferUtils.CreateEmpty<float>(numFeatures);
138136
}
139137

140138
protected override void BeginIteration(IChannel ch)
@@ -148,10 +146,10 @@ private void BeginBatch()
148146
_batch++;
149147
_numBatchExamples = 0;
150148
_biasUpdate = 0;
151-
_weightsUpdate = new VBuffer<Float>(_weightsUpdate.Length, 0, _weightsUpdate.Values, _weightsUpdate.Indices);
149+
_weightsUpdate = new VBuffer<float>(_weightsUpdate.Length, 0, _weightsUpdate.Values, _weightsUpdate.Indices);
152150
}
153151

154-
private void FinishBatch(ref VBuffer<Float> weightsUpdate, Float weightsUpdateScale)
152+
private void FinishBatch(ref VBuffer<float> weightsUpdate, float weightsUpdateScale)
155153
{
156154
if (_numBatchExamples > 0)
157155
UpdateWeights(ref weightsUpdate, weightsUpdateScale);
@@ -161,19 +159,19 @@ private void FinishBatch(ref VBuffer<Float> weightsUpdate, Float weightsUpdateSc
161159
/// <summary>
162160
/// Observe an example and update weights if necessary
163161
/// </summary>
164-
protected override void ProcessDataInstance(IChannel ch, ref VBuffer<Float> feat, Float label, Float weight)
162+
protected override void ProcessDataInstance(IChannel ch, ref VBuffer<float> feat, float label, float weight)
165163
{
166164
base.ProcessDataInstance(ch, ref feat, label, weight);
167165

168166
// compute the update and update if needed
169-
Float output = Margin(ref feat);
170-
Float trueOutput = (label > 0 ? 1 : -1);
171-
Float loss = output * trueOutput - 1;
167+
float output = Margin(ref feat);
168+
float trueOutput = (label > 0 ? 1 : -1);
169+
float loss = output * trueOutput - 1;
172170

173171
// Accumulate the update if there is a loss and we have larger batches.
174172
if (Args.BatchSize > 1 && loss < 0)
175173
{
176-
Float currentBiasUpdate = trueOutput * weight;
174+
float currentBiasUpdate = trueOutput * weight;
177175
_biasUpdate += currentBiasUpdate;
178176
// Only aggregate in the case where we're handling multiple instances.
179177
if (_weightsUpdate.Count == 0)
@@ -192,7 +190,7 @@ protected override void ProcessDataInstance(IChannel ch, ref VBuffer<Float> feat
192190
Contracts.Assert(_weightsUpdate.Count == 0);
193191
// If we aren't aggregating multiple instances, just use the instance's
194192
// vector directly.
195-
Float currentBiasUpdate = trueOutput * weight;
193+
float currentBiasUpdate = trueOutput * weight;
196194
_biasUpdate += currentBiasUpdate;
197195
FinishBatch(ref feat, currentBiasUpdate);
198196
}
@@ -206,13 +204,13 @@ protected override void ProcessDataInstance(IChannel ch, ref VBuffer<Float> feat
206204
/// Updates the weights at the end of the batch. Since weightsUpdate can be an instance
207205
/// feature vector, this function should not change the contents of weightsUpdate.
208206
/// </summary>
209-
private void UpdateWeights(ref VBuffer<Float> weightsUpdate, Float weightsUpdateScale)
207+
private void UpdateWeights(ref VBuffer<float> weightsUpdate, float weightsUpdateScale)
210208
{
211209
Contracts.Assert(_batch > 0);
212210

213211
// REVIEW: This is really odd - normally lambda is small, so the learning rate is initially huge!?!?!
214212
// Changed from the paper's recommended rate = 1 / (lambda * t) to rate = 1 / (1 + lambda * t).
215-
Float rate = 1 / (1 + Args.Lambda * _batch);
213+
float rate = 1 / (1 + Args.Lambda * _batch);
216214

217215
// w_{t+1/2} = (1 - eta*lambda) w_t + eta/k * totalUpdate
218216
WeightsScale *= 1 - rate * Args.Lambda;
@@ -226,7 +224,7 @@ private void UpdateWeights(ref VBuffer<Float> weightsUpdate, Float weightsUpdate
226224
// w_{t+1} = min{1, 1/sqrt(lambda)/|w_{t+1/2}|} * w_{t+1/2}
227225
if (Args.PerformProjection)
228226
{
229-
Float normalizer = 1 / (MathUtils.Sqrt(Args.Lambda) * VectorUtils.Norm(Weights) * Math.Abs(WeightsScale));
227+
float normalizer = 1 / (MathUtils.Sqrt(Args.Lambda) * VectorUtils.Norm(Weights) * Math.Abs(WeightsScale));
230228
if (normalizer < 1)
231229
{
232230
// REVIEW: Why would we not scale _bias if we're scaling the weights?

0 commit comments

Comments
 (0)