-
Notifications
You must be signed in to change notification settings - Fork 1.9k
How to inspect OneVersusAll models #3701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@rauhs making Is your main motivation to find feature importance? If so, it's easier to address your PFI issue with OVA. AFAIK, PFI runtime for OVA is a linear function of n_rows * n_features * n_classes. It should be very close to the PFI runtime for the underlying binary trainer multiplied by n_classes, because OVA has that many binary models that it need to evaluate during PFI. My PFI code below finishes under 7min for 1M rows, 40 features, and 6 classes on a very modest VM with specs below. What are the properties of your data that's taking a day to finish? What binary trainer are you using with OVA? Have you tried OVA with other binary trainers? Please share your code if possible. using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
namespace Samples.Dynamic.Trainers.MulticlassClassification
{
public static class OvaPfi
{
public static void Example()
{
var mlContext = new MLContext(seed: 0);
var dataPoints = GenerateRandomDataPoints(1000*1000);
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
trainingData = mlContext.Transforms.Conversion.MapValueToKey("Label").Fit(trainingData).Transform(trainingData);
var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll(mlContext.BinaryClassification.Trainers.SdcaNonCalibrated());
var model = pipeline.Fit(trainingData);
var stopWatch = new System.Diagnostics.Stopwatch();
stopWatch.Start();
var x = mlContext.MulticlassClassification.PermutationFeatureImportance(model, trainingData);
stopWatch.Stop();
Console.WriteLine(stopWatch.Elapsed);
// Output:
// 00:06:29.0938075
}
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed = 0)
{
var random = new Random(seed);
float randomFloat() => (float)(random.NextDouble() - 0.5);
for (int i = 0; i < count; i++)
{
var label = random.Next(6);
yield return new DataPoint
{
Label = (uint)label,
// Create random features that are correlated with the label.
// The feature values are slightly increased by adding a constant multiple of label.
Features = Enumerable.Repeat(label, 40).Select(x => randomFloat() + label * 0.2f).ToArray()
};
}
}
private class DataPoint
{
public uint Label { get; set; }
[VectorType(40)]
public float[] Features { get; set; }
}
}
} |
My hunch is that the We have hundreds of classes, sometimes even 2-3k classes. This particular setup:
Log:
So ~15min for PFI and only 15s for the training. And I've reduced our usual problem size here. I've passed in the test set (1000 Samples) to the PFI call. |
We need to do some profiling for this. The large number of classes is making it very slow. |
Version: 1.0
Since 2b417bb made
SubModelParameters
private, there is no chance to get any of the sub models which would be needed for feature importance.How do I inspect OVA models? In particular, how to get feature importance?
FWIW, I have no chance of using PFI feature importance. It would take day to run. It's 1000x slower than training any of my models.
The text was updated successfully, but these errors were encountered: