diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
index e173e01cbe..253d76d654 100644
--- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
+++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
@@ -703,10 +703,12 @@ public static IDataTransform CreateForEntryPoint(IHostEnvironment env, Arguments
using (var ch = host.Start("Create Tree Ensemble Scorer"))
{
var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments() { Suffix = args.Suffix };
- var predictor = args.PredictorModel?.Predictor;
+ var predictor = args.PredictorModel.Predictor;
ch.Trace("Prepare data");
RoleMappedData data = null;
- args.PredictorModel?.PrepareData(env, input, out data, out var predictor2);
+ args.PredictorModel.PrepareData(env, input, out data, out var predictor2);
+ ch.AssertValue(data);
+ ch.Assert(predictor == predictor2);
// Make sure that the given predictor has the correct number of input features.
if (predictor is CalibratedPredictorBase)
@@ -715,16 +717,16 @@ public static IDataTransform CreateForEntryPoint(IHostEnvironment env, Arguments
// be non-null.
var vm = predictor as IValueMapper;
ch.CheckUserArg(vm != null, nameof(args.PredictorModel), "Predictor does not have compatible type");
- if (data != null && vm?.InputType.VectorSize != data.Schema.Feature.Type.VectorSize)
+ if (data != null && vm.InputType.VectorSize != data.Schema.Feature.Type.VectorSize)
{
throw ch.ExceptUserArg(nameof(args.PredictorModel),
"Predictor expects {0} features, but data has {1} features",
- vm?.InputType.VectorSize, data.Schema.Feature.Type.VectorSize);
+ vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize);
}
var bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
- var bound = bindable.Bind(env, data?.Schema);
- xf = new GenericScorer(env, scorerArgs, input, bound, data?.Schema);
+ var bound = bindable.Bind(env, data.Schema);
+ xf = new GenericScorer(env, scorerArgs, data.Data, bound, data.Schema);
ch.Done();
}
return xf;
diff --git a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj
index da16090d43..1660eefce2 100644
--- a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj
+++ b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj
@@ -18,6 +18,7 @@
+
\ No newline at end of file
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
index b52ee90f17..e8be6c0370 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
@@ -6,13 +6,13 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
-using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Core.Tests.UnitTests;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.IO;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.EntryPoints.JsonUtils;
+using Microsoft.ML.Runtime.FastTree;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Learners;
using Newtonsoft.Json;
@@ -2521,5 +2521,70 @@ public void EntryPointPrepareLabelConvertPredictedLabel()
}
}
}
+
+ [Fact]
+ public void EntryPointTreeLeafFeaturizer()
+ {
+ var dataPath = GetDataPath(@"adult.tiny.with-schema.txt");
+ var inputFile = new SimpleFileHandle(Env, dataPath, false, false);
+ var dataView = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFile }).Data;
+ var cat = Categorical.CatTransformDict(Env, new CategoricalTransform.Arguments()
+ {
+ Data = dataView,
+ Column = new[] { new CategoricalTransform.Column { Name = "Categories", Source = "Categories" } }
+ });
+ var concat = SchemaManipulation.ConcatColumns(Env, new ConcatTransform.Arguments()
+ {
+ Data = cat.OutputData,
+ Column = new[] { new ConcatTransform.Column { Name = "Features", Source = new[] { "Categories", "NumericFeatures" } } }
+ });
+
+ var fastTree = FastTree.FastTree.TrainBinary(Env, new FastTreeBinaryClassificationTrainer.Arguments
+ {
+ FeatureColumn = "Features",
+ NumTrees = 5,
+ NumLeaves = 4,
+ LabelColumn = DefaultColumnNames.Label,
+ TrainingData = concat.OutputData
+ });
+
+ var combine = ModelOperations.CombineModels(Env, new ModelOperations.PredictorModelInput()
+ {
+ PredictorModel = fastTree.PredictorModel,
+ TransformModels = new[] { cat.Model, concat.Model }
+ });
+
+ var treeLeaf = TreeFeaturize.Featurizer(Env, new TreeEnsembleFeaturizerTransform.ArgumentsForEntryPoint
+ {
+ Data = dataView,
+ PredictorModel = combine.PredictorModel
+ });
+
+ var view = treeLeaf.OutputData;
+ Assert.True(view.Schema.TryGetColumnIndex("Trees", out int treesCol));
+ Assert.True(view.Schema.TryGetColumnIndex("Leaves", out int leavesCol));
+ Assert.True(view.Schema.TryGetColumnIndex("Paths", out int pathsCol));
+ VBuffer treeValues = default(VBuffer);
+ VBuffer leafIndicators = default(VBuffer);
+ VBuffer pathIndicators = default(VBuffer);
+ using (var curs = view.GetRowCursor(c => c == treesCol || c == leavesCol || c == pathsCol))
+ {
+ var treesGetter = curs.GetGetter>(treesCol);
+ var leavesGetter = curs.GetGetter>(leavesCol);
+ var pathsGetter = curs.GetGetter>(pathsCol);
+ while (curs.MoveNext())
+ {
+ treesGetter(ref treeValues);
+ leavesGetter(ref leafIndicators);
+ pathsGetter(ref pathIndicators);
+
+ Assert.Equal(5, treeValues.Length);
+ Assert.Equal(5, treeValues.Count);
+ Assert.Equal(20, leafIndicators.Length);
+ Assert.Equal(5, leafIndicators.Count);
+ Assert.Equal(15, pathIndicators.Length);
+ }
+ }
+ }
}
}
\ No newline at end of file