diff --git a/src/Microsoft.ML/LearningPipeline.cs b/src/Microsoft.ML/LearningPipeline.cs index 630084a588..87389077ed 100644 --- a/src/Microsoft.ML/LearningPipeline.cs +++ b/src/Microsoft.ML/LearningPipeline.cs @@ -182,7 +182,9 @@ public PredictionModel Train() if (transformModels.Count > 0) { - transformModels.Insert(0,lastTransformModel); + if (lastTransformModel != null) + transformModels.Insert(0, lastTransformModel); + var modelInput = new Transforms.ModelCombiner { Models = new ArrayVar(transformModels.ToArray()) diff --git a/test/Microsoft.ML.Tests/LearningPipelineTests.cs b/test/Microsoft.ML.Tests/LearningPipelineTests.cs index bafef95040..30dd844d58 100644 --- a/test/Microsoft.ML.Tests/LearningPipelineTests.cs +++ b/test/Microsoft.ML.Tests/LearningPipelineTests.cs @@ -3,7 +3,9 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML; +using Microsoft.ML.Runtime.Api; using Microsoft.ML.TestFramework; +using Microsoft.ML.Transforms; using System.Linq; using Xunit; using Xunit.Abstractions; @@ -42,5 +44,40 @@ public void CanAddAndRemoveFromPipeline() pipeline.Add(new Trainers.StochasticDualCoordinateAscentRegressor()); Assert.Equal(3, pipeline.Count); } + + private class InputData + { + [Column(ordinal: "1")] + public string F1; + } + + private class TransformedData + { +#pragma warning disable 649 + [ColumnName("F1")] + public float[] TransformedF1; +#pragma warning restore 649 + } + + [Fact] + public void TransformOnlyPipeline() + { + const string _dataPath = @"..\..\Data\breast-cancer.txt"; + var pipeline = new LearningPipeline(); + pipeline.Add(new TextLoader(_dataPath, useHeader: false)); + pipeline.Add(new CategoricalHashOneHotVectorizer("F1") { HashBits = 10, Seed = 314489979, OutputKind = CategoricalTransformOutputKind.Bag }); + var model = pipeline.Train(); + var predictionModel = model.Predict(new InputData() { F1 = "5" }); + + Assert.NotNull(predictionModel); + Assert.NotNull(predictionModel.TransformedF1); + Assert.Equal(1024, predictionModel.TransformedF1.Length); + + for (int index = 0; index < 1024; index++) + if (index == 265) + Assert.Equal(1, predictionModel.TransformedF1[index]); + else + Assert.Equal(0, predictionModel.TransformedF1[index]); + } } }