Skip to content

Commit c6b53ba

Browse files
committed
review comments
1 parent c82e0f2 commit c6b53ba

File tree

8 files changed

+45
-55
lines changed

8 files changed

+45
-55
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ public static void Example()
3535
// Create a Feature Contribution Calculator
3636
// Calculate the feature contributions for all features given trained model parameters
3737
// And don't normalize the contribution scores
38-
var featureContributionCalculator = mlContext.Transforms.CalculateFeatureContribution(model.Model, model.FeatureColumn, numberOfPositiveContributions: 11, normalize: false);
38+
var featureContributionCalculator = mlContext.Transforms.CalculateFeatureContribution(model, numberOfPositiveContributions: 11, normalize: false);
3939
var outputData = featureContributionCalculator.Fit(scoredData).Transform(scoredData);
4040

4141
// FeatureContributionCalculatingEstimator can be use as an intermediary step in a pipeline.
4242
// The features retained by FeatureContributionCalculatingEstimator will be in the FeatureContribution column.
43-
var pipeline = mlContext.Transforms.CalculateFeatureContribution(model.Model, model.FeatureColumn, numberOfPositiveContributions: 11)
43+
var pipeline = mlContext.Transforms.CalculateFeatureContribution(model, numberOfPositiveContributions: 11)
4444
.Append(mlContext.Regression.Trainers.Ols(featureColumnName: "FeatureContributions"));
4545
var outData = featureContributionCalculator.Fit(scoredData).Transform(scoredData);
4646

docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIRegressionExample.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public static void Example()
3131
// Compute the permutation metrics using the properly normalized data.
3232
var transformedData = model.Transform(data);
3333
var permutationMetrics = mlContext.Regression.PermutationFeatureImportance(
34-
linearPredictor, transformedData, labelColumnName: labelName, featureColumnName: "Features", permutationCount: 3);
34+
linearPredictor, transformedData, labelColumnName: labelName, permutationCount: 3);
3535

3636
// Now let's look at which features are most important to the model overall
3737
// Get the feature indices sorted by their impact on R-Squared

docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public static void Example()
3535
// Compute the permutation metrics using the properly normalized data.
3636
var transformedData = model.Transform(data);
3737
var permutationMetrics = mlContext.BinaryClassification.PermutationFeatureImportance(
38-
linearPredictor, transformedData, labelColumnName: labelName, featureColumnName: "Features", permutationCount: 3);
38+
linearPredictor, transformedData, labelColumnName: labelName, permutationCount: 3);
3939

4040
// Now let's look at which features are most important to the model overall.
4141
// Get the feature indices sorted by their impact on AreaUnderRocCurve.

src/Microsoft.ML.Data/Transforms/ExplainabilityCatalog.cs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ public static class ExplainabilityCatalog
1717
/// Note that this functionality is not supported by all the models. See <see cref="FeatureContributionCalculatingTransformer"/> for a list of the suported models.
1818
/// </summary>
1919
/// <param name="catalog">The model explainability operations catalog.</param>
20-
/// <param name="modelParameters">Trained model parameters that support Feature Contribution Calculation and which will be used for scoring.</param>
21-
/// <param name="featureColumnName">The name of the feature column that will be used as input.</param>
20+
/// <param name="predictionTransformer">A <see cref="ISingleFeaturePredictionTransformer{TModel}"/> that supports Feature Contribution Calculation,
21+
/// and which will also be used for scoring.</param>
2222
/// <param name="numberOfPositiveContributions">The number of positive contributions to report, sorted from highest magnitude to lowest magnitude.
2323
/// Note that if there are fewer features with positive contributions than <paramref name="numberOfPositiveContributions"/>, the rest will be returned as zeros.</param>
2424
/// <param name="numberOfNegativeContributions">The number of negative contributions to report, sorted from highest magnitude to lowest magnitude.
@@ -32,11 +32,10 @@ public static class ExplainabilityCatalog
3232
/// </format>
3333
/// </example>
3434
public static FeatureContributionCalculatingEstimator CalculateFeatureContribution(this TransformsCatalog catalog,
35-
ICalculateFeatureContribution modelParameters,
36-
string featureColumnName = DefaultColumnNames.Features,
35+
ISingleFeaturePredictionTransformer<ICalculateFeatureContribution> predictionTransformer,
3736
int numberOfPositiveContributions = FeatureContributionDefaults.NumberOfPositiveContributions,
3837
int numberOfNegativeContributions = FeatureContributionDefaults.NumberOfNegativeContributions,
3938
bool normalize = FeatureContributionDefaults.Normalize)
40-
=> new FeatureContributionCalculatingEstimator(CatalogUtils.GetEnvironment(catalog), modelParameters, featureColumnName, numberOfPositiveContributions, numberOfNegativeContributions, normalize);
39+
=> new FeatureContributionCalculatingEstimator(CatalogUtils.GetEnvironment(catalog), predictionTransformer, numberOfPositiveContributions, numberOfNegativeContributions, normalize);
4140
}
4241
}

