-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Field-aware factorization machine to estimator #912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
FAFM doesn't inherit from TrainerEstimatorBase, just implements ITrainerEstimator
@@ -30,23 +32,33 @@ public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer | |||
protected readonly ISchemaBindableMapper BindableMapper; | |||
protected readonly ISchema TrainSchema; | |||
|
|||
public string FeatureColumn { get; } | |||
public string[] FeatureColumn { get; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FeatureColumn [](start = 24, length = 13)
oh no... I don' like this change already.
Forcing ALL predictors to expose parallel arrays of feature columns is not a great change #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there's much overhead from it. Do you think it will cause problems? #Closed
|
||
IEstimator<ITransformer> est = new FieldAwareFactorizationMachineTrainer(Env, "Label", new[] { "Feature1", "Feature2", "Feature3", "Feature4" }); | ||
|
||
//var result = est.Fit(data); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// [](start = 12, length = 2)
remove commented out code #Resolved
{ | ||
public sealed class OnlineLinearTests : TestDataPipeBase | ||
public partial class TrainerEstimators |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TrainerEstimators [](start = 25, length = 17)
thanks for making this change #ByDesign
using Microsoft.ML.Runtime.Internal.Calibration; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
|
||
namespace Microsoft.ML.Runtime | ||
{ | ||
public interface IPredictionTransformer<out TModel> : ITransformer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IPredictionTransformer [](start = 21, length = 22)
I don't think we even need an interface for FFM, since for the time being it's the only trainer that accepts multiple feature columns. #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to inquire about all trainers, it is useful to have them extend one interface. #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🕐
…tures. Splitting IPredictionTransformer into two interfaces Creating a transformer wrapping the FAFM predictor.
{ | ||
public partial class TrainerEstimators : TestDataPipeBase | ||
{ | ||
[Fact] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Fact] [](start = 7, length = 7)
The test is failing to check whether the input is valid for fit. Resolve before check in. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TModel Model { get; } | ||
} | ||
|
||
public interface IClassicPredictionTransformer<out TModel> : IPredictionTransformer<TModel> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Classic [](start = 22, length = 7)
'Classic' sounds a bit wacky, even though it was my suggestion. Maybe 'SingleInput' ?.. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
protected void SaveModel(ModelSaveContext ctx) | ||
{ | ||
// *** Binary format *** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
*** Binary format *** [](start = 15, length = 21)
whenever you save or load, need *** Binary format ***
#Pending
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -7,7 +7,7 @@ | |||
namespace Microsoft.ML.Runtime.Training | |||
{ | |||
public interface ITrainerEstimator<out TTransformer, out TPredictor>: IEstimator<TTransformer> | |||
where TTransformer: IPredictionTransformer<TPredictor> | |||
where TTransformer: IClassicPredictionTransformer<TPredictor> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IClassicPredictionTransformer [](start = 28, length = 29)
I believe this change is incorrect, isn't it? #Pending
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, maybe i named in reverse, but this is the old interface.
In reply to: 218633926 [](ancestors = 218633926)
public string[] FeatureColumns { get; } | ||
|
||
/// <summary> | ||
/// The type of the prediction transformer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type of the prediction transformer [](start = 12, length = 38)
fix the comment #Resolved
loaderSignature: LoaderSignature); | ||
} | ||
|
||
private static FieldAwareFactorizationMachinePredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
private [](start = 8, length = 7)
this is just to avoid having a public ctor? #Pending
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't load through the ctor; i think bc for the ctor the loadable class signature requires the args.
In reply to: 218634595 [](ancestors = 218634595)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
protected void SaveModel(ModelSaveContext ctx) | ||
{ | ||
// *** Binary format *** | ||
// model: prediction model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model: prediction model. [](start = 15, length = 24)
Technically the model isn't part of this format, since you're not writing it to the stream, you're writing it somewhere else entirely, but that's OK. Consider fixing if you have to change the code anyway. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@[email protected] fixing it == remove the comment?
In reply to: 218818528 [](ancestors = 218818528)
} | ||
} | ||
|
||
public abstract class ClassicPredictionTransformerBase<TModel> : PredictionTransformerBase<TModel>, IClassicPredictionTransformer<TModel>, ICanSaveModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess by "classic" this just means a prediction transformer base that takes a single features column as its input. Classic is a bit of a funny word, but then again SinlgeFeaturesPredictionTransformerBase
might be a bit of a mouthful and itself potentially confusing. #Resolved
@@ -26,7 +26,7 @@ namespace Microsoft.ML.Runtime.Learners | |||
{ | |||
|
|||
using TDistPredictor = IDistPredictorProducing<float, float>; | |||
using TScalarTrainer = ITrainerEstimator<IPredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>; | |||
using TScalarTrainer = ITrainerEstimator<IClassicPredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Classic [](start = 46, length = 7)
Yes after seeing this a bit I think the word "classic" is just going to confuse the heck out of a lot of people. Please consider doing something else. #Resolved
TModel Model { get; } | ||
} | ||
|
||
public interface IClassicPredictionTransformer<out TModel> : IPredictionTransformer<TModel> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IClassicPredictionTransformer [](start = 21, length = 29)
I sort of feel like an interface called IClassicPredictionTransformer
needs some XML comment on it. All public interfaces should, but especially one with the word "classic" in the name.
I am choosing to interpret this interface as it gives me classic coke whenever I use it. #Resolved
CheckSameSchemas(scoredTrain.Schema, scoredTrain2.Schema); | ||
CheckSameValues(scoredTrain, scoredTrain2); | ||
CheckSameSchemas(scoredTrain.Schema, scoredTrain2.Schema); | ||
CheckSameValues(scoredTrain, scoredTrain2); | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somehow these got indented two spaces for no obvious reason. #Resolved
/// Base class for transformers with no feature column, or more than one feature columns. | ||
/// </summary> | ||
/// <typeparam name="TModel"></typeparam> | ||
public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer<TModel> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No documentation? #Pending
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good thanks @sfilipi
close-reopen to trigger build |
FAFM now extends TrainerEstimatorBase