Skip to content

Commit e0c029c

Browse files
authored
Update Permutation Feature Importance Samples (#3247)
* Updating PFI Docs.
1 parent 6723bb6 commit e0c029c

File tree

8 files changed

+445
-218
lines changed

8 files changed

+445
-218
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIHelper.cs

Lines changed: 0 additions & 61 deletions
This file was deleted.

docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIRegressionExample.cs

Lines changed: 0 additions & 78 deletions
This file was deleted.

docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs

Lines changed: 0 additions & 77 deletions
This file was deleted.
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.ML;
5+
6+
namespace Samples.Dynamic.Trainers.BinaryClassification
7+
{
8+
public static class PermutationFeatureImportance
9+
{
10+
public static void Example()
11+
{
12+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
13+
// as a catalog of available operations and as the source of randomness.
14+
var mlContext = new MLContext(seed:1);
15+
16+
// Create sample data.
17+
var samples = GenerateData();
18+
19+
// Load the sample data as an IDataView.
20+
var data = mlContext.Data.LoadFromEnumerable(samples);
21+
22+
// Define a training pipeline that concatenates features into a vector, normalizes them, and then
23+
// trains a linear model.
24+
var featureColumns = new string[] { nameof(Data.Feature1), nameof(Data.Feature2) };
25+
var pipeline = mlContext.Transforms.Concatenate("Features", featureColumns)
26+
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
27+
.Append(mlContext.BinaryClassification.Trainers.SdcaLogisticRegression());
28+
29+
// Fit the pipeline to the data.
30+
var model = pipeline.Fit(data);
31+
32+
// Transform the dataset.
33+
var transformedData = model.Transform(data);
34+
35+
// Extract the predictor.
36+
var linearPredictor = model.LastTransformer;
37+
38+
// Compute the permutation metrics for the linear model using the normalized data.
39+
var permutationMetrics = mlContext.BinaryClassification.PermutationFeatureImportance(
40+
linearPredictor, transformedData, permutationCount: 30);
41+
42+
// Now let's look at which features are most important to the model overall.
43+
// Get the feature indices sorted by their impact on AUC.
44+
var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.AreaUnderRocCurve})
45+
.OrderByDescending(feature => Math.Abs(feature.AreaUnderRocCurve.Mean))
46+
.Select(feature => feature.index);
47+
48+
Console.WriteLine("Feature\tModel Weight\tChange in AUC\t95% Confidence in the Mean Change in AUC");
49+
var auc = permutationMetrics.Select(x => x.AreaUnderRocCurve).ToArray();
50+
foreach (int i in sortedIndices)
51+
{
52+
Console.WriteLine("{0}\t{1:0.00}\t{2:G4}\t{3:G4}",
53+
featureColumns[i],
54+
linearPredictor.Model.SubModel.Weights[i],
55+
auc[i].Mean,
56+
1.96 * auc[i].StandardError);
57+
}
58+
59+
// Expected output:
60+
// Feature Model Weight Change in AUC 95% Confidence in the Mean Change in AUC
61+
// Feature2 35.15 -0.387 0.002015
62+
// Feature1 17.94 -0.1514 0.0008963
63+
}
64+
65+
private class Data
66+
{
67+
public bool Label { get; set; }
68+
69+
public float Feature1 { get; set; }
70+
71+
public float Feature2 { get; set; }
72+
}
73+
74+
/// <summary>
75+
/// Generate an enumerable of Data objects, creating the label as a simple
76+
/// linear combination of the features.
77+
/// </summary>
78+
/// <param name="nExamples">The number of examples.</param>
79+
/// <param name="bias">The bias, or offset, in the calculation of the label.</param>
80+
/// <param name="weight1">The weight to multiply the first feature with to compute the label.</param>
81+
/// <param name="weight2">The weight to multiply the second feature with to compute the label.</param>
82+
/// <param name="seed">The seed for generating feature values and label noise.</param>
83+
/// <returns>An enumerable of Data objects.</returns>
84+
private static IEnumerable<Data> GenerateData(int nExamples = 10000,
85+
double bias = 0, double weight1 = 1, double weight2 = 2, int seed = 1)
86+
{
87+
var rng = new Random(seed);
88+
for (int i = 0; i < nExamples; i++)
89+
{
90+
var data = new Data
91+
{
92+
Feature1 = (float)(rng.Next(10) * (rng.NextDouble() - 0.5)),
93+
Feature2 = (float)(rng.Next(10) * (rng.NextDouble() - 0.5)),
94+
};
95+
96+
// Create a noisy label.
97+
var value = (float)(bias + weight1 * data.Feature1 + weight2 * data.Feature2 + rng.NextDouble() - 0.5);
98+
data.Label = Sigmoid(value) > 0.5;
99+
yield return data;
100+
}
101+
}
102+
103+
private static double Sigmoid(double x) => 1.0 / (1.0 + Math.Exp(-1 * x));
104+
}
105+
}

0 commit comments

Comments
 (0)