From e925604040071bc07821e1f6754ef4fcc5c77b32 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Thu, 17 May 2018 11:36:26 -0700 Subject: [PATCH 1/3] handle boolean type in construction utils. --- .../DataViewConstructionUtils.cs | 18 +++++++++++++++ .../LearningPipelineTests.cs | 23 +++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index 03e7408e71..033bfaeed3 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -152,6 +152,12 @@ private Delegate CreateGetter(int index) Ch.Assert(colType.IsText); return CreateStringToTextGetter(index); } + else if (outputType == typeof(bool)) + { + Ch.Assert(colType.IsBool); + return CreateBooleanToDvBoolGetter(index); + } + // T -> T Ch.Assert(colType.RawType == outputType); del = CreateDirectGetter; @@ -197,6 +203,18 @@ private Delegate CreateStringToTextGetter(int index) }); } + private Delegate CreateBooleanToDvBoolGetter(int index) + { + var peek = DataView._peeks[index] as Peek; + Ch.AssertValue(peek); + bool buf = false; + return (ValueGetter)((ref DvBool dst) => + { + peek(GetCurrentRowObject(), Position, ref buf); + dst = buf ? DvBool.True : DvBool.False; + }); + } + private Delegate CreateArrayToVBufferGetter(int index) { var peek = DataView._peeks[index] as Peek; diff --git a/test/Microsoft.ML.Tests/LearningPipelineTests.cs b/test/Microsoft.ML.Tests/LearningPipelineTests.cs index ec7a6b6e92..a9ff1e86a4 100644 --- a/test/Microsoft.ML.Tests/LearningPipelineTests.cs +++ b/test/Microsoft.ML.Tests/LearningPipelineTests.cs @@ -110,5 +110,28 @@ public void NoTransformPipeline() pipeline.Add(new FastForestBinaryClassifier()); var model = pipeline.Train(); } + + public class BooleanLabelData + { + [ColumnName("Features")] + [VectorType(2)] + public float[] Features; + + [ColumnName("Label")] + public bool Label; + } + + [Fact] + public void BooleanLabelPipeline() + { + var data = new BooleanLabelData[1]; + data[0] = new BooleanLabelData(); + data[0].Features = new float[] { 0.0f, 1.0f }; + data[0].Label = false; + var pipeline = new LearningPipeline(); + pipeline.Add(CollectionDataSource.Create(data)); + pipeline.Add(new FastForestBinaryClassifier()); + var model = pipeline.Train(); + } } } From 18d4bb440329980b17246dc47640ab1dfba64455 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Fri, 18 May 2018 10:22:44 -0700 Subject: [PATCH 2/3] nullable bool --- src/Microsoft.ML.Api/ApiUtils.cs | 2 +- .../DataViewConstructionUtils.cs | 23 ++++++++++++++-- src/Microsoft.ML.Core/Data/DataKind.cs | 2 +- .../LearningPipelineTests.cs | 26 +++++++++++++++++++ 4 files changed, 49 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.ML.Api/ApiUtils.cs b/src/Microsoft.ML.Api/ApiUtils.cs index 80967007ec..5b5936b5a8 100644 --- a/src/Microsoft.ML.Api/ApiUtils.cs +++ b/src/Microsoft.ML.Api/ApiUtils.cs @@ -21,7 +21,7 @@ private static OpCode GetAssignmentOpCode(Type t) // REVIEW: This should be a Dictionary based solution. // DvTexts, strings, arrays, and VBuffers. if (t == typeof(DvInt8) || t == typeof(DvInt4) || t == typeof(DvInt2) || t == typeof(DvInt1) || - t == typeof(DvBool) || t == typeof(DvText) || t == typeof(string) || t.IsArray || + t == typeof(DvBool) || t==typeof(bool?) || t == typeof(DvText) || t == typeof(string) || t.IsArray || (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || t == typeof(DvDateTime) || t == typeof(DvDateTimeZone) || t == typeof(DvTimeSpan) || t == typeof(UInt128)) { diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index 033bfaeed3..6115bce21c 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -154,10 +154,17 @@ private Delegate CreateGetter(int index) } else if (outputType == typeof(bool)) { + // Bool -> DvBool Ch.Assert(colType.IsBool); return CreateBooleanToDvBoolGetter(index); } - + else if (outputType == typeof(bool?)) + { + // Bool -> DvBool + Ch.Assert(colType.IsBool); + return CreateNullableBooleanToDvBoolGetter(index); + } + // T -> T Ch.Assert(colType.RawType == outputType); del = CreateDirectGetter; @@ -211,7 +218,19 @@ private Delegate CreateBooleanToDvBoolGetter(int index) return (ValueGetter)((ref DvBool dst) => { peek(GetCurrentRowObject(), Position, ref buf); - dst = buf ? DvBool.True : DvBool.False; + dst = (DvBool)buf; + }); + } + + private Delegate CreateNullableBooleanToDvBoolGetter(int index) + { + var peek = DataView._peeks[index] as Peek; + Ch.AssertValue(peek); + bool? buf = null; + return (ValueGetter)((ref DvBool dst) => + { + peek(GetCurrentRowObject(), Position, ref buf); + dst = buf.HasValue ? (DvBool)buf.Value : DvBool.NA; }); } diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs index 1c043b070e..5ed5ded1c1 100644 --- a/src/Microsoft.ML.Core/Data/DataKind.cs +++ b/src/Microsoft.ML.Core/Data/DataKind.cs @@ -207,7 +207,7 @@ public static bool TryGetDataKind(this Type type, out DataKind kind) kind = DataKind.R8; else if (type == typeof(DvText)) kind = DataKind.TX; - else if (type == typeof(DvBool) || type == typeof(bool)) + else if (type == typeof(DvBool) || type == typeof(bool) ||type ==typeof(bool?)) kind = DataKind.BL; else if (type == typeof(DvTimeSpan)) kind = DataKind.TS; diff --git a/test/Microsoft.ML.Tests/LearningPipelineTests.cs b/test/Microsoft.ML.Tests/LearningPipelineTests.cs index a9ff1e86a4..4519fc5285 100644 --- a/test/Microsoft.ML.Tests/LearningPipelineTests.cs +++ b/test/Microsoft.ML.Tests/LearningPipelineTests.cs @@ -133,5 +133,31 @@ public void BooleanLabelPipeline() pipeline.Add(new FastForestBinaryClassifier()); var model = pipeline.Train(); } + + public class NullableBooleanLabelData + { + [ColumnName("Features")] + [VectorType(2)] + public float[] Features; + + [ColumnName("Label")] + public bool? Label; + } + + [Fact] + public void NullableBooleanLabelPipeline() + { + var data = new NullableBooleanLabelData[2]; + data[0] = new NullableBooleanLabelData(); + data[0].Features = new float[] { 0.0f, 1.0f }; + data[0].Label = null; + data[1] = new NullableBooleanLabelData(); + data[1].Features = new float[] { 1.0f, 0.0f }; + data[1].Label = false; + var pipeline = new LearningPipeline(); + pipeline.Add(CollectionDataSource.Create(data)); + pipeline.Add(new FastForestBinaryClassifier()); + var model = pipeline.Train(); + } } } From fc741dd3b07a4f929bb703d9cd8a0d9c37dc117c Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Fri, 18 May 2018 10:27:16 -0700 Subject: [PATCH 3/3] bool? --- src/Microsoft.ML.Api/DataViewConstructionUtils.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index 6115bce21c..840b866fe4 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -160,7 +160,7 @@ private Delegate CreateGetter(int index) } else if (outputType == typeof(bool?)) { - // Bool -> DvBool + // Bool? -> DvBool Ch.Assert(colType.IsBool); return CreateNullableBooleanToDvBoolGetter(index); }