Skip to content

Commit 1c751fa

Browse files
committed
Rename several maximum entropy models and trainers
1 parent b1f044e commit 1c751fa

File tree

18 files changed

+131
-125
lines changed

18 files changed

+131
-125
lines changed

src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ public Arguments()
6464
// non-default column names. Unfortuantely no method of resolving this temporary strikes me as being any
6565
// less laborious than the proper fix, which is that this "meta" component should itself be a trainer
6666
// estimator, as opposed to a regular trainer.
67-
var trainerEstimator = new LogisticRegressionMulticlassClassificationTrainer(env, LabelColumnName, FeatureColumnName);
68-
return TrainerUtils.MapTrainerEstimatorToTrainer<LogisticRegressionMulticlassClassificationTrainer,
69-
MulticlassLogisticRegressionModelParameters, MulticlassLogisticRegressionModelParameters>(env, trainerEstimator);
67+
var trainerEstimator = new LbfgsMaximumEntropyTrainer(env, LabelColumnName, FeatureColumnName);
68+
return TrainerUtils.MapTrainerEstimatorToTrainer<LbfgsMaximumEntropyTrainer,
69+
MaximumEntropyModelParameters, MaximumEntropyModelParameters>(env, trainerEstimator);
7070
})
7171
};
7272
}

src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs

Lines changed: 68 additions & 62 deletions
Large diffs are not rendered by default.

src/Microsoft.ML.StandardTrainers/Standard/SdcaMulticlass.cs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -433,10 +433,10 @@ private protected override float GetInstanceWeight(FloatLabelCursor cursor)
433433

