6
6
using System . Collections . Generic ;
7
7
using System . Runtime . InteropServices ;
8
8
using System . Security ;
9
+ using Microsoft . ML . Core . Data ;
9
10
using Microsoft . ML . Runtime ;
10
11
using Microsoft . ML . Runtime . CommandLine ;
11
12
using Microsoft . ML . Runtime . Data ;
@@ -131,7 +132,7 @@ private RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedDa
131
132
return examplesToFeedTrain ;
132
133
}
133
134
134
- public override TPredictor Train ( TrainContext context )
135
+ protected override TPredictor TrainModelCore ( TrainContext context )
135
136
{
136
137
Host . CheckValue ( context , nameof ( context ) ) ;
137
138
TPredictor pred ;
@@ -161,29 +162,33 @@ public override TPredictor Train(TrainContext context)
161
162
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
162
163
public SymSgdClassificationTrainer ( IHostEnvironment env , string featureColumn , string labelColumn ,
163
164
string weightColumn = null , Action < Arguments > advancedSettings = null )
164
- : this ( env , ArgsInit ( featureColumn , labelColumn , weightColumn , advancedSettings ) )
165
+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( LoadNameValue ) , TrainerUtils . MakeR4VecFeature ( featureColumn ) ,
166
+ TrainerUtils . MakeR4ScalarLabel ( labelColumn ) , TrainerUtils . MakeR4ScalarWeightColumn ( weightColumn ) )
165
167
{
166
- }
168
+ var args = new Arguments ( ) ;
167
169
168
- public SymSgdClassificationTrainer ( IHostEnvironment env , Arguments args )
169
- : base ( env , LoadNameValue )
170
- {
170
+ //apply the advanced args, if the user supplied any
171
171
args . Check ( Host ) ;
172
+ advancedSettings ? . Invoke ( args ) ;
173
+ args . FeatureColumn = featureColumn ;
174
+ args . LabelColumn = labelColumn ;
175
+ // TODO:
176
+ //args.WeightColumn = weightColumn;
177
+
172
178
_args = args ;
173
179
Info = new TrainerInfo ( ) ;
174
180
}
175
181
176
- private static Arguments ArgsInit ( string featureColumn , string labelColumn ,
177
- string weightColumn , Action < Arguments > advancedSettings )
182
+ /// <summary>
183
+ /// Initializes a new instance of <see cref="SymSgdClassificationTrainer"/>
184
+ /// </summary>
185
+ internal SymSgdClassificationTrainer ( IHostEnvironment env , Arguments args )
186
+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( LoadNameValue ) , TrainerUtils . MakeR4VecFeature ( args . FeatureColumn ) ,
187
+ TrainerUtils . MakeR4ScalarLabel ( args . LabelColumn ) , TrainerUtils . MakeR4ScalarWeightColumn ( null ) )
178
188
{
179
- var args = new Arguments ( ) ;
180
-
181
- //apply the advanced args, if the user supplied any
182
- advancedSettings ? . Invoke ( args ) ;
183
- args . FeatureColumn = featureColumn ;
184
- args . LabelColumn = labelColumn ;
185
- args . WeightColumn = weightColumn ;
186
- return args ;
189
+ args . Check ( Host ) ;
190
+ _args = args ;
191
+ Info = new TrainerInfo ( ) ;
187
192
}
188
193
189
194
private TPredictor CreatePredictor ( VBuffer < float > weights , float bias )
@@ -197,6 +202,19 @@ private TPredictor CreatePredictor(VBuffer<float> weights, float bias)
197
202
return new ParameterMixingCalibratedPredictor ( Host , predictor , new PlattCalibrator ( Host , - 1 , 0 ) ) ;
198
203
}
199
204
205
+ protected override BinaryPredictionTransformer < TPredictor > MakeTransformer ( TPredictor model , ISchema trainSchema )
206
+ => new BinaryPredictionTransformer < TPredictor > ( Host , model , trainSchema , FeatureColumn . Name ) ;
207
+
208
+ protected override SchemaShape . Column [ ] GetOutputColumnsCore ( SchemaShape inputSchema )
209
+ {
210
+ return new [ ]
211
+ {
212
+ new SchemaShape . Column ( DefaultColumnNames . Score , SchemaShape . Column . VectorKind . Scalar , NumberType . R4 , false , new SchemaShape ( MetadataUtils . GetTrainerOutputMetadata ( ) ) ) ,
213
+ new SchemaShape . Column ( DefaultColumnNames . Probability , SchemaShape . Column . VectorKind . Scalar , NumberType . R4 , false , new SchemaShape ( MetadataUtils . GetTrainerOutputMetadata ( true ) ) ) ,
214
+ new SchemaShape . Column ( DefaultColumnNames . PredictedLabel , SchemaShape . Column . VectorKind . Scalar , BoolType . Instance , false , new SchemaShape ( MetadataUtils . GetTrainerOutputMetadata ( ) ) )
215
+ } ;
216
+ }
217
+
200
218
[ TlcModule . EntryPoint ( Name = "Trainers.SymSgdBinaryClassifier" ,
201
219
Desc = "Train a symbolic SGD." ,
202
220
UserName = SymSgdClassificationTrainer . UserNameValue ,
0 commit comments