From b36f5b6304cf94c4f193ca125020dc8b4519bafa Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Mon, 14 May 2018 16:50:48 -0700 Subject: [PATCH 1/3] fix. --- src/Microsoft.ML/LearningPipeline.cs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML/LearningPipeline.cs b/src/Microsoft.ML/LearningPipeline.cs index 630084a588..87389077ed 100644 --- a/src/Microsoft.ML/LearningPipeline.cs +++ b/src/Microsoft.ML/LearningPipeline.cs @@ -182,7 +182,9 @@ public PredictionModel Train() if (transformModels.Count > 0) { - transformModels.Insert(0,lastTransformModel); + if (lastTransformModel != null) + transformModels.Insert(0, lastTransformModel); + var modelInput = new Transforms.ModelCombiner { Models = new ArrayVar(transformModels.ToArray()) From 1b225f0e2e2a75cb8263117f81d04f11cc2477ef Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Mon, 14 May 2018 17:19:55 -0700 Subject: [PATCH 2/3] Add test. --- .../LearningPipelineTests.cs | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/Microsoft.ML.Tests/LearningPipelineTests.cs b/test/Microsoft.ML.Tests/LearningPipelineTests.cs index bafef95040..19eebbfb58 100644 --- a/test/Microsoft.ML.Tests/LearningPipelineTests.cs +++ b/test/Microsoft.ML.Tests/LearningPipelineTests.cs @@ -3,7 +3,9 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML; +using Microsoft.ML.Runtime.Api; using Microsoft.ML.TestFramework; +using Microsoft.ML.Transforms; using System.Linq; using Xunit; using Xunit.Abstractions; @@ -42,5 +44,40 @@ public void CanAddAndRemoveFromPipeline() pipeline.Add(new Trainers.StochasticDualCoordinateAscentRegressor()); Assert.Equal(3, pipeline.Count); } + + class InputData + { + [Column(ordinal: "1")] + public string F1; + } + + class TransformedData + { + [ColumnName("F1")] + public float[] TransformedF1; + } + + [Fact] + public void TransformOnlyPipeline() + { + const string _dataPath = @"..\..\Data\breast-cancer.txt"; + var pipeline = new LearningPipeline(); + pipeline.Add(new TextLoader(_dataPath, useHeader: false)); + pipeline.Add(new CategoricalHashOneHotVectorizer("F1") { HashBits = 10, Seed = 314489979, OutputKind = CategoricalTransformOutputKind.Bag }); + var model = pipeline.Train(); + var predictionModel = model.Predict(new InputData() { F1 = "5" }); + + Assert.NotNull(predictionModel); + Assert.NotNull(predictionModel.TransformedF1); + Assert.Equal(1024, predictionModel.TransformedF1.Length); + + for (int index = 0; index < 1024; index++) + if (index == 265) + Assert.Equal(1, predictionModel.TransformedF1[index]); + else + Assert.Equal(0, predictionModel.TransformedF1[index]); + + predictionModel.TransformedF1 = null; + } } } From bbe39e48e69f6e9df54ea1fb1ad6805986c31159 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Mon, 14 May 2018 17:43:30 -0700 Subject: [PATCH 3/3] PR feedback. --- test/Microsoft.ML.Tests/LearningPipelineTests.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/Microsoft.ML.Tests/LearningPipelineTests.cs b/test/Microsoft.ML.Tests/LearningPipelineTests.cs index 19eebbfb58..30dd844d58 100644 --- a/test/Microsoft.ML.Tests/LearningPipelineTests.cs +++ b/test/Microsoft.ML.Tests/LearningPipelineTests.cs @@ -45,16 +45,18 @@ public void CanAddAndRemoveFromPipeline() Assert.Equal(3, pipeline.Count); } - class InputData + private class InputData { [Column(ordinal: "1")] public string F1; } - class TransformedData + private class TransformedData { +#pragma warning disable 649 [ColumnName("F1")] public float[] TransformedF1; +#pragma warning restore 649 } [Fact] @@ -76,8 +78,6 @@ public void TransformOnlyPipeline() Assert.Equal(1, predictionModel.TransformedF1[index]); else Assert.Equal(0, predictionModel.TransformedF1[index]); - - predictionModel.TransformedF1 = null; } } }