Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1122,8 +1122,8 @@ public sealed class PlattCalibratorTrainer : CalibratorTrainerBase
private Double _paramA;
private Double _paramB;

public const string UserName = "Sigmoid Calibration";
public const string LoadName = "PlattCalibration";
internal const string UserName = "Sigmoid Calibration";
internal const string LoadName = "PlattCalibration";
internal const string Summary = "This model was introduced by Platt in the paper Probabilistic Outputs for Support Vector Machines "
+ "and Comparisons to Regularized Likelihood Methods";

Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public Arguments()
env => new Ova(env, new Ova.Arguments()
{
PredictorType = ComponentFactoryUtils.CreateFromFunction(
e => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments()))
Copy link
Member Author

@sfilipi sfilipi Sep 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e => new FastTreeBinaryClassificationTrain [](start = 26, length = 44)

this is temporary, until FastTree is an estimator. #Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @sfilipi, honestly I feel like this sort of thing really should not have a default. It seems to be along the same lines as the discussion in #682.


In reply to: 217222841 [](ancestors = 217222841)

e => new AveragedPerceptronTrainer(e, new AveragedPerceptronTrainer.Arguments()))
}));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,53 +2,87 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Float = System.Single;

using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.Conversion;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Calibration;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Training;
using System.Collections.Generic;
using System.Linq;

namespace Microsoft.ML.Runtime.Learners
{
using TScalarTrainer = ITrainer<IPredictorProducing<Float>>;
using TScalarTrainer = ITrainerEstimator<IPredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>;
Copy link
Member Author

@sfilipi sfilipi Sep 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IPredictionTransformer [](start = 45, length = 22)

should be BinaryPredictionTransformer #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's fine as it is


In reply to: 217218242 [](ancestors = 217218242)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Binary is not an interface, so we can't have covariance if we have it there


In reply to: 217234106 [](ancestors = 217234106,217218242)


public abstract class MetaMulticlassTrainer<TPred, TArgs> : TrainerBase<TPred>
where TPred : IPredictor
where TArgs : MetaMulticlassTrainer<TPred, TArgs>.ArgumentsBase
public abstract class MetaMulticlassTrainer<TTransformer, TModel> : ITrainerEstimator<TTransformer, TModel>, ITrainer<TModel>
where TTransformer : IPredictionTransformer<TModel>
where TModel : IPredictor
{
public abstract class ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "Base predictor", ShortName = "p", SortOrder = 1, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
[Argument(ArgumentType.Multiple, HelpText = "Base predictor", ShortName = "p", SortOrder = 4, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
[TGUI(Label = "Predictor Type", Description = "Type of underlying binary predictor")]
public IComponentFactory<TScalarTrainer> PredictorType;

[Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "<None>", SignatureType = typeof(SignatureCalibrator))]
[Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", SortOrder = 150, NullName = "<None>", SignatureType = typeof(SignatureCalibrator))]
Copy link
Contributor

@TomFinley TomFinley Sep 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

150 [](start = 109, length = 3)

What is the purpose of these 150 sort orders? #Pending

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the old GUI doesn't show them. They are the "advanced" ones?


In reply to: 216345747 [](ancestors = 216345747)

public IComponentFactory<ICalibratorTrainer> Calibrator = new PlattCalibratorTrainerFactory();

[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", ShortName = "numcali")]
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", SortOrder = 150, ShortName = "numcali")]
public int MaxCalibrationExamples = 1000000000;

[Argument(ArgumentType.Multiple, HelpText = "Whether to treat missing labels as having negative labels, instead of keeping them missing", ShortName = "missNeg")]
[Argument(ArgumentType.Multiple, HelpText = "Whether to treat missing labels as having negative labels, instead of keeping them missing", SortOrder = 150, ShortName = "missNeg")]
public bool ImputeMissingLabelsAsNegative;
}

protected readonly TArgs Args;
/// <summary>
/// The label column that the trainer expects.
/// </summary>
public readonly SchemaShape.Column LabelColumn;

protected readonly ArgumentsBase Args;
protected readonly IHost Host;
protected readonly ICalibratorTrainer Calibrator;

private TScalarTrainer _trainer;

public sealed override PredictionKind PredictionKind => PredictionKind.MultiClassClassification;
public override TrainerInfo Info { get; }
public PredictionKind PredictionKind => PredictionKind.MultiClassClassification;

protected SchemaShape.Column[] OutputColumns;

internal MetaMulticlassTrainer(IHostEnvironment env, TArgs args, string name)
: base(env, name)
public TrainerInfo Info { get; }

public TScalarTrainer PredictorType;

/// <summary>
/// Initializes the <see cref="MetaMulticlassTrainer{TTransformer, TModel}"/> from the Arguments class.
/// </summary>
/// <param name="env">The private instance of the <see cref="IHostEnvironment"/>.</param>
/// <param name="args">The legacy arguments <see cref="ArgumentsBase"/>class.</param>
/// <param name="name">The component name.</param>
/// <param name="labelColumn">The label column for the metalinear trainer and the binary trainer.</param>
/// <param name="singleEstimator">The binary estimator.</param>
/// <param name="calibrator">The calibrator. If a calibrator is not explicitly provided, it will default to <see cref="PlattCalibratorCalibratorTrainer"/></param>
internal MetaMulticlassTrainer(IHostEnvironment env, ArgumentsBase args, string name, string labelColumn = null,
TScalarTrainer singleEstimator = null, ICalibratorTrainer calibrator = null)
{
Host = Contracts.CheckRef(env, nameof(env)).Register(name);
Host.CheckValue(args, nameof(args));
Args = args;

if (labelColumn != null)
LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);

// Create the first trainer so errors in the args surface early.
_trainer = CreateTrainer();
_trainer = singleEstimator ?? CreateTrainer();

Calibrator = calibrator ?? new PlattCalibratorTrainer(env);

if (args.Calibrator != null)
Calibrator = args.Calibrator.CreateComponent(Host);

// Regarding caching, no matter what the internal predictor, we're performing many passes
// simply by virtue of this being a meta-trainer, so we will still cache.
Info = new TrainerInfo(normalization: _trainer.Info.NeedNormalization);
Expand All @@ -61,29 +95,28 @@ private TScalarTrainer CreateTrainer()
new LinearSvm(Host, new LinearSvm.Arguments());
}

protected IDataView MapLabelsCore<T>(ColumnType type, RefPredicate<T> equalsTarget, RoleMappedData data, string dstName)
protected IDataView MapLabelsCore<T>(ColumnType type, RefPredicate<T> equalsTarget, RoleMappedData data)
{
Host.AssertValue(type);
Host.Assert(type.RawType == typeof(T));
Host.AssertValue(equalsTarget);
Host.AssertValue(data);
Host.AssertValue(data.Schema.Label);
Host.AssertNonWhiteSpace(dstName);

var lab = data.Schema.Label;

RefPredicate<T> isMissing;
if (!Args.ImputeMissingLabelsAsNegative && Conversions.Instance.TryGetIsNAPredicate(type, out isMissing))
{
return LambdaColumnMapper.Create(Host, "Label mapper", data.Data,
lab.Name, dstName, type, NumberType.Float,
(ref T src, ref Float dst) =>
dst = equalsTarget(ref src) ? 1 : (isMissing(ref src) ? Float.NaN : default(Float)));
lab.Name, lab.Name, type, NumberType.Float,
(ref T src, ref float dst) =>
dst = equalsTarget(ref src) ? 1 : (isMissing(ref src) ? float.NaN : default(float)));
}
return LambdaColumnMapper.Create(Host, "Label mapper", data.Data,
lab.Name, dstName, type, NumberType.Float,
(ref T src, ref Float dst) =>
dst = equalsTarget(ref src) ? 1 : default(Float));
lab.Name, lab.Name, type, NumberType.Float,
(ref T src, ref float dst) =>
dst = equalsTarget(ref src) ? 1 : default(float));
}

protected TScalarTrainer GetTrainer()
Expand All @@ -95,9 +128,14 @@ protected TScalarTrainer GetTrainer()
return train;
}