src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -286,23 +286,22 @@ internal static class Defaults
286286
/// Note that this functionality is not supported by all the models. See <see cref="FeatureContributionCalculatingTransformer"/> for a list of the suported models.
287287
/// </summary>
288288
/// <param name="env">The environment to use.</param>
289-
/// <param name="modelParameters">Trained model parameters that support Feature Contribution Calculation and which will be used for scoring.</param>
290-
/// <param name="featureColumnName">The name of the feature column that will be used as input.</param>
289+
/// <param name="predictionTransformer">A <see cref="ISingleFeaturePredictionTransformer{TModel}"/> that supports Feature Contribution Calculation,
290+
/// and which will also be used for scoring.</param>
291291
/// <param name="numberOfPositiveContributions">The number of positive contributions to report, sorted from highest magnitude to lowest magnitude.
292292
/// Note that if there are fewer features with positive contributions than <paramref name="numberOfPositiveContributions"/>, the rest will be returned as zeros.</param>
293293
/// <param name="numberOfNegativeContributions">The number of negative contributions to report, sorted from highest magnitude to lowest magnitude.
294294
/// Note that if there are fewer features with negative contributions than <paramref name="numberOfNegativeContributions"/>, the rest will be returned as zeros.</param>
295295
/// <param name="normalize">Whether the feature contributions should be normalized to the [-1, 1] interval.</param>
296-
internal FeatureContributionCalculatingEstimator(IHostEnvironment env, ICalculateFeatureContribution modelParameters,
297-
string featureColumnName = DefaultColumnNames.Features,
296+
internal FeatureContributionCalculatingEstimator(IHostEnvironment env, ISingleFeaturePredictionTransformer<ICalculateFeatureContribution> predictionTransformer,
298297
int numberOfPositiveContributions = Defaults.NumberOfPositiveContributions,
299298
int numberOfNegativeContributions = Defaults.NumberOfNegativeContributions,
300299
bool normalize = Defaults.Normalize)
301300
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(FeatureContributionCalculatingTransformer)),
302-
new FeatureContributionCalculatingTransformer(env, modelParameters, featureColumnName, numberOfPositiveContributions, numberOfNegativeContributions, normalize))
301+
new FeatureContributionCalculatingTransformer(env, predictionTransformer.Model, predictionTransformer.FeatureColumn, numberOfPositiveContributions, numberOfNegativeContributions, normalize))
303302
{
304-
_featureColumn = featureColumnName;
305-
_predictor = modelParameters;
303+
_featureColumn = predictionTransformer.FeatureColumn;
304+
_predictor = predictionTransformer.Model;
306305
}
307306

308307
/// <summary>

