Skip to content

Change IModelCombiner to not be generic, and add unit tests #1305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 23, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/Microsoft.ML.Core/Prediction/ITrainer.cs
Original file line number Diff line number Diff line change
@@ -95,10 +95,11 @@ public static TPredictor Train<TPredictor>(this ITrainer<TPredictor> trainer, Ro
=> trainer.Train(new TrainContext(trainData));
}

// A trainer can optionally implement this to indicate it can combine multiple models into a single predictor.
public interface IModelCombiner<TModel, TPredictor>
where TPredictor : IPredictor
/// <summary>
/// An interface that combines multiple predictors into a single predictor.
/// </summary>
public interface IModelCombiner
{
TPredictor CombineModels(IEnumerable<TModel> models);
IPredictor CombineModels(IEnumerable<IPredictor> models);
}
}
21 changes: 17 additions & 4 deletions src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
Original file line number Diff line number Diff line change
@@ -19,6 +19,9 @@
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) },
EnsembleTrainer.UserNameValue, EnsembleTrainer.LoadNameValue, "pe", "ParallelEnsemble")]

[assembly: LoadableClass(typeof(EnsembleTrainer), typeof(EnsembleTrainer.Arguments), typeof(SignatureModelCombiner),
"Binary Classification Ensemble Model Combiner", EnsembleTrainer.LoadNameValue, "pe", "ParallelEnsemble")]

namespace Microsoft.ML.Runtime.Ensemble
{
using TDistPredictor = IDistPredictorProducing<Single, Single>;
@@ -28,7 +31,7 @@ namespace Microsoft.ML.Runtime.Ensemble
/// </summary>
public sealed class EnsembleTrainer : EnsembleTrainerBase<Single, TScalarPredictor,
IBinarySubModelSelector, IBinaryOutputCombiner>,
IModelCombiner<TScalarPredictor, TScalarPredictor>
IModelCombiner
{
public const string LoadNameValue = "WeightedEnsemble";
public const string UserNameValue = "Parallel Ensemble (bagging, stacking, etc)";
@@ -70,6 +73,12 @@ public EnsembleTrainer(IHostEnvironment env, Arguments args)
Combiner = args.OutputCombiner.CreateComponent(Host);
}

private EnsembleTrainer(IHostEnvironment env, Arguments args, PredictionKind predictionKind)
: this(env, args)
{
Host.CheckParam(predictionKind == PredictionKind.BinaryClassification, nameof(PredictionKind));
}

public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

private protected override TScalarPredictor CreatePredictor(List<FeatureSubsetModel<TScalarPredictor>> models)
@@ -79,18 +88,22 @@ private protected override TScalarPredictor CreatePredictor(List<FeatureSubsetMo
return new EnsemblePredictor(Host, PredictionKind, CreateModels<TScalarPredictor>(models), Combiner);
}

public TScalarPredictor CombineModels(IEnumerable<TScalarPredictor> models)
public IPredictor CombineModels(IEnumerable<IPredictor> models)
{
Host.CheckValue(models, nameof(models));

var combiner = _outputCombiner.CreateComponent(Host);
var p = models.First();

if (p is TDistPredictor)
{
Host.CheckParam(models.All(m => m is TDistPredictor), nameof(models));
return new EnsembleDistributionPredictor(Host, p.PredictionKind,
models.Select(k => new FeatureSubsetModel<TDistPredictor>((TDistPredictor)k)).ToArray(), combiner);
}

Host.CheckParam(models.All(m => m is TScalarPredictor), nameof(models));
return new EnsemblePredictor(Host, p.PredictionKind,
models.Select(k => new FeatureSubsetModel<TScalarPredictor>(k)).ToArray(), combiner);
models.Select(k => new FeatureSubsetModel<TScalarPredictor>((TScalarPredictor)k)).ToArray(), combiner);
}
}
}
Original file line number Diff line number Diff line change
@@ -12,7 +12,6 @@
using Microsoft.ML.Runtime.Ensemble;
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.Ensemble.Selector;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Learners;

@@ -22,6 +21,9 @@
MulticlassDataPartitionEnsembleTrainer.UserNameValue,
MulticlassDataPartitionEnsembleTrainer.LoadNameValue)]