protected abstract TPred TrainCore(IChannel ch, RoleMappedData data, int count);
protected abstract TModel TrainCore(IChannel ch, RoleMappedData data, int count);

public override TPred Train(TrainContext context)
/// <summary>
/// The legacy train method.
/// </summary>
/// <param name="context">The trainig context for this learner.</param>
/// <returns>The trained model.</returns>
public TModel Train(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var data = context.TrainingSet;
Expand All @@ -116,5 +154,76 @@ public override TPred Train(TrainContext context)
return pred;
}
}

/// <summary>
/// Gets the output columns.
/// </summary>
/// <param name="inputSchema">The input schema. </param>
/// <returns>The output <see cref="SchemaShape"/></returns>
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));

if (LabelColumn != null)
{
if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
throw Host.ExceptSchemaMismatch(nameof(labelCol), DefaultColumnNames.PredictedLabel, DefaultColumnNames.PredictedLabel);

if (!LabelColumn.IsCompatibleWith(labelCol))
throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible");
}

var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);
foreach (var col in GetOutputColumnsCore(inputSchema))
outColumns[col.Name] = col;

return new SchemaShape(outColumns.Values);
}

private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
{
if (LabelColumn != null)
{
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
Contracts.Assert(success);

var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
.Concat(MetadataForScoreColumn()));
return new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(MetadataForScoreColumn())),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, metadata)
};
}
else
return new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(MetadataForScoreColumn())),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, new SchemaShape(MetadataForScoreColumn()))
};
}

/// <summary>
/// Normal metadata that we produce for score columns.
/// </summary>
private static IEnumerable<SchemaShape.Column> MetadataForScoreColumn()
{
var cols = new List<SchemaShape.Column>();
cols.Add(new SchemaShape.Column(MetadataUtils.Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true));
cols.Add(new SchemaShape.Column(MetadataUtils.Kinds.ScoreColumnKind, SchemaShape.Column.VectorKind.Scalar, TextType.Instance, false));
cols.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false));
cols.Add(new SchemaShape.Column(MetadataUtils.Kinds.ScoreValueKind, SchemaShape.Column.VectorKind.Scalar, TextType.Instance, false));

return cols;
}

IPredictor ITrainer.Train(TrainContext context) => Train(context);

/// <summary>
/// Fits the data to the trainer.
/// </summary>
/// <param name="input">The input data to fit to.</param>
/// <returns>The transformer.</returns>
public abstract TTransformer Fit(IDataView input);
}
}
}
Loading