diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index e173e01cbe..253d76d654 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -703,10 +703,12 @@ public static IDataTransform CreateForEntryPoint(IHostEnvironment env, Arguments using (var ch = host.Start("Create Tree Ensemble Scorer")) { var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments() { Suffix = args.Suffix }; - var predictor = args.PredictorModel?.Predictor; + var predictor = args.PredictorModel.Predictor; ch.Trace("Prepare data"); RoleMappedData data = null; - args.PredictorModel?.PrepareData(env, input, out data, out var predictor2); + args.PredictorModel.PrepareData(env, input, out data, out var predictor2); + ch.AssertValue(data); + ch.Assert(predictor == predictor2); // Make sure that the given predictor has the correct number of input features. if (predictor is CalibratedPredictorBase) @@ -715,16 +717,16 @@ public static IDataTransform CreateForEntryPoint(IHostEnvironment env, Arguments // be non-null. var vm = predictor as IValueMapper; ch.CheckUserArg(vm != null, nameof(args.PredictorModel), "Predictor does not have compatible type"); - if (data != null && vm?.InputType.VectorSize != data.Schema.Feature.Type.VectorSize) + if (data != null && vm.InputType.VectorSize != data.Schema.Feature.Type.VectorSize) { throw ch.ExceptUserArg(nameof(args.PredictorModel), "Predictor expects {0} features, but data has {1} features", - vm?.InputType.VectorSize, data.Schema.Feature.Type.VectorSize); + vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize); } var bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); - var bound = bindable.Bind(env, data?.Schema); - xf = new GenericScorer(env, scorerArgs, input, bound, data?.Schema); + var bound = bindable.Bind(env, data.Schema); + xf = new GenericScorer(env, scorerArgs, data.Data, bound, data.Schema); ch.Done(); } return xf; diff --git a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj index da16090d43..1660eefce2 100644 --- a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj +++ b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj @@ -18,6 +18,7 @@ + \ No newline at end of file diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index b52ee90f17..e8be6c0370 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -6,13 +6,13 @@ using System.Collections.Generic; using System.IO; using System.Linq; -using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Core.Tests.UnitTests; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.EntryPoints.JsonUtils; +using Microsoft.ML.Runtime.FastTree; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Learners; using Newtonsoft.Json; @@ -2521,5 +2521,70 @@ public void EntryPointPrepareLabelConvertPredictedLabel() } } } + + [Fact] + public void EntryPointTreeLeafFeaturizer() + { + var dataPath = GetDataPath(@"adult.tiny.with-schema.txt"); + var inputFile = new SimpleFileHandle(Env, dataPath, false, false); + var dataView = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFile }).Data; + var cat = Categorical.CatTransformDict(Env, new CategoricalTransform.Arguments() + { + Data = dataView, + Column = new[] { new CategoricalTransform.Column { Name = "Categories", Source = "Categories" } } + }); + var concat = SchemaManipulation.ConcatColumns(Env, new ConcatTransform.Arguments() + { + Data = cat.OutputData, + Column = new[] { new ConcatTransform.Column { Name = "Features", Source = new[] { "Categories", "NumericFeatures" } } } + }); + + var fastTree = FastTree.FastTree.TrainBinary(Env, new FastTreeBinaryClassificationTrainer.Arguments + { + FeatureColumn = "Features", + NumTrees = 5, + NumLeaves = 4, + LabelColumn = DefaultColumnNames.Label, + TrainingData = concat.OutputData + }); + + var combine = ModelOperations.CombineModels(Env, new ModelOperations.PredictorModelInput() + { + PredictorModel = fastTree.PredictorModel, + TransformModels = new[] { cat.Model, concat.Model } + }); + + var treeLeaf = TreeFeaturize.Featurizer(Env, new TreeEnsembleFeaturizerTransform.ArgumentsForEntryPoint + { + Data = dataView, + PredictorModel = combine.PredictorModel + }); + + var view = treeLeaf.OutputData; + Assert.True(view.Schema.TryGetColumnIndex("Trees", out int treesCol)); + Assert.True(view.Schema.TryGetColumnIndex("Leaves", out int leavesCol)); + Assert.True(view.Schema.TryGetColumnIndex("Paths", out int pathsCol)); + VBuffer treeValues = default(VBuffer); + VBuffer leafIndicators = default(VBuffer); + VBuffer pathIndicators = default(VBuffer); + using (var curs = view.GetRowCursor(c => c == treesCol || c == leavesCol || c == pathsCol)) + { + var treesGetter = curs.GetGetter>(treesCol); + var leavesGetter = curs.GetGetter>(leavesCol); + var pathsGetter = curs.GetGetter>(pathsCol); + while (curs.MoveNext()) + { + treesGetter(ref treeValues); + leavesGetter(ref leafIndicators); + pathsGetter(ref pathIndicators); + + Assert.Equal(5, treeValues.Length); + Assert.Equal(5, treeValues.Count); + Assert.Equal(20, leafIndicators.Length); + Assert.Equal(5, leafIndicators.Count); + Assert.Equal(15, pathIndicators.Length); + } + } + } } } \ No newline at end of file