-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Ova and Pkpd as estimators #865
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
43090ce
cd05a04
0c94bf6
3358e63
45b050a
537b9f1
d1f780a
cb8113d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,53 +2,87 @@ | |
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Float = System.Single; | ||
|
||
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Runtime.CommandLine; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Data.Conversion; | ||
using Microsoft.ML.Runtime.EntryPoints; | ||
using Microsoft.ML.Runtime.Internal.Calibration; | ||
using Microsoft.ML.Runtime.Internal.Internallearn; | ||
using Microsoft.ML.Runtime.Training; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
|
||
namespace Microsoft.ML.Runtime.Learners | ||
{ | ||
using TScalarTrainer = ITrainer<IPredictorProducing<Float>>; | ||
using TScalarTrainer = ITrainerEstimator<IPredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
should be BinaryPredictionTransformer #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Binary is not an interface, so we can't have covariance if we have it there In reply to: 217234106 [](ancestors = 217234106,217218242) |
||
|
||
public abstract class MetaMulticlassTrainer<TPred, TArgs> : TrainerBase<TPred> | ||
where TPred : IPredictor | ||
where TArgs : MetaMulticlassTrainer<TPred, TArgs>.ArgumentsBase | ||
public abstract class MetaMulticlassTrainer<TTransformer, TModel> : ITrainerEstimator<TTransformer, TModel>, ITrainer<TModel> | ||
where TTransformer : IPredictionTransformer<TModel> | ||
where TModel : IPredictor | ||
{ | ||
public abstract class ArgumentsBase | ||
{ | ||
[Argument(ArgumentType.Multiple, HelpText = "Base predictor", ShortName = "p", SortOrder = 1, SignatureType = typeof(SignatureBinaryClassifierTrainer))] | ||
[Argument(ArgumentType.Multiple, HelpText = "Base predictor", ShortName = "p", SortOrder = 4, SignatureType = typeof(SignatureBinaryClassifierTrainer))] | ||
[TGUI(Label = "Predictor Type", Description = "Type of underlying binary predictor")] | ||
public IComponentFactory<TScalarTrainer> PredictorType; | ||
|
||
[Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "<None>", SignatureType = typeof(SignatureCalibrator))] | ||
[Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", SortOrder = 150, NullName = "<None>", SignatureType = typeof(SignatureCalibrator))] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
What is the purpose of these 150 sort orders? #Pending There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think the old GUI doesn't show them. They are the "advanced" ones? In reply to: 216345747 [](ancestors = 216345747) |
||
public IComponentFactory<ICalibratorTrainer> Calibrator = new PlattCalibratorTrainerFactory(); | ||
|
||
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", ShortName = "numcali")] | ||
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", SortOrder = 150, ShortName = "numcali")] | ||
public int MaxCalibrationExamples = 1000000000; | ||
|
||
[Argument(ArgumentType.Multiple, HelpText = "Whether to treat missing labels as having negative labels, instead of keeping them missing", ShortName = "missNeg")] | ||
[Argument(ArgumentType.Multiple, HelpText = "Whether to treat missing labels as having negative labels, instead of keeping them missing", SortOrder = 150, ShortName = "missNeg")] | ||
public bool ImputeMissingLabelsAsNegative; | ||
} | ||
|
||
protected readonly TArgs Args; | ||
/// <summary> | ||
/// The label column that the trainer expects. | ||
/// </summary> | ||
public readonly SchemaShape.Column LabelColumn; | ||
|
||
protected readonly ArgumentsBase Args; | ||
protected readonly IHost Host; | ||
protected readonly ICalibratorTrainer Calibrator; | ||
|
||
private TScalarTrainer _trainer; | ||
|
||
public sealed override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; | ||
public override TrainerInfo Info { get; } | ||
public PredictionKind PredictionKind => PredictionKind.MultiClassClassification; | ||
|
||
protected SchemaShape.Column[] OutputColumns; | ||
|
||
internal MetaMulticlassTrainer(IHostEnvironment env, TArgs args, string name) | ||
: base(env, name) | ||
public TrainerInfo Info { get; } | ||
|
||
public TScalarTrainer PredictorType; | ||
|
||
/// <summary> | ||
/// Initializes the <see cref="MetaMulticlassTrainer{TTransformer, TModel}"/> from the Arguments class. | ||
/// </summary> | ||
/// <param name="env">The private instance of the <see cref="IHostEnvironment"/>.</param> | ||
/// <param name="args">The legacy arguments <see cref="ArgumentsBase"/>class.</param> | ||
/// <param name="name">The component name.</param> | ||
/// <param name="labelColumn">The label column for the metalinear trainer and the binary trainer.</param> | ||
/// <param name="singleEstimator">The binary estimator.</param> | ||
/// <param name="calibrator">The calibrator. If a calibrator is not explicitly provided, it will default to <see cref="PlattCalibratorCalibratorTrainer"/></param> | ||
internal MetaMulticlassTrainer(IHostEnvironment env, ArgumentsBase args, string name, string labelColumn = null, | ||
TScalarTrainer singleEstimator = null, ICalibratorTrainer calibrator = null) | ||
{ | ||
Host = Contracts.CheckRef(env, nameof(env)).Register(name); | ||
Host.CheckValue(args, nameof(args)); | ||
Args = args; | ||
|
||
if (labelColumn != null) | ||
LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); | ||
|
||
// Create the first trainer so errors in the args surface early. | ||
_trainer = CreateTrainer(); | ||
_trainer = singleEstimator ?? CreateTrainer(); | ||
|
||
Calibrator = calibrator ?? new PlattCalibratorTrainer(env); | ||
|
||
if (args.Calibrator != null) | ||
Calibrator = args.Calibrator.CreateComponent(Host); | ||
|
||
// Regarding caching, no matter what the internal predictor, we're performing many passes | ||
// simply by virtue of this being a meta-trainer, so we will still cache. | ||
Info = new TrainerInfo(normalization: _trainer.Info.NeedNormalization); | ||
|
@@ -61,29 +95,28 @@ private TScalarTrainer CreateTrainer() | |
new LinearSvm(Host, new LinearSvm.Arguments()); | ||
} | ||
|
||
protected IDataView MapLabelsCore<T>(ColumnType type, RefPredicate<T> equalsTarget, RoleMappedData data, string dstName) | ||
protected IDataView MapLabelsCore<T>(ColumnType type, RefPredicate<T> equalsTarget, RoleMappedData data) | ||
{ | ||
Host.AssertValue(type); | ||
Host.Assert(type.RawType == typeof(T)); | ||
Host.AssertValue(equalsTarget); | ||
Host.AssertValue(data); | ||
Host.AssertValue(data.Schema.Label); | ||
Host.AssertNonWhiteSpace(dstName); | ||
|
||
var lab = data.Schema.Label; | ||
|
||
RefPredicate<T> isMissing; | ||
if (!Args.ImputeMissingLabelsAsNegative && Conversions.Instance.TryGetIsNAPredicate(type, out isMissing)) | ||
{ | ||
return LambdaColumnMapper.Create(Host, "Label mapper", data.Data, | ||
lab.Name, dstName, type, NumberType.Float, | ||
(ref T src, ref Float dst) => | ||
dst = equalsTarget(ref src) ? 1 : (isMissing(ref src) ? Float.NaN : default(Float))); | ||
lab.Name, lab.Name, type, NumberType.Float, | ||
(ref T src, ref float dst) => | ||
dst = equalsTarget(ref src) ? 1 : (isMissing(ref src) ? float.NaN : default(float))); | ||
} | ||
return LambdaColumnMapper.Create(Host, "Label mapper", data.Data, | ||
lab.Name, dstName, type, NumberType.Float, | ||
(ref T src, ref Float dst) => | ||
dst = equalsTarget(ref src) ? 1 : default(Float)); | ||
lab.Name, lab.Name, type, NumberType.Float, | ||
(ref T src, ref float dst) => | ||
dst = equalsTarget(ref src) ? 1 : default(float)); | ||
} | ||
|
||
protected TScalarTrainer GetTrainer() | ||
|
@@ -95,9 +128,14 @@ protected TScalarTrainer GetTrainer() | |
return train; | ||
} | ||
|
||
protected abstract TPred TrainCore(IChannel ch, RoleMappedData data, int count); | ||
protected abstract TModel TrainCore(IChannel ch, RoleMappedData data, int count); | ||
|
||
public override TPred Train(TrainContext context) | ||
/// <summary> | ||
/// The legacy train method. | ||
/// </summary> | ||
/// <param name="context">The trainig context for this learner.</param> | ||
/// <returns>The trained model.</returns> | ||
public TModel Train(TrainContext context) | ||
{ | ||
Host.CheckValue(context, nameof(context)); | ||
var data = context.TrainingSet; | ||
|
@@ -116,5 +154,76 @@ public override TPred Train(TrainContext context) | |
return pred; | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Gets the output columns. | ||
/// </summary> | ||
/// <param name="inputSchema">The input schema. </param> | ||
/// <returns>The output <see cref="SchemaShape"/></returns> | ||
public SchemaShape GetOutputSchema(SchemaShape inputSchema) | ||
{ | ||
Host.CheckValue(inputSchema, nameof(inputSchema)); | ||
|
||
if (LabelColumn != null) | ||
{ | ||
if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol)) | ||
throw Host.ExceptSchemaMismatch(nameof(labelCol), DefaultColumnNames.PredictedLabel, DefaultColumnNames.PredictedLabel); | ||
|
||
if (!LabelColumn.IsCompatibleWith(labelCol)) | ||
throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible"); | ||
} | ||
|
||
var outColumns = inputSchema.Columns.ToDictionary(x => x.Name); | ||
foreach (var col in GetOutputColumnsCore(inputSchema)) | ||
outColumns[col.Name] = col; | ||
|
||
return new SchemaShape(outColumns.Values); | ||
} | ||
|
||
private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) | ||
{ | ||
if (LabelColumn != null) | ||
{ | ||
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); | ||
Contracts.Assert(success); | ||
|
||
var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) | ||
.Concat(MetadataForScoreColumn())); | ||
return new[] | ||
{ | ||
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(MetadataForScoreColumn())), | ||
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, metadata) | ||
}; | ||
} | ||
else | ||
return new[] | ||
{ | ||
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(MetadataForScoreColumn())), | ||
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, new SchemaShape(MetadataForScoreColumn())) | ||
}; | ||
} | ||
|
||
/// <summary> | ||
/// Normal metadata that we produce for score columns. | ||
/// </summary> | ||
private static IEnumerable<SchemaShape.Column> MetadataForScoreColumn() | ||
{ | ||
var cols = new List<SchemaShape.Column>(); | ||
cols.Add(new SchemaShape.Column(MetadataUtils.Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true)); | ||
cols.Add(new SchemaShape.Column(MetadataUtils.Kinds.ScoreColumnKind, SchemaShape.Column.VectorKind.Scalar, TextType.Instance, false)); | ||
cols.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)); | ||
cols.Add(new SchemaShape.Column(MetadataUtils.Kinds.ScoreValueKind, SchemaShape.Column.VectorKind.Scalar, TextType.Instance, false)); | ||
|
||
return cols; | ||
} | ||
|
||
IPredictor ITrainer.Train(TrainContext context) => Train(context); | ||
|
||
/// <summary> | ||
/// Fits the data to the trainer. | ||
/// </summary> | ||
/// <param name="input">The input data to fit to.</param> | ||
/// <returns>The transformer.</returns> | ||
public abstract TTransformer Fit(IDataView input); | ||
} | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is temporary, until FastTree is an estimator. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @sfilipi, honestly I feel like this sort of thing really should not have a default. It seems to be along the same lines as the discussion in #682.
In reply to: 217222841 [](ancestors = 217222841)