-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Extended contexts to regression and multiclass, added FFM pigstension #993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,13 +51,13 @@ public static BinaryClassifierEvaluator.CalibratedResult Evaluate<T>( | |
} | ||
|
||
/// <summary> | ||
/// Evaluates scored binary classification data. | ||
/// Evaluates scored binary classification data, if the predictions are not calibrated. | ||
/// </summary> | ||
/// <typeparam name="T">The shape type for the input data.</typeparam> | ||
/// <param name="ctx">The binary classification context.</param> | ||
/// <param name="data">The data to evaluate.</param> | ||
/// <param name="label">The index delegate for the label column.</param> | ||
/// <param name="pred">The index delegate for columns from calibrated prediction of a binary classifier. | ||
/// <param name="pred">The index delegate for columns from uncalibrated prediction of a binary classifier. | ||
/// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param> | ||
/// <returns>The evaluation results for these uncalibrated outputs.</returns> | ||
public static BinaryClassifierEvaluator.Result Evaluate<T>( | ||
|
@@ -83,5 +83,89 @@ public static BinaryClassifierEvaluator.Result Evaluate<T>( | |
var eval = new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { }); | ||
return eval.Evaluate(data.AsDynamic, labelName, scoreName, predName); | ||
} | ||
|
||
/// <summary> | ||
/// Evaluates scored multiclass classification data. | ||
/// </summary> | ||
/// <typeparam name="T">The shape type for the input data.</typeparam> | ||
/// <typeparam name="TKey">The value type for the key label.</typeparam> | ||
/// <param name="ctx">The multiclass classification context.</param> | ||
/// <param name="data">The data to evaluate.</param> | ||
/// <param name="label">The index delegate for the label column.</param> | ||
/// <param name="pred">The index delegate for columns from the prediction of a multiclass classifier. | ||
/// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param> | ||
/// <param name="topK">If given a positive value, the <see cref="MultiClassClassifierEvaluator.Result.TopKAccuracy"/> will be filled with | ||
/// the top-K accuracy, that is, the accuracy assuming we consider an example with the correct class within | ||
/// the top-K values as being stored "correctly."</param> | ||
/// <returns>The evaluation metrics.</returns> | ||
public static MultiClassClassifierEvaluator.Result Evaluate<T, TKey>( | ||
this MulticlassClassificationContext ctx, | ||
DataView<T> data, | ||
Func<T, Key<uint, TKey>> label, | ||
Func<T, (Vector<float> score, Key<uint, TKey> predictedLabel)> pred, | ||
int topK = 0) | ||
{ | ||
Contracts.CheckValue(data, nameof(data)); | ||
var env = StaticPipeUtils.GetEnvironment(data); | ||
Contracts.AssertValue(env); | ||
env.CheckValue(label, nameof(label)); | ||
env.CheckValue(pred, nameof(pred)); | ||
env.CheckParam(topK >= 0, nameof(topK), "Must not be negative."); | ||
|
||
var indexer = StaticPipeUtils.GetIndexer(data); | ||
string labelName = indexer.Get(label(indexer.Indices)); | ||
(var scoreCol, var predCol) = pred(indexer.Indices); | ||
Contracts.CheckParam(scoreCol != null, nameof(pred), "Indexing delegate resulted in null score column."); | ||
Contracts.CheckParam(predCol != null, nameof(pred), "Indexing delegate resulted in null predicted label column."); | ||
string scoreName = indexer.Get(scoreCol); | ||
string predName = indexer.Get(predCol); | ||
|
||
var args = new MultiClassClassifierEvaluator.Arguments() { }; | ||
if (topK > 0) | ||
args.OutputTopKAcc = topK; | ||
|
||
var eval = new MultiClassClassifierEvaluator(env, args); | ||
return eval.Evaluate(data.AsDynamic, labelName, scoreName, predName); | ||
} | ||
|
||
private sealed class TrivialRegressionLossFactory : ISupportRegressionLossFactory | ||
{ | ||
private readonly IRegressionLoss _loss; | ||
public TrivialRegressionLossFactory(IRegressionLoss loss) => _loss = loss; | ||
public IRegressionLoss CreateComponent(IHostEnvironment env) => _loss; | ||
} | ||
|
||
/// <summary> | ||
/// Evaluates scored multiclass classification data. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
regression #Resolved |
||
/// </summary> | ||
/// <typeparam name="T">The shape type for the input data.</typeparam> | ||
/// <param name="ctx">The regression context.</param> | ||
/// <param name="data">The data to evaluate.</param> | ||
/// <param name="label">The index delegate for the label column.</param> | ||
/// <param name="score">The index delegate for predicted score column.</param> | ||
/// <param name="loss">Potentially custom loss function. If left unspecified defaults to <see cref="SquaredLoss"/>.</param> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
maybe just simpler: The loss function? #Resolved |
||
/// <returns>The evaluation metrics.</returns> | ||
public static RegressionEvaluator.Result Evaluate<T>( | ||
this RegressionContext ctx, | ||
DataView<T> data, | ||
Func<T, Scalar<float>> label, | ||
Func<T, Scalar<float>> score, | ||
IRegressionLoss loss = null) | ||
{ | ||
Contracts.CheckValue(data, nameof(data)); | ||
var env = StaticPipeUtils.GetEnvironment(data); | ||
Contracts.AssertValue(env); | ||
env.CheckValue(label, nameof(label)); | ||
env.CheckValue(score, nameof(score)); | ||
|
||
var indexer = StaticPipeUtils.GetIndexer(data); | ||
string labelName = indexer.Get(label(indexer.Indices)); | ||
string scoreName = indexer.Get(score(indexer.Indices)); | ||
|
||
var args = new RegressionEvaluator.Arguments() { }; | ||
if (loss != null) | ||
args.LossFunction = new TrivialRegressionLossFactory(loss); | ||
return new RegressionEvaluator(env, args).Evaluate(data.AsDynamic, labelName, scoreName); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it time to have a Contracts.CheckIndexingDelegate() :) #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure it is :)
In reply to: 219660357 [](ancestors = 219660357)