diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln
index 9d969b6be2..2b9bcb6cf9 100644
--- a/Microsoft.ML.sln
+++ b/Microsoft.ML.sln
@@ -33,6 +33,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TestFramework"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Predictor.Tests", "test\Microsoft.ML.Predictor.Tests\Microsoft.ML.Predictor.Tests.csproj", "{6B047E09-39C9-4583-96F3-685D84CA4117}"
EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Functional.Tests", "test\Microsoft.ML.Functional.Tests\Microsoft.ML.Functional.Tests.csproj", "{CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}"
+EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.ResultProcessor", "src\Microsoft.ML.ResultProcessor\Microsoft.ML.ResultProcessor.csproj", "{3769FCC3-9AFF-4C37-97E9-6854324681DF}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.FastTree", "src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj", "{B7B593C5-FB8C-4ADA-A638-5B53B47D087E}"
@@ -928,6 +930,18 @@ Global
{5E920CAC-5A28-42FB-936E-49C472130953}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
{5E920CAC-5A28-42FB-936E-49C472130953}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
{5E920CAC-5A28-42FB-936E-49C472130953}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
+ {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
+ {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
+ {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU
+ {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU
+ {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Release|Any CPU.Build.0 = Release|Any CPU
+ {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
+ {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
+ {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
+ {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -1011,6 +1025,7 @@ Global
{85D0CAFD-2FE8-496A-88C7-585D35B94243} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{31D38B21-102B-41C0-9E0A-2FE0BF68D123} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
{5E920CAC-5A28-42FB-936E-49C472130953} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
+ {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
diff --git a/build/Dependencies.props b/build/Dependencies.props
index 896ca68978..221d1e6552 100644
--- a/build/Dependencies.props
+++ b/build/Dependencies.props
@@ -43,6 +43,8 @@
0.11.3
0.0.3-test
+ 0.0.10-test
+ 0.0.4-test
diff --git a/docs/code/MlNetCookBook.md b/docs/code/MlNetCookBook.md
index 1a4b423e82..bb5e866c04 100644
--- a/docs/code/MlNetCookBook.md
+++ b/docs/code/MlNetCookBook.md
@@ -688,11 +688,11 @@ var catColumns = data.GetColumn(mlContext, "CategoricalFeatures").Take
// Build several alternative featurization pipelines.
var pipeline =
// Convert each categorical feature into one-hot encoding independently.
- mlContext.Transforms.Categorical.OneHotEncoding("CategoricalFeatures", "CategoricalOneHot")
+ mlContext.Transforms.Categorical.OneHotEncoding("CategoricalOneHot", "CategoricalFeatures")
// Convert all categorical features into indices, and build a 'word bag' of these.
- .Append(mlContext.Transforms.Categorical.OneHotEncoding("CategoricalFeatures", "CategoricalBag", CategoricalTransform.OutputKind.Bag))
+ .Append(mlContext.Transforms.Categorical.OneHotEncoding("CategoricalBag", "CategoricalFeatures", CategoricalTransform.OutputKind.Bag))
// One-hot encode the workclass column, then drop all the categories that have fewer than 10 instances in the train set.
- .Append(mlContext.Transforms.Categorical.OneHotEncoding("Workclass", "WorkclassOneHot"))
+ .Append(mlContext.Transforms.Categorical.OneHotEncoding("WorkclassOneHot", "Workclass"))
.Append(mlContext.Transforms.FeatureSelection.CountFeatureSelectingEstimator("WorkclassOneHot", "WorkclassOneHotTrimmed", count: 10));
// Let's train our pipeline, and then apply it to the same data.
@@ -825,12 +825,12 @@ var pipeline =
.Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent());
// Split the data 90:10 into train and test sets, train and evaluate.
-var (trainData, testData) = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1);
+var split = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1);
// Train the model.
-var model = pipeline.Fit(trainData);
+var model = pipeline.Fit(split.TrainSet);
// Compute quality metrics on the test set.
-var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(testData));
+var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(split.TestSet));
Console.WriteLine(metrics.AccuracyMicro);
// Now run the 5-fold cross-validation experiment, using the same pipeline.
@@ -838,7 +838,7 @@ var cvResults = mlContext.MulticlassClassification.CrossValidate(data, pipeline,
// The results object is an array of 5 elements. For each of the 5 folds, we have metrics, model and scored test data.
// Let's compute the average micro-accuracy.
-var microAccuracies = cvResults.Select(r => r.metrics.AccuracyMicro);
+var microAccuracies = cvResults.Select(r => r.Metrics.AccuracyMicro);
Console.WriteLine(microAccuracies.Average());
```
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs
index eb46a9683a..ed75040e77 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs
@@ -43,7 +43,7 @@ public static void Calibration()
var data = reader.Read(dataFile);
// Split the dataset into two parts: one used for training, the other to train the calibrator
- var (trainData, calibratorTrainingData) = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1);
+ var split = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1);
// Featurize the text column through the FeaturizeText API.
// Then append the StochasticDualCoordinateAscentBinary binary classifier, setting the "Label" column as the label of the dataset, and
@@ -56,12 +56,12 @@ public static void Calibration()
loss: new HingeLoss())); // By specifying loss: new HingeLoss(), StochasticDualCoordinateAscent will train a support vector machine (SVM).
// Fit the pipeline, and get a transformer that knows how to score new data.
- var transformer = pipeline.Fit(trainData);
+ var transformer = pipeline.Fit(split.TrainSet);
IPredictor model = transformer.LastTransformer.Model;
// Let's score the new data. The score will give us a numerical estimation of the chance that the particular sample
// bears positive sentiment. This estimate is relative to the numbers obtained.
- var scoredData = transformer.Transform(calibratorTrainingData);
+ var scoredData = transformer.Transform(split.TestSet);
var scoredDataPreview = scoredData.Preview();
PrintRowViewValues(scoredDataPreview);
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/IidChangePointDetectorTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/IidChangePointDetectorTransform.cs
index 7f588568c4..9b34510bd0 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/IidChangePointDetectorTransform.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/IidChangePointDetectorTransform.cs
@@ -5,11 +5,9 @@
using System;
using System.Collections.Generic;
using System.IO;
-using System.Linq;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Data;
-using Microsoft.ML.TimeSeries;
-using Microsoft.ML.TimeSeriesProcessing;
+using Microsoft.ML.Transforms.TimeSeries;
namespace Microsoft.ML.Samples.Dynamic
{
@@ -54,16 +52,9 @@ public static void IidChangePointDetectorTransform()
// Setup IidSpikeDetector arguments
string outputColumnName = nameof(ChangePointPrediction.Prediction);
string inputColumnName = nameof(IidChangePointData.Value);
- var args = new IidChangePointDetector.Arguments()
- {
- Source = inputColumnName,
- Name = outputColumnName,
- Confidence = 95, // The confidence for spike detection in the range [0, 100]
- ChangeHistoryLength = Size / 4, // The length of the sliding window on p-values for computing the martingale score.
- };
// The transformed data.
- var transformedData = new IidChangePointEstimator(ml, args).Fit(dataView).Transform(dataView);
+ var transformedData = ml.Transforms.IidChangePointEstimator(outputColumnName, inputColumnName, 95, Size / 4).Fit(dataView).Transform(dataView);
// Getting the data of the newly created column as an IEnumerable of ChangePointPrediction.
var predictionColumn = ml.CreateEnumerable(transformedData, reuseRowObject: false);
@@ -119,16 +110,9 @@ public static void IidChangePointDetectorPrediction()
// Setup IidSpikeDetector arguments
string outputColumnName = nameof(ChangePointPrediction.Prediction);
string inputColumnName = nameof(IidChangePointData.Value);
- var args = new IidChangePointDetector.Arguments()
- {
- Source = inputColumnName,
- Name = outputColumnName,
- Confidence = 95, // The confidence for spike detection in the range [0, 100]
- ChangeHistoryLength = Size / 4, // The length of the sliding window on p-values for computing the martingale score.
- };
// Time Series model.
- ITransformer model = new IidChangePointEstimator(ml, args).Fit(dataView);
+ ITransformer model = ml.Transforms.IidChangePointEstimator(outputColumnName, inputColumnName, 95, Size / 4).Fit(dataView);
// Create a time series prediction engine from the model.
var engine = model.CreateTimeSeriesPredictionFunction(ml);
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/IidSpikeDetectorTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/IidSpikeDetectorTransform.cs
index bcbeaa36ee..c2fedc5275 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/IidSpikeDetectorTransform.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/IidSpikeDetectorTransform.cs
@@ -1,11 +1,9 @@
using System;
using System.Collections.Generic;
using System.IO;
-using System.Linq;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Data;
-using Microsoft.ML.TimeSeries;
-using Microsoft.ML.TimeSeriesProcessing;
+using Microsoft.ML.Transforms.TimeSeries;
namespace Microsoft.ML.Samples.Dynamic
{
@@ -51,16 +49,9 @@ public static void IidSpikeDetectorTransform()
// Setup IidSpikeDetector arguments
string outputColumnName = nameof(IidSpikePrediction.Prediction);
string inputColumnName = nameof(IidSpikeData.Value);
- var args = new IidSpikeDetector.Arguments()
- {
- Source = inputColumnName,
- Name = outputColumnName,
- Confidence = 95, // The confidence for spike detection in the range [0, 100]
- PvalueHistoryLength = Size / 4 // The size of the sliding window for computing the p-value; shorter windows are more sensitive to spikes.
- };
// The transformed data.
- var transformedData = new IidSpikeEstimator(ml, args).Fit(dataView).Transform(dataView);
+ var transformedData = ml.Transforms.IidSpikeEstimator(outputColumnName, inputColumnName, 95, Size / 4).Fit(dataView).Transform(dataView);
// Getting the data of the newly created column as an IEnumerable of IidSpikePrediction.
var predictionColumn = ml.CreateEnumerable(transformedData, reuseRowObject: false);
@@ -108,16 +99,8 @@ public static void IidSpikeDetectorPrediction()
// Setup IidSpikeDetector arguments
string outputColumnName = nameof(IidSpikePrediction.Prediction);
string inputColumnName = nameof(IidSpikeData.Value);
- var args = new IidSpikeDetector.Arguments()
- {
- Source = inputColumnName,
- Name = outputColumnName,
- Confidence = 95, // The confidence for spike detection in the range [0, 100]
- PvalueHistoryLength = Size / 4 // The size of the sliding window for computing the p-value; shorter windows are more sensitive to spikes.
- };
-
// The transformed model.
- ITransformer model = new IidSpikeEstimator(ml, args).Fit(dataView);
+ ITransformer model = ml.Transforms.IidChangePointEstimator(outputColumnName, inputColumnName, 95, Size).Fit(dataView);
// Create a time series prediction engine from the model.
var engine = model.CreateTimeSeriesPredictionFunction(ml);
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs
index 4b25bdc805..f85b68bb7e 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs
@@ -57,7 +57,7 @@ public static void LogisticRegression()
IDataView data = reader.Read(dataFilePath);
- var (trainData, testData) = ml.BinaryClassification.TrainTestSplit(data, testFraction: 0.2);
+ var split = ml.BinaryClassification.TrainTestSplit(data, testFraction: 0.2);
var pipeline = ml.Transforms.Concatenate("Text", "workclass", "education", "marital-status",
"relationship", "ethnicity", "sex", "native-country")
@@ -66,9 +66,9 @@ public static void LogisticRegression()
"education-num", "capital-gain", "capital-loss", "hours-per-week"))
.Append(ml.BinaryClassification.Trainers.LogisticRegression());
- var model = pipeline.Fit(trainData);
+ var model = pipeline.Fit(split.TrainSet);
- var dataWithPredictions = model.Transform(testData);
+ var dataWithPredictions = model.Transform(split.TestSet);
var metrics = ml.BinaryClassification.Evaluate(dataWithPredictions);
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIHelper.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIHelper.cs
index 79052c2314..73471a0c52 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIHelper.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIHelper.cs
@@ -1,7 +1,8 @@
using System;
using System.Linq;
using Microsoft.Data.DataView;
-using Microsoft.ML.Learners;
+using Microsoft.ML.Data;
+using Microsoft.ML.Trainers;
using Microsoft.ML.SamplesUtils;
using Microsoft.ML.Trainers.HalLearners;
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs
index 205f7cc4ce..77e2e63d27 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs
@@ -1,6 +1,6 @@
using System;
using System.Linq;
-using Microsoft.ML.Learners;
+using Microsoft.ML.Trainers;
namespace Microsoft.ML.Samples.Dynamic.PermutationFeatureImportance
{
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs
index 341827bde3..0f06b9c905 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs
@@ -54,7 +54,7 @@ public static void SDCA_BinaryClassification()
// Step 3: Run Cross-Validation on this pipeline.
var cvResults = mlContext.BinaryClassification.CrossValidate(data, pipeline, labelColumn: "Sentiment");
- var accuracies = cvResults.Select(r => r.metrics.Accuracy);
+ var accuracies = cvResults.Select(r => r.Metrics.Accuracy);
Console.WriteLine(accuracies.Average());
// If we wanted to specify more advanced parameters for the algorithm,
@@ -70,7 +70,7 @@ public static void SDCA_BinaryClassification()
// Run Cross-Validation on this second pipeline.
var cvResults_advancedPipeline = mlContext.BinaryClassification.CrossValidate(data, pipeline, labelColumn: "Sentiment", numFolds: 3);
- accuracies = cvResults_advancedPipeline.Select(r => r.metrics.Accuracy);
+ accuracies = cvResults_advancedPipeline.Select(r => r.Metrics.Accuracy);
Console.WriteLine(accuracies.Average());
}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/SsaChangePointDetectorTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/SsaChangePointDetectorTransform.cs
index efc52a90dc..223bab2277 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/SsaChangePointDetectorTransform.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/SsaChangePointDetectorTransform.cs
@@ -1,11 +1,9 @@
using System;
using System.Collections.Generic;
using System.IO;
-using System.Linq;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Data;
-using Microsoft.ML.TimeSeries;
-using Microsoft.ML.TimeSeriesProcessing;
+using Microsoft.ML.Transforms.TimeSeries;
namespace Microsoft.ML.Samples.Dynamic
{
@@ -49,19 +47,9 @@ public static void SsaChangePointDetectorTransform()
// Setup SsaChangePointDetector arguments
var inputColumnName = nameof(SsaChangePointData.Value);
var outputColumnName = nameof(ChangePointPrediction.Prediction);
- var args = new SsaChangePointDetector.Arguments()
- {
- Source = inputColumnName,
- Name = outputColumnName,
- Confidence = 95, // The confidence for spike detection in the range [0, 100]
- ChangeHistoryLength = 8, // The length of the window for detecting a change in trend; shorter windows are more sensitive to spikes.
- TrainingWindowSize = TrainingSize, // The number of points from the beginning of the sequence used for training.
- SeasonalWindowSize = SeasonalitySize + 1 // An upper bound on the largest relevant seasonality in the input time series."
-
- };
// The transformed data.
- var transformedData = new SsaChangePointEstimator(ml, args).Fit(dataView).Transform(dataView);
+ var transformedData = ml.Transforms.SsaChangePointEstimator(outputColumnName, inputColumnName, 95, 8, TrainingSize, SeasonalitySize + 1).Fit(dataView).Transform(dataView);
// Getting the data of the newly created column as an IEnumerable of ChangePointPrediction.
var predictionColumn = ml.CreateEnumerable(transformedData, reuseRowObject: false);
@@ -120,19 +108,9 @@ public static void SsaChangePointDetectorPrediction()
// Setup SsaChangePointDetector arguments
var inputColumnName = nameof(SsaChangePointData.Value);
var outputColumnName = nameof(ChangePointPrediction.Prediction);
- var args = new SsaChangePointDetector.Arguments()
- {
- Source = inputColumnName,
- Name = outputColumnName,
- Confidence = 95, // The confidence for spike detection in the range [0, 100]
- ChangeHistoryLength = 8, // The length of the window for detecting a change in trend; shorter windows are more sensitive to spikes.
- TrainingWindowSize = TrainingSize, // The number of points from the beginning of the sequence used for training.
- SeasonalWindowSize = SeasonalitySize + 1 // An upper bound on the largest relevant seasonality in the input time series."
-
- };
// Train the change point detector.
- ITransformer model = new SsaChangePointEstimator(ml, args).Fit(dataView);
+ ITransformer model = ml.Transforms.SsaChangePointEstimator(outputColumnName, inputColumnName, 95, 8, TrainingSize, SeasonalitySize + 1).Fit(dataView);
// Create a prediction engine from the model for feeding new data.
var engine = model.CreateTimeSeriesPredictionFunction(ml);
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/SsaSpikeDetectorTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/SsaSpikeDetectorTransform.cs
index db8bfcad70..9ebf1a41d2 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/SsaSpikeDetectorTransform.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/SsaSpikeDetectorTransform.cs
@@ -1,11 +1,9 @@
using System;
using System.Collections.Generic;
using System.IO;
-using System.Linq;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Data;
-using Microsoft.ML.TimeSeries;
-using Microsoft.ML.TimeSeriesProcessing;
+using Microsoft.ML.Transforms.TimeSeries;
namespace Microsoft.ML.Samples.Dynamic
{
@@ -56,18 +54,9 @@ public static void SsaSpikeDetectorTransform()
// Setup IidSpikeDetector arguments
var inputColumnName = nameof(SsaSpikeData.Value);
var outputColumnName = nameof(SsaSpikePrediction.Prediction);
- var args = new SsaSpikeDetector.Arguments()
- {
- Source = inputColumnName,
- Name = outputColumnName,
- Confidence = 95, // The confidence for spike detection in the range [0, 100]
- PvalueHistoryLength = 8, // The size of the sliding window for computing the p-value; shorter windows are more sensitive to spikes.
- TrainingWindowSize = TrainingSize, // The number of points from the beginning of the sequence used for training.
- SeasonalWindowSize = SeasonalitySize + 1 // An upper bound on the largest relevant seasonality in the input time series."
- };
// The transformed data.
- var transformedData = new SsaSpikeEstimator(ml, args).Fit(dataView).Transform(dataView);
+ var transformedData = ml.Transforms.SsaSpikeEstimator(outputColumnName, inputColumnName, 95, 8, TrainingSize, SeasonalitySize + 1).Fit(dataView).Transform(dataView);
// Getting the data of the newly created column as an IEnumerable of SsaSpikePrediction.
var predictionColumn = ml.CreateEnumerable(transformedData, reuseRowObject: false);
@@ -127,18 +116,9 @@ public static void SsaSpikeDetectorPrediction()
// Setup IidSpikeDetector arguments
var inputColumnName = nameof(SsaSpikeData.Value);
var outputColumnName = nameof(SsaSpikePrediction.Prediction);
- var args = new SsaSpikeDetector.Arguments()
- {
- Source = inputColumnName,
- Name = outputColumnName,
- Confidence = 95, // The confidence for spike detection in the range [0, 100]
- PvalueHistoryLength = 8, // The size of the sliding window for computing the p-value; shorter windows are more sensitive to spikes.
- TrainingWindowSize = TrainingSize, // The number of points from the beginning of the sequence used for training.
- SeasonalWindowSize = SeasonalitySize + 1 // An upper bound on the largest relevant seasonality in the input time series."
- };
// Train the change point detector.
- ITransformer model = new SsaSpikeEstimator(ml, args).Fit(dataView);
+ ITransformer model = ml.Transforms.SsaChangePointEstimator(outputColumnName, inputColumnName, 95, 8, TrainingSize, SeasonalitySize + 1).Fit(dataView);
// Create a prediction engine from the model for feeding new data.
var engine = model.CreateTimeSeriesPredictionFunction(ml);
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptron.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptron.cs
new file mode 100644
index 0000000000..ee2e3fdd94
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptron.cs
@@ -0,0 +1,46 @@
+using Microsoft.ML;
+
+namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification
+{
+ public static class AveragedPerceptron
+ {
+ public static void Example()
+ {
+ // In this examples we will use the adult income dataset. The goal is to predict
+ // if a person's income is above $50K or not, based on different pieces of information about that person.
+ // For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult
+
+ // Create a new context for ML.NET operations. It can be used for exception tracking and logging,
+ // as a catalog of available operations and as the source of randomness.
+ // Setting the seed to a fixed number in this example to make outputs deterministic.
+ var mlContext = new MLContext(seed: 0);
+
+ // Download and featurize the dataset
+ var data = SamplesUtils.DatasetUtils.LoadFeaturizedAdultDataset(mlContext);
+
+ // Leave out 10% of data for testing
+ var trainTestData = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1);
+
+ // Create data training pipeline
+ var pipeline = mlContext.BinaryClassification.Trainers.AveragedPerceptron(
+ "IsOver50K", "Features", numIterations: 10);
+
+ // Fit this pipeline to the training data
+ var model = pipeline.Fit(trainTestData.TrainSet);
+
+ // Evaluate how the model is doing on the test data
+ var dataWithPredictions = model.Transform(trainTestData.TestSet);
+ var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(dataWithPredictions, "IsOver50K");
+ SamplesUtils.ConsoleUtils.PrintMetrics(metrics);
+
+ // Output:
+ // Accuracy: 0.86
+ // AUC: 0.91
+ // F1 Score: 0.68
+ // Negative Precision: 0.90
+ // Negative Recall: 0.91
+ // Positive Precision: 0.70
+ // Positive Recall: 0.66
+ }
+ }
+}
\ No newline at end of file
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs
new file mode 100644
index 0000000000..ac9296a96d
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs
@@ -0,0 +1,58 @@
+using Microsoft.ML;
+using Microsoft.ML.Trainers.Online;
+
+namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification
+{
+ public static class AveragedPerceptronWithOptions
+ {
+ public static void Example()
+ {
+ // In this examples we will use the adult income dataset. The goal is to predict
+ // if a person's income is above $50K or not, based on different pieces of information about that person.
+ // For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult
+
+ // Create a new context for ML.NET operations. It can be used for exception tracking and logging,
+ // as a catalog of available operations and as the source of randomness.
+ // Setting the seed to a fixed number in this example to make outputs deterministic.
+ var mlContext = new MLContext(seed: 0);
+
+ // Download and featurize the dataset
+ var data = SamplesUtils.DatasetUtils.LoadFeaturizedAdultDataset(mlContext);
+
+ // Leave out 10% of data for testing
+ var trainTestData = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1);
+
+ // Define the trainer options
+ var options = new AveragedPerceptronTrainer.Options()
+ {
+ LossFunction = new SmoothedHingeLoss.Arguments(),
+ LearningRate = 0.1f,
+ DoLazyUpdates = false,
+ RecencyGain = 0.1f,
+ NumberOfIterations = 10,
+ LabelColumn = "IsOver50K",
+ FeatureColumn = "Features"
+ };
+
+ // Create data training pipeline
+ var pipeline = mlContext.BinaryClassification.Trainers.AveragedPerceptron(options);
+
+ // Fit this pipeline to the training data
+ var model = pipeline.Fit(trainTestData.TrainSet);
+
+ // Evaluate how the model is doing on the test data
+ var dataWithPredictions = model.Transform(trainTestData.TestSet);
+ var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(dataWithPredictions, "IsOver50K");
+ SamplesUtils.ConsoleUtils.PrintMetrics(metrics);
+
+ // Output:
+ // Accuracy: 0.86
+ // AUC: 0.90
+ // F1 Score: 0.66
+ // Negative Precision: 0.89
+ // Negative Recall: 0.93
+ // Positive Precision: 0.72
+ // Positive Recall: 0.61
+ }
+ }
+}
\ No newline at end of file
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Recommendation/MatrixFactorization.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Recommendation/MatrixFactorization.cs
new file mode 100644
index 0000000000..91de5a8be8
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Recommendation/MatrixFactorization.cs
@@ -0,0 +1,62 @@
+using System;
+using System.Collections.Generic;
+using Microsoft.ML.Data;
+using static Microsoft.ML.SamplesUtils.DatasetUtils;
+
+namespace Microsoft.ML.Samples.Dynamic
+{
+ public partial class MatrixFactorizationExample
+ {
+ // This example first creates in-memory data and then use it to train a matrix factorization mode with default parameters. Afterward, quality metrics are reported.
+ public static void MatrixFactorization()
+ {
+ // Create a new context for ML.NET operations. It can be used for exception tracking and logging,
+ // as a catalog of available operations and as the source of randomness.
+ var mlContext = new MLContext(seed: 0, conc: 1);
+
+ // Get a small in-memory dataset.
+ var data = GetRecommendationData();
+
+ // Convert the in-memory matrix into an IDataView so that ML.NET components can consume it.
+ var dataView = mlContext.Data.ReadFromEnumerable(data);
+
+ // Create a matrix factorization trainer which may consume "Value" as the training label, "MatrixColumnIndex" as the
+ // matrix's column index, and "MatrixRowIndex" as the matrix's row index. Here nameof(...) is used to extract field
+ // names' in MatrixElement class.
+ var pipeline = mlContext.Recommendation().Trainers.MatrixFactorization(nameof(MatrixElement.MatrixColumnIndex),
+ nameof(MatrixElement.MatrixRowIndex), nameof(MatrixElement.Value), 10, 0.2, 10);
+
+ // Train a matrix factorization model.
+ var model = pipeline.Fit(dataView);
+
+ // Apply the trained model to the training set.
+ var prediction = model.Transform(dataView);
+
+ // Calculate regression matrices for the prediction result.
+ var metrics = mlContext.Recommendation().Evaluate(prediction,
+ label: nameof(MatrixElement.Value), score: nameof(MatrixElementForScore.Score));
+
+ // Print out some metrics for checking the model's quality.
+ Console.WriteLine($"L1 - {metrics.L1}"); // 0.17208
+ Console.WriteLine($"L2 - {metrics.L2}"); // 0.04766
+ Console.WriteLine($"LossFunction - {metrics.LossFn}"); // 0.04766
+ Console.WriteLine($"RMS - {metrics.Rms}"); //0.21831
+ Console.WriteLine($"RSquared - {metrics.RSquared}"); // 0.97616
+
+ // Create two two entries for making prediction. Of course, the prediction value, Score, is unknown so it can be anything
+ // (here we use Score=0 and it will be overwritten by the true prediction). If any of row and column indexes are out-of-range
+ // (e.g., MatrixColumnIndex=99999), the prediction value will be NaN.
+ var testMatrix = new List() {
+ new MatrixElementForScore() { MatrixColumnIndex = 1, MatrixRowIndex = 7, Score = 0 },
+ new MatrixElementForScore() { MatrixColumnIndex = 3, MatrixRowIndex = 6, Score = 0 } };
+
+ // Again, convert the test data to a format supported by ML.NET.
+ var testDataView = mlContext.Data.ReadFromEnumerable(testMatrix);
+ // Feed the test data into the model and then iterate through all predictions.
+ foreach (var pred in mlContext.CreateEnumerable(model.Transform(testDataView), false))
+ Console.WriteLine($"Predicted value at row {pred.MatrixRowIndex - 1} and column {pred.MatrixColumnIndex - 1} is {pred.Score}");
+ // Predicted value at row 7 and column 1 is 2.876928
+ // Predicted value at row 6 and column 3 is 3.587935
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/MatrixFactorization.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Recommendation/MatrixFactorizationWithOptions.cs
similarity index 50%
rename from docs/samples/Microsoft.ML.Samples/Dynamic/MatrixFactorization.cs
rename to docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Recommendation/MatrixFactorizationWithOptions.cs
index 449aa46017..eb776087bf 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/MatrixFactorization.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Recommendation/MatrixFactorizationWithOptions.cs
@@ -2,65 +2,29 @@
using System.Collections.Generic;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
+using static Microsoft.ML.SamplesUtils.DatasetUtils;
namespace Microsoft.ML.Samples.Dynamic
{
- public class MatrixFactorizationExample
+ public partial class MatrixFactorizationExample
{
- // The following variables defines the shape of a matrix. Its shape is _synthesizedMatrixRowCount-by-_synthesizedMatrixColumnCount.
- // Because in ML.NET key type's minimal value is zero, the first row index is always zero in C# data structure (e.g., MatrixColumnIndex=0
- // and MatrixRowIndex=0 in MatrixElement below specifies the value at the upper-left corner in the training matrix). If user's row index
- // starts with 1, their row index 1 would be mapped to the 2nd row in matrix factorization module and their first row may contain no values.
- // This behavior is also true to column index.
- const int _synthesizedMatrixFirstColumnIndex = 1;
- const int _synthesizedMatrixFirstRowIndex = 1;
- const int _synthesizedMatrixColumnCount = 60;
- const int _synthesizedMatrixRowCount = 100;
-
- // A data structure used to encode a single value in matrix
- internal class MatrixElement
- {
- // Matrix column index is at most _synthesizedMatrixColumnCount + _synthesizedMatrixFirstColumnIndex.
- [KeyType(Count = _synthesizedMatrixColumnCount + _synthesizedMatrixFirstColumnIndex)]
- public uint MatrixColumnIndex;
- // Matrix row index is at most _synthesizedMatrixRowCount + _synthesizedMatrixFirstRowIndex.
- [KeyType(Count = _synthesizedMatrixRowCount + _synthesizedMatrixFirstRowIndex)]
- public uint MatrixRowIndex;
- // The value at the column MatrixColumnIndex and row MatrixRowIndex.
- public float Value;
- }
-
- // A data structure used to encode prediction result. Comparing with MatrixElement, The field Value in MatrixElement is
- // renamed to Score because Score is the default name of matrix factorization's output.
- internal class MatrixElementForScore
- {
- [KeyType(Count = _synthesizedMatrixColumnCount + _synthesizedMatrixFirstColumnIndex)]
- public uint MatrixColumnIndex;
- [KeyType(Count = _synthesizedMatrixRowCount + _synthesizedMatrixFirstRowIndex)]
- public uint MatrixRowIndex;
- public float Score;
- }
// This example first creates in-memory data and then use it to train a matrix factorization model. Afterward, quality metrics are reported.
- public static void MatrixFactorizationInMemoryData()
+ public static void MatrixFactorizationWithOptions()
{
- // Create an in-memory matrix as a list of tuples (column index, row index, value).
- var dataMatrix = new List();
- for (uint i = _synthesizedMatrixFirstColumnIndex; i < _synthesizedMatrixFirstColumnIndex + _synthesizedMatrixColumnCount; ++i)
- for (uint j = _synthesizedMatrixFirstRowIndex; j < _synthesizedMatrixFirstRowIndex + _synthesizedMatrixRowCount; ++j)
- dataMatrix.Add(new MatrixElement() { MatrixColumnIndex = i, MatrixRowIndex = j, Value = (i + j) % 5 });
-
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
// as a catalog of available operations and as the source of randomness.
var mlContext = new MLContext(seed: 0, conc: 1);
+ // Get a small in-memory dataset.
+ var data = GetRecommendationData();
+
// Convert the in-memory matrix into an IDataView so that ML.NET components can consume it.
- var dataView = mlContext.Data.ReadFromEnumerable(dataMatrix);
+ var dataView = mlContext.Data.ReadFromEnumerable(data);
// Create a matrix factorization trainer which may consume "Value" as the training label, "MatrixColumnIndex" as the
// matrix's column index, and "MatrixRowIndex" as the matrix's row index. Here nameof(...) is used to extract field
// names' in MatrixElement class.
-
var options = new MatrixFactorizationTrainer.Options
{
MatrixColumnIndexColumnName = nameof(MatrixElement.MatrixColumnIndex),
@@ -68,7 +32,8 @@ public static void MatrixFactorizationInMemoryData()
LabelColumnName = nameof(MatrixElement.Value),
NumIterations = 10,
NumThreads = 1,
- K = 32,
+ ApproximationRank = 32,
+ LearningRate = 0.3
};
var pipeline = mlContext.Recommendation().Trainers.MatrixFactorization(options);
@@ -84,11 +49,11 @@ public static void MatrixFactorizationInMemoryData()
label: nameof(MatrixElement.Value), score: nameof(MatrixElementForScore.Score));
// Print out some metrics for checking the model's quality.
- Console.WriteLine($"L1 - {metrics.L1}");
- Console.WriteLine($"L2 - {metrics.L2}");
- Console.WriteLine($"LossFunction - {metrics.LossFn}");
- Console.WriteLine($"RMS - {metrics.Rms}");
- Console.WriteLine($"RSquared - {metrics.RSquared}");
+ Console.WriteLine($"L1 - {metrics.L1}"); // 0.16375
+ Console.WriteLine($"L2 - {metrics.L2}"); // 0.04407
+ Console.WriteLine($"LossFunction - {metrics.LossFn}"); // 0.04407
+ Console.WriteLine($"RMS - {metrics.Rms}"); // 0.2099
+ Console.WriteLine($"RSquared - {metrics.RSquared}"); // 0.97797
// Create two two entries for making prediction. Of course, the prediction value, Score, is unknown so it can be anything
// (here we use Score=0 and it will be overwritten by the true prediction). If any of row and column indexes are out-of-range
@@ -101,8 +66,10 @@ public static void MatrixFactorizationInMemoryData()
var testDataView = mlContext.Data.ReadFromEnumerable(testMatrix);
// Feed the test data into the model and then iterate through all predictions.
- foreach (var pred in mlContext.CreateEnumerable(testDataView, false))
- Console.WriteLine($"Predicted value at row {pred.MatrixRowIndex} and column {pred.MatrixColumnIndex} is {pred.Score}");
+ foreach (var pred in mlContext.CreateEnumerable(model.Transform(testDataView), false))
+ Console.WriteLine($"Predicted value at row {pred.MatrixRowIndex-1} and column {pred.MatrixColumnIndex-1} is {pred.Score}");
+ // Predicted value at row 7 and column 1 is 2.828761
+ // Predicted value at row 6 and column 3 is 3.642226
}
}
}
diff --git a/docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs b/docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs
index 50d5b8c6aa..914788eae8 100644
--- a/docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs
+++ b/docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs
@@ -1,7 +1,7 @@
using System;
using Microsoft.ML.Data;
-using Microsoft.ML.Learners;
using Microsoft.ML.StaticPipe;
+using Microsoft.ML.Trainers;
namespace Microsoft.ML.Samples.Static
{
diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs
index 6ebef55827..1e96439a1e 100644
--- a/src/Microsoft.ML.Core/Data/IEstimator.cs
+++ b/src/Microsoft.ML.Core/Data/IEstimator.cs
@@ -7,6 +7,7 @@
using System.Linq;
using Microsoft.Data.DataView;
using Microsoft.ML.Data;
+using Microsoft.ML.Model;
namespace Microsoft.ML.Core.Data
{
@@ -263,7 +264,7 @@ public interface IDataReaderEstimator
/// The transformer is a component that transforms data.
/// It also supports 'schema propagation' to answer the question of 'how will the data with this schema look, after you transform it?'.
///
- public interface ITransformer
+ public interface ITransformer : ICanSaveModel
{
///
/// Schema propagation for transformers.
diff --git a/src/Microsoft.ML.Data/Model/ModelHeader.cs b/src/Microsoft.ML.Core/Data/ModelHeader.cs
similarity index 99%
rename from src/Microsoft.ML.Data/Model/ModelHeader.cs
rename to src/Microsoft.ML.Core/Data/ModelHeader.cs
index 88fcf9bdad..d1b572e8bc 100644
--- a/src/Microsoft.ML.Data/Model/ModelHeader.cs
+++ b/src/Microsoft.ML.Core/Data/ModelHeader.cs
@@ -10,7 +10,8 @@
namespace Microsoft.ML.Model
{
- [StructLayout(LayoutKind.Explicit, Size = ModelHeader.Size)]
+ [BestFriend]
+ [StructLayout(LayoutKind.Explicit, Size = Size)]
internal struct ModelHeader
{
///
diff --git a/src/Microsoft.ML.Data/Model/ModelLoadContext.cs b/src/Microsoft.ML.Core/Data/ModelLoadContext.cs
similarity index 100%
rename from src/Microsoft.ML.Data/Model/ModelLoadContext.cs
rename to src/Microsoft.ML.Core/Data/ModelLoadContext.cs
diff --git a/src/Microsoft.ML.Data/Model/ModelLoading.cs b/src/Microsoft.ML.Core/Data/ModelLoading.cs
similarity index 98%
rename from src/Microsoft.ML.Data/Model/ModelLoading.cs
rename to src/Microsoft.ML.Core/Data/ModelLoading.cs
index 06ebc0bf06..d561389e39 100644
--- a/src/Microsoft.ML.Data/Model/ModelLoading.cs
+++ b/src/Microsoft.ML.Core/Data/ModelLoading.cs
@@ -10,6 +10,12 @@
namespace Microsoft.ML.Model
{
+ ///
+ /// Signature for a repository based model loader. This is the dual of .
+ ///
+ [BestFriend]
+ internal delegate void SignatureLoadModel(ModelLoadContext ctx);
+
public sealed partial class ModelLoadContext : IDisposable
{
public const string ModelStreamName = "Model.key";
diff --git a/src/Microsoft.ML.Data/Model/ModelSaveContext.cs b/src/Microsoft.ML.Core/Data/ModelSaveContext.cs
similarity index 100%
rename from src/Microsoft.ML.Data/Model/ModelSaveContext.cs
rename to src/Microsoft.ML.Core/Data/ModelSaveContext.cs
diff --git a/src/Microsoft.ML.Data/Model/ModelSaving.cs b/src/Microsoft.ML.Core/Data/ModelSaving.cs
similarity index 100%
rename from src/Microsoft.ML.Data/Model/ModelSaving.cs
rename to src/Microsoft.ML.Core/Data/ModelSaving.cs
diff --git a/src/Microsoft.ML.Data/Model/Repository.cs b/src/Microsoft.ML.Core/Data/Repository.cs
similarity index 97%
rename from src/Microsoft.ML.Data/Model/Repository.cs
rename to src/Microsoft.ML.Core/Data/Repository.cs
index dd3ebfe43a..42cf955b84 100644
--- a/src/Microsoft.ML.Data/Model/Repository.cs
+++ b/src/Microsoft.ML.Core/Data/Repository.cs
@@ -10,13 +10,11 @@
namespace Microsoft.ML.Model
{
- ///
- /// Signature for a repository based model loader. This is the dual of ICanSaveModel.
- ///
- public delegate void SignatureLoadModel(ModelLoadContext ctx);
-
///
/// For saving a model into a repository.
+ /// Classes implementing should do an explicit implementation of .
+ /// Classes inheriting from a base class should overwrite the function invoked by
+ /// in that base class, if there is one.
///
public interface ICanSaveModel
{
@@ -293,6 +291,8 @@ protected Entry AddEntry(string pathEnt, Stream stream)
public sealed class RepositoryWriter : Repository
{
+ private const string DirTrainingInfo = "TrainingInfo";
+
private ZipArchive _archive;
private Queue> _closed;
@@ -301,7 +301,7 @@ public static RepositoryWriter CreateNew(Stream stream, IExceptionContext ectx =
Contracts.CheckValueOrNull(ectx);
ectx.CheckValue(stream, nameof(stream));
var rep = new RepositoryWriter(stream, ectx, useFileSystem);
- using (var ent = rep.CreateEntry(ModelFileUtils.DirTrainingInfo, "Version.txt"))
+ using (var ent = rep.CreateEntry(DirTrainingInfo, "Version.txt"))
using (var writer = Utils.OpenWriter(ent.Stream))
writer.WriteLine(typeof(RepositoryWriter).Assembly.GetName().Version);
return rep;
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs
index 25d00518f3..a42a84d559 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs
@@ -942,7 +942,7 @@ private static Stream OpenStream(string filename)
return OpenStream(files);
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
index ff5348c58c..ccd25a0039 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
@@ -502,7 +502,7 @@ private static IDataLoader LoadTransforms(ModelLoadContext ctx, IDataLoader srcL
});
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
index 795fa2c66e..2a8a25df2b 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
@@ -834,7 +834,7 @@ public Bindings(ModelLoadContext ctx, TextLoader parent)
OutputSchema = ComputeOutputSchema();
}
- public void Save(ModelSaveContext ctx)
+ internal void Save(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
@@ -1283,7 +1283,7 @@ internal static IDataLoader Create(IHostEnvironment env, Arguments args, IMultiS
internal static IDataView ReadFile(IHostEnvironment env, Arguments args, IMultiStreamSource fileSource)
=> new TextLoader(env, args, fileSource).Read(fileSource);
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -1420,7 +1420,7 @@ public RowCursor[] GetRowCursorSet(IEnumerable columnsNeeded, int
return Cursor.CreateSet(_reader, _files, active, n);
}
- public void Save(ModelSaveContext ctx) => _reader.Save(ctx);
+ void ICanSaveModel.Save(ModelSaveContext ctx) => ((ICanSaveModel)_reader).Save(ctx);
}
}
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
index bf2af44b4c..24a29b1cc9 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
@@ -18,7 +18,7 @@ namespace Microsoft.ML.Data
{
// REVIEW: this class is public, as long as the Wrappers.cs in tests still rely on it.
// It needs to become internal.
- public sealed class TransformWrapper : ITransformer, ICanSaveModel
+ public sealed class TransformWrapper : ITransformer
{
public const string LoaderSignature = "TransformWrapper";
private const string TransformDirTemplate = "Step_{0:000}";
@@ -46,7 +46,7 @@ public Schema GetOutputSchema(Schema inputSchema)
return output.Schema;
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
if (!_allowSave)
throw _host.Except("Saving is not permitted.");
diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs
index 6b5ee01d8d..6e52f83e75 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs
@@ -51,7 +51,7 @@ internal interface ITransformerChainAccessor
/// A chain of transformers (possibly empty) that end with a .
/// For an empty chain, is always .
///
- public sealed class TransformerChain : ITransformer, ICanSaveModel, IEnumerable, ITransformerChainAccessor
+ public sealed class TransformerChain : ITransformer, IEnumerable, ITransformerChainAccessor
where TLastTransformer : class, ITransformer
{
private readonly ITransformer[] _transformers;
@@ -165,7 +165,7 @@ public TransformerChain Append(TNewLast transformer, Transfo
return new TransformerChain(_transformers.AppendElement(transformer), _scopes.AppendElement(scope));
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
@@ -181,7 +181,7 @@ public void Save(ModelSaveContext ctx)
}
///
- /// The loading constructor of transformer chain. Reverse of .
+ /// The loading constructor of transformer chain. Reverse of .
///
internal TransformerChain(IHostEnvironment env, ModelLoadContext ctx)
{
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs
index 625bcf71e7..f3b7629bc4 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs
@@ -513,7 +513,7 @@ public static TransposeLoader Create(IHostEnvironment env, ModelLoadContext ctx,
});
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs b/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs
index 9f7290e0d1..c34881b2d6 100644
--- a/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs
+++ b/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs
@@ -140,7 +140,7 @@ public Impl(IHostEnvironment env, string name, IDataView input, OneToOneColumn c
Metadata.Seal();
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.Assert(false, "Shouldn't serialize this!");
throw Host.ExceptNotSupp("Shouldn't serialize this");
diff --git a/src/Microsoft.ML.Data/DataView/LambdaFilter.cs b/src/Microsoft.ML.Data/DataView/LambdaFilter.cs
index 9e726ad9d4..f1cf418431 100644
--- a/src/Microsoft.ML.Data/DataView/LambdaFilter.cs
+++ b/src/Microsoft.ML.Data/DataView/LambdaFilter.cs
@@ -96,7 +96,7 @@ public Impl(IHostEnvironment env, string name, IDataView input,
_conv = conv;
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.Assert(false, "Shouldn't serialize this!");
throw Host.ExceptNotSupp("Shouldn't serialize this");
diff --git a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
index e3a2377009..23b0211404 100644
--- a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
+++ b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
@@ -131,7 +131,7 @@ public static RowToRowMapperTransform Create(IHostEnvironment env, ModelLoadCont
return h.Apply("Loading Model", ch => new RowToRowMapperTransform(h, ctx, input));
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs b/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs
index 3b88307eb6..affe071745 100644
--- a/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs
+++ b/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs
@@ -223,7 +223,7 @@ public static ChooseColumnsByIndexTransform Create(IHostEnvironment env, ModelLo
return h.Apply("Loading Model", ch => new ChooseColumnsByIndexTransform(h, ctx, input));
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
index c4e5ad76cd..7bade429f1 100644
--- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
@@ -948,7 +948,7 @@ public static BinaryPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadC
return new BinaryPerInstanceEvaluator(env, ctx, schema);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -960,7 +960,7 @@ public override void Save(ModelSaveContext ctx)
// float: _threshold
// byte: _useRaw
- base.Save(ctx);
+ base.SaveModel(ctx);
ctx.SaveStringOrNull(_probCol);
Contracts.Assert(FloatUtils.IsFinite(_threshold));
ctx.Writer.Write(_threshold);
diff --git a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs
index 067dccfb50..f1c4b19b0a 100644
--- a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs
@@ -628,13 +628,13 @@ public static ClusteringPerInstanceEvaluator Create(IHostEnvironment env, ModelL
return new ClusteringPerInstanceEvaluator(env, ctx, schema);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
// *** Binary format **
// base
// int: number of clusters
- base.Save(ctx);
+ base.SaveModel(ctx);
Host.Assert(_numClusters > 0);
ctx.Writer.Write(_numClusters);
}
diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs
index 88fa1ee285..f15c514e0e 100644
--- a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs
+++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs
@@ -510,7 +510,12 @@ protected PerInstanceEvaluatorBase(IHostEnvironment env, ModelLoadContext ctx,
throw Host.ExceptSchemaMismatch(nameof(schema), "score", ScoreCol);
}
- public virtual void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
+
+ ///
+ /// Derived class, for example A, should overwrite so that (()A).Save(ctx) can correctly dump A.
+ ///
+ private protected virtual void SaveModel(ModelSaveContext ctx)
{
// *** Binary format **
// int: Id of the score column name
diff --git a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs
index 7e78be4a75..bb0a27d87d 100644
--- a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs
@@ -631,7 +631,7 @@ public static MultiClassPerInstanceEvaluator Create(IHostEnvironment env, ModelL
return new MultiClassPerInstanceEvaluator(env, ctx, schema);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -642,7 +642,7 @@ public override void Save(ModelSaveContext ctx)
// int: number of classes
// int[]: Ids of the class names
- base.Save(ctx);
+ base.SaveModel(ctx);
Host.Assert(_numClasses > 0);
ctx.Writer.Write(_numClasses);
for (int i = 0; i < _numClasses; i++)
diff --git a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
index 1e232aff0d..6757da940c 100644
--- a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
@@ -426,7 +426,7 @@ public static MultiOutputRegressionPerInstanceEvaluator Create(IHostEnvironment
return new MultiOutputRegressionPerInstanceEvaluator(env, ctx, schema);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -434,7 +434,7 @@ public override void Save(ModelSaveContext ctx)
// *** Binary format **
// base
- base.Save(ctx);
+ base.SaveModel(ctx);
}
private protected override Func GetDependenciesCore(Func activeOutput)
diff --git a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs
index 1332c58948..1f16938682 100644
--- a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs
@@ -324,7 +324,7 @@ public static QuantileRegressionPerInstanceEvaluator Create(IHostEnvironment env
return new QuantileRegressionPerInstanceEvaluator(env, ctx, schema);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -335,7 +335,7 @@ public override void Save(ModelSaveContext ctx)
// int: _scoreSize
// int[]: Ids of the quantile names
- base.Save(ctx);
+ base.SaveModel(ctx);
Host.Assert(_scoreSize > 0);
ctx.Writer.Write(_scoreSize);
var quantiles = _quantiles.GetValues();
diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
index ed215e8387..b420b1ad27 100644
--- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
@@ -597,11 +597,11 @@ public static RankerPerInstanceTransform Create(IHostEnvironment env, ModelLoadC
return h.Apply("Loading Model", ch => new RankerPerInstanceTransform(h, ctx, input));
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
- _transform.Save(ctx);
+ ((ICanSaveModel)_transform).Save(ctx);
}
public long? GetRowCount()
@@ -715,7 +715,7 @@ public Transform(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
_bindings = new Bindings(Host, input.Schema, false, LabelCol, ScoreCol, GroupCol, _truncationLevel);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.AssertValue(ctx);
@@ -725,7 +725,7 @@ public override void Save(ModelSaveContext ctx)
// int: _labelGains.Length
// double[]: _labelGains
- base.Save(ctx);
+ base.SaveModel(ctx);
Host.Assert(0 < _truncationLevel && _truncationLevel < 100);
ctx.Writer.Write(_truncationLevel);
ctx.Writer.WriteDoubleArray(_labelGains);
diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs
index 69df01da67..2015f30eea 100644
--- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs
@@ -234,7 +234,7 @@ public static RegressionPerInstanceEvaluator Create(IHostEnvironment env, ModelL
return new RegressionPerInstanceEvaluator(env, ctx, schema);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -242,7 +242,7 @@ public override void Save(ModelSaveContext ctx)
// *** Binary format **
// base
- base.Save(ctx);
+ base.SaveModel(ctx);
}
private protected override Func GetDependenciesCore(Func activeOutput)
diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
index ef942c5cdf..2ee036d760 100644
--- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs
+++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
@@ -407,7 +407,7 @@ private static ValueMapperCalibratedModelParameters Crea
return new ValueMapperCalibratedModelParameters(env, ctx);
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -462,7 +462,7 @@ private static FeatureWeightsCalibratedModelParameters C
return new FeatureWeightsCalibratedModelParameters(env, ctx);
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -528,7 +528,7 @@ private static ParameterMixingCalibratedModelParameters
return new ParameterMixingCalibratedModelParameters(env, ctx);
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -698,7 +698,7 @@ private static SchemaBindableCalibratedModelParameters C
return new SchemaBindableCalibratedModelParameters(env, ctx);
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
ctx.CheckAtModel();
@@ -1488,7 +1488,7 @@ private static PlattCalibrator Create(IHostEnvironment env, ModelLoadContext ctx
return new PlattCalibrator(env, ctx);
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs b/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs
index ca384c0e4e..f53bb6993e 100644
--- a/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs
+++ b/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs
@@ -200,7 +200,7 @@ internal CalibratorTransformer(IHostEnvironment env, ModelLoadContext ctx, strin
bool ITransformer.IsRowToRowMapper => true;
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
ctx.CheckAtModel();
@@ -245,7 +245,7 @@ internal Mapper(CalibratorTransformer parent, TCalibrator calibrato
private protected override Func GetDependenciesCore(Func activeOutput)
=> col => col == _scoreColIndex;
- public override void Save(ModelSaveContext ctx) => _parent.Save(ctx);
+ private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
protected override Schema.DetachedColumn[] GetOutputColumnsCore()
{
diff --git a/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs b/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs
index 55e30203d4..d90539cc80 100644
--- a/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs
+++ b/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs
@@ -146,7 +146,7 @@ private protected virtual void PredictionEngineCore(IHostEnvironment env, DataVi
disposer = inputRow.Dispose;
}
- protected virtual Func TransformerChecker(IExceptionContext ectx, ITransformer transformer)
+ private protected virtual Func TransformerChecker(IExceptionContext ectx, ITransformer transformer)
{
ectx.CheckValue(transformer, nameof(transformer));
ectx.CheckParam(transformer.IsRowToRowMapper, nameof(transformer), "Must be a row to row mapper");
diff --git a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs
index afb1f9bb51..5f837507db 100644
--- a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs
+++ b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs
@@ -163,7 +163,7 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
return new RowMapper(env, this, schema);
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.SetVersionInfo(GetVersionInfo());
diff --git a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs
index 60f5864f9f..612231ebf8 100644
--- a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs
@@ -114,7 +114,7 @@ public static Bindings Create(ModelLoadContext ctx,
return Create(env, bindable, input, roles, suffix, user: false);
}
- public override void Save(ModelSaveContext ctx)
+ internal override void SaveModel(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
@@ -205,7 +205,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
ctx.SetVersionInfo(GetVersionInfo());
- _bindings.Save(ctx);
+ _bindings.SaveModel(ctx);
}
void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx)
diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs
index 2f57bab274..25366f02f8 100644
--- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs
@@ -165,7 +165,7 @@ private static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadConte
return h.Apply("Loading Model", ch => new LabelNameBindableMapper(h, ctx));
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
index a72b09906c..bf2ec96e10 100644
--- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
+++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
@@ -160,7 +160,7 @@ public static BindingsImpl Create(ModelLoadContext ctx, Schema input,
return new BindingsImpl(input, rowMapper, suffix, scoreKind, false, scoreColIndex, predColType);
}
- public override void Save(ModelSaveContext ctx)
+ internal override void SaveModel(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
@@ -335,7 +335,7 @@ private protected PredictedLabelScorerBase(IHost host, ModelLoadContext ctx, IDa
private protected override void SaveCore(ModelSaveContext ctx)
{
Host.AssertValue(ctx);
- Bindings.Save(ctx);
+ Bindings.SaveModel(ctx);
}
void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx)
diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
index cbbb2bbe19..64f71aba83 100644
--- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
+++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
@@ -134,7 +134,11 @@ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema)
return (IRowToRowMapper)Scorer.ApplyToData(Host, new EmptyDataView(Host, inputSchema));
}
- protected void SaveModel(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
+
+ private protected abstract void SaveModel(ModelSaveContext ctx);
+
+ protected void SaveModelCore(ModelSaveContext ctx)
{
// *** Binary format ***
//
@@ -157,7 +161,7 @@ protected void SaveModel(ModelSaveContext ctx)
/// Those are all the transformers that work with one feature column.
///
/// The model used to transform the data.
- public abstract class SingleFeaturePredictionTransformerBase : PredictionTransformerBase, ISingleFeaturePredictionTransformer, ICanSaveModel
+ public abstract class SingleFeaturePredictionTransformerBase : PredictionTransformerBase, ISingleFeaturePredictionTransformer
where TModel : class
{
///
@@ -226,7 +230,7 @@ public sealed override Schema GetOutputSchema(Schema inputSchema)
return Transform(new EmptyDataView(Host, inputSchema)).Schema;
}
- public void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -235,7 +239,7 @@ public void Save(ModelSaveContext ctx)
protected virtual void SaveCore(ModelSaveContext ctx)
{
- SaveModel(ctx);
+ SaveModelCore(ctx);
ctx.SaveStringOrNull(FeatureColumn);
}
diff --git a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs
index fb68553672..082b083a40 100644
--- a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs
+++ b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs
@@ -49,7 +49,7 @@ private protected RowToRowScorerBase(IHost host, ModelLoadContext ctx, IDataView
ctx.LoadModel(host, out Bindable, "SchemaBindableMapper");
}
- public sealed override void Save(ModelSaveContext ctx)
+ private protected sealed override void SaveModel(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
ctx.CheckAtModel();
@@ -417,7 +417,7 @@ protected void SaveBase(ModelSaveContext ctx)
}
}
- public abstract void Save(ModelSaveContext ctx);
+ internal abstract void SaveModel(ModelSaveContext ctx);
protected override ColumnType GetColumnTypeCore(int iinfo)
{
diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs
index 5d5e9e3178..5b62f45a7e 100644
--- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs
+++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs
@@ -76,7 +76,9 @@ protected SchemaBindablePredictorWrapperBase(IHostEnvironment env, ModelLoadCont
ScoreType = GetScoreType(Predictor, out ValueMapper);
}
- public virtual void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
+
+ private protected virtual void SaveModel(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -283,11 +285,11 @@ public static SchemaBindablePredictorWrapper Create(IHostEnvironment env, ModelL
return new SchemaBindablePredictorWrapper(env, ctx);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.SetVersionInfo(GetVersionInfo());
- base.Save(ctx);
+ base.SaveModel(ctx);
}
private protected override void SaveAsPfaCore(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputNames)
@@ -390,11 +392,11 @@ public static SchemaBindableBinaryPredictorWrapper Create(IHostEnvironment env,
return new SchemaBindableBinaryPredictorWrapper(env, ctx);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.SetVersionInfo(GetVersionInfo());
- base.Save(ctx);
+ base.SaveModel(ctx);
}
private protected override void SaveAsPfaCore(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputNames)
@@ -631,7 +633,7 @@ private SchemaBindableQuantileRegressionPredictor(IHostEnvironment env, ModelLoa
Contracts.CheckDecode(Utils.Size(_quantiles) > 0);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.SetVersionInfo(GetVersionInfo());
@@ -641,7 +643,7 @@ public override void Save(ModelSaveContext ctx)
// int: the number of quantiles
// double[]: the quantiles
- base.Save(ctx);
+ base.SaveModel(ctx);
ctx.Writer.WriteDoubleArray(_quantiles);
}
diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs
index d8b4c53265..586ac66bd0 100644
--- a/src/Microsoft.ML.Data/TrainCatalog.cs
+++ b/src/Microsoft.ML.Data/TrainCatalog.cs
@@ -25,6 +25,31 @@ public abstract class TrainCatalogBase
[BestFriend]
internal IHostEnvironment Environment => Host;
+ ///
+ /// A pair of datasets, for the train and test set.
+ ///
+ public struct TrainTestData
+ {
+ ///
+ /// Training set.
+ ///
+ public readonly IDataView TrainSet;
+ ///
+ /// Testing set.
+ ///
+ public readonly IDataView TestSet;
+ ///
+ /// Create pair of datasets.
+ ///
+ /// Training set.
+ /// Testing set.
+ internal TrainTestData(IDataView trainSet, IDataView testSet)
+ {
+ TrainSet = trainSet;
+ TestSet = testSet;
+ }
+ }
+
///
/// Split the dataset into the train set and test set according to the given fraction.
/// Respects the if provided.
@@ -37,8 +62,7 @@ public abstract class TrainCatalogBase
/// Optional parameter used in combination with the .
/// If the is not provided, the random numbers generated to create it, will use this seed as value.
/// And if it is not provided, the default value will be used.
- /// A pair of datasets, for the train and test set.
- public (IDataView trainSet, IDataView testSet) TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null, uint? seed = null)
+ public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null, uint? seed = null)
{
Host.CheckValue(data, nameof(data));
Host.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive");
@@ -61,14 +85,71 @@ public abstract class TrainCatalogBase
Complement = false
}, data);
- return (trainFilter, testFilter);
+ return new TrainTestData(trainFilter, testFilter);
+ }
+
+ ///
+ /// Results for specific cross-validation fold.
+ ///
+ protected internal struct CrossValidationResult
+ {
+ ///
+ /// Model trained during cross validation fold.
+ ///
+ public readonly ITransformer Model;
+ ///
+ /// Scored test set with for this fold.
+ ///
+ public readonly IDataView Scores;
+ ///
+ /// Fold number.
+ ///
+ public readonly int Fold;
+
+ public CrossValidationResult(ITransformer model, IDataView scores, int fold)
+ {
+ Model = model;
+ Scores = scores;
+ Fold = fold;
+ }
+ }
+ ///
+ /// Results of running cross-validation.
+ ///
+ /// Type of metric class.
+ public sealed class CrossValidationResult where T : class
+ {
+ ///
+ /// Metrics for this cross-validation fold.
+ ///
+ public readonly T Metrics;
+ ///
+ /// Model trained during cross-validation fold.
+ ///
+ public readonly ITransformer Model;
+ ///
+ /// The scored hold-out set for this fold.
+ ///
+ public readonly IDataView ScoredHoldOutSet;
+ ///
+ /// Fold number.
+ ///
+ public readonly int Fold;
+
+ internal CrossValidationResult(ITransformer model, T metrics, IDataView scores, int fold)
+ {
+ Model = model;
+ Metrics = metrics;
+ ScoredHoldOutSet = scores;
+ Fold = fold;
+ }
}
///
/// Train the on folds of the data sequentially.
/// Return each model and each scored test dataset.
///
- protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidateTrain(IDataView data, IEstimator estimator,
+ protected internal CrossValidationResult[] CrossValidateTrain(IDataView data, IEstimator estimator,
int numFolds, string stratificationColumn, uint? seed = null)
{
Host.CheckValue(data, nameof(data));
@@ -78,7 +159,7 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate
EnsureStratificationColumn(ref data, ref stratificationColumn, seed);
- Func foldFunction =
+ Func foldFunction =
fold =>
{
var trainFilter = new RangeFilter(Host, new RangeFilter.Options
@@ -98,17 +179,17 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate
var model = estimator.Fit(trainFilter);
var scoredTest = model.Transform(testFilter);
- return (scoredTest, model);
+ return new CrossValidationResult(model, scoredTest, fold);
};
// Sequential per-fold training.
// REVIEW: we could have a parallel implementation here. We would need to
// spawn off a separate host per fold in that case.
- var result = new List<(IDataView scores, ITransformer model)>();
+ var result = new CrossValidationResult[numFolds];
for (int fold = 0; fold < numFolds; fold++)
- result.Add(foldFunction(fold));
+ result[fold] = foldFunction(fold);
- return result.ToArray();
+ return result;
}
protected internal TrainCatalogBase(IHostEnvironment env, string registrationName)
@@ -263,13 +344,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string
/// If the is not provided, the random numbers generated to create it, will use this seed as value.
/// And if it is not provided, the default value will be used.
/// Per-fold results: metrics, models, scored datasets.
- public (BinaryClassificationMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidateNonCalibrated(
+ public CrossValidationResult[] CrossValidateNonCalibrated(
IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label,
string stratificationColumn = null, uint? seed = null)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
- return result.Select(x => (EvaluateNonCalibrated(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
+ return result.Select(x => new CrossValidationResult(x.Model,
+ EvaluateNonCalibrated(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
}
///
@@ -287,13 +369,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string
/// train to the test set.
/// If not present in dataset we will generate random filled column based on provided .
/// Per-fold results: metrics, models, scored datasets.
- public (CalibratedBinaryClassificationMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
+ public CrossValidationResult[] CrossValidate(
IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label,
string stratificationColumn = null, uint? seed = null)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
- return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
+ return result.Select(x => new CrossValidationResult(x.Model,
+ Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
}
}
@@ -369,12 +452,13 @@ public ClusteringMetrics Evaluate(IDataView data,
/// If the is not provided, the random numbers generated to create it, will use this seed as value.
/// And if it is not provided, the default value will be used.
/// Per-fold results: metrics, models, scored datasets.
- public (ClusteringMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
+ public CrossValidationResult[] CrossValidate(
IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = null, string featuresColumn = null,
string stratificationColumn = null, uint? seed = null)
{
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
- return result.Select(x => (Evaluate(x.scoredTestSet, label: labelColumn, features: featuresColumn), x.model, x.scoredTestSet)).ToArray();
+ return result.Select(x => new CrossValidationResult(x.Model,
+ Evaluate(x.Scores, label: labelColumn, features: featuresColumn), x.Scores, x.Fold)).ToArray();
}
}
@@ -444,13 +528,14 @@ public MultiClassClassifierMetrics Evaluate(IDataView data, string label = Defau
/// If the is not provided, the random numbers generated to create it, will use this seed as value.
/// And if it is not provided, the default value will be used.
/// Per-fold results: metrics, models, scored datasets.
- public (MultiClassClassifierMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
+ public CrossValidationResult[] CrossValidate(
IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label,
string stratificationColumn = null, uint? seed = null)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
- return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
+ return result.Select(x => new CrossValidationResult(x.Model,
+ Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
}
}
@@ -511,13 +596,14 @@ public RegressionMetrics Evaluate(IDataView data, string label = DefaultColumnNa
/// If the is not provided, the random numbers generated to create it, will use this seed as value.
/// And if it is not provided, the default value will be used.
/// Per-fold results: metrics, models, scored datasets.
- public (RegressionMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
+ public CrossValidationResult[] CrossValidate(
IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label,
string stratificationColumn = null, uint? seed = null)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
- return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
+ return result.Select(x => new CrossValidationResult(x.Model,
+ Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
}
}
diff --git a/src/Microsoft.ML.Data/Transforms/BootstrapSamplingTransformer.cs b/src/Microsoft.ML.Data/Transforms/BootstrapSamplingTransformer.cs
index cb61c00db1..ee3289396f 100644
--- a/src/Microsoft.ML.Data/Transforms/BootstrapSamplingTransformer.cs
+++ b/src/Microsoft.ML.Data/Transforms/BootstrapSamplingTransformer.cs
@@ -129,7 +129,7 @@ private BootstrapSamplingTransformer(IHost host, ModelLoadContext ctx, IDataView
Host.CheckDecode(_poolSize >= 0);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs
index f9b5d6bf8e..8206e66723 100644
--- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs
+++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs
@@ -263,7 +263,7 @@ private static VersionInfo GetVersionInfo()
private const int VersionAddedAliases = 0x00010002;
private const int VersionTransformer = 0x00010003;
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -864,7 +864,7 @@ private protected override Func GetDependenciesCore(Func a
protected override Schema.DetachedColumn[] GetOutputColumnsCore() => _columns.Select(x => x.MakeSchemaColumn()).ToArray();
- public override void Save(ModelSaveContext ctx) => _parent.Save(ctx);
+ private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer)
{
diff --git a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs
index 7dfdf17d23..07e5784bad 100644
--- a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs
+++ b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs
@@ -166,7 +166,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
ctx.SetVersionInfo(GetVersionInfo());
SaveColumns(ctx);
diff --git a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs
index c3e6b50018..00bd42fee0 100644
--- a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs
+++ b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs
@@ -124,7 +124,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
///
/// The allows the user to specify columns to drop or keep from a given input.
///
- public sealed class ColumnSelectingTransformer : ITransformer, ICanSaveModel
+ public sealed class ColumnSelectingTransformer : ITransformer
{
internal const string Summary = "Selects which columns from the dataset to keep.";
internal const string UserName = "Select Columns Transform";
@@ -417,7 +417,9 @@ private static IDataTransform Create(IHostEnvironment env, Options options, IDat
return new SelectColumnsDataTransform(env, transform, new Mapper(transform, input.Schema), input);
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
+
+ internal void SaveModel(ModelSaveContext ctx)
{
ctx.SetVersionInfo(GetVersionInfo());
@@ -678,7 +680,7 @@ public RowCursor[] GetRowCursorSet(IEnumerable columnsNeeded, int
return cursors;
}
- public void Save(ModelSaveContext ctx) => _transform.Save(ctx);
+ void ICanSaveModel.Save(ModelSaveContext ctx) => _transform.SaveModel(ctx);
public Func GetDependencies(Func activeOutput)
{
diff --git a/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs b/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs
index 461a3f61e9..f6f46d62ac 100644
--- a/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs
+++ b/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs
@@ -172,7 +172,7 @@ private FeatureContributionCalculatingTransformer(IHostEnvironment env, ModelLoa
Normalize = ctx.Reader.ReadBoolByte();
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.SetVersionInfo(GetVersionInfo());
diff --git a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs
index 54fb453894..c14cdd49d9 100644
--- a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs
@@ -163,7 +163,7 @@ public static Bindings Create(ModelLoadContext ctx, Schema input)
return new Bindings(useCounter, states, input, false, names);
}
- public void Save(ModelSaveContext ctx)
+ internal void Save(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
@@ -309,7 +309,7 @@ public static GenerateNumberTransform Create(IHostEnvironment env, ModelLoadCont
return h.Apply("Loading Model", ch => new GenerateNumberTransform(h, ctx, input));
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs
index f98f4baa47..d60794d2d8 100644
--- a/src/Microsoft.ML.Data/Transforms/Hashing.cs
+++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs
@@ -278,7 +278,7 @@ private HashingTransformer(IHost host, ModelLoadContext ctx)
TextModelHelper.LoadAll(Host, ctx, columnsLength, out _keyValues, out _kvTypes);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs
index ffb6c61da3..eaaf566e95 100644
--- a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs
+++ b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs
@@ -144,7 +144,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs
index d2c7a5e260..cc4ed51905 100644
--- a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs
+++ b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs
@@ -143,7 +143,7 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(KeyToVectorMappingTransformer).Assembly.FullName);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs
index 408e62e330..cf0b9790ea 100644
--- a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs
@@ -120,7 +120,7 @@ public static LabelConvertTransform Create(IHostEnvironment env, ModelLoadContex
});
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs
index 21b0b13bce..e848312ee9 100644
--- a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs
@@ -98,7 +98,7 @@ public static LabelIndicatorTransform Create(IHostEnvironment env,
ch => new LabelIndicatorTransform(h, args, input));
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Transforms/NAFilter.cs b/src/Microsoft.ML.Data/Transforms/NAFilter.cs
index ffc04cbc80..956f201cd6 100644
--- a/src/Microsoft.ML.Data/Transforms/NAFilter.cs
+++ b/src/Microsoft.ML.Data/Transforms/NAFilter.cs
@@ -169,7 +169,7 @@ public static NAFilter Create(IHostEnvironment env, ModelLoadContext ctx, IDataV
return h.Apply("Loading Model", ch => new NAFilter(h, ctx, input));
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Transforms/NopTransform.cs b/src/Microsoft.ML.Data/Transforms/NopTransform.cs
index 92f701d704..234c060be2 100644
--- a/src/Microsoft.ML.Data/Transforms/NopTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/NopTransform.cs
@@ -89,7 +89,7 @@ private NopTransform(IHost host, ModelLoadContext ctx, IDataView input)
// Nothing :)
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
index 4888aeb4c4..f15f3e77fc 100644
--- a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
@@ -369,7 +369,9 @@ private AffineColumnFunction(IHost host)
Host = host;
}
- public abstract void Save(ModelSaveContext ctx);
+ void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
+
+ private protected abstract void SaveModel(ModelSaveContext ctx);
public abstract JToken PfaInfo(BoundPfaContext ctx, JToken srcToken);
public bool CanSaveOnnx(OnnxContext ctx) => true;
@@ -485,7 +487,9 @@ private CdfColumnFunction(IHost host)
Host = host;
}
- public abstract void Save(ModelSaveContext ctx);
+ void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
+
+ private protected abstract void SaveModel(ModelSaveContext ctx);
public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) => null;
@@ -614,7 +618,9 @@ protected BinColumnFunction(IHost host)
Host = host;
}
- public abstract void Save(ModelSaveContext ctx);
+ void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
+
+ private protected abstract void SaveModel(ModelSaveContext ctx);
public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) => null;
diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs
index 54b4c11ce4..e2050bdcdc 100644
--- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs
@@ -568,7 +568,7 @@ private void GetResult(ref TFloat input, ref TFloat value)
value = (input - Offset) * Scale;
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
AffineNormSerializationUtils.SaveModel(ctx, 1, null, new[] { Scale }, new[] { Offset }, saveText: true);
}
@@ -628,7 +628,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr
return new ImplVec(host, scales, offsets, (offsets != null && nz.Count < cv / 2) ? nz.ToArray() : null);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
AffineNormSerializationUtils.SaveModel(ctx, Scale.Length, null, Scale, Offset, saveText: true);
}
@@ -892,7 +892,7 @@ private void GetResult(ref TFloat input, ref TFloat value)
value = CdfUtils.Cdf(val, Mean, Stddev);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
ctx.CheckAtModel();
@@ -947,7 +947,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr
return new ImplVec(host, mean, stddev, useLog);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
ctx.CheckAtModel();
@@ -1071,7 +1071,7 @@ public ImplOne(IHost host, TFloat[] binUpperBounds, bool fixZero)
return new ImplOne(host, binUpperBounds[0], fixZero);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
ctx.CheckAtModel();
@@ -1157,7 +1157,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr
return new ImplVec(host, binUpperBounds, fixZero);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs
index 482e5d8f71..7a4abcdfbb 100644
--- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs
@@ -570,7 +570,7 @@ private void GetResult(ref TFloat input, ref TFloat value)
value = (input - Offset) * Scale;
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
AffineNormSerializationUtils.SaveModel(ctx, 1, null, new[] { Scale }, new[] { Offset }, saveText: true);
}
@@ -629,7 +629,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr
return new ImplVec(host, scales, offsets, (offsets != null && nz.Count < cv / 2) ? nz.ToArray() : null);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
AffineNormSerializationUtils.SaveModel(ctx, Scale.Length, null, Scale, Offset, saveText: true);
}
@@ -896,7 +896,7 @@ private void GetResult(ref TFloat input, ref TFloat value)
value = CdfUtils.Cdf(val, Mean, Stddev);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
ctx.CheckAtModel();
@@ -951,7 +951,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr
return new ImplVec(host, mean, stddev, useLog);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
ctx.CheckAtModel();
@@ -1076,7 +1076,7 @@ public ImplOne(IHost host, TFloat[] binUpperBounds, bool fixZero)
return new ImplOne(host, binUpperBounds[0], fixZero);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
ctx.CheckAtModel();
@@ -1162,7 +1162,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr
return new ImplVec(host, binUpperBounds, fixZero);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs
index bff9503614..78edabbe0c 100644
--- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs
+++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs
@@ -540,7 +540,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs
index f1cf1f52b3..4a4a94f7de 100644
--- a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs
+++ b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs
@@ -110,7 +110,7 @@ private protected override Func GetDependenciesCore(Func a
return col => active[col];
}
- public override void Save(ModelSaveContext ctx) => _parent.Save(ctx);
+ private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
}
}
}
diff --git a/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs b/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs
index c5e95a7683..f226d28a30 100644
--- a/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs
+++ b/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs
@@ -133,7 +133,9 @@ protected PerGroupTransformBase(IHostEnvironment env, ModelLoadContext ctx, IDat
GroupCol = ctx.LoadNonEmptyString();
}
- public virtual void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
+
+ private protected virtual void SaveModel(ModelSaveContext ctx)
{
Host.AssertValue(ctx);
diff --git a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs
index 838f1698b2..e6f356ffe4 100644
--- a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs
+++ b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs
@@ -177,7 +177,7 @@ public static RangeFilter Create(IHostEnvironment env, ModelLoadContext ctx, IDa
return h.Apply("Loading Model", ch => new RangeFilter(h, ctx, input));
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs
index 0ed952c379..0f6b6a5dc3 100644
--- a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs
+++ b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs
@@ -159,7 +159,7 @@ public static RowShufflingTransformer Create(IHostEnvironment env, ModelLoadCont
return h.Apply("Loading Model", ch => new RowShufflingTransformer(h, ctx, input));
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs
index 104cdad96d..13622d6063 100644
--- a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs
+++ b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs
@@ -13,7 +13,7 @@ namespace Microsoft.ML.Data
///
/// Base class for transformer which produce new columns, but doesn't affect existing ones.
///
- public abstract class RowToRowTransformerBase : ITransformer, ICanSaveModel
+ public abstract class RowToRowTransformerBase : ITransformer
{
protected readonly IHost Host;
@@ -23,7 +23,9 @@ protected RowToRowTransformerBase(IHost host)
Host = host;
}
- public abstract void Save(ModelSaveContext ctx);
+ void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
+
+ private protected abstract void SaveModel(ModelSaveContext ctx);
public bool IsRowToRowMapper => true;
@@ -109,7 +111,9 @@ Func IRowMapper.GetDependencies(Func activeOutput)
[BestFriend]
private protected abstract Func GetDependenciesCore(Func activeOutput);
- public abstract void Save(ModelSaveContext ctx);
+ void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
+
+ private protected abstract void SaveModel(ModelSaveContext ctx);
public ITransformer GetTransformer() => _parent;
}
diff --git a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs
index 6c8a316a7c..b83e701597 100644
--- a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs
+++ b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs
@@ -168,7 +168,7 @@ public static SkipTakeFilter Create(IHostEnvironment env, ModelLoadContext ctx,
}
///Saves class data to context
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs
index 9b57d6d031..037521e7e1 100644
--- a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs
+++ b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs
@@ -324,7 +324,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Transforms/TransformBase.cs b/src/Microsoft.ML.Data/Transforms/TransformBase.cs
index ec8dec1af9..0ab2b61c27 100644
--- a/src/Microsoft.ML.Data/Transforms/TransformBase.cs
+++ b/src/Microsoft.ML.Data/Transforms/TransformBase.cs
@@ -43,7 +43,9 @@ protected TransformBase(IHost host, IDataView input)
Source = input;
}
- public abstract void Save(ModelSaveContext ctx);
+ void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
+
+ private protected abstract void SaveModel(ModelSaveContext ctx);
public abstract long? GetRowCount();
@@ -388,7 +390,7 @@ public static Bindings Create(OneToOneTransformBase parent, ModelLoadContext ctx
return new Bindings(parent, infos, inputSchema, false, names);
}
- public void Save(ModelSaveContext ctx)
+ internal void Save(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs
index 5bdbdb0a29..7dcd1ab641 100644
--- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs
+++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs
@@ -206,7 +206,7 @@ internal TypeConvertingTransformer(IHostEnvironment env, params TypeConvertingEs
_columns = columns.ToArray();
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs
index b6469b55cc..c74650a7a0 100644
--- a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs
+++ b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs
@@ -740,7 +740,7 @@ protected static PrimitiveType GetPrimitiveType(Type rawType, out bool isVectorT
return ColumnTypeExtensions.PrimitiveTypeFromKind(kind);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.SetVersionInfo(GetVersionInfo());
diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs
index 011d491608..dacd543f29 100644
--- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs
+++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs
@@ -641,7 +641,7 @@ private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] info
return termMap;
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
diff --git a/src/Microsoft.ML.Ensemble/Batch.cs b/src/Microsoft.ML.Ensemble/Batch.cs
index caaf6bf4f3..756b095c72 100644
--- a/src/Microsoft.ML.Ensemble/Batch.cs
+++ b/src/Microsoft.ML.Ensemble/Batch.cs
@@ -4,7 +4,7 @@
using Microsoft.ML.Data;
-namespace Microsoft.ML.Ensemble
+namespace Microsoft.ML.Trainers.Ensemble
{
internal sealed class Batch
{
diff --git a/src/Microsoft.ML.Ensemble/EnsembleUtils.cs b/src/Microsoft.ML.Ensemble/EnsembleUtils.cs
index 2af9b96f3b..342a33b9bd 100644
--- a/src/Microsoft.ML.Ensemble/EnsembleUtils.cs
+++ b/src/Microsoft.ML.Ensemble/EnsembleUtils.cs
@@ -7,7 +7,7 @@
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Ensemble
+namespace Microsoft.ML.Trainers.Ensemble
{
internal static class EnsembleUtils
{
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs
index cf88c501b8..6ce5673797 100644
--- a/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs
@@ -10,15 +10,13 @@
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble;
-using Microsoft.ML.Ensemble.EntryPoints;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(void), typeof(EnsembleCreator), null, typeof(SignatureEntryPointModule), "CreateEnsemble")]
-namespace Microsoft.ML.EntryPoints
+namespace Microsoft.ML.Trainers.Ensemble
{
///
/// A component to combine given models into an ensemble model.
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs b/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs
index 83d8fd0a96..075b3659ec 100644
--- a/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs
@@ -2,16 +2,15 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Ensemble.EntryPoints;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.DiversityMeasure;
using Microsoft.ML.EntryPoints;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.DiversityMeasure;
[assembly: EntryPointModule(typeof(DisagreementDiversityFactory))]
[assembly: EntryPointModule(typeof(RegressionDisagreementDiversityFactory))]
[assembly: EntryPointModule(typeof(MultiDisagreementDiversityFactory))]
-namespace Microsoft.ML.Ensemble.EntryPoints
+namespace Microsoft.ML.Trainers.Ensemble
{
[TlcModule.Component(Name = DisagreementDiversityMeasure.LoadName, FriendlyName = DisagreementDiversityMeasure.UserName)]
internal sealed class DisagreementDiversityFactory : ISupportBinaryDiversityMeasureFactory
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs
index 6cf8b418bc..afec3af471 100644
--- a/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs
@@ -3,12 +3,12 @@
// See the LICENSE file in the project root for more information.
using Microsoft.ML;
-using Microsoft.ML.Ensemble.EntryPoints;
using Microsoft.ML.EntryPoints;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(void), typeof(Ensemble), null, typeof(SignatureEntryPointModule), "TrainEnsemble")]
-namespace Microsoft.ML.Ensemble.EntryPoints
+namespace Microsoft.ML.Trainers.Ensemble
{
internal static class Ensemble
{
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/FeatureSelector.cs b/src/Microsoft.ML.Ensemble/EntryPoints/FeatureSelector.cs
index c2d7862a31..8af31a0723 100644
--- a/src/Microsoft.ML.Ensemble/EntryPoints/FeatureSelector.cs
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/FeatureSelector.cs
@@ -2,15 +2,14 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Ensemble.EntryPoints;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.FeatureSelector;
using Microsoft.ML.EntryPoints;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.FeatureSelector;
[assembly: EntryPointModule(typeof(AllFeatureSelectorFactory))]
[assembly: EntryPointModule(typeof(RandomFeatureSelector))]
-namespace Microsoft.ML.Ensemble.EntryPoints
+namespace Microsoft.ML.Trainers.Ensemble
{
[TlcModule.Component(Name = AllFeatureSelector.LoadName, FriendlyName = AllFeatureSelector.UserName)]
public sealed class AllFeatureSelectorFactory : ISupportFeatureSelectorFactory
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/OutputCombiner.cs b/src/Microsoft.ML.Ensemble/EntryPoints/OutputCombiner.cs
index 7583d72cfb..5af5cdf487 100644
--- a/src/Microsoft.ML.Ensemble/EntryPoints/OutputCombiner.cs
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/OutputCombiner.cs
@@ -2,9 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Ensemble.EntryPoints;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.EntryPoints;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: EntryPointModule(typeof(AverageFactory))]
[assembly: EntryPointModule(typeof(MedianFactory))]
@@ -18,7 +17,7 @@
[assembly: EntryPointModule(typeof(VotingFactory))]
[assembly: EntryPointModule(typeof(WeightedAverage))]
-namespace Microsoft.ML.Ensemble.EntryPoints
+namespace Microsoft.ML.Trainers.Ensemble
{
[TlcModule.Component(Name = Average.LoadName, FriendlyName = Average.UserName)]
public sealed class AverageFactory : ISupportBinaryOutputCombinerFactory, ISupportRegressionOutputCombinerFactory
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs
index c80dba1738..2bafa85a2f 100644
--- a/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs
@@ -4,13 +4,13 @@
using Microsoft.Data.DataView;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.EntryPoints;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Calibration;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: EntryPointModule(typeof(PipelineEnsemble))]
-namespace Microsoft.ML.Ensemble.EntryPoints
+namespace Microsoft.ML.Trainers.Ensemble
{
internal static class PipelineEnsemble
{
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/SubModelSelector.cs b/src/Microsoft.ML.Ensemble/EntryPoints/SubModelSelector.cs
index 10b1551c62..867833f999 100644
--- a/src/Microsoft.ML.Ensemble/EntryPoints/SubModelSelector.cs
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/SubModelSelector.cs
@@ -2,10 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Ensemble.EntryPoints;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.SubModelSelector;
using Microsoft.ML.EntryPoints;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.SubModelSelector;
[assembly: EntryPointModule(typeof(AllSelectorFactory))]
[assembly: EntryPointModule(typeof(AllSelectorMultiClassFactory))]
@@ -16,7 +15,7 @@
[assembly: EntryPointModule(typeof(BestPerformanceSelector))]
[assembly: EntryPointModule(typeof(BestPerformanceSelectorMultiClass))]
-namespace Microsoft.ML.Ensemble.EntryPoints
+namespace Microsoft.ML.Trainers.Ensemble
{
[TlcModule.Component(Name = AllSelector.LoadName, FriendlyName = AllSelector.UserName)]
public sealed class AllSelectorFactory : ISupportBinarySubModelSelectorFactory, ISupportRegressionSubModelSelectorFactory
diff --git a/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs b/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs
index 733573d39f..d583a1b125 100644
--- a/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs
+++ b/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs
@@ -6,7 +6,7 @@
using System.Collections.Generic;
using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Ensemble
+namespace Microsoft.ML.Trainers.Ensemble
{
internal sealed class FeatureSubsetModel
{
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs
index 05b85b37ec..e78e63cecc 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs
@@ -4,15 +4,15 @@
using System;
using Microsoft.ML;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.Model;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(Average), null, typeof(SignatureCombiner), Average.UserName)]
[assembly: LoadableClass(typeof(Average), null, typeof(SignatureLoadModel), Average.UserName, Average.LoaderSignature)]
-namespace Microsoft.ML.Ensemble.OutputCombiners
+namespace Microsoft.ML.Trainers.Ensemble
{
- public sealed class Average : BaseAverager, ICanSaveModel, IRegressionOutputCombiner
+ public sealed class Average : BaseAverager, IRegressionOutputCombiner
{
public const string UserName = "Average";
public const string LoadName = "Average";
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs
index 1ffbe13ec1..1d8f56e3fe 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs
@@ -5,9 +5,9 @@
using System;
using Microsoft.ML.Model;
-namespace Microsoft.ML.Ensemble.OutputCombiners
+namespace Microsoft.ML.Trainers.Ensemble
{
- public abstract class BaseAverager : IBinaryOutputCombiner
+ public abstract class BaseAverager : IBinaryOutputCombiner, ICanSaveModel
{
protected readonly IHost Host;
public BaseAverager(IHostEnvironment env, string name)
@@ -30,7 +30,7 @@ protected BaseAverager(IHostEnvironment env, string name, ModelLoadContext ctx)
Host.CheckDecode(cbFloat == sizeof(Single));
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs
index 73e6f4ea7e..044ccfb44e 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs
@@ -8,7 +8,7 @@
using Microsoft.ML.Model;
using Microsoft.ML.Numeric;
-namespace Microsoft.ML.Ensemble.OutputCombiners
+namespace Microsoft.ML.Trainers.Ensemble
{
public abstract class BaseMultiAverager : BaseMultiCombiner
{
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs
index 5e8296862b..d0ee4583b3 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs
@@ -9,9 +9,9 @@
using Microsoft.ML.Model;
using Microsoft.ML.Numeric;
-namespace Microsoft.ML.Ensemble.OutputCombiners
+namespace Microsoft.ML.Trainers.Ensemble
{
- public abstract class BaseMultiCombiner : IMultiClassOutputCombiner
+ public abstract class BaseMultiCombiner : IMultiClassOutputCombiner, ICanSaveModel
{
protected readonly IHost Host;
@@ -49,7 +49,7 @@ internal BaseMultiCombiner(IHostEnvironment env, string name, ModelLoadContext c
Normalize = ctx.Reader.ReadBoolByte();
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs
index ba75e05080..59700cb18c 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs
@@ -7,7 +7,7 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
-namespace Microsoft.ML.Ensemble.OutputCombiners
+namespace Microsoft.ML.Trainers.Ensemble
{
internal abstract class BaseScalarStacking : BaseStacking
{
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
index 81d0c545f9..a172a6bce4 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
@@ -13,9 +13,9 @@
using Microsoft.ML.Model;
using Microsoft.ML.Training;
-namespace Microsoft.ML.Ensemble.OutputCombiners
+namespace Microsoft.ML.Trainers.Ensemble
{
- internal abstract class BaseStacking : IStackingTrainer
+ internal abstract class BaseStacking : IStackingTrainer, ICanSaveModel
{
public abstract class ArgumentsBase
{
@@ -67,7 +67,7 @@ private protected BaseStacking(IHostEnvironment env, string name, ModelLoadConte
CheckMeta();
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
Host.Check(Meta != null, "Can't save an untrained Stacking combiner");
Host.CheckValue(ctx, nameof(ctx));
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs
index bbbbecf217..e054cd2fa8 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs
@@ -7,7 +7,7 @@
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
-namespace Microsoft.ML.Ensemble.OutputCombiners
+namespace Microsoft.ML.Trainers.Ensemble
{
///
/// Signature for combiners.
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs
index 3b94196eb9..88a9d73ca3 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs
@@ -4,14 +4,14 @@
using System;
using Microsoft.ML;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(Median), null, typeof(SignatureCombiner), Median.UserName, Median.LoadName)]
[assembly: LoadableClass(typeof(Median), null, typeof(SignatureLoadModel), Median.UserName, Median.LoaderSignature)]
-namespace Microsoft.ML.Ensemble.OutputCombiners
+namespace Microsoft.ML.Trainers.Ensemble
{
///
/// Generic interface for combining outputs of multiple models
@@ -59,7 +59,7 @@ public static Median Create(IHostEnvironment env, ModelLoadContext ctx)
return new Median(env, ctx);
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs
index 1ba5cdf028..3cf424b50f 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs
@@ -5,18 +5,18 @@
using System;
using Microsoft.ML;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Model;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(MultiAverage), typeof(MultiAverage.Arguments), typeof(SignatureCombiner),
Average.UserName, MultiAverage.LoadName)]
[assembly: LoadableClass(typeof(MultiAverage), null, typeof(SignatureLoadModel), Average.UserName,
MultiAverage.LoadName, MultiAverage.LoaderSignature)]
-namespace Microsoft.ML.Ensemble.OutputCombiners
+namespace Microsoft.ML.Trainers.Ensemble
{
- public sealed class MultiAverage : BaseMultiAverager, ICanSaveModel
+ public sealed class MultiAverage : BaseMultiAverager
{
public const string LoadName = "MultiAverage";
public const string LoaderSignature = "MultiAverageCombiner";
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs
index 477a658355..3b44413397 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs
@@ -5,21 +5,21 @@
using System;
using Microsoft.ML;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(MultiMedian), typeof(MultiMedian.Arguments), typeof(SignatureCombiner),
Median.UserName, MultiMedian.LoadName)]
[assembly: LoadableClass(typeof(MultiMedian), null, typeof(SignatureLoadModel), Median.UserName, MultiMedian.LoaderSignature)]
-namespace Microsoft.ML.Ensemble.OutputCombiners
+namespace Microsoft.ML.Trainers.Ensemble
{
///
/// Generic interface for combining outputs of multiple models
///
- public sealed class MultiMedian : BaseMultiCombiner, ICanSaveModel
+ public sealed class MultiMedian : BaseMultiCombiner
{
public const string LoadName = "MultiMedian";
public const string LoaderSignature = "MultiMedianCombiner";
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
index 3ec8ea85c8..01011a99d4 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
@@ -6,10 +6,10 @@
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Model;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(MultiStacking), typeof(MultiStacking.Arguments), typeof(SignatureCombiner),
Stacking.UserName, MultiStacking.LoadName)]
@@ -17,10 +17,10 @@
[assembly: LoadableClass(typeof(MultiStacking), null, typeof(SignatureLoadModel),
Stacking.UserName, MultiStacking.LoaderSignature)]
-namespace Microsoft.ML.Ensemble.OutputCombiners
+namespace Microsoft.ML.Trainers.Ensemble
{
using TVectorPredictor = IPredictorProducing>;
- internal sealed class MultiStacking : BaseStacking>, ICanSaveModel, IMultiClassOutputCombiner
+ internal sealed class MultiStacking : BaseStacking>, IMultiClassOutputCombiner
{
public const string LoadName = "MultiStacking";
public const string LoaderSignature = "MultiStackingCombiner";
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs
index a8ba932926..f943a4176c 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs
@@ -5,19 +5,19 @@
using System;
using Microsoft.ML;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Numeric;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(MultiVoting), null, typeof(SignatureCombiner), Voting.UserName, MultiVoting.LoadName)]
[assembly: LoadableClass(typeof(MultiVoting), null, typeof(SignatureLoadModel), Voting.UserName, MultiVoting.LoaderSignature)]
-namespace Microsoft.ML.Ensemble.OutputCombiners
+namespace Microsoft.ML.Trainers.Ensemble
{
// REVIEW: Why is MultiVoting based on BaseMultiCombiner? Normalizing the model outputs
// is senseless, so the base adds no real functionality.
- public sealed class MultiVoting : BaseMultiCombiner, ICanSaveModel
+ public sealed class MultiVoting : BaseMultiCombiner
{
public const string LoadName = "MultiVoting";
public const string LoaderSignature = "MultiVotingCombiner";
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs
index deef23abcd..aecc3963bc 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs
@@ -6,10 +6,10 @@
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Model;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(MultiWeightedAverage), typeof(MultiWeightedAverage.Arguments), typeof(SignatureCombiner),
MultiWeightedAverage.UserName, MultiWeightedAverage.LoadName)]
@@ -17,12 +17,12 @@
[assembly: LoadableClass(typeof(MultiWeightedAverage), null, typeof(SignatureLoadModel),
MultiWeightedAverage.UserName, MultiWeightedAverage.LoaderSignature)]
-namespace Microsoft.ML.Ensemble.OutputCombiners
+namespace Microsoft.ML.Trainers.Ensemble
{
///
/// Generic interface for combining outputs of multiple models
///
- public sealed class MultiWeightedAverage : BaseMultiAverager, IWeightedAverager, ICanSaveModel
+ public sealed class MultiWeightedAverage : BaseMultiAverager, IWeightedAverager
{
public const string UserName = "Multi Weighted Average";
public const string LoadName = "MultiWeightedAverage";
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
index ab57cfe9aa..ae9ef67db8 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
@@ -4,10 +4,10 @@
using System;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Model;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(RegressionStacking), typeof(RegressionStacking.Arguments), typeof(SignatureCombiner),
Stacking.UserName, RegressionStacking.LoadName)]
@@ -15,11 +15,11 @@
[assembly: LoadableClass(typeof(RegressionStacking), null, typeof(SignatureLoadModel),
Stacking.UserName, RegressionStacking.LoaderSignature)]
-namespace Microsoft.ML.Ensemble.OutputCombiners
+namespace Microsoft.ML.Trainers.Ensemble
{
using TScalarPredictor = IPredictorProducing;
- internal sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner, ICanSaveModel
+ internal sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner
{
public const string LoadName = "RegressionStacking";
public const string LoaderSignature = "RegressionStacking";
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
index 891963dbea..93fe1d3240 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
@@ -5,18 +5,18 @@
using System;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Model;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(Stacking), typeof(Stacking.Arguments), typeof(SignatureCombiner), Stacking.UserName, Stacking.LoadName)]
[assembly: LoadableClass(typeof(Stacking), null, typeof(SignatureLoadModel), Stacking.UserName, Stacking.LoaderSignature)]
-namespace Microsoft.ML.Ensemble.OutputCombiners
+namespace Microsoft.ML.Trainers.Ensemble
{
using TScalarPredictor = IPredictorProducing;
- internal sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner, ICanSaveModel
+ internal sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner
{
public const string UserName = "Stacking";
public const string LoadName = "Stacking";
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs
index b18de67b18..17c569d8ce 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs
@@ -4,14 +4,14 @@
using System;
using Microsoft.ML;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(Voting), null, typeof(SignatureCombiner), Voting.UserName, Voting.LoadName)]
[assembly: LoadableClass(typeof(Voting), null, typeof(SignatureLoadModel), Voting.UserName, Voting.LoaderSignature)]
-namespace Microsoft.ML.Ensemble.OutputCombiners
+namespace Microsoft.ML.Trainers.Ensemble
{
public sealed class Voting : IBinaryOutputCombiner, ICanSaveModel
{
@@ -57,7 +57,7 @@ public static Voting Create(IHostEnvironment env, ModelLoadContext ctx)
return new Voting(env, ctx);
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs
index ec1d2dcd11..e5aa458768 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs
@@ -6,10 +6,10 @@
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Model;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(WeightedAverage), typeof(WeightedAverage.Arguments), typeof(SignatureCombiner),
WeightedAverage.UserName, WeightedAverage.LoadName)]
@@ -17,9 +17,9 @@
[assembly: LoadableClass(typeof(WeightedAverage), null, typeof(SignatureLoadModel),
WeightedAverage.UserName, WeightedAverage.LoaderSignature)]
-namespace Microsoft.ML.Ensemble.OutputCombiners
+namespace Microsoft.ML.Trainers.Ensemble
{
- public sealed class WeightedAverage : BaseAverager, IWeightedAverager, ICanSaveModel
+ public sealed class WeightedAverage : BaseAverager, IWeightedAverager
{
public const string UserName = "Weighted Average";
public const string LoadName = "WeightedAverage";
diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
index 2e5da36c3a..4f44a42204 100644
--- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
+++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
@@ -10,18 +10,17 @@
using Microsoft.Data.DataView;
using Microsoft.ML;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Calibration;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(SchemaBindablePipelineEnsembleBase), null, typeof(SignatureLoadModel),
SchemaBindablePipelineEnsembleBase.UserName, SchemaBindablePipelineEnsembleBase.LoaderSignature)]
-namespace Microsoft.ML.Ensemble
+namespace Microsoft.ML.Trainers.Ensemble
{
///
/// This class represents an ensemble predictor, where each predictor has its own featurization pipeline. It is
@@ -482,7 +481,7 @@ protected SchemaBindablePipelineEnsembleBase(IHostEnvironment env, ModelLoadCont
_inputCols[i] = ctx.LoadNonEmptyString();
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
Host.AssertValue(ctx);
ctx.SetVersionInfo(GetVersionInfo());
diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs
index b4106a7b1d..c2f40e0b1a 100644
--- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs
@@ -6,7 +6,7 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
-namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure
+namespace Microsoft.ML.Trainers.Ensemble.DiversityMeasure
{
internal abstract class BaseDisagreementDiversityMeasure : IDiversityMeasure
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs
index bb0127003f..4cb1cbd883 100644
--- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs
@@ -4,13 +4,13 @@
using System;
using Microsoft.ML;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.DiversityMeasure;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.DiversityMeasure;
[assembly: LoadableClass(typeof(DisagreementDiversityMeasure), null, typeof(SignatureEnsembleDiversityMeasure),
DisagreementDiversityMeasure.UserName, DisagreementDiversityMeasure.LoadName)]
-namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure
+namespace Microsoft.ML.Trainers.Ensemble.DiversityMeasure
{
internal sealed class DisagreementDiversityMeasure : BaseDisagreementDiversityMeasure, IBinaryDiversityMeasure
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs
index b182ad9963..04c9e7f2d2 100644
--- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs
@@ -4,7 +4,7 @@
using System;
-namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure
+namespace Microsoft.ML.Trainers.Ensemble.DiversityMeasure
{
internal sealed class ModelDiversityMetric
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs
index ffb6839818..ccdd4d6b0b 100644
--- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs
@@ -5,14 +5,14 @@
using System;
using Microsoft.ML;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.DiversityMeasure;
using Microsoft.ML.Numeric;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.DiversityMeasure;
[assembly: LoadableClass(typeof(MultiDisagreementDiversityMeasure), null, typeof(SignatureEnsembleDiversityMeasure),
DisagreementDiversityMeasure.UserName, MultiDisagreementDiversityMeasure.LoadName)]
-namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure
+namespace Microsoft.ML.Trainers.Ensemble.DiversityMeasure
{
internal sealed class MultiDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure>, IMulticlassDiversityMeasure
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs
index 2b60e44f1c..c1b37411b6 100644
--- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs
@@ -4,13 +4,13 @@
using System;
using Microsoft.ML;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.DiversityMeasure;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.DiversityMeasure;
[assembly: LoadableClass(typeof(RegressionDisagreementDiversityMeasure), null, typeof(SignatureEnsembleDiversityMeasure),
DisagreementDiversityMeasure.UserName, RegressionDisagreementDiversityMeasure.LoadName)]
-namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure
+namespace Microsoft.ML.Trainers.Ensemble.DiversityMeasure
{
internal sealed class RegressionDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure, IRegressionDiversityMeasure
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/AllFeatureSelector.cs b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/AllFeatureSelector.cs
index f2ffadf877..52c0a4752b 100644
--- a/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/AllFeatureSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/AllFeatureSelector.cs
@@ -5,13 +5,13 @@
using System;
using Microsoft.ML;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.FeatureSelector;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.FeatureSelector;
[assembly: LoadableClass(typeof(AllFeatureSelector), null, typeof(SignatureEnsembleFeatureSelector),
AllFeatureSelector.UserName, AllFeatureSelector.LoadName)]
-namespace Microsoft.ML.Ensemble.Selector.FeatureSelector
+namespace Microsoft.ML.Trainers.Ensemble.FeatureSelector
{
internal sealed class AllFeatureSelector : IFeatureSelector
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs
index 93c4bd7603..5841d8c126 100644
--- a/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs
@@ -7,15 +7,15 @@
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.FeatureSelector;
using Microsoft.ML.EntryPoints;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.FeatureSelector;
using Microsoft.ML.Training;
[assembly: LoadableClass(typeof(RandomFeatureSelector), typeof(RandomFeatureSelector.Arguments),
typeof(SignatureEnsembleFeatureSelector), RandomFeatureSelector.UserName, RandomFeatureSelector.LoadName)]
-namespace Microsoft.ML.Ensemble.Selector.FeatureSelector
+namespace Microsoft.ML.Trainers.Ensemble.FeatureSelector
{
internal class RandomFeatureSelector : IFeatureSelector
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs
index 8ac30f3818..96642ccc8f 100644
--- a/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs
@@ -6,10 +6,10 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.Selector.DiversityMeasure;
using Microsoft.ML.EntryPoints;
+using Microsoft.ML.Trainers.Ensemble.DiversityMeasure;
-namespace Microsoft.ML.Ensemble.Selector
+namespace Microsoft.ML.Trainers.Ensemble
{
internal interface IDiversityMeasure
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/IFeatureSelector.cs b/src/Microsoft.ML.Ensemble/Selector/IFeatureSelector.cs
index e4eb986294..6ccc6da5d7 100644
--- a/src/Microsoft.ML.Ensemble/Selector/IFeatureSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/IFeatureSelector.cs
@@ -6,7 +6,7 @@
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
-namespace Microsoft.ML.Ensemble.Selector
+namespace Microsoft.ML.Trainers.Ensemble
{
internal interface IFeatureSelector
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs
index e5b35082ee..6f910f6a44 100644
--- a/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs
@@ -7,7 +7,7 @@
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
-namespace Microsoft.ML.Ensemble.Selector
+namespace Microsoft.ML.Trainers.Ensemble
{
internal interface ISubModelSelector
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/ISubsetSelector.cs b/src/Microsoft.ML.Ensemble/Selector/ISubsetSelector.cs
index 6ba7002508..8ffef7aba1 100644
--- a/src/Microsoft.ML.Ensemble/Selector/ISubsetSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/ISubsetSelector.cs
@@ -7,7 +7,7 @@
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
-namespace Microsoft.ML.Ensemble.Selector
+namespace Microsoft.ML.Trainers.Ensemble
{
internal interface ISubsetSelector
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelector.cs
index f88df3bfee..ce62ee9180 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelector.cs
@@ -4,12 +4,12 @@
using System;
using Microsoft.ML;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.SubModelSelector;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.SubModelSelector;
[assembly: LoadableClass(typeof(AllSelector), null, typeof(SignatureEnsembleSubModelSelector), AllSelector.UserName, AllSelector.LoadName)]
-namespace Microsoft.ML.Ensemble.Selector.SubModelSelector
+namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector
{
internal sealed class AllSelector : BaseSubModelSelector, IBinarySubModelSelector, IRegressionSubModelSelector
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelectorMultiClass.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelectorMultiClass.cs
index 2158b05733..6905579c3b 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelectorMultiClass.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelectorMultiClass.cs
@@ -5,13 +5,13 @@
using System;
using Microsoft.ML;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.SubModelSelector;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.SubModelSelector;
[assembly: LoadableClass(typeof(AllSelectorMultiClass), null, typeof(SignatureEnsembleSubModelSelector),
AllSelectorMultiClass.UserName, AllSelectorMultiClass.LoadName)]
-namespace Microsoft.ML.Ensemble.Selector.SubModelSelector
+namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector
{
internal sealed class AllSelectorMultiClass : BaseSubModelSelector>, IMulticlassSubModelSelector
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs
index 55add475f5..a13ef47b35 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs
@@ -8,7 +8,7 @@
using System.Reflection;
using Microsoft.ML.CommandLine;
-namespace Microsoft.ML.Ensemble.Selector.SubModelSelector
+namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector
{
internal abstract class BaseBestPerformanceSelector : SubModelDataSelector
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs
index e1edba726e..8f6d783c76 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs
@@ -6,11 +6,11 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.Selector.DiversityMeasure;
using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Trainers.Ensemble.DiversityMeasure;
using Microsoft.ML.Training;
-namespace Microsoft.ML.Ensemble.Selector.SubModelSelector
+namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector
{
internal abstract class BaseDiverseSelector : SubModelDataSelector
where TDiversityMetric : class, IDiversityMeasure
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs
index 8d1189d049..87a61192cd 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs
@@ -8,7 +8,7 @@
using Microsoft.Data.DataView;
using Microsoft.ML.Data;
-namespace Microsoft.ML.Ensemble.Selector.SubModelSelector
+namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector
{
internal abstract class BaseSubModelSelector : ISubModelSelector
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorBinary.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorBinary.cs
index e2a0147f01..869da7307e 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorBinary.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorBinary.cs
@@ -7,20 +7,17 @@
using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
-using Microsoft.ML.Ensemble.EntryPoints;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.DiversityMeasure;
-using Microsoft.ML.Ensemble.Selector.SubModelSelector;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.DiversityMeasure;
+using Microsoft.ML.Trainers.Ensemble.SubModelSelector;
[assembly: LoadableClass(typeof(BestDiverseSelectorBinary), typeof(BestDiverseSelectorBinary.Arguments),
typeof(SignatureEnsembleSubModelSelector), BestDiverseSelectorBinary.UserName, BestDiverseSelectorBinary.LoadName)]
-namespace Microsoft.ML.Ensemble.Selector.SubModelSelector
+namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector
{
- using TScalarPredictor = IPredictorProducing;
-
internal sealed class BestDiverseSelectorBinary : BaseDiverseSelector, IBinarySubModelSelector
{
public const string UserName = "Best Diverse Selector";
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorMultiClass.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorMultiClass.cs
index a25fe4d8d1..9a39b4f5b3 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorMultiClass.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorMultiClass.cs
@@ -8,20 +8,17 @@
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.EntryPoints;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.DiversityMeasure;
-using Microsoft.ML.Ensemble.Selector.SubModelSelector;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.DiversityMeasure;
+using Microsoft.ML.Trainers.Ensemble.SubModelSelector;
[assembly: LoadableClass(typeof(BestDiverseSelectorMultiClass), typeof(BestDiverseSelectorMultiClass.Arguments),
typeof(SignatureEnsembleSubModelSelector), BestDiverseSelectorMultiClass.UserName, BestDiverseSelectorMultiClass.LoadName)]
-namespace Microsoft.ML.Ensemble.Selector.SubModelSelector
+namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector
{
- using TVectorPredictor = IPredictorProducing>;
-
internal sealed class BestDiverseSelectorMultiClass : BaseDiverseSelector, IDiversityMeasure>>, IMulticlassSubModelSelector
{
public const string UserName = "Best Diverse Selector";
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorRegression.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorRegression.cs
index ef310eab7e..2022e7d057 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorRegression.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorRegression.cs
@@ -7,20 +7,17 @@
using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
-using Microsoft.ML.Ensemble.EntryPoints;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.DiversityMeasure;
-using Microsoft.ML.Ensemble.Selector.SubModelSelector;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.DiversityMeasure;
+using Microsoft.ML.Trainers.Ensemble.SubModelSelector;
[assembly: LoadableClass(typeof(BestDiverseSelectorRegression), typeof(BestDiverseSelectorRegression.Arguments),
typeof(SignatureEnsembleSubModelSelector), BestDiverseSelectorRegression.UserName, BestDiverseSelectorRegression.LoadName)]
-namespace Microsoft.ML.Ensemble.Selector.SubModelSelector
+namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector
{
- using TScalarPredictor = IPredictorProducing;
-
internal sealed class BestDiverseSelectorRegression : BaseDiverseSelector, IRegressionSubModelSelector
{
public const string UserName = "Best Diverse Selector";
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceRegressionSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceRegressionSelector.cs
index be7b7e8f35..de6fe8874d 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceRegressionSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceRegressionSelector.cs
@@ -6,15 +6,15 @@
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.SubModelSelector;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.SubModelSelector;
[assembly: LoadableClass(typeof(BestPerformanceRegressionSelector), typeof(BestPerformanceRegressionSelector.Arguments),
typeof(SignatureEnsembleSubModelSelector), BestPerformanceRegressionSelector.UserName, BestPerformanceRegressionSelector.LoadName)]
-namespace Microsoft.ML.Ensemble.Selector.SubModelSelector
+namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector
{
internal sealed class BestPerformanceRegressionSelector : BaseBestPerformanceSelector, IRegressionSubModelSelector
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelector.cs
index 9b24798276..f11b30556f 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelector.cs
@@ -6,15 +6,15 @@
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.SubModelSelector;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.SubModelSelector;
[assembly: LoadableClass(typeof(BestPerformanceSelector), typeof(BestPerformanceSelector.Arguments),
typeof(SignatureEnsembleSubModelSelector), BestPerformanceSelector.UserName, BestPerformanceSelector.LoadName)]
-namespace Microsoft.ML.Ensemble.Selector.SubModelSelector
+namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector
{
internal sealed class BestPerformanceSelector : BaseBestPerformanceSelector, IBinarySubModelSelector
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelectorMultiClass.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelectorMultiClass.cs
index 36f9635c87..34ce8719db 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelectorMultiClass.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelectorMultiClass.cs
@@ -6,15 +6,15 @@
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.SubModelSelector;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.SubModelSelector;
[assembly: LoadableClass(typeof(BestPerformanceSelectorMultiClass), typeof(BestPerformanceSelectorMultiClass.Arguments),
typeof(SignatureEnsembleSubModelSelector), BestPerformanceSelectorMultiClass.UserName, BestPerformanceSelectorMultiClass.LoadName)]
-namespace Microsoft.ML.Ensemble.Selector.SubModelSelector
+namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector
{
internal sealed class BestPerformanceSelectorMultiClass : BaseBestPerformanceSelector>, IMulticlassSubModelSelector
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/SubModelDataSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/SubModelDataSelector.cs
index 0ba6497f25..ff8dd74354 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/SubModelDataSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/SubModelDataSelector.cs
@@ -6,7 +6,7 @@
using Microsoft.ML.CommandLine;
using Microsoft.ML.Internal.Internallearn;
-namespace Microsoft.ML.Ensemble.Selector.SubModelSelector
+namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector
{
internal abstract class SubModelDataSelector : BaseSubModelSelector
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/AllInstanceSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/AllInstanceSelector.cs
index 43203e9ca9..1c1c15c3cd 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/AllInstanceSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/AllInstanceSelector.cs
@@ -5,16 +5,16 @@
using System;
using System.Collections.Generic;
using Microsoft.ML;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.SubsetSelector;
using Microsoft.ML.EntryPoints;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.SubsetSelector;
[assembly: LoadableClass(typeof(AllInstanceSelector), typeof(AllInstanceSelector.Arguments),
typeof(SignatureEnsembleDataSelector), AllInstanceSelector.UserName, AllInstanceSelector.LoadName)]
[assembly: EntryPointModule(typeof(AllInstanceSelector))]
-namespace Microsoft.ML.Ensemble.Selector.SubsetSelector
+namespace Microsoft.ML.Trainers.Ensemble.SubsetSelector
{
internal sealed class AllInstanceSelector : BaseSubsetSelector
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BaseSubsetSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BaseSubsetSelector.cs
index a17ea20d88..3305574d36 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BaseSubsetSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BaseSubsetSelector.cs
@@ -6,10 +6,9 @@
using System.Collections.Generic;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.EntryPoints;
using Microsoft.ML.Transforms;
-namespace Microsoft.ML.Ensemble.Selector.SubsetSelector
+namespace Microsoft.ML.Trainers.Ensemble.SubsetSelector
{
internal abstract class BaseSubsetSelector : ISubsetSelector
where TArgs : BaseSubsetSelector.ArgumentsBase
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs
index c3a98e47c8..8006ba2195 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs
@@ -6,9 +6,9 @@
using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.SubsetSelector;
using Microsoft.ML.EntryPoints;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.SubsetSelector;
using Microsoft.ML.Transforms;
[assembly: LoadableClass(typeof(BootstrapSelector), typeof(BootstrapSelector.Arguments),
@@ -16,7 +16,7 @@
[assembly: EntryPointModule(typeof(BootstrapSelector))]
-namespace Microsoft.ML.Ensemble.Selector.SubsetSelector
+namespace Microsoft.ML.Trainers.Ensemble.SubsetSelector
{
internal sealed class BootstrapSelector : BaseSubsetSelector
{
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs
index a142189808..a7f52e6b7b 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs
@@ -6,9 +6,9 @@
using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.SubsetSelector;
using Microsoft.ML.EntryPoints;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.Ensemble.SubsetSelector;
using Microsoft.ML.Transforms;
[assembly: LoadableClass(typeof(RandomPartitionSelector), typeof(RandomPartitionSelector.Arguments),
@@ -16,7 +16,7 @@
[assembly: EntryPointModule(typeof(RandomPartitionSelector))]
-namespace Microsoft.ML.Ensemble.Selector.SubsetSelector
+namespace Microsoft.ML.Trainers.Ensemble.SubsetSelector
{
internal sealed class RandomPartitionSelector : BaseSubsetSelector
{
diff --git a/src/Microsoft.ML.Ensemble/Subset.cs b/src/Microsoft.ML.Ensemble/Subset.cs
index 743be33df6..e1e579fd7b 100644
--- a/src/Microsoft.ML.Ensemble/Subset.cs
+++ b/src/Microsoft.ML.Ensemble/Subset.cs
@@ -5,7 +5,7 @@
using System.Collections;
using Microsoft.ML.Data;
-namespace Microsoft.ML.Ensemble
+namespace Microsoft.ML.Trainers.Ensemble
{
internal sealed class Subset
{
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
index 83637c333f..1e734c875c 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
@@ -7,12 +7,8 @@
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
-using Microsoft.ML.Ensemble;
-using Microsoft.ML.Ensemble.EntryPoints;
-using Microsoft.ML.Ensemble.OutputCombiners;
-using Microsoft.ML.Ensemble.Selector;
using Microsoft.ML.Internal.Internallearn;
-using Microsoft.ML.Learners;
+using Microsoft.ML.Trainers.Ensemble;
using Microsoft.ML.Trainers.Online;
using Microsoft.ML.Training;
@@ -23,7 +19,7 @@
[assembly: LoadableClass(typeof(EnsembleTrainer), typeof(EnsembleTrainer.Arguments), typeof(SignatureModelCombiner),
"Binary Classification Ensemble Model Combiner", EnsembleTrainer.LoadNameValue, "pe", "ParallelEnsemble")]
-namespace Microsoft.ML.Ensemble
+namespace Microsoft.ML.Trainers.Ensemble
{
using TDistPredictor = IDistPredictorProducing;
using TScalarPredictor = IPredictorProducing;
diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs
index e63f2059e5..d425db8e4f 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs
@@ -9,16 +9,15 @@
using Microsoft.Data.DataView;
using Microsoft.ML;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
+using Microsoft.ML.Trainers.Ensemble;
// These are for deserialization from a model repository.
[assembly: LoadableClass(typeof(EnsembleDistributionModelParameters), null, typeof(SignatureLoadModel),
EnsembleDistributionModelParameters.UserName, EnsembleDistributionModelParameters.LoaderSignature)]
-namespace Microsoft.ML.Ensemble
+namespace Microsoft.ML.Trainers.Ensemble
{
using TDistPredictor = IDistPredictorProducing;
diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs
index 032ff39695..9eb56d5bfd 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs
@@ -7,17 +7,16 @@
using Microsoft.Data.DataView;
using Microsoft.ML;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Model;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(EnsembleModelParameters), null, typeof(SignatureLoadModel), EnsembleModelParameters.UserName,
EnsembleModelParameters.LoaderSignature)]
[assembly: EntryPointModule(typeof(EnsembleModelParameters))]
-namespace Microsoft.ML.Ensemble
+namespace Microsoft.ML.Trainers.Ensemble
{
///
/// A class for artifacts of ensembled models.
diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs
index 5e79bbe41e..f8568cc2d7 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs
@@ -6,12 +6,12 @@
using System.Collections.Generic;
using System.IO;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
+using Microsoft.ML.Trainers.Ensemble;
-namespace Microsoft.ML.Ensemble
+namespace Microsoft.ML.Trainers.Ensemble
{
public abstract class EnsembleModelParametersBase : ModelParametersBase,
IPredictorProducing, ICanSaveInTextFormat, ICanSaveSummary
diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
index ef7ebc6d4d..b9aad6228f 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
@@ -8,15 +8,13 @@
using System.Threading.Tasks;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble.OutputCombiners;
-using Microsoft.ML.Ensemble.Selector;
-using Microsoft.ML.Ensemble.Selector.SubsetSelector;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Trainers.Ensemble.SubsetSelector;
using Microsoft.ML.Training;
-namespace Microsoft.ML.Ensemble
+namespace Microsoft.ML.Trainers.Ensemble
{
using Stopwatch = System.Diagnostics.Stopwatch;
diff --git a/src/Microsoft.ML.Ensemble/Trainer/IModelCombiner.cs b/src/Microsoft.ML.Ensemble/Trainer/IModelCombiner.cs
index 85f55d8111..48af1cfcfa 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/IModelCombiner.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/IModelCombiner.cs
@@ -4,7 +4,7 @@
using System.Collections.Generic;
-namespace Microsoft.ML.Ensemble
+namespace Microsoft.ML.Trainers.Ensemble
{
public delegate void SignatureModelCombiner(PredictionKind kind);
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs
index 134e58d51d..163b256f1e 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs
@@ -7,17 +7,14 @@
using Microsoft.Data.DataView;
using Microsoft.ML;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble;
-using Microsoft.ML.Ensemble.OutputCombiners;
using Microsoft.ML.Model;
+using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(EnsembleMultiClassModelParameters), null, typeof(SignatureLoadModel),
EnsembleMultiClassModelParameters.UserName, EnsembleMultiClassModelParameters.LoaderSignature)]
-namespace Microsoft.ML.Ensemble
+namespace Microsoft.ML.Trainers.Ensemble
{
- using TVectorPredictor = IPredictorProducing>;
-
public sealed class EnsembleMultiClassModelParameters : EnsembleModelParametersBase>, IValueMapper
{
internal const string UserName = "Ensemble Multiclass Executor";
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
index d194dfc947..4420b15ba9 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
@@ -8,12 +8,8 @@
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble;
-using Microsoft.ML.Ensemble.EntryPoints;
-using Microsoft.ML.Ensemble.OutputCombiners;
-using Microsoft.ML.Ensemble.Selector;
using Microsoft.ML.Internal.Internallearn;
-using Microsoft.ML.Learners;
+using Microsoft.ML.Trainers.Ensemble;
using Microsoft.ML.Training;
[assembly: LoadableClass(MulticlassDataPartitionEnsembleTrainer.Summary, typeof(MulticlassDataPartitionEnsembleTrainer),
@@ -25,7 +21,7 @@
[assembly: LoadableClass(typeof(MulticlassDataPartitionEnsembleTrainer), typeof(MulticlassDataPartitionEnsembleTrainer.Arguments),
typeof(SignatureModelCombiner), "Multiclass Classification Ensemble Model Combiner", MulticlassDataPartitionEnsembleTrainer.LoadNameValue)]
-namespace Microsoft.ML.Ensemble
+namespace Microsoft.ML.Trainers.Ensemble
{
using TVectorPredictor = IPredictorProducing>;
///
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
index 569d4cd829..e3b7976d51 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
@@ -7,13 +7,8 @@
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
-using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble;
-using Microsoft.ML.Ensemble.EntryPoints;
-using Microsoft.ML.Ensemble.OutputCombiners;
-using Microsoft.ML.Ensemble.Selector;
using Microsoft.ML.Internal.Internallearn;
-using Microsoft.ML.Learners;
+using Microsoft.ML.Trainers.Ensemble;
using Microsoft.ML.Trainers.Online;
using Microsoft.ML.Training;
@@ -25,7 +20,7 @@
[assembly: LoadableClass(typeof(RegressionEnsembleTrainer), typeof(RegressionEnsembleTrainer.Arguments), typeof(SignatureModelCombiner),
"Regression Ensemble Model Combiner", RegressionEnsembleTrainer.LoadNameValue)]
-namespace Microsoft.ML.Ensemble
+namespace Microsoft.ML.Trainers.Ensemble
{
using TScalarPredictor = IPredictorProducing;
internal sealed class RegressionEnsembleTrainer : EnsembleTrainerBase ComputeTests(double[] scores)
- {
- List result = new List()
- {
- new TestResult("DominationLoss", ComputeDominationLoss(scores), 1, true, TestResult.ValueOperator.Sum),
- };
-
- return result;
- }
- }
-#endif
-}
diff --git a/src/Microsoft.ML.FastTree/Application/LogLossApplication.cs b/src/Microsoft.ML.FastTree/Application/LogLossApplication.cs
deleted file mode 100644
index 9dd896a401..0000000000
--- a/src/Microsoft.ML.FastTree/Application/LogLossApplication.cs
+++ /dev/null
@@ -1,169 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-namespace Microsoft.ML.Trainers.FastTree.Internal
-{
-#if OLD_DATALOAD
- public class LogLossCommandLineArgs : TrainingCommandLineArgs
- {
- public enum LogLossMode { Pairwise, Wholepage };
- [Argument(ArgumentType.LastOccurenceWins, HelpText = "Which style of log loss to use", ShortName = "llm")]
- public LogLossMode loglossmode = LogLossMode.Pairwise;
- [Argument(ArgumentType.LastOccurenceWins, HelpText = "log loss cefficient", ShortName = "llc")]
- public double loglosscoef = 1.0;
- }
-
- public class LogLossTrainingApplication : ApplicationBase
- {
- new LogLossCommandLineArgs cmd;
- private const string RegistrationName = "LogLossApplication";
-
- public LogLossTrainingApplication(IHostEnvironment env, string args, TrainingApplicationData data)
- : base(env, RegistrationName, args, data)
- {
- base.cmd = this.cmd = new LogLossCommandLineArgs();
- }
-
- public override ObjectiveFunction ConstructObjFunc()
- {
- return new LogLossObjectiveFunction(TrainSet, cmd);
- }
-
- public override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
- {
- OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
- // optimizationAlgorithm.AdjustTreeOutputsOverride = new NoOutputOptimization();
- var lossCalculator = new LogLossTest(optimizationAlgorithm.TrainingScores, cmd.loglosscoef);
- int lossIndex = (cmd.loglossmode == LogLossCommandLineArgs.LogLossMode.Pairwise) ? 0 : 1;
-
- optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, lossIndex, cmd.numPostBracketSteps, cmd.minStepSize);
-
- return optimizationAlgorithm;
- }
-
- protected override void PrepareLabels(IChannel ch)
- {
- }
-
- protected override Test ConstructTestForTrainingData()
- {
- return new LogLossTest(ConstructScoreTracker(TrainSet), cmd.loglosscoef);
- }
-
- protected override void InitializeTests()
- {
- Tests.Add(new LogLossTest(ConstructScoreTracker(TrainSet), cmd.loglosscoef));
- if (ValidSet != null)
- Tests.Add(new LogLossTest(ConstructScoreTracker(ValidSet), cmd.loglosscoef));
-
- if (TestSets != null)
- {
- for (int t = 0; t < TestSets.Length; ++t)
- {
- Tests.Add(new LogLossTest(ConstructScoreTracker(TestSets[t]), cmd.loglosscoef));
- }
- }
- }
- }
-
- public class LogLossObjectiveFunction : RankingObjectiveFunction
- {
- private LogLossCommandLineArgs.LogLossMode _mode;
- private double _coef = 1.0;
-
- public LogLossObjectiveFunction(Dataset trainSet, LogLossCommandLineArgs cmd)
- : base(trainSet, trainSet.Ratings, cmd)
- {
- _mode = cmd.loglossmode;
- _coef = cmd.loglosscoef;
- }
-
- protected override void GetGradientInOneQuery(int query, int threadIndex)
- {
- int begin = Dataset.Boundaries[query];
- int end = Dataset.Boundaries[query + 1];
- short[] labels = Dataset.Ratings;
-
- if (end - begin <= 1)
- return;
- Array.Clear(_gradient, begin, end - begin);
-
- for (int d1 = begin; d1 < end - 1; ++d1)
- {
- int stop = (_mode == LogLossCommandLineArgs.LogLossMode.Pairwise) ? d1 + 2 : end;
- for (int d2 = d1 + 1; d2 < stop; ++d2)
- {
- short labelDiff = (short)(labels[d1] - labels[d2]);
- if (labelDiff == 0)
- continue;
- double delta = (_coef * labelDiff) / (1.0 + Math.Exp(_coef * labelDiff * (_scores[d1] - _scores[d2])));
-
- _gradient[d1] += delta;
- _gradient[d2] -= delta;
- }
- }
- }
- }
-
- public class LogLossTest : Test
- {
- protected double _coef;
- public LogLossTest(ScoreTracker scoreTracker, double coef)
- : base(scoreTracker)
- {
- _coef = coef;
- }
-
- public override IEnumerable ComputeTests(double[] scores)
- {
- Object _lock = new Object();
- double pairedLoss = 0.0;
- double allPairLoss = 0.0;
- short maxLabel = 0;
- short minLabel = 10000;
-
- for (int query = 0; query < Dataset.Boundaries.Length - 1; query++)
- {
- int start = Dataset.Boundaries[query];
- int length = Dataset.Boundaries[query + 1] - start;
- for (int i = start; i < start + length; i++)
- for (int j = i + 1; j < start + length; j++)
- allPairLoss += Math.Log((1.0 + Math.Exp(-_coef * (Dataset.Ratings[i] - Dataset.Ratings[j]) * (scores[i] - scores[j]))));
- //allPairLoss += Math.Max(0.0, _coef - (Dataset.Ratings[i] - Dataset.Ratings[j]) * (scores[i] - scores[j]));
-
- for (int i = start; i < start + length - 1; i++)
- {
- pairedLoss += Math.Log((1.0 + Math.Exp(-_coef * (Dataset.Ratings[i] - Dataset.Ratings[i + 1]) * (scores[i] - scores[i + 1]))));
- // pairedLoss += Math.Max(0.0, _coef - (Dataset.Ratings[i] - Dataset.Ratings[i + 1]) * (scores[i] - scores[i + 1]));
- }
- }
-
- for (int i = 0; i < Dataset.Ratings.Length; i++)
- {
- if (Dataset.Ratings[i] > maxLabel)
- maxLabel = Dataset.Ratings[i];
- if (Dataset.Ratings[i] < minLabel)
- minLabel = Dataset.Ratings[i];
- }
- List result = new List()
- {
- new TestResult("paired loss", pairedLoss, Dataset.NumDocs, true, TestResult.ValueOperator.Average),
- new TestResult("all pairs loss", allPairLoss, Dataset.NumDocs, true, TestResult.ValueOperator.Average),
- new TestResult("coefficient", _coef, 1, true, TestResult.ValueOperator.Constant),
- new TestResult("max Label", maxLabel, 1, false, TestResult.ValueOperator.Max),
- new TestResult("min Label", minLabel, 1, true, TestResult.ValueOperator.Min),
- };
-
- return result;
- }
- }
-
- public class NoOutputOptimization : IStepSearch
- {
- public NoOutputOptimization() { }
-
- public void AdjustTreeOutputs(IChannel ch, RegressionTree tree, DocumentPartitioning partitioning, ScoreTracker trainingScores) { }
- }
-#endif
-}
diff --git a/src/Microsoft.ML.FastTree/Application/SizeAdjustedLogLossApplication.cs b/src/Microsoft.ML.FastTree/Application/SizeAdjustedLogLossApplication.cs
deleted file mode 100644
index 0cb8c9496f..0000000000
--- a/src/Microsoft.ML.FastTree/Application/SizeAdjustedLogLossApplication.cs
+++ /dev/null
@@ -1,388 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-namespace Microsoft.ML.Trainers.FastTree.Internal
-{
-#if OLD_DATALOAD
- public class SizeAdjustedLogLossCommandLineArgs : TrainingCommandLineArgs
- {
- public enum LogLossMode
- {
- Pairwise,
- Wholepage
- };
-
- public enum CostFunctionMode
- {
- SizeAdjustedWinratePredictor,
- SizeAdjustedPageOrdering
- };
-
- [Argument(ArgumentType.LastOccurenceWins, HelpText = "Which style of log loss to use", ShortName = "llm")]
- public LogLossMode loglossmode = LogLossMode.Pairwise;
-
- [Argument(ArgumentType.LastOccurenceWins, HelpText = "Which style of cost function to use", ShortName = "cfm")]
- public CostFunctionMode costFunctionMode = CostFunctionMode.SizeAdjustedPageOrdering;
-
- // REVIEW: If we ever want to expose this application in TLC, the natural thing would be for these to
- // be loaded from a column from the input data view. However I'll keep them around for now, so that when we
- // do migrate there's some clear idea of how it is to be done.
- [Argument(ArgumentType.AtMostOnce, HelpText = "tab seperated file which contains the max and min score from the model", ShortName = "srange")]
- public string scoreRangeFileName = null;
-
- [Argument(ArgumentType.MultipleUnique, HelpText = "TSV filename with float size values associated with train bin file", ShortName = "trsl")]
- public string[] trainSizeLabelFilenames = null;
-
- [Argument(ArgumentType.LastOccurenceWins, HelpText = "TSV filename with float size values associated with validation bin file", ShortName = "vasl")]
- public string validSizeLabelFilename = null;
-
- [Argument(ArgumentType.MultipleUnique, HelpText = "TSV filename with float size values associated with test bin file", ShortName = "tesl")]
- public string[] testSizeLabelFilenames = null;
-
- [Argument(ArgumentType.LastOccurenceWins, HelpText = "log loss cefficient", ShortName = "llc")]
- public double loglosscoef = 1.0;
-
- }
-
- public class SizeAdjustedLogLossUtil
- {
- public const string sizeLabelName = "size";
-
- public static float[] GetSizeLabels(Dataset set)
- {
- if (set == null)
- {
- return null;
- }
- float[] labels = set.Skeleton.GetData(sizeLabelName);
- if (labels == null)
- {
- labels = set.Ratings.Select(x => (float)x).ToArray();
- }
- return labels;
- }
- }
-
- public class SizeAdjustedLogLossTrainingApplication : ApplicationBase
- {
- new SizeAdjustedLogLossCommandLineArgs cmd;
- float[] trainSetSizeLabels;
- private const string RegistrationName = "SizeAdjustedLogLossApplication";
-
- public SizeAdjustedLogLossTrainingApplication(IHostEnvironment env, string args, TrainingApplicationData data)
- : base(env, RegistrationName, args, data)
- {
- base.cmd = this.cmd = new SizeAdjustedLogLossCommandLineArgs();
- }
-
- public override ObjectiveFunction ConstructObjFunc()
- {
- return new SizeAdjustedLogLossObjectiveFunction(TrainSet, cmd);
- }
-
- public override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
- {
- OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
- // optimizationAlgorithm.AdjustTreeOutputsOverride = new NoOutputOptimization(); // For testing purposes - this will not use line search and thus the scores won't be scaled.
- var lossCalculator = new SizeAdjustedLogLossTest(optimizationAlgorithm.TrainingScores, cmd.scoreRangeFileName, trainSetSizeLabels, cmd.loglosscoef);
-
- // The index of the label signifies which index from TestResult would be used as a loss. For every query, we compute both wholepage and pairwise loss with the two cost function modes, this index lets us pick the appropriate one.
- int lossIndex = 0;
- if (cmd.loglossmode == SizeAdjustedLogLossCommandLineArgs.LogLossMode.Wholepage && cmd.costFunctionMode == SizeAdjustedLogLossCommandLineArgs.CostFunctionMode.SizeAdjustedPageOrdering)
- {
- lossIndex = 1;
- }
- else if (cmd.loglossmode == SizeAdjustedLogLossCommandLineArgs.LogLossMode.Pairwise && cmd.costFunctionMode == SizeAdjustedLogLossCommandLineArgs.CostFunctionMode.SizeAdjustedWinratePredictor)
- {
- lossIndex = 2;
- }
- else if (cmd.loglossmode == SizeAdjustedLogLossCommandLineArgs.LogLossMode.Wholepage && cmd.costFunctionMode == SizeAdjustedLogLossCommandLineArgs.CostFunctionMode.SizeAdjustedWinratePredictor)
- {
- lossIndex = 3;
- }
-
- optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, lossIndex, cmd.numPostBracketSteps, cmd.minStepSize);
-
- return optimizationAlgorithm;
- }
-
- private static IEnumerable LoadSizeLabels(string filename)
- {
- using (StreamReader reader = new StreamReader(new FileStream(filename, FileMode.Open, FileAccess.Read, FileShare.Read)))
- {
- string line = reader.ReadLine();
- Contracts.Check(line != null && line.Trim() == "m:Size", "Regression label file should contain only one column m:Size");
- while ((line = reader.ReadLine()) != null)
- {
- float val = float.Parse(line.Trim(), CultureInfo.InvariantCulture);
- yield return val;
- }
- }
- }
-
- protected override void PrepareLabels(IChannel ch)
- {
- trainSetSizeLabels = SizeAdjustedLogLossUtil.GetSizeLabels(TrainSet);
- }
-
- protected override Test ConstructTestForTrainingData()
- {
- return new SizeAdjustedLogLossTest(ConstructScoreTracker(TrainSet), cmd.scoreRangeFileName, SizeAdjustedLogLossUtil.GetSizeLabels(TrainSet), cmd.loglosscoef);
- }
-
- protected override void ProcessBinFile(int trainValidTest, int index, DatasetBinFile bin)
- {
- string labelPath = null;
- switch (trainValidTest)
- {
- case 0:
- if (cmd.trainSizeLabelFilenames != null && index < cmd.trainSetFilenames.Length)
- {
- labelPath = cmd.trainSizeLabelFilenames[index];
- }
- break;
- case 1:
- if (cmd.validSizeLabelFilename != null)
- {
- labelPath = cmd.validSizeLabelFilename;
- }
- break;
- case 2:
- if (cmd.testSizeLabelFilenames != null && index < cmd.testSizeLabelFilenames.Length)
- {
- labelPath = cmd.testSizeLabelFilenames[index];
- }
- break;
- }
- // If we have no labels, return.
- if (labelPath == null)
- {
- return;
- }
- float[] labels = LoadSizeLabels(labelPath).ToArray();
- bin.DatasetSkeleton.SetData(SizeAdjustedLogLossUtil.sizeLabelName, labels, false);
- }
-
- protected override void InitializeTests()
- {
- Tests.Add(new SizeAdjustedLogLossTest(ConstructScoreTracker(TrainSet), cmd.scoreRangeFileName, SizeAdjustedLogLossUtil.GetSizeLabels(TrainSet), cmd.loglosscoef));
- if (ValidSet != null)
- {
- Tests.Add(new SizeAdjustedLogLossTest(ConstructScoreTracker(ValidSet), cmd.scoreRangeFileName, SizeAdjustedLogLossUtil.GetSizeLabels(ValidSet), cmd.loglosscoef));
- }
-
- if (TestSets != null && TestSets.Length > 0)
- {
- for (int t = 0; t < TestSets.Length; ++t)
- {
- Tests.Add(new SizeAdjustedLogLossTest(ConstructScoreTracker(TestSets[t]), cmd.scoreRangeFileName, SizeAdjustedLogLossUtil.GetSizeLabels(TestSets[t]), cmd.loglosscoef));
- }
- }
- }
-
- public override void PrintIterationMessage(IChannel ch, IProgressChannel pch)
- {
- base.PrintIterationMessage(ch, pch);
- }
- }
-
- public class SizeAdjustedLogLossObjectiveFunction : RankingObjectiveFunction
- {
- private SizeAdjustedLogLossCommandLineArgs.LogLossMode _mode;
- private SizeAdjustedLogLossCommandLineArgs.CostFunctionMode _algo;
- private double _llc;
-
- public SizeAdjustedLogLossObjectiveFunction(Dataset trainSet, SizeAdjustedLogLossCommandLineArgs cmd)
- : base(trainSet, trainSet.Ratings, cmd)
- {
- _mode = cmd.loglossmode;
- _algo = cmd.costFunctionMode;
- _llc = cmd.loglosscoef;
- }
-
- protected override void GetGradientInOneQuery(int query, int threadIndex)
- {
- int begin = Dataset.Boundaries[query];
- int end = Dataset.Boundaries[query + 1];
- short[] labels = Dataset.Ratings;
- float[] sizes = SizeAdjustedLogLossUtil.GetSizeLabels(Dataset);
-
- Contracts.Check(Dataset.NumDocs == sizes.Length, "Mismatch between dataset and labels");
-
- if (end - begin <= 1)
- {
- return;
- }
- Array.Clear(_gradient, begin, end - begin);
-
- for (int d1 = begin; d1 < end - 1; ++d1)
- {
- for (int d2 = d1 + 1; d2 < end; ++d2)
- {
- float size = sizes[d1];
-
- //Compute Lij
- float sizeAdjustedLoss = 0.0F;
- for (int d3 = d2; d3 < end; ++d3)
- {
- size -= sizes[d3];
- if (size >= 0.0F && labels[d3] > 0)
- {
- sizeAdjustedLoss = 1.0F;
- }
- else if (size < 0.0F && labels[d3] > 0)
- {
- sizeAdjustedLoss = (1.0F + (size / sizes[d3]));
- }
-
- if (size <= 0.0F || sizeAdjustedLoss > 0.0F)
- {
- // Exit condition- we have reached size or size adjusted loss is already populated.
- break;
- }
- }
-
- double scoreDiff = _scores[d1] - _scores[d2];
- float labelDiff = ((float)labels[d1] - sizeAdjustedLoss);
- double delta = 0.0;
- if (_algo == SizeAdjustedLogLossCommandLineArgs.CostFunctionMode.SizeAdjustedPageOrdering)
- {
- delta = (_llc * labelDiff) / (1.0 + Math.Exp(_llc * labelDiff * scoreDiff));
- }
- else
- {
- delta = (double)labels[d1] - ((double)(labels[d1] + sizeAdjustedLoss) / (1.0 + Math.Exp(-scoreDiff)));
- }
-
- _gradient[d1] += delta;
- _gradient[d2] -= delta;
-
- if (_mode == SizeAdjustedLogLossCommandLineArgs.LogLossMode.Pairwise)
- {
- break;
- }
- }
- }
- }
- }
-
- public class SizeAdjustedLogLossTest : Test
- {
- protected string _scoreRangeFileName = null;
- static double maxScore = double.MinValue;
- static double minScore = double.MaxValue;
- private float[] _sizeLabels;
- private double _llc;
-
- public SizeAdjustedLogLossTest(ScoreTracker scoreTracker, string scoreRangeFileName, float[] sizeLabels, double loglossCoeff)
- : base(scoreTracker)
- {
- Contracts.Check(scoreTracker.Dataset.NumDocs == sizeLabels.Length, "Mismatch between dataset and labels");
- _sizeLabels = sizeLabels;
- _scoreRangeFileName = scoreRangeFileName;
- _llc = loglossCoeff;
- }
-
- public override IEnumerable ComputeTests(double[] scores)
- {
- Object _lock = new Object();
- double pairedPageOrderingLoss = 0.0;
- double allPairPageOrderingLoss = 0.0;
- double pairedSAWRPredictLoss = 0.0;
- double allPairSAWRPredictLoss = 0.0;
-
- short maxLabel = 0;
- short minLabel = 10000;
- for (int query = 0; query < Dataset.Boundaries.Length - 1; ++query)
- {
- int begin = Dataset.Boundaries[query];
- int end = Dataset.Boundaries[query + 1];
-
- if (end - begin <= 1)
- {
- continue;
- }
-
- for (int d1 = begin; d1 < end - 1; ++d1)
- {
- bool firstTime = false;
- for (int d2 = d1 + 1; d2 < end; ++d2)
- {
- float size = _sizeLabels[d1];
-
- //Compute Lij
- float sizeAdjustedLoss = 0.0F;
- for (int d3 = d2; d3 < end; ++d3)
- {
- size -= _sizeLabels[d3];
- if (size >= 0.0F && Dataset.Ratings[d3] > 0)
- {
- sizeAdjustedLoss = 1.0F;
- }
- else if (size < 0.0F && Dataset.Ratings[d3] > 0)
- {
- sizeAdjustedLoss = (1.0F + (size / _sizeLabels[d3]));
- }
-
- if (size <= 0.0F || sizeAdjustedLoss > 0.0F)
- {
- // Exit condition- we have reached size or size adjusted loss is already populated.
- break;
- }
- }
- //Compute page ordering loss
- double scoreDiff = scores[d1] - scores[d2];
- float labelDiff = ((float)Dataset.Ratings[d1] - sizeAdjustedLoss);
- double pageOrderingLoss = Math.Log(1.0 + Math.Exp(-_llc * labelDiff * scoreDiff));
-
- // Compute SAWR predict loss
- double sawrPredictLoss = (double)((Dataset.Ratings[d1] + sizeAdjustedLoss) * Math.Log(1.0 + Math.Exp(scoreDiff))) - ((double)Dataset.Ratings[d1] * scoreDiff);
- if (!firstTime)
- {
- pairedPageOrderingLoss += pageOrderingLoss;
- pairedSAWRPredictLoss += sawrPredictLoss;
- firstTime = true;
- }
- allPairPageOrderingLoss += pageOrderingLoss;
- allPairSAWRPredictLoss += sawrPredictLoss;
- }
- }
- }
-
- for (int i = 0; i < Dataset.Ratings.Length; i++)
- {
- if (Dataset.Ratings[i] > maxLabel)
- maxLabel = Dataset.Ratings[i];
- if (Dataset.Ratings[i] < minLabel)
- minLabel = Dataset.Ratings[i];
- if (scores[i] > maxScore)
- maxScore = scores[i];
- if (scores[i] < minScore)
- minScore = scores[i];
- }
-
- if (_scoreRangeFileName != null)
- {
- using (StreamWriter sw = File.CreateText(_scoreRangeFileName))
- {
- sw.WriteLine(string.Format("{0}\t{1}", minScore, maxScore));
- }
- }
-
- List result = new List()
- {
- // The index of the label signifies which index from TestResult would be used as a loss. For every query, we compute both wholepage and pairwise loss with the two cost function modes, this index lets us pick the appropriate one.
- new TestResult("page ordering paired loss", pairedPageOrderingLoss, Dataset.NumDocs, true, TestResult.ValueOperator.Average),
- new TestResult("page ordering all pairs loss", allPairPageOrderingLoss, Dataset.NumDocs, true, TestResult.ValueOperator.Average),
- new TestResult("SAWR predict paired loss", pairedSAWRPredictLoss, Dataset.NumDocs, true, TestResult.ValueOperator.Average),
- new TestResult("SAWR predict all pairs loss", allPairSAWRPredictLoss, Dataset.NumDocs, true, TestResult.ValueOperator.Average),
- new TestResult("max Label", maxLabel, 1, false, TestResult.ValueOperator.Max),
- new TestResult("min Label", minLabel, 1, true, TestResult.ValueOperator.Min),
- };
-
- return result;
- }
- }
-#endif
-}
diff --git a/src/Microsoft.ML.FastTree/Application/WinLossSurplusApplication.cs b/src/Microsoft.ML.FastTree/Application/WinLossSurplusApplication.cs
deleted file mode 100644
index 1708033161..0000000000
--- a/src/Microsoft.ML.FastTree/Application/WinLossSurplusApplication.cs
+++ /dev/null
@@ -1,181 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-namespace Microsoft.ML.Trainers.FastTree.Internal
-{
-#if OLD_DATALOAD
- public class WinLossSurplusCommandLineArgs : TrainingCommandLineArgs
- {
- [Argument(ArgumentType.AtMostOnce, HelpText = "Scaling Factor for win loss surplus", ShortName = "wls")]
- public double winlossScaleFactor = 1.0;
- }
-
- public class WinLossSurplusTrainingApplication : RankingApplication
- {
- new WinLossSurplusCommandLineArgs cmd;
- private const string RegistrationName = "WinLossSurplusApplication";
-
- public WinLossSurplusTrainingApplication(IHostEnvironment env, string args, TrainingApplicationData data)
- : base(env, RegistrationName, args, data)
- {
- base.cmd = this.cmd = new WinLossSurplusCommandLineArgs();
- }
-
- public override ObjectiveFunction ConstructObjFunc()
- {
- return new WinLossSurplusObjectiveFunction(TrainSet, TrainSet.Ratings, cmd);
- }
-
- public override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
- {
- OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
- var lossCalculator = new WinLossSurplusTest(optimizationAlgorithm.TrainingScores, TrainSet.Ratings, cmd.sortingAlgorithm, cmd.winlossScaleFactor);
- optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, 0, cmd.numPostBracketSteps, cmd.minStepSize);
- return optimizationAlgorithm;
- }
-
- protected override Test ConstructTestForTrainingData()
- {
- return new WinLossSurplusTest(ConstructScoreTracker(TrainSet), TrainSet.Ratings, cmd.sortingAlgorithm, cmd.winlossScaleFactor);
- }
-
- public override void PrintIterationMessage(IChannel ch, IProgressChannel pch)
- {
- // REVIEW: Shift this to use progress channels.
-#if OLD_TRACING
- if (PruningTest != null)
- {
- if (PruningTest is TestWindowWithTolerance)
- {
- if (PruningTest.BestIteration != -1)
- ch.Info("Iteration {0} \t(Best tolerated validation moving average WinLossSurplus {1}:{2:00.00}~{3:00.00})",
- Ensemble.NumTrees,
- PruningTest.BestIteration,
- (PruningTest as TestWindowWithTolerance).BestAverageValue,
- (PruningTest as TestWindowWithTolerance).CurrentAverageValue);
- else
- ch.Info("Iteration {0}", Ensemble.NumTrees);
- }
- else
- {
- ch.Info("Iteration {0} \t(best validation WinLoss {1}:{2:00.00}>{3:00.00})",
- Ensemble.NumTrees,
- PruningTest.BestIteration,
- PruningTest.BestResult.FinalValue,
- PruningTest.ComputeTests().First().FinalValue);
- }
- }
- else
- base.PrintIterationMessage(ch, pch);
-#else
- base.PrintIterationMessage(ch, pch);
-#endif
- }
-
- protected override Test CreateStandardTest(Dataset dataset)
- {
- return new WinLossSurplusTest(
- ConstructScoreTracker(dataset),
- dataset.Ratings,
- cmd.sortingAlgorithm,
- cmd.winlossScaleFactor);
- }
-
- protected override Test CreateSpecialTrainSetTest()
- {
- return CreateStandardTest(TrainSet);
- }
-
- protected override Test CreateSpecialValidSetTest()
- {
- return CreateStandardTest(ValidSet);
- }
-
- protected override string GetTestGraphHeader()
- {
- return "Eval:\tFileName\tMaxSurplus\tSurplus@100\tSurplus@200\tSurplus@300\tSurplus@400\tSurplus@500\tSurplus@1000\tMaxSurplusPos\tPercentTop\n";
- }
- }
-
- public class WinLossSurplusObjectiveFunction : LambdaRankObjectiveFunction
- {
- public WinLossSurplusObjectiveFunction(Dataset trainSet, short[] labels, WinLossSurplusCommandLineArgs cmd)
- : base(trainSet, labels, cmd)
- {
- }
-
- protected override void GetGradientInOneQuery(int query, int threadIndex)
- {
- int begin = Dataset.Boundaries[query];
- int numDocuments = Dataset.Boundaries[query + 1] - Dataset.Boundaries[query];
-
- Array.Clear(_gradient, begin, numDocuments);
- Array.Clear(_weights, begin, numDocuments);
-
- double inverseMaxDCG = _inverseMaxDCGT[query];
-
- int[] permutation = _permutationBuffers[threadIndex];
-
- short[] labels = Labels;
- double[] scoresToUse = _scores;
-
- // Keep track of top 3 labels for later use
- //GetTopQueryLabels(query, permutation, false);
- unsafe
- {
- fixed (int* pPermutation = permutation)
- fixed (short* pLabels = labels)
- fixed (double* pScores = scoresToUse, pLambdas = _gradient, pWeights = _weights, pDiscount = _discount)
- fixed (double* pGain = _gain, pGainLabels = _gainLabels, pSigmoidTable = _sigmoidTable)
- fixed (int* pOneTwoThree = _oneTwoThree)
- {
- // calculates the permutation that orders "scores" in descending order, without modifying "scores"
- Array.Copy(_oneTwoThree, permutation, numDocuments);
-#if USE_FASTTREENATIVE
- double lambdaSum = 0;
-
- C_Sort(pPermutation, &pScores[begin], &pLabels[begin], numDocuments);
-
- // Keep track of top 3 labels for later use
- GetTopQueryLabels(query, permutation, true);
-
- int numActualResults = numDocuments;
-
- C_GetSurplusDerivatives(numDocuments, begin, pPermutation, pLabels,
- pScores, pLambdas, pWeights, pDiscount,
- pGainLabels,
- pSigmoidTable, _minScore, _maxScore, _sigmoidTable.Length, _scoreToSigmoidTableFactor,
- _costFunctionParam, _distanceWeight2, &lambdaSum, double.MinValue);
-
- if (_normalizeQueryLambdas)
- {
- if (lambdaSum > 0)
- {
- double normFactor = (10 * Math.Log(1 + lambdaSum)) / lambdaSum;
-
- for (int i = begin; i < begin + numDocuments; ++i)
- {
- pLambdas[i] = pLambdas[i] * normFactor;
- pWeights[i] = pWeights[i] * normFactor;
- }
- }
- }
-#else
- throw new Exception("Shifted NDCG / ContinuousWeightedRanknet / WinLossSurplus / distanceWeight2 / normalized lambdas are only supported by unmanaged code");
-#endif
- }
- }
- }
-
- [DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Ansi)]
- private unsafe static extern void C_GetSurplusDerivatives(int numDocuments, int begin, int* pPermutation, short* pLabels,
- double* pScores, double* pLambdas, double* pWeights, double* pDiscount,
- double* pGainLabels, double* lambdaTable, double minScore, double maxScore,
- int lambdaTableLength, double scoreToLambdaTableFactor,
- char costFunctionParam, [MarshalAs(UnmanagedType.U1)] bool distanceWeight2,
- double* pLambdaSum, double doubleMinValue);
-
- }
-#endif
-}
diff --git a/src/Microsoft.ML.FastTree/BinFile/BinFinder.cs b/src/Microsoft.ML.FastTree/BinFile/BinFinder.cs
index 7725cb7aeb..7df57005b7 100644
--- a/src/Microsoft.ML.FastTree/BinFile/BinFinder.cs
+++ b/src/Microsoft.ML.FastTree/BinFile/BinFinder.cs
@@ -7,7 +7,7 @@
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
///
/// A class that bins vectors of doubles into a specified number of equal mass bins.
diff --git a/src/Microsoft.ML.FastTree/BinFile/IniFileParserInterface.cs b/src/Microsoft.ML.FastTree/BinFile/IniFileParserInterface.cs
index 9b28df134c..257b9e83b5 100644
--- a/src/Microsoft.ML.FastTree/BinFile/IniFileParserInterface.cs
+++ b/src/Microsoft.ML.FastTree/BinFile/IniFileParserInterface.cs
@@ -7,7 +7,7 @@
using System.Runtime.InteropServices;
using System.Text;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
internal sealed class IniFileParserInterface
{
diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs
index e84ec0117b..072ee2e20f 100644
--- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs
+++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs
@@ -6,7 +6,6 @@
using System.Linq;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Internal.Internallearn;
-using Microsoft.ML.Trainers.FastTree.Internal;
using Float = System.Single;
namespace Microsoft.ML.Trainers.FastTree
diff --git a/src/Microsoft.ML.FastTree/Dataset/Dataset.cs b/src/Microsoft.ML.FastTree/Dataset/Dataset.cs
index 330ef028ca..1e5de9b229 100644
--- a/src/Microsoft.ML.FastTree/Dataset/Dataset.cs
+++ b/src/Microsoft.ML.FastTree/Dataset/Dataset.cs
@@ -8,7 +8,7 @@
using System.Threading.Tasks;
using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
///
/// A dataset of features.
diff --git a/src/Microsoft.ML.FastTree/Dataset/DatasetUtils.cs b/src/Microsoft.ML.FastTree/Dataset/DatasetUtils.cs
index 542b097381..617e6559e0 100644
--- a/src/Microsoft.ML.FastTree/Dataset/DatasetUtils.cs
+++ b/src/Microsoft.ML.FastTree/Dataset/DatasetUtils.cs
@@ -5,7 +5,7 @@
using System.Collections.Generic;
using System.Linq;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
///
/// Loads training/validation/test sets from file
diff --git a/src/Microsoft.ML.FastTree/Dataset/DenseIntArray.cs b/src/Microsoft.ML.FastTree/Dataset/DenseIntArray.cs
index c6f10843d8..68c33a0d2d 100644
--- a/src/Microsoft.ML.FastTree/Dataset/DenseIntArray.cs
+++ b/src/Microsoft.ML.FastTree/Dataset/DenseIntArray.cs
@@ -8,7 +8,7 @@
using System.Runtime.InteropServices;
using System.Security;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
#if USE_SINGLE_PRECISION
using FloatType = System.Single;
diff --git a/src/Microsoft.ML.FastTree/Dataset/Feature.cs b/src/Microsoft.ML.FastTree/Dataset/Feature.cs
index c74ef5baf4..3013b536cf 100644
--- a/src/Microsoft.ML.FastTree/Dataset/Feature.cs
+++ b/src/Microsoft.ML.FastTree/Dataset/Feature.cs
@@ -4,7 +4,7 @@
using System.Linq;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
///
diff --git a/src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs b/src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs
index 73c6b646bf..2d370499a4 100644
--- a/src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs
+++ b/src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs
@@ -15,7 +15,7 @@
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
///
/// Holds statistics per bin value for a feature. These are yielded by
diff --git a/src/Microsoft.ML.FastTree/Dataset/FeatureHistogram.cs b/src/Microsoft.ML.FastTree/Dataset/FeatureHistogram.cs
index 2294525a9f..4968ad6a58 100644
--- a/src/Microsoft.ML.FastTree/Dataset/FeatureHistogram.cs
+++ b/src/Microsoft.ML.FastTree/Dataset/FeatureHistogram.cs
@@ -5,7 +5,7 @@
using System;
using System.Linq;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
#if USE_SINGLE_PRECISION
using FloatType = System.Single;
diff --git a/src/Microsoft.ML.FastTree/Dataset/FileObjectStore.cs b/src/Microsoft.ML.FastTree/Dataset/FileObjectStore.cs
index 9f7623d0ad..ff6e990dca 100644
--- a/src/Microsoft.ML.FastTree/Dataset/FileObjectStore.cs
+++ b/src/Microsoft.ML.FastTree/Dataset/FileObjectStore.cs
@@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
#if !NO_STORE
///
diff --git a/src/Microsoft.ML.FastTree/Dataset/IntArray.cs b/src/Microsoft.ML.FastTree/Dataset/IntArray.cs
index e2c8de872b..05ed42c957 100644
--- a/src/Microsoft.ML.FastTree/Dataset/IntArray.cs
+++ b/src/Microsoft.ML.FastTree/Dataset/IntArray.cs
@@ -6,7 +6,7 @@
using System.Collections.Generic;
using System.Linq;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
#if USE_SINGLE_PRECISION
using FloatType = System.Single;
diff --git a/src/Microsoft.ML.FastTree/Dataset/NHotFeatureFlock.cs b/src/Microsoft.ML.FastTree/Dataset/NHotFeatureFlock.cs
index fcfa6c938a..43eacf22f0 100644
--- a/src/Microsoft.ML.FastTree/Dataset/NHotFeatureFlock.cs
+++ b/src/Microsoft.ML.FastTree/Dataset/NHotFeatureFlock.cs
@@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
///
/// This is a feature flock that misuses a property of
diff --git a/src/Microsoft.ML.FastTree/Dataset/OneHotFeatureFlock.cs b/src/Microsoft.ML.FastTree/Dataset/OneHotFeatureFlock.cs
index da8592fc14..605d912328 100644
--- a/src/Microsoft.ML.FastTree/Dataset/OneHotFeatureFlock.cs
+++ b/src/Microsoft.ML.FastTree/Dataset/OneHotFeatureFlock.cs
@@ -4,7 +4,7 @@
using System.Linq;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
///
/// A feature flock for a set of features where per example at most one of the features has a
diff --git a/src/Microsoft.ML.FastTree/Dataset/RepeatIntArray.cs b/src/Microsoft.ML.FastTree/Dataset/RepeatIntArray.cs
index 3b9502a8db..ac176c20ec 100644
--- a/src/Microsoft.ML.FastTree/Dataset/RepeatIntArray.cs
+++ b/src/Microsoft.ML.FastTree/Dataset/RepeatIntArray.cs
@@ -6,7 +6,7 @@
using System.Collections.Generic;
using System.Linq;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
#if USE_SINGLE_PRECISION
using FloatType = Single;
diff --git a/src/Microsoft.ML.FastTree/Dataset/SegmentIntArray.cs b/src/Microsoft.ML.FastTree/Dataset/SegmentIntArray.cs
index 156f5e227c..bfeb19f2f2 100644
--- a/src/Microsoft.ML.FastTree/Dataset/SegmentIntArray.cs
+++ b/src/Microsoft.ML.FastTree/Dataset/SegmentIntArray.cs
@@ -7,7 +7,7 @@
using System.Runtime.InteropServices;
using System.Security;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
#if USE_SINGLE_PRECISION
using FloatType = System.Single;
diff --git a/src/Microsoft.ML.FastTree/Dataset/SingletonFeatureFlock.cs b/src/Microsoft.ML.FastTree/Dataset/SingletonFeatureFlock.cs
index c1cf5fe4d4..30f609faa4 100644
--- a/src/Microsoft.ML.FastTree/Dataset/SingletonFeatureFlock.cs
+++ b/src/Microsoft.ML.FastTree/Dataset/SingletonFeatureFlock.cs
@@ -5,7 +5,7 @@
using System.Linq;
using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
///
/// The singleton feature flock is the simplest possible sort of flock, that is, a flock
diff --git a/src/Microsoft.ML.FastTree/Dataset/SparseIntArray.cs b/src/Microsoft.ML.FastTree/Dataset/SparseIntArray.cs
index b3dc61b07a..b9e5cc0f28 100644
--- a/src/Microsoft.ML.FastTree/Dataset/SparseIntArray.cs
+++ b/src/Microsoft.ML.FastTree/Dataset/SparseIntArray.cs
@@ -7,7 +7,7 @@
using System.Runtime.InteropServices;
using System.Security;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
#if USE_SINGLE_PRECISION
using FloatType = System.Single;
diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs
index fc94e6099c..c2a90c4f3e 100644
--- a/src/Microsoft.ML.FastTree/FastTree.cs
+++ b/src/Microsoft.ML.FastTree/FastTree.cs
@@ -23,7 +23,6 @@
using Microsoft.ML.Model;
using Microsoft.ML.Model.Onnx;
using Microsoft.ML.Model.Pfa;
-using Microsoft.ML.Trainers.FastTree.Internal;
using Microsoft.ML.Training;
using Microsoft.ML.Transforms;
using Microsoft.ML.Transforms.Conversions;
diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs
index ff47449238..4ae7ab7555 100644
--- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs
@@ -15,7 +15,6 @@
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Model;
using Microsoft.ML.Trainers.FastTree;
-using Microsoft.ML.Trainers.FastTree.Internal;
using Microsoft.ML.Training;
[assembly: LoadableClass(FastTreeBinaryClassificationTrainer.Summary, typeof(FastTreeBinaryClassificationTrainer), typeof(FastTreeBinaryClassificationTrainer.Options),
diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs
index 485a7639ca..059cd8ab93 100644
--- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs
@@ -17,7 +17,6 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Trainers.FastTree;
-using Microsoft.ML.Trainers.FastTree.Internal;
using Microsoft.ML.Training;
// REVIEW: Do we really need all these names?
diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs
index e26ce00e82..1399b8bebe 100644
--- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs
@@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using System.Linq;
using System.Text;
using Microsoft.Data.DataView;
@@ -13,7 +12,6 @@
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Model;
using Microsoft.ML.Trainers.FastTree;
-using Microsoft.ML.Trainers.FastTree.Internal;
using Microsoft.ML.Training;
[assembly: LoadableClass(FastTreeRegressionTrainer.Summary, typeof(FastTreeRegressionTrainer), typeof(FastTreeRegressionTrainer.Options),
diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
index b8ea0ce030..547d48f7ca 100644
--- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
@@ -14,7 +14,6 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Trainers.FastTree;
-using Microsoft.ML.Trainers.FastTree.Internal;
using Microsoft.ML.Training;
[assembly: LoadableClass(FastTreeTweedieTrainer.Summary, typeof(FastTreeTweedieTrainer), typeof(FastTreeTweedieTrainer.Options),
diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs
index 635862e178..79f8265dbb 100644
--- a/src/Microsoft.ML.FastTree/GamClassification.cs
+++ b/src/Microsoft.ML.FastTree/GamClassification.cs
@@ -14,7 +14,6 @@
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Model;
using Microsoft.ML.Trainers.FastTree;
-using Microsoft.ML.Trainers.FastTree.Internal;
using Microsoft.ML.Training;
[assembly: LoadableClass(BinaryClassificationGamTrainer.Summary,
diff --git a/src/Microsoft.ML.FastTree/GamModelParameters.cs b/src/Microsoft.ML.FastTree/GamModelParameters.cs
index 38c931f03f..a726edd1bc 100644
--- a/src/Microsoft.ML.FastTree/GamModelParameters.cs
+++ b/src/Microsoft.ML.FastTree/GamModelParameters.cs
@@ -18,7 +18,6 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Trainers.FastTree;
-using Microsoft.ML.Trainers.FastTree.Internal;
using Microsoft.ML.Training;
[assembly: LoadableClass(typeof(GamModelParametersBase.VisualizationCommand), typeof(GamModelParametersBase.VisualizationCommand.Arguments), typeof(SignatureCommand),
diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs
index 100cc6b61a..4c3fbe021b 100644
--- a/src/Microsoft.ML.FastTree/GamRegression.cs
+++ b/src/Microsoft.ML.FastTree/GamRegression.cs
@@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using Microsoft.Data.DataView;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
@@ -11,7 +10,6 @@
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Model;
using Microsoft.ML.Trainers.FastTree;
-using Microsoft.ML.Trainers.FastTree.Internal;
using Microsoft.ML.Training;
[assembly: LoadableClass(RegressionGamTrainer.Summary,
diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs
index 87bcc370ba..11953bbe36 100644
--- a/src/Microsoft.ML.FastTree/GamTrainer.cs
+++ b/src/Microsoft.ML.FastTree/GamTrainer.cs
@@ -14,15 +14,12 @@
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Trainers.FastTree;
-using Microsoft.ML.Trainers.FastTree.Internal;
using Microsoft.ML.Training;
-using Timer = Microsoft.ML.Trainers.FastTree.Internal.Timer;
[assembly: LoadableClass(typeof(void), typeof(Gam), null, typeof(SignatureEntryPointModule), "GAM")]
namespace Microsoft.ML.Trainers.FastTree
{
- using AutoResetEvent = System.Threading.AutoResetEvent;
using SplitInfo = LeastSquaresRegressionTreeLearner.SplitInfo;
///
diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs
index 943b10f621..aea3a991bf 100644
--- a/src/Microsoft.ML.FastTree/RandomForest.cs
+++ b/src/Microsoft.ML.FastTree/RandomForest.cs
@@ -2,9 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using Microsoft.ML.Core.Data;
-using Microsoft.ML.Trainers.FastTree.Internal;
namespace Microsoft.ML.Trainers.FastTree
{
diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs
index d6a31ec6a0..13d2261ea9 100644
--- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs
+++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs
@@ -15,7 +15,6 @@
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Model;
using Microsoft.ML.Trainers.FastTree;
-using Microsoft.ML.Trainers.FastTree.Internal;
using Microsoft.ML.Training;
[assembly: LoadableClass(FastForestClassification.Summary, typeof(FastForestClassification), typeof(FastForestClassification.Options),
diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs
index de83cabeca..daed4b26c1 100644
--- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs
+++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs
@@ -13,7 +13,6 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Trainers.FastTree;
-using Microsoft.ML.Trainers.FastTree.Internal;
using Microsoft.ML.Training;
[assembly: LoadableClass(FastForestRegression.Summary, typeof(FastForestRegression), typeof(FastForestRegression.Options),
diff --git a/src/Microsoft.ML.FastTree/RegressionTree.cs b/src/Microsoft.ML.FastTree/RegressionTree.cs
index 867d6015ef..96a0784602 100644
--- a/src/Microsoft.ML.FastTree/RegressionTree.cs
+++ b/src/Microsoft.ML.FastTree/RegressionTree.cs
@@ -4,7 +4,7 @@
using System.Collections.Generic;
using System.Collections.Immutable;
-using Microsoft.ML.Trainers.FastTree.Internal;
+using Microsoft.ML.Trainers.FastTree;
namespace Microsoft.ML.FastTree
{
diff --git a/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs b/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs
index a25b370586..7816dc993b 100644
--- a/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs
+++ b/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs
@@ -17,14 +17,13 @@
using Microsoft.ML.CommandLine;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Trainers.FastTree;
-using Microsoft.ML.Trainers.FastTree.Internal;
[assembly: LoadableClass(typeof(SumupPerformanceCommand), typeof(SumupPerformanceCommand.Arguments), typeof(SignatureCommand),
"", "FastTreeSumupPerformance", "ftsumup")]
namespace Microsoft.ML.Trainers.FastTree
{
-using Stopwatch = System.Diagnostics.Stopwatch;
+ using Stopwatch = System.Diagnostics.Stopwatch;
///
/// This is an internal utility command to measure the performance of the IntArray sumup operation.
diff --git a/src/Microsoft.ML.FastTree/Training/Applications/GradientWrappers.cs b/src/Microsoft.ML.FastTree/Training/Applications/GradientWrappers.cs
index a1e65a455b..b360b9ea4a 100644
--- a/src/Microsoft.ML.FastTree/Training/Applications/GradientWrappers.cs
+++ b/src/Microsoft.ML.FastTree/Training/Applications/GradientWrappers.cs
@@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
///
/// Trivial weights wrapper. Creates proxy class holding the targets.
diff --git a/src/Microsoft.ML.FastTree/Training/Applications/ObjectiveFunction.cs b/src/Microsoft.ML.FastTree/Training/Applications/ObjectiveFunction.cs
index 2808e24cb7..644add5903 100644
--- a/src/Microsoft.ML.FastTree/Training/Applications/ObjectiveFunction.cs
+++ b/src/Microsoft.ML.FastTree/Training/Applications/ObjectiveFunction.cs
@@ -7,7 +7,7 @@
using System.Linq;
using System.Threading.Tasks;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
public abstract class ObjectiveFunctionBase
{
diff --git a/src/Microsoft.ML.FastTree/Training/BaggingProvider.cs b/src/Microsoft.ML.FastTree/Training/BaggingProvider.cs
index 3fd5125620..0dc50ce83b 100644
--- a/src/Microsoft.ML.FastTree/Training/BaggingProvider.cs
+++ b/src/Microsoft.ML.FastTree/Training/BaggingProvider.cs
@@ -4,7 +4,7 @@
using System;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
public class BaggingProvider
{
diff --git a/src/Microsoft.ML.FastTree/Training/DcgCalculator.cs b/src/Microsoft.ML.FastTree/Training/DcgCalculator.cs
index e9284a6a5e..7fd960a190 100644
--- a/src/Microsoft.ML.FastTree/Training/DcgCalculator.cs
+++ b/src/Microsoft.ML.FastTree/Training/DcgCalculator.cs
@@ -7,7 +7,7 @@
using System.Threading.Tasks;
using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
public sealed class DcgCalculator
{
diff --git a/src/Microsoft.ML.FastTree/Training/DcgPermutationComparer.cs b/src/Microsoft.ML.FastTree/Training/DcgPermutationComparer.cs
index 05e2955ec9..e0fb3ff926 100644
--- a/src/Microsoft.ML.FastTree/Training/DcgPermutationComparer.cs
+++ b/src/Microsoft.ML.FastTree/Training/DcgPermutationComparer.cs
@@ -4,7 +4,7 @@
using System.Collections.Generic;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
public abstract class DcgPermutationComparer : IComparer
{
diff --git a/src/Microsoft.ML.FastTree/Training/DocumentPartitioning.cs b/src/Microsoft.ML.FastTree/Training/DocumentPartitioning.cs
index 3a52acdebf..6e06480014 100644
--- a/src/Microsoft.ML.FastTree/Training/DocumentPartitioning.cs
+++ b/src/Microsoft.ML.FastTree/Training/DocumentPartitioning.cs
@@ -7,7 +7,7 @@
using System.Linq;
using System.Threading.Tasks;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
#if USE_SINGLE_PRECISION
using FloatType = System.Single;
diff --git a/src/Microsoft.ML.FastTree/Training/EnsembleCompression/IEnsembleCompressor.cs b/src/Microsoft.ML.FastTree/Training/EnsembleCompression/IEnsembleCompressor.cs
index 7f3c687819..fb87ab95f2 100644
--- a/src/Microsoft.ML.FastTree/Training/EnsembleCompression/IEnsembleCompressor.cs
+++ b/src/Microsoft.ML.FastTree/Training/EnsembleCompression/IEnsembleCompressor.cs
@@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
internal interface IEnsembleCompressor
{
diff --git a/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoBasedEnsembleCompressor.cs b/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoBasedEnsembleCompressor.cs
index eba82d0cfe..69b2acd798 100644
--- a/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoBasedEnsembleCompressor.cs
+++ b/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoBasedEnsembleCompressor.cs
@@ -6,7 +6,7 @@
using System.Collections.Generic;
using System.Diagnostics;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
///
/// This implementation is based on:
diff --git a/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoFit.cs b/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoFit.cs
index 7a1e7ac321..a29752f612 100644
--- a/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoFit.cs
+++ b/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoFit.cs
@@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
public sealed class LassoFit
{
diff --git a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/AcceleratedGradientDescent.cs b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/AcceleratedGradientDescent.cs
index 0678eb4060..72cf5af52c 100644
--- a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/AcceleratedGradientDescent.cs
+++ b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/AcceleratedGradientDescent.cs
@@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
//Accelerated gradient descent score tracker
internal class AcceleratedGradientDescent : GradientDescent
diff --git a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/ConjugateGradientDescent.cs b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/ConjugateGradientDescent.cs
index 044ef7bcf3..bde9917c53 100644
--- a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/ConjugateGradientDescent.cs
+++ b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/ConjugateGradientDescent.cs
@@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
// Conjugate gradient descent
internal class ConjugateGradientDescent : GradientDescent
diff --git a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/GradientDescent.cs b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/GradientDescent.cs
index b2cf4f584b..5543763351 100644
--- a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/GradientDescent.cs
+++ b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/GradientDescent.cs
@@ -6,7 +6,7 @@
using System.Collections.Generic;
using System.Linq;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
internal class GradientDescent : OptimizationAlgorithm
{
diff --git a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/NoOptimizationAlgorithm.cs b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/NoOptimizationAlgorithm.cs
index 1925272122..9e53353316 100644
--- a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/NoOptimizationAlgorithm.cs
+++ b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/NoOptimizationAlgorithm.cs
@@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
///
/// This is dummy optimizer. As Random forest does not have any boosting based optimization, this is place holder to be consistent
diff --git a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/OptimizationAlgorithm.cs b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/OptimizationAlgorithm.cs
index 299d13a300..986a9c13f7 100644
--- a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/OptimizationAlgorithm.cs
+++ b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/OptimizationAlgorithm.cs
@@ -5,7 +5,7 @@
using System;
using System.Collections.Generic;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
//An interface that can be implemnted on
public interface IFastTrainingScoresUpdate
diff --git a/src/Microsoft.ML.FastTree/Training/Parallel/IParallelTraining.cs b/src/Microsoft.ML.FastTree/Training/Parallel/IParallelTraining.cs
index fccb0673eb..965026556c 100644
--- a/src/Microsoft.ML.FastTree/Training/Parallel/IParallelTraining.cs
+++ b/src/Microsoft.ML.FastTree/Training/Parallel/IParallelTraining.cs
@@ -4,7 +4,6 @@
using System;
using Microsoft.ML.EntryPoints;
-using Microsoft.ML.Trainers.FastTree.Internal;
namespace Microsoft.ML.Trainers.FastTree
{
diff --git a/src/Microsoft.ML.FastTree/Training/Parallel/SingleTrainer.cs b/src/Microsoft.ML.FastTree/Training/Parallel/SingleTrainer.cs
index 734cda69d9..52adef9a50 100644
--- a/src/Microsoft.ML.FastTree/Training/Parallel/SingleTrainer.cs
+++ b/src/Microsoft.ML.FastTree/Training/Parallel/SingleTrainer.cs
@@ -15,9 +15,8 @@
namespace Microsoft.ML.Trainers.FastTree
{
- using Microsoft.ML.Trainers.FastTree.Internal;
- using LeafSplitCandidates = Internal.LeastSquaresRegressionTreeLearner.LeafSplitCandidates;
- using SplitInfo = Internal.LeastSquaresRegressionTreeLearner.SplitInfo;
+ using LeafSplitCandidates = LeastSquaresRegressionTreeLearner.LeafSplitCandidates;
+ using SplitInfo = LeastSquaresRegressionTreeLearner.SplitInfo;
internal sealed class SingleTrainer : IParallelTraining
{
diff --git a/src/Microsoft.ML.FastTree/Training/RegressionTreeNodeDocuments.cs b/src/Microsoft.ML.FastTree/Training/RegressionTreeNodeDocuments.cs
index 75fe766cc6..8849aaa53e 100644
--- a/src/Microsoft.ML.FastTree/Training/RegressionTreeNodeDocuments.cs
+++ b/src/Microsoft.ML.FastTree/Training/RegressionTreeNodeDocuments.cs
@@ -5,7 +5,7 @@
using System.Collections.Generic;
using System.Linq;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
// RegressionTreeNodeDocuments represents an association between a node in a regression
// tree and documents belonging to that node.
diff --git a/src/Microsoft.ML.FastTree/Training/ScoreTracker.cs b/src/Microsoft.ML.FastTree/Training/ScoreTracker.cs
index 297aa06e29..882fa4a2d6 100644
--- a/src/Microsoft.ML.FastTree/Training/ScoreTracker.cs
+++ b/src/Microsoft.ML.FastTree/Training/ScoreTracker.cs
@@ -6,7 +6,7 @@
using System.Collections.Generic;
using System.Threading.Tasks;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
public class ScoreTracker
{
diff --git a/src/Microsoft.ML.FastTree/Training/StepSearch.cs b/src/Microsoft.ML.FastTree/Training/StepSearch.cs
index bffa5213ca..ca767a22ed 100644
--- a/src/Microsoft.ML.FastTree/Training/StepSearch.cs
+++ b/src/Microsoft.ML.FastTree/Training/StepSearch.cs
@@ -5,7 +5,7 @@
using System;
using System.Linq;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
internal interface IStepSearch
{
diff --git a/src/Microsoft.ML.FastTree/Training/Test.cs b/src/Microsoft.ML.FastTree/Training/Test.cs
index c76903d522..88f985f67a 100644
--- a/src/Microsoft.ML.FastTree/Training/Test.cs
+++ b/src/Microsoft.ML.FastTree/Training/Test.cs
@@ -8,7 +8,7 @@
using System.Threading;
using System.Threading.Tasks;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
public sealed class TestResult : IComparable
{
diff --git a/src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs b/src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs
index 3009bd500a..ae7f2a3b2f 100644
--- a/src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs
+++ b/src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs
@@ -4,7 +4,7 @@
using System;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
internal class RandomForestLeastSquaresTreeLearner : LeastSquaresRegressionTreeLearner
{
diff --git a/src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs b/src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs
index 823942f824..2cf8f44a75 100644
--- a/src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs
+++ b/src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs
@@ -9,7 +9,7 @@
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
#if USE_SINGLE_PRECISION
using FloatType = System.Single;
diff --git a/src/Microsoft.ML.FastTree/Training/TreeLearners/TreeLearner.cs b/src/Microsoft.ML.FastTree/Training/TreeLearners/TreeLearner.cs
index e5f633b9c9..93dd45ca60 100644
--- a/src/Microsoft.ML.FastTree/Training/TreeLearners/TreeLearner.cs
+++ b/src/Microsoft.ML.FastTree/Training/TreeLearners/TreeLearner.cs
@@ -4,7 +4,7 @@
using System;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
internal abstract class TreeLearner
{
diff --git a/src/Microsoft.ML.FastTree/Training/WinLossCalculator.cs b/src/Microsoft.ML.FastTree/Training/WinLossCalculator.cs
index 0320a23ded..a183ddc9f6 100644
--- a/src/Microsoft.ML.FastTree/Training/WinLossCalculator.cs
+++ b/src/Microsoft.ML.FastTree/Training/WinLossCalculator.cs
@@ -8,7 +8,7 @@
using System.Threading.Tasks;
using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
public sealed class WinLossCalculator
{
diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalQuantileRegressionTree.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalQuantileRegressionTree.cs
index f93111201a..7a8d5dcf23 100644
--- a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalQuantileRegressionTree.cs
+++ b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalQuantileRegressionTree.cs
@@ -7,7 +7,7 @@
using Microsoft.ML.Model;
using Float = System.Single;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
internal class InternalQuantileRegressionTree : InternalRegressionTree
{
@@ -50,7 +50,7 @@ public InternalQuantileRegressionTree(byte[] buffer, ref int position)
_instanceWeights = buffer.ToDoubleArray(ref position);
}
- public override void Save(ModelSaveContext ctx)
+ internal override void Save(ModelSaveContext ctx)
{
// *** Binary format ***
// double[]: Labels Distribution.
diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs
index 4727919022..31f65ae92e 100644
--- a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs
+++ b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs
@@ -15,7 +15,7 @@
using Microsoft.ML.Model.Pfa;
using Newtonsoft.Json.Linq;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
/// Note that is shared between FastTree and LightGBM assemblies,
/// so has .
@@ -409,7 +409,7 @@ protected void Save(ModelSaveContext ctx, TreeType code)
writer.WriteDoubleArray(_previousLeafValue);
}
- public virtual void Save(ModelSaveContext ctx)
+ internal virtual void Save(ModelSaveContext ctx)
{
Save(ctx, TreeType.Regression);
}
diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalTreeEnsemble.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalTreeEnsemble.cs
index 421f788ef6..62fd670fc8 100644
--- a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalTreeEnsemble.cs
+++ b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalTreeEnsemble.cs
@@ -14,7 +14,7 @@
using Microsoft.ML.Model.Pfa;
using Newtonsoft.Json.Linq;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
[BestFriend]
internal class InternalTreeEnsemble
@@ -56,7 +56,7 @@ public InternalTreeEnsemble(ModelLoadContext ctx, bool usingDefaultValues, bool
_firstInputInitializationContent = ctx.LoadStringOrNull();
}
- public void Save(ModelSaveContext ctx)
+ internal void Save(ModelSaveContext ctx)
{
// *** Binary format ***
// int: Number of trees
diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs
index f215dec82a..ab451165c8 100644
--- a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs
+++ b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs
@@ -4,16 +4,14 @@
using System.Collections.Generic;
using Microsoft.ML;
-using Microsoft.ML.Calibrator;
using Microsoft.ML.Data;
-using Microsoft.ML.Ensemble;
using Microsoft.ML.Internal.Calibration;
-using Microsoft.ML.Internal.Internallearn;
-using Microsoft.ML.Trainers.FastTree.Internal;
+using Microsoft.ML.Trainers.Ensemble;
+using Microsoft.ML.Trainers.FastTree;
[assembly: LoadableClass(typeof(TreeEnsembleCombiner), null, typeof(SignatureModelCombiner), "Fast Tree Model Combiner", "FastTreeCombiner")]
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
public sealed class TreeEnsembleCombiner : IModelCombiner
{
diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
index e1aa26f56f..18a0cf1aec 100644
--- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
+++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
@@ -393,7 +393,7 @@ public TreeEnsembleFeaturizerBindableMapper(IHostEnvironment env, ModelLoadConte
_totalLeafCount = CountLeaves(_ensemble);
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.FastTree/Utils/Algorithms.cs b/src/Microsoft.ML.FastTree/Utils/Algorithms.cs
index 9ad94b359a..553e0e7553 100644
--- a/src/Microsoft.ML.FastTree/Utils/Algorithms.cs
+++ b/src/Microsoft.ML.FastTree/Utils/Algorithms.cs
@@ -5,7 +5,7 @@
using System;
using System.Linq;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
public static class Algorithms
{
diff --git a/src/Microsoft.ML.FastTree/Utils/BlockingThreadPool.cs b/src/Microsoft.ML.FastTree/Utils/BlockingThreadPool.cs
index 76d623a850..e096de7005 100644
--- a/src/Microsoft.ML.FastTree/Utils/BlockingThreadPool.cs
+++ b/src/Microsoft.ML.FastTree/Utils/BlockingThreadPool.cs
@@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
///
/// This class wraps the standard .NET ThreadPool and adds the following functionality:
diff --git a/src/Microsoft.ML.FastTree/Utils/BufferPoolManager.cs b/src/Microsoft.ML.FastTree/Utils/BufferPoolManager.cs
index cd0d03cfc8..526fddba09 100644
--- a/src/Microsoft.ML.FastTree/Utils/BufferPoolManager.cs
+++ b/src/Microsoft.ML.FastTree/Utils/BufferPoolManager.cs
@@ -10,7 +10,7 @@
using System.Linq;
using System.Runtime.InteropServices;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
///
/// This class enables basic buffer pooling.
diff --git a/src/Microsoft.ML.FastTree/Utils/CompressUtils.cs b/src/Microsoft.ML.FastTree/Utils/CompressUtils.cs
index 0588e386ae..4c87a666a9 100644
--- a/src/Microsoft.ML.FastTree/Utils/CompressUtils.cs
+++ b/src/Microsoft.ML.FastTree/Utils/CompressUtils.cs
@@ -7,7 +7,7 @@
using System.IO;
using System.IO.Compression;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
internal struct BufferBlock
{
diff --git a/src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs b/src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs
index 939306cbaa..a9400739dd 100644
--- a/src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs
+++ b/src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs
@@ -8,7 +8,7 @@
using Microsoft.ML.Internal.Calibration;
using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
internal static class FastTreeIniFileUtils
{
diff --git a/src/Microsoft.ML.FastTree/Utils/LinqExtensions.cs b/src/Microsoft.ML.FastTree/Utils/LinqExtensions.cs
index 88d8825205..0b912e58bc 100644
--- a/src/Microsoft.ML.FastTree/Utils/LinqExtensions.cs
+++ b/src/Microsoft.ML.FastTree/Utils/LinqExtensions.cs
@@ -6,7 +6,7 @@
using System.Collections.Generic;
using System.Linq;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
public static class LinqExtensions
{
diff --git a/src/Microsoft.ML.FastTree/Utils/MD5Hasher.cs b/src/Microsoft.ML.FastTree/Utils/MD5Hasher.cs
index fc31ec097e..8a5fda81ea 100644
--- a/src/Microsoft.ML.FastTree/Utils/MD5Hasher.cs
+++ b/src/Microsoft.ML.FastTree/Utils/MD5Hasher.cs
@@ -7,7 +7,7 @@
using System.Security.Cryptography;
using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
public struct MD5Hash
{
diff --git a/src/Microsoft.ML.FastTree/Utils/MappedObjectPool.cs b/src/Microsoft.ML.FastTree/Utils/MappedObjectPool.cs
index bcfb68e210..ad9b54e6cd 100644
--- a/src/Microsoft.ML.FastTree/Utils/MappedObjectPool.cs
+++ b/src/Microsoft.ML.FastTree/Utils/MappedObjectPool.cs
@@ -5,7 +5,7 @@
using System;
using System.Linq;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
///
/// Implements a paging mechanism on indexed objects.
diff --git a/src/Microsoft.ML.FastTree/Utils/PseudorandomFunction.cs b/src/Microsoft.ML.FastTree/Utils/PseudorandomFunction.cs
index 59adf3a453..9a7800d5ac 100644
--- a/src/Microsoft.ML.FastTree/Utils/PseudorandomFunction.cs
+++ b/src/Microsoft.ML.FastTree/Utils/PseudorandomFunction.cs
@@ -5,7 +5,7 @@
using System;
using System.Linq;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
///
/// This class defines a psuedorandom function, mapping a number to
diff --git a/src/Microsoft.ML.FastTree/Utils/StreamExtensions.cs b/src/Microsoft.ML.FastTree/Utils/StreamExtensions.cs
index c12ba5d066..125b49fecb 100644
--- a/src/Microsoft.ML.FastTree/Utils/StreamExtensions.cs
+++ b/src/Microsoft.ML.FastTree/Utils/StreamExtensions.cs
@@ -5,7 +5,7 @@
using System.IO;
using System.IO.Compression;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
public static class StreamExtensions
{
diff --git a/src/Microsoft.ML.FastTree/Utils/ThreadTaskManager.cs b/src/Microsoft.ML.FastTree/Utils/ThreadTaskManager.cs
index d479c2c53b..aa0e06ffc0 100644
--- a/src/Microsoft.ML.FastTree/Utils/ThreadTaskManager.cs
+++ b/src/Microsoft.ML.FastTree/Utils/ThreadTaskManager.cs
@@ -7,7 +7,7 @@
using System.Linq;
using System.Threading.Tasks;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
internal static class ThreadTaskManager
{
diff --git a/src/Microsoft.ML.FastTree/Utils/Timer.cs b/src/Microsoft.ML.FastTree/Utils/Timer.cs
index 35770a5e50..feb8d631cc 100644
--- a/src/Microsoft.ML.FastTree/Utils/Timer.cs
+++ b/src/Microsoft.ML.FastTree/Utils/Timer.cs
@@ -6,7 +6,7 @@
using System.Text;
using System.Threading;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
using Stopwatch = System.Diagnostics.Stopwatch;
diff --git a/src/Microsoft.ML.FastTree/Utils/ToByteArrayExtensions.cs b/src/Microsoft.ML.FastTree/Utils/ToByteArrayExtensions.cs
index 108aeeac87..1452a3bbba 100644
--- a/src/Microsoft.ML.FastTree/Utils/ToByteArrayExtensions.cs
+++ b/src/Microsoft.ML.FastTree/Utils/ToByteArrayExtensions.cs
@@ -7,7 +7,7 @@
using System.Text;
using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
///
/// This class contains extension methods that support binary serialization of some base C# types
diff --git a/src/Microsoft.ML.FastTree/Utils/VectorUtils.cs b/src/Microsoft.ML.FastTree/Utils/VectorUtils.cs
index 8941d50da0..3b54b1a33b 100644
--- a/src/Microsoft.ML.FastTree/Utils/VectorUtils.cs
+++ b/src/Microsoft.ML.FastTree/Utils/VectorUtils.cs
@@ -5,7 +5,7 @@
using System;
using System.Text;
-namespace Microsoft.ML.Trainers.FastTree.Internal
+namespace Microsoft.ML.Trainers.FastTree
{
public class VectorUtils
{
diff --git a/src/Microsoft.ML.HalLearners/ComputeLRTrainingStdThroughHal.cs b/src/Microsoft.ML.HalLearners/ComputeLRTrainingStdThroughHal.cs
index bc30deabfc..ce3bad8b99 100644
--- a/src/Microsoft.ML.HalLearners/ComputeLRTrainingStdThroughHal.cs
+++ b/src/Microsoft.ML.HalLearners/ComputeLRTrainingStdThroughHal.cs
@@ -7,7 +7,7 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Trainers.HalLearners;
-namespace Microsoft.ML.Learners
+namespace Microsoft.ML.Trainers
{
using Mkl = OlsLinearRegressionTrainer.Mkl;
diff --git a/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs b/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs
index 1b635d49fc..564a5df3c6 100644
--- a/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs
+++ b/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs
@@ -2,11 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Trainers.HalLearners;
-using Microsoft.ML.Trainers.SymSgd;
using Microsoft.ML.Transforms.Projections;
namespace Microsoft.ML
diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
index 18a5d97b97..31c01b6324 100644
--- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
+++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
@@ -15,7 +15,6 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
-using Microsoft.ML.Learners;
using Microsoft.ML.Model;
using Microsoft.ML.Trainers.HalLearners;
using Microsoft.ML.Training;
diff --git a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs
index 8b7377bbe1..e23619e81c 100644
--- a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs
+++ b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs
@@ -16,8 +16,7 @@
using Microsoft.ML.Internal.Calibration;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
-using Microsoft.ML.Learners;
-using Microsoft.ML.Trainers.SymSgd;
+using Microsoft.ML.Trainers.HalLearners;
using Microsoft.ML.Training;
using Microsoft.ML.Transforms;
@@ -29,7 +28,7 @@
[assembly: LoadableClass(typeof(void), typeof(SymSgdClassificationTrainer), null, typeof(SignatureEntryPointModule), SymSgdClassificationTrainer.LoadNameValue)]
-namespace Microsoft.ML.Trainers.SymSgd
+namespace Microsoft.ML.Trainers.HalLearners
{
using TPredictor = CalibratedModelParametersBase;
diff --git a/src/Microsoft.ML.HalLearners/VectorWhitening.cs b/src/Microsoft.ML.HalLearners/VectorWhitening.cs
index f865ba4fcd..fc298c1ed1 100644
--- a/src/Microsoft.ML.HalLearners/VectorWhitening.cs
+++ b/src/Microsoft.ML.HalLearners/VectorWhitening.cs
@@ -464,7 +464,7 @@ private static void TrainModels(IHostEnvironment env, IChannel ch, float[][] col
}
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs
index cf5388d8da..fc9ed6c699 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs
@@ -141,7 +141,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs
index 8829d1dd2f..dab2dc17c2 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs
@@ -138,7 +138,7 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, TextType.Instance.ToString(), inputSchema[srcCol].Type.ToString());
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs
index a34fff871c..8359c4d9a6 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs
@@ -252,7 +252,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs
index 665cfd516d..a23c0d0a93 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs
@@ -235,7 +235,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
diff --git a/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs
index cc0a5da52e..da26e1f5dc 100644
--- a/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs
@@ -195,7 +195,7 @@ public ColInfoEx(ModelLoadContext ctx)
Interleave = ctx.Reader.ReadBoolByte();
}
- public void Save(ModelSaveContext ctx)
+ internal void Save(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
@@ -306,7 +306,7 @@ public static VectorToImageTransform Create(IHostEnvironment env, ModelLoadConte
});
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs
index e73bf4acb9..bf01aecbfd 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs
@@ -13,7 +13,6 @@
using Microsoft.ML.LightGBM;
using Microsoft.ML.Model;
using Microsoft.ML.Trainers.FastTree;
-using Microsoft.ML.Trainers.FastTree.Internal;
using Microsoft.ML.Training;
[assembly: LoadableClass(LightGbmBinaryTrainer.Summary, typeof(LightGbmBinaryTrainer), typeof(Options),
diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
index a3ec3ffef5..3b2a1b56af 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
@@ -12,7 +12,7 @@
using Microsoft.ML.Internal.Calibration;
using Microsoft.ML.LightGBM;
using Microsoft.ML.Trainers;
-using Microsoft.ML.Trainers.FastTree.Internal;
+using Microsoft.ML.Trainers.FastTree;
using Microsoft.ML.Training;
[assembly: LoadableClass(LightGbmMulticlassTrainer.Summary, typeof(LightGbmMulticlassTrainer), typeof(Options),
diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs
index 04e78e7a4d..559ed6b5e1 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs
@@ -11,7 +11,6 @@
using Microsoft.ML.LightGBM;
using Microsoft.ML.Model;
using Microsoft.ML.Trainers.FastTree;
-using Microsoft.ML.Trainers.FastTree.Internal;
using Microsoft.ML.Training;
[assembly: LoadableClass(LightGbmRankingTrainer.UserName, typeof(LightGbmRankingTrainer), typeof(Options),
diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs
index 676627aa84..2068746cea 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs
@@ -10,7 +10,6 @@
using Microsoft.ML.LightGBM;
using Microsoft.ML.Model;
using Microsoft.ML.Trainers.FastTree;
-using Microsoft.ML.Trainers.FastTree.Internal;
using Microsoft.ML.Training;
[assembly: LoadableClass(LightGbmRegressorTrainer.Summary, typeof(LightGbmRegressorTrainer), typeof(Options),
diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs
index 1ce9b0ccf4..695157f91b 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs
@@ -8,7 +8,7 @@
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
-using Microsoft.ML.Trainers.FastTree.Internal;
+using Microsoft.ML.Trainers.FastTree;
using Microsoft.ML.Training;
namespace Microsoft.ML.LightGBM
diff --git a/src/Microsoft.ML.LightGBM/WrappedLightGbmBooster.cs b/src/Microsoft.ML.LightGBM/WrappedLightGbmBooster.cs
index 04dc9f9f62..cf8b9e4c8e 100644
--- a/src/Microsoft.ML.LightGBM/WrappedLightGbmBooster.cs
+++ b/src/Microsoft.ML.LightGBM/WrappedLightGbmBooster.cs
@@ -6,7 +6,7 @@
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
-using Microsoft.ML.Trainers.FastTree.Internal;
+using Microsoft.ML.Trainers.FastTree;
namespace Microsoft.ML.LightGBM
{
diff --git a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs
index 30f20846c4..22bb76c019 100644
--- a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs
+++ b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs
@@ -261,7 +261,7 @@ internal OnnxTransformer(IHostEnvironment env, string[] outputColumnNames, strin
{
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.AssertValue(ctx);
@@ -368,7 +368,7 @@ private protected override Func GetDependenciesCore(Func a
return col => Enumerable.Range(0, _parent.Outputs.Length).Any(i => activeOutput(i)) && _inputColIndices.Any(i => i == col);
}
- public override void Save(ModelSaveContext ctx) => _parent.Save(ctx);
+ private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
private interface INamedOnnxValueGetter
{
diff --git a/src/Microsoft.ML.PCA/PcaTransformer.cs b/src/Microsoft.ML.PCA/PcaTransformer.cs
index b6b134a1e0..dd17255d0a 100644
--- a/src/Microsoft.ML.PCA/PcaTransformer.cs
+++ b/src/Microsoft.ML.PCA/PcaTransformer.cs
@@ -140,7 +140,7 @@ public TransformInfo(ModelLoadContext ctx)
Contracts.CheckDecode(MeanProjected == null || (MeanProjected.Length == Rank && FloatUtils.IsFinite(MeanProjected)));
}
- public void Save(ModelSaveContext ctx)
+ internal void Save(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
@@ -279,7 +279,7 @@ private static PrincipalComponentAnalysisTransformer Create(IHostEnvironment env
return new PrincipalComponentAnalysisTransformer(host, ctx);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Parquet/ParquetLoader.cs b/src/Microsoft.ML.Parquet/ParquetLoader.cs
index 0f733273fb..9f5eaa97d3 100644
--- a/src/Microsoft.ML.Parquet/ParquetLoader.cs
+++ b/src/Microsoft.ML.Parquet/ParquetLoader.cs
@@ -406,7 +406,7 @@ public RowCursor[] GetRowCursorSet(IEnumerable columnsNee
return new RowCursor[] { GetRowCursor(columnsNeeded, rand) };
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs b/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs
index 8f2c83a17f..98c83a59e2 100644
--- a/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs
+++ b/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs
@@ -253,7 +253,7 @@ public static PartitionedFileLoader Create(IHostEnvironment env, ModelLoadContex
ch => new PartitionedFileLoader(host, ctx, files));
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Parquet/PartitionedPathParser.cs b/src/Microsoft.ML.Parquet/PartitionedPathParser.cs
index 06ae2b9700..70a01be64b 100644
--- a/src/Microsoft.ML.Parquet/PartitionedPathParser.cs
+++ b/src/Microsoft.ML.Parquet/PartitionedPathParser.cs
@@ -148,7 +148,7 @@ public static SimplePartitionedPathParser Create(IHostEnvironment env, ModelLoad
ch => new SimplePartitionedPathParser(host, ctx));
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.SetVersionInfo(GetVersionInfo());
@@ -261,7 +261,7 @@ public static ParquetPartitionedPathParser Create(IHostEnvironment env, ModelLoa
ch => new ParquetPartitionedPathParser(host, ctx));
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.SetVersionInfo(GetVersionInfo());
diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs
index 8642188c2f..a0007d9108 100644
--- a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs
+++ b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs
@@ -16,7 +16,7 @@
using Microsoft.ML.Recommender.Internal;
using Microsoft.ML.Trainers.Recommender;
-[assembly: LoadableClass(typeof(MatrixFactorizationPredictor), null, typeof(SignatureLoadModel), "Matrix Factorization Predictor Executor", MatrixFactorizationPredictor.LoaderSignature)]
+[assembly: LoadableClass(typeof(MatrixFactorizationModelParameters), null, typeof(SignatureLoadModel), "Matrix Factorization Predictor Executor", MatrixFactorizationModelParameters.LoaderSignature)]
[assembly: LoadableClass(typeof(MatrixFactorizationPredictionTransformer), typeof(MatrixFactorizationPredictionTransformer),
null, typeof(SignatureLoadModel), "", MatrixFactorizationPredictionTransformer.LoaderSignature)]
@@ -24,12 +24,15 @@
namespace Microsoft.ML.Trainers.Recommender
{
///
- /// stores two factor matrices, P and Q, for approximating the training matrix, R, by P * Q,
- /// where * is a matrix multiplication. This predictor expects two inputs, row index and column index, and produces the (approximated)
+ /// Model parameters for matrix factorization recommender.
+ ///
+ ///
+ /// stores two factor matrices, P and Q, for approximating the training matrix, R, by P * Q,
+ /// where * is a matrix multiplication. This model expects two inputs, row index and column index, and produces the (approximated)
/// value at the location specified by the two inputs in R. More specifically, if input row and column indices are u and v, respectively.
/// The output (a scalar) would be the inner product product of the u-th row in P and the v-th column in Q.
- ///
- public sealed class MatrixFactorizationPredictor : IPredictor, ICanSaveModel, ICanSaveInTextFormat, ISchemaBindableMapper
+ ///
+ public sealed class MatrixFactorizationModelParameters : IPredictor, ICanSaveModel, ICanSaveInTextFormat, ISchemaBindableMapper
{
internal const string LoaderSignature = "MFPredictor";
internal const string RegistrationName = "MatrixFactorizationPredictor";
@@ -43,33 +46,42 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
- loaderAssemblyName: typeof(MatrixFactorizationPredictor).Assembly.FullName);
+ loaderAssemblyName: typeof(MatrixFactorizationModelParameters).Assembly.FullName);
}
private const uint VersionNoMinCount = 0x00010002;
private readonly IHost _host;
- // The number of rows.
- private readonly int _numberOfRows;
- // The number of columns.
- private readonly int _numberofColumns;
- // The rank of the factor matrices.
- private readonly int _approximationRank;
- // Packed _numberOfRows by _approximationRank matrix.
- private readonly float[] _leftFactorMatrix;
- // Packed _approximationRank by _numberofColumns matrix.
- private readonly float[] _rightFactorMatrix;
-
- public PredictionKind PredictionKind
- {
- get { return PredictionKind.Recommendation; }
- }
+ /// The number of rows.
+ public readonly int NumberOfRows;
+ /// The number of columns.
+ public readonly int NumberOfColumns;
+ /// The rank of the factor matrices.
+ public readonly int ApproximationRank;
+ ///
+ /// Left approximation matrix
+ ///
+ ///
+ /// This is two dimensional matrix with size of * flattened into one-dimensional matrix.
+ /// Row by row.
+ ///
+ public readonly IReadOnlyList LeftFactorMatrix;
+ ///
+ /// Left approximation matrix
+ ///
+ ///
+ /// This is two dimensional matrix with size of * flattened into one-dimensional matrix.
+ /// Row by row.
+ ///
+ public readonly IReadOnlyList RightFactorMatrix;
+
+ public PredictionKind PredictionKind => PredictionKind.Recommendation;
- public ColumnType OutputType { get { return NumberType.Float; } }
+ private ColumnType OutputType => NumberType.Float;
- public ColumnType MatrixColumnIndexType { get; }
- public ColumnType MatrixRowIndexType { get; }
+ internal ColumnType MatrixColumnIndexType { get; }
+ internal ColumnType MatrixRowIndexType { get; }
- internal MatrixFactorizationPredictor(IHostEnvironment env, SafeTrainingAndModelBuffer buffer, KeyType matrixColumnIndexType, KeyType matrixRowIndexType)
+ internal MatrixFactorizationModelParameters(IHostEnvironment env, SafeTrainingAndModelBuffer buffer, KeyType matrixColumnIndexType, KeyType matrixRowIndexType)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(RegistrationName);
@@ -78,18 +90,19 @@ internal MatrixFactorizationPredictor(IHostEnvironment env, SafeTrainingAndModel
_host.CheckValue(buffer, nameof(buffer));
_host.CheckValue(matrixColumnIndexType, nameof(matrixColumnIndexType));
_host.CheckValue(matrixRowIndexType, nameof(matrixRowIndexType));
-
- buffer.Get(out _numberOfRows, out _numberofColumns, out _approximationRank, out _leftFactorMatrix, out _rightFactorMatrix);
- _host.Assert(_numberofColumns == matrixColumnIndexType.GetCountAsInt32(_host));
- _host.Assert(_numberOfRows == matrixRowIndexType.GetCountAsInt32(_host));
- _host.Assert(_leftFactorMatrix.Length == _numberOfRows * _approximationRank);
- _host.Assert(_rightFactorMatrix.Length == _numberofColumns * _approximationRank);
+ buffer.Get(out NumberOfRows, out NumberOfColumns, out ApproximationRank, out var leftFactorMatrix, out var rightFactorMatrix);
+ LeftFactorMatrix = leftFactorMatrix;
+ RightFactorMatrix = rightFactorMatrix;
+ _host.Assert(NumberOfColumns == matrixColumnIndexType.GetCountAsInt32(_host));
+ _host.Assert(NumberOfRows == matrixRowIndexType.GetCountAsInt32(_host));
+ _host.Assert(LeftFactorMatrix.Count == NumberOfRows * ApproximationRank);
+ _host.Assert(RightFactorMatrix.Count == ApproximationRank * NumberOfColumns);
MatrixColumnIndexType = matrixColumnIndexType;
MatrixRowIndexType = matrixRowIndexType;
}
- private MatrixFactorizationPredictor(IHostEnvironment env, ModelLoadContext ctx)
+ private MatrixFactorizationModelParameters(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(RegistrationName);
@@ -100,49 +113,49 @@ private MatrixFactorizationPredictor(IHostEnvironment env, ModelLoadContext ctx)
// float[m * k]: the left factor matrix
// float[k * n]: the right factor matrix
- _numberOfRows = ctx.Reader.ReadInt32();
- _host.CheckDecode(_numberOfRows > 0);
+ NumberOfRows = ctx.Reader.ReadInt32();
+ _host.CheckDecode(NumberOfRows > 0);
if (ctx.Header.ModelVerWritten < VersionNoMinCount)
{
ulong mMin = ctx.Reader.ReadUInt64();
// We no longer support non zero Min for KeyType.
_host.CheckDecode(mMin == 0);
- _host.CheckDecode((ulong)_numberOfRows <= ulong.MaxValue - mMin);
+ _host.CheckDecode((ulong)NumberOfRows <= ulong.MaxValue - mMin);
}
- _numberofColumns = ctx.Reader.ReadInt32();
- _host.CheckDecode(_numberofColumns > 0);
+ NumberOfColumns = ctx.Reader.ReadInt32();
+ _host.CheckDecode(NumberOfColumns > 0);
if (ctx.Header.ModelVerWritten < VersionNoMinCount)
{
ulong nMin = ctx.Reader.ReadUInt64();
// We no longer support non zero Min for KeyType.
_host.CheckDecode(nMin == 0);
- _host.CheckDecode((ulong)_numberofColumns <= ulong.MaxValue - nMin);
+ _host.CheckDecode((ulong)NumberOfColumns <= ulong.MaxValue - nMin);
}
- _approximationRank = ctx.Reader.ReadInt32();
- _host.CheckDecode(_approximationRank > 0);
+ ApproximationRank = ctx.Reader.ReadInt32();
+ _host.CheckDecode(ApproximationRank > 0);
- _leftFactorMatrix = Utils.ReadSingleArray(ctx.Reader, checked(_numberOfRows * _approximationRank));
- _rightFactorMatrix = Utils.ReadSingleArray(ctx.Reader, checked(_numberofColumns * _approximationRank));
+ LeftFactorMatrix = Utils.ReadSingleArray(ctx.Reader, checked(NumberOfRows * ApproximationRank));
+ RightFactorMatrix = Utils.ReadSingleArray(ctx.Reader, checked(NumberOfColumns * ApproximationRank));
- MatrixColumnIndexType = new KeyType(typeof(uint), _numberofColumns);
- MatrixRowIndexType = new KeyType(typeof(uint), _numberOfRows);
+ MatrixColumnIndexType = new KeyType(typeof(uint), NumberOfColumns);
+ MatrixRowIndexType = new KeyType(typeof(uint), NumberOfRows);
}
///
/// Load model from the given context
///
- public static MatrixFactorizationPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
+ private static MatrixFactorizationModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
- return new MatrixFactorizationPredictor(env, ctx);
+ return new MatrixFactorizationModelParameters(env, ctx);
}
///
/// Save model to the given context
///
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
@@ -154,16 +167,16 @@ public void Save(ModelSaveContext ctx)
// float[m * k]: the left factor matrix
// float[k * n]: the right factor matrix
- _host.Check(_numberOfRows > 0, "Number of rows must be positive");
- _host.Check(_numberofColumns > 0, "Number of columns must be positive");
- _host.Check(_approximationRank > 0, "Number of latent factors must be positive");
- ctx.Writer.Write(_numberOfRows);
- ctx.Writer.Write(_numberofColumns);
- ctx.Writer.Write(_approximationRank);
- _host.Check(Utils.Size(_leftFactorMatrix) == _numberOfRows * _approximationRank, "Unexpected matrix size of a factor matrix (matrix P in LIBMF paper)");
- _host.Check(Utils.Size(_rightFactorMatrix) == _numberofColumns * _approximationRank, "Unexpected matrix size of a factor matrix (matrix Q in LIBMF paper)");
- Utils.WriteSinglesNoCount(ctx.Writer, _leftFactorMatrix.AsSpan(0, _numberOfRows * _approximationRank));
- Utils.WriteSinglesNoCount(ctx.Writer, _rightFactorMatrix.AsSpan(0, _numberofColumns * _approximationRank));
+ _host.Check(NumberOfRows > 0, "Number of rows must be positive");
+ _host.Check(NumberOfColumns > 0, "Number of columns must be positive");
+ _host.Check(ApproximationRank > 0, "Number of latent factors must be positive");
+ ctx.Writer.Write(NumberOfRows);
+ ctx.Writer.Write(NumberOfColumns);
+ ctx.Writer.Write(ApproximationRank);
+ _host.Check(Utils.Size(LeftFactorMatrix) == NumberOfRows * ApproximationRank, "Unexpected matrix size of a factor matrix (matrix P in LIBMF paper)");
+ _host.Check(Utils.Size(RightFactorMatrix) == NumberOfColumns * ApproximationRank, "Unexpected matrix size of a factor matrix (matrix Q in LIBMF paper)");
+ Utils.WriteSinglesNoCount(ctx.Writer, LeftFactorMatrix as float[]);
+ Utils.WriteSinglesNoCount(ctx.Writer, RightFactorMatrix as float[]);
}
///
@@ -172,20 +185,20 @@ public void Save(ModelSaveContext ctx)
void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)
{
writer.WriteLine("# Imputed matrix is P * Q'");
- writer.WriteLine("# P in R^({0} x {1}), rows correpond to Y item", _numberOfRows, _approximationRank);
- for (int i = 0; i < _leftFactorMatrix.Length; ++i)
+ writer.WriteLine("# P in R^({0} x {1}), rows correpond to Y item", NumberOfRows, ApproximationRank);
+ for (int i = 0; i < LeftFactorMatrix.Count; ++i)
{
- writer.Write(_leftFactorMatrix[i].ToString("G"));
- if (i % _approximationRank == _approximationRank - 1)
+ writer.Write(LeftFactorMatrix[i].ToString("G"));
+ if (i % ApproximationRank == ApproximationRank - 1)
writer.WriteLine();
else
writer.Write('\t');
}
- writer.WriteLine("# Q in R^({0} x {1}), rows correpond to X item", _numberofColumns, _approximationRank);
- for (int i = 0; i < _rightFactorMatrix.Length; ++i)
+ writer.WriteLine("# Q in R^({0} x {1}), rows correpond to X item", NumberOfColumns, ApproximationRank);
+ for (int i = 0; i < RightFactorMatrix.Count; ++i)
{
- writer.Write(_rightFactorMatrix[i].ToString("G"));
- if (i % _approximationRank == _approximationRank - 1)
+ writer.Write(RightFactorMatrix[i].ToString("G"));
+ if (i % ApproximationRank == ApproximationRank - 1)
writer.WriteLine();
else
writer.Write('\t');
@@ -241,7 +254,7 @@ private void MapperCore(in uint srcCol, ref uint srcRow, ref float dst)
// training. For higher-than-expected values, the predictor version would return
// 0, rather than NaN as we do here. It is in my mind an open question as to what
// is actually correct.
- if (srcRow == 0 || srcRow > _numberOfRows || srcCol == 0 || srcCol > _numberofColumns)
+ if (srcRow == 0 || srcRow > NumberOfRows || srcCol == 0 || srcCol > NumberOfColumns)
{
dst = float.NaN;
return;
@@ -251,15 +264,15 @@ private void MapperCore(in uint srcCol, ref uint srcRow, ref float dst)
private float Score(int columnIndex, int rowIndex)
{
- _host.Assert(0 <= rowIndex && rowIndex < _numberOfRows);
- _host.Assert(0 <= columnIndex && columnIndex < _numberofColumns);
+ _host.Assert(0 <= rowIndex && rowIndex < NumberOfRows);
+ _host.Assert(0 <= columnIndex && columnIndex < NumberOfColumns);
float score = 0;
// Starting position of the rowIndex-th row in the left factor factor matrix
- int rowOffset = rowIndex * _approximationRank;
+ int rowOffset = rowIndex * ApproximationRank;
// Starting position of the columnIndex-th column in the right factor factor matrix
- int columnOffset = columnIndex * _approximationRank;
- for (int i = 0; i < _approximationRank; i++)
- score += _leftFactorMatrix[rowOffset + i] * _rightFactorMatrix[columnOffset + i];
+ int columnOffset = columnIndex * ApproximationRank;
+ for (int i = 0; i < ApproximationRank; i++)
+ score += LeftFactorMatrix[rowOffset + i] * RightFactorMatrix[columnOffset + i];
return score;
}
@@ -276,7 +289,7 @@ ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSc
private sealed class RowMapper : ISchemaBoundRowMapper
{
- private readonly MatrixFactorizationPredictor _parent;
+ private readonly MatrixFactorizationModelParameters _parent;
// The tail "ColumnIndex" means the column index in IDataView
private readonly int _matrixColumnIndexColumnIndex;
private readonly int _matrixRowIndexCololumnIndex;
@@ -289,7 +302,7 @@ private sealed class RowMapper : ISchemaBoundRowMapper
public RoleMappedSchema InputRoleMappedSchema { get; }
- public RowMapper(IHostEnvironment env, MatrixFactorizationPredictor parent, RoleMappedSchema schema, Schema outputSchema)
+ public RowMapper(IHostEnvironment env, MatrixFactorizationModelParameters parent, RoleMappedSchema schema, Schema outputSchema)
{
Contracts.AssertValue(parent);
_env = env;
@@ -375,13 +388,16 @@ public Row GetRow(Row input, Func active)
}
}
- public sealed class MatrixFactorizationPredictionTransformer : PredictionTransformerBase, ICanSaveModel
+ ///
+ /// Trains a . It factorizes the training matrix into the product of two low-rank matrices.
+ ///
+ public sealed class MatrixFactorizationPredictionTransformer : PredictionTransformerBase
{
- public const string LoaderSignature = "MaFactPredXf";
- public string MatrixColumnIndexColumnName { get; }
- public string MatrixRowIndexColumnName { get; }
- public ColumnType MatrixColumnIndexColumnType { get; }
- public ColumnType MatrixRowIndexColumnType { get; }
+ internal const string LoaderSignature = "MaFactPredXf";
+ internal string MatrixColumnIndexColumnName { get; }
+ internal string MatrixRowIndexColumnName { get; }
+ internal ColumnType MatrixColumnIndexColumnType { get; }
+ internal ColumnType MatrixRowIndexColumnType { get; }
///
/// Build a transformer based on matrix factorization predictor (model) and the input schema (trainSchema). The created
@@ -395,7 +411,7 @@ public sealed class MatrixFactorizationPredictionTransformer : PredictionTransfo
/// The name of the column used as role in matrix factorization world
/// The name of the column used as role in matrix factorization world
/// A string attached to the output column name of this transformer
- public MatrixFactorizationPredictionTransformer(IHostEnvironment env, MatrixFactorizationPredictor model, Schema trainSchema,
+ internal MatrixFactorizationPredictionTransformer(IHostEnvironment env, MatrixFactorizationModelParameters model, Schema trainSchema,
string matrixColumnIndexColumnName, string matrixRowIndexColumnName, string scoreColumnNameSuffix = "")
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MatrixFactorizationPredictionTransformer)), model, trainSchema)
{
@@ -432,7 +448,7 @@ private RoleMappedSchema GetSchema()
/// The counter constructor of re-creating from the context where
/// the original transform is saved.
///
- public MatrixFactorizationPredictionTransformer(IHostEnvironment host, ModelLoadContext ctx)
+ private MatrixFactorizationPredictionTransformer(IHostEnvironment host, ModelLoadContext ctx)
: base(Contracts.CheckRef(host, nameof(host)).Register(nameof(MatrixFactorizationPredictionTransformer)), ctx)
{
// *** Binary format ***
@@ -458,6 +474,10 @@ public MatrixFactorizationPredictionTransformer(IHostEnvironment host, ModelLoad
Scorer = new GenericScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
}
+ ///
+ /// Schema propagation for transformers.
+ /// Returns the output schema of the data, if the input schema is like the one provided.
+ ///
public override Schema GetOutputSchema(Schema inputSchema)
{
if (!inputSchema.TryGetColumnIndex(MatrixColumnIndexColumnName, out int xCol))
@@ -468,7 +488,7 @@ public override Schema GetOutputSchema(Schema inputSchema)
return Transform(new EmptyDataView(Host, inputSchema)).Schema;
}
- public void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs
index c3696d22bf..e09eab6ace 100644
--- a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs
+++ b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs
@@ -83,15 +83,37 @@ namespace Microsoft.ML.Trainers
///
///
///
///
///
- public sealed class MatrixFactorizationTrainer : TrainerBase,
+ public sealed class MatrixFactorizationTrainer : TrainerBase,
IEstimator
{
- public enum LossFunctionType { SquareLossRegression = 0, SquareLossOneClass = 12 };
+ ///
+ /// Type of loss function.
+ ///
+ public enum LossFunctionType
+ {
+ ///
+ /// Used in traditional collaborative filtering problem with squared loss.
+ ///
+ ///
+ /// See Equation (1).
+ ///
+ SquareLossRegression = 0,
+ ///
+ /// Used in implicit-feedback recommendation problem.
+ ///
+ ///
+ /// See Equation (3).
+ ///
+ SquareLossOneClass = 12
+ };
+ ///
+ /// Advanced options for the .
+ ///
public sealed class Options
{
///
@@ -110,40 +132,69 @@ public sealed class Options
public string LabelColumnName;
///
- /// Loss function minimized for finding factor matrices. Two values are allowed, 0 or 12. The values 0 means traditional collaborative filtering
- /// problem with squared loss. The value 12 triggers one-class matrix factorization for implicit-feedback recommendation problem.
+ /// Loss function minimized for finding factor matrices.
///
+ ///
+ /// Two values are allowed, or .
+ /// The means traditional collaborative filtering problem with squared loss.
+ /// The triggers one-class matrix factorization for implicit-feedback recommendation problem.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Loss function minimized for finding factor matrices.")]
[TGUI(SuggestedSweeps = "0,12")]
[TlcModule.SweepableDiscreteParam("LossFunction", new object[] { LossFunctionType.SquareLossRegression, LossFunctionType.SquareLossOneClass })]
- public LossFunctionType LossFunction = LossFunctionType.SquareLossRegression;
+ public LossFunctionType LossFunction = Defaults.LossFunction;
+ ///
+ /// Regularization parameter.
+ ///
+ ///
+ /// It's the weight of factor matrices Frobenius norms in the objective function minimized by matrix factorization's algorithm. A small value could cause over-fitting.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Regularization parameter. " +
- "It's the weight of factor matrices' norms in the objective function minimized by matrix factorization's algorithm. " +
+ "It's the weight of factor matrices Frobenius norms in the objective function minimized by matrix factorization's algorithm. " +
"A small value could cause over-fitting.")]
[TGUI(SuggestedSweeps = "0.01,0.05,0.1,0.5,1")]
[TlcModule.SweepableDiscreteParam("Lambda", new object[] { 0.01f, 0.05f, 0.1f, 0.5f, 1f })]
- public double Lambda = 0.1;
+ public double Lambda = Defaults.Lambda;
+ ///
+ /// Rank of approximation matrixes.
+ ///
+ ///
+ /// If input data has size of m-by-n we would build two approximation matrixes m-by-k and k-by-n where k is approximation rank.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Latent space dimension (denoted by k). If the factorized matrix is m-by-n, " +
"two factor matrices found by matrix factorization are m-by-k and k-by-n, respectively. " +
- "This value is also known as the rank of matrix factorization because k is generally much smaller than m and n.")]
+ "This value is also known as the rank of matrix factorization because k is generally much smaller than m and n.", ShortName = "K")]
[TGUI(SuggestedSweeps = "8,16,64,128")]
[TlcModule.SweepableDiscreteParam("K", new object[] { 8, 16, 64, 128 })]
- public int K = 8;
+ public int ApproximationRank = Defaults.ApproximationRank;
+ ///
+ /// Number of training iterations.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Training iterations; that is, the times that the training algorithm iterates through the whole training data once.", ShortName = "iter")]
[TGUI(SuggestedSweeps = "10,20,40")]
[TlcModule.SweepableDiscreteParam("NumIterations", new object[] { 10, 20, 40 })]
- public int NumIterations = 20;
-
+ public int NumIterations = Defaults.NumIterations;
+
+ ///
+ /// Initial learning rate. It specifies the speed of the training algorithm.
+ ///
+ ///
+ /// Small value may increase the number of iterations needed to achieve a reasonable result.
+ /// Large value may lead to numerical difficulty such as a infinity value.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Initial learning rate. It specifies the speed of the training algorithm. " +
- "Small value may increase the number of iterations needed to achieve a reasonable result. Large value may lead to numerical difficulty such as a infinity value.")]
+ "Small value may increase the number of iterations needed to achieve a reasonable result. Large value may lead to numerical difficulty such as a infinity value.", ShortName = "Eta")]
[TGUI(SuggestedSweeps = "0.001,0.01,0.1")]
[TlcModule.SweepableDiscreteParam("Eta", new object[] { 0.001f, 0.01f, 0.1f })]
- public double Eta = 0.1;
+ public double LearningRate = Defaults.LearningRate;
///
+ /// Importance of unobserved entries' loss in one-class matrix factorization. Applicable if set to
+ ///
+ ///
/// Importance of unobserved (i.e., negative) entries' loss in one-class matrix factorization.
/// In general, only a few of matrix entries (e.g., less than 1%) in the training are observed (i.e., positive).
/// To balance the contributions from unobserved and obverved in the overall loss function, this parameter is
@@ -154,32 +205,57 @@ public sealed class Options
/// Alpha = (# of observed entries) / (# of unobserved entries) can make observed and unobserved entries equally important
/// in the minimized loss function. However, the best setting in machine learning is alwasy data-depedent so user still needs to
/// try multiple values.
- ///
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Importance of unobserved entries' loss in one-class matrix factorization.")]
[TGUI(SuggestedSweeps = "1,0.01,0.0001,0.000001")]
[TlcModule.SweepableDiscreteParam("Alpha", new object[] { 1f, 0.01f, 0.0001f, 0.000001f })]
- public double Alpha = 0.0001;
+ public double Alpha = Defaults.Alpha;
///
- /// Desired negative entries value in one-class matrix factorization. In one-class matrix factorization, all matrix values observed are one
- /// (which can be viewed as positive cases in binary classification) while unobserved values (which can be viewed as negative cases in binary
- /// classification) need to be specified manually using this option.
+ /// Desired negative entries value in one-class matrix factorization. Applicable if set to
///
+ ///
+ /// In one-class matrix factorization, all matrix values observed are one (which can be viewed as positive cases in binary classification)
+ /// while unobserved values (which can be viewed as negative cases in binary classification) need to be specified manually using this option.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Desired negative entries' value in one-class matrix factorization")]
[TGUI(SuggestedSweeps = "0.000001,0,0001,0.01")]
[TlcModule.SweepableDiscreteParam("C", new object[] { 0.000001f, 0.0001f, 0.01f })]
- public double C = 0.000001f;
+ public double C = Defaults.C;
+ ///
+ /// Number of threads will be used during training. If unspecified all aviable threads will be use.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of threads can be used in the training procedure.", ShortName = "t")]
public int? NumThreads;
+ ///
+ /// Suppress writing additional information to output.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Suppress writing additional information to output.")]
- public bool Quiet;
+ public bool Quiet = Defaults.Quiet;
+ ///
+ /// Force the factor matrices to be non-negative.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Force the factor matrices to be non-negative.", ShortName = "nn")]
- public bool NonNegative;
+ public bool NonNegative = Defaults.NonNegative;
};
+ [BestFriend]
+ internal static class Defaults
+ {
+ public const bool Quiet = false;
+ public const bool NonNegative = false;
+ public const double C = 0.000001f;
+ public const double Alpha = 0.0001f;
+ public const double LearningRate = 0.1;
+ public const int NumIterations = 20;
+ public const int ApproximationRank = 8;
+ public const double Lambda = 0.1;
+ public const LossFunctionType LossFunction = LossFunctionType.SquareLossRegression;
+ }
+
internal const string Summary = "From pairs of row/column indices and a value of a matrix, this trains a predictor capable of filling in unknown entries of the matrix, "
+ "using a low-rank matrix factorization. This technique is often used in recommender system, where the row and column indices indicate users and items, "
+ "and the values of the matrix are ratings. ";
@@ -197,7 +273,8 @@ public sealed class Options
private readonly bool _doNmf;
public override PredictionKind PredictionKind => PredictionKind.Recommendation;
- public const string LoadNameValue = "MatrixFactorization";
+
+ internal const string LoadNameValue = "MatrixFactorization";
///
/// The row index, column index, and label columns needed to specify the training matrix. This trainer uses tuples of (row index, column index, label value) to specify a matrix.
@@ -238,18 +315,18 @@ internal MatrixFactorizationTrainer(IHostEnvironment env, Options options) : bas
{
const string posError = "Parameter must be positive";
Host.CheckValue(options, nameof(options));
- Host.CheckUserArg(options.K > 0, nameof(options.K), posError);
+ Host.CheckUserArg(options.ApproximationRank > 0, nameof(options.ApproximationRank), posError);
Host.CheckUserArg(!options.NumThreads.HasValue || options.NumThreads > 0, nameof(options.NumThreads), posError);
Host.CheckUserArg(options.NumIterations > 0, nameof(options.NumIterations), posError);
Host.CheckUserArg(options.Lambda > 0, nameof(options.Lambda), posError);
- Host.CheckUserArg(options.Eta > 0, nameof(options.Eta), posError);
+ Host.CheckUserArg(options.LearningRate > 0, nameof(options.LearningRate), posError);
Host.CheckUserArg(options.Alpha > 0, nameof(options.Alpha), posError);
_fun = (int)options.LossFunction;
_lambda = options.Lambda;
- _k = options.K;
+ _k = options.ApproximationRank;
_iter = options.NumIterations;
- _eta = options.Eta;
+ _eta = options.LearningRate;
_alpha = options.Alpha;
_c = options.C;
_threads = options.NumThreads ?? Environment.ProcessorCount;
@@ -270,21 +347,27 @@ internal MatrixFactorizationTrainer(IHostEnvironment env, Options options) : bas
/// The name of the column hosting the matrix's column IDs.
/// The name of the column hosting the matrix's row IDs.
/// The name of the label column.
+ /// Rank of approximation matrixes.
+ /// Initial learning rate. It specifies the speed of the training algorithm.
+ /// Number of training iterations.
[BestFriend]
internal MatrixFactorizationTrainer(IHostEnvironment env,
string matrixColumnIndexColumnName,
string matrixRowIndexColumnName,
- string labelColumn = DefaultColumnNames.Label)
+ string labelColumn = DefaultColumnNames.Label,
+ int approximationRank = Defaults.ApproximationRank,
+ double learningRate = Defaults.LearningRate,
+ int numIterations = Defaults.NumIterations)
: base(env, LoadNameValue)
{
var args = new Options();
_fun = (int)args.LossFunction;
- _lambda = args.Lambda;
- _k = args.K;
- _iter = args.NumIterations;
- _eta = args.Eta;
+ _k = approximationRank;
+ _iter = numIterations;
+ _eta = learningRate;
_alpha = args.Alpha;
+ _lambda = args.Lambda;
_c = args.C;
_threads = args.NumThreads ?? Environment.ProcessorCount;
_quiet = args.Quiet;
@@ -301,7 +384,7 @@ internal MatrixFactorizationTrainer(IHostEnvironment env,
/// Train a matrix factorization model based on training data, validation data, and so on in the given context.
///
/// The information collection needed for training. for details.
- private protected override MatrixFactorizationPredictor Train(TrainContext context)
+ private protected override MatrixFactorizationModelParameters Train(TrainContext context)
{
Host.CheckValue(context, nameof(context));
using (var ch = Host.Start("Training"))
@@ -310,7 +393,7 @@ private protected override MatrixFactorizationPredictor Train(TrainContext conte
}
}
- private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data, RoleMappedData validData = null)
+ private MatrixFactorizationModelParameters TrainCore(IChannel ch, RoleMappedData data, RoleMappedData validData = null)
{
Host.AssertValue(ch);
ch.AssertValue(data);
@@ -321,7 +404,7 @@ private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data,
var labelCol = data.Schema.Label.Value;
if (labelCol.Type != NumberType.R4 && labelCol.Type != NumberType.R8)
throw ch.Except("Column '{0}' for label should be floating point, but is instead {1}", labelCol.Name, labelCol.Type);
- MatrixFactorizationPredictor predictor;
+ MatrixFactorizationModelParameters predictor;
if (validData != null)
{
ch.CheckValue(validData, nameof(validData));
@@ -362,7 +445,7 @@ private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data,
using (var buffer = PrepareBuffer())
{
buffer.Train(ch, rowCount, colCount, cursor, labGetter, matrixRowIndexGetter, matrixColumnIndexGetter);
- predictor = new MatrixFactorizationPredictor(Host, buffer, (KeyType)matrixColumnIndexColInfo.Type, (KeyType)matrixRowIndexColInfo.Type);
+ predictor = new MatrixFactorizationModelParameters(Host, buffer, (KeyType)matrixColumnIndexColInfo.Type, (KeyType)matrixRowIndexColInfo.Type);
}
}
else
@@ -380,7 +463,7 @@ private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data,
buffer.TrainWithValidation(ch, rowCount, colCount,
cursor, labGetter, matrixRowIndexGetter, matrixColumnIndexGetter,
validCursor, validLabelGetter, validMatrixRowIndexGetter, validMatrixColumnIndexGetter);
- predictor = new MatrixFactorizationPredictor(Host, buffer, (KeyType)matrixColumnIndexColInfo.Type, (KeyType)matrixRowIndexColInfo.Type);
+ predictor = new MatrixFactorizationModelParameters(Host, buffer, (KeyType)matrixColumnIndexColInfo.Type, (KeyType)matrixRowIndexColInfo.Type);
}
}
}
@@ -403,7 +486,7 @@ private SafeTrainingAndModelBuffer PrepareBuffer()
/// The validation data set.
public MatrixFactorizationPredictionTransformer Train(IDataView trainData, IDataView validationData = null)
{
- MatrixFactorizationPredictor model = null;
+ MatrixFactorizationModelParameters model = null;
var roles = new List>();
roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Label, LabelName));
@@ -427,6 +510,10 @@ public MatrixFactorizationPredictionTransformer Train(IDataView trainData, IData
/// The training data set.
public MatrixFactorizationPredictionTransformer Fit(IDataView input) => Train(input);
+ ///
+ /// Schema propagation for transformers. Returns the output schema of the data, if
+ /// the input schema is like the one provided.
+ ///
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
diff --git a/src/Microsoft.ML.Recommender/RecommenderCatalog.cs b/src/Microsoft.ML.Recommender/RecommenderCatalog.cs
index 448595ec41..637e4962ea 100644
--- a/src/Microsoft.ML.Recommender/RecommenderCatalog.cs
+++ b/src/Microsoft.ML.Recommender/RecommenderCatalog.cs
@@ -55,11 +55,24 @@ internal RecommendationTrainers(RecommendationCatalog catalog)
/// The name of the column hosting the matrix's column IDs.
/// The name of the column hosting the matrix's row IDs.
/// The name of the label column.
+ /// Rank of approximation matrixes.
+ /// Initial learning rate. It specifies the speed of the training algorithm.
+ /// Number of training iterations.
+ ///
+ ///
+ ///
+ ///
public MatrixFactorizationTrainer MatrixFactorization(
string matrixColumnIndexColumnName,
string matrixRowIndexColumnName,
- string labelColumn = DefaultColumnNames.Label)
- => new MatrixFactorizationTrainer(Owner.Environment, matrixColumnIndexColumnName, matrixRowIndexColumnName, labelColumn);
+ string labelColumn = DefaultColumnNames.Label,
+ int approximationRank = MatrixFactorizationTrainer.Defaults.ApproximationRank,
+ double learningRate = MatrixFactorizationTrainer.Defaults.LearningRate,
+ int numIterations = MatrixFactorizationTrainer.Defaults.NumIterations)
+ => new MatrixFactorizationTrainer(Owner.Environment, matrixColumnIndexColumnName, matrixRowIndexColumnName, labelColumn,
+ approximationRank, learningRate, numIterations);
///
/// Train a matrix factorization model. It factorizes the training matrix into the product of two low-rank matrices.
@@ -74,7 +87,7 @@ public MatrixFactorizationTrainer MatrixFactorization(
///
///
///
///
public MatrixFactorizationTrainer MatrixFactorization(
@@ -115,13 +128,13 @@ public RegressionMetrics Evaluate(IDataView data, string label = DefaultColumnNa
/// If the is not provided, the random numbers generated to create it, will use this seed as value.
/// And if it is not provided, the default value will be used.
/// Per-fold results: metrics, models, scored datasets.
- public (RegressionMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
+ public CrossValidationResult[] CrossValidate(
IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label,
string stratificationColumn = null, uint? seed = null)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
- return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
+ return result.Select(x => new CrossValidationResult(x.Model, Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
}
}
}
diff --git a/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs b/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs
index 2ce8ae224f..ba7f0b2fce 100644
--- a/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs
+++ b/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs
@@ -12,6 +12,7 @@
using Microsoft.ML.Command;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
+using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Tools;
@@ -20,7 +21,7 @@
using Microsoft.ML.ExperimentVisualization;
#endif
-namespace Microsoft.ML.Internal.Internallearn.ResultProcessor
+namespace Microsoft.ML.ResultProcessor
{
using Float = System.Single;
///
diff --git a/src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs b/src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs
new file mode 100644
index 0000000000..83fafd8658
--- /dev/null
+++ b/src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs
@@ -0,0 +1,28 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.SamplesUtils
+{
+ ///
+ /// Utilities for creating console outputs in samples' code.
+ ///
+ public static class ConsoleUtils
+ {
+ ///
+ /// Pretty-print BinaryClassificationMetrics objects.
+ ///
+ /// Binary classification metrics.
+ public static void PrintMetrics(BinaryClassificationMetrics metrics)
+ {
+ Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}");
+ Console.WriteLine($"AUC: {metrics.Auc:F2}");
+ Console.WriteLine($"F1 Score: {metrics.F1Score:F2}");
+ Console.WriteLine($"Negative Precision: {metrics.NegativePrecision:F2}");
+ Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}");
+ Console.WriteLine($"Positive Precision: {metrics.PositivePrecision:F2}");
+ Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}");
+ }
+ }
+}
diff --git a/src/Microsoft.ML.SamplesUtils/Microsoft.ML.SamplesUtils.csproj b/src/Microsoft.ML.SamplesUtils/Microsoft.ML.SamplesUtils.csproj
index e4d6c5d504..0bdb047d42 100644
--- a/src/Microsoft.ML.SamplesUtils/Microsoft.ML.SamplesUtils.csproj
+++ b/src/Microsoft.ML.SamplesUtils/Microsoft.ML.SamplesUtils.csproj
@@ -6,7 +6,9 @@
+
+
diff --git a/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs b/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs
index edaf2d55c5..918cd26a9d 100644
--- a/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs
+++ b/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs
@@ -7,6 +7,7 @@
using System.IO;
using System.Net;
using Microsoft.Data.DataView;
+using Microsoft.ML;
using Microsoft.ML.Data;
namespace Microsoft.ML.SamplesUtils
@@ -17,7 +18,12 @@ public static class DatasetUtils
/// Downloads the housing dataset from the ML.NET repo.
///
public static string DownloadHousingRegressionDataset()
- => Download("https://raw.githubusercontent.com/dotnet/machinelearning/024bd4452e1d3660214c757237a19d6123f951ca/test/data/housing.txt", "housing.txt");
+ {
+ var fileName = "housing.txt";
+ if (!File.Exists(fileName))
+ Download("https://raw.githubusercontent.com/dotnet/machinelearning/024bd4452e1d3660214c757237a19d6123f951ca/test/data/housing.txt", fileName);
+ return fileName;
+ }
public static IDataView LoadHousingRegressionDataset(MLContext mlContext)
{
@@ -81,6 +87,57 @@ public static string DownloadSentimentDataset()
public static string DownloadAdultDataset()
=> Download("https://raw.githubusercontent.com/dotnet/machinelearning/244a8c2ac832657af282aa312d568211698790aa/test/data/adult.train", "adult.txt");
+ public static IDataView LoadFeaturizedAdultDataset(MLContext mlContext)
+ {
+ // Download the file
+ string dataFile = DownloadAdultDataset();
+
+ // Define the columns to read
+ var reader = mlContext.Data.CreateTextLoader(
+ columns: new[]
+ {
+ new TextLoader.Column("age", DataKind.R4, 0),
+ new TextLoader.Column("workclass", DataKind.TX, 1),
+ new TextLoader.Column("fnlwgt", DataKind.R4, 2),
+ new TextLoader.Column("education", DataKind.TX, 3),
+ new TextLoader.Column("education-num", DataKind.R4, 4),
+ new TextLoader.Column("marital-status", DataKind.TX, 5),
+ new TextLoader.Column("occupation", DataKind.TX, 6),
+ new TextLoader.Column("relationship", DataKind.TX, 7),
+ new TextLoader.Column("ethnicity", DataKind.TX, 8),
+ new TextLoader.Column("sex", DataKind.TX, 9),
+ new TextLoader.Column("capital-gain", DataKind.R4, 10),
+ new TextLoader.Column("capital-loss", DataKind.R4, 11),
+ new TextLoader.Column("hours-per-week", DataKind.R4, 12),
+ new TextLoader.Column("native-country", DataKind.R4, 13),
+ new TextLoader.Column("IsOver50K", DataKind.BL, 14),
+ },
+ separatorChar: ',',
+ hasHeader: true
+ );
+
+ // Create data featurizing pipeline
+ var pipeline =
+ // Convert categorical features to one-hot vectors
+ mlContext.Transforms.Categorical.OneHotEncoding("workclass")
+ .Append(mlContext.Transforms.Categorical.OneHotEncoding("education"))
+ .Append(mlContext.Transforms.Categorical.OneHotEncoding("marital-status"))
+ .Append(mlContext.Transforms.Categorical.OneHotEncoding("occupation"))
+ .Append(mlContext.Transforms.Categorical.OneHotEncoding("relationship"))
+ .Append(mlContext.Transforms.Categorical.OneHotEncoding("ethnicity"))
+ .Append(mlContext.Transforms.Categorical.OneHotEncoding("native-country"))
+ // Combine all features into one feature vector
+ .Append(mlContext.Transforms.Concatenate("Features", "workclass", "education", "marital-status",
+ "occupation", "relationship", "ethnicity", "native-country", "age", "education-num",
+ "capital-gain", "capital-loss", "hours-per-week"))
+ // Min-max normalized all the features
+ .Append(mlContext.Transforms.Normalize("Features"));
+
+ var data = reader.Read(dataFile);
+ var featurizedData = pipeline.Fit(data).Transform(data);
+ return featurizedData;
+ }
+
///
/// Downloads the breast cancer dataset from the ML.NET repo.
///
@@ -121,14 +178,14 @@ public static string DownloadTensorFlowSentimentModel()
string remotePath = "https://github.com/dotnet/machinelearning-testdata/raw/master/Microsoft.ML.TensorFlow.TestModels/sentiment_model/";
string path = "sentiment_model";
- if(!Directory.Exists(path))
+ if (!Directory.Exists(path))
Directory.CreateDirectory(path);
string varPath = Path.Combine(path, "variables");
if (!Directory.Exists(varPath))
Directory.CreateDirectory(varPath);
- Download(Path.Combine(remotePath, "saved_model.pb"), Path.Combine(path,"saved_model.pb"));
+ Download(Path.Combine(remotePath, "saved_model.pb"), Path.Combine(path, "saved_model.pb"));
Download(Path.Combine(remotePath, "imdb_word_index.csv"), Path.Combine(path, "imdb_word_index.csv"));
Download(Path.Combine(remotePath, "variables", "variables.data-00000-of-00001"), Path.Combine(varPath, "variables.data-00000-of-00001"));
Download(Path.Combine(remotePath, "variables", "variables.index"), Path.Combine(varPath, "variables.index"));
@@ -221,7 +278,7 @@ public static IEnumerable GetTopicsData()
public class SampleTemperatureData
{
- public DateTime Date {get; set; }
+ public DateTime Date { get; set; }
public float Temperature { get; set; }
}
@@ -374,7 +431,7 @@ public class BinaryLabelFloatFeatureVectorSample
public float[] Features;
}
- public static IEnumerable GenerateBinaryLabelFloatFeatureVectorSamples(int exampleCount)
+ public static IEnumerable GenerateBinaryLabelFloatFeatureVectorSamples(int exampleCount)
{
var rnd = new Random(0);
var data = new List();
@@ -405,7 +462,7 @@ public class FloatLabelFloatFeatureVectorSample
public float[] Features;
}
- public static IEnumerable GenerateFloatLabelFloatFeatureVectorSamples(int exampleCount, double naRate = 0)
+ public static IEnumerable GenerateFloatLabelFloatFeatureVectorSamples(int exampleCount, double naRate = 0)
{
var rnd = new Random(0);
var data = new List();
@@ -446,17 +503,20 @@ public class FfmExample
public float[] Field2;
}
- public static IEnumerable GenerateFfmSamples(int exampleCount)
+ public static IEnumerable GenerateFfmSamples(int exampleCount)
{
var rnd = new Random(0);
var data = new List();
for (int i = 0; i < exampleCount; ++i)
{
// Initialize an example with a random label and an empty feature vector.
- var sample = new FfmExample() { Label = rnd.Next() % 2 == 0,
+ var sample = new FfmExample()
+ {
+ Label = rnd.Next() % 2 == 0,
Field0 = new float[_simpleBinaryClassSampleFeatureLength],
Field1 = new float[_simpleBinaryClassSampleFeatureLength],
- Field2 = new float[_simpleBinaryClassSampleFeatureLength] };
+ Field2 = new float[_simpleBinaryClassSampleFeatureLength]
+ };
// Fill feature vector according the assigned label.
for (int j = 0; j < 10; ++j)
{
@@ -546,5 +606,49 @@ public static List GenerateRandomMulticlassClas
}
return examples;
}
+
+ // The following variables defines the shape of a matrix. Its shape is _synthesizedMatrixRowCount-by-_synthesizedMatrixColumnCount.
+ // Because in ML.NET key type's minimal value is zero, the first row index is always zero in C# data structure (e.g., MatrixColumnIndex=0
+ // and MatrixRowIndex=0 in MatrixElement below specifies the value at the upper-left corner in the training matrix). If user's row index
+ // starts with 1, their row index 1 would be mapped to the 2nd row in matrix factorization module and their first row may contain no values.
+ // This behavior is also true to column index.
+ private const int _synthesizedMatrixFirstColumnIndex = 1;
+ private const int _synthesizedMatrixFirstRowIndex = 1;
+ private const int _synthesizedMatrixColumnCount = 60;
+ private const int _synthesizedMatrixRowCount = 100;
+
+ // A data structure used to encode a single value in matrix
+ public class MatrixElement
+ {
+ // Matrix column index is at most _synthesizedMatrixColumnCount + _synthesizedMatrixFirstColumnIndex.
+ [KeyType(Count = _synthesizedMatrixColumnCount + _synthesizedMatrixFirstColumnIndex)]
+ public uint MatrixColumnIndex;
+ // Matrix row index is at most _synthesizedMatrixRowCount + _synthesizedMatrixFirstRowIndex.
+ [KeyType(Count = _synthesizedMatrixRowCount + _synthesizedMatrixFirstRowIndex)]
+ public uint MatrixRowIndex;
+ // The value at the column MatrixColumnIndex and row MatrixRowIndex.
+ public float Value;
+ }
+
+ // A data structure used to encode prediction result. Comparing with MatrixElement, The field Value in MatrixElement is
+ // renamed to Score because Score is the default name of matrix factorization's output.
+ public class MatrixElementForScore
+ {
+ [KeyType(Count = _synthesizedMatrixColumnCount + _synthesizedMatrixFirstColumnIndex)]
+ public uint MatrixColumnIndex;
+ [KeyType(Count = _synthesizedMatrixRowCount + _synthesizedMatrixFirstRowIndex)]
+ public uint MatrixRowIndex;
+ public float Score;
+ }
+
+ // Create an in-memory matrix as a list of tuples (column index, row index, value).
+ public static List GetRecommendationData()
+ {
+ var dataMatrix = new List();
+ for (uint i = _synthesizedMatrixFirstColumnIndex; i < _synthesizedMatrixFirstColumnIndex + _synthesizedMatrixColumnCount; ++i)
+ for (uint j = _synthesizedMatrixFirstRowIndex; j < _synthesizedMatrixFirstRowIndex + _synthesizedMatrixRowCount; ++j)
+ dataMatrix.Add(new MatrixElement() { MatrixColumnIndex = i, MatrixRowIndex = j, Value = (i + j) % 5 });
+ return dataMatrix;
+ }
}
}
diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
index dcfe680c4d..ad81336f58 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
@@ -511,6 +511,10 @@ public FieldAwareFactorizationMachinePredictionTransformer Train(IDataView train
public FieldAwareFactorizationMachinePredictionTransformer Fit(IDataView input) => Train(input);
+ ///
+ /// Schema propagation for transformers. Returns the output schema of the data, if
+ /// the input schema is like the one provided.
+ ///
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
index 4bf005b25a..9aac84dae1 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
@@ -283,7 +283,7 @@ public float[] GetLatentWeights()
}
}
- public sealed class FieldAwareFactorizationMachinePredictionTransformer : PredictionTransformerBase, ICanSaveModel
+ public sealed class FieldAwareFactorizationMachinePredictionTransformer : PredictionTransformerBase
{
public const string LoaderSignature = "FAFMPredXfer";
@@ -387,7 +387,7 @@ public override Schema GetOutputSchema(Schema inputSchema)
/// Saves the transformer to file.
///
/// The that facilitates saving to the .
- public void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs
index 3a88c62128..df52e48433 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs
@@ -14,11 +14,11 @@
using Microsoft.ML.Internal.Calibration;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
-using Microsoft.ML.Learners;
using Microsoft.ML.Model;
using Microsoft.ML.Model.Onnx;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Numeric;
+using Microsoft.ML.Trainers;
using Newtonsoft.Json.Linq;
// This is for deserialization from a model repository.
@@ -36,7 +36,7 @@
"Poisson Regression Executor",
PoissonRegressionModelParameters.LoaderSignature)]
-namespace Microsoft.ML.Learners
+namespace Microsoft.ML.Trainers
{
public abstract class LinearModelParameters : ModelParametersBase,
IValueMapper,
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictorUtils.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictorUtils.cs
index f9f64f3689..b61f443e02 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictorUtils.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictorUtils.cs
@@ -13,7 +13,7 @@
using Microsoft.ML.Internal.Utilities;
using Float = System.Single;
-namespace Microsoft.ML.Learners
+namespace Microsoft.ML.Trainers
{
///
/// Helper methods for linear predictors
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs
index c6c6602c52..ca335e033f 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs
@@ -14,7 +14,7 @@
using Microsoft.ML.Numeric;
using Microsoft.ML.Training;
-namespace Microsoft.ML.Learners
+namespace Microsoft.ML.Trainers
{
public abstract class LbfgsTrainerBase : TrainerEstimatorBase
where TTransformer : ISingleFeaturePredictionTransformer
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs
index 3cdb24c7df..d5fe66bdd7 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs
@@ -13,8 +13,8 @@
using Microsoft.ML.Internal.Calibration;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
-using Microsoft.ML.Learners;
using Microsoft.ML.Numeric;
+using Microsoft.ML.Trainers;
using Microsoft.ML.Training;
[assembly: LoadableClass(LogisticRegression.Summary, typeof(LogisticRegression), typeof(LogisticRegression.Options),
@@ -26,7 +26,7 @@
[assembly: LoadableClass(typeof(void), typeof(LogisticRegression), null, typeof(SignatureEntryPointModule), LogisticRegression.LoadNameValue)]
-namespace Microsoft.ML.Learners
+namespace Microsoft.ML.Trainers
{
///
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
index 0ed3f008da..9d809dd3fa 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
@@ -14,7 +14,6 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
-using Microsoft.ML.Learners;
using Microsoft.ML.Model;
using Microsoft.ML.Model.Onnx;
using Microsoft.ML.Model.Pfa;
@@ -35,7 +34,7 @@
"Multiclass LR Executor",
MulticlassLogisticRegressionModelParameters.LoaderSignature)]
-namespace Microsoft.ML.Learners
+namespace Microsoft.ML.Trainers
{
///
///
diff --git a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs
index d2efdc8d57..e5c2854da7 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs
@@ -12,15 +12,15 @@
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
-using Microsoft.ML.Learners;
using Microsoft.ML.Model;
+using Microsoft.ML.Trainers;
// This is for deserialization from a model repository.
[assembly: LoadableClass(typeof(LinearModelStatistics), null, typeof(SignatureLoadModel),
"Linear Model Statistics",
LinearModelStatistics.LoaderSignature)]
-namespace Microsoft.ML.Learners
+namespace Microsoft.ML.Trainers
{
///
/// Represents a coefficient statistics object.
@@ -166,7 +166,7 @@ internal static LinearModelStatistics Create(IHostEnvironment env, ModelLoadCont
return new LinearModelStatistics(env, ctx);
}
- public void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx)
{
Contracts.AssertValue(_env);
_env.CheckValue(ctx, nameof(ctx));
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs
index fdd32f2f20..448a078a8b 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs
@@ -14,7 +14,7 @@
using Microsoft.ML.Trainers.Online;
using Microsoft.ML.Training;
-namespace Microsoft.ML.Learners
+namespace Microsoft.ML.Trainers
{
using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
index cf7a083985..53f0f35fca 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
@@ -16,7 +16,6 @@
using Microsoft.ML.Internal.Calibration;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
-using Microsoft.ML.Learners;
using Microsoft.ML.Model;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Trainers;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs
index 5e3dd33f8b..344f9e5f6a 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs
@@ -11,7 +11,6 @@
using Microsoft.ML.Internal.Calibration;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
-using Microsoft.ML.Learners;
using Microsoft.ML.Model;
using Microsoft.ML.Trainers;
using Microsoft.ML.Training;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs
index bf7a88d27d..9234cd2df4 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs
@@ -9,7 +9,6 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
-using Microsoft.ML.Learners;
using Microsoft.ML.Numeric;
// TODO: Check if it works properly if Averaged is set to false
@@ -18,38 +17,89 @@ namespace Microsoft.ML.Trainers.Online
{
public abstract class AveragedLinearArguments : OnlineLinearArguments
{
+ ///
+ /// Learning rate.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr", SortOrder = 50)]
[TGUI(Label = "Learning rate", SuggestedSweeps = "0.01,0.1,0.5,1.0")]
[TlcModule.SweepableDiscreteParam("LearningRate", new object[] { 0.01, 0.1, 0.5, 1.0 })]
public float LearningRate = AveragedDefaultArgs.LearningRate;
+ ///
+ /// Determine whether to decrease the or not.
+ ///
+ ///
+ /// to decrease the as iterations progress; otherwise, .
+ /// Default is .
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Decrease learning rate", ShortName = "decreaselr", SortOrder = 50)]
[TGUI(Label = "Decrease Learning Rate", Description = "Decrease learning rate as iterations progress")]
[TlcModule.SweepableDiscreteParam("DecreaseLearningRate", new object[] { false, true })]
public bool DecreaseLearningRate = AveragedDefaultArgs.DecreaseLearningRate;
+ ///
+ /// Number of examples after which weights will be reset to the current average.
+ ///
+ ///
+ /// Default is , which disables this feature.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of examples after which weights will be reset to the current average", ShortName = "numreset")]
public long? ResetWeightsAfterXExamples = null;
+ ///
+ /// Determines when to update averaged weights.
+ ///
+ ///
+ /// to update averaged weights only when loss is nonzero.
+ /// to update averaged weights on every example.
+ /// Default is .
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Instead of updating averaged weights on every example, only update when loss is nonzero", ShortName = "lazy")]
public bool DoLazyUpdates = true;
+ ///
+ /// L2 weight for regularization.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization Weight", ShortName = "reg", SortOrder = 50)]
[TGUI(Label = "L2 Regularization Weight")]
[TlcModule.SweepableFloatParam("L2RegularizerWeight", 0.0f, 0.4f)]
public float L2RegularizerWeight = AveragedDefaultArgs.L2RegularizerWeight;
+ ///
+ /// Extra weight given to more recent updates.
+ ///
+ ///
+ /// Default is 0, i.e. no extra gain.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Extra weight given to more recent updates", ShortName = "rg")]
public float RecencyGain = 0;
+ ///
+ /// Determines whether is multiplicative or additive.
+ ///
+ ///
+ /// means is multiplicative.
+ /// means is additive.
+ /// Default is .
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether Recency Gain is multiplicative (vs. additive)", ShortName = "rgm")]
public bool RecencyGainMulti = false;
+ ///
+ /// Determines whether to do averaging or not.
+ ///
+ ///
+ /// to do averaging; otherwise, .
+ /// Default is .
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Do averaging?", ShortName = "avg")]
public bool Averaged = true;
+ ///
+ /// The inexactness tolerance for averaging.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "The inexactness tolerance for averaging", ShortName = "avgtol")]
- public float AveragedTolerance = (float)1e-2;
+ internal float AveragedTolerance = (float)1e-2;
[BestFriend]
internal class AveragedDefaultArgs : OnlineDefaultArgs
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs
index b3659974fa..d3dbdf619e 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs
@@ -11,7 +11,6 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Calibration;
using Microsoft.ML.Internal.Internallearn;
-using Microsoft.ML.Learners;
using Microsoft.ML.Numeric;
using Microsoft.ML.Trainers.Online;
using Microsoft.ML.Training;
@@ -25,12 +24,12 @@
namespace Microsoft.ML.Trainers.Online
{
- // This is an averaged perceptron classifier.
- // Configurable subcomponents:
- // - Loss function. By default, hinge loss (aka max-margin avgd perceptron)
- // - Feature normalization. By default, rescaling between min and max values for every feature
- // - Prediction calibration to produce probabilities. Off by default, if on, uses exponential (aka Platt) calibration.
- ///
+ ///
+ /// This is averaged perceptron trainer.
+ ///
+ ///
+ /// For usage details, please see
+ ///
public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer, LinearBinaryModelParameters>
{
public const string LoadNameValue = "AveragedPerceptron";
@@ -42,12 +41,21 @@ public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer
+ /// The custom loss.
+ ///
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
public ISupportClassificationLossFactory LossFunction = new HingeLoss.Arguments();
+ ///
+ /// The calibrator for producing probabilities. Default is exponential (aka Platt) calibration.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "The calibrator kind to apply to the predictor. Specify null for no calibration", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public ICalibratorTrainerFactory Calibrator = new PlattCalibratorTrainerFactory();
+ ///
+ /// The maximum number of examples to use when training the calibrator.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public int MaxCalibrationExamples = 1000000;
@@ -101,9 +109,9 @@ internal AveragedPerceptronTrainer(IHostEnvironment env, Options options)
/// The name of the feature column.
/// The optional name of the weights column.
/// The learning rate.
- /// Wheather to decrease learning rate as iterations progress.
+ /// Whether to decrease learning rate as iterations progress.
/// L2 Regularization Weight.
- /// The number of training iteraitons.
+ /// The number of training iterations.
internal AveragedPerceptronTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
@@ -121,7 +129,7 @@ internal AveragedPerceptronTrainer(IHostEnvironment env,
LearningRate = learningRate,
DecreaseLearningRate = decreaseLearningRate,
L2RegularizerWeight = l2RegularizerWeight,
- NumIterations = numIterations,
+ NumberOfIterations = numIterations,
LossFunction = new TrivialFactory(lossFunction ?? new HingeLoss())
})
{
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs
index 184a6554aa..2cd75e623c 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs
@@ -12,7 +12,6 @@
using Microsoft.ML.Internal.Calibration;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
-using Microsoft.ML.Learners;
using Microsoft.ML.Numeric;
using Microsoft.ML.Trainers.Online;
using Microsoft.ML.Training;
@@ -240,7 +239,7 @@ internal LinearSvmTrainer(IHostEnvironment env,
LabelColumn = labelColumn,
FeatureColumn = featureColumn,
InitialWeights = weightsColumn,
- NumIterations = numIterations,
+ NumberOfIterations = numIterations,
})
{
}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs
index 39ed31a6ec..b683aefb07 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs
@@ -10,7 +10,6 @@
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
-using Microsoft.ML.Learners;
using Microsoft.ML.Numeric;
using Microsoft.ML.Trainers.Online;
using Microsoft.ML.Training;
@@ -115,7 +114,7 @@ internal OnlineGradientDescentTrainer(IHostEnvironment env,
LearningRate = learningRate,
DecreaseLearningRate = decreaseLearningRate,
L2RegularizerWeight = l2RegularizerWeight,
- NumIterations = numIterations,
+ NumberOfIterations = numIterations,
LabelColumn = labelColumn,
FeatureColumn = featureColumn,
InitialWeights = weightsColumn,
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
index 662afe9a01..4dbdc8da57 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
@@ -11,7 +11,6 @@
using Microsoft.ML.Internal.Calibration;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
-using Microsoft.ML.Learners;
using Microsoft.ML.Numeric;
using Microsoft.ML.Training;
@@ -20,27 +19,41 @@ namespace Microsoft.ML.Trainers.Online
public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel
{
+ ///
+ /// Number of training iterations through the data.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter", SortOrder = 50)]
[TGUI(Label = "Number of Iterations", Description = "Number of training iterations through data", SuggestedSweeps = "1,10,100")]
[TlcModule.SweepableLongParamAttribute("NumIterations", 1, 100, stepSize: 10, isLogScale: true)]
- public int NumIterations = OnlineDefaultArgs.NumIterations;
+ public int NumberOfIterations = OnlineDefaultArgs.NumIterations;
+ ///
+ /// Initial weights and bias, comma-separated.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Initial Weights and bias, comma-separated", ShortName = "initweights")]
[TGUI(NoSweep = true)]
public string InitialWeights;
- [Argument(ArgumentType.AtMostOnce, HelpText = "Init weights diameter", ShortName = "initwts", SortOrder = 140)]
+ ///
+ /// Initial weights and bias scale.
+ ///
+ ///
+ /// This property is only used if the provided value is positive and is not specified.
+ /// The weights and bias will be randomly selected from InitialWeights * [-0.5,0.5] interval with uniform distribution.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Init weights diameter", ShortName = "initwts, initWtsDiameter", SortOrder = 140)]
[TGUI(Label = "Initial Weights Scale", SuggestedSweeps = "0,0.1,0.5,1")]
[TlcModule.SweepableFloatParamAttribute("InitWtsDiameter", 0.0f, 1.0f, numSteps: 5)]
- public float InitWtsDiameter = 0;
+ public float InitialWeightsDiameter = 0;
+ ///
+ /// to shuffle data for each training iteration; otherwise, .
+ /// Default is .
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to shuffle for each training iteration", ShortName = "shuf")]
[TlcModule.SweepableDiscreteParamAttribute("Shuffle", new object[] { false, true })]
public bool Shuffle = true;
- [Argument(ArgumentType.AtMostOnce, HelpText = "Size of cache when trained in Scope", ShortName = "cache")]
- public int StreamingCacheSize = 1000000;
-
[BestFriend]
internal class OnlineDefaultArgs
{
@@ -135,13 +148,13 @@ protected TrainStateBase(IChannel ch, int numFeatures, LinearModelParameters pre
Weights = new VBuffer(numFeatures, weightValues);
Bias = float.Parse(weightStr[numFeatures], CultureInfo.InvariantCulture);
}
- else if (parent.Args.InitWtsDiameter > 0)
+ else if (parent.Args.InitialWeightsDiameter > 0)
{
var weightValues = new float[numFeatures];
for (int i = 0; i < numFeatures; i++)
- weightValues[i] = parent.Args.InitWtsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
+ weightValues[i] = parent.Args.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
Weights = new VBuffer(numFeatures, weightValues);
- Bias = parent.Args.InitWtsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
+ Bias = parent.Args.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
}
else if (numFeatures <= 1000)
Weights = VBufferUtils.CreateDense(numFeatures);
@@ -239,9 +252,8 @@ private protected OnlineLinearTrainer(OnlineLinearArguments args, IHostEnvironme
: base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.InitialWeights))
{
Contracts.CheckValue(args, nameof(args));
- Contracts.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), UserErrorPositive);
- Contracts.CheckUserArg(args.InitWtsDiameter >= 0, nameof(args.InitWtsDiameter), UserErrorNonNegative);
- Contracts.CheckUserArg(args.StreamingCacheSize > 0, nameof(args.StreamingCacheSize), UserErrorPositive);
+ Contracts.CheckUserArg(args.NumberOfIterations > 0, nameof(args.NumberOfIterations), UserErrorPositive);
+ Contracts.CheckUserArg(args.InitialWeightsDiameter >= 0, nameof(args.InitialWeightsDiameter), UserErrorNonNegative);
Args = args;
Name = name;
@@ -300,7 +312,7 @@ private void TrainCore(IChannel ch, RoleMappedData data, TrainStateBase state)
var cursorFactory = new FloatLabelCursor.Factory(data, cursorOpt);
long numBad = 0;
- while (state.Iteration < Args.NumIterations)
+ while (state.Iteration < Args.NumberOfIterations)
{
state.BeginIteration(ch);
@@ -318,7 +330,7 @@ private void TrainCore(IChannel ch, RoleMappedData data, TrainStateBase state)
{
ch.Warning(
"Skipped {0} instances with missing features during training (over {1} iterations; {2} inst/iter)",
- numBad, Args.NumIterations, numBad / Args.NumIterations);
+ numBad, Args.NumberOfIterations, numBad / Args.NumberOfIterations);
}
}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/doc.xml b/src/Microsoft.ML.StandardLearners/Standard/Online/doc.xml
index 8e8f5dc2ba..292aeface5 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/doc.xml
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/doc.xml
@@ -25,44 +25,5 @@
-
-
-
- Averaged Perceptron Binary Classifier.
-
-
- Perceptron is a classification algorithm that makes its predictions based on a linear function.
- I.e., for an instance with feature values f0, f1,..., f_D-1, , the prediction is given by the sign of sigma[0,D-1] ( w_i * f_i), where w_0, w_1,...,w_D-1 are the weights computed by the algorithm.
-
- Perceptron is an online algorithm, i.e., it processes the instances in the training set one at a time.
- The weights are initialized to be 0, or some random values. Then, for each example in the training set, the value of sigma[0, D-1] (w_i * f_i) is computed.
- If this value has the same sign as the label of the current example, the weights remain the same. If they have opposite signs,
- the weights vector is updated by either subtracting or adding (if the label is negative or positive, respectively) the feature vector of the current example,
- multiplied by a factor 0 < a <= 1, called the learning rate. In a generalization of this algorithm, the weights are updated by adding the feature vector multiplied by the learning rate,
- and by the gradient of some loss function (in the specific case described above, the loss is hinge-loss, whose gradient is 1 when it is non-zero).
-
-
- In Averaged Perceptron (AKA voted-perceptron), the weight vectors are stored,
- together with a weight that counts the number of iterations it survived (this is equivalent to storing the weight vector after every iteration, regardless of whether it was updated or not).
- The prediction is then calculated by taking the weighted average of all the sums sigma[0, D-1] (w_i * f_i) or the different weight vectors.
-
- For more information see:
- Wikipedia entry for Perceptron
- Large Margin Classification Using the Perceptron Algorithm
-
-
-
-
-
- new AveragedPerceptronBinaryClassifier
- {
- NumIterations = 10,
- L2RegularizerWeight = 0.01f,
- LossFunction = new ExpLossClassificationLossFunction()
- }
-
-
-
-
diff --git a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs
index caf4a37166..b389fc6034 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs
@@ -10,7 +10,6 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
-using Microsoft.ML.Learners;
using Microsoft.ML.Numeric;
using Microsoft.ML.Trainers;
using Microsoft.ML.Training;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs
index 191c906b78..ca07f7fbe5 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs
@@ -18,7 +18,6 @@
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
-using Microsoft.ML.Learners;
using Microsoft.ML.Numeric;
using Microsoft.ML.Trainers;
using Microsoft.ML.Training;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
index aceabb7b82..ff49788028 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
@@ -14,7 +14,6 @@
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
-using Microsoft.ML.Learners;
using Microsoft.ML.Numeric;
using Microsoft.ML.Trainers;
using Microsoft.ML.Training;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs
index 49b65abe89..9524c48365 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs
@@ -12,7 +12,6 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
-using Microsoft.ML.Learners;
using Microsoft.ML.Trainers;
using Microsoft.ML.Training;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs
index 10fa3b75c8..d9265513f8 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs
@@ -10,7 +10,7 @@
using Microsoft.ML.Training;
using Microsoft.ML.Transforms;
-namespace Microsoft.ML.Learners
+namespace Microsoft.ML.Trainers
{
public abstract class StochasticTrainerBase : TrainerEstimatorBase
where TTransformer : ISingleFeaturePredictionTransformer
diff --git a/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs b/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs
index 6442e4eb01..d78464acde 100644
--- a/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs
+++ b/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs
@@ -5,7 +5,6 @@
using System;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Calibration;
-using Microsoft.ML.Learners;
using Microsoft.ML.Trainers;
using Microsoft.ML.Trainers.Online;
using Microsoft.ML.Training;
@@ -191,17 +190,46 @@ public static SdcaMultiClassTrainer StochasticDualCoordinateAscent(this Multicla
}
///
- /// Predict a target using a linear binary classification model trained with the AveragedPerceptron trainer.
+ /// Predict a target using a linear binary classification model trained with averaged perceptron trainer.
///
+ ///
+ /// Perceptron is a classification algorithm that makes its predictions by finding a separating hyperplane.
+ /// For instance, with feature values f0, f1,..., f_D-1, the prediction is given by determining what side of the hyperplane the point falls into.
+ /// That is the same as the sign of sigma[0, D-1] (w_i * f_i), where w_0, w_1,..., w_D-1 are the weights computed by the algorithm.
+ ///
+ /// The perceptron is an online algorithm, which means it processes the instances in the training set one at a time.
+ /// It starts with a set of initial weights (zero, random, or initialized from a previous learner). Then, for each example in the training set, the weighted sum of the features (sigma[0, D-1] (w_i * f_i)) is computed.
+ /// If this value has the same sign as the label of the current example, the weights remain the same.If they have opposite signs,
+ /// the weights vector is updated by either subtracting or adding (if the label is negative or positive, respectively) the feature vector of the current example,
+ /// multiplied by a factor 0 < a <= 1, called the learning rate.In a generalization of this algorithm, the weights are updated by adding the feature vector multiplied by the learning rate,
+ /// and by the gradient of some loss function (in the specific case described above, the loss is hinge-loss, whose gradient is 1 when it is non-zero).
+ ///
+ /// In Averaged Perceptron (AKA voted-perceptron), the weight vectors are stored,
+ /// together with a weight that counts the number of iterations it survived (this is equivalent to storing the weight vector after every iteration, regardless of whether it was updated or not).
+ /// The prediction is then calculated by taking the weighted average of all the sums sigma[0, D-1] (w_i * f_i) or the different weight vectors.
+ ///
+ /// For more information see Wikipedia entry for Perceptron
+ /// or Large Margin Classification Using the Perceptron Algorithm
+ ///
/// The binary classification catalog trainer object.
/// The name of the label column, or dependent variable.
/// The features, or independent variables.
- /// The custom loss.
+ /// The custom loss. If , hinge loss will be used resulting in max-margin averaged perceptron.
/// The optional example weights.
- /// The learning Rate.
- /// Decrease learning rate as iterations progress.
- /// L2 regularization weight.
+ /// Learning rate.
+ ///
+ /// to decrease the as iterations progress; otherwise, .
+ /// Default is .
+ ///
+ /// L2 weight for regularization.
/// Number of training iterations through the data.
+ ///
+ ///
+ ///
+ ///
+ ///
public static AveragedPerceptronTrainer AveragedPerceptron(
this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
string labelColumn = DefaultColumnNames.Label,
@@ -220,10 +248,18 @@ public static AveragedPerceptronTrainer AveragedPerceptron(
}
///
- /// Predict a target using a linear binary classification model trained with the AveragedPerceptron trainer.
+ /// Predict a target using a linear binary classification model trained with averaged perceptron trainer using advanced options.
+ /// For usage details, please see
///
/// The binary classification catalog trainer object.
- /// Advanced arguments to the algorithm.
+ /// Trainer options.
+ ///
+ ///
+ ///
+ ///
+ ///
public static AveragedPerceptronTrainer AveragedPerceptron(
this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, AveragedPerceptronTrainer.Options options)
{
@@ -293,7 +329,7 @@ public static OnlineGradientDescentTrainer OnlineGradientDescent(this Regression
}
///
- /// Predict a target using a linear binary classification model trained with the trainer.
+ /// Predict a target using a linear binary classification model trained with the trainer.
///
/// The binary classificaiton catalog trainer object.
/// The label column name, or dependent variable.
@@ -302,7 +338,7 @@ public static OnlineGradientDescentTrainer OnlineGradientDescent(this Regression
/// Enforce non-negative weights.
/// Weight of L1 regularization term.
/// Weight of L2 regularization term.
- /// Memory size for . Low=faster, less accurate.
+ /// Memory size for . Low=faster, less accurate.
/// Threshold for optimizer convergence.
///
///
@@ -327,7 +363,7 @@ public static LogisticRegression LogisticRegression(this BinaryClassificationCat
}
///
- /// Predict a target using a linear binary classification model trained with the trainer.
+ /// Predict a target using a linear binary classification model trained with the trainer.
///
/// The binary classificaiton catalog trainer object.
/// Advanced arguments to the algorithm.
@@ -341,7 +377,7 @@ public static LogisticRegression LogisticRegression(this BinaryClassificationCat
}
///
- /// Predict a target using a linear regression model trained with the trainer.
+ /// Predict a target using a linear regression model trained with the trainer.
///
/// The regression catalog trainer object.
/// The labelColumn, or dependent variable.
@@ -350,7 +386,7 @@ public static LogisticRegression LogisticRegression(this BinaryClassificationCat
/// Weight of L1 regularization term.
/// Weight of L2 regularization term.
/// Threshold for optimizer convergence.
- /// Memory size for . Low=faster, less accurate.
+ /// Memory size for . Low=faster, less accurate.
/// Enforce non-negative weights.
public static PoissonRegression PoissonRegression(this RegressionCatalog.RegressionTrainers catalog,
string labelColumn = DefaultColumnNames.Label,
@@ -368,7 +404,7 @@ public static PoissonRegression PoissonRegression(this RegressionCatalog.Regress
}
///
- /// Predict a target using a linear regression model trained with the trainer.
+ /// Predict a target using a linear regression model trained with the trainer.
///
/// The regression catalog trainer object.
/// Advanced arguments to the algorithm.
@@ -382,7 +418,7 @@ public static PoissonRegression PoissonRegression(this RegressionCatalog.Regress
}
///
- /// Predict a target using a linear multiclass classification model trained with the trainer.
+ /// Predict a target using a linear multiclass classification model trained with the trainer.
///
/// The .
/// The labelColumn, or dependent variable.
@@ -391,7 +427,7 @@ public static PoissonRegression PoissonRegression(this RegressionCatalog.Regress
/// Enforce non-negative weights.
/// Weight of L1 regularization term.
/// Weight of L2 regularization term.
- /// Memory size for . Low=faster, less accurate.
+ /// Memory size for . Low=faster, less accurate.
/// Threshold for optimizer convergence.
public static MulticlassLogisticRegression LogisticRegression(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
string labelColumn = DefaultColumnNames.Label,
@@ -409,7 +445,7 @@ public static MulticlassLogisticRegression LogisticRegression(this MulticlassCla
}
///
- /// Predict a target using a linear multiclass classification model trained with the trainer.
+ /// Predict a target using a linear multiclass classification model trained with the trainer.
///
/// The .
/// Advanced arguments to the algorithm.
diff --git a/src/Microsoft.ML.StaticPipe/LbfgsStatic.cs b/src/Microsoft.ML.StaticPipe/LbfgsStatic.cs
index 5de22b6e89..7e2ed821c7 100644
--- a/src/Microsoft.ML.StaticPipe/LbfgsStatic.cs
+++ b/src/Microsoft.ML.StaticPipe/LbfgsStatic.cs
@@ -6,7 +6,6 @@
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Calibration;
-using Microsoft.ML.Learners;
using Microsoft.ML.StaticPipe.Runtime;
using Microsoft.ML.Trainers;
@@ -20,7 +19,7 @@ namespace Microsoft.ML.StaticPipe
public static class LbfgsBinaryClassificationStaticExtensions
{
///
- /// Predict a target using a linear binary classification model trained with the trainer.
+ /// Predict a target using a linear binary classification model trained with the trainer.
///
/// The binary classificaiton catalog trainer object.
/// The label, or dependent variable.
@@ -29,7 +28,7 @@ public static class LbfgsBinaryClassificationStaticExtensions
/// Enforce non-negative weights.
/// Weight of L1 regularization term.
/// Weight of L2 regularization term.
- /// Memory size for . Low=faster, less accurate.
+ /// Memory size for . Low=faster, less accurate.
/// Threshold for optimizer convergence.
/// A delegate that is called every time the
/// method is called on the
@@ -66,7 +65,7 @@ public static (Scalar score, Scalar probability, Scalar pred
}
///
- /// Predict a target using a linear binary classification model trained with the trainer.
+ /// Predict a target using a linear binary classification model trained with the trainer.
///
/// The binary classificaiton catalog trainer object.
/// The label, or dependent variable.
@@ -116,7 +115,7 @@ public static (Scalar score, Scalar probability, Scalar pred
public static class LbfgsRegressionExtensions
{
///
- /// Predict a target using a linear regression model trained with the trainer.
+ /// Predict a target using a linear regression model trained with the trainer.
///
/// The regression catalog trainer object.
/// The label, or dependent variable.
@@ -125,7 +124,7 @@ public static class LbfgsRegressionExtensions
/// Enforce non-negative weights.
/// Weight of L1 regularization term.
/// Weight of L2 regularization term.
- /// Memory size for . Low=faster, less accurate.
+ /// Memory size for . Low=faster, less accurate.
/// Threshold for optimizer convergence.
/// A delegate that is called every time the
/// method is called on the
@@ -162,7 +161,7 @@ public static Scalar PoissonRegression(this RegressionCatalog.RegressionT
}
///
- /// Predict a target using a linear regression model trained with the trainer.
+ /// Predict a target using a linear regression model trained with the trainer.
///
/// The regression catalog trainer object.
/// The label, or dependent variable.
@@ -212,7 +211,7 @@ public static Scalar PoissonRegression(this RegressionCatalog.RegressionT
public static class LbfgsMulticlassExtensions
{
///
- /// Predict a target using a linear multiclass classification model trained with the trainer.
+ /// Predict a target using a linear multiclass classification model trained with the trainer.
///
/// The multiclass classification catalog trainer object.
/// The label, or dependent variable.
@@ -221,7 +220,7 @@ public static class LbfgsMulticlassExtensions
/// Enforce non-negative weights.
/// Weight of L1 regularization term.
/// Weight of L2 regularization term.
- /// Memory size for . Low=faster, less accurate.
+ /// Memory size for . Low=faster, less accurate.
/// Threshold for optimizer convergence.
/// A delegate that is called every time the
/// method is called on the
@@ -258,7 +257,7 @@ public static (Vector score, Key predictedLabel)
}
///
- /// Predict a target using a linear multiclass classification model trained with the trainer.
+ /// Predict a target using a linear multiclass classification model trained with the trainer.
///
/// The multiclass classification catalog trainer object.
/// The label, or dependent variable.
diff --git a/src/Microsoft.ML.StaticPipe/MatrixFactorizationStatic.cs b/src/Microsoft.ML.StaticPipe/MatrixFactorizationStatic.cs
index 0ca0f30854..2ecdf6e438 100644
--- a/src/Microsoft.ML.StaticPipe/MatrixFactorizationStatic.cs
+++ b/src/Microsoft.ML.StaticPipe/MatrixFactorizationStatic.cs
@@ -32,7 +32,7 @@ public static class MatrixFactorizationExtensions
public static Scalar MatrixFactorization(this RegressionCatalog.RegressionTrainers catalog,
Scalar label, Key matrixColumnIndex, Key matrixRowIndex,
MatrixFactorizationTrainer.Options options,
- Action onFit = null)
+ Action onFit = null)
{
Contracts.CheckValue(label, nameof(label));
Contracts.CheckValue(matrixColumnIndex, nameof(matrixColumnIndex));
diff --git a/src/Microsoft.ML.StaticPipe/OnlineLearnerStatic.cs b/src/Microsoft.ML.StaticPipe/OnlineLearnerStatic.cs
index fea4867428..38165451a6 100644
--- a/src/Microsoft.ML.StaticPipe/OnlineLearnerStatic.cs
+++ b/src/Microsoft.ML.StaticPipe/OnlineLearnerStatic.cs
@@ -3,8 +3,8 @@
// See the LICENSE file in the project root for more information.
using System;
-using Microsoft.ML.Learners;
using Microsoft.ML.StaticPipe.Runtime;
+using Microsoft.ML.Trainers;
using Microsoft.ML.Trainers.Online;
namespace Microsoft.ML.StaticPipe
diff --git a/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs
index 444a9fd7bc..62bc7f5e48 100644
--- a/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs
+++ b/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs
@@ -4,7 +4,6 @@
using System;
using Microsoft.ML.Internal.Calibration;
-using Microsoft.ML.Learners;
using Microsoft.ML.StaticPipe.Runtime;
using Microsoft.ML.Trainers;
diff --git a/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs
index 357246a658..53992cf4ff 100644
--- a/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs
+++ b/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs
@@ -49,8 +49,8 @@ public static (DataView trainSet, DataView testSet) TrainTestSplit(this
stratName = indexer.Get(column);
}
- var (trainData, testData) = catalog.TrainTestSplit(data.AsDynamic, testFraction, stratName, seed);
- return (new DataView(env, trainData, data.Shape), new DataView(env, testData, data.Shape));
+ var split = catalog.TrainTestSplit(data.AsDynamic, testFraction, stratName, seed);
+ return (new DataView(env, split.TrainSet, data.Shape), new DataView(env, split.TestSet, data.Shape));
}
///
@@ -105,9 +105,9 @@ public static (RegressionMetrics metrics, Transformer (
- x.metrics,
- new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape),
- new DataView(env, x.scoredTestData, estimator.Shape)))
+ x.Metrics,
+ new Transformer(env, (TTransformer)x.Model, data.Shape, estimator.Shape),
+ new DataView(env, x.ScoredHoldOutSet, estimator.Shape)))
.ToArray();
}
@@ -163,9 +163,9 @@ public static (MultiClassClassifierMetrics metrics, Transformer (
- x.metrics,
- new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape),
- new DataView(env, x.scoredTestData, estimator.Shape)))
+ x.Metrics,
+ new Transformer(env, (TTransformer)x.Model, data.Shape, estimator.Shape),
+ new DataView(env, x.ScoredHoldOutSet, estimator.Shape)))
.ToArray();
}
@@ -221,9 +221,9 @@ public static (BinaryClassificationMetrics metrics, Transformer (
- x.metrics,
- new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape),
- new DataView(env, x.scoredTestData, estimator.Shape)))
+ x.Metrics,
+ new Transformer(env, (TTransformer)x.Model, data.Shape, estimator.Shape),
+ new DataView(env, x.ScoredHoldOutSet, estimator.Shape)))
.ToArray();
}
@@ -279,9 +279,9 @@ public static (CalibratedBinaryClassificationMetrics metrics, Transformer (
- x.metrics,
- new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape),
- new DataView(env, x.scoredTestData, estimator.Shape)))
+ x.Metrics,
+ new Transformer(env, (TTransformer)x.Model, data.Shape, estimator.Shape),
+ new DataView(env, x.ScoredHoldOutSet, estimator.Shape)))
.ToArray();
}
}
diff --git a/src/Microsoft.ML.Sweeper/Algorithms/KdoSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/KdoSweeper.cs
index 7e0ca2c035..469988ec7c 100644
--- a/src/Microsoft.ML.Sweeper/Algorithms/KdoSweeper.cs
+++ b/src/Microsoft.ML.Sweeper/Algorithms/KdoSweeper.cs
@@ -9,7 +9,7 @@
using Microsoft.ML.CommandLine;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Sweeper.Algorithms;
-using Microsoft.ML.Trainers.FastTree.Internal;
+using Microsoft.ML.Trainers.FastTree;
using Float = System.Single;
[assembly: LoadableClass(typeof(KdoSweeper), typeof(KdoSweeper.Arguments), typeof(SignatureSweeper),
diff --git a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs
index 6c1738b297..8a81844eb0 100644
--- a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs
+++ b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs
@@ -13,7 +13,6 @@
using Microsoft.ML.Sweeper;
using Microsoft.ML.Sweeper.Algorithms;
using Microsoft.ML.Trainers.FastTree;
-using Microsoft.ML.Trainers.FastTree.Internal;
using Float = System.Single;
[assembly: LoadableClass(typeof(SmacSweeper), typeof(SmacSweeper.Arguments), typeof(SignatureSweeper),
diff --git a/src/Microsoft.ML.Sweeper/ConfigRunner.cs b/src/Microsoft.ML.Sweeper/ConfigRunner.cs
index 082a3fb4c9..cbe2a99900 100644
--- a/src/Microsoft.ML.Sweeper/ConfigRunner.cs
+++ b/src/Microsoft.ML.Sweeper/ConfigRunner.cs
@@ -12,7 +12,7 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Sweeper;
-using ResultProcessorInternal = Microsoft.ML.Internal.Internallearn.ResultProcessor;
+using ResultProcessorInternal = Microsoft.ML.ResultProcessor;
[assembly: LoadableClass(typeof(LocalExeConfigRunner), typeof(LocalExeConfigRunner.Arguments), typeof(SignatureConfigRunner),
"Local Sweep Config Runner", "Local")]
diff --git a/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs b/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs
index 8b03b89cb9..a55f8d4647 100644
--- a/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs
+++ b/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs
@@ -8,7 +8,7 @@
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Sweeper;
-using ResultProcessor = Microsoft.ML.Internal.Internallearn.ResultProcessor;
+using ResultProcessor = Microsoft.ML.ResultProcessor;
[assembly: LoadableClass(typeof(InternalSweepResultEvaluator), typeof(InternalSweepResultEvaluator.Arguments), typeof(SignatureSweepResultEvaluator),
"TLC Sweep Result Evaluator", "TlcEvaluator", "Tlc")]
diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
index bef2bf1b1f..cfe041445e 100644
--- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
@@ -667,15 +667,6 @@ internal static (TFDataType[] tfInputTypes, TFShape[] tfInputShapes) GetInputInf
var tfInput = new TFOutput(session.Graph[inputs[i]]);
tfInputTypes[i] = tfInput.OutputType;
tfInputShapes[i] = session.Graph.GetTensorShape(tfInput);
- if (tfInputShapes[i].NumDimensions != -1)
- {
- var newShape = new long[tfInputShapes[i].NumDimensions];
- newShape[0] = tfInputShapes[i][0] == -1 ? BatchSize : tfInputShapes[i][0];
-
- for (int j = 1; j < tfInputShapes[i].NumDimensions; j++)
- newShape[j] = tfInputShapes[i][j];
- tfInputShapes[i] = new TFShape(newShape);
- }
}
return (tfInputTypes, tfInputShapes);
}
@@ -698,7 +689,14 @@ internal static (TFDataType[] tfOutputTypes, ColumnType[] outputTypes) GetOutput
{
var tfOutput = new TFOutput(session.Graph[outputs[i]]);
var shape = session.Graph.GetTensorShape(tfOutput);
+
+ // The transformer can only retreive the output as fixed length vector with shape of kind [-1, d1, d2, d3, ...]
+ // i.e. the first dimension (if unknown) is assumed to be batch dimension.
+ // If there are other dimension that are unknown the transformer will return a variable length vector.
+ // This is the work around in absence of reshape transformer.
int[] dims = shape.NumDimensions > 0 ? shape.ToIntArray().Skip(shape[0] == -1 ? 1 : 0).ToArray() : new[] { 0 };
+ for (int j = 0; j < dims.Length; j++)
+ dims[j] = dims[j] == -1 ? 0 : dims[j];
var type = TensorFlowUtils.Tf2MlNetType(tfOutput.OutputType);
outputTypes[i] = new VectorType(type, dims);
tfOutputTypes[i] = tfOutput.OutputType;
@@ -709,7 +707,7 @@ internal static (TFDataType[] tfOutputTypes, ColumnType[] outputTypes) GetOutput
private protected override IRowMapper MakeRowMapper(Schema inputSchema) => new Mapper(this, inputSchema);
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.AssertValue(ctx);
ctx.CheckAtModel();
@@ -837,14 +835,22 @@ public Mapper(TensorFlowTransformer parent, Schema inputSchema) :
var originalShape = _parent.TFInputShapes[i];
var shape = originalShape.ToIntArray();
- var colTypeDims = vecType.Dimensions.Prepend(1).Select(dim => (long)dim).ToArray();
+ var colTypeDims = vecType.Dimensions.Select(dim => (long)dim).ToArray();
if (shape == null)
_fullySpecifiedShapes[i] = new TFShape(colTypeDims);
- else if (vecType.Dimensions.Length == 1)
+ else
{
// If the column is one dimension we make sure that the total size of the TF shape matches.
// Compute the total size of the known dimensions of the shape.
- int valCount = shape.Where(x => x > 0).Aggregate((x, y) => x * y);
+ int valCount = 1;
+ int numOfUnkDim = 0;
+ foreach (var s in shape)
+ {
+ if (s > 0)
+ valCount *= s;
+ else
+ numOfUnkDim++;
+ }
// The column length should be divisible by this, so that the other dimensions can be integral.
int typeValueCount = type.GetValueCount();
if (typeValueCount % valCount != 0)
@@ -853,8 +859,8 @@ public Mapper(TensorFlowTransformer parent, Schema inputSchema) :
// If the shape is multi-dimensional, we should be able to create the length of the vector by plugging
// in a single value for the unknown shapes. For example, if the shape is [?,?,3], then there should exist a value
// d such that d*d*3 is equal to the length of the input column.
- var d = originalShape.NumDimensions > 2 ? Math.Pow(typeValueCount / valCount, 1.0 / (originalShape.NumDimensions - 2)) : 1;
- if (originalShape.NumDimensions > 2 && d - (int)d != 0)
+ var d = numOfUnkDim > 0 ? Math.Pow(typeValueCount / valCount, 1.0 / numOfUnkDim) : 0;
+ if (d - (int)d != 0)
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is of length {typeValueCount}.");
// Fill in the unknown dimensions.
@@ -863,21 +869,10 @@ public Mapper(TensorFlowTransformer parent, Schema inputSchema) :
l[ishape] = originalShape[ishape] == -1 ? (int)d : originalShape[ishape];
_fullySpecifiedShapes[i] = new TFShape(l);
}
- else
- {
- if (shape.Select((dim, j) => dim != -1 && dim != colTypeDims[j]).Any(b => b))
- throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is {vecType.ToString()}.");
-
- // Fill in the unknown dimensions.
- var l = new long[originalShape.NumDimensions];
- for (int ishape = 0; ishape < originalShape.NumDimensions; ishape++)
- l[ishape] = originalShape[ishape] == -1 ? colTypeDims[ishape] : originalShape[ishape];
- _fullySpecifiedShapes[i] = new TFShape(l);
- }
}
}
- public override void Save(ModelSaveContext ctx) => _parent.Save(ctx);
+ private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
private class OutputCache
{
diff --git a/src/Microsoft.ML.TimeSeries.StaticPipe/TimeSeriesStatic.cs b/src/Microsoft.ML.TimeSeries.StaticPipe/TimeSeriesStatic.cs
index 8e8f7e137b..8eec3a20e3 100644
--- a/src/Microsoft.ML.TimeSeries.StaticPipe/TimeSeriesStatic.cs
+++ b/src/Microsoft.ML.TimeSeries.StaticPipe/TimeSeriesStatic.cs
@@ -5,12 +5,10 @@
using System.Collections.Generic;
using Microsoft.ML.Core.Data;
using Microsoft.ML.StaticPipe.Runtime;
-using Microsoft.ML.TimeSeriesProcessing;
+using Microsoft.ML.Transforms.TimeSeries;
namespace Microsoft.ML.StaticPipe
{
- using IidBase = Microsoft.ML.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase;
- using SsaBase = Microsoft.ML.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase;
///
/// Static API extension methods for .
@@ -25,7 +23,7 @@ public OutColumn(
Scalar input,
int confidence,
int changeHistoryLength,
- IidBase.MartingaleType martingale,
+ MartingaleType martingale,
double eps)
: base(new Reconciler(confidence, changeHistoryLength, martingale, eps), input)
{
@@ -37,13 +35,13 @@ private sealed class Reconciler : EstimatorReconciler
{
private readonly int _confidence;
private readonly int _changeHistoryLength;
- private readonly IidBase.MartingaleType _martingale;
+ private readonly MartingaleType _martingale;
private readonly double _eps;
public Reconciler(
int confidence,
int changeHistoryLength,
- IidBase.MartingaleType martingale,
+ MartingaleType martingale,
double eps)
{
_confidence = confidence;
@@ -60,11 +58,11 @@ public override IEstimator Reconcile(IHostEnvironment env,
{
Contracts.Assert(toOutput.Length == 1);
var outCol = (OutColumn)toOutput[0];
- return new IidChangePointEstimator(env,
+ return new MLContext().Transforms.IidChangePointEstimator(
outputNames[outCol],
+ inputNames[outCol.Input],
_confidence,
_changeHistoryLength,
- inputNames[outCol.Input],
_martingale,
_eps);
}
@@ -77,7 +75,7 @@ public static Vector IidChangePointDetect(
this Scalar input,
int confidence,
int changeHistoryLength,
- IidBase.MartingaleType martingale = IidBase.MartingaleType.Power,
+ MartingaleType martingale = MartingaleType.Power,
double eps = 0.1) => new OutColumn(input, confidence, changeHistoryLength, martingale, eps);
}
@@ -93,7 +91,7 @@ private sealed class OutColumn : Vector
public OutColumn(Scalar input,
int confidence,
int pvalueHistoryLength,
- IidBase.AnomalySide side)
+ AnomalySide side)
: base(new Reconciler(confidence, pvalueHistoryLength, side), input)
{
Input = input;
@@ -104,12 +102,12 @@ private sealed class Reconciler : EstimatorReconciler
{
private readonly int _confidence;
private readonly int _pvalueHistoryLength;
- private readonly IidBase.AnomalySide _side;
+ private readonly AnomalySide _side;
public Reconciler(
int confidence,
int pvalueHistoryLength,
- IidBase.AnomalySide side)
+ AnomalySide side)
{
_confidence = confidence;
_pvalueHistoryLength = pvalueHistoryLength;
@@ -124,11 +122,11 @@ public override IEstimator Reconcile(IHostEnvironment env,
{
Contracts.Assert(toOutput.Length == 1);
var outCol = (OutColumn)toOutput[0];
- return new IidSpikeEstimator(env,
+ return new MLContext().Transforms.IidSpikeEstimator(
outputNames[outCol],
+ inputNames[outCol.Input],
_confidence,
_pvalueHistoryLength,
- inputNames[outCol.Input],
_side);
}
}
@@ -140,7 +138,7 @@ public static Vector IidSpikeDetect(
this Scalar input,
int confidence,
int pvalueHistoryLength,
- IidBase.AnomalySide side = IidBase.AnomalySide.TwoSided
+ AnomalySide side = AnomalySide.TwoSided
) => new OutColumn(input, confidence, pvalueHistoryLength, side);
}
@@ -158,8 +156,8 @@ public OutColumn(Scalar input,
int changeHistoryLength,
int trainingWindowSize,
int seasonalityWindowSize,
- ErrorFunctionUtils.ErrorFunction errorFunction,
- SsaBase.MartingaleType martingale,
+ ErrorFunction errorFunction,
+ MartingaleType martingale,
double eps)
: base(new Reconciler(confidence, changeHistoryLength, trainingWindowSize, seasonalityWindowSize, errorFunction, martingale, eps), input)
{
@@ -173,8 +171,8 @@ private sealed class Reconciler : EstimatorReconciler
private readonly int _changeHistoryLength;
private readonly int _trainingWindowSize;
private readonly int _seasonalityWindowSize;
- private readonly ErrorFunctionUtils.ErrorFunction _errorFunction;
- private readonly SsaBase.MartingaleType _martingale;
+ private readonly ErrorFunction _errorFunction;
+ private readonly MartingaleType _martingale;
private readonly double _eps;
public Reconciler(
@@ -182,8 +180,8 @@ public Reconciler(
int changeHistoryLength,
int trainingWindowSize,
int seasonalityWindowSize,
- ErrorFunctionUtils.ErrorFunction errorFunction,
- SsaBase.MartingaleType martingale,
+ ErrorFunction errorFunction,
+ MartingaleType martingale,
double eps)
{
_confidence = confidence;
@@ -203,13 +201,13 @@ public override IEstimator Reconcile(IHostEnvironment env,
{
Contracts.Assert(toOutput.Length == 1);
var outCol = (OutColumn)toOutput[0];
- return new SsaChangePointEstimator(env,
+ return new MLContext().Transforms.SsaChangePointEstimator(
outputNames[outCol],
+ inputNames[outCol.Input],
_confidence,
_changeHistoryLength,
_trainingWindowSize,
_seasonalityWindowSize,
- inputNames[outCol.Input],
_errorFunction,
_martingale,
_eps);
@@ -225,8 +223,8 @@ public static Vector SsaChangePointDetect(
int changeHistoryLength,
int trainingWindowSize,
int seasonalityWindowSize,
- ErrorFunctionUtils.ErrorFunction errorFunction = ErrorFunctionUtils.ErrorFunction.SignedDifference,
- SsaBase.MartingaleType martingale = SsaBase.MartingaleType.Power,
+ ErrorFunction errorFunction = ErrorFunction.SignedDifference,
+ MartingaleType martingale = MartingaleType.Power,
double eps = 0.1) => new OutColumn(input, confidence, changeHistoryLength, trainingWindowSize, seasonalityWindowSize, errorFunction, martingale, eps);
}
@@ -244,8 +242,8 @@ public OutColumn(Scalar input,
int pvalueHistoryLength,
int trainingWindowSize,
int seasonalityWindowSize,
- SsaBase.AnomalySide side,
- ErrorFunctionUtils.ErrorFunction errorFunction)
+ AnomalySide side,
+ ErrorFunction errorFunction)
: base(new Reconciler(confidence, pvalueHistoryLength, trainingWindowSize, seasonalityWindowSize, side, errorFunction), input)
{
Input = input;
@@ -258,16 +256,16 @@ private sealed class Reconciler : EstimatorReconciler
private readonly int _pvalueHistoryLength;
private readonly int _trainingWindowSize;
private readonly int _seasonalityWindowSize;
- private readonly SsaBase.AnomalySide _side;
- private readonly ErrorFunctionUtils.ErrorFunction _errorFunction;
+ private readonly AnomalySide _side;
+ private readonly ErrorFunction _errorFunction;
public Reconciler(
int confidence,
int pvalueHistoryLength,
int trainingWindowSize,
int seasonalityWindowSize,
- SsaBase.AnomalySide side,
- ErrorFunctionUtils.ErrorFunction errorFunction)
+ AnomalySide side,
+ ErrorFunction errorFunction)
{
_confidence = confidence;
_pvalueHistoryLength = pvalueHistoryLength;
@@ -285,13 +283,13 @@ public override IEstimator Reconcile(IHostEnvironment env,
{
Contracts.Assert(toOutput.Length == 1);
var outCol = (OutColumn)toOutput[0];
- return new SsaSpikeEstimator(env,
+ return new MLContext().Transforms.SsaSpikeEstimator(
outputNames[outCol],
+ inputNames[outCol.Input],
_confidence,
_pvalueHistoryLength,
_trainingWindowSize,
_seasonalityWindowSize,
- inputNames[outCol.Input],
_side,
_errorFunction);
}
@@ -306,8 +304,8 @@ public static Vector SsaSpikeDetect(
int changeHistoryLength,
int trainingWindowSize,
int seasonalityWindowSize,
- SsaBase.AnomalySide side = SsaBase.AnomalySide.TwoSided,
- ErrorFunctionUtils.ErrorFunction errorFunction = ErrorFunctionUtils.ErrorFunction.SignedDifference
+ AnomalySide side = AnomalySide.TwoSided,
+ ErrorFunction errorFunction = ErrorFunction.SignedDifference
) => new OutColumn(input, confidence, changeHistoryLength, trainingWindowSize, seasonalityWindowSize, side, errorFunction);
}
diff --git a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs
index 4a3593ca14..cf5e602cf9 100644
--- a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs
+++ b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs
@@ -11,22 +11,21 @@
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
-using Microsoft.ML.TimeSeries;
-using Microsoft.ML.TimeSeriesProcessing;
+using Microsoft.ML.Transforms.TimeSeries;
[assembly: LoadableClass(typeof(AdaptiveSingularSpectrumSequenceModeler), typeof(AdaptiveSingularSpectrumSequenceModeler), null, typeof(SignatureLoadModel),
"SSA Sequence Modeler",
AdaptiveSingularSpectrumSequenceModeler.LoaderSignature)]
-namespace Microsoft.ML.TimeSeriesProcessing
+namespace Microsoft.ML.Transforms.TimeSeries
{
///
/// This class implements basic Singular Spectrum Analysis (SSA) model for modeling univariate time-series.
/// For the details of the model, refer to http://arxiv.org/pdf/1206.6910.pdf.
///
- public sealed class AdaptiveSingularSpectrumSequenceModeler : SequenceModelerBase
+ internal sealed class AdaptiveSingularSpectrumSequenceModeler : SequenceModelerBase
{
- public const string LoaderSignature = "SSAModel";
+ internal const string LoaderSignature = "SSAModel";
public enum RankSelectionMethod
{
@@ -35,7 +34,7 @@ public enum RankSelectionMethod
Fast
}
- public sealed class SsaForecastResult : ForecastResultBase
+ internal sealed class SsaForecastResult : ForecastResultBase
{
public VBuffer ForecastStandardDeviation;
public VBuffer UpperBound;
@@ -465,7 +464,7 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo
_xSmooth = new CpuAlignedVector(_windowSize, CpuMathUtils.GetVectorAlignment());
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -1517,7 +1516,7 @@ internal override SequenceModelerBase Clone()
///
/// The input forecast object
/// The confidence level in [0, 1)
- public static void ComputeForecastIntervals(ref SsaForecastResult forecast, Single confidenceLevel = 0.95f)
+ internal static void ComputeForecastIntervals(ref SsaForecastResult forecast, Single confidenceLevel = 0.95f)
{
Contracts.CheckParam(0 <= confidenceLevel && confidenceLevel < 1, nameof(confidenceLevel), "The confidence level must be in [0, 1).");
Contracts.CheckValue(forecast, nameof(forecast));
diff --git a/src/Microsoft.ML.TimeSeries/EigenUtils.cs b/src/Microsoft.ML.TimeSeries/EigenUtils.cs
index 8d0b6794e5..fc8224b94e 100644
--- a/src/Microsoft.ML.TimeSeries/EigenUtils.cs
+++ b/src/Microsoft.ML.TimeSeries/EigenUtils.cs
@@ -8,10 +8,10 @@
using Microsoft.ML.Internal.Utilities;
using Float = System.Single;
-namespace Microsoft.ML.TimeSeriesProcessing
+namespace Microsoft.ML.Transforms.TimeSeries
{
//REVIEW: improve perf with SSE and Multithreading
- public static class EigenUtils
+ internal static class EigenUtils
{
//Compute the Eigen-decomposition of a symmetric matrix
//REVIEW: use matrix/vector operations, not Array Math
diff --git a/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs b/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs
index 15e7a84a8f..29653b2620 100644
--- a/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs
+++ b/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs
@@ -10,25 +10,26 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
-using Microsoft.ML.TimeSeriesProcessing;
+using Microsoft.ML.Transforms.TimeSeries;
[assembly: LoadableClass(ExponentialAverageTransform.Summary, typeof(ExponentialAverageTransform), typeof(ExponentialAverageTransform.Arguments), typeof(SignatureDataTransform),
ExponentialAverageTransform.UserName, ExponentialAverageTransform.LoaderSignature, ExponentialAverageTransform.ShortName)]
[assembly: LoadableClass(ExponentialAverageTransform.Summary, typeof(ExponentialAverageTransform), null, typeof(SignatureLoadDataTransform),
ExponentialAverageTransform.UserName, ExponentialAverageTransform.LoaderSignature)]
-namespace Microsoft.ML.TimeSeriesProcessing
+namespace Microsoft.ML.Transforms.TimeSeries
{
///
/// ExponentialAverageTransform is a weighted average of the values: ExpAvg(y_t) = a * y_t + (1-a) * ExpAvg(y_(t-1)).
///
- public sealed class ExponentialAverageTransform : SequentialTransformBase
+ internal sealed class ExponentialAverageTransform : SequentialTransformBase
{
public const string Summary = "Applies a Exponential average on a time series.";
public const string LoaderSignature = "ExpAverageTransform";
public const string UserName = "Exponential Average Transform";
public const string ShortName = "ExpAvg";
+#pragma warning disable 0649
public sealed class Arguments : TransformInputBase
{
[Argument(ArgumentType.Required, HelpText = "The name of the source column", ShortName = "src",
@@ -43,6 +44,7 @@ public sealed class Arguments : TransformInputBase
ShortName = "d", SortOrder = 4)]
public Single Decay = 0.9f;
}
+#pragma warning restore 0649
private static VersionInfo GetVersionInfo()
{
@@ -77,7 +79,7 @@ public ExponentialAverageTransform(IHostEnvironment env, ModelLoadContext ctx, I
Host.CheckDecode(WindowSize == 1);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
Host.Assert(WindowSize >= 1);
@@ -89,7 +91,7 @@ public override void Save(ModelSaveContext ctx)
//
// Single _decay
- base.Save(ctx);
+ base.SaveModel(ctx);
ctx.Writer.Write(_decay);
}
diff --git a/src/Microsoft.ML.TimeSeries/ExtensionsCatalog.cs b/src/Microsoft.ML.TimeSeries/ExtensionsCatalog.cs
new file mode 100644
index 0000000000..e013f7c671
--- /dev/null
+++ b/src/Microsoft.ML.TimeSeries/ExtensionsCatalog.cs
@@ -0,0 +1,87 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Data;
+using Microsoft.ML.Transforms.TimeSeries;
+
+namespace Microsoft.ML
+{
+ public static class TimeSeriesCatalog
+ {
+ ///
+ /// Create a new instance of
+ ///
+ /// The transform's catalog.
+ /// Name of the column resulting from the transformation of .
+ /// Column is a vector of type double and size 4. The vector contains Alert, Raw Score, P-Value and Martingale score as first four values.
+ /// Name of column to transform. If set to , the value of the will be used as source.
+ /// The confidence for change point detection in the range [0, 100].
+ /// The length of the sliding window on p-values for computing the martingale score.
+ /// The martingale used for scoring.
+ /// The epsilon parameter for the Power martingale.
+ public static IidChangePointEstimator IidChangePointEstimator(this TransformsCatalog catalog, string outputColumnName, string inputColumnName,
+ int confidence, int changeHistoryLength, MartingaleType martingale = MartingaleType.Power, double eps = 0.1)
+ => new IidChangePointEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, confidence, changeHistoryLength, inputColumnName, martingale, eps);
+
+ ///
+ /// Create a new instance of
+ ///
+ /// The transform's catalog.
+ /// Name of the column resulting from the transformation of .
+ /// Name of column to transform. If set to , the value of the will be used as source.
+ /// The confidence for spike detection in the range [0, 100].
+ /// The size of the sliding window for computing the p-value.
+ /// The argument that determines whether to detect positive or negative anomalies, or both.
+ public static IidSpikeEstimator IidSpikeEstimator(this TransformsCatalog catalog, string outputColumnName, string inputColumnName,
+ int confidence, int pvalueHistoryLength, AnomalySide side = AnomalySide.TwoSided)
+ => new IidSpikeEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, confidence, pvalueHistoryLength, inputColumnName, side);
+
+ ///
+ /// Create a new instance of
+ ///
+ /// The transform's catalog.
+ /// Name of the column resulting from the transformation of .
+ /// Column is a vector of type double and size 4. The vector contains Alert, Raw Score, P-Value and Martingale score as first four values.
+ /// Name of column to transform. If set to , the value of the will be used as source.
+ /// The confidence for change point detection in the range [0, 100].
+ /// The number of points from the beginning of the sequence used for training.
+ /// The size of the sliding window for computing the p-value.
+ /// An upper bound on the largest relevant seasonality in the input time-series.
+ /// The function used to compute the error between the expected and the observed value.
+ /// The martingale used for scoring.
+ /// The epsilon parameter for the Power martingale.
+ public static SsaChangePointEstimator SsaChangePointEstimator(this TransformsCatalog catalog, string outputColumnName, string inputColumnName,
+ int confidence, int changeHistoryLength, int trainingWindowSize, int seasonalityWindowSize, ErrorFunction errorFunction = ErrorFunction.SignedDifference,
+ MartingaleType martingale = MartingaleType.Power, double eps = 0.1)
+ => new SsaChangePointEstimator(CatalogUtils.GetEnvironment(catalog), new SsaChangePointDetector.Options
+ {
+ Name = outputColumnName,
+ Source = inputColumnName ?? outputColumnName,
+ Confidence = confidence,
+ ChangeHistoryLength = changeHistoryLength,
+ TrainingWindowSize = trainingWindowSize,
+ SeasonalWindowSize = seasonalityWindowSize,
+ Martingale = martingale,
+ PowerMartingaleEpsilon = eps,
+ ErrorFunction = errorFunction
+ });
+
+ ///
+ /// Create a new instance of
+ ///
+ /// The transform's catalog.
+ /// Name of the column resulting from the transformation of .
+ /// Name of column to transform. If set to , the value of the will be used as source.
+ /// The confidence for spike detection in the range [0, 100].
+ /// The size of the sliding window for computing the p-value.
+ /// The number of points from the beginning of the sequence used for training.
+ /// An upper bound on the largest relevant seasonality in the input time-series.
+ /// The vector contains Alert, Raw Score, P-Value as first three values.
+ /// The argument that determines whether to detect positive or negative anomalies, or both.
+ /// The function used to compute the error between the expected and the observed value.
+ public static SsaSpikeEstimator SsaSpikeEstimator(this TransformsCatalog catalog, string outputColumnName, string inputColumnName, int confidence, int pvalueHistoryLength,
+ int trainingWindowSize, int seasonalityWindowSize, AnomalySide side = AnomalySide.TwoSided, ErrorFunction errorFunction = ErrorFunction.SignedDifference)
+ => new SsaSpikeEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, confidence, pvalueHistoryLength, trainingWindowSize, seasonalityWindowSize, inputColumnName, side, errorFunction);
+ }
+}
diff --git a/src/Microsoft.ML.TimeSeries/FftUtils.cs b/src/Microsoft.ML.TimeSeries/FftUtils.cs
index a02575c031..28a19a89ed 100644
--- a/src/Microsoft.ML.TimeSeries/FftUtils.cs
+++ b/src/Microsoft.ML.TimeSeries/FftUtils.cs
@@ -6,7 +6,7 @@
using System.Runtime.InteropServices;
using System.Security;
-namespace Microsoft.ML.TimeSeriesProcessing
+namespace Microsoft.ML.Transforms.TimeSeries
{
///
/// The utility functions that wrap the native Discrete Fast Fourier Transform functionality from Intel MKL.
diff --git a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs
index eeff5db3d2..25cf99347b 100644
--- a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs
+++ b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs
@@ -5,107 +5,193 @@
using System;
using System.IO;
using Microsoft.Data.DataView;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
-using Microsoft.ML.TimeSeries;
-namespace Microsoft.ML.TimeSeriesProcessing
+namespace Microsoft.ML.Transforms.TimeSeries
{
///
- /// This transform computes the p-values and martingale scores for a supposedly i.i.d input sequence of floats. In other words, it assumes
+ /// The is the wrapper to that computes the p-values and martingale scores for a supposedly i.i.d input sequence of floats. In other words, it assumes
/// the input sequence represents the raw anomaly score which might have been computed via another process.
///
- public abstract class IidAnomalyDetectionBase : SequentialAnomalyDetectionTransformBase
+ public class IidAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveModel
{
- public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment env)
- : base(args, name, env)
+ ///
+ /// Whether a call to should succeed, on an
+ /// appropriate schema.
+ ///
+ public bool IsRowToRowMapper => InternalTransform.IsRowToRowMapper;
+
+ ///
+ /// Creates a clone of the transfomer. Used for taking the snapshot of the state.
+ ///
+ ///
+ IStatefulTransformer IStatefulTransformer.Clone() => InternalTransform.Clone();
+
+ ///
+ /// Schema propagation for transformers.
+ /// Returns the output schema of the data, if the input schema is like the one provided.
+ ///
+ public Schema GetOutputSchema(Schema inputSchema) => InternalTransform.GetOutputSchema(inputSchema);
+
+ ///
+ /// Constructs a row-to-row mapper based on an input schema. If
+ /// is false, then an exception should be thrown. If the input schema is in any way
+ /// unsuitable for constructing the mapper, an exception should likewise be thrown.
+ ///
+ /// The input schema for which we should get the mapper.
+ /// The row to row mapper.
+ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) => InternalTransform.GetRowToRowMapper(inputSchema);
+
+ ///
+ /// Same as but also supports mechanism to save the state.
+ ///
+ /// The input schema for which we should get the mapper.
+ /// The row to row mapper.
+ public IRowToRowMapper GetStatefulRowToRowMapper(Schema inputSchema) => ((IStatefulTransformer)InternalTransform).GetStatefulRowToRowMapper(inputSchema);
+
+ ///
+ /// Take the data in, make transformations, output the data.
+ /// Note that 's are lazy, so no actual transformations happen here, just schema validation.
+ ///
+ public IDataView Transform(IDataView input) => InternalTransform.Transform(input);
+
+ ///
+ /// For saving a model into a repository.
+ ///
+ public virtual void Save(ModelSaveContext ctx)
{
- InitialWindowSize = 0;
- StateRef = new State();
- StateRef.InitState(WindowSize, InitialWindowSize, this, Host);
+ InternalTransform.SaveThis(ctx);
}
- public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name)
- : base(env, ctx, name)
- {
- Host.CheckDecode(InitialWindowSize == 0);
- StateRef = new State(ctx.Reader);
- StateRef.InitState(this, Host);
- }
-
- public override Schema GetOutputSchema(Schema inputSchema)
- {
- Host.CheckValue(inputSchema, nameof(inputSchema));
+ ///
+ /// Creates a row mapper from Schema.
+ ///
+ internal IStatefulRowMapper MakeRowMapper(Schema schema) => InternalTransform.MakeRowMapper(schema);
- if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col))
- throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName);
+ ///
+ /// Creates an IDataTransform from an IDataView.
+ ///
+ internal IDataTransform MakeDataTransform(IDataView input) => InternalTransform.MakeDataTransform(input);
- var colType = inputSchema[col].Type;
- if (colType != NumberType.R4)
- throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, "float", colType.ToString());
+ internal IidAnomalyDetectionBase InternalTransform;
- return Transform(new EmptyDataView(Host, inputSchema)).Schema;
+ internal IidAnomalyDetectionBaseWrapper(ArgumentsBase args, string name, IHostEnvironment env)
+ {
+ InternalTransform = new IidAnomalyDetectionBase(args, name, env, this);
}
- public override void Save(ModelSaveContext ctx)
+ internal IidAnomalyDetectionBaseWrapper(IHostEnvironment env, ModelLoadContext ctx, string name)
{
- ctx.CheckAtModel();
- Host.Assert(InitialWindowSize == 0);
- base.Save(ctx);
-
- // *** Binary format ***
- //
- // State: StateRef
- StateRef.Save(ctx.Writer);
+ InternalTransform = new IidAnomalyDetectionBase(env, ctx, name, this);
}
- public sealed class State : AnomalyDetectionStateBase
+ ///
+ /// This transform computes the p-values and martingale scores for a supposedly i.i.d input sequence of floats. In other words, it assumes
+ /// the input sequence represents the raw anomaly score which might have been computed via another process.
+ ///
+ internal class IidAnomalyDetectionBase : SequentialAnomalyDetectionTransformBase
{
- public State()
- {
- }
+ internal IidAnomalyDetectionBaseWrapper Parent;
- internal State(BinaryReader reader) : base(reader)
+ public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment env, IidAnomalyDetectionBaseWrapper parent)
+ : base(args, name, env)
{
- WindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host);
- InitialWindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host);
+ InitialWindowSize = 0;
+ StateRef = new State();
+ StateRef.InitState(WindowSize, InitialWindowSize, this, Host);
+ Parent = parent;
}
- internal override void Save(BinaryWriter writer)
+ public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name, IidAnomalyDetectionBaseWrapper parent)
+ : base(env, ctx, name)
{
- base.Save(writer);
- TimeSeriesUtils.SerializeFixedSizeQueue(WindowedBuffer, writer);
- TimeSeriesUtils.SerializeFixedSizeQueue(InitialWindowedBuffer, writer);
+ Host.CheckDecode(InitialWindowSize == 0);
+ StateRef = new State(ctx.Reader);
+ StateRef.InitState(this, Host);
+ Parent = parent;
}
- private protected override void CloneCore(StateBase state)
+ public override Schema GetOutputSchema(Schema inputSchema)
{
- base.CloneCore(state);
- Contracts.Assert(state is State);
- var stateLocal = state as State;
- stateLocal.WindowedBuffer = WindowedBuffer.Clone();
- stateLocal.InitialWindowedBuffer = InitialWindowedBuffer.Clone();
- }
+ Host.CheckValue(inputSchema, nameof(inputSchema));
- private protected override void LearnStateFromDataCore(FixedSizeQueue data)
- {
- // This method is empty because there is no need for initial tuning for this transform.
+ if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col))
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName);
+
+ var colType = inputSchema[col].Type;
+ if (colType != NumberType.R4)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, NumberType.R4.ToString(), colType.ToString());
+
+ return Transform(new EmptyDataView(Host, inputSchema)).Schema;
}
- private protected override void InitializeAnomalyDetector()
+ private protected override void SaveModel(ModelSaveContext ctx)
{
- // This method is empty because there is no need for any extra initialization for this transform.
+ Parent.Save(ctx);
}
- private protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue windowedBuffer, long iteration)
+ internal void SaveThis(ModelSaveContext ctx)
{
- // This transform treats the input sequenence as the raw anomaly score.
- return (double)input;
+ ctx.CheckAtModel();
+ Host.Assert(InitialWindowSize == 0);
+ base.SaveModel(ctx);
+
+ // *** Binary format ***
+ //
+ // State: StateRef
+ StateRef.Save(ctx.Writer);
}
- public override void Consume(float value)
+ internal sealed class State : AnomalyDetectionStateBase
{
+ public State()
+ {
+ }
+
+ internal State(BinaryReader reader) : base(reader)
+ {
+ WindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host);
+ InitialWindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host);
+ }
+
+ internal override void Save(BinaryWriter writer)
+ {
+ base.Save(writer);
+ TimeSeriesUtils.SerializeFixedSizeQueue(WindowedBuffer, writer);
+ TimeSeriesUtils.SerializeFixedSizeQueue(InitialWindowedBuffer, writer);
+ }
+
+ private protected override void CloneCore(State state)
+ {
+ base.CloneCore(state);
+ Contracts.Assert(state is State);
+ var stateLocal = state as State;
+ stateLocal.WindowedBuffer = WindowedBuffer.Clone();
+ stateLocal.InitialWindowedBuffer = InitialWindowedBuffer.Clone();
+ }
+
+ private protected override void LearnStateFromDataCore(FixedSizeQueue data)
+ {
+ // This method is empty because there is no need for initial tuning for this transform.
+ }
+
+ private protected override void InitializeAnomalyDetector()
+ {
+ // This method is empty because there is no need for any extra initialization for this transform.
+ }
+
+ private protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue windowedBuffer, long iteration)
+ {
+ // This transform treats the input sequenence as the raw anomaly score.
+ return (double)input;
+ }
+
+ public override void Consume(float value)
+ {
+ }
}
}
}
diff --git a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs
index cb0874e2ce..9f5fc85a2d 100644
--- a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs
+++ b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs
@@ -12,11 +12,9 @@
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Model;
-using Microsoft.ML.TimeSeries;
-using Microsoft.ML.TimeSeriesProcessing;
-using static Microsoft.ML.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase;
+using Microsoft.ML.Transforms.TimeSeries;
-[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IDataTransform), typeof(IidChangePointDetector), typeof(IidChangePointDetector.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IDataTransform), typeof(IidChangePointDetector), typeof(IidChangePointDetector.Options), typeof(SignatureDataTransform),
IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature, IidChangePointDetector.ShortName)]
[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IDataTransform), typeof(IidChangePointDetector), null, typeof(SignatureLoadDataTransform),
@@ -28,19 +26,19 @@
[assembly: LoadableClass(typeof(IRowMapper), typeof(IidChangePointDetector), null, typeof(SignatureLoadRowMapper),
IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature)]
-namespace Microsoft.ML.TimeSeriesProcessing
+namespace Microsoft.ML.Transforms.TimeSeries
{
///
/// This class implements the change point detector transform for an i.i.d. sequence based on adaptive kernel density estimation and martingales.
///
- public sealed class IidChangePointDetector : IidAnomalyDetectionBase
+ public sealed class IidChangePointDetector : IidAnomalyDetectionBaseWrapper, IStatefulTransformer
{
internal const string Summary = "This transform detects the change-points in an i.i.d. sequence using adaptive kernel density estimation and martingales.";
- public const string LoaderSignature = "IidChangePointDetector";
- public const string UserName = "IID Change Point Detection";
- public const string ShortName = "ichgpnt";
+ internal const string LoaderSignature = "IidChangePointDetector";
+ internal const string UserName = "IID Change Point Detection";
+ internal const string ShortName = "ichgpnt";
- public sealed class Arguments : TransformInputBase
+ internal sealed class Options : TransformInputBase
{
[Argument(ArgumentType.Required, HelpText = "The name of the source column.", ShortName = "src",
SortOrder = 1, Purpose = SpecialPurpose.ColumnName)]
@@ -59,7 +57,7 @@ public sealed class Arguments : TransformInputBase
public double Confidence = 95;
[Argument(ArgumentType.AtMostOnce, HelpText = "The martingale used for scoring.", ShortName = "mart", SortOrder = 103)]
- public MartingaleType Martingale = SequentialAnomalyDetectionTransformBase.MartingaleType.Power;
+ public MartingaleType Martingale = MartingaleType.Power;
[Argument(ArgumentType.AtMostOnce, HelpText = "The epsilon parameter for the Power martingale.",
ShortName = "eps", SortOrder = 104)]
@@ -68,27 +66,27 @@ public sealed class Arguments : TransformInputBase
private sealed class BaseArguments : ArgumentsBase
{
- public BaseArguments(Arguments args)
+ public BaseArguments(Options options)
{
- Source = args.Source;
- Name = args.Name;
- Side = SequentialAnomalyDetectionTransformBase.AnomalySide.TwoSided;
- WindowSize = args.ChangeHistoryLength;
- Martingale = args.Martingale;
- PowerMartingaleEpsilon = args.PowerMartingaleEpsilon;
- AlertOn = SequentialAnomalyDetectionTransformBase.AlertingScore.MartingaleScore;
+ Source = options.Source;
+ Name = options.Name;
+ Side = AnomalySide.TwoSided;
+ WindowSize = options.ChangeHistoryLength;
+ Martingale = options.Martingale;
+ PowerMartingaleEpsilon = options.PowerMartingaleEpsilon;
+ AlertOn = AlertingScore.MartingaleScore;
}
public BaseArguments(IidChangePointDetector transform)
{
- Source = transform.InputColumnName;
- Name = transform.OutputColumnName;
+ Source = transform.InternalTransform.InputColumnName;
+ Name = transform.InternalTransform.OutputColumnName;
Side = AnomalySide.TwoSided;
- WindowSize = transform.WindowSize;
- Martingale = transform.Martingale;
- PowerMartingaleEpsilon = transform.PowerMartingaleEpsilon;
+ WindowSize = transform.InternalTransform.WindowSize;
+ Martingale = transform.InternalTransform.Martingale;
+ PowerMartingaleEpsilon = transform.InternalTransform.PowerMartingaleEpsilon;
AlertOn = AlertingScore.MartingaleScore;
- AlertThreshold = transform.AlertThreshold;
+ AlertThreshold = transform.InternalTransform.AlertThreshold;
}
}
@@ -103,39 +101,39 @@ private static VersionInfo GetVersionInfo()
}
// Factory method for SignatureDataTransform.
- private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
- env.CheckValue(args, nameof(args));
+ env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
- return new IidChangePointDetector(env, args).MakeDataTransform(input);
+ return new IidChangePointDetector(env, options).MakeDataTransform(input);
}
- internal override IStatefulTransformer Clone()
+ IStatefulTransformer IStatefulTransformer.Clone()
{
var clone = (IidChangePointDetector)MemberwiseClone();
- clone.StateRef = (State)clone.StateRef.Clone();
- clone.StateRef.InitState(clone, Host);
+ clone.InternalTransform.StateRef = (IidAnomalyDetectionBase.State)clone.InternalTransform.StateRef.Clone();
+ clone.InternalTransform.StateRef.InitState(clone.InternalTransform, InternalTransform.Host);
return clone;
}
- internal IidChangePointDetector(IHostEnvironment env, Arguments args)
- : base(new BaseArguments(args), LoaderSignature, env)
+ internal IidChangePointDetector(IHostEnvironment env, Options options)
+ : base(new BaseArguments(options), LoaderSignature, env)
{
- switch (Martingale)
+ switch (InternalTransform.Martingale)
{
case MartingaleType.None:
- AlertThreshold = Double.MaxValue;
+ InternalTransform.AlertThreshold = Double.MaxValue;
break;
case MartingaleType.Power:
- AlertThreshold = Math.Exp(WindowSize * LogPowerMartigaleBettingFunc(1 - args.Confidence / 100, PowerMartingaleEpsilon));
+ InternalTransform.AlertThreshold = Math.Exp(InternalTransform.WindowSize * InternalTransform.LogPowerMartigaleBettingFunc(1 - options.Confidence / 100, InternalTransform.PowerMartingaleEpsilon));
break;
case MartingaleType.Mixture:
- AlertThreshold = Math.Exp(WindowSize * LogMixtureMartigaleBettingFunc(1 - args.Confidence / 100));
+ InternalTransform.AlertThreshold = Math.Exp(InternalTransform.WindowSize * InternalTransform.LogMixtureMartigaleBettingFunc(1 - options.Confidence / 100));
break;
default:
- throw Host.ExceptParam(nameof(args.Martingale),
+ throw InternalTransform.Host.ExceptParam(nameof(options.Martingale),
"The martingale type can be only (0) None, (1) Power or (2) Mixture.");
}
}
@@ -166,8 +164,8 @@ internal IidChangePointDetector(IHostEnvironment env, ModelLoadContext ctx)
// *** Binary format ***
//
- Host.CheckDecode(ThresholdScore == AlertingScore.MartingaleScore);
- Host.CheckDecode(Side == AnomalySide.TwoSided);
+ InternalTransform.Host.CheckDecode(InternalTransform.ThresholdScore == AlertingScore.MartingaleScore);
+ InternalTransform.Host.CheckDecode(InternalTransform.Side == AnomalySide.TwoSided);
}
private IidChangePointDetector(IHostEnvironment env, IidChangePointDetector transform)
@@ -177,12 +175,12 @@ private IidChangePointDetector(IHostEnvironment env, IidChangePointDetector tran
public override void Save(ModelSaveContext ctx)
{
- Host.CheckValue(ctx, nameof(ctx));
+ InternalTransform.Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
- Host.Assert(ThresholdScore == AlertingScore.MartingaleScore);
- Host.Assert(Side == AnomalySide.TwoSided);
+ InternalTransform.Host.Assert(InternalTransform.ThresholdScore == AlertingScore.MartingaleScore);
+ InternalTransform.Host.Assert(InternalTransform.Side == AnomalySide.TwoSided);
// *** Binary format ***
//
@@ -219,10 +217,10 @@ public sealed class IidChangePointEstimator : TrivialEstimatorName of column to transform. If set to , the value of the will be used as source.
/// The martingale used for scoring.
/// The epsilon parameter for the Power martingale.
- public IidChangePointEstimator(IHostEnvironment env, string outputColumnName, int confidence,
+ internal IidChangePointEstimator(IHostEnvironment env, string outputColumnName, int confidence,
int changeHistoryLength, string inputColumnName, MartingaleType martingale = MartingaleType.Power, double eps = 0.1)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidChangePointEstimator)),
- new IidChangePointDetector(env, new IidChangePointDetector.Arguments
+ new IidChangePointDetector(env, new IidChangePointDetector.Options
{
Name = outputColumnName,
Source = inputColumnName ?? outputColumnName,
@@ -234,28 +232,32 @@ public IidChangePointEstimator(IHostEnvironment env, string outputColumnName, in
{
}
- public IidChangePointEstimator(IHostEnvironment env, IidChangePointDetector.Arguments args)
+ internal IidChangePointEstimator(IHostEnvironment env, IidChangePointDetector.Options options)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidChangePointEstimator)),
- new IidChangePointDetector(env, args))
+ new IidChangePointDetector(env, options))
{
}
+ ///
+ /// Returns the of the schema which will be produced by the transformer.
+ /// Used for schema propagation and verification in a pipeline.
+ ///
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- if (!inputSchema.TryFindColumn(Transformer.InputColumnName, out var col))
- throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName);
+ if (!inputSchema.TryFindColumn(Transformer.InternalTransform.InputColumnName, out var col))
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InternalTransform.InputColumnName);
if (col.ItemType != NumberType.R4)
- throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName, "float", col.GetTypeString());
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InternalTransform.InputColumnName, "float", col.GetTypeString());
var metadata = new List() {
new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
};
var resultDic = inputSchema.ToDictionary(x => x.Name);
- resultDic[Transformer.OutputColumnName] = new SchemaShape.Column(
- Transformer.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));
+ resultDic[Transformer.InternalTransform.OutputColumnName] = new SchemaShape.Column(
+ Transformer.InternalTransform.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));
return new SchemaShape(resultDic.Values);
}
diff --git a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs
index 20cecfc285..813043606a 100644
--- a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs
+++ b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs
@@ -11,11 +11,9 @@
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Model;
-using Microsoft.ML.TimeSeries;
-using Microsoft.ML.TimeSeriesProcessing;
-using static Microsoft.ML.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase;
+using Microsoft.ML.Transforms.TimeSeries;
-[assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IDataTransform), typeof(IidSpikeDetector), typeof(IidSpikeDetector.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IDataTransform), typeof(IidSpikeDetector), typeof(IidSpikeDetector.Options), typeof(SignatureDataTransform),
IidSpikeDetector.UserName, IidSpikeDetector.LoaderSignature, IidSpikeDetector.ShortName)]
[assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IDataTransform), typeof(IidSpikeDetector), null, typeof(SignatureLoadDataTransform),
@@ -27,19 +25,19 @@
[assembly: LoadableClass(typeof(IRowMapper), typeof(IidSpikeDetector), null, typeof(SignatureLoadRowMapper),
IidSpikeDetector.UserName, IidSpikeDetector.LoaderSignature)]
-namespace Microsoft.ML.TimeSeriesProcessing
+namespace Microsoft.ML.Transforms.TimeSeries
{
///
/// This class implements the spike detector transform for an i.i.d. sequence based on adaptive kernel density estimation.
///
- public sealed class IidSpikeDetector : IidAnomalyDetectionBase
+ public sealed class IidSpikeDetector : IidAnomalyDetectionBaseWrapper, IStatefulTransformer
{
internal const string Summary = "This transform detects the spikes in a i.i.d. sequence using adaptive kernel density estimation.";
- public const string LoaderSignature = "IidSpikeDetector";
- public const string UserName = "IID Spike Detection";
- public const string ShortName = "ispike";
+ internal const string LoaderSignature = "IidSpikeDetector";
+ internal const string UserName = "IID Spike Detection";
+ internal const string ShortName = "ispike";
- public sealed class Arguments : TransformInputBase
+ internal sealed class Options : TransformInputBase
{
[Argument(ArgumentType.Required, HelpText = "The name of the source column.", ShortName = "src",
SortOrder = 1, Purpose = SpecialPurpose.ColumnName)]
@@ -64,24 +62,24 @@ public sealed class Arguments : TransformInputBase
private sealed class BaseArguments : ArgumentsBase
{
- public BaseArguments(Arguments args)
+ public BaseArguments(Options options)
{
- Source = args.Source;
- Name = args.Name;
- Side = args.Side;
- WindowSize = args.PvalueHistoryLength;
- AlertThreshold = 1 - args.Confidence / 100;
- AlertOn = SequentialAnomalyDetectionTransformBase.AlertingScore.PValueScore;
+ Source = options.Source;
+ Name = options.Name;
+ Side = options.Side;
+ WindowSize = options.PvalueHistoryLength;
+ AlertThreshold = 1 - options.Confidence / 100;
+ AlertOn = AlertingScore.PValueScore;
Martingale = MartingaleType.None;
}
public BaseArguments(IidSpikeDetector transform)
{
- Source = transform.InputColumnName;
- Name = transform.OutputColumnName;
- Side = transform.Side;
- WindowSize = transform.WindowSize;
- AlertThreshold = transform.AlertThreshold;
+ Source = transform.InternalTransform.InputColumnName;
+ Name = transform.InternalTransform.OutputColumnName;
+ Side = transform.InternalTransform.Side;
+ WindowSize = transform.InternalTransform.WindowSize;
+ AlertThreshold = transform.InternalTransform.AlertThreshold;
AlertOn = AlertingScore.PValueScore;
Martingale = MartingaleType.None;
}
@@ -99,25 +97,25 @@ private static VersionInfo GetVersionInfo()
}
// Factory method for SignatureDataTransform.
- private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
- env.CheckValue(args, nameof(args));
+ env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
- return new IidSpikeDetector(env, args).MakeDataTransform(input);
+ return new IidSpikeDetector(env, options).MakeDataTransform(input);
}
- internal override IStatefulTransformer Clone()
+ IStatefulTransformer IStatefulTransformer.Clone()
{
var clone = (IidSpikeDetector)MemberwiseClone();
- clone.StateRef = (State)clone.StateRef.Clone();
- clone.StateRef.InitState(clone, Host);
+ clone.InternalTransform.StateRef = (IidAnomalyDetectionBase.State)clone.InternalTransform.StateRef.Clone();
+ clone.InternalTransform.StateRef.InitState(clone.InternalTransform, InternalTransform.Host);
return clone;
}
- internal IidSpikeDetector(IHostEnvironment env, Arguments args)
- : base(new BaseArguments(args), LoaderSignature, env)
+ internal IidSpikeDetector(IHostEnvironment env, Options options)
+ : base(new BaseArguments(options), LoaderSignature, env)
{
// This constructor is empty.
}
@@ -142,14 +140,15 @@ private static IidSpikeDetector Create(IHostEnvironment env, ModelLoadContext ct
return new IidSpikeDetector(env, ctx);
}
- public IidSpikeDetector(IHostEnvironment env, ModelLoadContext ctx)
+ internal IidSpikeDetector(IHostEnvironment env, ModelLoadContext ctx)
: base(env, ctx, LoaderSignature)
{
// *** Binary format ***
//
- Host.CheckDecode(ThresholdScore == AlertingScore.PValueScore);
+ InternalTransform.Host.CheckDecode(InternalTransform.ThresholdScore == AlertingScore.PValueScore);
}
+
private IidSpikeDetector(IHostEnvironment env, IidSpikeDetector transform)
: base(new BaseArguments(transform), LoaderSignature, env)
{
@@ -157,11 +156,11 @@ private IidSpikeDetector(IHostEnvironment env, IidSpikeDetector transform)
public override void Save(ModelSaveContext ctx)
{
- Host.CheckValue(ctx, nameof(ctx));
+ InternalTransform.Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
- Host.Assert(ThresholdScore == AlertingScore.PValueScore);
+ InternalTransform.Host.Assert(InternalTransform.ThresholdScore == AlertingScore.PValueScore);
// *** Binary format ***
//
@@ -197,9 +196,9 @@ public sealed class IidSpikeEstimator : TrivialEstimator
/// The size of the sliding window for computing the p-value.
/// Name of column to transform. If set to , the value of the will be used as source.
/// The argument that determines whether to detect positive or negative anomalies, or both.
- public IidSpikeEstimator(IHostEnvironment env, string outputColumnName, int confidence, int pvalueHistoryLength, string inputColumnName, AnomalySide side = AnomalySide.TwoSided)
+ internal IidSpikeEstimator(IHostEnvironment env, string outputColumnName, int confidence, int pvalueHistoryLength, string inputColumnName, AnomalySide side = AnomalySide.TwoSided)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidSpikeDetector)),
- new IidSpikeDetector(env, new IidSpikeDetector.Arguments
+ new IidSpikeDetector(env, new IidSpikeDetector.Options
{
Name = outputColumnName,
Source = inputColumnName,
@@ -210,26 +209,30 @@ public IidSpikeEstimator(IHostEnvironment env, string outputColumnName, int conf
{
}
- public IidSpikeEstimator(IHostEnvironment env, IidSpikeDetector.Arguments args)
- : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidSpikeEstimator)), new IidSpikeDetector(env, args))
+ internal IidSpikeEstimator(IHostEnvironment env, IidSpikeDetector.Options options)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidSpikeEstimator)), new IidSpikeDetector(env, options))
{
}
+ ///
+ /// Schema propagation for transformers.
+ /// Returns the output schema of the data, if the input schema is like the one provided.
+ ///
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- if (!inputSchema.TryFindColumn(Transformer.InputColumnName, out var col))
- throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName);
+ if (!inputSchema.TryFindColumn(Transformer.InternalTransform.InputColumnName, out var col))
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InternalTransform.InputColumnName);
if (col.ItemType != NumberType.R4)
- throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName, "float", col.GetTypeString());
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InternalTransform.InputColumnName, "float", col.GetTypeString());
var metadata = new List() {
new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
};
var resultDic = inputSchema.ToDictionary(x => x.Name);
- resultDic[Transformer.OutputColumnName] = new SchemaShape.Column(
- Transformer.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));
+ resultDic[Transformer.InternalTransform.OutputColumnName] = new SchemaShape.Column(
+ Transformer.InternalTransform.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));
return new SchemaShape(resultDic.Values);
}
diff --git a/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs b/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs
index f207c7b75a..1e04173d49 100644
--- a/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs
+++ b/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs
@@ -10,20 +10,20 @@
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
-using Microsoft.ML.TimeSeriesProcessing;
+using Microsoft.ML.Transforms.TimeSeries;
[assembly: LoadableClass(MovingAverageTransform.Summary, typeof(MovingAverageTransform), typeof(MovingAverageTransform.Arguments), typeof(SignatureDataTransform),
"Moving Average Transform", MovingAverageTransform.LoaderSignature, "MoAv")]
[assembly: LoadableClass(MovingAverageTransform.Summary, typeof(MovingAverageTransform), null, typeof(SignatureLoadDataTransform),
"Moving Average Transform", MovingAverageTransform.LoaderSignature)]
-namespace Microsoft.ML.TimeSeriesProcessing
+namespace Microsoft.ML.Transforms.TimeSeries
{
///
/// MovingAverageTransform is a weighted average of the values in
/// the sliding window.
///
- public sealed class MovingAverageTransform : SequentialTransformBase
+ internal sealed class MovingAverageTransform : SequentialTransformBase
{
public const string Summary = "Applies a moving average on a time series. Only finite values are taken into account.";
public const string LoaderSignature = "MovingAverageTransform";
@@ -94,7 +94,7 @@ public MovingAverageTransform(IHostEnvironment env, ModelLoadContext ctx, IDataV
Host.CheckDecode(_weights == null || Utils.Size(_weights) == WindowSize + 1 - _lag);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
Host.Assert(WindowSize >= 1);
@@ -107,7 +107,7 @@ public override void Save(ModelSaveContext ctx)
// int: _lag
// Single[]: _weights
- base.Save(ctx);
+ base.SaveModel(ctx);
ctx.Writer.Write(_lag);
Host.Assert(_weights == null || Utils.Size(_weights) == WindowSize + 1 - _lag);
ctx.Writer.WriteSingleArray(_weights);
diff --git a/src/Microsoft.ML.TimeSeries/PValueTransform.cs b/src/Microsoft.ML.TimeSeries/PValueTransform.cs
index d8f3adcb29..528f8b7fd0 100644
--- a/src/Microsoft.ML.TimeSeries/PValueTransform.cs
+++ b/src/Microsoft.ML.TimeSeries/PValueTransform.cs
@@ -10,20 +10,20 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
-using Microsoft.ML.TimeSeriesProcessing;
+using Microsoft.ML.Transforms.TimeSeries;
[assembly: LoadableClass(PValueTransform.Summary, typeof(PValueTransform), typeof(PValueTransform.Arguments), typeof(SignatureDataTransform),
PValueTransform.UserName, PValueTransform.LoaderSignature, PValueTransform.ShortName)]
[assembly: LoadableClass(PValueTransform.Summary, typeof(PValueTransform), null, typeof(SignatureLoadDataTransform),
PValueTransform.UserName, PValueTransform.LoaderSignature)]
-namespace Microsoft.ML.TimeSeriesProcessing
+namespace Microsoft.ML.Transforms.TimeSeries
{
///
/// PValueTransform is a sequential transform that computes the empirical p-value of the current value in the series based on the other values in
/// the sliding window.
///
- public sealed class PValueTransform : SequentialTransformBase
+ internal sealed class PValueTransform : SequentialTransformBase
{
internal const string Summary = "This P-Value transform calculates the p-value of the current input in the sequence with regard to the values in the sliding window.";
public const string LoaderSignature = "PValueTransform";
@@ -91,7 +91,7 @@ public PValueTransform(IHostEnvironment env, ModelLoadContext ctx, IDataView inp
Host.CheckDecode(WindowSize >= 1);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
Host.Assert(WindowSize >= 1);
@@ -103,7 +103,7 @@ public override void Save(ModelSaveContext ctx)
// int: _percentile
// byte: _isPositiveSide
- base.Save(ctx);
+ base.SaveModel(ctx);
ctx.Writer.Write(_seed);
ctx.Writer.WriteBoolByte(_isPositiveSide);
}
diff --git a/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs b/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs
index 868e446899..a4e8c792d0 100644
--- a/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs
+++ b/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs
@@ -10,20 +10,20 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
-using Microsoft.ML.TimeSeriesProcessing;
+using Microsoft.ML.Transforms.TimeSeries;
[assembly: LoadableClass(PercentileThresholdTransform.Summary, typeof(PercentileThresholdTransform), typeof(PercentileThresholdTransform.Arguments), typeof(SignatureDataTransform),
PercentileThresholdTransform.UserName, PercentileThresholdTransform.LoaderSignature, PercentileThresholdTransform.ShortName)]
[assembly: LoadableClass(PercentileThresholdTransform.Summary, typeof(PercentileThresholdTransform), null, typeof(SignatureLoadDataTransform),
PercentileThresholdTransform.UserName, PercentileThresholdTransform.LoaderSignature)]
-namespace Microsoft.ML.TimeSeriesProcessing
+namespace Microsoft.ML.Transforms.TimeSeries
{
///
/// PercentileThresholdTransform is a sequential transform that decides whether the current value of the time-series belongs to the 'percentile' % of the top values in
/// the sliding window. The output of the transform will be a boolean flag.
///
- public sealed class PercentileThresholdTransform : SequentialTransformBase
+ internal sealed class PercentileThresholdTransform : SequentialTransformBase
{
public const string Summary = "Detects the values of time-series that are in the top percentile of the sliding window.";
public const string LoaderSignature = "PercentThrTransform";
@@ -86,7 +86,7 @@ public PercentileThresholdTransform(IHostEnvironment env, ModelLoadContext ctx,
Host.CheckDecode(MinPercentile <= _percentile && _percentile <= MaxPercentile);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
Host.Assert(MinPercentile <= _percentile && _percentile <= MaxPercentile);
@@ -98,7 +98,7 @@ public override void Save(ModelSaveContext ctx)
//
// Double: _percentile
- base.Save(ctx);
+ base.SaveModel(ctx);
ctx.Writer.Write(_percentile);
}
diff --git a/src/Microsoft.ML.TimeSeries/PolynomialUtils.cs b/src/Microsoft.ML.TimeSeries/PolynomialUtils.cs
index 094984dcb2..7ba7dd0d25 100644
--- a/src/Microsoft.ML.TimeSeries/PolynomialUtils.cs
+++ b/src/Microsoft.ML.TimeSeries/PolynomialUtils.cs
@@ -8,9 +8,9 @@
using System.Numerics;
using Microsoft.ML.Internal.Utilities;
-namespace Microsoft.ML.TimeSeriesProcessing
+namespace Microsoft.ML.Transforms.TimeSeries
{
- public static class PolynomialUtils
+ internal static class PolynomialUtils
{
// Part 1: Computing the polynomial real and complex roots from its real coefficients
diff --git a/src/Microsoft.ML.TimeSeries/PredictionFunction.cs b/src/Microsoft.ML.TimeSeries/PredictionFunction.cs
index e0c352e632..5aab239fef 100644
--- a/src/Microsoft.ML.TimeSeries/PredictionFunction.cs
+++ b/src/Microsoft.ML.TimeSeries/PredictionFunction.cs
@@ -10,7 +10,7 @@
using Microsoft.ML.Core.Data;
using Microsoft.ML.Data;
-namespace Microsoft.ML.TimeSeries
+namespace Microsoft.ML.Transforms.TimeSeries
{
internal interface IStatefulRowToRowMapper : IRowToRowMapper
{
@@ -18,8 +18,17 @@ internal interface IStatefulRowToRowMapper : IRowToRowMapper
internal interface IStatefulTransformer : ITransformer
{
+ ///
+ /// Same as but also supports mechanism to save the state.
+ ///
+ /// The input schema for which we should get the mapper.
+ /// The row to row mapper.
IRowToRowMapper GetStatefulRowToRowMapper(Schema inputSchema);
+ ///
+ /// Creates a clone of the transfomer. Used for taking the snapshot of the state.
+ ///
+ ///
IStatefulTransformer Clone();
}
@@ -90,6 +99,10 @@ private static ITransformer CloneTransformers(ITransformer transformer)
return transformer is IStatefulTransformer ? ((IStatefulTransformer)transformer).Clone() : transformer;
}
+ ///
+ /// Contructor for creating time series specific prediction engine. It allows update the time series model to be updated with the observations
+ /// seen at prediction time via
+ ///
public TimeSeriesPredictionFunction(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) :
base(env, CloneTransformers(transformer), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
@@ -200,7 +213,7 @@ private IRowToRowMapper GetRowToRowMapper(Schema inputSchema)
return new CompositeRowToRowMapper(inputSchema, mappers);
}
- protected override Func TransformerChecker(IExceptionContext ectx, ITransformer transformer)
+ private protected override Func TransformerChecker(IExceptionContext ectx, ITransformer transformer)
{
ectx.CheckValue(transformer, nameof(transformer));
ectx.CheckParam(IsRowToRowMapper(transformer), nameof(transformer), "Must be a row to row mapper or " + nameof(IStatefulTransformer));
diff --git a/src/Microsoft.ML.TimeSeries/Properties/AssemblyInfo.cs b/src/Microsoft.ML.TimeSeries/Properties/AssemblyInfo.cs
index 56899c0981..a16f8e17fd 100644
--- a/src/Microsoft.ML.TimeSeries/Properties/AssemblyInfo.cs
+++ b/src/Microsoft.ML.TimeSeries/Properties/AssemblyInfo.cs
@@ -5,4 +5,7 @@
using System.Runtime.CompilerServices;
using Microsoft.ML;
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries.Tests" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)]
+
+[assembly: WantsToBeBestFriends]
diff --git a/src/Microsoft.ML.TimeSeries/SequenceModelerBase.cs b/src/Microsoft.ML.TimeSeries/SequenceModelerBase.cs
index d66b58c9f8..b6da0925cc 100644
--- a/src/Microsoft.ML.TimeSeries/SequenceModelerBase.cs
+++ b/src/Microsoft.ML.TimeSeries/SequenceModelerBase.cs
@@ -6,13 +6,13 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
-namespace Microsoft.ML.TimeSeriesProcessing
+namespace Microsoft.ML.Transforms.TimeSeries
{
///
/// The base container class for the forecast result on a sequence of type .
///
/// The type of the elements in the sequence
- public abstract class ForecastResultBase
+ internal abstract class ForecastResultBase
{
public VBuffer PointForecast;
}
@@ -22,7 +22,7 @@ public abstract class ForecastResultBase
///
/// The type of the elements in the input sequence
/// The type of the elements in the output sequence
- public abstract class SequenceModelerBase : ICanSaveModel
+ internal abstract class SequenceModelerBase : ICanSaveModel
{
private protected SequenceModelerBase()
{
@@ -75,6 +75,8 @@ private protected SequenceModelerBase()
///
/// Implementation of .
///
- public abstract void Save(ModelSaveContext ctx);
+ void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
+
+ private protected abstract void SaveModel(ModelSaveContext ctx);
}
}
diff --git a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs
index 91a554b96d..4bd1cbc87d 100644
--- a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs
+++ b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs
@@ -12,682 +12,680 @@
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
-using Microsoft.ML.TimeSeries;
+using Microsoft.ML.Transforms.TimeSeries;
-namespace Microsoft.ML.TimeSeriesProcessing
+namespace Microsoft.ML.Transforms.TimeSeries
{
- // REVIEW: This base class and its children classes generate one output column of type VBuffer to output 3 different anomaly scores as well as
- // the alert flag. Ideally these 4 output information should be put in four seaparate columns instead of one VBuffer<> column. However, this is not currently
- // possible due to our design restriction. This must be fixed in the next version and will potentially affect the children classes.
-
///
- /// The base class for sequential anomaly detection transforms that supports the p-value as well as the martingales scores computation from the sequence of
- /// raw anomaly scores whose calculation is specified by the children classes. This class also provides mechanism for the threshold-based alerting on
- /// the raw anomaly score, the p-value score or the martingale score. Currently, this class supports Power and Mixture martingales.
- /// For more details, please refer to http://arxiv.org/pdf/1204.3251.pdf
+ /// The type of the martingale.
///
- /// The type of the input sequence
- /// The type of the state object for sequential anomaly detection. Must be a class inherited from AnomalyDetectionStateBase
- public abstract class SequentialAnomalyDetectionTransformBase : SequentialTransformerBase, TState>
- where TState : SequentialAnomalyDetectionTransformBase.AnomalyDetectionStateBase, new()
+ public enum MartingaleType : byte
{
///
- /// The type of the martingale.
+ /// (None) No martingale is used.
///
- public enum MartingaleType : byte
- {
- ///
- /// (None) No martingale is used.
- ///
- None,
- ///
- /// (Power) The Power martingale is used.
- ///
- Power,
- ///
- /// (Mixture) The Mixture martingale is used.
- ///
- Mixture
- }
-
+ None,
///
- /// The side of anomaly detection.
+ /// (Power) The Power martingale is used.
///
- public enum AnomalySide : byte
- {
- ///
- /// (Positive) Only positive anomalies are detected.
- ///
- Positive,
- ///
- /// (Negative) Only negative anomalies are detected.
- ///
- Negative,
- ///
- /// (TwoSided) Both positive and negative anomalies are detected.
- ///
- TwoSided
- }
+ Power,
+ ///
+ /// (Mixture) The Mixture martingale is used.
+ ///
+ Mixture
+ }
+ ///
+ /// The side of anomaly detection.
+ ///
+ public enum AnomalySide : byte
+ {
///
- /// The score that should be thresholded to generate alerts.
+ /// (Positive) Only positive anomalies are detected.
///
- public enum AlertingScore : byte
- {
- ///
- /// (RawScore) The raw anomaly score is thresholded.
- ///
- RawScore,
- ///
- /// (PValueScore) The p-value score is thresholded.
- ///
- PValueScore,
- ///
- /// (MartingaleScore) The martingale score is thresholded.
- ///
- MartingaleScore
- }
+ Positive,
+ ///
+ /// (Negative) Only negative anomalies are detected.
+ ///
+ Negative,
+ ///
+ /// (TwoSided) Both positive and negative anomalies are detected.
+ ///
+ TwoSided
+ }
+ ///
+ /// The score that should be thresholded to generate alerts.
+ ///
+ internal enum AlertingScore : byte
+ {
///
- /// The base class that can be inherited by the 'Argument' classes in the derived classes containing the shared input parameters.
+ /// (RawScore) The raw anomaly score is thresholded.
///
- public abstract class ArgumentsBase
- {
- [Argument(ArgumentType.Required, HelpText = "The name of the source column", ShortName = "src",
- SortOrder = 1, Purpose = SpecialPurpose.ColumnName)]
- public string Source;
+ RawScore,
+ ///
+ /// (PValueScore) The p-value score is thresholded.
+ ///
+ PValueScore,
+ ///
+ /// (MartingaleScore) The martingale score is thresholded.
+ ///
+ MartingaleScore
+ }
- [Argument(ArgumentType.Required, HelpText = "The name of the new column", ShortName = "name",
- SortOrder = 2)]
- public string Name;
+ ///
+ /// The base class that can be inherited by the 'Argument' classes in the derived classes containing the shared input parameters.
+ ///
+ internal abstract class ArgumentsBase
+ {
+ [Argument(ArgumentType.Required, HelpText = "The name of the source column", ShortName = "src",
+ SortOrder = 1, Purpose = SpecialPurpose.ColumnName)]
+ public string Source;
- [Argument(ArgumentType.AtMostOnce, HelpText = "The argument that determines whether to detect positive or negative anomalies, or both", ShortName = "side",
- SortOrder = 3)]
- public AnomalySide Side = AnomalySide.TwoSided;
+ [Argument(ArgumentType.Required, HelpText = "The name of the new column", ShortName = "name",
+ SortOrder = 2)]
+ public string Name;
- [Argument(ArgumentType.AtMostOnce, HelpText = "The size of the sliding window for computing the p-value.", ShortName = "wnd",
- SortOrder = 4)]
- public int WindowSize = 1;
+ [Argument(ArgumentType.AtMostOnce, HelpText = "The argument that determines whether to detect positive or negative anomalies, or both", ShortName = "side",
+ SortOrder = 3)]
+ public AnomalySide Side = AnomalySide.TwoSided;
- [Argument(ArgumentType.AtMostOnce, HelpText = "The size of the initial window for computing the p-value as well as training if needed. The default value is set to 0, which means there is no initial window considered.",
- ShortName = "initwnd", SortOrder = 5)]
- public int InitialWindowSize = 0;
+ [Argument(ArgumentType.AtMostOnce, HelpText = "The size of the sliding window for computing the p-value.", ShortName = "wnd",
+ SortOrder = 4)]
+ public int WindowSize = 1;
- [Argument(ArgumentType.AtMostOnce, HelpText = "The martingale used for scoring",
- ShortName = "martingale", SortOrder = 6)]
- public MartingaleType Martingale = MartingaleType.Power;
+ [Argument(ArgumentType.AtMostOnce, HelpText = "The size of the initial window for computing the p-value as well as training if needed. The default value is set to 0, which means there is no initial window considered.",
+ ShortName = "initwnd", SortOrder = 5)]
+ public int InitialWindowSize = 0;
- [Argument(ArgumentType.AtMostOnce, HelpText = "The argument that determines whether anomalies should be detected based on the raw anomaly score, the p-value or the martingale score",
- ShortName = "alert", SortOrder = 7)]
- public AlertingScore AlertOn = AlertingScore.MartingaleScore;
+ [Argument(ArgumentType.AtMostOnce, HelpText = "The martingale used for scoring",
+ ShortName = "martingale", SortOrder = 6)]
+ public MartingaleType Martingale = MartingaleType.Power;
- [Argument(ArgumentType.AtMostOnce, HelpText = "The epsilon parameter for the Power martingale",
- ShortName = "eps", SortOrder = 8)]
- public Double PowerMartingaleEpsilon = 0.1;
+ [Argument(ArgumentType.AtMostOnce, HelpText = "The argument that determines whether anomalies should be detected based on the raw anomaly score, the p-value or the martingale score",
+ ShortName = "alert", SortOrder = 7)]
+ public AlertingScore AlertOn = AlertingScore.MartingaleScore;
- [Argument(ArgumentType.Required, HelpText = "The threshold for alerting",
- ShortName = "thr", SortOrder = 9)]
- public Double AlertThreshold;
- }
+ [Argument(ArgumentType.AtMostOnce, HelpText = "The epsilon parameter for the Power martingale",
+ ShortName = "eps", SortOrder = 8)]
+ public Double PowerMartingaleEpsilon = 0.1;
- // Determines the side of anomaly detection for this transform.
- protected AnomalySide Side;
+ [Argument(ArgumentType.Required, HelpText = "The threshold for alerting",
+ ShortName = "thr", SortOrder = 9)]
+ public Double AlertThreshold;
+ }
- // Determines the type of martingale used by this transform.
- protected MartingaleType Martingale;
+ // REVIEW: This base class and its children classes generate one output column of type VBuffer to output 3 different anomaly scores as well as
+ // the alert flag. Ideally these 4 output information should be put in four seaparate columns instead of one VBuffer<> column. However, this is not currently
+ // possible due to our design restriction. This must be fixed in the next version and will potentially affect the children classes.
+ ///
+ /// The base class for sequential anomaly detection transforms that supports the p-value as well as the martingales scores computation from the sequence of
+ /// raw anomaly scores whose calculation is specified by the children classes. This class also provides mechanism for the threshold-based alerting on
+ /// the raw anomaly score, the p-value score or the martingale score. Currently, this class supports Power and Mixture martingales.
+ /// For more details, please refer to http://arxiv.org/pdf/1204.3251.pdf
+ ///
+ /// The type of the input sequence
+ /// The type of the input sequence
+ internal abstract class SequentialAnomalyDetectionTransformBase : SequentialTransformerBase, TState>
+ where TState : SequentialAnomalyDetectionTransformBase.AnomalyDetectionStateBase, new()
+ {
+ // Determines the side of anomaly detection for this transform.
+ internal AnomalySide Side;
- // The epsilon parameter used by the Power martingale.
- protected Double PowerMartingaleEpsilon;
+ // Determines the type of martingale used by this transform.
+ internal MartingaleType Martingale;
- // Determines the score that should be thresholded to generate alerts by this transform.
- protected AlertingScore ThresholdScore;
+ // The epsilon parameter used by the Power martingale.
+ internal Double PowerMartingaleEpsilon;
- // Determines the threshold for generating alerts.
- protected Double AlertThreshold;
+ // Determines the score that should be thresholded to generate alerts by this transform.
+ internal AlertingScore ThresholdScore;
- // The size of the VBuffer in the dst column.
- private int _outputLength;
+ // Determines the threshold for generating alerts.
+ internal Double AlertThreshold;
- private static int GetOutputLength(AlertingScore alertingScore, IHostEnvironment host)
- {
- switch (alertingScore)
+ // The size of the VBuffer in the dst column.
+ internal int OutputLength;
+
+ // The minimum value for p-values. The smaller p-values are ceiled to this value.
+ internal const Double MinPValue = 1e-8;
+
+ // The maximun value for p-values. The larger p-values are floored to this value.
+ internal const Double MaxPValue = 1 - MinPValue;
+
+ private static int GetOutputLength(AlertingScore alertingScore, IHostEnvironment host)
+ {
+ switch (alertingScore)
+ {
+ case AlertingScore.RawScore:
+ return 2;
+ case AlertingScore.PValueScore:
+ return 3;
+ case AlertingScore.MartingaleScore:
+ return 4;
+ default:
+ throw host.Except("The alerting score can be only (0) RawScore, (1) PValueScore or (2) MartingaleScore.");
+ }
+ }
+
+ private protected SequentialAnomalyDetectionTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, string name, IHostEnvironment env,
+ AnomalySide anomalySide, MartingaleType martingale, AlertingScore alertingScore, Double powerMartingaleEpsilon,
+ Double alertThreshold)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(name), windowSize, initialWindowSize, outputColumnName, inputColumnName, new VectorType(NumberType.R8, GetOutputLength(alertingScore, env)))
{
- case AlertingScore.RawScore:
- return 2;
- case AlertingScore.PValueScore:
- return 3;
- case AlertingScore.MartingaleScore:
- return 4;
- default:
- throw host.Except("The alerting score can be only (0) RawScore, (1) PValueScore or (2) MartingaleScore.");
+ Host.CheckUserArg(Enum.IsDefined(typeof(MartingaleType), martingale), nameof(ArgumentsBase.Martingale), "Value is undefined.");
+ Host.CheckUserArg(Enum.IsDefined(typeof(AnomalySide), anomalySide), nameof(ArgumentsBase.Side), "Value is undefined.");
+ Host.CheckUserArg(Enum.IsDefined(typeof(AlertingScore), alertingScore), nameof(ArgumentsBase.AlertOn), "Value is undefined.");
+ Host.CheckUserArg(martingale != MartingaleType.None || alertingScore != AlertingScore.MartingaleScore, nameof(ArgumentsBase.Martingale), "A martingale type should be specified if alerting is based on the martingale score.");
+ Host.CheckUserArg(windowSize > 0 || alertingScore == AlertingScore.RawScore, nameof(ArgumentsBase.AlertOn),
+ "When there is no windowed buffering (i.e., " + nameof(ArgumentsBase.WindowSize) + " = 0), the alert can be generated only based on the raw score (i.e., "
+ + nameof(ArgumentsBase.AlertOn) + " = " + nameof(AlertingScore.RawScore) + ")");
+ Host.CheckUserArg(0 < powerMartingaleEpsilon && powerMartingaleEpsilon < 1, nameof(ArgumentsBase.PowerMartingaleEpsilon), "Should be in (0,1).");
+ Host.CheckUserArg(alertThreshold >= 0, nameof(ArgumentsBase.AlertThreshold), "Must be non-negative.");
+ Host.CheckUserArg(alertingScore != AlertingScore.PValueScore || (0 <= alertThreshold && alertThreshold <= 1), nameof(ArgumentsBase.AlertThreshold), "Must be in [0,1].");
+
+ ThresholdScore = alertingScore;
+ Side = anomalySide;
+ Martingale = martingale;
+ PowerMartingaleEpsilon = powerMartingaleEpsilon;
+ AlertThreshold = alertThreshold;
+ OutputLength = GetOutputLength(ThresholdScore, Host);
}
- }
- private protected SequentialAnomalyDetectionTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, string name, IHostEnvironment env,
- AnomalySide anomalySide, MartingaleType martingale, AlertingScore alertingScore, Double powerMartingaleEpsilon,
- Double alertThreshold)
- : base(Contracts.CheckRef(env, nameof(env)).Register(name), windowSize, initialWindowSize, outputColumnName, inputColumnName, new VectorType(NumberType.R8, GetOutputLength(alertingScore, env)))
- {
- Host.CheckUserArg(Enum.IsDefined(typeof(MartingaleType), martingale), nameof(ArgumentsBase.Martingale), "Value is undefined.");
- Host.CheckUserArg(Enum.IsDefined(typeof(AnomalySide), anomalySide), nameof(ArgumentsBase.Side), "Value is undefined.");
- Host.CheckUserArg(Enum.IsDefined(typeof(AlertingScore), alertingScore), nameof(ArgumentsBase.AlertOn), "Value is undefined.");
- Host.CheckUserArg(martingale != MartingaleType.None || alertingScore != AlertingScore.MartingaleScore, nameof(ArgumentsBase.Martingale), "A martingale type should be specified if alerting is based on the martingale score.");
- Host.CheckUserArg(windowSize > 0 || alertingScore == AlertingScore.RawScore, nameof(ArgumentsBase.AlertOn),
- "When there is no windowed buffering (i.e., " + nameof(ArgumentsBase.WindowSize) + " = 0), the alert can be generated only based on the raw score (i.e., "
- + nameof(ArgumentsBase.AlertOn) + " = " + nameof(AlertingScore.RawScore) + ")");
- Host.CheckUserArg(0 < powerMartingaleEpsilon && powerMartingaleEpsilon < 1, nameof(ArgumentsBase.PowerMartingaleEpsilon), "Should be in (0,1).");
- Host.CheckUserArg(alertThreshold >= 0, nameof(ArgumentsBase.AlertThreshold), "Must be non-negative.");
- Host.CheckUserArg(alertingScore != AlertingScore.PValueScore || (0 <= alertThreshold && alertThreshold <= 1), nameof(ArgumentsBase.AlertThreshold), "Must be in [0,1].");
-
- ThresholdScore = alertingScore;
- Side = anomalySide;
- Martingale = martingale;
- PowerMartingaleEpsilon = powerMartingaleEpsilon;
- AlertThreshold = alertThreshold;
- _outputLength = GetOutputLength(ThresholdScore, Host);
- }
+ private protected SequentialAnomalyDetectionTransformBase(ArgumentsBase args, string name, IHostEnvironment env)
+ : this(args.WindowSize, args.InitialWindowSize, args.Source, args.Name, name, env, args.Side, args.Martingale,
+ args.AlertOn, args.PowerMartingaleEpsilon, args.AlertThreshold)
+ {
+ }
- private protected SequentialAnomalyDetectionTransformBase(ArgumentsBase args, string name, IHostEnvironment env)
- : this(args.WindowSize, args.InitialWindowSize, args.Source, args.Name, name, env, args.Side, args.Martingale,
- args.AlertOn, args.PowerMartingaleEpsilon, args.AlertThreshold)
- {
- }
+ private protected SequentialAnomalyDetectionTransformBase(IHostEnvironment env, ModelLoadContext ctx, string name)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(name), ctx)
+ {
+ // *** Binary format ***
+ //
+ // byte: _martingale
+ // byte: _alertingScore
+ // byte: _anomalySide
+ // Double: _powerMartingaleEpsilon
+ // Double: _alertThreshold
+
+ byte temp;
+ temp = ctx.Reader.ReadByte();
+ Host.CheckDecode(Enum.IsDefined(typeof(MartingaleType), temp));
+ Martingale = (MartingaleType)temp;
+
+ temp = ctx.Reader.ReadByte();
+ Host.CheckDecode(Enum.IsDefined(typeof(AlertingScore), temp));
+ ThresholdScore = (AlertingScore)temp;
+
+ Host.CheckDecode(Martingale != MartingaleType.None || ThresholdScore != AlertingScore.MartingaleScore);
+ Host.CheckDecode(WindowSize > 0 || ThresholdScore == AlertingScore.RawScore);
+
+ temp = ctx.Reader.ReadByte();
+ Host.CheckDecode(Enum.IsDefined(typeof(AnomalySide), temp));
+ Side = (AnomalySide)temp;
+
+ PowerMartingaleEpsilon = ctx.Reader.ReadDouble();
+ Host.CheckDecode(0 < PowerMartingaleEpsilon && PowerMartingaleEpsilon < 1);
+
+ AlertThreshold = ctx.Reader.ReadDouble();
+ Host.CheckDecode(AlertThreshold >= 0);
+ Host.CheckDecode(ThresholdScore != AlertingScore.PValueScore || (0 <= AlertThreshold && AlertThreshold <= 1));
+
+ OutputLength = GetOutputLength(ThresholdScore, Host);
+ }
- private protected SequentialAnomalyDetectionTransformBase(IHostEnvironment env, ModelLoadContext ctx, string name)
- : base(Contracts.CheckRef(env, nameof(env)).Register(name), ctx)
- {
- // *** Binary format ***
- //
- // byte: _martingale
- // byte: _alertingScore
- // byte: _anomalySide
- // Double: _powerMartingaleEpsilon
- // Double: _alertThreshold
-
- byte temp;
- temp = ctx.Reader.ReadByte();
- Host.CheckDecode(Enum.IsDefined(typeof(MartingaleType), temp));
- Martingale = (MartingaleType)temp;
-
- temp = ctx.Reader.ReadByte();
- Host.CheckDecode(Enum.IsDefined(typeof(AlertingScore), temp));
- ThresholdScore = (AlertingScore)temp;
-
- Host.CheckDecode(Martingale != MartingaleType.None || ThresholdScore != AlertingScore.MartingaleScore);
- Host.CheckDecode(WindowSize > 0 || ThresholdScore == AlertingScore.RawScore);
-
- temp = ctx.Reader.ReadByte();
- Host.CheckDecode(Enum.IsDefined(typeof(AnomalySide), temp));
- Side = (AnomalySide)temp;
-
- PowerMartingaleEpsilon = ctx.Reader.ReadDouble();
- Host.CheckDecode(0 < PowerMartingaleEpsilon && PowerMartingaleEpsilon < 1);
-
- AlertThreshold = ctx.Reader.ReadDouble();
- Host.CheckDecode(AlertThreshold >= 0);
- Host.CheckDecode(ThresholdScore != AlertingScore.PValueScore || (0 <= AlertThreshold && AlertThreshold <= 1));
-
- _outputLength = GetOutputLength(ThresholdScore, Host);
- }
+ private protected override void SaveModel(ModelSaveContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+
+ Host.Assert(Enum.IsDefined(typeof(MartingaleType), Martingale));
+ Host.Assert(Enum.IsDefined(typeof(AlertingScore), ThresholdScore));
+ Host.Assert(Martingale != MartingaleType.None || ThresholdScore != AlertingScore.MartingaleScore);
+ Host.Assert(WindowSize > 0 || ThresholdScore == AlertingScore.RawScore);
+ Host.Assert(Enum.IsDefined(typeof(AnomalySide), Side));
+ Host.Assert(0 < PowerMartingaleEpsilon && PowerMartingaleEpsilon < 1);
+ Host.Assert(AlertThreshold >= 0);
+ Host.Assert(ThresholdScore != AlertingScore.PValueScore || (0 <= AlertThreshold && AlertThreshold <= 1));
+
+ // *** Binary format ***
+ //
+ // byte: _martingale
+ // byte: _alertingScore
+ // byte: _anomalySide
+ // Double: _powerMartingaleEpsilon
+ // Double: _alertThreshold
+
+ base.SaveModel(ctx);
+ ctx.Writer.Write((byte)Martingale);
+ ctx.Writer.Write((byte)ThresholdScore);
+ ctx.Writer.Write((byte)Side);
+ ctx.Writer.Write(PowerMartingaleEpsilon);
+ ctx.Writer.Write(AlertThreshold);
+ }
- public override void Save(ModelSaveContext ctx)
- {
- Host.CheckValue(ctx, nameof(ctx));
- ctx.CheckAtModel();
-
- Host.Assert(Enum.IsDefined(typeof(MartingaleType), Martingale));
- Host.Assert(Enum.IsDefined(typeof(AlertingScore), ThresholdScore));
- Host.Assert(Martingale != MartingaleType.None || ThresholdScore != AlertingScore.MartingaleScore);
- Host.Assert(WindowSize > 0 || ThresholdScore == AlertingScore.RawScore);
- Host.Assert(Enum.IsDefined(typeof(AnomalySide), Side));
- Host.Assert(0 < PowerMartingaleEpsilon && PowerMartingaleEpsilon < 1);
- Host.Assert(AlertThreshold >= 0);
- Host.Assert(ThresholdScore != AlertingScore.PValueScore || (0 <= AlertThreshold && AlertThreshold <= 1));
-
- // *** Binary format ***
- //
- // byte: _martingale
- // byte: _alertingScore
- // byte: _anomalySide
- // Double: _powerMartingaleEpsilon
- // Double: _alertThreshold
-
- base.Save(ctx);
- ctx.Writer.Write((byte)Martingale);
- ctx.Writer.Write((byte)ThresholdScore);
- ctx.Writer.Write((byte)Side);
- ctx.Writer.Write(PowerMartingaleEpsilon);
- ctx.Writer.Write(AlertThreshold);
- }
+ ///
+ /// Calculates the betting function for the Power martingale in the log scale.
+ /// For more details, please refer to http://arxiv.org/pdf/1204.3251.pdf.
+ ///
+ /// The p-value
+ /// The epsilon
+ /// The Power martingale betting function value in the natural logarithmic scale.
+ internal Double LogPowerMartigaleBettingFunc(Double p, Double epsilon)
+ {
+ Host.Assert(MinPValue > 0);
+ Host.Assert(MaxPValue < 1);
+ Host.Assert(MinPValue <= p && p <= MaxPValue);
+ Host.Assert(0 < epsilon && epsilon < 1);
- // The minimum value for p-values. The smaller p-values are ceiled to this value.
- private const Double MinPValue = 1e-8;
+ return Math.Log(epsilon) + (epsilon - 1) * Math.Log(p);
+ }
- // The maximun value for p-values. The larger p-values are floored to this value.
- private const Double MaxPValue = 1 - MinPValue;
+ ///
+ /// Calculates the betting function for the Mixture martingale in the log scale.
+ /// For more details, please refer to http://arxiv.org/pdf/1204.3251.pdf.
+ ///
+ /// The p-value
+ /// The Mixure (marginalized over epsilon) martingale betting function value in the natural logarithmic scale.
+ internal Double LogMixtureMartigaleBettingFunc(Double p)
+ {
+ Host.Assert(MinPValue > 0);
+ Host.Assert(MaxPValue < 1);
+ Host.Assert(MinPValue <= p && p <= MaxPValue);
- ///
- /// Calculates the betting function for the Power martingale in the log scale.
- /// For more details, please refer to http://arxiv.org/pdf/1204.3251.pdf.
- ///
- /// The p-value
- /// The epsilon
- /// The Power martingale betting function value in the natural logarithmic scale.
- protected Double LogPowerMartigaleBettingFunc(Double p, Double epsilon)
- {
- Host.Assert(MinPValue > 0);
- Host.Assert(MaxPValue < 1);
- Host.Assert(MinPValue <= p && p <= MaxPValue);
- Host.Assert(0 < epsilon && epsilon < 1);
+ Double logP = Math.Log(p);
+ return Math.Log(p * logP + 1 - p) - 2 * Math.Log(-logP) - logP;
+ }
- return Math.Log(epsilon) + (epsilon - 1) * Math.Log(p);
- }
+ internal override IStatefulRowMapper MakeRowMapper(Schema schema) => new Mapper(Host, this, schema);
- ///
- /// Calculates the betting function for the Mixture martingale in the log scale.
- /// For more details, please refer to http://arxiv.org/pdf/1204.3251.pdf.
- ///
- /// The p-value
- /// The Mixure (marginalized over epsilon) martingale betting function value in the natural logarithmic scale.
- protected Double LogMixtureMartigaleBettingFunc(Double p)
- {
- Host.Assert(MinPValue > 0);
- Host.Assert(MaxPValue < 1);
- Host.Assert(MinPValue <= p && p <= MaxPValue);
+ internal sealed class Mapper : IStatefulRowMapper
+ {
+ private readonly IHost _host;
+ private readonly SequentialAnomalyDetectionTransformBase _parent;
+ private readonly Schema _parentSchema;
+ private readonly int _inputColumnIndex;
+ private readonly VBuffer> _slotNames;
+ private AnomalyDetectionStateBase State { get; set; }
+
+ public Mapper(IHostEnvironment env, SequentialAnomalyDetectionTransformBase parent, Schema inputSchema)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host = env.Register(nameof(Mapper));
+ _host.CheckValue(inputSchema, nameof(inputSchema));
+ _host.CheckValue(parent, nameof(parent));
- Double logP = Math.Log(p);
- return Math.Log(p * logP + 1 - p) - 2 * Math.Log(-logP) - logP;
- }
+ if (!inputSchema.TryGetColumnIndex(parent.InputColumnName, out _inputColumnIndex))
+ throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.InputColumnName);
- ///
- /// The base state class for sequential anomaly detection: this class implements the p-values and martinagle calculations for anomaly detection
- /// given that the raw anomaly score calculation is specified by the derived classes.
- ///
- public abstract class AnomalyDetectionStateBase : StateBase
- {
- // A reference to the parent transform.
- protected SequentialAnomalyDetectionTransformBase Parent;
+ var colType = inputSchema[_inputColumnIndex].Type;
+ if (colType != NumberType.R4)
+ throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.InputColumnName, "float", colType.ToString());
- // A windowed buffer to cache the update values to the martingale score in the log scale.
- private FixedSizeQueue LogMartingaleUpdateBuffer { get; set; }
+ _parent = parent;
+ _parentSchema = inputSchema;
+ _slotNames = new VBuffer>(4, new[] { "Alert".AsMemory(), "Raw Score".AsMemory(),
+ "P-Value Score".AsMemory(), "Martingale Score".AsMemory() });
- // A windowed buffer to cache the raw anomaly scores for p-value calculation.
- private FixedSizeQueue RawScoreBuffer { get; set; }
+ State = (AnomalyDetectionStateBase)_parent.StateRef;
+ }
- // The current martingale score in the log scale.
- private Double _logMartingaleValue;
+ public Schema.DetachedColumn[] GetOutputColumns()
+ {
+ var meta = new MetadataBuilder();
+ meta.AddSlotNames(_parent.OutputLength, GetSlotNames);
+ var info = new Schema.DetachedColumn[1];
+ info[0] = new Schema.DetachedColumn(_parent.OutputColumnName, new VectorType(NumberType.R8, _parent.OutputLength), meta.GetMetadata());
+ return info;
+ }
- // Sum of the squared Euclidean distances among the raw socres in the buffer.
- // Used for computing the optimal bandwidth for Kernel Density Estimation of p-values.
- private Double _sumSquaredDist;
+ public void GetSlotNames(ref VBuffer> dst) => _slotNames.CopyTo(ref dst, 0, _parent.OutputLength);
- private int _martingaleAlertCounter;
+ public Func GetDependencies(Func activeOutput)
+ {
+ if (activeOutput(0))
+ return col => col == _inputColumnIndex;
+ else
+ return col => false;
+ }
- protected Double LatestMartingaleScore => Math.Exp(_logMartingaleValue);
+ public void Save(ModelSaveContext ctx) => _parent.SaveModel(ctx);
- private protected AnomalyDetectionStateBase() { }
+ public Delegate[] CreateGetters(Row input, Func activeOutput, out Action disposer)
+ {
+ disposer = null;
+ var getters = new Delegate[1];
+ if (activeOutput(0))
+ getters[0] = MakeGetter(input, State);
- private protected override void CloneCore(StateBase state)
- {
- base.CloneCore(state);
- Contracts.Assert(state is AnomalyDetectionStateBase);
- var stateLocal = state as AnomalyDetectionStateBase;
- stateLocal.LogMartingaleUpdateBuffer = LogMartingaleUpdateBuffer.Clone();
- stateLocal.RawScoreBuffer = RawScoreBuffer.Clone();
- }
+ return getters;
+ }
- private protected AnomalyDetectionStateBase(BinaryReader reader) : base(reader)
- {
- LogMartingaleUpdateBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueDouble(reader, Host);
- RawScoreBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host);
- _logMartingaleValue = reader.ReadDouble();
- _sumSquaredDist = reader.ReadDouble();
- _martingaleAlertCounter = reader.ReadInt32();
- }
+ private delegate void ProcessData(ref TInput src, ref VBuffer dst);
- internal override void Save(BinaryWriter writer)
- {
- base.Save(writer);
- TimeSeriesUtils.SerializeFixedSizeQueue(LogMartingaleUpdateBuffer, writer);
- TimeSeriesUtils.SerializeFixedSizeQueue(RawScoreBuffer, writer);
- writer.Write(_logMartingaleValue);
- writer.Write(_sumSquaredDist);
- writer.Write(_martingaleAlertCounter);
- }
+ private Delegate MakeGetter(Row input, AnomalyDetectionStateBase state)
+ {
+ _host.AssertValue(input);
+ var srcGetter = input.GetGetter(_inputColumnIndex);
+ ProcessData processData = _parent.WindowSize > 0 ?
+ (ProcessData)state.Process : state.ProcessWithoutBuffer;
- private Double ComputeKernelPValue(Double rawScore)
- {
- int i;
- int n = RawScoreBuffer.Count;
+ ValueGetter> valueGetter = (ref VBuffer dst) =>
+ {
+ TInput src = default;
+ srcGetter(ref src);
+ processData(ref src, ref dst);
+ };
+ return valueGetter;
+ }
- if (n == 0)
- return 0.5;
+ public Action CreatePinger(Row input, Func activeOutput, out Action disposer)
+ {
+ disposer = null;
+ Action pinger = null;
+ if (activeOutput(0))
+ pinger = MakePinger(input, State);
- Double pValue = 0;
- Double bandWidth = Math.Sqrt(2) * ((n == 1) ? 1 : Math.Sqrt(_sumSquaredDist) / n);
- bandWidth = Math.Max(bandWidth, 1e-6);
+ return pinger;
+ }
- Double diff;
- for (i = 0; i < n; ++i)
+ private Action MakePinger(Row input, AnomalyDetectionStateBase state)
{
- diff = rawScore - RawScoreBuffer[i];
- pValue -= ProbabilityFunctions.Erf(diff / bandWidth);
- _sumSquaredDist += diff * diff;
+ _host.AssertValue(input);
+ var srcGetter = input.GetGetter(_inputColumnIndex);
+ Action pinger = (long rowPosition) =>
+ {
+ TInput src = default;
+ srcGetter(ref src);
+ state.UpdateState(ref src, rowPosition, _parent.WindowSize > 0);
+ };
+ return pinger;
}
- pValue = 0.5 + pValue / (2 * n);
- if (RawScoreBuffer.IsFull)
+ public void CloneState()
{
- for (i = 1; i < n; ++i)
+ if (Interlocked.Increment(ref _parent.StateRefCount) > 1)
{
- diff = RawScoreBuffer[0] - RawScoreBuffer[i];
- _sumSquaredDist -= diff * diff;
+ State = (AnomalyDetectionStateBase)_parent.StateRef.Clone();
}
-
- diff = RawScoreBuffer[0] - rawScore;
- _sumSquaredDist -= diff * diff;
}
- return pValue;
+ public ITransformer GetTransformer()
+ {
+ return _parent;
+ }
}
-
- private protected override void SetNaOutput(ref VBuffer