434434
/// <summary>
435435
/// The <see cref="IEstimator{TTransformer}"/> for training a maximum entropy classification model using the stochastic dual coordinate ascent method.
436-
/// The trained model <see cref="MulticlassLogisticRegressionModelParameters"/> produces probabilities of classes.
436+
/// The trained model <see cref="MaximumEntropyModelParameters"/> produces probabilities of classes.
437437
/// </summary>
438438
/// <include file='doc.xml' path='doc/members/member[@name="SDCA_remarks"]/*' />
439-
public sealed class SdcaMulticlassClassificationTrainer : SdcaMulticlassClassificationTrainerBase<MulticlassLogisticRegressionModelParameters>
439+
public sealed class SdcaMulticlassClassificationTrainer : SdcaMulticlassClassificationTrainerBase<MaximumEntropyModelParameters>
440440
{
441441
internal SdcaMulticlassClassificationTrainer(IHostEnvironment env,
442442
string labelColumn = DefaultColumnNames.Label,
@@ -462,28 +462,28 @@ internal SdcaMulticlassClassificationTrainer(IHostEnvironment env, Options optio
462462
{
463463
}
464464

465-
private protected override MulticlassLogisticRegressionModelParameters CreatePredictor(VBuffer<float>[] weights, float[] bias)
465+
private protected override MaximumEntropyModelParameters CreatePredictor(VBuffer<float>[] weights, float[] bias)
466466
{
467467
Host.CheckValue(weights, nameof(weights));
468468
Host.CheckValue(bias, nameof(bias));
469469
Host.CheckParam(weights.Length > 0, nameof(weights));
470470
Host.CheckParam(weights.Length == bias.Length, nameof(weights));
471471

472-
return new MulticlassLogisticRegressionModelParameters(Host, weights, bias, bias.Length, weights[0].Length, null, stats: null);
472+
return new MaximumEntropyModelParameters(Host, weights, bias, bias.Length, weights[0].Length, null, stats: null);
473473
}
474474

475-
private protected override MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters> MakeTransformer(
476-
MulticlassLogisticRegressionModelParameters model, DataViewSchema trainSchema) =>
477-
new MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
475+
private protected override MulticlassPredictionTransformer<MaximumEntropyModelParameters> MakeTransformer(
476+
MaximumEntropyModelParameters model, DataViewSchema trainSchema) =>
477+
new MulticlassPredictionTransformer<MaximumEntropyModelParameters>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
478478
}
479479

480480
/// <summary>
481481
/// The <see cref="IEstimator{TTransformer}"/> for training a multiclass linear model using the stochastic dual coordinate ascent method.
482-
/// The trained model <see cref="MulticlassLinearModelParameters"/> does not produces probabilities of classes, but we can still make decisions
482+
/// The trained model <see cref="LinearMulticlassModelParameters"/> does not produces probabilities of classes, but we can still make decisions
483483
/// by choosing the class associated with the largest score.
484484
/// </summary>
485485
/// <include file='doc.xml' path='doc/members/member[@name="SDCA_remarks"]/*' />
486-
public sealed class SdcaNonCalibratedMulticlassClassificationTrainer : SdcaMulticlassClassificationTrainerBase<MulticlassLinearModelParameters>
486+
public sealed class SdcaNonCalibratedMulticlassClassificationTrainer : SdcaMulticlassClassificationTrainerBase<LinearMulticlassModelParameters>
487487
{
488488
internal SdcaNonCalibratedMulticlassClassificationTrainer(IHostEnvironment env,
489489
string labelColumn = DefaultColumnNames.Label,
@@ -509,19 +509,19 @@ internal SdcaNonCalibratedMulticlassClassificationTrainer(IHostEnvironment env,
509509
{
510510
}
511511

512-
private protected override MulticlassLinearModelParameters CreatePredictor(VBuffer<float>[] weights, float[] bias)
512+
private protected override LinearMulticlassModelParameters CreatePredictor(VBuffer<float>[] weights, float[] bias)
513513
{
514514
Host.CheckValue(weights, nameof(weights));
515515
Host.CheckValue(bias, nameof(bias));
516516
Host.CheckParam(weights.Length > 0, nameof(weights));
517517
Host.CheckParam(weights.Length == bias.Length, nameof(weights));
518518

519-
return new MulticlassLinearModelParameters(Host, weights, bias, bias.Length, weights[0].Length, null, stats: null);
519+
return new LinearMulticlassModelParameters(Host, weights, bias, bias.Length, weights[0].Length, null, stats: null);
520520
}
521521

522-
private protected override MulticlassPredictionTransformer<MulticlassLinearModelParameters> MakeTransformer(
523-
MulticlassLinearModelParameters model, DataViewSchema trainSchema) =>
524-
new MulticlassPredictionTransformer<MulticlassLinearModelParameters>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
522+
private protected override MulticlassPredictionTransformer<LinearMulticlassModelParameters> MakeTransformer(
523+
LinearMulticlassModelParameters model, DataViewSchema trainSchema) =>
524+
new MulticlassPredictionTransformer<LinearMulticlassModelParameters>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
525525
}
526526

527527
/// <summary>

src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ public static PoissonRegressionTrainer PoissonRegression(this RegressionCatalog.
587587
}
588588

589589
/// <summary>
590-
/// Predict a target using a linear multiclass classification model trained with the <see cref="LogisticRegressionMulticlassClassificationTrainer"/> trainer.
590+
/// Predict a target using a linear multiclass classification model trained with the <see cref="LbfgsMaximumEntropyTrainer"/> trainer.
591591
/// </summary>
592592
/// <param name="catalog">The <see cref="MulticlassClassificationCatalog.MulticlassClassificationTrainers"/>.</param>
593593
/// <param name="labelColumnName">The name of the label column.</param>
@@ -596,9 +596,9 @@ public static PoissonRegressionTrainer PoissonRegression(this RegressionCatalog.
596596
/// <param name="enforceNonNegativity">Enforce non-negative weights.</param>
597597
/// <param name="l1Regularization">Weight of L1 regularization term.</param>
598598
/// <param name="l2Regularization">Weight of L2 regularization term.</param>
599-
/// <param name="historySize">Memory size for <see cref="Microsoft.ML.Trainers.LogisticRegressionMulticlassClassificationTrainer"/>. Low=faster, less accurate.</param>
599+
/// <param name="historySize">Memory size for <see cref="Microsoft.ML.Trainers.LbfgsMaximumEntropyTrainer"/>. Low=faster, less accurate.</param>
600600
/// <param name="optimizationTolerance">Threshold for optimizer convergence.</param>
601-
public static LogisticRegressionMulticlassClassificationTrainer LogisticRegression(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
601+
public static LbfgsMaximumEntropyTrainer LbfgsMaximumEntropy(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
602602
string labelColumnName = DefaultColumnNames.Label,
603603
string featureColumnName = DefaultColumnNames.Features,
604604
string exampleWeightColumnName = null,
@@ -610,22 +610,22 @@ public static LogisticRegressionMulticlassClassificationTrainer LogisticRegressi
610610
{
611611
Contracts.CheckValue(catalog, nameof(catalog));
612612
var env = CatalogUtils.GetEnvironment(catalog);
613-
return new LogisticRegressionMulticlassClassificationTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, l1Regularization, l2Regularization, optimizationTolerance, historySize, enforceNonNegativity);
613+
return new LbfgsMaximumEntropyTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, l1Regularization, l2Regularization, optimizationTolerance, historySize, enforceNonNegativity);
614614
}
615615

616616
/// <summary>
617-
/// Predict a target using a linear multiclass classification model trained with the <see cref="LogisticRegressionMulticlassClassificationTrainer"/> trainer.
617+
/// Predict a target using a linear multiclass classification model trained with the <see cref="LbfgsMaximumEntropyTrainer"/> trainer.
618618
/// </summary>
619619
/// <param name="catalog">The <see cref="MulticlassClassificationCatalog.MulticlassClassificationTrainers"/>.</param>
620620
/// <param name="options">Advanced arguments to the algorithm.</param>
621-
public static LogisticRegressionMulticlassClassificationTrainer LogisticRegression(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
622-
LogisticRegressionMulticlassClassificationTrainer.Options options)
621+
public static LbfgsMaximumEntropyTrainer LbfgsMaximumEntropy(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
622+
LbfgsMaximumEntropyTrainer.Options options)
623623
{
624624
Contracts.CheckValue(catalog, nameof(catalog));
625625
Contracts.CheckValue(options, nameof(options));
626626

627627
var env = CatalogUtils.GetEnvironment(catalog);
628-
return new LogisticRegressionMulticlassClassificationTrainer(env, options);
628+
return new LbfgsMaximumEntropyTrainer(env, options);
629629
}
630630

631631
/// <summary>

src/Microsoft.ML.StaticPipe/LbfgsStatic.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ public static Scalar<float> PoissonRegression(this RegressionCatalog.RegressionT
209209
public static class LbfgsMulticlassExtensions
210210
{
211211
/// <summary>
212-
/// Predict a target using a linear multiclass classification model trained with the <see cref="Microsoft.ML.Trainers.LogisticRegressionMulticlassClassificationTrainer"/> trainer.
212+
/// Predict a target using a linear multiclass classification model trained with the <see cref="Microsoft.ML.Trainers.LbfgsMaximumEntropyTrainer"/> trainer.
213213
/// </summary>
214214
/// <param name="catalog">The multiclass classification catalog trainer object.</param>
215215
/// <param name="label">The label, or dependent variable.</param>
@@ -227,7 +227,7 @@ public static class LbfgsMulticlassExtensions
227227
/// result in any way; it is only a way for the caller to be informed about what was learnt.</param>
228228
/// <returns>The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label.</returns>
229229
public static (Vector<float> score, Key<uint, TVal> predictedLabel)
230-
MulticlassLogisticRegression<TVal>(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
230+
LbfgsMaximumEntropy<TVal>(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
231231
Key<uint, TVal> label,
232232
Vector<float> features,
233233
Scalar<float> weights = null,
@@ -236,14 +236,14 @@ public static (Vector<float> score, Key<uint, TVal> predictedLabel)
236236
float optimizationTolerance = Options.Defaults.OptimizationTolerance,
237237
int historySize = Options.Defaults.HistorySize,
238238
bool enforceNonNegativity = Options.Defaults.EnforceNonNegativity,
239-
Action<MulticlassLogisticRegressionModelParameters> onFit = null)
239+
Action<MaximumEntropyModelParameters> onFit = null)
240240
{
241241
LbfgsStaticUtils.ValidateParams(label, features, weights, l1Regularization, l2Regularization, optimizationTolerance, historySize, enforceNonNegativity, onFit);
242242

243243
var rec = new TrainerEstimatorReconciler.MulticlassClassificationReconciler<TVal>(
244244
(env, labelName, featuresName, weightsName) =>
245245
{
246-
var trainer = new LogisticRegressionMulticlassClassificationTrainer(env, labelName, featuresName, weightsName,
246+
var trainer = new LbfgsMaximumEntropyTrainer(env, labelName, featuresName, weightsName,
247247
l1Regularization, l2Regularization, optimizationTolerance, historySize, enforceNonNegativity);
248248

249249
if (onFit != null)
@@ -255,7 +255,7 @@ public static (Vector<float> score, Key<uint, TVal> predictedLabel)
255255
}
256256

257257
/// <summary>
258-
/// Predict a target using a linear multiclass classification model trained with the <see cref="Microsoft.ML.Trainers.LogisticRegressionMulticlassClassificationTrainer"/> trainer.
258+
/// Predict a target using a linear multiclass classification model trained with the <see cref="Microsoft.ML.Trainers.LbfgsMaximumEntropyTrainer"/> trainer.
259259
/// </summary>
260260
/// <param name="catalog">The multiclass classification catalog trainer object.</param>
261261
/// <param name="label">The label, or dependent variable.</param>
@@ -269,12 +269,12 @@ public static (Vector<float> score, Key<uint, TVal> predictedLabel)
269269
/// result in any way; it is only a way for the caller to be informed about what was learnt.</param>
270270
/// <returns>The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label.</returns>
271271
public static (Vector<float> score, Key<uint, TVal> predictedLabel)
272-
MulticlassLogisticRegression<TVal>(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
272+
LbfgsMaximumEntropy<TVal>(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
273273
Key<uint, TVal> label,
274274
Vector<float> features,
275275
Scalar<float> weights,
276-
LogisticRegressionMulticlassClassificationTrainer.Options options,
277-
Action<MulticlassLogisticRegressionModelParameters> onFit = null)
276+
LbfgsMaximumEntropyTrainer.Options options,
277+
Action<MaximumEntropyModelParameters> onFit = null)
278278
{
279279
Contracts.CheckValue(label, nameof(label));
280280
Contracts.CheckValue(features, nameof(features));
@@ -288,7 +288,7 @@ public static (Vector<float> score, Key<uint, TVal> predictedLabel)
288288
options.FeatureColumnName = featuresName;
289289
options.ExampleWeightColumnName = weightsName;
290290

291-
var trainer = new LogisticRegressionMulticlassClassificationTrainer(env, options);
291+
var trainer = new LbfgsMaximumEntropyTrainer(env, options);
292292

293293
if (onFit != null)
294294
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));

src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ public static (Vector<float> score, Key<uint, TVal> predictedLabel) Sdca<TVal>(
352352
float? l2Regularization = null,
353353
float? l1Threshold = null,
354354
int? numberOfIterations = null,
355-
Action<MulticlassLogisticRegressionModelParameters> onFit = null)
355+
Action<MaximumEntropyModelParameters> onFit = null)
356356
{
357357
Contracts.CheckValue(label, nameof(label));
358358
Contracts.CheckValue(features, nameof(features));
@@ -394,7 +394,7 @@ public static (Vector<float> score, Key<uint, TVal> predictedLabel) Sdca<TVal>(
394394
Vector<float> features,
395395
Scalar<float> weights,
396396
SdcaMulticlassClassificationTrainer.Options options,
397-
Action<MulticlassLogisticRegressionModelParameters> onFit = null)
397+
Action<MaximumEntropyModelParameters> onFit = null)
398398
{
399399
Contracts.CheckValue(label, nameof(label));
400400
Contracts.CheckValue(features, nameof(features));
@@ -443,7 +443,7 @@ public static (Vector<float> score, Key<uint, TVal> predictedLabel) SdcaNonCalib
443443
float? l2Regularization = null,
444444
float? l1Threshold = null,
445445
int? numberOfIterations = null,
446-
Action<MulticlassLinearModelParameters> onFit = null)
446+
Action<LinearMulticlassModelParameters> onFit = null)
447447
{
448448
Contracts.CheckValue(label, nameof(label));
449449
Contracts.CheckValue(features, nameof(features));
@@ -486,7 +486,7 @@ public static (Vector<float> score, Key<uint, TVal> predictedLabel) SdcaNonCalib
486486
Vector<float> features,
487487
Scalar<float> weights,
488488
SdcaNonCalibratedMulticlassClassificationTrainer.Options options,
489-
Action<MulticlassLinearModelParameters> onFit = null)
489+
Action<LinearMulticlassModelParameters> onFit = null)
490490
{
491491
Contracts.CheckValue(label, nameof(label));
492492
Contracts.CheckValue(features, nameof(features));

test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public class StochasticDualCoordinateAscentClassifierBench : WithExtraMetrics
3535
PetalWidth = 5.1f,
3636
};
3737

38-
private TransformerChain<MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>> _trainedModel;
38+
private TransformerChain<MulticlassPredictionTransformer<MaximumEntropyModelParameters>> _trainedModel;
3939
private PredictionEngine<IrisData, IrisPrediction> _predictionEngine;
4040
private IrisData[][] _batches;
4141
private MulticlassClassificationMetrics _metrics;
@@ -54,9 +54,9 @@ protected override IEnumerable<Metric> GetMetrics()
5454
}
5555

5656
[Benchmark]
57-
public TransformerChain<MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>> TrainIris() => Train(_dataPath);
57+
public TransformerChain<MulticlassPredictionTransformer<MaximumEntropyModelParameters>> TrainIris() => Train(_dataPath);
5858

59-
private TransformerChain<MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>> Train(string dataPath)
59+
private TransformerChain<MulticlassPredictionTransformer<MaximumEntropyModelParameters>> Train(string dataPath)
6060
{
6161
// Create text loader.
6262
var options = new TextLoader.Options()

test/Microsoft.ML.Benchmarks/Text/MultiClassClassification.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ public void CV_Multiclass_WikiDetox_WordEmbeddings_SDCAMC()
8585
" xf=WordEmbeddingsTransform{col=FeaturesWordEmbedding:FeaturesText_TransformedText model=FastTextWikipedia300D}" +
8686
" xf=Concat{col=Features:FeaturesWordEmbedding,logged_in,ns}";
8787

88-
var environment = EnvironmentFactory.CreateClassificationEnvironment<TextLoader, OneHotEncodingTransformer, SdcaMulticlassClassificationTrainer, MulticlassLogisticRegressionModelParameters>();
88+
var environment = EnvironmentFactory.CreateClassificationEnvironment<TextLoader, OneHotEncodingTransformer, SdcaMulticlassClassificationTrainer, MaximumEntropyModelParameters>();
8989
cmd.ExecuteMamlCommand(environment);
9090
}
9191
}

0 commit comments

Comments
 (0)