-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Restore OVA ability to preserve key names on predicted label #3101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
33a578e
a73b04e
fbbcfcf
d46ae57
2bb3ecc
a61b914
b68c4e5
06cce1a
10541d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -438,7 +438,7 @@ public void ContinueTrainingSymbolicStochasticGradientDescent() | |
} | ||
|
||
/// <summary> | ||
/// Training: Meta-compononts function as expected. For OVA (one-versus-all), a user will be able to specify only | ||
/// Training: Meta-components function as expected. For OVA (one-versus-all), a user will be able to specify only | ||
/// binary classifier trainers. If they specify a different model class there should be a compile error. | ||
/// </summary> | ||
[Fact] | ||
|
@@ -467,5 +467,39 @@ public void MetacomponentsFunctionAsExpectedOva() | |
// Evaluate the model. | ||
var binaryClassificationMetrics = mlContext.MulticlassClassification.Evaluate(binaryClassificationPredictions); | ||
} | ||
|
||
/// <summary> | ||
/// Training: Meta-components function as expected. For OVA (one-versus-all), a user will be able to specify only | ||
/// binary classifier trainers. If they specify a different model class there should be a compile error. | ||
/// </summary> | ||
[Fact] | ||
public void MetacomponentsFunctionWithKeyHandling() | ||
{ | ||
var mlContext = new MLContext(seed: 1); | ||
|
||
var data = mlContext.Data.LoadFromTextFile<Iris>(GetDataPath(TestDatasets.iris.trainFilename), | ||
hasHeader: TestDatasets.iris.fileHasHeader, | ||
separatorChar: TestDatasets.iris.fileSeparator); | ||
|
||
// Create a model training an OVA trainer with a binary classifier. | ||
var binaryClassificationTrainer = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression( | ||
new LbfgsLogisticRegressionBinaryTrainer.Options { MaximumNumberOfIterations = 10, NumberOfThreads = 1, }); | ||
var binaryClassificationPipeline = mlContext.Transforms.Concatenate("Features", Iris.Features) | ||
.AppendCacheCheckpoint(mlContext) | ||
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label")) | ||
.Append(mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryClassificationTrainer)) | ||
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel")); | ||
|
||
// Fit the binary classification pipeline. | ||
var binaryClassificationModel = binaryClassificationPipeline.Fit(data); | ||
|
||
// Transform the data | ||
var binaryClassificationPredictions = binaryClassificationModel.Transform(data); | ||
|
||
// Evaluate the model. | ||
var binaryClassificationMetrics = mlContext.MulticlassClassification.Evaluate(binaryClassificationPredictions); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we spot check some of these values to make sure the code isn't returning garbage? |
||
|
||
Assert.Equal(0.4367, binaryClassificationMetrics.LogLoss, 4); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,22 +37,22 @@ void PredictAndMetadata() | |
|
||
var testLoader = ml.Data.LoadFromTextFile(dataPath, TestDatasets.irisData.GetLoaderColumns(), separatorChar: ',', hasHeader: true); | ||
var testData = ml.Data.CreateEnumerable<IrisData>(testLoader, false); | ||
|
||
// During prediction we will get Score column with 3 float values. | ||
// We need to find way to map each score to original label. | ||
// In order to do what we need to get SlotNames from Score column. | ||
// Slot names on top of Score column represent original labels for i-th value in Score array. | ||
VBuffer<ReadOnlyMemory<char>> slotNames = default; | ||
engine.OutputSchema[nameof(IrisPrediction.Score)].GetSlotNames(ref slotNames); | ||
// In order to do what we need to get TrainingLabelValues from Score column. | ||
// TrainingLabelValues on top of Score column represent original labels for i-th value in Score array. | ||
VBuffer<ReadOnlyMemory<char>> originalLabels = default; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In this particular case we should be propagating both slot names and label names, right? Since they're string in both cases? While I see the point in augmenting the test to cover this new metadata type, is there any particular reason to remove the test that the vector has teh appropriate slot names? #WontFix There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is old scenario test. Purpose of it to show user how to do work with metadata. All tests with TestEstimatorCore routing would test on presence of slotnames and TrainingLabelValues. And we have plenty of them,. why should we do anything here with slotname? In reply to: 271046772 [](ancestors = 271046772) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, that's fine. I had the idea that these "showing the user" things were more the point of functional tests, but as you like. In reply to: 271050961 [](ancestors = 271050961,271046772) |
||
engine.OutputSchema[nameof(IrisPrediction.Score)].Annotations.GetValue(AnnotationUtils.Kinds.TrainingLabelValues, ref originalLabels); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Is the distinction with slot names that slots names must be text, while these might be any type? That might excuse not using them. But in such a case I'd argue that we should still have the slot names for descriptive user-facing purposes. so I'd like to confirm we're still doing that. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No we don't, but I can change that. Any reason while we want continue to propagate slotnames? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, let's imagine I write out a text file, and I have this scores column. With slot names, I get a descriptive header. Without it I don't. Does that make sense? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only if original labels were string, but ok. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Multiple metadata kinds is good, thanks Ivan. I believe this is what you did, that is, from the code I read you are propagating the labels always, and propogating slot names if tehy're text, and that seems fine to me. In reply to: 269742478 [](ancestors = 269742478) |
||
// Since we apply MapValueToKey estimator with default parameters, key values | ||
// depends on order of occurence in data file. Which is "Iris-setosa", "Iris-versicolor", "Iris-virginica" | ||
// So if we have Score column equal to [0.2, 0.3, 0.5] that's mean what score for | ||
// Iris-setosa is 0.2 | ||
// Iris-versicolor is 0.3 | ||
// Iris-virginica is 0.5. | ||
Assert.True(slotNames.GetItemOrDefault(0).ToString() == "Iris-setosa"); | ||
Assert.True(slotNames.GetItemOrDefault(1).ToString() == "Iris-versicolor"); | ||
Assert.True(slotNames.GetItemOrDefault(2).ToString() == "Iris-virginica"); | ||
Assert.Equal("Iris-setosa", originalLabels.GetItemOrDefault(0).ToString()); | ||
Assert.Equal("Iris-versicolor", originalLabels.GetItemOrDefault(1).ToString()); | ||
Assert.Equal("Iris-virginica", originalLabels.GetItemOrDefault(2).ToString()); | ||
|
||
// Let's look how we can convert key value for PredictedLabel to original labels. | ||
// We need to read KeyValues for "PredictedLabel" column. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this change discovered by a test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to make sure all multiclass learners works fine with my changes (so I run bunch of different one, on my test)
My test has only 2 features, and I got exception.
Mainly because Utils.EnsureSize use 4 as length for array even if you specify 1 or 2 or 3.
It make sense for VBuffer (since we have Count or Length, i'm always confused about which one is actual size of whole collection and which is size of elements in it), not sure why we do it for arrays as well.
In reply to: 271070400 [](ancestors = 271070400)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to put all those tests into regression? That way we can catch bugs like this in the future?