Skip to content

One type label policy in trainers #2804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Core/Data/AnnotationUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.Column
schema.Schema[list[0].Index].Annotations.GetValue(Kinds.SlotNames, ref slotNames);
}

public static bool HasKeyValues(this SchemaShape.Column col)
public static bool NeedsSlotNames(this SchemaShape.Column col)
{
return col.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol)
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector
Expand Down Expand Up @@ -442,7 +442,7 @@ public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int co
public static IEnumerable<SchemaShape.Column> AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
{
var cols = new List<SchemaShape.Column>();
if (labelColumn != null && labelColumn.Value.IsKey && HasKeyValues(labelColumn.Value))
if (labelColumn != null && labelColumn.Value.IsKey && NeedsSlotNames(labelColumn.Value))
cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
cols.AddRange(GetTrainerOutputAnnotation());
return cols;
Expand Down
27 changes: 25 additions & 2 deletions src/Microsoft.ML.Data/Data/Conversion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ private Conversions()
AddStd<I1, R4>(Convert);
AddStd<I1, R8>(Convert);
AddAux<I1, SB>(Convert);
AddStd<I1, BL>(Convert);

AddStd<I2, I1>(Convert);
AddStd<I2, I2>(Convert);
Expand All @@ -119,6 +120,7 @@ private Conversions()
AddStd<I2, R4>(Convert);
AddStd<I2, R8>(Convert);
AddAux<I2, SB>(Convert);
AddStd<I2, BL>(Convert);

AddStd<I4, I1>(Convert);
AddStd<I4, I2>(Convert);
Expand All @@ -127,6 +129,7 @@ private Conversions()
AddStd<I4, R4>(Convert);
AddStd<I4, R8>(Convert);
AddAux<I4, SB>(Convert);
AddStd<I4, BL>(Convert);

AddStd<I8, I1>(Convert);
AddStd<I8, I2>(Convert);
Expand All @@ -135,6 +138,7 @@ private Conversions()
AddStd<I8, R4>(Convert);
AddStd<I8, R8>(Convert);
AddAux<I8, SB>(Convert);
AddStd<I8, BL>(Convert);

AddStd<U1, U1>(Convert);
AddStd<U1, U2>(Convert);
Expand All @@ -144,6 +148,7 @@ private Conversions()
AddStd<U1, R4>(Convert);
AddStd<U1, R8>(Convert);
AddAux<U1, SB>(Convert);
AddStd<U1, BL>(Convert);

AddStd<U2, U1>(Convert);
AddStd<U2, U2>(Convert);
Expand All @@ -153,6 +158,7 @@ private Conversions()
AddStd<U2, R4>(Convert);
AddStd<U2, R8>(Convert);
AddAux<U2, SB>(Convert);
AddStd<U2, BL>(Convert);

AddStd<U4, U1>(Convert);
AddStd<U4, U2>(Convert);
Expand All @@ -162,6 +168,7 @@ private Conversions()
AddStd<U4, R4>(Convert);
AddStd<U4, R8>(Convert);
AddAux<U4, SB>(Convert);
AddStd<U4, BL>(Convert);

AddStd<U8, U1>(Convert);
AddStd<U8, U2>(Convert);
Expand All @@ -171,6 +178,7 @@ private Conversions()
AddStd<U8, R4>(Convert);
AddStd<U8, R8>(Convert);
AddAux<U8, SB>(Convert);
AddStd<U8, BL>(Convert);

AddStd<UG, U1>(Convert);
AddStd<UG, U2>(Convert);
Expand All @@ -180,11 +188,13 @@ private Conversions()
AddAux<UG, SB>(Convert);

AddStd<R4, R4>(Convert);
AddStd<R4, BL>(Convert);
AddStd<R4, R8>(Convert);
AddAux<R4, SB>(Convert);

AddStd<R8, R4>(Convert);
AddStd<R8, R8>(Convert);
AddStd<R8, BL>(Convert);
AddAux<R8, SB>(Convert);

AddStd<TX, I1>(Convert);
Expand Down Expand Up @@ -901,6 +911,19 @@ public void Convert(in BL src, ref SB dst)
public void Convert(in DZ src, ref SB dst) { ClearDst(ref dst); dst.AppendFormat("{0:o}", src); }
#endregion ToStringBuilder

#region ToBL
public void Convert(in R8 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
public void Convert(in R4 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
public void Convert(in I1 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
public void Convert(in I2 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
public void Convert(in I4 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
public void Convert(in I8 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
public void Convert(in U1 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
public void Convert(in U2 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
public void Convert(in U4 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
public void Convert(in U8 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
#endregion

#region FromR4
public void Convert(in R4 src, ref R4 dst) => dst = src;
public void Convert(in R4 src, ref R8 dst) => dst = src;
Expand Down Expand Up @@ -1139,7 +1162,7 @@ private bool TryParseCore(ReadOnlySpan<char> span, out ulong dst)
dst = res;
return true;

LFail:
LFail:
dst = 0;
return false;
}
Expand Down Expand Up @@ -1246,7 +1269,7 @@ private bool TryParseNonNegative(ReadOnlySpan<char> span, out long result)
result = res;
return true;

LFail:
LFail:
result = 0;
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ internal static bool CanWrap(ISchemaBoundMapper mapper, DataViewType labelNameTy
var scoreType = outSchema[scoreIdx].Type;

// Check that the type is vector, and is of compatible size with the score output.
return labelNameType is VectorType vectorType && vectorType.Size == scoreType.GetVectorSize();
return labelNameType is VectorType vectorType && vectorType.Size == scoreType.GetVectorSize() && vectorType.ItemType == TextDataViewType.Instance;
}

internal static ISchemaBoundMapper WrapCore<T>(IHostEnvironment env, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
Expand Down
9 changes: 8 additions & 1 deletion src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,15 @@ private protected virtual void CheckLabelCompatible(SchemaShape.Column labelCol)
private protected TTransformer TrainTransformer(IDataView trainSet,
IDataView validationSet = null, IPredictor initPredictor = null)
{
CheckInputSchema(SchemaShape.Create(trainSet.Schema));
Copy link
Contributor

@rogancarr rogancarr Mar 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CheckInputSchema [](start = 12, length = 16)

Nit: Ordering for train and valid in Make and Check are different. I'd prefer to have the same sequence. #Resolved

var trainRoleMapped = MakeRoles(trainSet);
var validRoleMapped = validationSet == null ? null : MakeRoles(validationSet);
RoleMappedData validRoleMapped = null;

if (validationSet != null)
{
CheckInputSchema(SchemaShape.Create(validationSet.Schema));
validRoleMapped = MakeRoles(validationSet);
}

var pred = TrainModelCore(new TrainContext(trainRoleMapped, validRoleMapped, null, initPredictor));
return MakeTransformer(pred, trainSet.Schema);
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.FastTree/FastTreeRanking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ private protected override void CheckLabelCompatible(SchemaShape.Column labelCol
if (!labelCol.IsKey && labelCol.ItemType != NumberDataViewType.Single)
error();
}

private protected override float GetMaxLabel()
{
return GetLabelGains().Length - 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
using Microsoft.ML.Calibrators;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Data.Conversion;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Runtime;

using Microsoft.ML.Transforms;
namespace Microsoft.ML.Trainers
{
using TScalarTrainer = ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>;
Expand All @@ -32,7 +31,7 @@ internal abstract class OptionsBase
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", SortOrder = 150, ShortName = "numcali")]
internal int MaxCalibrationExamples = 1000000000;

[Argument(ArgumentType.Multiple, HelpText = "Whether to treat missing labels as having negative labels, instead of keeping them missing", SortOrder = 150, ShortName = "missNeg")]
[Argument(ArgumentType.Multiple, HelpText = "Whether to treat missing labels as having negative labels, or exclude their rows from dataview.", SortOrder = 150, ShortName = "missNeg")]
public bool ImputeMissingLabelsAsNegative;
}

Expand Down Expand Up @@ -98,20 +97,15 @@ private protected IDataView MapLabelsCore<T>(DataViewType type, InPredicate<T> e
Host.AssertValue(data);
Host.Assert(data.Schema.Label.HasValue);

var lab = data.Schema.Label.Value;
var label = data.Schema.Label.Value;
IDataView dataView = data.Data;
if (!Args.ImputeMissingLabelsAsNegative)
dataView = new NAFilter(Host, data.Data, false, label.Name);

InPredicate<T> isMissing;
if (!Args.ImputeMissingLabelsAsNegative && Conversions.Instance.TryGetIsNAPredicate(type, out isMissing))
{
return LambdaColumnMapper.Create(Host, "Label mapper", data.Data,
lab.Name, lab.Name, type, NumberDataViewType.Single,
(in T src, ref float dst) =>
dst = equalsTarget(in src) ? 1 : (isMissing(in src) ? float.NaN : default(float)));
}
return LambdaColumnMapper.Create(Host, "Label mapper", data.Data,
lab.Name, lab.Name, type, NumberDataViewType.Single,
(in T src, ref float dst) =>
dst = equalsTarget(in src) ? 1 : default(float));
label.Name, label.Name, type, BooleanDataViewType.Instance,
(in T src, ref bool dst) =>
dst = equalsTarget(in src) ? true : false);
}

private protected abstract TModel TrainCore(IChannel ch, RoleMappedData data, int count);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,28 +145,18 @@ private ISingleFeaturePredictionTransformer<TScalarPredictor> TrainOne(IChannel

private IDataView MapLabels(RoleMappedData data, int cls)
{
var lab = data.Schema.Label.Value;
Host.Assert(!lab.IsHidden);
Host.Assert(lab.Type.GetKeyCount() > 0 || lab.Type == NumberDataViewType.Single || lab.Type == NumberDataViewType.Double);
var label = data.Schema.Label.Value;
Host.Assert(!label.IsHidden);
Host.Assert(label.Type.GetKeyCount() > 0 || label.Type == NumberDataViewType.Single || label.Type == NumberDataViewType.Double);

if (lab.Type.GetKeyCount() > 0)
if (label.Type.GetKeyCount() > 0)
{
// Key values are 1-based.
uint key = (uint)(cls + 1);
return MapLabelsCore(NumberDataViewType.UInt32, (in uint val) => key == val, data);
}
if (lab.Type == NumberDataViewType.Single)
{
float key = cls;
return MapLabelsCore(NumberDataViewType.Single, (in float val) => key == val, data);
}
if (lab.Type == NumberDataViewType.Double)
{
double key = cls;
return MapLabelsCore(NumberDataViewType.Double, (in double val) => key == val, data);
}

throw Host.ExceptNotSupp($"Label column type is not supported by OneVersusAllTrainer: {lab.Type.RawType}");
throw Host.ExceptNotSupp($"Label column type is not supported by OneVersusAllTrainer: {label.Type.RawType}");
}

/// <summary> Trains a <see cref="MulticlassPredictionTransformer{OneVersusAllModelParameters}"/> model.</summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,31 +142,19 @@ private ISingleFeaturePredictionTransformer<TDistPredictor> TrainOne(IChannel ch

private IDataView MapLabels(RoleMappedData data, int cls1, int cls2)
{
var lab = data.Schema.Label.Value;
Host.Assert(!lab.IsHidden);
Host.Assert(lab.Type.GetKeyCount() > 0 || lab.Type == NumberDataViewType.Single || lab.Type == NumberDataViewType.Double);
var label = data.Schema.Label.Value;
Host.Assert(!label.IsHidden);
Host.Assert(label.Type.GetKeyCount() > 0 || label.Type == NumberDataViewType.Single || label.Type == NumberDataViewType.Double);

if (lab.Type.GetKeyCount() > 0)
if (label.Type.GetKeyCount() > 0)
{
// Key values are 1-based.
uint key1 = (uint)(cls1 + 1);
uint key2 = (uint)(cls2 + 1);
return MapLabelsCore(NumberDataViewType.UInt32, (in uint val) => val == key1 || val == key2, data);
}
if (lab.Type == NumberDataViewType.Single)
{
float key1 = cls1;
float key2 = cls2;
return MapLabelsCore(NumberDataViewType.Single, (in float val) => val == key1 || val == key2, data);
}
if (lab.Type == NumberDataViewType.Double)
{
double key1 = cls1;
double key2 = cls2;
return MapLabelsCore(NumberDataViewType.Double, (in double val) => val == key1 || val == key2, data);
}

throw Host.ExceptNotSupp($"Label column type is not supported by nameof(PairwiseCouplingTrainer): {lab.Type.RawType}");
throw Host.ExceptNotSupp($"Label column type is not supported by nameof(PairwiseCouplingTrainer): {label.Type.RawType}");
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,20 +174,6 @@ private protected override void CheckLabels(RoleMappedData data)
data.CheckBinaryLabel();
}

private protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
{
Contracts.Assert(labelCol.IsValid);

Action error =
() => throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", labelCol.Name, "float, double, bool or KeyType", labelCol.GetTypeString());

if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
error();

if (!labelCol.IsKey && labelCol.ItemType != NumberDataViewType.Single && labelCol.ItemType != NumberDataViewType.Double && !(labelCol.ItemType is BooleanDataViewType))
error();
}

private protected override TrainStateBase MakeState(IChannel ch, int numFeatures, LinearModelParameters predictor)
{
return new TrainState(ch, numFeatures, predictor, this);
Expand Down
14 changes: 0 additions & 14 deletions src/Microsoft.ML.StandardTrainers/Standard/SdcaBinary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1512,20 +1512,6 @@ private protected SdcaBinaryTrainerBase(IHostEnvironment env, BinaryOptionsBase

private protected abstract SchemaShape.Column[] ComputeSdcaBinaryClassifierSchemaShape();

private protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
{
Contracts.Assert(labelCol.IsValid);

Action error =
() => throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", labelCol.Name, "float, double, bool or KeyType", labelCol.GetTypeString());

if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
error();

if (!labelCol.IsKey && labelCol.ItemType != NumberDataViewType.Single && labelCol.ItemType != NumberDataViewType.Double && !(labelCol.ItemType is BooleanDataViewType))
error();
}

private protected LinearBinaryModelParameters CreateLinearBinaryModelParameters(VBuffer<float>[] weights, float[] bias)
{
Host.CheckParam(Utils.Size(weights) == 1, nameof(weights));
Expand Down
13 changes: 0 additions & 13 deletions src/Microsoft.ML.StandardTrainers/Standard/SdcaMultiClass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,6 @@ private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape
};
}

private protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
{
Contracts.Assert(labelCol.IsValid);

Action error =
() => throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", labelCol.Name, "float, double or KeyType", labelCol.GetTypeString());

if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
error();
if (!labelCol.IsKey && labelCol.ItemType != NumberDataViewType.Single && labelCol.ItemType != NumberDataViewType.Double)
error();
}

/// <inheritdoc/>
private protected override void TrainWithoutLock(IProgressChannelProvider progress, FloatLabelCursor.Factory cursorFactory, Random rand,
IdToIdxLookup idToIdx, int numThreads, DualsTableBase duals, float[] biasReg, float[] invariants, float lambdaNInv,
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Transforms/Text/NgramTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
if (!IsSchemaColumnValid(col))
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, ExpectedColumnType, col.GetTypeString());
var metadata = new List<SchemaShape.Column>();
if (col.HasKeyValues())
if (col.NeedsSlotNames())
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(metadata));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
#@ col=X2VBuffer:R4:1-4
#@ col=X3Important:R4:5
#@ col=Label:R4:6
#@ col=Features:R4:7-12
#@ col=Features:R4:13-18
#@ col=FeatureContributions:R4:19-24
#@ col=FeatureContributions:R4:25-30
#@ col=FeatureContributions:R4:31-36
#@ col=FeatureContributions:R4:37-42
#@ col=Label:BL:7
Copy link
Member

@wschin wschin Mar 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to have two label columns? #ByDesign

Copy link
Contributor

@rogancarr rogancarr Mar 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ByDesign — response on a different thread.


In reply to: 263191537 [](ancestors = 263191537)

#@ col=Features:R4:8-13
#@ col=Features:R4:14-19
#@ col=FeatureContributions:R4:20-25
#@ col=FeatureContributions:R4:26-31
#@ col=FeatureContributions:R4:32-37
#@ col=FeatureContributions:R4:38-43
#@ }
950 757 692 720 297 7515 1 950 757 692 720 297 7515 0.956696868 0.760804 0.7872582 0.754716933 0.297893673 0.7578661 0 0.1527809 0 0 0 1 -0.6583012 0 -1 -0.517060339 0 0 12 2:-0.13028869 5:1 8:-0.370813 11:2.84608746
459 961 0 659 274 2147 0 459 961 0 659 274 2147 0.462235659 0.965829134 0 0.690775633 0.27482447 0.21651876 0 0.6788808 0 0 0 0.99999994 -0.6720779 0 0 -1 -0.870772958 0 12 3:-0.215823054 5:0.99999994 9:-0.175488681 11:0.8131137
672 275 0 65 195 9818 1 672 275 0 65 195 9818 0.6767372 0.2763819 0 0.06813417 0.195586756 0.990116954 0 0.04248268 0 0 0 1 -1 0 0 -0.100242466 -0.6298147 0 12 0:-0.04643902 5:1 6:-0.172673345 11:3.71828127
186 301 0 681 526 1456 0 186 301 0 681 526 1456 0.187311172 0.302512556 0 0.713836432 0.527582765 0.1468334 0 0.313550383 0 0 0 1 -0.162922 0 0 -0.6181894 -1 0 12 4:-0.5319963 5:1 10:-0.293352127 11:0.5514176
950 757 692 720 297 7515 1 1 950 757 692 720 297 7515 0.956696868 0.760804 0.7872582 0.754716933 0.297893673 0.7578661 0 0.1527809 0 0 0 1 -0.6583012 0 -1 -0.517060339 0 0 12 2:-0.13028869 5:1 8:-0.370813 11:2.84608746
459 961 0 659 274 2147 0 0 459 961 0 659 274 2147 0.462235659 0.965829134 0 0.690775633 0.27482447 0.21651876 0 0.6788808 0 0 0 0.99999994 -0.6720779 0 0 -1 -0.870772958 0 12 3:-0.215823054 5:0.99999994 9:-0.175488681 11:0.8131137
672 275 0 65 195 9818 1 1 672 275 0 65 195 9818 0.6767372 0.2763819 0 0.06813417 0.195586756 0.990116954 0 0.04248268 0 0 0 1 -1 0 0 -0.100242466 -0.6298147 0 12 0:-0.04643902 5:1 6:-0.172673345 11:3.71828127
186 301 0 681 526 1456 0 0 186 301 0 681 526 1456 0.187311172 0.302512556 0 0.713836432 0.527582765 0.1468334 0 0.313550383 0 0 0 1 -0.162922 0 0 -0.6181894 -1 0 12 4:-0.5319963 5:1 10:-0.293352127 11:0.5514176
Loading