[assembly: LoadableClass(typeof(MulticlassDataPartitionEnsembleTrainer), typeof(MulticlassDataPartitionEnsembleTrainer.Arguments),
typeof(SignatureModelCombiner), "Multiclass Classification Ensemble Model Combiner", MulticlassDataPartitionEnsembleTrainer.LoadNameValue)]

namespace Microsoft.ML.Runtime.Ensemble
{
using TVectorPredictor = IPredictorProducing<VBuffer<Single>>;
@@ -31,7 +33,7 @@ namespace Microsoft.ML.Runtime.Ensemble
public sealed class MulticlassDataPartitionEnsembleTrainer :
EnsembleTrainerBase<VBuffer<Single>, EnsembleMultiClassPredictor,
IMulticlassSubModelSelector, IMultiClassOutputCombiner>,
IModelCombiner<TVectorPredictor, TVectorPredictor>
IModelCombiner
{
public const string LoadNameValue = "WeightedEnsembleMulticlass";
public const string UserNameValue = "Multi-class Parallel Ensemble (bagging, stacking, etc)";
@@ -72,19 +74,28 @@ public MulticlassDataPartitionEnsembleTrainer(IHostEnvironment env, Arguments ar
Combiner = args.OutputCombiner.CreateComponent(Host);
}

private MulticlassDataPartitionEnsembleTrainer(IHostEnvironment env, Arguments args, PredictionKind predictionKind)
: this(env, args)
{
Host.CheckParam(predictionKind == PredictionKind.MultiClassClassification, nameof(PredictionKind));
}

public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification;

private protected override EnsembleMultiClassPredictor CreatePredictor(List<FeatureSubsetModel<TVectorPredictor>> models)
{
return new EnsembleMultiClassPredictor(Host, CreateModels<TVectorPredictor>(models), Combiner as IMultiClassOutputCombiner);
}

public TVectorPredictor CombineModels(IEnumerable<TVectorPredictor> models)
public IPredictor CombineModels(IEnumerable<IPredictor> models)
{
var predictor = new EnsembleMultiClassPredictor(Host,
models.Select(k => new FeatureSubsetModel<TVectorPredictor>(k)).ToArray(),
_outputCombiner.CreateComponent(Host));
Host.CheckValue(models, nameof(models));
Host.CheckParam(models.All(m => m is TVectorPredictor), nameof(models));

var combiner = _outputCombiner.CreateComponent(Host);
var predictor = new EnsembleMultiClassPredictor(Host,
models.Select(k => new FeatureSubsetModel<TVectorPredictor>((TVectorPredictor)k)).ToArray(),
combiner);
return predictor;
}
}
Original file line number Diff line number Diff line change
@@ -12,7 +12,6 @@
using Microsoft.ML.Runtime.Ensemble;
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.Ensemble.Selector;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Learners;

@@ -21,12 +20,15 @@
RegressionEnsembleTrainer.UserNameValue,
RegressionEnsembleTrainer.LoadNameValue)]

[assembly: LoadableClass(typeof(RegressionEnsembleTrainer), typeof(RegressionEnsembleTrainer.Arguments), typeof(SignatureModelCombiner),
"Regression Ensemble Model Combiner", RegressionEnsembleTrainer.LoadNameValue)]

namespace Microsoft.ML.Runtime.Ensemble
{
using TScalarPredictor = IPredictorProducing<Single>;
public sealed class RegressionEnsembleTrainer : EnsembleTrainerBase<Single, TScalarPredictor,
IRegressionSubModelSelector, IRegressionOutputCombiner>,
IModelCombiner<TScalarPredictor, TScalarPredictor>
IModelCombiner
{
public const string LoadNameValue = "EnsembleRegression";
public const string UserNameValue = "Regression Ensemble (bagging, stacking, etc)";
@@ -66,20 +68,29 @@ public RegressionEnsembleTrainer(IHostEnvironment env, Arguments args)
Combiner = args.OutputCombiner.CreateComponent(Host);
}

private RegressionEnsembleTrainer(IHostEnvironment env, Arguments args, PredictionKind predictionKind)
: this(env, args)
{
Host.CheckParam(predictionKind == PredictionKind.Regression, nameof(PredictionKind));
}

public override PredictionKind PredictionKind => PredictionKind.Regression;

private protected override TScalarPredictor CreatePredictor(List<FeatureSubsetModel<TScalarPredictor>> models)
{
return new EnsemblePredictor(Host, PredictionKind, CreateModels<TScalarPredictor>(models), Combiner);
}

public TScalarPredictor CombineModels(IEnumerable<TScalarPredictor> models)
public IPredictor CombineModels(IEnumerable<IPredictor> models)
{
Host.CheckValue(models, nameof(models));
Host.CheckParam(models.All(m => m is TScalarPredictor), nameof(models));

var combiner = _outputCombiner.CreateComponent(Host);
var p = models.First();

var predictor = new EnsemblePredictor(Host, p.PredictionKind,
models.Select(k => new FeatureSubsetModel<TScalarPredictor>(k)).ToArray(), combiner);
models.Select(k => new FeatureSubsetModel<TScalarPredictor>((TScalarPredictor)k)).ToArray(), combiner);

return predictor;
}
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@