src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,33 +46,31 @@ public static class PermutationFeatureImportanceExtensions
4646
/// </format>
4747
/// </example>
4848
/// <param name="catalog">The regression catalog.</param>
49-
/// <param name="model">The model on which to evaluate feature importance.</param>
49+
/// <param name="predictionTransformer">The model on which to evaluate feature importance.</param>
5050
/// <param name="data">The evaluation data set.</param>
5151
/// <param name="labelColumnName">Label column name.</param>
52-
/// <param name="featureColumnName">Feature column name.</param>
5352
/// <param name="useFeatureWeightFilter">Use features weight to pre-filter features.</param>
5453
/// <param name="numberOfExamplesToUse">Limit the number of examples to evaluate on. <cref langword="null"/> means up to ~2 bln examples from <paramref param="data"/> will be used.</param>
5554
/// <param name="permutationCount">The number of permutations to perform.</param>
5655
/// <returns>Array of per-feature 'contributions' to the score.</returns>
5756
public static ImmutableArray<RegressionMetricsStatistics>
5857
PermutationFeatureImportance<TModel>(
5958
this RegressionCatalog catalog,
60-
IPredictionTransformer<TModel> model,
59+
ISingleFeaturePredictionTransformer<TModel> predictionTransformer,
6160
IDataView data,
6261
string labelColumnName = DefaultColumnNames.Label,
63-
string featureColumnName = DefaultColumnNames.Features,
6462
bool useFeatureWeightFilter = false,
6563
int? numberOfExamplesToUse = null,
6664
int permutationCount = 1)
6765
{
6866
return PermutationFeatureImportance<TModel, RegressionMetrics, RegressionMetricsStatistics>.GetImportanceMetricsMatrix(
6967
catalog.GetEnvironment(),
70-
model,
68+
predictionTransformer,
7169
data,
7270
() => new RegressionMetricsStatistics(),
7371
idv => catalog.Evaluate(idv, labelColumnName),
7472
RegressionDelta,
75-
featureColumnName,
73+
predictionTransformer.FeatureColumn,
7674
permutationCount,
7775
useFeatureWeightFilter,
7876
numberOfExamplesToUse);
@@ -124,33 +122,31 @@ private static RegressionMetrics RegressionDelta(
124122
/// </format>
125123
/// </example>
126124
/// <param name="catalog">The binary classification catalog.</param>
127-
/// <param name="model">The model on which to evaluate feature importance.</param>
125+
/// <param name="predictionTransformer">The model on which to evaluate feature importance.</param>
128126
/// <param name="data">The evaluation data set.</param>
129127
/// <param name="labelColumnName">Label column name.</param>
130-
/// <param name="featureColumnName">Feature column name.</param>
131128
/// <param name="useFeatureWeightFilter">Use features weight to pre-filter features.</param>
132129
/// <param name="numberOfExamplesToUse">Limit the number of examples to evaluate on. <cref langword="null"/> means up to ~2 bln examples from <paramref param="data"/> will be used.</param>
133130
/// <param name="permutationCount">The number of permutations to perform.</param>
134131
/// <returns>Array of per-feature 'contributions' to the score.</returns>
135132
public static ImmutableArray<BinaryClassificationMetricsStatistics>
136133
PermutationFeatureImportance<TModel>(
137134
this BinaryClassificationCatalog catalog,
138-
IPredictionTransformer<TModel> model,
135+
ISingleFeaturePredictionTransformer<TModel> predictionTransformer,
139136
IDataView data,
140137
string labelColumnName = DefaultColumnNames.Label,
141-
string featureColumnName = DefaultColumnNames.Features,
142138
bool useFeatureWeightFilter = false,
143139
int? numberOfExamplesToUse = null,
144140
int permutationCount = 1)
145141
{
146142
return PermutationFeatureImportance<TModel, BinaryClassificationMetrics, BinaryClassificationMetricsStatistics>.GetImportanceMetricsMatrix(
147143
catalog.GetEnvironment(),
148-
model,
144+
predictionTransformer,
149145
data,
150146
() => new BinaryClassificationMetricsStatistics(),
151147
idv => catalog.Evaluate(idv, labelColumnName),
152148
BinaryClassifierDelta,
153-
featureColumnName,
149+
predictionTransformer.FeatureColumn,
154150
permutationCount,
155151
useFeatureWeightFilter,
156152
numberOfExamplesToUse);
@@ -199,33 +195,31 @@ private static BinaryClassificationMetrics BinaryClassifierDelta(
199195
/// </para>
200196
/// </remarks>
201197
/// <param name="catalog">The clustering catalog.</param>
202-
/// <param name="model">The model on which to evaluate feature importance.</param>
198+
/// <param name="predictionTransformer">The model on which to evaluate feature importance.</param>
203199
/// <param name="data">The evaluation data set.</param>
204200
/// <param name="labelColumnName">Label column name.</param>
205-
/// <param name="featureColumnName">Feature column name.</param>
206201
/// <param name="useFeatureWeightFilter">Use features weight to pre-filter features.</param>
207202
/// <param name="numberOfExamplesToUse">Limit the number of examples to evaluate on. <cref langword="null"/> means up to ~2 bln examples from <paramref param="data"/> will be used.</param>
208203
/// <param name="permutationCount">The number of permutations to perform.</param>
209204
/// <returns>Array of per-feature 'contributions' to the score.</returns>
210205
public static ImmutableArray<MulticlassClassificationMetricsStatistics>
211206
PermutationFeatureImportance<TModel>(
212207
this MulticlassClassificationCatalog catalog,
213-
IPredictionTransformer<TModel> model,
208+
ISingleFeaturePredictionTransformer<TModel> predictionTransformer,
214209
IDataView data,
215210
string labelColumnName = DefaultColumnNames.Label,
216-
string featureColumnName = DefaultColumnNames.Features,
217211
bool useFeatureWeightFilter = false,
218212
int? numberOfExamplesToUse = null,
219213
int permutationCount = 1)
220214
{
221215
return PermutationFeatureImportance<TModel, MulticlassClassificationMetrics, MulticlassClassificationMetricsStatistics>.GetImportanceMetricsMatrix(
222216
catalog.GetEnvironment(),
223-
model,
217+
predictionTransformer,
224218
data,
225219
() => new MulticlassClassificationMetricsStatistics(),
226220
idv => catalog.Evaluate(idv, labelColumnName),
227221
MulticlassClassificationDelta,
228-
featureColumnName,
222+
predictionTransformer.FeatureColumn,
229223
permutationCount,
230224
useFeatureWeightFilter,
231225
numberOfExamplesToUse);
@@ -279,35 +273,33 @@ private static MulticlassClassificationMetrics MulticlassClassificationDelta(
279273
/// </para>
280274
/// </remarks>
281275
/// <param name="catalog">The clustering catalog.</param>
282-
/// <param name="model">The model on which to evaluate feature importance.</param>
276+
/// <param name="predictionTransformer">The model on which to evaluate feature importance.</param>
283277
/// <param name="data">The evaluation data set.</param>
284278
/// <param name="labelColumnName">Label column name.</param>
285279
/// <param name="rowGroupColumnName">GroupId column name</param>
286-
/// <param name="featureColumnName">Feature column name.</param>
287280
/// <param name="useFeatureWeightFilter">Use features weight to pre-filter features.</param>
288281
/// <param name="numberOfExamplesToUse">Limit the number of examples to evaluate on. <cref langword="null"/> means up to ~2 bln examples from <paramref param="data"/> will be used.</param>
289282
/// <param name="permutationCount">The number of permutations to perform.</param>
290283
/// <returns>Array of per-feature 'contributions' to the score.</returns>
291284
public static ImmutableArray<RankingMetricsStatistics>
292285
PermutationFeatureImportance<TModel>(
293286
this RankingCatalog catalog,
294-
IPredictionTransformer<TModel> model,
287+
ISingleFeaturePredictionTransformer<TModel> predictionTransformer,
295288
IDataView data,
296289
string labelColumnName = DefaultColumnNames.Label,
297290
string rowGroupColumnName = DefaultColumnNames.GroupId,
298-
string featureColumnName = DefaultColumnNames.Features,
299291
bool useFeatureWeightFilter = false,
300292
int? numberOfExamplesToUse = null,
301293
int permutationCount = 1)
302294
{
303295
return PermutationFeatureImportance<TModel, RankingMetrics, RankingMetricsStatistics>.GetImportanceMetricsMatrix(
304296
catalog.GetEnvironment(),
305-
model,
297+
predictionTransformer,
306298
data,
307299
() => new RankingMetricsStatistics(),
308300
idv => catalog.Evaluate(idv, labelColumnName, rowGroupColumnName),
309301
RankingDelta,
310-
featureColumnName,
302+
predictionTransformer.FeatureColumn,
311303
permutationCount,
312304
useFeatureWeightFilter,
313305
numberOfExamplesToUse);

0 commit comments

Comments
 (0)