2
2
// The .NET Foundation licenses this file to you under the MIT license.
3
3
// See the LICENSE file in the project root for more information.
4
4
5
- using Float = System . Single ;
6
-
7
5
using System ;
8
6
using System . Collections . Generic ;
9
7
using System . Runtime . InteropServices ;
10
8
using System . Security ;
9
+ using Microsoft . ML . Core . Data ;
11
10
using Microsoft . ML . Runtime ;
12
11
using Microsoft . ML . Runtime . CommandLine ;
13
12
using Microsoft . ML . Runtime . Data ;
26
25
SymSgdClassificationTrainer . LoadNameValue ,
27
26
SymSgdClassificationTrainer . ShortName ) ]
28
27
29
- [ assembly: LoadableClass ( typeof ( void ) , typeof ( SymSgdClassificationTrainer ) , null , typeof ( SignatureEntryPointModule ) , "SymSGD" ) ]
28
+ [ assembly: LoadableClass ( typeof ( void ) , typeof ( SymSgdClassificationTrainer ) , null , typeof ( SignatureEntryPointModule ) , SymSgdClassificationTrainer . LoadNameValue ) ]
30
29
31
30
namespace Microsoft . ML . Runtime . SymSgd
32
31
{
33
- using TPredictor = IPredictorWithFeatureWeights < Float > ;
32
+ using TPredictor = IPredictorWithFeatureWeights < float > ;
34
33
35
34
/// <include file='doc.xml' path='doc/members/member[@name="SymSGD"]/*' />
36
- public sealed class SymSgdClassificationTrainer :
37
- TrainerBase < TPredictor > ,
38
- ITrainer < TPredictor >
35
+ public sealed class SymSgdClassificationTrainer : TrainerEstimatorBase < BinaryPredictionTransformer < TPredictor > , TPredictor >
39
36
{
40
- public const string LoadNameValue = "SymbolicSGD" ;
41
- public const string UserNameValue = "Symbolic SGD (binary)" ;
42
- public const string ShortName = "SymSGD" ;
37
+ internal const string LoadNameValue = "SymbolicSGD" ;
38
+ internal const string UserNameValue = "Symbolic SGD (binary)" ;
39
+ internal const string ShortName = "SymSGD" ;
43
40
44
41
public sealed class Arguments : LearnerInputBaseWithLabel
45
42
{
@@ -78,7 +75,7 @@ public sealed class Arguments : LearnerInputBaseWithLabel
78
75
public bool Shuffle = true ;
79
76
80
77
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Apply weight to the positive class, for imbalanced data" , ShortName = "piw" ) ]
81
- public Float PositiveInstanceWeight = 1 ;
78
+ public float PositiveInstanceWeight = 1 ;
82
79
83
80
public void Check ( IExceptionContext ectx )
84
81
{
@@ -135,7 +132,7 @@ private RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedDa
135
132
return examplesToFeedTrain ;
136
133
}
137
134
138
- public override TPredictor Train ( TrainContext context )
135
+ protected override TPredictor TrainModelCore ( TrainContext context )
139
136
{
140
137
Host . CheckValue ( context , nameof ( context ) ) ;
141
138
TPredictor pred ;
@@ -155,25 +152,64 @@ public override TPredictor Train(TrainContext context)
155
152
156
153
public override PredictionKind PredictionKind => PredictionKind . BinaryClassification ;
157
154
158
- public SymSgdClassificationTrainer ( IHostEnvironment env , Arguments args )
159
- : base ( env , LoadNameValue )
155
+ /// <summary>
156
+ /// Initializes a new instance of <see cref="SymSgdClassificationTrainer"/>
157
+ /// </summary>
158
+ /// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
159
+ /// <param name="labelColumn">The name of the label column.</param>
160
+ /// <param name="featureColumn">The name of the feature column.</param>
161
+ /// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
162
+ public SymSgdClassificationTrainer ( IHostEnvironment env , string featureColumn , string labelColumn , Action < Arguments > advancedSettings = null )
163
+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( LoadNameValue ) , TrainerUtils . MakeR4VecFeature ( featureColumn ) ,
164
+ TrainerUtils . MakeBoolScalarLabel ( labelColumn ) )
165
+ {
166
+ _args = new Arguments ( ) ;
167
+
168
+ // Apply the advanced args, if the user supplied any.
169
+ _args . Check ( Host ) ;
170
+ advancedSettings ? . Invoke ( _args ) ;
171
+ _args . FeatureColumn = featureColumn ;
172
+ _args . LabelColumn = labelColumn ;
173
+
174
+ Info = new TrainerInfo ( ) ;
175
+ }
176
+
177
+ /// <summary>
178
+ /// Initializes a new instance of <see cref="SymSgdClassificationTrainer"/>
179
+ /// </summary>
180
+ internal SymSgdClassificationTrainer ( IHostEnvironment env , Arguments args )
181
+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( LoadNameValue ) , TrainerUtils . MakeR4VecFeature ( args . FeatureColumn ) ,
182
+ TrainerUtils . MakeBoolScalarLabel ( args . LabelColumn ) )
160
183
{
161
184
args . Check ( Host ) ;
162
185
_args = args ;
163
186
Info = new TrainerInfo ( ) ;
164
187
}
165
188
166
- private TPredictor CreatePredictor ( VBuffer < Float > weights , Float bias )
189
+ private TPredictor CreatePredictor ( VBuffer < float > weights , float bias )
167
190
{
168
191
Host . CheckParam ( weights . Length > 0 , nameof ( weights ) ) ;
169
192
170
- VBuffer < Float > maybeSparseWeights = default ;
193
+ VBuffer < float > maybeSparseWeights = default ;
171
194
VBufferUtils . CreateMaybeSparseCopy ( ref weights , ref maybeSparseWeights ,
172
- Conversions . Instance . GetIsDefaultPredicate < Float > ( NumberType . Float ) ) ;
195
+ Conversions . Instance . GetIsDefaultPredicate < float > ( NumberType . R4 ) ) ;
173
196
var predictor = new LinearBinaryPredictor ( Host , ref maybeSparseWeights , bias ) ;
174
197
return new ParameterMixingCalibratedPredictor ( Host , predictor , new PlattCalibrator ( Host , - 1 , 0 ) ) ;
175
198
}
176
199
200
+ protected override BinaryPredictionTransformer < TPredictor > MakeTransformer ( TPredictor model , ISchema trainSchema )
201
+ => new BinaryPredictionTransformer < TPredictor > ( Host , model , trainSchema , FeatureColumn . Name ) ;
202
+
203
+ protected override SchemaShape . Column [ ] GetOutputColumnsCore ( SchemaShape inputSchema )
204
+ {
205
+ return new [ ]
206
+ {
207
+ new SchemaShape . Column ( DefaultColumnNames . Score , SchemaShape . Column . VectorKind . Scalar , NumberType . R4 , false , new SchemaShape ( MetadataUtils . GetTrainerOutputMetadata ( ) ) ) ,
208
+ new SchemaShape . Column ( DefaultColumnNames . Probability , SchemaShape . Column . VectorKind . Scalar , NumberType . R4 , false , new SchemaShape ( MetadataUtils . GetTrainerOutputMetadata ( true ) ) ) ,
209
+ new SchemaShape . Column ( DefaultColumnNames . PredictedLabel , SchemaShape . Column . VectorKind . Scalar , BoolType . Instance , false , new SchemaShape ( MetadataUtils . GetTrainerOutputMetadata ( ) ) )
210
+ } ;
211
+ }
212
+
177
213
[ TlcModule . EntryPoint ( Name = "Trainers.SymSgdBinaryClassifier" ,
178
214
Desc = "Train a symbolic SGD." ,
179
215
UserName = SymSgdClassificationTrainer . UserNameValue ,
@@ -496,9 +532,9 @@ public void LoadAsMuchAsPossible()
496
532
continue ;
497
533
}
498
534
499
- // We assume that cursor.Features.values are represented by Float and cursor.Features.indices are represented by int
535
+ // We assume that cursor.Features.values are represented by float and cursor.Features.indices are represented by int
500
536
// We conservatively assume that an instance is sparse and therefore, it has an array of Floats and ints for values and indices
501
- int perNonZeroInBytes = sizeof ( Float ) + sizeof ( int ) ;
537
+ int perNonZeroInBytes = sizeof ( float ) + sizeof ( int ) ;
502
538
if ( featureCount > _trainer . AcceleratedMemoryBudgetBytes / perNonZeroInBytes )
503
539
{
504
540
// Hopefully this never happens. But the memorySize must >= perNonZeroInBytes * length(the longest instance).
0 commit comments