namespace Microsoft.ML.Runtime.FastTree.Internal
{
public sealed class TreeEnsembleCombiner : IModelCombiner<IPredictorProducing<float>, IPredictorProducing<float>>
public sealed class TreeEnsembleCombiner : IModelCombiner
{
private readonly IHost _host;
private readonly PredictionKind _kind;
@@ -32,7 +32,7 @@ public TreeEnsembleCombiner(IHostEnvironment env, PredictionKind kind)
}
}

public IPredictorProducing<float> CombineModels(IEnumerable<IPredictorProducing<float>> models)
public IPredictor CombineModels(IEnumerable<IPredictor> models)
{
_host.CheckValue(models, nameof(models));

316 changes: 268 additions & 48 deletions test/Microsoft.ML.Predictor.Tests/TestPredictors.cs
Original file line number Diff line number Diff line change
@@ -14,6 +14,8 @@ namespace Microsoft.ML.Runtime.RunTests
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.FastTree;
using Microsoft.ML.Runtime.FastTree.Internal;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Runtime.LightGBM;
using Microsoft.ML.Runtime.SymSgd;
using Microsoft.ML.TestFramework;
@@ -585,14 +587,12 @@ public void RankingLightGBMTest()
Done();
}

[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 fails with "Unknown command: 'train'; Format error at (83,3)-(83,4011): Illegal quoting"
// x86 fails. Associated GitHubIssue: https://github.com/dotnet/machinelearning/issues/1216
[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))]
public void TestTreeEnsembleCombiner()
{
var dataPath = GetDataPath("breast-cancer.txt");
var inputFile = new SimpleFileHandle(Env, dataPath, false, false);
#pragma warning disable 0618
var dataView = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFile }).Data;
#pragma warning restore 0618
var dataView = TextLoader.Create(Env, new TextLoader.Arguments(), new MultiFileSource(dataPath));

var fastTrees = new IPredictorModel[3];
for (int i = 0; i < 3; i++)
@@ -609,14 +609,12 @@ public void TestTreeEnsembleCombiner()
CombineAndTestTreeEnsembles(dataView, fastTrees);
}

[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 fails with "Unknown command: 'train'; Format error at (83,3)-(83,4011): Illegal quoting"
// x86 fails. Associated GitHubIssue: https://github.com/dotnet/machinelearning/issues/1216
[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))]
public void TestTreeEnsembleCombinerWithCategoricalSplits()
{
var dataPath = GetDataPath("adult.tiny.with-schema.txt");
var inputFile = new SimpleFileHandle(Env, dataPath, false, false);
#pragma warning disable 0618
var dataView = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFile }).Data;
#pragma warning restore 0618
var dataView = TextLoader.Create(Env, new TextLoader.Arguments(), new MultiFileSource(dataPath));

