Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Core/Data/AnnotationUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.Column
schema.Schema[list[0].Index].Annotations.GetValue(Kinds.SlotNames, ref slotNames);
}

public static bool HasKeyValues(this SchemaShape.Column col)
public static bool NeedsSlotNames(this SchemaShape.Column col)
{
return col.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol)
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector
Expand Down Expand Up @@ -441,7 +441,7 @@ public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int co
public static IEnumerable<SchemaShape.Column> AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
{
var cols = new List<SchemaShape.Column>();
if (labelColumn != null && labelColumn.Value.IsKey && HasKeyValues(labelColumn.Value))
if (labelColumn != null && labelColumn.Value.IsKey && NeedsSlotNames(labelColumn.Value))
cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
cols.AddRange(GetTrainerOutputAnnotation());
return cols;
Expand Down
26 changes: 24 additions & 2 deletions src/Microsoft.ML.Data/Data/Conversion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ private Conversions()
AddStd<I1, R4>(Convert);
AddStd<I1, R8>(Convert);
AddAux<I1, SB>(Convert);
AddStd<I1, BL>(Convert);

AddStd<I2, I1>(Convert);
AddStd<I2, I2>(Convert);
Expand All @@ -118,6 +119,7 @@ private Conversions()
AddStd<I2, R4>(Convert);
AddStd<I2, R8>(Convert);
AddAux<I2, SB>(Convert);
AddStd<I2, BL>(Convert);

AddStd<I4, I1>(Convert);
AddStd<I4, I2>(Convert);
Expand All @@ -126,6 +128,7 @@ private Conversions()
AddStd<I4, R4>(Convert);
AddStd<I4, R8>(Convert);
AddAux<I4, SB>(Convert);
AddStd<I4, BL>(Convert);

AddStd<I8, I1>(Convert);
AddStd<I8, I2>(Convert);
Expand All @@ -134,6 +137,7 @@ private Conversions()
AddStd<I8, R4>(Convert);
AddStd<I8, R8>(Convert);
AddAux<I8, SB>(Convert);
AddStd<I8, BL>(Convert);

AddStd<U1, U1>(Convert);
AddStd<U1, U2>(Convert);
Expand All @@ -143,6 +147,7 @@ private Conversions()
AddStd<U1, R4>(Convert);
AddStd<U1, R8>(Convert);
AddAux<U1, SB>(Convert);
AddStd<U1, BL>(Convert);

AddStd<U2, U1>(Convert);
AddStd<U2, U2>(Convert);
Expand All @@ -152,6 +157,7 @@ private Conversions()
AddStd<U2, R4>(Convert);
AddStd<U2, R8>(Convert);
AddAux<U2, SB>(Convert);
AddStd<U2, BL>(Convert);

AddStd<U4, U1>(Convert);
AddStd<U4, U2>(Convert);
Expand All @@ -161,6 +167,7 @@ private Conversions()
AddStd<U4, R4>(Convert);
AddStd<U4, R8>(Convert);
AddAux<U4, SB>(Convert);
AddStd<U4, BL>(Convert);

AddStd<U8, U1>(Convert);
AddStd<U8, U2>(Convert);
Expand All @@ -170,6 +177,7 @@ private Conversions()
AddStd<U8, R4>(Convert);
AddStd<U8, R8>(Convert);
AddAux<U8, SB>(Convert);
AddStd<U8, BL>(Convert);

AddStd<UG, U1>(Convert);
AddStd<UG, U2>(Convert);
Expand All @@ -179,11 +187,13 @@ private Conversions()
AddAux<UG, SB>(Convert);

AddStd<R4, R4>(Convert);
AddStd<R4, BL>(Convert);
AddStd<R4, R8>(Convert);
AddAux<R4, SB>(Convert);

AddStd<R8, R4>(Convert);
AddStd<R8, R8>(Convert);
AddStd<R8, BL>(Convert);
AddAux<R8, SB>(Convert);

AddStd<TX, I1>(Convert);
Expand Down Expand Up @@ -899,6 +909,18 @@ public void Convert(in BL src, ref SB dst)
public void Convert(in DT src, ref SB dst) { ClearDst(ref dst); dst.AppendFormat("{0:o}", src); }
public void Convert(in DZ src, ref SB dst) { ClearDst(ref dst); dst.AppendFormat("{0:o}", src); }
#endregion ToStringBuilder
#region ToBL
public void Convert(in R8 src, ref BL dst) => dst = src > 0 ? true : false;
Copy link
Member

@wschin wschin Mar 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public void Convert(in R8 src, ref BL dst) => dst = src > 0 ? true : false;
public void Convert(in R8 src, ref BL dst) => dst = src > 0.5 ? true : false;

Rounding to the nearest number looks more reasonable. #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about that.


In reply to: 263190893 [](ancestors = 263190893)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 0.00000001 looks more closer to 0?


In reply to: 263196830 [](ancestors = 263196830,263190893)

public void Convert(in R4 src, ref BL dst) => dst = src > 0 ? true : false;
Copy link
Member

@wschin wschin Mar 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public void Convert(in R4 src, ref BL dst) => dst = src > 0 ? true : false;
public void Convert(in R4 src, ref BL dst) => dst = src > 0.5 ? true : false;
``` #Closed

public void Convert(in I1 src, ref BL dst) => dst = src > 0 ? true : false;
public void Convert(in I2 src, ref BL dst) => dst = src > 0 ? true : false;
public void Convert(in I4 src, ref BL dst) => dst = src > 0 ? true : false;
public void Convert(in I8 src, ref BL dst) => dst = src > 0 ? true : false;
public void Convert(in U1 src, ref BL dst) => dst = src > 0 ? true : false;
public void Convert(in U2 src, ref BL dst) => dst = src > 0 ? true : false;
public void Convert(in U4 src, ref BL dst) => dst = src > 0 ? true : false;
public void Convert(in U8 src, ref BL dst) => dst = src > 0 ? true : false;
#endregion

#region FromR4
public void Convert(in R4 src, ref R4 dst) => dst = src;
Expand Down Expand Up @@ -1138,7 +1160,7 @@ private bool TryParseCore(ReadOnlySpan<char> span, out ulong dst)
dst = res;
return true;

LFail:
LFail:
dst = 0;
return false;
}
Expand Down Expand Up @@ -1245,7 +1267,7 @@ private bool TryParseNonNegative(ReadOnlySpan<char> span, out long result)
result = res;
return true;

LFail:
LFail:
result = 0;
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ internal static bool CanWrap(ISchemaBoundMapper mapper, DataViewType labelNameTy
var scoreType = outSchema[scoreIdx].Type;

// Check that the type is vector, and is of compatible size with the score output.
return labelNameType is VectorType vectorType && vectorType.Size == scoreType.GetVectorSize();
return labelNameType is VectorType vectorType && vectorType.Size == scoreType.GetVectorSize() && vectorType.ItemType == TextDataViewType.Instance;
}

internal static ISchemaBoundMapper WrapCore<T>(IHostEnvironment env, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
Expand Down
9 changes: 8 additions & 1 deletion src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,14 @@ private protected TTransformer TrainTransformer(IDataView trainSet,
IDataView validationSet = null, IPredictor initPredictor = null)
{
var trainRoleMapped = MakeRoles(trainSet);
var validRoleMapped = validationSet == null ? null : MakeRoles(validationSet);
CheckInputSchema(SchemaShape.Create(trainSet.Schema));
Copy link
Contributor

@rogancarr rogancarr Mar 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CheckInputSchema [](start = 12, length = 16)

Nit: Ordering for train and valid in Make and Check are different. I'd prefer to have the same sequence. #Resolved

RoleMappedData validRoleMapped = null;

if (validationSet != null)
{
CheckInputSchema(SchemaShape.Create(validationSet.Schema));
validRoleMapped = MakeRoles(validationSet);
}

var pred = TrainModelCore(new TrainContext(trainRoleMapped, validRoleMapped, null, initPredictor));
return MakeTransformer(pred, trainSet.Schema);
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.FastTree/FastTreeRanking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ private protected override void CheckLabelCompatible(SchemaShape.Column labelCol
if (!labelCol.IsKey && labelCol.ItemType != NumberDataViewType.Single)
error();
}

private protected override float GetMaxLabel()
{
return GetLabelGains().Length - 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
using Microsoft.ML.Calibrators;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Data.Conversion;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;

namespace Microsoft.ML.Trainers
{
Expand All @@ -32,7 +31,7 @@ internal abstract class OptionsBase
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", SortOrder = 150, ShortName = "numcali")]
internal int MaxCalibrationExamples = 1000000000;

[Argument(ArgumentType.Multiple, HelpText = "Whether to treat missing labels as having negative labels, instead of keeping them missing", SortOrder = 150, ShortName = "missNeg")]
[Argument(ArgumentType.Multiple, HelpText = "Whether to treat missing labels as having negative labels, or exclude their rows from dataview.", SortOrder = 150, ShortName = "missNeg")]
public bool ImputeMissingLabelsAsNegative;
}

Expand Down Expand Up @@ -99,19 +98,16 @@ private protected IDataView MapLabelsCore<T>(DataViewType type, InPredicate<T> e
Host.Assert(data.Schema.Label.HasValue);

var lab = data.Schema.Label.Value;
Copy link
Contributor

@rogancarr rogancarr Mar 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lab [](start = 16, length = 3)

Nit: This hurts my eyes. #Resolved


InPredicate<T> isMissing;
if (!Args.ImputeMissingLabelsAsNegative && Conversions.Instance.TryGetIsNAPredicate(type, out isMissing))
IDataView dataView = data.Data;
if (!Args.ImputeMissingLabelsAsNegative)
{
Copy link
Member

@sfilipi sfilipi Mar 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{ [](start = 12, length = 1)

remove #Resolved

return LambdaColumnMapper.Create(Host, "Label mapper", data.Data,
lab.Name, lab.Name, type, NumberDataViewType.Single,
(in T src, ref float dst) =>
dst = equalsTarget(in src) ? 1 : (isMissing(in src) ? float.NaN : default(float)));
dataView = new NAFilter(Host, data.Data, false, lab.Name);
}

return LambdaColumnMapper.Create(Host, "Label mapper", data.Data,
lab.Name, lab.Name, type, NumberDataViewType.Single,
(in T src, ref float dst) =>
dst = equalsTarget(in src) ? 1 : default(float));
lab.Name, lab.Name, type, BooleanDataViewType.Instance,
(in T src, ref bool dst) =>
dst = equalsTarget(in src) ? true : false);
}

private protected abstract TModel TrainCore(IChannel ch, RoleMappedData data, int count);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,6 @@ private IDataView MapLabels(RoleMappedData data, int cls)
uint key = (uint)(cls + 1);
return MapLabelsCore(NumberDataViewType.UInt32, (in uint val) => key == val, data);
}
if (lab.Type == NumberDataViewType.Single)
{
float key = cls;
return MapLabelsCore(NumberDataViewType.Single, (in float val) => key == val, data);
}
if (lab.Type == NumberDataViewType.Double)
{
double key = cls;
return MapLabelsCore(NumberDataViewType.Double, (in double val) => key == val, data);
}

throw Host.ExceptNotSupp($"Label column type is not supported by OneVersusAllTrainer: {lab.Type.RawType}");
}
Expand Down Expand Up @@ -465,7 +455,7 @@ protected bool IsValid(IValueMapper mapper, ref VectorType inputType)
return false;
if (mapper.OutputType != NumberDataViewType.Single)
return false;
if (!(mapper.InputType is VectorType mapperVectorType)|| mapperVectorType.ItemType != NumberDataViewType.Single)
if (!(mapper.InputType is VectorType mapperVectorType) || mapperVectorType.ItemType != NumberDataViewType.Single)
return false;
if (inputType == null)
inputType = mapperVectorType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,6 @@ private IDataView MapLabels(RoleMappedData data, int cls1, int cls2)
uint key2 = (uint)(cls2 + 1);
return MapLabelsCore(NumberDataViewType.UInt32, (in uint val) => val == key1 || val == key2, data);
}
if (lab.Type == NumberDataViewType.Single)
{
float key1 = cls1;
float key2 = cls2;
return MapLabelsCore(NumberDataViewType.Single, (in float val) => val == key1 || val == key2, data);
}
if (lab.Type == NumberDataViewType.Double)
{
double key1 = cls1;
double key2 = cls2;
return MapLabelsCore(NumberDataViewType.Double, (in double val) => val == key1 || val == key2, data);
}

throw Host.ExceptNotSupp($"Label column type is not supported by nameof(PairwiseCouplingTrainer): {lab.Type.RawType}");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,20 +180,6 @@ private protected override void CheckLabels(RoleMappedData data)
data.CheckBinaryLabel();
}

private protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
{
Contracts.Assert(labelCol.IsValid);

Action error =
() => throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", labelCol.Name, "float, double, bool or KeyType", labelCol.GetTypeString());

if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
error();

if (!labelCol.IsKey && labelCol.ItemType != NumberDataViewType.Single && labelCol.ItemType != NumberDataViewType.Double && !(labelCol.ItemType is BooleanDataViewType))
error();
}

private protected override TrainStateBase MakeState(IChannel ch, int numFeatures, LinearModelParameters predictor)
{
return new TrainState(ch, numFeatures, predictor, this);
Expand Down
14 changes: 0 additions & 14 deletions src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1523,20 +1523,6 @@ private protected SdcaBinaryTrainerBase(IHostEnvironment env, BinaryOptionsBase

private protected abstract SchemaShape.Column[] ComputeSdcaBinaryClassifierSchemaShape();

private protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
Copy link
Member

@wschin wschin Mar 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CheckLabelCompatible [](start = 40, length = 20)

Can label be text? If no, we still a check. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you can see this is override method. By removing this one, I would force system to fallback to default behavior. Which is - to check it's a boolean column.


In reply to: 263182006 [](ancestors = 263182006)

{
Contracts.Assert(labelCol.IsValid);

Action error =
() => throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", labelCol.Name, "float, double, bool or KeyType", labelCol.GetTypeString());

if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
error();

if (!labelCol.IsKey && labelCol.ItemType != NumberDataViewType.Single && labelCol.ItemType != NumberDataViewType.Double && !(labelCol.ItemType is BooleanDataViewType))
error();
}

private protected LinearBinaryModelParameters CreateLinearBinaryModelParameters(VBuffer<float>[] weights, float[] bias)
{
Host.CheckParam(Utils.Size(weights) == 1, nameof(weights));
Expand Down
13 changes: 0 additions & 13 deletions src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,6 @@ private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape
};
}

private protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
{
Contracts.Assert(labelCol.IsValid);

Action error =
() => throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", labelCol.Name, "float, double or KeyType", labelCol.GetTypeString());

if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
error();
if (!labelCol.IsKey && labelCol.ItemType != NumberDataViewType.Single && labelCol.ItemType != NumberDataViewType.Double)
error();
}

/// <inheritdoc/>
private protected override void TrainWithoutLock(IProgressChannelProvider progress, FloatLabelCursor.Factory cursorFactory, Random rand,
IdToIdxLookup idToIdx, int numThreads, DualsTableBase duals, float[] biasReg, float[] invariants, float lambdaNInv,
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Transforms/Text/NgramTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
if (!IsSchemaColumnValid(col))
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, ExpectedColumnType, col.GetTypeString());
var metadata = new List<SchemaShape.Column>();
if (col.HasKeyValues())
if (col.NeedsSlotNames())
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(metadata));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
#@ col=X2VBuffer:R4:1-4
#@ col=X3Important:R4:5
#@ col=Label:R4:6
#@ col=Features:R4:7-12
#@ col=Features:R4:13-18
#@ col=FeatureContributions:R4:19-24
#@ col=FeatureContributions:R4:25-30
#@ col=FeatureContributions:R4:31-36
#@ col=FeatureContributions:R4:37-42
#@ col=Label:BL:7
Copy link
Member

@wschin wschin Mar 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to have two label columns? #ByDesign

Copy link
Contributor

@rogancarr rogancarr Mar 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ByDesign — response on a different thread.


In reply to: 263191537 [](ancestors = 263191537)

#@ col=Features:R4:8-13
#@ col=Features:R4:14-19
#@ col=FeatureContributions:R4:20-25
#@ col=FeatureContributions:R4:26-31
#@ col=FeatureContributions:R4:32-37
#@ col=FeatureContributions:R4:38-43
#@ }
950 757 692 720 297 7515 1 950 757 692 720 297 7515 0.956696868 0.760804 0.7872582 0.754716933 0.297893673 0.7578661 0 0.1527809 0 0 0 1 -0.6583012 0 -1 -0.517060339 0 0 12 2:-0.13028869 5:1 8:-0.370813 11:2.84608746
459 961 0 659 274 2147 0 459 961 0 659 274 2147 0.462235659 0.965829134 0 0.690775633 0.27482447 0.21651876 0 0.6788808 0 0 0 0.99999994 -0.6720779 0 0 -1 -0.870772958 0 12 3:-0.215823054 5:0.99999994 9:-0.175488681 11:0.8131137
672 275 0 65 195 9818 1 672 275 0 65 195 9818 0.6767372 0.2763819 0 0.06813417 0.195586756 0.990116954 0 0.04248268 0 0 0 1 -1 0 0 -0.100242466 -0.6298147 0 12 0:-0.04643902 5:1 6:-0.172673345 11:3.71828127
186 301 0 681 526 1456 0 186 301 0 681 526 1456 0.187311172 0.302512556 0 0.713836432 0.527582765 0.1468334 0 0.313550383 0 0 0 1 -0.162922 0 0 -0.6181894 -1 0 12 4:-0.5319963 5:1 10:-0.293352127 11:0.5514176
950 757 692 720 297 7515 1 1 950 757 692 720 297 7515 0.956696868 0.760804 0.7872582 0.754716933 0.297893673 0.7578661 0 0.1527809 0 0 0 1 -0.6583012 0 -1 -0.517060339 0 0 12 2:-0.13028869 5:1 8:-0.370813 11:2.84608746
459 961 0 659 274 2147 0 0 459 961 0 659 274 2147 0.462235659 0.965829134 0 0.690775633 0.27482447 0.21651876 0 0.6788808 0 0 0 0.99999994 -0.6720779 0 0 -1 -0.870772958 0 12 3:-0.215823054 5:0.99999994 9:-0.175488681 11:0.8131137
672 275 0 65 195 9818 1 1 672 275 0 65 195 9818 0.6767372 0.2763819 0 0.06813417 0.195586756 0.990116954 0 0.04248268 0 0 0 1 -1 0 0 -0.100242466 -0.6298147 0 12 0:-0.04643902 5:1 6:-0.172673345 11:3.71828127
186 301 0 681 526 1456 0 0 186 301 0 681 526 1456 0.187311172 0.302512556 0 0.713836432 0.527582765 0.1468334 0 0.313550383 0 0 0 1 -0.162922 0 0 -0.6181894 -1 0 12 4:-0.5319963 5:1 10:-0.293352127 11:0.5514176
Loading