-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Extended contexts to regression and multiclass, added FFM pigstension #993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
/// <param name="label">The name of the label column in <paramref name="data"/>.</param> | ||
/// <param name="score">The name of the score column in <paramref name="data"/>.</param> | ||
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param> | ||
/// <returns>The evaluation results for these uncalibrated outputs.</returns> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
returns [](start = 13, length = 7)
remove #Closed
using Microsoft.ML.Runtime.Training; | ||
using System; | ||
using Xunit; | ||
using Xunit.Abstractions; | ||
using Microsoft.ML.Runtime.FactorizationMachine; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using [](start = 0, length = 5)
(placeholder) #Closed
namespace Microsoft.ML.Trainers | ||
{ | ||
/// <summary> | ||
/// Extension methods and utilities for instantiating FFM trainer estimators inside statically typed pipelines. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FFM trai [](start = 58, length = 8)
#Resolved
var indexer = StaticPipeUtils.GetIndexer(data); | ||
string labelName = indexer.Get(label(indexer.Indices)); | ||
(var scoreCol, var predCol) = pred(indexer.Indices); | ||
Contracts.CheckParam(scoreCol != null, nameof(pred), "Indexing delegate resulted in null score column."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indexing delegate res [](start = 66, length = 21)
is it time to have a Contracts.CheckIndexingDelegate() :) #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
} | ||
|
||
/// <summary> | ||
/// Evaluates scored multiclass classification data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
multiclass [](start = 29, length = 10)
regression #Resolved
/// <param name="data">The data to evaluate.</param> | ||
/// <param name="label">The index delegate for the label column.</param> | ||
/// <param name="score">The index delegate for predicted score column.</param> | ||
/// <param name="loss">Potentially custom loss function. If left unspecified defaults to <see cref="SquaredLoss"/>.</param> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potentially custom loss function [](start = 31, length = 32)
maybe just simpler: The loss function? #Resolved
public sealed class MulticlassClassificationContext : TrainContextBase | ||
{ | ||
/// <summary> | ||
/// For trainers for performing multiclass classification. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For trainers for performing multiclass classification. [](start = 12, length = 54)
clarify, maybe. #Resolved
/// the model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to | ||
/// be informed about what was learnt.</param> | ||
/// <returns>The predicted output.</returns> | ||
public static (Scalar<float> score, Scalar<bool> predictedLabel) FieldAwareFactorizationMachine(this BinaryClassificationContext.BinaryClassificationTrainers ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FieldAwareFactorizationMachine [](start = 73, length = 30)
Shall we go with FAFM, if we are doing the acronym for SDCA? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the choice is to go with acronyms, we should use FFM to follow the relevant papers (such as this one). #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I would prefer NOT to have acronyms, so fix SDCA instead
In reply to: 219684257 [](ancestors = 219684257)
|
||
private sealed class CustomReconciler : TrainerEstimatorReconciler | ||
{ | ||
private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score, DefaultColumnNames.PredictedLabel }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Score [](start = 92, length = 5)
is probability missing because FAFM doesn't calibrate? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -7,6 +7,7 @@ | |||
<ProjectReference Include="..\..\src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj" /> | |||
<ProjectReference Include="..\..\src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" /> | |||
<ProjectReference Include="..\Microsoft.ML.TestFramework\Microsoft.ML.TestFramework.csproj" /> | |||
<NativeAssemblyReference Include="FactorizationMachineNative" /> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NativeAssemblyReference [](start = 5, length = 23)
want to move it close to line 17? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -7,6 +7,7 @@ | |||
<ProjectReference Include="..\..\src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj" /> | |||
<ProjectReference Include="..\..\src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" /> | |||
<ProjectReference Include="..\Microsoft.ML.TestFramework\Microsoft.ML.TestFramework.csproj" /> | |||
<NativeAssemblyReference Include="FactorizationMachineNative" /> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FactorizationMachineNative [](start = 38, length = 26)
Hmmm. Opportunistic fix for some other issue? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's that FFM requires FFMNative, and I added a test for it in static training
In reply to: 219902222 [](ancestors = 219902222)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Zruty0 !
Part of #754, extends #949.
Added MulticlassClassification context and RegressionContext, with corresponding Evaluate methods. Also added FFM binary trainer to the context extensions.