var cat = new CategoricalEstimator(Env, "Categories", "Features").Fit(dataView).Transform(dataView);
var fastTrees = new IPredictorModel[3];
@@ -647,62 +645,284 @@ private void CombineAndTestTreeEnsembles(IDataView idv, IPredictorModel[] fastTr
Assert.True(scored.Schema.TryGetColumnIndex("Probability", out int probCol));
Assert.True(scored.Schema.TryGetColumnIndex("PredictedLabel", out int predCol));

var scoredArray = new IDataView[3];
var scoreColArray = new int[3];
var probColArray = new int[3];
var predColArray = new int[3];
for (int i = 0; i < 3; i++)
int predCount = Utils.Size(fastTrees);
var scoredArray = new IDataView[predCount];
var scoreColArray = new int[predCount];
var probColArray = new int[predCount];
var predColArray = new int[predCount];
for (int i = 0; i < predCount; i++)
{
scoredArray[i] = ScoreModel.Score(Env, new ScoreModel.Input() { Data = idv, PredictorModel = fastTrees[i] }).ScoredData;
Assert.True(scoredArray[i].Schema.TryGetColumnIndex("Score", out scoreColArray[i]));
Assert.True(scoredArray[i].Schema.TryGetColumnIndex("Probability", out probColArray[i]));
Assert.True(scoredArray[i].Schema.TryGetColumnIndex("PredictedLabel", out predColArray[i]));
}

var cursors = new IRowCursor[3];
using (var curs = scored.GetRowCursor(c => c == scoreCol || c == probCol || c == predCol))
using (cursors[0] = scoredArray[0].GetRowCursor(c => c == scoreColArray[0] || c == probColArray[0] || c == predColArray[0]))
using (cursors[1] = scoredArray[1].GetRowCursor(c => c == scoreColArray[1] || c == probColArray[1] || c == predColArray[1]))
using (cursors[2] = scoredArray[2].GetRowCursor(c => c == scoreColArray[2] || c == probColArray[2] || c == predColArray[2]))
var cursors = new IRowCursor[predCount];
for (int i = 0; i < predCount; i++)
cursors[i] = scoredArray[i].GetRowCursor(c => c == scoreColArray[i] || c == probColArray[i] || c == predColArray[i]);

try
{
using (var curs = scored.GetRowCursor(c => c == scoreCol || c == probCol || c == predCol))
{
var scoreGetter = curs.GetGetter<float>(scoreCol);
var probGetter = curs.GetGetter<float>(probCol);
var predGetter = curs.GetGetter<bool>(predCol);
var scoreGetters = new ValueGetter<float>[predCount];
var probGetters = new ValueGetter<float>[predCount];
var predGetters = new ValueGetter<bool>[predCount];
for (int i = 0; i < predCount; i++)
{
scoreGetters[i] = cursors[i].GetGetter<float>(scoreColArray[i]);
probGetters[i] = cursors[i].GetGetter<float>(probColArray[i]);
predGetters[i] = cursors[i].GetGetter<bool>(predColArray[i]);
}

float score = 0;
float prob = 0;
bool pred = default;
var scores = new float[predCount];
var probs = new float[predCount];
var preds = new bool[predCount];
while (curs.MoveNext())
{
scoreGetter(ref score);
probGetter(ref prob);
predGetter(ref pred);
for (int i = 0; i < predCount; i++)
{
Assert.True(cursors[i].MoveNext());
scoreGetters[i](ref scores[i]);
probGetters[i](ref probs[i]);
predGetters[i](ref preds[i]);
}
Assert.Equal(score, 0.4 * scores.Sum() / predCount, 5);
Assert.Equal(prob, 1 / (1 + Math.Exp(-score)), 6);
Assert.True(pred == score > 0);
}
}
}
finally
{
for (int i = 0; i < predCount; i++)
cursors[i].Dispose();
}
}

// x86 fails. Associated GitHubIssue: https://github.com/dotnet/machinelearning/issues/1216
[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))]
public void TestEnsembleCombiner()
{
var dataPath = GetDataPath("breast-cancer.txt");
var dataView = TextLoader.Create(Env, new TextLoader.Arguments(), new MultiFileSource(dataPath));

var predictors = new IPredictorModel[]
{
FastTree.TrainBinary(Env, new FastTreeBinaryClassificationTrainer.Arguments
{
FeatureColumn = "Features",
NumTrees = 5,
NumLeaves = 4,
LabelColumn = DefaultColumnNames.Label,
TrainingData = dataView
}).PredictorModel,
AveragedPerceptronTrainer.TrainBinary(Env, new AveragedPerceptronTrainer.Arguments()
{
FeatureColumn = "Features",
LabelColumn = DefaultColumnNames.Label,
NumIterations = 2,
TrainingData = dataView,
NormalizeFeatures = NormalizeOption.No
}).PredictorModel,
LogisticRegression.TrainBinary(Env, new LogisticRegression.Arguments()
{
FeatureColumn = "Features",
LabelColumn = DefaultColumnNames.Label,
OptTol = 10e-4F,
TrainingData = dataView,
NormalizeFeatures = NormalizeOption.No
}).PredictorModel,
LogisticRegression.TrainBinary(Env, new LogisticRegression.Arguments()
{
FeatureColumn = "Features",
LabelColumn = DefaultColumnNames.Label,
OptTol = 10e-3F,
TrainingData = dataView,
NormalizeFeatures = NormalizeOption.No
}).PredictorModel
};
CombineAndTestEnsembles(dataView, "pe", "oc=average", PredictionKind.BinaryClassification, predictors);
}

