Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public Arguments()
env => new Ova(env, new Ova.Arguments()
{
PredictorType = ComponentFactoryUtils.CreateFromFunction(
e => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments()))
Copy link
Member Author

@sfilipi sfilipi Sep 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e => new FastTreeBinaryClassificationTrain [](start = 26, length = 44)

this is temporary, until FastTree is an estimator. #Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @sfilipi, honestly I feel like this sort of thing really should not have a default. It seems to be along the same lines as the discussion in #682.


In reply to: 217222841 [](ancestors = 217222841)

e => new AveragedPerceptronTrainer(e, new AveragedPerceptronTrainer.Arguments()))
}));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,97 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Float = System.Single;

using System;
using System.Linq;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.Conversion;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Calibration;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Training;
using Microsoft.ML.Core.Data;

namespace Microsoft.ML.Runtime.Learners
{
using TScalarTrainer = ITrainer<IPredictorProducing<Float>>;
using TScalarTrainer = ITrainerEstimator<IPredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>;
Copy link
Member Author

@sfilipi sfilipi Sep 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IPredictionTransformer [](start = 45, length = 22)

should be BinaryPredictionTransformer #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's fine as it is


In reply to: 217218242 [](ancestors = 217218242)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Binary is not an interface, so we can't have covariance if we have it there


In reply to: 217234106 [](ancestors = 217234106,217218242)


public abstract class MetaMulticlassTrainer<TPred, TArgs> : TrainerBase<TPred>
where TPred : IPredictor
where TArgs : MetaMulticlassTrainer<TPred, TArgs>.ArgumentsBase
public abstract class MetaMulticlassTrainer<TTransformer, TModel> : ITrainerEstimator<TTransformer, TModel>, ITrainer<TModel>
where TTransformer : IPredictionTransformer<TModel>
where TModel : IPredictor
{
public abstract class ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "Base predictor", ShortName = "p", SortOrder = 1, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
[Argument(ArgumentType.Multiple, HelpText = "Base predictor", ShortName = "p", SortOrder = 4, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
[TGUI(Label = "Predictor Type", Description = "Type of underlying binary predictor")]
public IComponentFactory<TScalarTrainer> PredictorType;

[Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "<None>", SignatureType = typeof(SignatureCalibrator))]
[Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", SortOrder = 150, NullName = "<None>", SignatureType = typeof(SignatureCalibrator))]
Copy link
Contributor

@TomFinley TomFinley Sep 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

150 [](start = 109, length = 3)

What is the purpose of these 150 sort orders? #Pending

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the old GUI doesn't show them. They are the "advanced" ones?


In reply to: 216345747 [](ancestors = 216345747)

public IComponentFactory<ICalibratorTrainer> Calibrator = new PlattCalibratorTrainerFactory();

[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", ShortName = "numcali")]
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", SortOrder = 150, ShortName = "numcali")]
public int MaxCalibrationExamples = 1000000000;

[Argument(ArgumentType.Multiple, HelpText = "Whether to treat missing labels as having negative labels, instead of keeping them missing", ShortName = "missNeg")]
[Argument(ArgumentType.Multiple, HelpText = "Whether to treat missing labels as having negative labels, instead of keeping them missing", SortOrder = 150, ShortName = "missNeg")]
public bool ImputeMissingLabelsAsNegative;
}

protected readonly TArgs Args;
/// <summary>
/// The label column that the trainer expects. Can be <c>null</c>, which indicates that label
Copy link
Contributor

@Zruty0 Zruty0 Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be null [](start = 55, length = 18)

can it? #Pending

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the comment since not relevant here; but the LabelColumn can be null, if instantiated through the legacy constructors.
Added a check for it in the new constructors, where it should not be null.


In reply to: 217234672 [](ancestors = 217234672)

/// is not used for training.
/// </summary>
public readonly SchemaShape.Column LabelColumn;

protected readonly ArgumentsBase Args;
protected readonly IHost Host;
protected readonly ICalibratorTrainer Calibrator;

private TScalarTrainer _trainer;

public sealed override PredictionKind PredictionKind => PredictionKind.MultiClassClassification;
public override TrainerInfo Info { get; }
public PredictionKind PredictionKind => PredictionKind.MultiClassClassification;

protected SchemaShape.Column[] OutputColumns { get; }

public TrainerInfo Info { get; }

public TScalarTrainer PredictorType;

internal MetaMulticlassTrainer(IHostEnvironment env, TArgs args, string name)
: base(env, name)
/// <summary>
/// Initializes the <see cref="MetaMulticlassTrainer{TTransformer, TModel}"/> from the Arguments class.
/// </summary>
/// <param name="env">The private instance of the <see cref="IHostEnvironment"/>.</param>
/// <param name="args">The legacy arguments <see cref="ArgumentsBase"/>class.</param>
/// <param name="name">The component name.</param>
/// <param name="labelColumn">The label column for the metalinear trainer and the binary trainer.</param>
/// <param name="singleEstimator">The binary estimator.</param>
/// <param name="calibrator">The calibrator.</param>
internal MetaMulticlassTrainer(IHostEnvironment env, ArgumentsBase args, string name, string labelColumn = null,
TScalarTrainer singleEstimator = null, ICalibratorTrainer calibrator = null)
{
Host = Contracts.CheckRef(env, nameof(env)).Register(name);
Host.CheckValue(args, nameof(args));
Args = args;

if (labelColumn != null)
LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);

// Create the first trainer so errors in the args surface early.
_trainer = CreateTrainer();
_trainer = singleEstimator ?? CreateTrainer();

Calibrator = calibrator ?? null;

if (args.Calibrator != null)
Calibrator = args.Calibrator.CreateComponent(Host);

// Regarding caching, no matter what the internal predictor, we're performing many passes
// simply by virtue of this being a meta-trainer, so we will still cache.
Info = new TrainerInfo(normalization: _trainer.Info.NeedNormalization);

OutputColumns = new[]
Copy link
Contributor

@Zruty0 Zruty0 Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OutputColumns [](start = 12, length = 13)

you no longer need these, right? #Resolved

{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true)
Copy link
Contributor

@Zruty0 Zruty0 Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true [](start = 127, length = 4)

check what I do with the base class now: you will have to do the same (as in, pass through the keyValues metadata here, if present #Pending

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skip adding the metadata, since i am creating the LabelColumn on line 77, without metadata?


In reply to: 217235004 [](ancestors = 217235004)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you will need metadata for keys. Check what multi-SDCA does


In reply to: 217271448 [](ancestors = 217271448,217235004)

};
}

private TScalarTrainer CreateTrainer()
Expand All @@ -61,29 +102,28 @@ private TScalarTrainer CreateTrainer()
new LinearSvm(Host, new LinearSvm.Arguments());
}

protected IDataView MapLabelsCore<T>(ColumnType type, RefPredicate<T> equalsTarget, RoleMappedData data, string dstName)
protected IDataView MapLabelsCore<T>(ColumnType type, RefPredicate<T> equalsTarget, RoleMappedData data)
{
Host.AssertValue(type);
Host.Assert(type.RawType == typeof(T));
Host.AssertValue(equalsTarget);
Host.AssertValue(data);
Host.AssertValue(data.Schema.Label);
Host.AssertNonWhiteSpace(dstName);

var lab = data.Schema.Label;

RefPredicate<T> isMissing;
if (!Args.ImputeMissingLabelsAsNegative && Conversions.Instance.TryGetIsNAPredicate(type, out isMissing))
{
return LambdaColumnMapper.Create(Host, "Label mapper", data.Data,
lab.Name, dstName, type, NumberType.Float,
(ref T src, ref Float dst) =>
dst = equalsTarget(ref src) ? 1 : (isMissing(ref src) ? Float.NaN : default(Float)));
lab.Name, lab.Name, type, NumberType.Float,
(ref T src, ref float dst) =>
dst = equalsTarget(ref src) ? 1 : (isMissing(ref src) ? float.NaN : default(float)));
}
return LambdaColumnMapper.Create(Host, "Label mapper", data.Data,
lab.Name, dstName, type, NumberType.Float,
(ref T src, ref Float dst) =>
dst = equalsTarget(ref src) ? 1 : default(Float));
lab.Name, lab.Name, type, NumberType.Float,
(ref T src, ref float dst) =>
dst = equalsTarget(ref src) ? 1 : default(float));
}

