-
Notifications
You must be signed in to change notification settings - Fork 1.9k
One type label policy in trainers #2804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
62c50a3
1631815
1ae2d9a
1620d74
f90311b
9f93371
395bc18
9dff7be
5707e17
538c7c1
909ba4e
44174f1
84dc7ce
fe64174
a08a0ac
be2afa2
55565d2
aeb7338
c173249
7d154e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -110,6 +110,7 @@ private Conversions() | |||||||
AddStd<I1, R4>(Convert); | ||||||||
AddStd<I1, R8>(Convert); | ||||||||
AddAux<I1, SB>(Convert); | ||||||||
AddStd<I1, BL>(Convert); | ||||||||
|
||||||||
AddStd<I2, I1>(Convert); | ||||||||
AddStd<I2, I2>(Convert); | ||||||||
|
@@ -118,6 +119,7 @@ private Conversions() | |||||||
AddStd<I2, R4>(Convert); | ||||||||
AddStd<I2, R8>(Convert); | ||||||||
AddAux<I2, SB>(Convert); | ||||||||
AddStd<I2, BL>(Convert); | ||||||||
|
||||||||
AddStd<I4, I1>(Convert); | ||||||||
AddStd<I4, I2>(Convert); | ||||||||
|
@@ -126,6 +128,7 @@ private Conversions() | |||||||
AddStd<I4, R4>(Convert); | ||||||||
AddStd<I4, R8>(Convert); | ||||||||
AddAux<I4, SB>(Convert); | ||||||||
AddStd<I4, BL>(Convert); | ||||||||
|
||||||||
AddStd<I8, I1>(Convert); | ||||||||
AddStd<I8, I2>(Convert); | ||||||||
|
@@ -134,6 +137,7 @@ private Conversions() | |||||||
AddStd<I8, R4>(Convert); | ||||||||
AddStd<I8, R8>(Convert); | ||||||||
AddAux<I8, SB>(Convert); | ||||||||
AddStd<I8, BL>(Convert); | ||||||||
|
||||||||
AddStd<U1, U1>(Convert); | ||||||||
AddStd<U1, U2>(Convert); | ||||||||
|
@@ -143,6 +147,7 @@ private Conversions() | |||||||
AddStd<U1, R4>(Convert); | ||||||||
AddStd<U1, R8>(Convert); | ||||||||
AddAux<U1, SB>(Convert); | ||||||||
AddStd<U1, BL>(Convert); | ||||||||
|
||||||||
AddStd<U2, U1>(Convert); | ||||||||
AddStd<U2, U2>(Convert); | ||||||||
|
@@ -152,6 +157,7 @@ private Conversions() | |||||||
AddStd<U2, R4>(Convert); | ||||||||
AddStd<U2, R8>(Convert); | ||||||||
AddAux<U2, SB>(Convert); | ||||||||
AddStd<U2, BL>(Convert); | ||||||||
|
||||||||
AddStd<U4, U1>(Convert); | ||||||||
AddStd<U4, U2>(Convert); | ||||||||
|
@@ -161,6 +167,7 @@ private Conversions() | |||||||
AddStd<U4, R4>(Convert); | ||||||||
AddStd<U4, R8>(Convert); | ||||||||
AddAux<U4, SB>(Convert); | ||||||||
AddStd<U4, BL>(Convert); | ||||||||
|
||||||||
AddStd<U8, U1>(Convert); | ||||||||
AddStd<U8, U2>(Convert); | ||||||||
|
@@ -170,6 +177,7 @@ private Conversions() | |||||||
AddStd<U8, R4>(Convert); | ||||||||
AddStd<U8, R8>(Convert); | ||||||||
AddAux<U8, SB>(Convert); | ||||||||
AddStd<U8, BL>(Convert); | ||||||||
|
||||||||
AddStd<UG, U1>(Convert); | ||||||||
AddStd<UG, U2>(Convert); | ||||||||
|
@@ -179,11 +187,13 @@ private Conversions() | |||||||
AddAux<UG, SB>(Convert); | ||||||||
|
||||||||
AddStd<R4, R4>(Convert); | ||||||||
AddStd<R4, BL>(Convert); | ||||||||
AddStd<R4, R8>(Convert); | ||||||||
AddAux<R4, SB>(Convert); | ||||||||
|
||||||||
AddStd<R8, R4>(Convert); | ||||||||
AddStd<R8, R8>(Convert); | ||||||||
AddStd<R8, BL>(Convert); | ||||||||
AddAux<R8, SB>(Convert); | ||||||||
|
||||||||
AddStd<TX, I1>(Convert); | ||||||||
|
@@ -899,6 +909,18 @@ public void Convert(in BL src, ref SB dst) | |||||||
public void Convert(in DT src, ref SB dst) { ClearDst(ref dst); dst.AppendFormat("{0:o}", src); } | ||||||||
public void Convert(in DZ src, ref SB dst) { ClearDst(ref dst); dst.AppendFormat("{0:o}", src); } | ||||||||
#endregion ToStringBuilder | ||||||||
#region ToBL | ||||||||
public void Convert(in R8 src, ref BL dst) => dst = src > 0 ? true : false; | ||||||||
public void Convert(in R4 src, ref BL dst) => dst = src > 0 ? true : false; | ||||||||
|
public void Convert(in R4 src, ref BL dst) => dst = src > 0 ? true : false; | |
public void Convert(in R4 src, ref BL dst) => dst = src > 0.5 ? true : false; | |
``` #Closed |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -145,7 +145,14 @@ private protected TTransformer TrainTransformer(IDataView trainSet, | |
IDataView validationSet = null, IPredictor initPredictor = null) | ||
{ | ||
var trainRoleMapped = MakeRoles(trainSet); | ||
var validRoleMapped = validationSet == null ? null : MakeRoles(validationSet); | ||
CheckInputSchema(SchemaShape.Create(trainSet.Schema)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Nit: Ordering for |
||
RoleMappedData validRoleMapped = null; | ||
|
||
if (validationSet != null) | ||
{ | ||
CheckInputSchema(SchemaShape.Create(validationSet.Schema)); | ||
validRoleMapped = MakeRoles(validationSet); | ||
} | ||
|
||
var pred = TrainModelCore(new TrainContext(trainRoleMapped, validRoleMapped, null, initPredictor)); | ||
return MakeTransformer(pred, trainSet.Schema); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,9 +8,8 @@ | |
using Microsoft.ML.Calibrators; | ||
using Microsoft.ML.CommandLine; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.Data.Conversion; | ||
using Microsoft.ML.Internal.Internallearn; | ||
using Microsoft.ML.Trainers; | ||
using Microsoft.ML.Transforms; | ||
|
||
namespace Microsoft.ML.Trainers | ||
{ | ||
|
@@ -32,7 +31,7 @@ internal abstract class OptionsBase | |
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", SortOrder = 150, ShortName = "numcali")] | ||
internal int MaxCalibrationExamples = 1000000000; | ||
|
||
[Argument(ArgumentType.Multiple, HelpText = "Whether to treat missing labels as having negative labels, instead of keeping them missing", SortOrder = 150, ShortName = "missNeg")] | ||
[Argument(ArgumentType.Multiple, HelpText = "Whether to treat missing labels as having negative labels, or exclude their rows from dataview.", SortOrder = 150, ShortName = "missNeg")] | ||
public bool ImputeMissingLabelsAsNegative; | ||
} | ||
|
||
|
@@ -99,19 +98,16 @@ private protected IDataView MapLabelsCore<T>(DataViewType type, InPredicate<T> e | |
Host.Assert(data.Schema.Label.HasValue); | ||
|
||
var lab = data.Schema.Label.Value; | ||
|
||
|
||
InPredicate<T> isMissing; | ||
if (!Args.ImputeMissingLabelsAsNegative && Conversions.Instance.TryGetIsNAPredicate(type, out isMissing)) | ||
IDataView dataView = data.Data; | ||
if (!Args.ImputeMissingLabelsAsNegative) | ||
{ | ||
|
||
return LambdaColumnMapper.Create(Host, "Label mapper", data.Data, | ||
lab.Name, lab.Name, type, NumberDataViewType.Single, | ||
(in T src, ref float dst) => | ||
dst = equalsTarget(in src) ? 1 : (isMissing(in src) ? float.NaN : default(float))); | ||
dataView = new NAFilter(Host, data.Data, false, lab.Name); | ||
} | ||
|
||
return LambdaColumnMapper.Create(Host, "Label mapper", data.Data, | ||
lab.Name, lab.Name, type, NumberDataViewType.Single, | ||
(in T src, ref float dst) => | ||
dst = equalsTarget(in src) ? 1 : default(float)); | ||
lab.Name, lab.Name, type, BooleanDataViewType.Instance, | ||
(in T src, ref bool dst) => | ||
dst = equalsTarget(in src) ? true : false); | ||
} | ||
|
||
private protected abstract TModel TrainCore(IChannel ch, RoleMappedData data, int count); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1523,20 +1523,6 @@ private protected SdcaBinaryTrainerBase(IHostEnvironment env, BinaryOptionsBase | |
|
||
private protected abstract SchemaShape.Column[] ComputeSdcaBinaryClassifierSchemaShape(); | ||
|
||
private protected override void CheckLabelCompatible(SchemaShape.Column labelCol) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Can label be text? If no, we still a check. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As you can see this is override method. By removing this one, I would force system to fallback to default behavior. Which is - to check it's a boolean column. In reply to: 263182006 [](ancestors = 263182006) |
||
{ | ||
Contracts.Assert(labelCol.IsValid); | ||
|
||
Action error = | ||
() => throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", labelCol.Name, "float, double, bool or KeyType", labelCol.GetTypeString()); | ||
|
||
if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar) | ||
error(); | ||
|
||
if (!labelCol.IsKey && labelCol.ItemType != NumberDataViewType.Single && labelCol.ItemType != NumberDataViewType.Double && !(labelCol.ItemType is BooleanDataViewType)) | ||
error(); | ||
} | ||
|
||
private protected LinearBinaryModelParameters CreateLinearBinaryModelParameters(VBuffer<float>[] weights, float[] bias) | ||
{ | ||
Host.CheckParam(Utils.Size(weights) == 1, nameof(weights)); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,14 +4,15 @@ | |
#@ col=X2VBuffer:R4:1-4 | ||
#@ col=X3Important:R4:5 | ||
#@ col=Label:R4:6 | ||
#@ col=Features:R4:7-12 | ||
#@ col=Features:R4:13-18 | ||
#@ col=FeatureContributions:R4:19-24 | ||
#@ col=FeatureContributions:R4:25-30 | ||
#@ col=FeatureContributions:R4:31-36 | ||
#@ col=FeatureContributions:R4:37-42 | ||
#@ col=Label:BL:7 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason to have two label columns? #ByDesign There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
#@ col=Features:R4:8-13 | ||
#@ col=Features:R4:14-19 | ||
#@ col=FeatureContributions:R4:20-25 | ||
#@ col=FeatureContributions:R4:26-31 | ||
#@ col=FeatureContributions:R4:32-37 | ||
#@ col=FeatureContributions:R4:38-43 | ||
#@ } | ||
950 757 692 720 297 7515 1 950 757 692 720 297 7515 0.956696868 0.760804 0.7872582 0.754716933 0.297893673 0.7578661 0 0.1527809 0 0 0 1 -0.6583012 0 -1 -0.517060339 0 0 12 2:-0.13028869 5:1 8:-0.370813 11:2.84608746 | ||
459 961 0 659 274 2147 0 459 961 0 659 274 2147 0.462235659 0.965829134 0 0.690775633 0.27482447 0.21651876 0 0.6788808 0 0 0 0.99999994 -0.6720779 0 0 -1 -0.870772958 0 12 3:-0.215823054 5:0.99999994 9:-0.175488681 11:0.8131137 | ||
672 275 0 65 195 9818 1 672 275 0 65 195 9818 0.6767372 0.2763819 0 0.06813417 0.195586756 0.990116954 0 0.04248268 0 0 0 1 -1 0 0 -0.100242466 -0.6298147 0 12 0:-0.04643902 5:1 6:-0.172673345 11:3.71828127 | ||
186 301 0 681 526 1456 0 186 301 0 681 526 1456 0.187311172 0.302512556 0 0.713836432 0.527582765 0.1468334 0 0.313550383 0 0 0 1 -0.162922 0 0 -0.6181894 -1 0 12 4:-0.5319963 5:1 10:-0.293352127 11:0.5514176 | ||
950 757 692 720 297 7515 1 1 950 757 692 720 297 7515 0.956696868 0.760804 0.7872582 0.754716933 0.297893673 0.7578661 0 0.1527809 0 0 0 1 -0.6583012 0 -1 -0.517060339 0 0 12 2:-0.13028869 5:1 8:-0.370813 11:2.84608746 | ||
459 961 0 659 274 2147 0 0 459 961 0 659 274 2147 0.462235659 0.965829134 0 0.690775633 0.27482447 0.21651876 0 0.6788808 0 0 0 0.99999994 -0.6720779 0 0 -1 -0.870772958 0 12 3:-0.215823054 5:0.99999994 9:-0.175488681 11:0.8131137 | ||
672 275 0 65 195 9818 1 1 672 275 0 65 195 9818 0.6767372 0.2763819 0 0.06813417 0.195586756 0.990116954 0 0.04248268 0 0 0 1 -1 0 0 -0.100242466 -0.6298147 0 12 0:-0.04643902 5:1 6:-0.172673345 11:3.71828127 | ||
186 301 0 681 526 1456 0 0 186 301 0 681 526 1456 0.187311172 0.302512556 0 0.713836432 0.527582765 0.1468334 0 0.313550383 0 0 0 1 -0.162922 0 0 -0.6181894 -1 0 12 4:-0.5319963 5:1 10:-0.293352127 11:0.5514176 |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rounding to the nearest number looks more reasonable. #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about that.
In reply to: 263190893 [](ancestors = 263190893)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think 0.00000001 looks more closer to 0?
In reply to: 263196830 [](ancestors = 263196830,263190893)