// x86 fails. Associated GitHubIssue: https://github.com/dotnet/machinelearning/issues/1216
[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))]
public void TestMultiClassEnsembleCombiner()
{
var dataPath = GetDataPath("breast-cancer.txt");
var dataView = TextLoader.Create(Env, new TextLoader.Arguments(), new MultiFileSource(dataPath));

var predictors = new IPredictorModel[]
{
LightGbm.TrainMultiClass(Env, new LightGbmArguments
{
FeatureColumn = "Features",
NumBoostRound = 5,
NumLeaves = 4,
LabelColumn = DefaultColumnNames.Label,
TrainingData = dataView
}).PredictorModel,
LogisticRegression.TrainMultiClass(Env, new MulticlassLogisticRegression.Arguments()
{
FeatureColumn = "Features",
LabelColumn = DefaultColumnNames.Label,
OptTol = 10e-4F,
TrainingData = dataView,
NormalizeFeatures = NormalizeOption.No
}).PredictorModel,
LogisticRegression.TrainMultiClass(Env, new MulticlassLogisticRegression.Arguments()
{
FeatureColumn = "Features",
LabelColumn = DefaultColumnNames.Label,
OptTol = 10e-3F,
TrainingData = dataView,
NormalizeFeatures = NormalizeOption.No
}).PredictorModel
};
CombineAndTestEnsembles(dataView, "weightedensemblemulticlass", "oc=multiaverage", PredictionKind.MultiClassClassification, predictors);
}

