-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Public API for remaining learners #1901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
_latentWeightsAligned.CopyFrom(latentWeightsAligned.AsSpan()); | ||
} | ||
|
||
internal FieldAwareFactorizationMachineModelParameters(IHostEnvironment env, bool norm, int fieldCount, int featureCount, int latentDim, | ||
float[] linearWeights, AlignedArray latentWeightsAligned) : base(env, LoaderSignature) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float[] linearWeights, AlignedArray latentWeightsAligned [](start = 12, length = 56)
@wschin, are this two similar array with only difference is one of them is array and second is align array?
If yes, I don't understand why I need to pass both of them, and not just get first and construct align version.
If not, why they called this way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
linearWeights
is linear coefficient where SSE is not very helpful to compute related values. In contrast, latentWeights
is the latent vector of each feature and heavily uses SEE.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is all nice from internal point of view.
My question are user specific.
Does user need to provide this TWO SAME ARRAYS or not?
If not, can you formulate rule according to which we can construct align array?
In reply to: 242650269 [](ancestors = 242650269)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same array? LinearWeights
and latentWeights
are two independent things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I shouldn't review code at 11 pm, you are right.
Still I don't think we should expect user to pass align array and construct it internally.
In reply to: 242666580 [](ancestors = 242666580)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the new iteration, I have added methods to return latent weights and linear weights to the user. The latent dimension is something specified by the user at training. The latent weights are converted to the aligned weights by adjusting the latent dimension. This is now hidden from the user. The expectation is that user will call FieldAwareFactorizationMachines.GetLatentWeights()
which will have the unadjusted latent dimension, then they can adjust the weights as they want, and pass it to the public constructor. The latent dimension will be adjusted during construction.
In reply to: 242693412 [](ancestors = 242693412,242666580)
...ft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
Outdated
Show resolved
Hide resolved
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; | ||
|
||
public BinaryClassificationGamPredictor(IHostEnvironment env, int inputLength, Dataset trainset, | ||
public GamBinaryClassificationModelParameters(IHostEnvironment env, int inputLength, Dataset trainset, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GamBinaryClassificationModelParameters [](start = 15, length = 38)
Do you plan to add comments to all public constructors in separate PR? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separate PR sometimes means forever. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🕐
...ft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
Outdated
Show resolved
Hide resolved
...ft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
Outdated
Show resolved
Hide resolved
...ft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
Show resolved
Hide resolved
src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
Outdated
Show resolved
Hide resolved
Yeah, it's overridden in GamRegression and GamVlassification, but those are derived classes so I'll make it In reply to: 448111459 [](ancestors = 448111459) Refers to: src/Microsoft.ML.FastTree/GamTrainer.cs:245 in d3ea4df. [](commit_id = d3ea4df, deletion_comment = False) |
...ft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
Outdated
Show resolved
Hide resolved
...ft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ooof, quite a lot of work I see @najeeb-kazmi ! Thanks for fixing the names in so many places, and making so much more internal.
Overall looks good. Two comments: |
...ft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
Outdated
Show resolved
Hide resolved
...ft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
Outdated
Show resolved
Hide resolved
...ft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
Outdated
Show resolved
Hide resolved
...ft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
Outdated
Show resolved
Hide resolved
...ft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
Outdated
Show resolved
Hide resolved
...ft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
Outdated
Show resolved
Hide resolved
...ft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
Outdated
Show resolved
Hide resolved
...ft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
Outdated
Show resolved
Hide resolved
...ft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
Outdated
Show resolved
Hide resolved
...ft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
Outdated
Show resolved
Hide resolved
...ft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
Show resolved
Hide resolved
Items.AsSpan().CopyTo(dst); | ||
} | ||
|
||
public void CopyTo(Span<float> dst, int index, int count) | ||
{ | ||
Contracts.Assert(0 <= count && count <= _size); | ||
Contracts.Assert(dst != null); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
double-checking that this is intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Must be. I did not touch this code, except Float -> float!
In reply to: 243714825 [](ancestors = 243714825)
@@ -107,44 +107,44 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc | |||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unrelated to your PR, but since i stumped upon this, is this used at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it is part of TrainerEstimatorBase
and has to be implemented
In reply to: 243714960 [](ancestors = 243714960)
_latentWeightsAligned[indexAligned + k] = 0; | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tagging @wschin to check on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixes #1703
Predictors covered in this PR:
EnsemblePredictorBase
EnsembleDistributionPredictor
EnsemblePredictor
EnsembleMultiClassPredictor
GamPredictorBase
BinaryClassificationGamPredictor
RegressionGamPredictor
PcaPredictor
FieldAwareFactorizationmachinePredictor
MultiClassNaiveBayesPredictor
OvaPredictor
PkpdPredictor
RandomPredictor
PriorPredictor