Skip to content

Commit d8e0462

Browse files
jwood803shauheen
authored andcommitted
Explainability doc (#2901)
1 parent 3fa207d commit d8e0462

File tree

2 files changed

+229
-0
lines changed

2 files changed

+229
-0
lines changed

docs/code/MlNetCookBook.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,60 @@ var biases = modelParameters.GetBiases();
578578
579579
```
580580

581+
## How do I get a model's weights to look at the global feature importance?
582+
The below snippet shows how to get a model's weights to help determine the feature importance of the model for a linear model.
583+
584+
```csharp
585+
var linearModel = model.LastTransformer.Model;
586+
587+
var weights = linearModel.Weights;
588+
```
589+
590+
The below snipper shows how to get the weights for a fast tree model.
591+
592+
```csharp
593+
var treeModel = model.LastTransformer.Model;
594+
595+
var weights = new VBuffer<float>();
596+
treeModel.GetFeatureWeights(ref weights);
597+
```
598+
599+
## How do I look at the global feature importance?
600+
The below snippet shows how to get a glimpse of the the feature importance. Permutation Feature Importance works by computing the change in the evaluation metrics when each feature is replaced by a random value. In this case, we are investigating the change in the root mean squared error. For more information on permutation feature importance, review the [documentation](https://docs.microsoft.com/en-us/dotnet/machine-learning/how-to-guides/determine-global-feature-importance-in-model).
601+
602+
```csharp
603+
var transformedData = model.Transform(data);
604+
605+
var featureImportance = context.Regression.PermutationFeatureImportance(model.LastTransformer, transformedData);
606+
607+
for (int i = 0; i < featureImportance.Count(); i++)
608+
{
609+
Console.WriteLine($"Feature {i}: Difference in RMS - {featureImportance[i].RootMeanSquaredError.Mean}");
610+
}
611+
```
612+
613+
## How do I look at the local feature importance per example?
614+
The below snippet shows how to get feature importance for each example of data.
615+
616+
```csharp
617+
var model = pipeline.Fit(data);
618+
var transformedData = model.Transform(data);
619+
620+
var linearModel = model.LastTransformer;
621+
622+
var featureContributionCalculation = context.Transforms.CalculateFeatureContribution(linearModel, normalize: false);
623+
624+
var featureContributionData = featureContributionCalculation.Fit(transformedData).Transform(transformedData);
625+
626+
var shuffledSubset = context.Data.TakeRows(context.Data.ShuffleRows(featureContributionData), 10);
627+
var scoringEnumerator = context.Data.CreateEnumerable<HousingData>(shuffledSubset, true);
628+
629+
foreach (var row in scoringEnumerator)
630+
{
631+
Console.WriteLine(row);
632+
}
633+
```
634+
581635
## What is normalization and why do I need to care?
582636

583637
In ML.NET we expose a number of [parametric and non-parametric algorithms](https://machinelearningmastery.com/parametric-and-nonparametric-machine-learning-algorithms/).
@@ -788,6 +842,7 @@ var transformedData = pipeline.Fit(data).Transform(data);
788842
var embeddings = transformedData.GetColumn<float[]>(mlContext, "Embeddings").Take(10).ToArray();
789843
var unigrams = transformedData.GetColumn<float[]>(mlContext, "BagOfWords").Take(10).ToArray();
790844
```
845+
791846
## How do I train using cross-validation?
792847

793848
[Cross-validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics)) is a useful technique for ML applications. It helps estimate the variance of the model quality from one run to another and also eliminates the need to extract a separate test set for evaluation.
@@ -838,6 +893,7 @@ var microAccuracies = cvResults.Select(r => r.Metrics.AccuracyMicro);
838893
Console.WriteLine(microAccuracies.Average());
839894

