From 7f04b87c2044c935b21a5135729e418cb5523918 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 24 Sep 2018 10:59:49 -0700 Subject: [PATCH 1/9] First commit --- src/Microsoft.ML.Data/Training/TrainContext.cs | 3 ++- .../Training/TrainingStaticExtensions.cs | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs diff --git a/src/Microsoft.ML.Data/Training/TrainContext.cs b/src/Microsoft.ML.Data/Training/TrainContext.cs index b15bd1b317..bcb27961e4 100644 --- a/src/Microsoft.ML.Data/Training/TrainContext.cs +++ b/src/Microsoft.ML.Data/Training/TrainContext.cs @@ -2,9 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; -namespace Microsoft.ML.Runtime.Training +namespace Microsoft.ML { /// /// A training context is an object instantiable by a user to do various tasks relating to a particular diff --git a/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs new file mode 100644 index 0000000000..e17e5ab782 --- /dev/null +++ b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.ML +{ + /// + /// Defines static extension methods that allow operations like train-test split, cross-validate, + /// sampling etc. with the . + /// + public static class TrainingStaticExtensions + { + } +} From 260aff64da5783955bdb33e5bddeb16858d957f5 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 24 Sep 2018 11:55:13 -0700 Subject: [PATCH 2/9] Temp 2 --- .../Training/TrainContext.cs | 67 +++++++++++++++++++ .../Training/TrainingStaticExtensions.cs | 30 +++++++++ .../StaticPipeTests.cs | 2 + 3 files changed, 99 insertions(+) diff --git a/src/Microsoft.ML.Data/Training/TrainContext.cs b/src/Microsoft.ML.Data/Training/TrainContext.cs index bcb27961e4..d09d127bc9 100644 --- a/src/Microsoft.ML.Data/Training/TrainContext.cs +++ b/src/Microsoft.ML.Data/Training/TrainContext.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Transforms; namespace Microsoft.ML { @@ -17,6 +18,72 @@ public abstract class TrainContextBase protected readonly IHost Host; internal IHostEnvironment Environment => Host; + /// + /// Split the dataset into the train set and test set according to the given fraction. + /// Respects the if provided. + /// + /// The dataset to split. + /// The fraction of data to go into the test set. + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// A pair of datasets, for the train and test set. + public (IDataView trainSet, IDataView testSet) TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null) + { + Host.CheckValue(data, nameof(data)); + Host.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1"); + Host.CheckValueOrNull(stratificationColumn); + + // We need to handle two cases: if the stratification column is provided, we use hashJoin to + // build a single hash of it. If it is not, we generate a random number. + + if (stratificationColumn != null) + { + stratificationColumn = data.Schema.GetTempColumnName("StratificationColumn"); + data = new GenerateNumberTransform(Host, data, stratificationColumn); + } + else + { + if (!data.Schema.TryGetColumnIndex(stratificationColumn, out int stratCol)) + throw Host.ExceptSchemaMismatch(nameof(stratificationColumn), "stratification", stratificationColumn); + + var type = data.Schema.GetColumnType(stratCol); + if (!RangeFilter.IsValidRangeFilterColumnType(Host, type)) + { + // Hash the stratification column. + // REVIEW: this could currently crash, since Hash only accepts a limited set + // of column types. It used to be HashJoin, but we should probably extend Hash + // instead of having two hash transformations. + var origStratCol = stratificationColumn; + int tmp; + int inc = 0; + + // Generate a new column with the hashed stratification column. + while (data.Schema.TryGetColumnIndex(stratificationColumn, out tmp)) + stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc); + data = new HashEstimator(Host, origStratCol, stratificationColumn, 30).Fit(data).Transform(data); + } + } + + var trainFilter = new RangeFilter(Host, new RangeFilter.Arguments() + { + Column = stratificationColumn, + Min = 0, + Max = testFraction, + Complement = true + }, data); + var testFilter = new RangeFilter(Host, new RangeFilter.Arguments() + { + Column = stratificationColumn, + Min = 0, + Max = testFraction, + Complement = false + }, data); + + return (trainFilter, testFilter); + } + protected TrainContextBase(IHostEnvironment env, string registrationName) { Contracts.CheckValue(env, nameof(env)); diff --git a/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs index e17e5ab782..db07eaf609 100644 --- a/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs +++ b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs @@ -2,6 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Data.StaticPipe; +using Microsoft.ML.Data.StaticPipe.Runtime; +using Microsoft.ML.Runtime; using System; using System.Collections.Generic; using System.Text; @@ -14,5 +17,32 @@ namespace Microsoft.ML /// public static class TrainingStaticExtensions { + /// + /// Split the dataset into the train set and test set according to the given fraction. + /// Respects the if provided. + /// + /// The tuple describing the data schema. + /// The training context. + /// The dataset to split. + /// The fraction of data to go into the test set. + /// Optional selector for the stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// A pair of datasets, for the train and test set. + public static (DataView trainSet, DataView testSet) TrainTestSplit(this TrainContextBase context, + DataView data, double testFraction = 0.1, Func stratificationColumn = null) + { + var env = StaticPipeUtils.GetEnvironment(data); + Contracts.AssertValue(env); + env.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1"); + env.CheckValueOrNull(stratificationColumn); + + var indexer = StaticPipeUtils.GetIndexer(data); + string stratName = indexer?.Get(stratificationColumn(indexer.Indices)); + + var (trainData, testData) = context.TrainTestSplit(data.AsDynamic, testFraction, stratName); + return (new DataView(env, trainData, data.Shape), new DataView(env, testData, data.Shape)); + } } } diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index dad5148ada..6987e10c60 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -558,5 +558,7 @@ public void LpGcNormAndWhitening() type = schema.GetColumnType(pcswhitenedCol); Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); } + // + //[Fact] } } From ab76c3d70eba3eddfee6d067d00babc6c8716ea6 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 24 Sep 2018 13:00:29 -0700 Subject: [PATCH 3/9] Temp 3 --- src/Microsoft.ML.Data/Training/TrainContext.cs | 2 +- .../Training/TrainingStaticExtensions.cs | 13 ++++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.ML.Data/Training/TrainContext.cs b/src/Microsoft.ML.Data/Training/TrainContext.cs index d09d127bc9..58d1d6ebe9 100644 --- a/src/Microsoft.ML.Data/Training/TrainContext.cs +++ b/src/Microsoft.ML.Data/Training/TrainContext.cs @@ -32,7 +32,7 @@ public abstract class TrainContextBase public (IDataView trainSet, IDataView testSet) TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null) { Host.CheckValue(data, nameof(data)); - Host.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1"); + Host.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); Host.CheckValueOrNull(stratificationColumn); // We need to handle two cases: if the stratification column is provided, we use hashJoin to diff --git a/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs index db07eaf609..619b8f96a2 100644 --- a/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs +++ b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs @@ -35,11 +35,18 @@ public static (DataView trainSet, DataView testSet) TrainTestSplit(this { var env = StaticPipeUtils.GetEnvironment(data); Contracts.AssertValue(env); - env.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1"); + env.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); env.CheckValueOrNull(stratificationColumn); - var indexer = StaticPipeUtils.GetIndexer(data); - string stratName = indexer?.Get(stratificationColumn(indexer.Indices)); + string stratName = null; + + if (stratificationColumn != null) + { + var indexer = StaticPipeUtils.GetIndexer(data); + var column = stratificationColumn(indexer.Indices); + env.CheckParam(column != null, nameof(stratificationColumn), "Stratification column not found"); + stratName = indexer.Get(column); + } var (trainData, testData) = context.TrainTestSplit(data.AsDynamic, testFraction, stratName); return (new DataView(env, trainData, data.Shape), new DataView(env, testData, data.Shape)); From 6c75e84652b39da88810d03d012e5468683c7c8d Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 24 Sep 2018 14:07:11 -0700 Subject: [PATCH 4/9] Added a test --- .../Training/TrainContext.cs | 2 +- .../StaticPipeTests.cs | 38 +++++++++++++++++-- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/src/Microsoft.ML.Data/Training/TrainContext.cs b/src/Microsoft.ML.Data/Training/TrainContext.cs index 58d1d6ebe9..fe31935799 100644 --- a/src/Microsoft.ML.Data/Training/TrainContext.cs +++ b/src/Microsoft.ML.Data/Training/TrainContext.cs @@ -38,7 +38,7 @@ public abstract class TrainContextBase // We need to handle two cases: if the stratification column is provided, we use hashJoin to // build a single hash of it. If it is not, we generate a random number. - if (stratificationColumn != null) + if (stratificationColumn == null) { stratificationColumn = data.Schema.GetTempColumnName("StratificationColumn"); data = new GenerateNumberTransform(Host, data, stratificationColumn); diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 6987e10c60..d7a41358d5 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.TestFramework; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Text; @@ -13,6 +14,7 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.IO; +using System.Linq; using System.Text; using Xunit; using Xunit.Abstractions; @@ -301,7 +303,7 @@ public void NormalizerWithOnFit() ImmutableArray> bb; var est = reader.MakeNewEstimator() - .Append(r => (r, + .Append(r => (r, ncdf: r.NormalizeByCumulativeDistribution(onFit: (m, s) => mm = m), n: r.NormalizeByMeanVar(onFit: (s, o) => { ss = s; Assert.Empty(o); }), b: r.NormalizeByBinning(onFit: b => bb = b))); @@ -534,7 +536,7 @@ public void LpGcNormAndWhitening() var data = reader.Read(dataSource); var est = reader.MakeNewEstimator() - .Append(r => (r.label, + .Append(r => (r.label, lpnorm: r.features.LpNormalize(), gcnorm: r.features.GlobalContrastNormalize(), zcawhitened: r.features.ZcaWhitening(), @@ -558,7 +560,35 @@ public void LpGcNormAndWhitening() type = schema.GetColumnType(pcswhitenedCol); Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); } - // - //[Fact] + + [Fact] + public void TrainTestSplit() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.iris.trainFilename); + var dataSource = new MultiFileSource(dataPath); + + var ctx = new BinaryClassificationContext(env); + + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadFloat(0), features: c.LoadFloat(1, 4))); + var data = reader.Read(dataSource); + + var (train, test) = ctx.TrainTestSplit(data, 0.5); + + // Just make sure that the train is about the same size as the test set. + var trainCount = train.GetColumn(r => r.label).Count(); + var testCount = test.GetColumn(r => r.label).Count(); + + Assert.InRange(trainCount * 1.0 / testCount, 0.8, 1.2); + + // Now stratify by label. Silly thing to do. + (train, test) = ctx.TrainTestSplit(data, 0.5, stratificationColumn: r => r.label); + var trainLabels = train.GetColumn(r => r.label).Distinct(); + var testLabels = test.GetColumn(r => r.label).Distinct(); + Assert.True(trainLabels.Count() > 0); + Assert.True(testLabels.Count() > 0); + Assert.False(trainLabels.Intersect(testLabels).Any()); + } } } From 227f38a36db3b3726bd6f5745599e243a605676a Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 24 Sep 2018 10:59:49 -0700 Subject: [PATCH 5/9] Added TrainTestSplit to training contexts. --- .../Training/TrainContext.cs | 70 ++++++++++++++++++- .../Training/TrainingStaticExtensions.cs | 55 +++++++++++++++ .../StaticPipeTests.cs | 36 +++++++++- 3 files changed, 158 insertions(+), 3 deletions(-) create mode 100644 src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs diff --git a/src/Microsoft.ML.Data/Training/TrainContext.cs b/src/Microsoft.ML.Data/Training/TrainContext.cs index b15bd1b317..fe31935799 100644 --- a/src/Microsoft.ML.Data/Training/TrainContext.cs +++ b/src/Microsoft.ML.Data/Training/TrainContext.cs @@ -2,9 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Transforms; -namespace Microsoft.ML.Runtime.Training +namespace Microsoft.ML { /// /// A training context is an object instantiable by a user to do various tasks relating to a particular @@ -16,6 +18,72 @@ public abstract class TrainContextBase protected readonly IHost Host; internal IHostEnvironment Environment => Host; + /// + /// Split the dataset into the train set and test set according to the given fraction. + /// Respects the if provided. + /// + /// The dataset to split. + /// The fraction of data to go into the test set. + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// A pair of datasets, for the train and test set. + public (IDataView trainSet, IDataView testSet) TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null) + { + Host.CheckValue(data, nameof(data)); + Host.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); + Host.CheckValueOrNull(stratificationColumn); + + // We need to handle two cases: if the stratification column is provided, we use hashJoin to + // build a single hash of it. If it is not, we generate a random number. + + if (stratificationColumn == null) + { + stratificationColumn = data.Schema.GetTempColumnName("StratificationColumn"); + data = new GenerateNumberTransform(Host, data, stratificationColumn); + } + else + { + if (!data.Schema.TryGetColumnIndex(stratificationColumn, out int stratCol)) + throw Host.ExceptSchemaMismatch(nameof(stratificationColumn), "stratification", stratificationColumn); + + var type = data.Schema.GetColumnType(stratCol); + if (!RangeFilter.IsValidRangeFilterColumnType(Host, type)) + { + // Hash the stratification column. + // REVIEW: this could currently crash, since Hash only accepts a limited set + // of column types. It used to be HashJoin, but we should probably extend Hash + // instead of having two hash transformations. + var origStratCol = stratificationColumn; + int tmp; + int inc = 0; + + // Generate a new column with the hashed stratification column. + while (data.Schema.TryGetColumnIndex(stratificationColumn, out tmp)) + stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc); + data = new HashEstimator(Host, origStratCol, stratificationColumn, 30).Fit(data).Transform(data); + } + } + + var trainFilter = new RangeFilter(Host, new RangeFilter.Arguments() + { + Column = stratificationColumn, + Min = 0, + Max = testFraction, + Complement = true + }, data); + var testFilter = new RangeFilter(Host, new RangeFilter.Arguments() + { + Column = stratificationColumn, + Min = 0, + Max = testFraction, + Complement = false + }, data); + + return (trainFilter, testFilter); + } + protected TrainContextBase(IHostEnvironment env, string registrationName) { Contracts.CheckValue(env, nameof(env)); diff --git a/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs new file mode 100644 index 0000000000..619b8f96a2 --- /dev/null +++ b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs @@ -0,0 +1,55 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data.StaticPipe; +using Microsoft.ML.Data.StaticPipe.Runtime; +using Microsoft.ML.Runtime; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.ML +{ + /// + /// Defines static extension methods that allow operations like train-test split, cross-validate, + /// sampling etc. with the . + /// + public static class TrainingStaticExtensions + { + /// + /// Split the dataset into the train set and test set according to the given fraction. + /// Respects the if provided. + /// + /// The tuple describing the data schema. + /// The training context. + /// The dataset to split. + /// The fraction of data to go into the test set. + /// Optional selector for the stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// A pair of datasets, for the train and test set. + public static (DataView trainSet, DataView testSet) TrainTestSplit(this TrainContextBase context, + DataView data, double testFraction = 0.1, Func stratificationColumn = null) + { + var env = StaticPipeUtils.GetEnvironment(data); + Contracts.AssertValue(env); + env.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); + env.CheckValueOrNull(stratificationColumn); + + string stratName = null; + + if (stratificationColumn != null) + { + var indexer = StaticPipeUtils.GetIndexer(data); + var column = stratificationColumn(indexer.Indices); + env.CheckParam(column != null, nameof(stratificationColumn), "Stratification column not found"); + stratName = indexer.Get(column); + } + + var (trainData, testData) = context.TrainTestSplit(data.AsDynamic, testFraction, stratName); + return (new DataView(env, trainData, data.Shape), new DataView(env, testData, data.Shape)); + } + } +} diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index dad5148ada..d7a41358d5 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.TestFramework; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Text; @@ -13,6 +14,7 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.IO; +using System.Linq; using System.Text; using Xunit; using Xunit.Abstractions; @@ -301,7 +303,7 @@ public void NormalizerWithOnFit() ImmutableArray> bb; var est = reader.MakeNewEstimator() - .Append(r => (r, + .Append(r => (r, ncdf: r.NormalizeByCumulativeDistribution(onFit: (m, s) => mm = m), n: r.NormalizeByMeanVar(onFit: (s, o) => { ss = s; Assert.Empty(o); }), b: r.NormalizeByBinning(onFit: b => bb = b))); @@ -534,7 +536,7 @@ public void LpGcNormAndWhitening() var data = reader.Read(dataSource); var est = reader.MakeNewEstimator() - .Append(r => (r.label, + .Append(r => (r.label, lpnorm: r.features.LpNormalize(), gcnorm: r.features.GlobalContrastNormalize(), zcawhitened: r.features.ZcaWhitening(), @@ -558,5 +560,35 @@ public void LpGcNormAndWhitening() type = schema.GetColumnType(pcswhitenedCol); Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); } + + [Fact] + public void TrainTestSplit() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.iris.trainFilename); + var dataSource = new MultiFileSource(dataPath); + + var ctx = new BinaryClassificationContext(env); + + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadFloat(0), features: c.LoadFloat(1, 4))); + var data = reader.Read(dataSource); + + var (train, test) = ctx.TrainTestSplit(data, 0.5); + + // Just make sure that the train is about the same size as the test set. + var trainCount = train.GetColumn(r => r.label).Count(); + var testCount = test.GetColumn(r => r.label).Count(); + + Assert.InRange(trainCount * 1.0 / testCount, 0.8, 1.2); + + // Now stratify by label. Silly thing to do. + (train, test) = ctx.TrainTestSplit(data, 0.5, stratificationColumn: r => r.label); + var trainLabels = train.GetColumn(r => r.label).Distinct(); + var testLabels = test.GetColumn(r => r.label).Distinct(); + Assert.True(trainLabels.Count() > 0); + Assert.True(testLabels.Count() > 0); + Assert.False(trainLabels.Intersect(testLabels).Any()); + } } } From 985fe7dbcc343987d25ce2a51a24dfe02c37f152 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 24 Sep 2018 14:53:46 -0700 Subject: [PATCH 6/9] Minor unrelated fix --- src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs index db4f207fa3..f3d84812ce 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs @@ -87,6 +87,12 @@ public void SaveTo(IHostEnvironment env, Stream outputStream) /// public static class CompositeDataReader { + /// + /// Save the contents to a stream, as a "model file". + /// + public static void SaveTo(this IDataReader reader, IHostEnvironment env, Stream outputStream) + => new CompositeDataReader(reader).SaveTo(env, outputStream); + /// /// Load the pipeline from stream. /// From 106a6a0bf3db9f50f9aaf6e915f7b31849b640a4 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 24 Sep 2018 16:44:54 -0700 Subject: [PATCH 7/9] CV work --- .../Training/TrainContext.cs | 141 +++++++++++++++--- .../Training/TrainingStaticExtensions.cs | 28 ++++ 2 files changed, 145 insertions(+), 24 deletions(-) diff --git a/src/Microsoft.ML.Data/Training/TrainContext.cs b/src/Microsoft.ML.Data/Training/TrainContext.cs index fe31935799..e8d5168958 100644 --- a/src/Microsoft.ML.Data/Training/TrainContext.cs +++ b/src/Microsoft.ML.Data/Training/TrainContext.cs @@ -2,9 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Transforms; +using System; +using System.Collections.Generic; +using System.Linq; namespace Microsoft.ML { @@ -35,6 +39,87 @@ public abstract class TrainContextBase Host.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); Host.CheckValueOrNull(stratificationColumn); + EnsureStratificationColumn(ref data, ref stratificationColumn); + + var trainFilter = new RangeFilter(Host, new RangeFilter.Arguments() + { + Column = stratificationColumn, + Min = 0, + Max = testFraction, + Complement = true + }, data); + var testFilter = new RangeFilter(Host, new RangeFilter.Arguments() + { + Column = stratificationColumn, + Min = 0, + Max = testFraction, + Complement = false + }, data); + + return (trainFilter, testFilter); + } + + /// + /// Train the on folds of the data sequentially. + /// Return each model and each scored test dataset. + /// + protected (IDataView scoredTestSet, ITransformer model)[] CrossValidateTrain(IDataView data, IEstimator estimator, + int numFolds, string stratificationColumn) + { + Host.CheckValue(data, nameof(data)); + Host.CheckValue(estimator, nameof(estimator)); + Host.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1"); + Host.CheckValueOrNull(stratificationColumn); + + EnsureStratificationColumn(ref data, ref stratificationColumn); + + Func foldFunction = + fold => + { + var trainFilter = new RangeFilter(Host, new RangeFilter.Arguments + { + Column = stratificationColumn, + Min = (double)fold / numFolds, + Max = (double)(fold + 1) / numFolds, + Complement = true + }, data); + var testFilter = new RangeFilter(Host, new RangeFilter.Arguments + { + Column = stratificationColumn, + Min = (double)fold / numFolds, + Max = (double)(fold + 1) / numFolds, + Complement = false + }, data); + + var model = estimator.Fit(trainFilter); + var scoredTest = model.Transform(testFilter); + return (scoredTest, model); + }; + + // Sequential per-fold training. + // REVIEW: we could have a parallel implementation here. We would need to + // spawn off a separate host per fold in that case. + var result = new List<(IDataView scores, ITransformer model)>(); + for (int fold = 0; fold < numFolds; fold++) + foldFunction(fold); + + return result.ToArray(); + } + + protected TrainContextBase(IHostEnvironment env, string registrationName) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckNonEmpty(registrationName, nameof(registrationName)); + Host = env.Register(registrationName); + } + + /// + /// Make sure the provided is valid + /// for , hash it if needed, or introduce a new one + /// if needed. + /// + private void EnsureStratificationColumn(ref IDataView data, ref string stratificationColumn) + { // We need to handle two cases: if the stratification column is provided, we use hashJoin to // build a single hash of it. If it is not, we generate a random number. @@ -65,30 +150,6 @@ public abstract class TrainContextBase data = new HashEstimator(Host, origStratCol, stratificationColumn, 30).Fit(data).Transform(data); } } - - var trainFilter = new RangeFilter(Host, new RangeFilter.Arguments() - { - Column = stratificationColumn, - Min = 0, - Max = testFraction, - Complement = true - }, data); - var testFilter = new RangeFilter(Host, new RangeFilter.Arguments() - { - Column = stratificationColumn, - Min = 0, - Max = testFraction, - Complement = false - }, data); - - return (trainFilter, testFilter); - } - - protected TrainContextBase(IHostEnvironment env, string registrationName) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckNonEmpty(registrationName, nameof(registrationName)); - Host = env.Register(registrationName); } /// @@ -208,6 +269,22 @@ public BinaryClassifierEvaluator.Result EvaluateNonCalibrated(IDataView data, st var eval = new BinaryClassifierEvaluator(Host, new BinaryClassifierEvaluator.Arguments() { }); return eval.Evaluate(data, label, score, predictedLabel); } + + public (BinaryClassifierEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidateNonCalibrated( + IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) + { + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); + return result.Select(x => (EvaluateNonCalibrated(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + } + + public (BinaryClassifierEvaluator.CalibratedResult metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) + { + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); + return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + } } /// @@ -259,6 +336,14 @@ public MultiClassClassifierEvaluator.Result Evaluate(IDataView data, string labe var eval = new MultiClassClassifierEvaluator(Host, args); return eval.Evaluate(data, label, score, predictedLabel); } + + public (MultiClassClassifierEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) + { + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); + return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + } } /// @@ -301,5 +386,13 @@ public RegressionEvaluator.Result Evaluate(IDataView data, string label, string var eval = new RegressionEvaluator(Host, new RegressionEvaluator.Arguments() { }); return eval.Evaluate(data, label, score); } + + public (RegressionEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) + { + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); + return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + } } } diff --git a/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs index 619b8f96a2..b4f3344f89 100644 --- a/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs +++ b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs @@ -2,9 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Core.Data; using Microsoft.ML.Data.StaticPipe; using Microsoft.ML.Data.StaticPipe.Runtime; using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; using System; using System.Collections.Generic; using System.Text; @@ -51,5 +53,31 @@ public static (DataView trainSet, DataView testSet) TrainTestSplit(this var (trainData, testData) = context.TrainTestSplit(data.AsDynamic, testFraction, stratName); return (new DataView(env, trainData, data.Shape), new DataView(env, testData, data.Shape)); } + + /// + /// Runs a sequential cross-validation by training + /// on in folds. + /// + /// The input schema shape. + /// The output schema shape. + /// The type of the trained model. + /// + /// + /// + /// + /// + /// + /// + public static (RegressionEvaluator.Result metrics, Transformer model, DataView scoredTestData)[] CrossValidate( + this RegressionContext context, + DataView data, + Estimator estimator, + Func> label, + int numFolds = 5, + Func stratificationColumn = null) + where TTransformer : class, ITransformer + { + + } } } From fbee542827fa4b7e7f596ad03cc85e39fc11a436 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 24 Sep 2018 17:32:38 -0700 Subject: [PATCH 8/9] Added CV methods and extensions. --- .../Training/TrainContext.cs | 58 ++++- .../Training/TrainingStaticExtensions.cs | 219 +++++++++++++++++- .../Standard/SdcaStatic.cs | 2 +- .../Training.cs | 27 ++- 4 files changed, 293 insertions(+), 13 deletions(-) diff --git a/src/Microsoft.ML.Data/Training/TrainContext.cs b/src/Microsoft.ML.Data/Training/TrainContext.cs index e8d5168958..c2593654cf 100644 --- a/src/Microsoft.ML.Data/Training/TrainContext.cs +++ b/src/Microsoft.ML.Data/Training/TrainContext.cs @@ -101,7 +101,7 @@ public abstract class TrainContextBase // spawn off a separate host per fold in that case. var result = new List<(IDataView scores, ITransformer model)>(); for (int fold = 0; fold < numFolds; fold++) - foldFunction(fold); + result.Add(foldFunction(fold)); return result.ToArray(); } @@ -270,6 +270,20 @@ public BinaryClassifierEvaluator.Result EvaluateNonCalibrated(IDataView data, st return eval.Evaluate(data, label, score, predictedLabel); } + /// + /// Run cross-validation over folds of , by fitting , + /// and respecting if provided. + /// Then evaluate each sub-model against and return metrics. + /// + /// The data to run cross-validation on. + /// The estimator to fit. + /// Number of cross-validation folds. + /// The label column (for evaluation). + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// Per-fold results: metrics, models, scored datasets. public (BinaryClassifierEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidateNonCalibrated( IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) { @@ -278,6 +292,20 @@ public BinaryClassifierEvaluator.Result EvaluateNonCalibrated(IDataView data, st return result.Select(x => (EvaluateNonCalibrated(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); } + /// + /// Run cross-validation over folds of , by fitting , + /// and respecting if provided. + /// Then evaluate each sub-model against and return metrics. + /// + /// The data to run cross-validation on. + /// The estimator to fit. + /// Number of cross-validation folds. + /// The label column (for evaluation). + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// Per-fold results: metrics, models, scored datasets. public (BinaryClassifierEvaluator.CalibratedResult metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) { @@ -337,6 +365,20 @@ public MultiClassClassifierEvaluator.Result Evaluate(IDataView data, string labe return eval.Evaluate(data, label, score, predictedLabel); } + /// + /// Run cross-validation over folds of , by fitting , + /// and respecting if provided. + /// Then evaluate each sub-model against and return metrics. + /// + /// The data to run cross-validation on. + /// The estimator to fit. + /// Number of cross-validation folds. + /// The label column (for evaluation). + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// Per-fold results: metrics, models, scored datasets. public (MultiClassClassifierEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) { @@ -387,6 +429,20 @@ public RegressionEvaluator.Result Evaluate(IDataView data, string label, string return eval.Evaluate(data, label, score); } + /// + /// Run cross-validation over folds of , by fitting , + /// and respecting if provided. + /// Then evaluate each sub-model against and return metrics. + /// + /// The data to run cross-validation on. + /// The estimator to fit. + /// Number of cross-validation folds. + /// The label column (for evaluation). + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// Per-fold results: metrics, models, scored datasets. public (RegressionEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) { diff --git a/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs index b4f3344f89..8361e94686 100644 --- a/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs +++ b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs @@ -9,6 +9,7 @@ using Microsoft.ML.Runtime.Data; using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace Microsoft.ML @@ -55,29 +56,227 @@ public static (DataView trainSet, DataView testSet) TrainTestSplit(this } /// - /// Runs a sequential cross-validation by training - /// on in folds. + /// Run cross-validation over folds of , by fitting , + /// and respecting if provided. + /// Then evaluate each sub-model against and return metrics. /// /// The input schema shape. /// The output schema shape. /// The type of the trained model. - /// - /// - /// - /// - /// - /// - /// + /// The training context. + /// The data to run cross-validation on. + /// The estimator to fit. + /// Number of cross-validation folds. + /// The label column (for evaluation). + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// Per-fold results: metrics, models, scored datasets. public static (RegressionEvaluator.Result metrics, Transformer model, DataView scoredTestData)[] CrossValidate( this RegressionContext context, DataView data, Estimator estimator, Func> label, int numFolds = 5, - Func stratificationColumn = null) + Func stratificationColumn = null) + where TTransformer : class, ITransformer + { + var env = StaticPipeUtils.GetEnvironment(data); + Contracts.AssertValue(env); + env.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1"); + env.CheckValue(label, nameof(label)); + env.CheckValueOrNull(stratificationColumn); + + var outIndexer = StaticPipeUtils.GetIndexer(estimator); + var labelColumn = label(outIndexer.Indices); + env.CheckParam(labelColumn != null, nameof(stratificationColumn), "Stratification column not found"); + var labelName = outIndexer.Get(labelColumn); + + string stratName = null; + if (stratificationColumn != null) + { + var indexer = StaticPipeUtils.GetIndexer(data); + var column = stratificationColumn(indexer.Indices); + env.CheckParam(column != null, nameof(stratificationColumn), "Stratification column not found"); + stratName = indexer.Get(column); + } + + var results = context.CrossValidate(data.AsDynamic, estimator.AsDynamic, numFolds, labelName, stratName); + + return results.Select(x => ( + x.metrics, + new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape), + new DataView(env, x.scoredTestData, estimator.Shape))) + .ToArray(); + } + + /// + /// Run cross-validation over folds of , by fitting , + /// and respecting if provided. + /// Then evaluate each sub-model against and return metrics. + /// + /// The input schema shape. + /// The output schema shape. + /// The type of the trained model. + /// The training context. + /// The data to run cross-validation on. + /// The estimator to fit. + /// Number of cross-validation folds. + /// The label column (for evaluation). + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// Per-fold results: metrics, models, scored datasets. + public static (MultiClassClassifierEvaluator.Result metrics, Transformer model, DataView scoredTestData)[] CrossValidate( + this MulticlassClassificationContext context, + DataView data, + Estimator estimator, + Func> label, + int numFolds = 5, + Func stratificationColumn = null) where TTransformer : class, ITransformer { + var env = StaticPipeUtils.GetEnvironment(data); + Contracts.AssertValue(env); + env.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1"); + env.CheckValue(label, nameof(label)); + env.CheckValueOrNull(stratificationColumn); + + var outputIndexer = StaticPipeUtils.GetIndexer(estimator); + var labelColumn = label(outputIndexer.Indices); + env.CheckParam(labelColumn != null, nameof(stratificationColumn), "Stratification column not found"); + var labelName = outputIndexer.Get(labelColumn); + + string stratName = null; + if (stratificationColumn != null) + { + var indexer = StaticPipeUtils.GetIndexer(data); + var column = stratificationColumn(indexer.Indices); + env.CheckParam(column != null, nameof(stratificationColumn), "Stratification column not found"); + stratName = indexer.Get(column); + } + + var results = context.CrossValidate(data.AsDynamic, estimator.AsDynamic, numFolds, labelName, stratName); + + return results.Select(x => ( + x.metrics, + new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape), + new DataView(env, x.scoredTestData, estimator.Shape))) + .ToArray(); + } + + /// + /// Run cross-validation over folds of , by fitting , + /// and respecting if provided. + /// Then evaluate each sub-model against and return metrics. + /// + /// The input schema shape. + /// The output schema shape. + /// The type of the trained model. + /// The training context. + /// The data to run cross-validation on. + /// The estimator to fit. + /// Number of cross-validation folds. + /// The label column (for evaluation). + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// Per-fold results: metrics, models, scored datasets. + public static (BinaryClassifierEvaluator.Result metrics, Transformer model, DataView scoredTestData)[] CrossValidateNonCalibrated( + this BinaryClassificationContext context, + DataView data, + Estimator estimator, + Func> label, + int numFolds = 5, + Func stratificationColumn = null) + where TTransformer : class, ITransformer + { + var env = StaticPipeUtils.GetEnvironment(data); + Contracts.AssertValue(env); + env.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1"); + env.CheckValue(label, nameof(label)); + env.CheckValueOrNull(stratificationColumn); + + var outputIndexer = StaticPipeUtils.GetIndexer(estimator); + var labelColumn = label(outputIndexer.Indices); + env.CheckParam(labelColumn != null, nameof(stratificationColumn), "Stratification column not found"); + var labelName = outputIndexer.Get(labelColumn); + + string stratName = null; + if (stratificationColumn != null) + { + var indexer = StaticPipeUtils.GetIndexer(data); + var column = stratificationColumn(indexer.Indices); + env.CheckParam(column != null, nameof(stratificationColumn), "Stratification column not found"); + stratName = indexer.Get(column); + } + + var results = context.CrossValidateNonCalibrated(data.AsDynamic, estimator.AsDynamic, numFolds, labelName, stratName); + + return results.Select(x => ( + x.metrics, + new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape), + new DataView(env, x.scoredTestData, estimator.Shape))) + .ToArray(); + } + + /// + /// Run cross-validation over folds of , by fitting , + /// and respecting if provided. + /// Then evaluate each sub-model against and return metrics. + /// + /// The input schema shape. + /// The output schema shape. + /// The type of the trained model. + /// The training context. + /// The data to run cross-validation on. + /// The estimator to fit. + /// Number of cross-validation folds. + /// The label column (for evaluation). + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// Per-fold results: metrics, models, scored datasets. + public static (BinaryClassifierEvaluator.CalibratedResult metrics, Transformer model, DataView scoredTestData)[] CrossValidate( + this BinaryClassificationContext context, + DataView data, + Estimator estimator, + Func> label, + int numFolds = 5, + Func stratificationColumn = null) + where TTransformer : class, ITransformer + { + var env = StaticPipeUtils.GetEnvironment(data); + Contracts.AssertValue(env); + env.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1"); + env.CheckValue(label, nameof(label)); + env.CheckValueOrNull(stratificationColumn); + + var outputIndexer = StaticPipeUtils.GetIndexer(estimator); + var labelColumn = label(outputIndexer.Indices); + env.CheckParam(labelColumn != null, nameof(stratificationColumn), "Stratification column not found"); + var labelName = outputIndexer.Get(labelColumn); + + string stratName = null; + if (stratificationColumn != null) + { + var indexer = StaticPipeUtils.GetIndexer(data); + var column = stratificationColumn(indexer.Indices); + env.CheckParam(column != null, nameof(stratificationColumn), "Stratification column not found"); + stratName = indexer.Get(column); + } + + var results = context.CrossValidate(data.AsDynamic, estimator.AsDynamic, numFolds, labelName, stratName); + return results.Select(x => ( + x.metrics, + new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape), + new DataView(env, x.scoredTestData, estimator.Shape))) + .ToArray(); } } } diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs index 0a92451e78..7eb68008d9 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs @@ -236,7 +236,7 @@ public static (Vector score, Key predictedLabel) { Contracts.CheckValue(label, nameof(label)); Contracts.CheckValue(features, nameof(features)); - Contracts.CheckValue(loss, nameof(loss)); + Contracts.CheckValueOrNull(loss); Contracts.CheckValueOrNull(weights); Contracts.CheckParam(!(l2Const < 0), nameof(l2Const), "Must not be negative"); Contracts.CheckParam(!(l1Threshold < 0), nameof(l1Threshold), "Must not be negative"); diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index a1c2e04463..b6d36fabe3 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -11,6 +11,7 @@ using Microsoft.ML.Runtime.Training; using Microsoft.ML.Trainers; using System; +using System.Linq; using Xunit; using Xunit.Abstractions; @@ -71,7 +72,7 @@ public void SdcaRegressionNameCollision() var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); var dataSource = new MultiFileSource(dataPath); var ctx = new RegressionContext(env); - + // Here we introduce another column called "Score" to collide with the name of the default output. Heh heh heh... var reader = TextLoader.CreateReader(env, c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10), Score: c.LoadText(2)), @@ -260,5 +261,29 @@ public void SdcaMulticlass() Assert.True(metrics.LogLoss > 0); Assert.True(metrics.TopKAccuracy > 0); } + + [Fact] + public void CrossValidate() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.iris.trainFilename); + var dataSource = new MultiFileSource(dataPath); + + var ctx = new MulticlassClassificationContext(env); + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); + + var est = reader.MakeNewEstimator() + .Append(r => (label: r.label.ToKey(), r.features)) + .Append(r => (r.label, preds: ctx.Trainers.Sdca( + r.label, + r.features, + maxIterations: 2))); + + var results = ctx.CrossValidate(reader.Read(dataSource), est, r => r.label) + .Select(x => x.metrics).ToArray(); + Assert.Equal(5, results.Length); + Assert.True(results.All(x => x.LogLoss > 0)); + } } } From 7569e08e8628e12ecdcf139562267a5aa0e1fda8 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Tue, 25 Sep 2018 13:27:08 -0700 Subject: [PATCH 9/9] Fixed build after merging --- src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs index 8361e94686..8cfda0485e 100644 --- a/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs +++ b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs @@ -3,8 +3,8 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Core.Data; -using Microsoft.ML.Data.StaticPipe; -using Microsoft.ML.Data.StaticPipe.Runtime; +using Microsoft.ML.StaticPipe; +using Microsoft.ML.StaticPipe.Runtime; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using System;