protected TScalarTrainer GetTrainer()
Expand All @@ -95,9 +135,14 @@ protected TScalarTrainer GetTrainer()
return train;
}

protected abstract TPred TrainCore(IChannel ch, RoleMappedData data, int count);
protected abstract TModel TrainCore(IChannel ch, RoleMappedData data, int count);

public override TPred Train(TrainContext context)
/// <summary>
/// The legacy train method.
/// </summary>
/// <param name="context">The trainig context for this learner.</param>
/// <returns>The trained model.</returns>
public TModel Train(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var data = context.TrainingSet;
Expand All @@ -116,5 +161,41 @@ public override TPred Train(TrainContext context)
return pred;
}
}

/// <summary>
///
/// </summary>
/// <param name="inputSchema"></param>
/// <returns></returns>
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));

// Special treatment for label column: we allow different types of labels, so the trainers
// may define their own requirements on the label column.
Copy link
Contributor

@Zruty0 Zruty0 Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may define their own requirements on the label column [](start = 14, length = 54)

they may not any more #Closed

if (LabelColumn != null)
{
if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
throw Host.Except($"Label column '{LabelColumn.Name}' is not found");
Copy link
Contributor

@Zruty0 Zruty0 Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except [](start = 31, length = 6)

ExceptSchemaMismatch #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought ExceptSchemaMismatch is when there is a type mismatch. This is checking for existence?


In reply to: 217234801 [](ancestors = 217234801)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's 2 overloads to ExceptSchemaMismatch, one says 'XXX column YYY not found', another says 'XXX column YYY is expected to be ZZZ but is QQQ'


In reply to: 217272278 [](ancestors = 217272278,217234801)


if (!labelCol.IsKey || labelCol.ItemType != NumberType.R4 || labelCol.ItemType != NumberType.R8)
throw Host.ExceptSchemaMismatch(nameof(labelCol), DefaultColumnNames.PredictedLabel, labelCol.Name, "R8, R4 or a Key", labelCol.GetTypeString());
}

var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);
foreach (var col in OutputColumns)
outColumns[col.Name] = col;

return new SchemaShape(outColumns.Values);
}

IPredictor ITrainer.Train(TrainContext context) => Train(context);

/// <summary>
///
/// </summary>
/// <param name="input"></param>
/// <returns></returns>
public abstract TTransformer Fit(IDataView input);
}
}
}
Loading