private void CombineAndTestEnsembles(IDataView idv, string name, string options, PredictionKind predictionKind,
IPredictorModel[] predictors)
{
var combiner = ComponentCatalog.CreateInstance<IModelCombiner>(
Env, typeof(SignatureModelCombiner), name, options, predictionKind);

var predictor = combiner.CombineModels(predictors.Select(pm => pm.Predictor));

var data = new RoleMappedData(idv, label: null, feature: "Features");
var scored = ScoreModel.Score(Env, new ScoreModel.Input() { Data = idv, PredictorModel = new PredictorModel(Env, data, idv, predictor) }).ScoredData;

var predCount = Utils.Size(predictors);

Assert.True(scored.Schema.TryGetColumnIndex("Score", out int scoreCol));
int probCol = -1;
int predCol = -1;
if (predictionKind == PredictionKind.BinaryClassification)
{
Assert.True(scored.Schema.TryGetColumnIndex("Probability", out probCol));
Assert.True(scored.Schema.TryGetColumnIndex("PredictedLabel", out predCol));
}

var scoredArray = new IDataView[predCount];
int[] scoreColArray = new int[predCount];
int[] probColArray = new int[predCount];
int[] predColArray = new int[predCount];

for (int i = 0; i < predCount; i++)
{
var scoreGetter = curs.GetGetter<float>(scoreCol);
var probGetter = curs.GetGetter<float>(probCol);
var predGetter = curs.GetGetter<bool>(predCol);
var scoreGetters = new ValueGetter<float>[3];
var probGetters = new ValueGetter<float>[3];
var predGetters = new ValueGetter<bool>[3];
for (int i = 0; i < 3; i++)
scoredArray[i] = ScoreModel.Score(Env, new ScoreModel.Input() { Data = idv, PredictorModel = predictors[i] }).ScoredData;
Assert.True(scoredArray[i].Schema.TryGetColumnIndex("Score", out scoreColArray[i]));
if (predictionKind == PredictionKind.BinaryClassification)
{
scoreGetters[i] = cursors[i].GetGetter<float>(scoreColArray[i]);
probGetters[i] = cursors[i].GetGetter<float>(probColArray[i]);
predGetters[i] = cursors[i].GetGetter<bool>(predColArray[i]);
Assert.True(scoredArray[i].Schema.TryGetColumnIndex("Probability", out probColArray[i]));
Assert.True(scoredArray[i].Schema.TryGetColumnIndex("PredictedLabel", out predColArray[i]));
}
else
{
probColArray[i] = -1;
predColArray[i] = -1;
}
}

var cursors = new IRowCursor[predCount];
for (int i = 0; i < predCount; i++)
cursors[i] = scoredArray[i].GetRowCursor(c => c == scoreColArray[i] || c == probColArray[i] || c == predColArray[i]);

float score = 0;
float prob = 0;
bool pred = default;
var scores = new float[3];
var probs = new float[3];
var preds = new bool[3];
while (curs.MoveNext())
try
{
using (var curs = scored.GetRowCursor(c => c == scoreCol || c == probCol || c == predCol))
{
scoreGetter(ref score);
probGetter(ref prob);
predGetter(ref pred);
for (int i = 0; i < 3; i++)
var scoreGetter = predictionKind == PredictionKind.MultiClassClassification ?
(ref float dst) => dst = 0 :
curs.GetGetter<float>(scoreCol);
var vectorScoreGetter = predictionKind == PredictionKind.MultiClassClassification ?
curs.GetGetter<VBuffer<float>>(scoreCol) :
(ref VBuffer<float> dst) => dst = default;
var probGetter = predictionKind == PredictionKind.BinaryClassification ?
curs.GetGetter<float>(probCol) :
(ref float dst) => dst = 0;
var predGetter = predictionKind == PredictionKind.BinaryClassification ?
curs.GetGetter<bool>(predCol) :
(ref bool dst) => dst = false;

var scoreGetters = new ValueGetter<float>[predCount];
var vectorScoreGetters = new ValueGetter<VBuffer<float>>[predCount];
var probGetters = new ValueGetter<float>[predCount];
var predGetters = new ValueGetter<bool>[predCount];
for (int i = 0; i< predCount; i++)
{
Assert.True(cursors[i].MoveNext());
scoreGetters[i](ref scores[i]);
probGetters[i](ref probs[i]);
predGetters[i](ref preds[i]);
scoreGetters[i] = predictionKind == PredictionKind.MultiClassClassification ?
(ref float dst) => dst = 0 :
cursors[i].GetGetter<float>(scoreColArray[i]);
vectorScoreGetters[i] = predictionKind == PredictionKind.MultiClassClassification ?
cursors[i].GetGetter<VBuffer<float>>(scoreColArray[i]) :
(ref VBuffer<float> dst) => dst = default;
probGetters[i] = predictionKind == PredictionKind.BinaryClassification ?
cursors[i].GetGetter<float>(probColArray[i]) :
(ref float dst) => dst = 0;
predGetters[i] = predictionKind == PredictionKind.BinaryClassification ?
cursors[i].GetGetter<bool>(predColArray[i]) :
(ref bool dst) => dst = false;
}

float score = 0;
VBuffer<float> vectorScore = default;
float prob = 0;
bool pred = false;
var scores = new float[predCount];
var vectorScores = new VBuffer<float>[predCount];
var probs = new float[predCount];
var preds = new bool[predCount];
while (curs.MoveNext())
{
scoreGetter(ref score);
vectorScoreGetter(ref vectorScore);
probGetter(ref prob);
predGetter(ref pred);

for (int i = 0; i < predCount; i++)
{
Assert.True(cursors[i].MoveNext());
scoreGetters[i](ref scores[i]);
vectorScoreGetters[i](ref vectorScores[i]);
probGetters[i](ref probs[i]);
predGetters[i](ref preds[i]);
}
if (scores.All(s => !float.IsNaN(s)))
Assert.Equal(score, scores.Sum() / predCount, 3);
for (int i = 0; i < predCount; i++)
Assert.Equal(vectorScore.Length, vectorScores[i].Length);
for (int i = 0; i < vectorScore.Length; i++)
{
float sum = 0;
for (int j = 0; j < predCount; j++)
sum += vectorScores[j].GetItemOrDefault(i);
if (float.IsNaN(sum))
Assert.Equal(vectorScore.GetItemOrDefault(i), sum / predCount, 3);
}
Assert.Equal(probs.Count(p => p >= prob), probs.Count(p => p <= prob));
}
Assert.Equal(score, 0.4 * scores.Sum() / 3, 5);
Assert.Equal(prob, 1 / (1 + Math.Exp(-score)), 6);
Assert.True(pred == score > 0);
}
}
finally
{
for (int i = 0; i < predCount; i++)
cursors[i].Dispose();
}
}


[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 output differs from Baseline
[TestCategory("Binary")]
[TestCategory("FastTree")]