840895
```
896+
841897
## Can I mix and match static and dynamic pipelines?
842898

843899
Yes, we can have both of them in our codebase. The static pipelines are just a statically-typed way to build dynamic pipelines.

test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,164 @@ private void NormalizationWorkout(string dataPath)
249249
public void Normalization()
250250
=> NormalizationWorkout(GetDataPath("iris.data"));
251251

252+
[Fact]
253+
public void GlobalFeatureImportance()
254+
{
255+
var dataPath = GetDataPath("housing.txt");
256+
257+
var context = new MLContext();
258+
259+
IDataView data = context.Data.LoadFromTextFile(dataPath, new[]
260+
{
261+
new TextLoader.Column("Label", DataKind.Single, 0),
262+
new TextLoader.Column("CrimesPerCapita", DataKind.Single, 1),
263+
new TextLoader.Column("PercentResidental", DataKind.Single, 2),
264+
new TextLoader.Column("PercentNonRetail", DataKind.Single, 3),
265+
new TextLoader.Column("CharlesRiver", DataKind.Single, 4),
266+
new TextLoader.Column("NitricOxides", DataKind.Single, 5),
267+
new TextLoader.Column("RoomsPerDwelling", DataKind.Single, 6),
268+
new TextLoader.Column("PercentPre40s", DataKind.Single, 7),
269+
new TextLoader.Column("EmploymentDistance", DataKind.Single, 8),
270+
new TextLoader.Column("HighwayDistance", DataKind.Single, 9),
271+
new TextLoader.Column("TaxRate", DataKind.Single, 10),
272+
new TextLoader.Column("TeacherRatio", DataKind.Single, 11)
273+
},
274+
hasHeader: true);
275+
276+
var pipeline = context.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides",
277+
"RoomsPerDwelling", "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio")
278+
.Append(context.Regression.Trainers.FastTree());
279+
280+
var model = pipeline.Fit(data);
281+
282+
var transformedData = model.Transform(data);
283+
284+
var featureImportance = context.Regression.PermutationFeatureImportance(model.LastTransformer, transformedData);
285+
286+
for (int i = 0; i < featureImportance.Count(); i++)
287+
{
288+
Console.WriteLine($"Feature {i}: Difference in RMS - {featureImportance[i].RootMeanSquaredError.Mean}");
289+
}
290+
}
291+
292+
[Fact]
293+
public void GetLinearModelWeights()
294+
{
295+
var dataPath = GetDataPath("housing.txt");
296+
297+
var context = new MLContext();
298+
299+
IDataView data = context.Data.LoadFromTextFile(dataPath, new[]
300+
{
301+
new TextLoader.Column("Label", DataKind.Single, 0),
302+
new TextLoader.Column("CrimesPerCapita", DataKind.Single, 1),
303+
new TextLoader.Column("PercentResidental", DataKind.Single, 2),
304+
new TextLoader.Column("PercentNonRetail", DataKind.Single, 3),
305+
new TextLoader.Column("CharlesRiver", DataKind.Single, 4),
306+
new TextLoader.Column("NitricOxides", DataKind.Single, 5),
307+
new TextLoader.Column("RoomsPerDwelling", DataKind.Single, 6),
308+
new TextLoader.Column("PercentPre40s", DataKind.Single, 7),
309+
new TextLoader.Column("EmploymentDistance", DataKind.Single, 8),
310+
new TextLoader.Column("HighwayDistance", DataKind.Single, 9),
311+
new TextLoader.Column("TaxRate", DataKind.Single, 10),
312+
new TextLoader.Column("TeacherRatio", DataKind.Single, 11)
313+
},
314+
hasHeader: true);
315+
316+
var pipeline = context.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides",
317+
"RoomsPerDwelling", "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio")
318+
.Append(context.Regression.Trainers.Sdca());
319+
320+
var model = pipeline.Fit(data);
321+
322+
var linearModel = model.LastTransformer.Model;
323+
324+
var weights = linearModel.Weights;
325+
}
326+
327+
[Fact]
328+
public void GetFastTreeModelWeights()
329+
{
330+
var dataPath = GetDataPath("housing.txt");
331+
332+
var context = new MLContext();
333+
334+
IDataView data = context.Data.LoadFromTextFile(dataPath, new[]
335+
{
336+
new TextLoader.Column("Label", DataKind.Single, 0),
337+
new TextLoader.Column("CrimesPerCapita", DataKind.Single, 1),
338+
new TextLoader.Column("PercentResidental", DataKind.Single, 2),
339+
new TextLoader.Column("PercentNonRetail", DataKind.Single, 3),
340+
new TextLoader.Column("CharlesRiver", DataKind.Single, 4),
341+
new TextLoader.Column("NitricOxides", DataKind.Single, 5),
342+
new TextLoader.Column("RoomsPerDwelling", DataKind.Single, 6),
343+
new TextLoader.Column("PercentPre40s", DataKind.Single, 7),
344+
new TextLoader.Column("EmploymentDistance", DataKind.Single, 8),
345+
new TextLoader.Column("HighwayDistance", DataKind.Single, 9),
346+
new TextLoader.Column("TaxRate", DataKind.Single, 10),
347+
new TextLoader.Column("TeacherRatio", DataKind.Single, 11)
348+
},
349+
hasHeader: true);
350+
351+
var pipeline = context.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides",
352+
"RoomsPerDwelling", "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio")
353+
.Append(context.Regression.Trainers.FastTree());
354+
355+
var model = pipeline.Fit(data);
356+
357+
var linearModel = model.LastTransformer.Model;
358+
359+
var weights = new VBuffer<float>();
360+
linearModel.GetFeatureWeights(ref weights);
361+
}
362+
363+
[Fact]
364+
public void FeatureImportanceForEachRow()
365+
{
366+
var dataPath = GetDataPath("housing.txt");
367+
368+
var context = new MLContext();
369+
370+
IDataView data = context.Data.LoadFromTextFile(dataPath, new[]
371+
{
372+
new TextLoader.Column("MedianHomeValue", DataKind.Single, 0),
373+
new TextLoader.Column("CrimesPerCapita", DataKind.Single, 1),
374+
new TextLoader.Column("PercentResidental", DataKind.Single, 2),
375+
new TextLoader.Column("PercentNonRetail", DataKind.Single, 3),
376+
new TextLoader.Column("CharlesRiver", DataKind.Single, 4),
377+
new TextLoader.Column("NitricOxides", DataKind.Single, 5),
378+
new TextLoader.Column("RoomsPerDwelling", DataKind.Single, 6),
379+
new TextLoader.Column("PercentPre40s", DataKind.Single, 7),
380+
new TextLoader.Column("EmploymentDistance", DataKind.Single, 8),
381+
new TextLoader.Column("HighwayDistance", DataKind.Single, 9),
382+
new TextLoader.Column("TaxRate", DataKind.Single, 10),
383+
new TextLoader.Column("TeacherRatio", DataKind.Single, 11)
384+
},
385+
hasHeader: true);
386+
387+
var pipeline = context.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides",
388+
"RoomsPerDwelling", "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio")
389+
.Append(context.Regression.Trainers.FastTree(labelColumnName: "MedianHomeValue"));
390+
391+
var model = pipeline.Fit(data);
392+
393+
var transfomedData = model.Transform(data);
394+
395+
var linearModel = model.LastTransformer;
396+
397+
var featureContributionCalculation = context.Transforms.CalculateFeatureContribution(linearModel, normalize: false);
398+
399+
var featureContributionData = featureContributionCalculation.Fit(transfomedData).Transform(transfomedData);
400+
401+
var shuffledSubset = context.Data.TakeRows(context.Data.ShuffleRows(featureContributionData), 10);
402+
var scoringEnumerator = context.Data.CreateEnumerable<HousingData>(shuffledSubset, true);
403+
404+
foreach (var row in scoringEnumerator)
405+
{
406+
Console.WriteLine(row);
407+
}
408+
}
409+
252410
private IEnumerable<CustomerChurnInfo> GetChurnInfo()
253411
{
254412
var r = new Random(454);
@@ -625,5 +783,20 @@ private class AdultData
625783
public float Target { get; set; }
626784
}
627785

786+
private class HousingData
787+
{
788+
public float MedianHomeValue { get; set; }
789+
public float CrimesPerCapita { get; set; }
790+
public float PercentResidental { get; set; }
791+
public float PercentNonRetail { get; set; }
792+
public float CharlesRiver { get; set; }
793+
public float NitricOxides { get; set; }
794+
public float RoomsPerDwelling { get; set; }
795+
public float PercentPre40s { get; set; }
796+
public float EmploymentDistance { get; set; }
797+
public float HighwayDistance { get; set; }
798+
public float TaxRate { get; set; }
799+
public float TeacherRatio { get; set; }
800+
}
628801
}
629802
}

0 commit comments

Comments
 (0)