-
Notifications
You must be signed in to change notification settings - Fork 1.9k
API scenarios implemented with low-level functions #653
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
86cf8c9
add 10 examples of api scenarios
e3db0e7
introspective training
7351367
address some Pete comments
1436f00
use same normalization in test and train folds.
fac8f53
saveasonnx and empty lines!
b067989
Visibility
dc7d1e3
metacomponents (OVA)
9012351
Merge branch 'master' into ivanidze/api_scenarios
68fda63
merge with master and address changes in OVA.
575da9b
address some PR comments
33306c1
address few more comments
ecc0690
add extensibility example!
e0e8d50
the ugliness of cross validation!
5b6161a
add multithreaded prediction test
590a415
some changes
292554f
only lock option for multiprediction
689194f
go through schema to detect column names and it's types.
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
58 changes: 58 additions & 0 deletions
58
test/Microsoft.ML.Tests/Scenarios/Api/ApiScenariosTests.cs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Runtime.Api; | ||
using Microsoft.ML.TestFramework; | ||
using Xunit.Abstractions; | ||
|
||
namespace Microsoft.ML.Tests.Scenarios.Api | ||
{ | ||
/// <summary> | ||
/// Common utility functions for API scenarios tests. | ||
/// </summary> | ||
public partial class ApiScenariosTests : BaseTestClass | ||
{ | ||
public ApiScenariosTests(ITestOutputHelper output) : base(output) | ||
{ | ||
} | ||
|
||
public const string IrisDataPath = "iris.data"; | ||
public const string SentimentDataPath = "wikipedia-detox-250-line-data.tsv"; | ||
public const string SentimentTestPath = "wikipedia-detox-250-line-test.tsv"; | ||
|
||
public class IrisData : IrisDataNoLabel | ||
{ | ||
public string Label; | ||
} | ||
|
||
public class IrisDataNoLabel | ||
{ | ||
public float SepalLength; | ||
public float SepalWidth; | ||
public float PetalLength; | ||
public float PetalWidth; | ||
} | ||
|
||
public class IrisPrediction | ||
{ | ||
public string PredictedLabel; | ||
public float[] Score; | ||
} | ||
|
||
public class SentimentData | ||
{ | ||
[ColumnName("Label")] | ||
public bool Sentiment; | ||
public string SentimentText; | ||
} | ||
|
||
public class SentimentPrediction | ||
{ | ||
[ColumnName("PredictedLabel")] | ||
public bool Sentiment; | ||
|
||
public float Score; | ||
} | ||
} | ||
} |
49 changes: 49 additions & 0 deletions
49
test/Microsoft.ML.Tests/Scenarios/Api/AutoNormalizationAndCaching.cs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Learners; | ||
using Xunit; | ||
|
||
namespace Microsoft.ML.Tests.Scenarios.Api | ||
{ | ||
public partial class ApiScenariosTests | ||
{ | ||
/// <summary> | ||
/// Auto-normalization and caching: It should be relatively easy for normalization | ||
/// and caching to be introduced for training, if the trainer supports or would benefit | ||
/// from that. | ||
/// </summary> | ||
[Fact] | ||
public void AutoNormalizationAndCaching() | ||
{ | ||
var dataPath = GetDataPath(SentimentDataPath); | ||
var testDataPath = GetDataPath(SentimentTestPath); | ||
|
||
using (var env = new TlcEnvironment(seed: 1, conc: 1)) | ||
{ | ||
// Pipeline. | ||
var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); | ||
|
||
var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader); | ||
|
||
// Train. | ||
var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments | ||
{ | ||
NumThreads = 1, | ||
ConvergenceTolerance = 1f | ||
}); | ||
|
||
// Auto-caching. | ||
IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, trans, prefetch: null) : trans; | ||
var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features"); | ||
|
||
// Auto-normalization. | ||
NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer); | ||
var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); | ||
} | ||
|
||
} | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Models; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Learners; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using Xunit; | ||
|
||
namespace Microsoft.ML.Tests.Scenarios.Api | ||
{ | ||
public partial class ApiScenariosTests | ||
{ | ||
/// <summary> | ||
/// Cross-validation: Have a mechanism to do cross validation, that is, you come up with | ||
/// a data source (optionally with stratification column), come up with an instantiable transform | ||
/// and trainer pipeline, and it will handle (1) splitting up the data, (2) training the separate | ||
/// pipelines on in-fold data, (3) scoring on the out-fold data, (4) returning the set of | ||
/// evaluations and optionally trained pipes. (People always want metrics out of xfold, | ||
/// they sometimes want the actual models too.) | ||
/// </summary> | ||
[Fact] | ||
void CrossValidation() | ||
{ | ||
var dataPath = GetDataPath(SentimentDataPath); | ||
var testDataPath = GetDataPath(SentimentTestPath); | ||
|
||
int numFolds = 5; | ||
using (var env = new TlcEnvironment(seed: 1, conc: 1)) | ||
{ | ||
// Pipeline. | ||
var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); | ||
|
||
var text = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader); | ||
IDataView trans = new GenerateNumberTransform(env, text, "StratificationColumn"); | ||
// Train. | ||
var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments | ||
{ | ||
NumThreads = 1, | ||
ConvergenceTolerance = 1f | ||
}); | ||
|
||
|
||
var metrics = new List<BinaryClassificationMetrics>(); | ||
for (int fold = 0; fold < numFolds; fold++) | ||
{ | ||
IDataView trainPipe = new RangeFilter(env, new RangeFilter.Arguments() | ||
{ | ||
Column = "StratificationColumn", | ||
Min = (Double)fold / numFolds, | ||
Max = (Double)(fold + 1) / numFolds, | ||
Complement = true | ||
}, trans); | ||
trainPipe = new OpaqueDataView(trainPipe); | ||
var trainData = new RoleMappedData(trainPipe, label: "Label", feature: "Features"); | ||
// Auto-normalization. | ||
NormalizeTransform.CreateIfNeeded(env, ref trainData, trainer); | ||
var preCachedData = trainData; | ||
// Auto-caching. | ||
if (trainer.Info.WantCaching) | ||
{ | ||
var prefetch = trainData.Schema.GetColumnRoles().Select(kc => kc.Value.Index).ToArray(); | ||
var cacheView = new CacheDataView(env, trainData.Data, prefetch); | ||
// Because the prefetching worked, we know that these are valid columns. | ||
trainData = new RoleMappedData(cacheView, trainData.Schema.GetColumnRoleNames()); | ||
} | ||
|
||
var predictor = trainer.Train(new Runtime.TrainContext(trainData)); | ||
IDataView testPipe = new RangeFilter(env, new RangeFilter.Arguments() | ||
{ | ||
Column = "StratificationColumn", | ||
Min = (Double)fold / numFolds, | ||
Max = (Double)(fold + 1) / numFolds, | ||
Complement = false | ||
}, trans); | ||
testPipe = new OpaqueDataView(testPipe); | ||
var pipe = ApplyTransformUtils.ApplyAllTransformsToData(env, preCachedData.Data, testPipe, trainPipe); | ||
|
||
var testRoles = new RoleMappedData(pipe, trainData.Schema.GetColumnRoleNames()); | ||
|
||
IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, testRoles, env, testRoles.Schema); | ||
|
||
BinaryClassifierMamlEvaluator eval = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { }); | ||
var dataEval = new RoleMappedData(scorer, testRoles.Schema.GetColumnRoleNames(), opt: true); | ||
var dict = eval.Evaluate(dataEval); | ||
var foldMetrics = BinaryClassificationMetrics.FromMetrics(env, dict["OverallMetrics"], dict["ConfusionMatrix"]); | ||
metrics.Add(foldMetrics.Single()); | ||
} | ||
} | ||
} | ||
} | ||
} |
60 changes: 60 additions & 0 deletions
60
test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Runtime.Api; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Learners; | ||
using System.Linq; | ||
using Xunit; | ||
|
||
namespace Microsoft.ML.Tests.Scenarios.Api | ||
{ | ||
|
||
public partial class ApiScenariosTests | ||
{ | ||
/// <summary> | ||
/// Decomposable train and predict: Train on Iris multiclass problem, which will require | ||
/// a transform on labels. Be able to reconstitute the pipeline for a prediction only task, | ||
/// which will essentially "drop" the transform over labels, while retaining the property | ||
/// that the predicted label for this has a key-type, the probability outputs for the classes | ||
/// have the class labels as slot names, etc. This should be do-able without ugly compromises like, | ||
/// say, injecting a dummy label. | ||
/// </summary> | ||
[Fact] | ||
void DecomposableTrainAndPredict() | ||
{ | ||
var dataPath = GetDataPath(IrisDataPath); | ||
using (var env = new TlcEnvironment()) | ||
{ | ||
var loader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); | ||
var term = new TermTransform(env, loader, "Label"); | ||
var concat = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"); | ||
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }); | ||
|
||
IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, concat, prefetch: null) : concat; | ||
var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features"); | ||
|
||
// Auto-normalization. | ||
NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer); | ||
var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); | ||
|
||
var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features"); | ||
IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema); | ||
|
||
// Cut out term transform from pipeline. | ||
var newScorer = ApplyTransformUtils.ApplyAllTransformsToData(env, scorer, loader, term); | ||
var keyToValue = new KeyToValueTransform(env, newScorer, "PredictedLabel"); | ||
var model = env.CreatePredictionEngine<IrisDataNoLabel, IrisPrediction>(keyToValue); | ||
|
||
var testLoader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); | ||
var testData = testLoader.AsEnumerable<IrisDataNoLabel>(env, false); | ||
foreach (var input in testData.Take(20)) | ||
{ | ||
var prediction = model.Predict(input); | ||
Assert.True(prediction.PredictedLabel == "Iris-setosa"); | ||
} | ||
} | ||
} | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Runtime.Api; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Learners; | ||
using Xunit; | ||
using Microsoft.ML.Models; | ||
|
||
namespace Microsoft.ML.Tests.Scenarios.Api | ||
{ | ||
public partial class ApiScenariosTests | ||
{ | ||
/// <summary> | ||
/// Evaluation: Similar to the simple train scenario, except instead of having some | ||
/// predictive structure, be able to score another "test" data file, run the result | ||
/// through an evaluator and get metrics like AUC, accuracy, PR curves, and whatnot. | ||
/// Getting metrics out of this shoudl be as straightforward and unannoying as possible. | ||
/// </summary> | ||
[Fact] | ||
public void Evaluation() | ||
{ | ||
var dataPath = GetDataPath(SentimentDataPath); | ||
var testDataPath = GetDataPath(SentimentTestPath); | ||
|
||
using (var env = new TlcEnvironment(seed: 1, conc: 1)) | ||
{ | ||
// Pipeline | ||
var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); | ||
|
||
var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); | ||
|
||
// Train | ||
var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments | ||
{ | ||
NumThreads = 1 | ||
}); | ||
|
||
var cached = new CacheDataView(env, trans, prefetch: null); | ||
var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); | ||
var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); | ||
var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); | ||
IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema); | ||
|
||
// Create prediction engine and test predictions. | ||
var model = env.CreatePredictionEngine<SentimentData, SentimentPrediction>(scorer); | ||
|
||
// Take a couple examples out of the test data and run predictions on top. | ||
var testLoader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath))); | ||
var testData = testLoader.AsEnumerable<SentimentData>(env, false); | ||
|
||
var dataEval = new RoleMappedData(scorer, label: "Label", feature: "Features", opt: true); | ||
|
||
var evaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { }); | ||
var metricsDict = evaluator.Evaluate(dataEval); | ||
|
||
var metrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0]; | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What, all of them? :) #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data file sorted by label, so first 50 is iris-setosa, then next 50 is iris-versicolor, and then last 50 is iris-virginica
In reply to: 208363615 [](ancestors = 208363615)