diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 9d969b6be2..2b9bcb6cf9 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -33,6 +33,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TestFramework" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Predictor.Tests", "test\Microsoft.ML.Predictor.Tests\Microsoft.ML.Predictor.Tests.csproj", "{6B047E09-39C9-4583-96F3-685D84CA4117}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Functional.Tests", "test\Microsoft.ML.Functional.Tests\Microsoft.ML.Functional.Tests.csproj", "{CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}" +EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.ResultProcessor", "src\Microsoft.ML.ResultProcessor\Microsoft.ML.ResultProcessor.csproj", "{3769FCC3-9AFF-4C37-97E9-6854324681DF}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.FastTree", "src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj", "{B7B593C5-FB8C-4ADA-A638-5B53B47D087E}" @@ -928,6 +930,18 @@ Global {5E920CAC-5A28-42FB-936E-49C472130953}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU {5E920CAC-5A28-42FB-936E-49C472130953}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU {5E920CAC-5A28-42FB-936E-49C472130953}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU + {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Debug|Any CPU.Build.0 = Debug|Any CPU + {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU + {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU + {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU + {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU + {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Release|Any CPU.ActiveCfg = Release|Any CPU + {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Release|Any CPU.Build.0 = Release|Any CPU + {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU + {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU + {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU + {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -1011,6 +1025,7 @@ Global {85D0CAFD-2FE8-496A-88C7-585D35B94243} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {31D38B21-102B-41C0-9E0A-2FE0BF68D123} = {D3D38B03-B557-484D-8348-8BADEE4DF592} {5E920CAC-5A28-42FB-936E-49C472130953} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} + {CFED9F0C-FF81-4C96-8D5E-0436264CA7B5} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D} diff --git a/build/Dependencies.props b/build/Dependencies.props index 896ca68978..221d1e6552 100644 --- a/build/Dependencies.props +++ b/build/Dependencies.props @@ -43,6 +43,8 @@ 0.11.3 0.0.3-test + 0.0.10-test + 0.0.4-test diff --git a/docs/code/MlNetCookBook.md b/docs/code/MlNetCookBook.md index 1a4b423e82..bb5e866c04 100644 --- a/docs/code/MlNetCookBook.md +++ b/docs/code/MlNetCookBook.md @@ -688,11 +688,11 @@ var catColumns = data.GetColumn(mlContext, "CategoricalFeatures").Take // Build several alternative featurization pipelines. var pipeline = // Convert each categorical feature into one-hot encoding independently. - mlContext.Transforms.Categorical.OneHotEncoding("CategoricalFeatures", "CategoricalOneHot") + mlContext.Transforms.Categorical.OneHotEncoding("CategoricalOneHot", "CategoricalFeatures") // Convert all categorical features into indices, and build a 'word bag' of these. - .Append(mlContext.Transforms.Categorical.OneHotEncoding("CategoricalFeatures", "CategoricalBag", CategoricalTransform.OutputKind.Bag)) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("CategoricalBag", "CategoricalFeatures", CategoricalTransform.OutputKind.Bag)) // One-hot encode the workclass column, then drop all the categories that have fewer than 10 instances in the train set. - .Append(mlContext.Transforms.Categorical.OneHotEncoding("Workclass", "WorkclassOneHot")) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("WorkclassOneHot", "Workclass")) .Append(mlContext.Transforms.FeatureSelection.CountFeatureSelectingEstimator("WorkclassOneHot", "WorkclassOneHotTrimmed", count: 10)); // Let's train our pipeline, and then apply it to the same data. @@ -825,12 +825,12 @@ var pipeline = .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent()); // Split the data 90:10 into train and test sets, train and evaluate. -var (trainData, testData) = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1); +var split = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1); // Train the model. -var model = pipeline.Fit(trainData); +var model = pipeline.Fit(split.TrainSet); // Compute quality metrics on the test set. -var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(testData)); +var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(split.TestSet)); Console.WriteLine(metrics.AccuracyMicro); // Now run the 5-fold cross-validation experiment, using the same pipeline. @@ -838,7 +838,7 @@ var cvResults = mlContext.MulticlassClassification.CrossValidate(data, pipeline, // The results object is an array of 5 elements. For each of the 5 folds, we have metrics, model and scored test data. // Let's compute the average micro-accuracy. -var microAccuracies = cvResults.Select(r => r.metrics.AccuracyMicro); +var microAccuracies = cvResults.Select(r => r.Metrics.AccuracyMicro); Console.WriteLine(microAccuracies.Average()); ``` diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs index eb46a9683a..ed75040e77 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs @@ -43,7 +43,7 @@ public static void Calibration() var data = reader.Read(dataFile); // Split the dataset into two parts: one used for training, the other to train the calibrator - var (trainData, calibratorTrainingData) = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1); + var split = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1); // Featurize the text column through the FeaturizeText API. // Then append the StochasticDualCoordinateAscentBinary binary classifier, setting the "Label" column as the label of the dataset, and @@ -56,12 +56,12 @@ public static void Calibration() loss: new HingeLoss())); // By specifying loss: new HingeLoss(), StochasticDualCoordinateAscent will train a support vector machine (SVM). // Fit the pipeline, and get a transformer that knows how to score new data. - var transformer = pipeline.Fit(trainData); + var transformer = pipeline.Fit(split.TrainSet); IPredictor model = transformer.LastTransformer.Model; // Let's score the new data. The score will give us a numerical estimation of the chance that the particular sample // bears positive sentiment. This estimate is relative to the numbers obtained. - var scoredData = transformer.Transform(calibratorTrainingData); + var scoredData = transformer.Transform(split.TestSet); var scoredDataPreview = scoredData.Preview(); PrintRowViewValues(scoredDataPreview); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/IidChangePointDetectorTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/IidChangePointDetectorTransform.cs index 7f588568c4..9b34510bd0 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/IidChangePointDetectorTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/IidChangePointDetectorTransform.cs @@ -5,11 +5,9 @@ using System; using System.Collections.Generic; using System.IO; -using System.Linq; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.TimeSeries; -using Microsoft.ML.TimeSeriesProcessing; +using Microsoft.ML.Transforms.TimeSeries; namespace Microsoft.ML.Samples.Dynamic { @@ -54,16 +52,9 @@ public static void IidChangePointDetectorTransform() // Setup IidSpikeDetector arguments string outputColumnName = nameof(ChangePointPrediction.Prediction); string inputColumnName = nameof(IidChangePointData.Value); - var args = new IidChangePointDetector.Arguments() - { - Source = inputColumnName, - Name = outputColumnName, - Confidence = 95, // The confidence for spike detection in the range [0, 100] - ChangeHistoryLength = Size / 4, // The length of the sliding window on p-values for computing the martingale score. - }; // The transformed data. - var transformedData = new IidChangePointEstimator(ml, args).Fit(dataView).Transform(dataView); + var transformedData = ml.Transforms.IidChangePointEstimator(outputColumnName, inputColumnName, 95, Size / 4).Fit(dataView).Transform(dataView); // Getting the data of the newly created column as an IEnumerable of ChangePointPrediction. var predictionColumn = ml.CreateEnumerable(transformedData, reuseRowObject: false); @@ -119,16 +110,9 @@ public static void IidChangePointDetectorPrediction() // Setup IidSpikeDetector arguments string outputColumnName = nameof(ChangePointPrediction.Prediction); string inputColumnName = nameof(IidChangePointData.Value); - var args = new IidChangePointDetector.Arguments() - { - Source = inputColumnName, - Name = outputColumnName, - Confidence = 95, // The confidence for spike detection in the range [0, 100] - ChangeHistoryLength = Size / 4, // The length of the sliding window on p-values for computing the martingale score. - }; // Time Series model. - ITransformer model = new IidChangePointEstimator(ml, args).Fit(dataView); + ITransformer model = ml.Transforms.IidChangePointEstimator(outputColumnName, inputColumnName, 95, Size / 4).Fit(dataView); // Create a time series prediction engine from the model. var engine = model.CreateTimeSeriesPredictionFunction(ml); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/IidSpikeDetectorTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/IidSpikeDetectorTransform.cs index bcbeaa36ee..c2fedc5275 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/IidSpikeDetectorTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/IidSpikeDetectorTransform.cs @@ -1,11 +1,9 @@ using System; using System.Collections.Generic; using System.IO; -using System.Linq; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.TimeSeries; -using Microsoft.ML.TimeSeriesProcessing; +using Microsoft.ML.Transforms.TimeSeries; namespace Microsoft.ML.Samples.Dynamic { @@ -51,16 +49,9 @@ public static void IidSpikeDetectorTransform() // Setup IidSpikeDetector arguments string outputColumnName = nameof(IidSpikePrediction.Prediction); string inputColumnName = nameof(IidSpikeData.Value); - var args = new IidSpikeDetector.Arguments() - { - Source = inputColumnName, - Name = outputColumnName, - Confidence = 95, // The confidence for spike detection in the range [0, 100] - PvalueHistoryLength = Size / 4 // The size of the sliding window for computing the p-value; shorter windows are more sensitive to spikes. - }; // The transformed data. - var transformedData = new IidSpikeEstimator(ml, args).Fit(dataView).Transform(dataView); + var transformedData = ml.Transforms.IidSpikeEstimator(outputColumnName, inputColumnName, 95, Size / 4).Fit(dataView).Transform(dataView); // Getting the data of the newly created column as an IEnumerable of IidSpikePrediction. var predictionColumn = ml.CreateEnumerable(transformedData, reuseRowObject: false); @@ -108,16 +99,8 @@ public static void IidSpikeDetectorPrediction() // Setup IidSpikeDetector arguments string outputColumnName = nameof(IidSpikePrediction.Prediction); string inputColumnName = nameof(IidSpikeData.Value); - var args = new IidSpikeDetector.Arguments() - { - Source = inputColumnName, - Name = outputColumnName, - Confidence = 95, // The confidence for spike detection in the range [0, 100] - PvalueHistoryLength = Size / 4 // The size of the sliding window for computing the p-value; shorter windows are more sensitive to spikes. - }; - // The transformed model. - ITransformer model = new IidSpikeEstimator(ml, args).Fit(dataView); + ITransformer model = ml.Transforms.IidChangePointEstimator(outputColumnName, inputColumnName, 95, Size).Fit(dataView); // Create a time series prediction engine from the model. var engine = model.CreateTimeSeriesPredictionFunction(ml); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs index 4b25bdc805..f85b68bb7e 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs @@ -57,7 +57,7 @@ public static void LogisticRegression() IDataView data = reader.Read(dataFilePath); - var (trainData, testData) = ml.BinaryClassification.TrainTestSplit(data, testFraction: 0.2); + var split = ml.BinaryClassification.TrainTestSplit(data, testFraction: 0.2); var pipeline = ml.Transforms.Concatenate("Text", "workclass", "education", "marital-status", "relationship", "ethnicity", "sex", "native-country") @@ -66,9 +66,9 @@ public static void LogisticRegression() "education-num", "capital-gain", "capital-loss", "hours-per-week")) .Append(ml.BinaryClassification.Trainers.LogisticRegression()); - var model = pipeline.Fit(trainData); + var model = pipeline.Fit(split.TrainSet); - var dataWithPredictions = model.Transform(testData); + var dataWithPredictions = model.Transform(split.TestSet); var metrics = ml.BinaryClassification.Evaluate(dataWithPredictions); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIHelper.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIHelper.cs index 79052c2314..73471a0c52 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIHelper.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIHelper.cs @@ -1,7 +1,8 @@ using System; using System.Linq; using Microsoft.Data.DataView; -using Microsoft.ML.Learners; +using Microsoft.ML.Data; +using Microsoft.ML.Trainers; using Microsoft.ML.SamplesUtils; using Microsoft.ML.Trainers.HalLearners; diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs index 205f7cc4ce..77e2e63d27 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs @@ -1,6 +1,6 @@ using System; using System.Linq; -using Microsoft.ML.Learners; +using Microsoft.ML.Trainers; namespace Microsoft.ML.Samples.Dynamic.PermutationFeatureImportance { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs index 341827bde3..0f06b9c905 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs @@ -54,7 +54,7 @@ public static void SDCA_BinaryClassification() // Step 3: Run Cross-Validation on this pipeline. var cvResults = mlContext.BinaryClassification.CrossValidate(data, pipeline, labelColumn: "Sentiment"); - var accuracies = cvResults.Select(r => r.metrics.Accuracy); + var accuracies = cvResults.Select(r => r.Metrics.Accuracy); Console.WriteLine(accuracies.Average()); // If we wanted to specify more advanced parameters for the algorithm, @@ -70,7 +70,7 @@ public static void SDCA_BinaryClassification() // Run Cross-Validation on this second pipeline. var cvResults_advancedPipeline = mlContext.BinaryClassification.CrossValidate(data, pipeline, labelColumn: "Sentiment", numFolds: 3); - accuracies = cvResults_advancedPipeline.Select(r => r.metrics.Accuracy); + accuracies = cvResults_advancedPipeline.Select(r => r.Metrics.Accuracy); Console.WriteLine(accuracies.Average()); } diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/SsaChangePointDetectorTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/SsaChangePointDetectorTransform.cs index efc52a90dc..223bab2277 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/SsaChangePointDetectorTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/SsaChangePointDetectorTransform.cs @@ -1,11 +1,9 @@ using System; using System.Collections.Generic; using System.IO; -using System.Linq; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.TimeSeries; -using Microsoft.ML.TimeSeriesProcessing; +using Microsoft.ML.Transforms.TimeSeries; namespace Microsoft.ML.Samples.Dynamic { @@ -49,19 +47,9 @@ public static void SsaChangePointDetectorTransform() // Setup SsaChangePointDetector arguments var inputColumnName = nameof(SsaChangePointData.Value); var outputColumnName = nameof(ChangePointPrediction.Prediction); - var args = new SsaChangePointDetector.Arguments() - { - Source = inputColumnName, - Name = outputColumnName, - Confidence = 95, // The confidence for spike detection in the range [0, 100] - ChangeHistoryLength = 8, // The length of the window for detecting a change in trend; shorter windows are more sensitive to spikes. - TrainingWindowSize = TrainingSize, // The number of points from the beginning of the sequence used for training. - SeasonalWindowSize = SeasonalitySize + 1 // An upper bound on the largest relevant seasonality in the input time series." - - }; // The transformed data. - var transformedData = new SsaChangePointEstimator(ml, args).Fit(dataView).Transform(dataView); + var transformedData = ml.Transforms.SsaChangePointEstimator(outputColumnName, inputColumnName, 95, 8, TrainingSize, SeasonalitySize + 1).Fit(dataView).Transform(dataView); // Getting the data of the newly created column as an IEnumerable of ChangePointPrediction. var predictionColumn = ml.CreateEnumerable(transformedData, reuseRowObject: false); @@ -120,19 +108,9 @@ public static void SsaChangePointDetectorPrediction() // Setup SsaChangePointDetector arguments var inputColumnName = nameof(SsaChangePointData.Value); var outputColumnName = nameof(ChangePointPrediction.Prediction); - var args = new SsaChangePointDetector.Arguments() - { - Source = inputColumnName, - Name = outputColumnName, - Confidence = 95, // The confidence for spike detection in the range [0, 100] - ChangeHistoryLength = 8, // The length of the window for detecting a change in trend; shorter windows are more sensitive to spikes. - TrainingWindowSize = TrainingSize, // The number of points from the beginning of the sequence used for training. - SeasonalWindowSize = SeasonalitySize + 1 // An upper bound on the largest relevant seasonality in the input time series." - - }; // Train the change point detector. - ITransformer model = new SsaChangePointEstimator(ml, args).Fit(dataView); + ITransformer model = ml.Transforms.SsaChangePointEstimator(outputColumnName, inputColumnName, 95, 8, TrainingSize, SeasonalitySize + 1).Fit(dataView); // Create a prediction engine from the model for feeding new data. var engine = model.CreateTimeSeriesPredictionFunction(ml); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/SsaSpikeDetectorTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/SsaSpikeDetectorTransform.cs index db8bfcad70..9ebf1a41d2 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/SsaSpikeDetectorTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/SsaSpikeDetectorTransform.cs @@ -1,11 +1,9 @@ using System; using System.Collections.Generic; using System.IO; -using System.Linq; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.TimeSeries; -using Microsoft.ML.TimeSeriesProcessing; +using Microsoft.ML.Transforms.TimeSeries; namespace Microsoft.ML.Samples.Dynamic { @@ -56,18 +54,9 @@ public static void SsaSpikeDetectorTransform() // Setup IidSpikeDetector arguments var inputColumnName = nameof(SsaSpikeData.Value); var outputColumnName = nameof(SsaSpikePrediction.Prediction); - var args = new SsaSpikeDetector.Arguments() - { - Source = inputColumnName, - Name = outputColumnName, - Confidence = 95, // The confidence for spike detection in the range [0, 100] - PvalueHistoryLength = 8, // The size of the sliding window for computing the p-value; shorter windows are more sensitive to spikes. - TrainingWindowSize = TrainingSize, // The number of points from the beginning of the sequence used for training. - SeasonalWindowSize = SeasonalitySize + 1 // An upper bound on the largest relevant seasonality in the input time series." - }; // The transformed data. - var transformedData = new SsaSpikeEstimator(ml, args).Fit(dataView).Transform(dataView); + var transformedData = ml.Transforms.SsaSpikeEstimator(outputColumnName, inputColumnName, 95, 8, TrainingSize, SeasonalitySize + 1).Fit(dataView).Transform(dataView); // Getting the data of the newly created column as an IEnumerable of SsaSpikePrediction. var predictionColumn = ml.CreateEnumerable(transformedData, reuseRowObject: false); @@ -127,18 +116,9 @@ public static void SsaSpikeDetectorPrediction() // Setup IidSpikeDetector arguments var inputColumnName = nameof(SsaSpikeData.Value); var outputColumnName = nameof(SsaSpikePrediction.Prediction); - var args = new SsaSpikeDetector.Arguments() - { - Source = inputColumnName, - Name = outputColumnName, - Confidence = 95, // The confidence for spike detection in the range [0, 100] - PvalueHistoryLength = 8, // The size of the sliding window for computing the p-value; shorter windows are more sensitive to spikes. - TrainingWindowSize = TrainingSize, // The number of points from the beginning of the sequence used for training. - SeasonalWindowSize = SeasonalitySize + 1 // An upper bound on the largest relevant seasonality in the input time series." - }; // Train the change point detector. - ITransformer model = new SsaSpikeEstimator(ml, args).Fit(dataView); + ITransformer model = ml.Transforms.SsaChangePointEstimator(outputColumnName, inputColumnName, 95, 8, TrainingSize, SeasonalitySize + 1).Fit(dataView); // Create a prediction engine from the model for feeding new data. var engine = model.CreateTimeSeriesPredictionFunction(ml); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptron.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptron.cs new file mode 100644 index 0000000000..ee2e3fdd94 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptron.cs @@ -0,0 +1,46 @@ +using Microsoft.ML; + +namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification +{ + public static class AveragedPerceptron + { + public static void Example() + { + // In this examples we will use the adult income dataset. The goal is to predict + // if a person's income is above $50K or not, based on different pieces of information about that person. + // For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult + + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + // Setting the seed to a fixed number in this example to make outputs deterministic. + var mlContext = new MLContext(seed: 0); + + // Download and featurize the dataset + var data = SamplesUtils.DatasetUtils.LoadFeaturizedAdultDataset(mlContext); + + // Leave out 10% of data for testing + var trainTestData = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1); + + // Create data training pipeline + var pipeline = mlContext.BinaryClassification.Trainers.AveragedPerceptron( + "IsOver50K", "Features", numIterations: 10); + + // Fit this pipeline to the training data + var model = pipeline.Fit(trainTestData.TrainSet); + + // Evaluate how the model is doing on the test data + var dataWithPredictions = model.Transform(trainTestData.TestSet); + var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(dataWithPredictions, "IsOver50K"); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + // Output: + // Accuracy: 0.86 + // AUC: 0.91 + // F1 Score: 0.68 + // Negative Precision: 0.90 + // Negative Recall: 0.91 + // Positive Precision: 0.70 + // Positive Recall: 0.66 + } + } +} \ No newline at end of file diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs new file mode 100644 index 0000000000..ac9296a96d --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs @@ -0,0 +1,58 @@ +using Microsoft.ML; +using Microsoft.ML.Trainers.Online; + +namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification +{ + public static class AveragedPerceptronWithOptions + { + public static void Example() + { + // In this examples we will use the adult income dataset. The goal is to predict + // if a person's income is above $50K or not, based on different pieces of information about that person. + // For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult + + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + // Setting the seed to a fixed number in this example to make outputs deterministic. + var mlContext = new MLContext(seed: 0); + + // Download and featurize the dataset + var data = SamplesUtils.DatasetUtils.LoadFeaturizedAdultDataset(mlContext); + + // Leave out 10% of data for testing + var trainTestData = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1); + + // Define the trainer options + var options = new AveragedPerceptronTrainer.Options() + { + LossFunction = new SmoothedHingeLoss.Arguments(), + LearningRate = 0.1f, + DoLazyUpdates = false, + RecencyGain = 0.1f, + NumberOfIterations = 10, + LabelColumn = "IsOver50K", + FeatureColumn = "Features" + }; + + // Create data training pipeline + var pipeline = mlContext.BinaryClassification.Trainers.AveragedPerceptron(options); + + // Fit this pipeline to the training data + var model = pipeline.Fit(trainTestData.TrainSet); + + // Evaluate how the model is doing on the test data + var dataWithPredictions = model.Transform(trainTestData.TestSet); + var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(dataWithPredictions, "IsOver50K"); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + // Output: + // Accuracy: 0.86 + // AUC: 0.90 + // F1 Score: 0.66 + // Negative Precision: 0.89 + // Negative Recall: 0.93 + // Positive Precision: 0.72 + // Positive Recall: 0.61 + } + } +} \ No newline at end of file diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Recommendation/MatrixFactorization.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Recommendation/MatrixFactorization.cs new file mode 100644 index 0000000000..91de5a8be8 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Recommendation/MatrixFactorization.cs @@ -0,0 +1,62 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML.Data; +using static Microsoft.ML.SamplesUtils.DatasetUtils; + +namespace Microsoft.ML.Samples.Dynamic +{ + public partial class MatrixFactorizationExample + { + // This example first creates in-memory data and then use it to train a matrix factorization mode with default parameters. Afterward, quality metrics are reported. + public static void MatrixFactorization() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + var mlContext = new MLContext(seed: 0, conc: 1); + + // Get a small in-memory dataset. + var data = GetRecommendationData(); + + // Convert the in-memory matrix into an IDataView so that ML.NET components can consume it. + var dataView = mlContext.Data.ReadFromEnumerable(data); + + // Create a matrix factorization trainer which may consume "Value" as the training label, "MatrixColumnIndex" as the + // matrix's column index, and "MatrixRowIndex" as the matrix's row index. Here nameof(...) is used to extract field + // names' in MatrixElement class. + var pipeline = mlContext.Recommendation().Trainers.MatrixFactorization(nameof(MatrixElement.MatrixColumnIndex), + nameof(MatrixElement.MatrixRowIndex), nameof(MatrixElement.Value), 10, 0.2, 10); + + // Train a matrix factorization model. + var model = pipeline.Fit(dataView); + + // Apply the trained model to the training set. + var prediction = model.Transform(dataView); + + // Calculate regression matrices for the prediction result. + var metrics = mlContext.Recommendation().Evaluate(prediction, + label: nameof(MatrixElement.Value), score: nameof(MatrixElementForScore.Score)); + + // Print out some metrics for checking the model's quality. + Console.WriteLine($"L1 - {metrics.L1}"); // 0.17208 + Console.WriteLine($"L2 - {metrics.L2}"); // 0.04766 + Console.WriteLine($"LossFunction - {metrics.LossFn}"); // 0.04766 + Console.WriteLine($"RMS - {metrics.Rms}"); //0.21831 + Console.WriteLine($"RSquared - {metrics.RSquared}"); // 0.97616 + + // Create two two entries for making prediction. Of course, the prediction value, Score, is unknown so it can be anything + // (here we use Score=0 and it will be overwritten by the true prediction). If any of row and column indexes are out-of-range + // (e.g., MatrixColumnIndex=99999), the prediction value will be NaN. + var testMatrix = new List() { + new MatrixElementForScore() { MatrixColumnIndex = 1, MatrixRowIndex = 7, Score = 0 }, + new MatrixElementForScore() { MatrixColumnIndex = 3, MatrixRowIndex = 6, Score = 0 } }; + + // Again, convert the test data to a format supported by ML.NET. + var testDataView = mlContext.Data.ReadFromEnumerable(testMatrix); + // Feed the test data into the model and then iterate through all predictions. + foreach (var pred in mlContext.CreateEnumerable(model.Transform(testDataView), false)) + Console.WriteLine($"Predicted value at row {pred.MatrixRowIndex - 1} and column {pred.MatrixColumnIndex - 1} is {pred.Score}"); + // Predicted value at row 7 and column 1 is 2.876928 + // Predicted value at row 6 and column 3 is 3.587935 + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/MatrixFactorization.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Recommendation/MatrixFactorizationWithOptions.cs similarity index 50% rename from docs/samples/Microsoft.ML.Samples/Dynamic/MatrixFactorization.cs rename to docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Recommendation/MatrixFactorizationWithOptions.cs index 449aa46017..eb776087bf 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/MatrixFactorization.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Recommendation/MatrixFactorizationWithOptions.cs @@ -2,65 +2,29 @@ using System.Collections.Generic; using Microsoft.ML.Data; using Microsoft.ML.Trainers; +using static Microsoft.ML.SamplesUtils.DatasetUtils; namespace Microsoft.ML.Samples.Dynamic { - public class MatrixFactorizationExample + public partial class MatrixFactorizationExample { - // The following variables defines the shape of a matrix. Its shape is _synthesizedMatrixRowCount-by-_synthesizedMatrixColumnCount. - // Because in ML.NET key type's minimal value is zero, the first row index is always zero in C# data structure (e.g., MatrixColumnIndex=0 - // and MatrixRowIndex=0 in MatrixElement below specifies the value at the upper-left corner in the training matrix). If user's row index - // starts with 1, their row index 1 would be mapped to the 2nd row in matrix factorization module and their first row may contain no values. - // This behavior is also true to column index. - const int _synthesizedMatrixFirstColumnIndex = 1; - const int _synthesizedMatrixFirstRowIndex = 1; - const int _synthesizedMatrixColumnCount = 60; - const int _synthesizedMatrixRowCount = 100; - - // A data structure used to encode a single value in matrix - internal class MatrixElement - { - // Matrix column index is at most _synthesizedMatrixColumnCount + _synthesizedMatrixFirstColumnIndex. - [KeyType(Count = _synthesizedMatrixColumnCount + _synthesizedMatrixFirstColumnIndex)] - public uint MatrixColumnIndex; - // Matrix row index is at most _synthesizedMatrixRowCount + _synthesizedMatrixFirstRowIndex. - [KeyType(Count = _synthesizedMatrixRowCount + _synthesizedMatrixFirstRowIndex)] - public uint MatrixRowIndex; - // The value at the column MatrixColumnIndex and row MatrixRowIndex. - public float Value; - } - - // A data structure used to encode prediction result. Comparing with MatrixElement, The field Value in MatrixElement is - // renamed to Score because Score is the default name of matrix factorization's output. - internal class MatrixElementForScore - { - [KeyType(Count = _synthesizedMatrixColumnCount + _synthesizedMatrixFirstColumnIndex)] - public uint MatrixColumnIndex; - [KeyType(Count = _synthesizedMatrixRowCount + _synthesizedMatrixFirstRowIndex)] - public uint MatrixRowIndex; - public float Score; - } // This example first creates in-memory data and then use it to train a matrix factorization model. Afterward, quality metrics are reported. - public static void MatrixFactorizationInMemoryData() + public static void MatrixFactorizationWithOptions() { - // Create an in-memory matrix as a list of tuples (column index, row index, value). - var dataMatrix = new List(); - for (uint i = _synthesizedMatrixFirstColumnIndex; i < _synthesizedMatrixFirstColumnIndex + _synthesizedMatrixColumnCount; ++i) - for (uint j = _synthesizedMatrixFirstRowIndex; j < _synthesizedMatrixFirstRowIndex + _synthesizedMatrixRowCount; ++j) - dataMatrix.Add(new MatrixElement() { MatrixColumnIndex = i, MatrixRowIndex = j, Value = (i + j) % 5 }); - // Create a new context for ML.NET operations. It can be used for exception tracking and logging, // as a catalog of available operations and as the source of randomness. var mlContext = new MLContext(seed: 0, conc: 1); + // Get a small in-memory dataset. + var data = GetRecommendationData(); + // Convert the in-memory matrix into an IDataView so that ML.NET components can consume it. - var dataView = mlContext.Data.ReadFromEnumerable(dataMatrix); + var dataView = mlContext.Data.ReadFromEnumerable(data); // Create a matrix factorization trainer which may consume "Value" as the training label, "MatrixColumnIndex" as the // matrix's column index, and "MatrixRowIndex" as the matrix's row index. Here nameof(...) is used to extract field // names' in MatrixElement class. - var options = new MatrixFactorizationTrainer.Options { MatrixColumnIndexColumnName = nameof(MatrixElement.MatrixColumnIndex), @@ -68,7 +32,8 @@ public static void MatrixFactorizationInMemoryData() LabelColumnName = nameof(MatrixElement.Value), NumIterations = 10, NumThreads = 1, - K = 32, + ApproximationRank = 32, + LearningRate = 0.3 }; var pipeline = mlContext.Recommendation().Trainers.MatrixFactorization(options); @@ -84,11 +49,11 @@ public static void MatrixFactorizationInMemoryData() label: nameof(MatrixElement.Value), score: nameof(MatrixElementForScore.Score)); // Print out some metrics for checking the model's quality. - Console.WriteLine($"L1 - {metrics.L1}"); - Console.WriteLine($"L2 - {metrics.L2}"); - Console.WriteLine($"LossFunction - {metrics.LossFn}"); - Console.WriteLine($"RMS - {metrics.Rms}"); - Console.WriteLine($"RSquared - {metrics.RSquared}"); + Console.WriteLine($"L1 - {metrics.L1}"); // 0.16375 + Console.WriteLine($"L2 - {metrics.L2}"); // 0.04407 + Console.WriteLine($"LossFunction - {metrics.LossFn}"); // 0.04407 + Console.WriteLine($"RMS - {metrics.Rms}"); // 0.2099 + Console.WriteLine($"RSquared - {metrics.RSquared}"); // 0.97797 // Create two two entries for making prediction. Of course, the prediction value, Score, is unknown so it can be anything // (here we use Score=0 and it will be overwritten by the true prediction). If any of row and column indexes are out-of-range @@ -101,8 +66,10 @@ public static void MatrixFactorizationInMemoryData() var testDataView = mlContext.Data.ReadFromEnumerable(testMatrix); // Feed the test data into the model and then iterate through all predictions. - foreach (var pred in mlContext.CreateEnumerable(testDataView, false)) - Console.WriteLine($"Predicted value at row {pred.MatrixRowIndex} and column {pred.MatrixColumnIndex} is {pred.Score}"); + foreach (var pred in mlContext.CreateEnumerable(model.Transform(testDataView), false)) + Console.WriteLine($"Predicted value at row {pred.MatrixRowIndex-1} and column {pred.MatrixColumnIndex-1} is {pred.Score}"); + // Predicted value at row 7 and column 1 is 2.828761 + // Predicted value at row 6 and column 3 is 3.642226 } } } diff --git a/docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs b/docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs index 50d5b8c6aa..914788eae8 100644 --- a/docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs +++ b/docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs @@ -1,7 +1,7 @@ using System; using Microsoft.ML.Data; -using Microsoft.ML.Learners; using Microsoft.ML.StaticPipe; +using Microsoft.ML.Trainers; namespace Microsoft.ML.Samples.Static { diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index 6ebef55827..1e96439a1e 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -7,6 +7,7 @@ using System.Linq; using Microsoft.Data.DataView; using Microsoft.ML.Data; +using Microsoft.ML.Model; namespace Microsoft.ML.Core.Data { @@ -263,7 +264,7 @@ public interface IDataReaderEstimator /// The transformer is a component that transforms data. /// It also supports 'schema propagation' to answer the question of 'how will the data with this schema look, after you transform it?'. /// - public interface ITransformer + public interface ITransformer : ICanSaveModel { /// /// Schema propagation for transformers. diff --git a/src/Microsoft.ML.Data/Model/ModelHeader.cs b/src/Microsoft.ML.Core/Data/ModelHeader.cs similarity index 99% rename from src/Microsoft.ML.Data/Model/ModelHeader.cs rename to src/Microsoft.ML.Core/Data/ModelHeader.cs index 88fcf9bdad..d1b572e8bc 100644 --- a/src/Microsoft.ML.Data/Model/ModelHeader.cs +++ b/src/Microsoft.ML.Core/Data/ModelHeader.cs @@ -10,7 +10,8 @@ namespace Microsoft.ML.Model { - [StructLayout(LayoutKind.Explicit, Size = ModelHeader.Size)] + [BestFriend] + [StructLayout(LayoutKind.Explicit, Size = Size)] internal struct ModelHeader { /// diff --git a/src/Microsoft.ML.Data/Model/ModelLoadContext.cs b/src/Microsoft.ML.Core/Data/ModelLoadContext.cs similarity index 100% rename from src/Microsoft.ML.Data/Model/ModelLoadContext.cs rename to src/Microsoft.ML.Core/Data/ModelLoadContext.cs diff --git a/src/Microsoft.ML.Data/Model/ModelLoading.cs b/src/Microsoft.ML.Core/Data/ModelLoading.cs similarity index 98% rename from src/Microsoft.ML.Data/Model/ModelLoading.cs rename to src/Microsoft.ML.Core/Data/ModelLoading.cs index 06ebc0bf06..d561389e39 100644 --- a/src/Microsoft.ML.Data/Model/ModelLoading.cs +++ b/src/Microsoft.ML.Core/Data/ModelLoading.cs @@ -10,6 +10,12 @@ namespace Microsoft.ML.Model { + /// + /// Signature for a repository based model loader. This is the dual of . + /// + [BestFriend] + internal delegate void SignatureLoadModel(ModelLoadContext ctx); + public sealed partial class ModelLoadContext : IDisposable { public const string ModelStreamName = "Model.key"; diff --git a/src/Microsoft.ML.Data/Model/ModelSaveContext.cs b/src/Microsoft.ML.Core/Data/ModelSaveContext.cs similarity index 100% rename from src/Microsoft.ML.Data/Model/ModelSaveContext.cs rename to src/Microsoft.ML.Core/Data/ModelSaveContext.cs diff --git a/src/Microsoft.ML.Data/Model/ModelSaving.cs b/src/Microsoft.ML.Core/Data/ModelSaving.cs similarity index 100% rename from src/Microsoft.ML.Data/Model/ModelSaving.cs rename to src/Microsoft.ML.Core/Data/ModelSaving.cs diff --git a/src/Microsoft.ML.Data/Model/Repository.cs b/src/Microsoft.ML.Core/Data/Repository.cs similarity index 97% rename from src/Microsoft.ML.Data/Model/Repository.cs rename to src/Microsoft.ML.Core/Data/Repository.cs index dd3ebfe43a..42cf955b84 100644 --- a/src/Microsoft.ML.Data/Model/Repository.cs +++ b/src/Microsoft.ML.Core/Data/Repository.cs @@ -10,13 +10,11 @@ namespace Microsoft.ML.Model { - /// - /// Signature for a repository based model loader. This is the dual of ICanSaveModel. - /// - public delegate void SignatureLoadModel(ModelLoadContext ctx); - /// /// For saving a model into a repository. + /// Classes implementing should do an explicit implementation of . + /// Classes inheriting from a base class should overwrite the function invoked by + /// in that base class, if there is one. /// public interface ICanSaveModel { @@ -293,6 +291,8 @@ protected Entry AddEntry(string pathEnt, Stream stream) public sealed class RepositoryWriter : Repository { + private const string DirTrainingInfo = "TrainingInfo"; + private ZipArchive _archive; private Queue> _closed; @@ -301,7 +301,7 @@ public static RepositoryWriter CreateNew(Stream stream, IExceptionContext ectx = Contracts.CheckValueOrNull(ectx); ectx.CheckValue(stream, nameof(stream)); var rep = new RepositoryWriter(stream, ectx, useFileSystem); - using (var ent = rep.CreateEntry(ModelFileUtils.DirTrainingInfo, "Version.txt")) + using (var ent = rep.CreateEntry(DirTrainingInfo, "Version.txt")) using (var writer = Utils.OpenWriter(ent.Stream)) writer.WriteLine(typeof(RepositoryWriter).Assembly.GetName().Version); return rep; diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs index 25d00518f3..a42a84d559 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs @@ -942,7 +942,7 @@ private static Stream OpenStream(string filename) return OpenStream(files); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs index ff5348c58c..ccd25a0039 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs @@ -502,7 +502,7 @@ private static IDataLoader LoadTransforms(ModelLoadContext ctx, IDataLoader srcL }); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index 795fa2c66e..2a8a25df2b 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -834,7 +834,7 @@ public Bindings(ModelLoadContext ctx, TextLoader parent) OutputSchema = ComputeOutputSchema(); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -1283,7 +1283,7 @@ internal static IDataLoader Create(IHostEnvironment env, Arguments args, IMultiS internal static IDataView ReadFile(IHostEnvironment env, Arguments args, IMultiStreamSource fileSource) => new TextLoader(env, args, fileSource).Read(fileSource); - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -1420,7 +1420,7 @@ public RowCursor[] GetRowCursorSet(IEnumerable columnsNeeded, int return Cursor.CreateSet(_reader, _files, active, n); } - public void Save(ModelSaveContext ctx) => _reader.Save(ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => ((ICanSaveModel)_reader).Save(ctx); } } } \ No newline at end of file diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs index bf2af44b4c..24a29b1cc9 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs @@ -18,7 +18,7 @@ namespace Microsoft.ML.Data { // REVIEW: this class is public, as long as the Wrappers.cs in tests still rely on it. // It needs to become internal. - public sealed class TransformWrapper : ITransformer, ICanSaveModel + public sealed class TransformWrapper : ITransformer { public const string LoaderSignature = "TransformWrapper"; private const string TransformDirTemplate = "Step_{0:000}"; @@ -46,7 +46,7 @@ public Schema GetOutputSchema(Schema inputSchema) return output.Schema; } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { if (!_allowSave) throw _host.Except("Saving is not permitted."); diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs index 6b5ee01d8d..6e52f83e75 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs @@ -51,7 +51,7 @@ internal interface ITransformerChainAccessor /// A chain of transformers (possibly empty) that end with a . /// For an empty chain, is always . /// - public sealed class TransformerChain : ITransformer, ICanSaveModel, IEnumerable, ITransformerChainAccessor + public sealed class TransformerChain : ITransformer, IEnumerable, ITransformerChainAccessor where TLastTransformer : class, ITransformer { private readonly ITransformer[] _transformers; @@ -165,7 +165,7 @@ public TransformerChain Append(TNewLast transformer, Transfo return new TransformerChain(_transformers.AppendElement(transformer), _scopes.AppendElement(scope)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); @@ -181,7 +181,7 @@ public void Save(ModelSaveContext ctx) } /// - /// The loading constructor of transformer chain. Reverse of . + /// The loading constructor of transformer chain. Reverse of . /// internal TransformerChain(IHostEnvironment env, ModelLoadContext ctx) { diff --git a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs index 625bcf71e7..f3b7629bc4 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs @@ -513,7 +513,7 @@ public static TransposeLoader Create(IHostEnvironment env, ModelLoadContext ctx, }); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs b/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs index 9f7290e0d1..c34881b2d6 100644 --- a/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs +++ b/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs @@ -140,7 +140,7 @@ public Impl(IHostEnvironment env, string name, IDataView input, OneToOneColumn c Metadata.Seal(); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.Assert(false, "Shouldn't serialize this!"); throw Host.ExceptNotSupp("Shouldn't serialize this"); diff --git a/src/Microsoft.ML.Data/DataView/LambdaFilter.cs b/src/Microsoft.ML.Data/DataView/LambdaFilter.cs index 9e726ad9d4..f1cf418431 100644 --- a/src/Microsoft.ML.Data/DataView/LambdaFilter.cs +++ b/src/Microsoft.ML.Data/DataView/LambdaFilter.cs @@ -96,7 +96,7 @@ public Impl(IHostEnvironment env, string name, IDataView input, _conv = conv; } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.Assert(false, "Shouldn't serialize this!"); throw Host.ExceptNotSupp("Shouldn't serialize this"); diff --git a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs index e3a2377009..23b0211404 100644 --- a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs +++ b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs @@ -131,7 +131,7 @@ public static RowToRowMapperTransform Create(IHostEnvironment env, ModelLoadCont return h.Apply("Loading Model", ch => new RowToRowMapperTransform(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs b/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs index 3b88307eb6..affe071745 100644 --- a/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs +++ b/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs @@ -223,7 +223,7 @@ public static ChooseColumnsByIndexTransform Create(IHostEnvironment env, ModelLo return h.Apply("Loading Model", ch => new ChooseColumnsByIndexTransform(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs index c4e5ad76cd..7bade429f1 100644 --- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs @@ -948,7 +948,7 @@ public static BinaryPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadC return new BinaryPerInstanceEvaluator(env, ctx, schema); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -960,7 +960,7 @@ public override void Save(ModelSaveContext ctx) // float: _threshold // byte: _useRaw - base.Save(ctx); + base.SaveModel(ctx); ctx.SaveStringOrNull(_probCol); Contracts.Assert(FloatUtils.IsFinite(_threshold)); ctx.Writer.Write(_threshold); diff --git a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs index 067dccfb50..f1c4b19b0a 100644 --- a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs @@ -628,13 +628,13 @@ public static ClusteringPerInstanceEvaluator Create(IHostEnvironment env, ModelL return new ClusteringPerInstanceEvaluator(env, ctx, schema); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { // *** Binary format ** // base // int: number of clusters - base.Save(ctx); + base.SaveModel(ctx); Host.Assert(_numClusters > 0); ctx.Writer.Write(_numClusters); } diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs index 88fa1ee285..f15c514e0e 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs @@ -510,7 +510,12 @@ protected PerInstanceEvaluatorBase(IHostEnvironment env, ModelLoadContext ctx, throw Host.ExceptSchemaMismatch(nameof(schema), "score", ScoreCol); } - public virtual void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + /// + /// Derived class, for example A, should overwrite so that (()A).Save(ctx) can correctly dump A. + /// + private protected virtual void SaveModel(ModelSaveContext ctx) { // *** Binary format ** // int: Id of the score column name diff --git a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs index 7e78be4a75..bb0a27d87d 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs @@ -631,7 +631,7 @@ public static MultiClassPerInstanceEvaluator Create(IHostEnvironment env, ModelL return new MultiClassPerInstanceEvaluator(env, ctx, schema); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -642,7 +642,7 @@ public override void Save(ModelSaveContext ctx) // int: number of classes // int[]: Ids of the class names - base.Save(ctx); + base.SaveModel(ctx); Host.Assert(_numClasses > 0); ctx.Writer.Write(_numClasses); for (int i = 0; i < _numClasses; i++) diff --git a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs index 1e232aff0d..6757da940c 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs @@ -426,7 +426,7 @@ public static MultiOutputRegressionPerInstanceEvaluator Create(IHostEnvironment return new MultiOutputRegressionPerInstanceEvaluator(env, ctx, schema); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -434,7 +434,7 @@ public override void Save(ModelSaveContext ctx) // *** Binary format ** // base - base.Save(ctx); + base.SaveModel(ctx); } private protected override Func GetDependenciesCore(Func activeOutput) diff --git a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs index 1332c58948..1f16938682 100644 --- a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs @@ -324,7 +324,7 @@ public static QuantileRegressionPerInstanceEvaluator Create(IHostEnvironment env return new QuantileRegressionPerInstanceEvaluator(env, ctx, schema); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -335,7 +335,7 @@ public override void Save(ModelSaveContext ctx) // int: _scoreSize // int[]: Ids of the quantile names - base.Save(ctx); + base.SaveModel(ctx); Host.Assert(_scoreSize > 0); ctx.Writer.Write(_scoreSize); var quantiles = _quantiles.GetValues(); diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs index ed215e8387..b420b1ad27 100644 --- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs @@ -597,11 +597,11 @@ public static RankerPerInstanceTransform Create(IHostEnvironment env, ModelLoadC return h.Apply("Loading Model", ch => new RankerPerInstanceTransform(h, ctx, input)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - _transform.Save(ctx); + ((ICanSaveModel)_transform).Save(ctx); } public long? GetRowCount() @@ -715,7 +715,7 @@ public Transform(IHostEnvironment env, ModelLoadContext ctx, IDataView input) _bindings = new Bindings(Host, input.Schema, false, LabelCol, ScoreCol, GroupCol, _truncationLevel); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.AssertValue(ctx); @@ -725,7 +725,7 @@ public override void Save(ModelSaveContext ctx) // int: _labelGains.Length // double[]: _labelGains - base.Save(ctx); + base.SaveModel(ctx); Host.Assert(0 < _truncationLevel && _truncationLevel < 100); ctx.Writer.Write(_truncationLevel); ctx.Writer.WriteDoubleArray(_labelGains); diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs index 69df01da67..2015f30eea 100644 --- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs @@ -234,7 +234,7 @@ public static RegressionPerInstanceEvaluator Create(IHostEnvironment env, ModelL return new RegressionPerInstanceEvaluator(env, ctx, schema); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -242,7 +242,7 @@ public override void Save(ModelSaveContext ctx) // *** Binary format ** // base - base.Save(ctx); + base.SaveModel(ctx); } private protected override Func GetDependenciesCore(Func activeOutput) diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index ef942c5cdf..2ee036d760 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -407,7 +407,7 @@ private static ValueMapperCalibratedModelParameters Crea return new ValueMapperCalibratedModelParameters(env, ctx); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -462,7 +462,7 @@ private static FeatureWeightsCalibratedModelParameters C return new FeatureWeightsCalibratedModelParameters(env, ctx); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -528,7 +528,7 @@ private static ParameterMixingCalibratedModelParameters return new ParameterMixingCalibratedModelParameters(env, ctx); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -698,7 +698,7 @@ private static SchemaBindableCalibratedModelParameters C return new SchemaBindableCalibratedModelParameters(env, ctx); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -1488,7 +1488,7 @@ private static PlattCalibrator Create(IHostEnvironment env, ModelLoadContext ctx return new PlattCalibrator(env, ctx); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs b/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs index ca384c0e4e..f53bb6993e 100644 --- a/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs +++ b/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs @@ -200,7 +200,7 @@ internal CalibratorTransformer(IHostEnvironment env, ModelLoadContext ctx, strin bool ITransformer.IsRowToRowMapper => true; - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -245,7 +245,7 @@ internal Mapper(CalibratorTransformer parent, TCalibrator calibrato private protected override Func GetDependenciesCore(Func activeOutput) => col => col == _scoreColIndex; - public override void Save(ModelSaveContext ctx) => _parent.Save(ctx); + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); protected override Schema.DetachedColumn[] GetOutputColumnsCore() { diff --git a/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs b/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs index 55e30203d4..d90539cc80 100644 --- a/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs +++ b/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs @@ -146,7 +146,7 @@ private protected virtual void PredictionEngineCore(IHostEnvironment env, DataVi disposer = inputRow.Dispose; } - protected virtual Func TransformerChecker(IExceptionContext ectx, ITransformer transformer) + private protected virtual Func TransformerChecker(IExceptionContext ectx, ITransformer transformer) { ectx.CheckValue(transformer, nameof(transformer)); ectx.CheckParam(transformer.IsRowToRowMapper, nameof(transformer), "Must be a row to row mapper"); diff --git a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs index afb1f9bb51..5f837507db 100644 --- a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs +++ b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs @@ -163,7 +163,7 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) return new RowMapper(env, this, schema); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs index 60f5864f9f..612231ebf8 100644 --- a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs @@ -114,7 +114,7 @@ public static Bindings Create(ModelLoadContext ctx, return Create(env, bindable, input, roles, suffix, user: false); } - public override void Save(ModelSaveContext ctx) + internal override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -205,7 +205,7 @@ private protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); - _bindings.Save(ctx); + _bindings.SaveModel(ctx); } void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx) diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs index 2f57bab274..25366f02f8 100644 --- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs @@ -165,7 +165,7 @@ private static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadConte return h.Apply("Loading Model", ch => new LabelNameBindableMapper(h, ctx)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs index a72b09906c..bf2ec96e10 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs @@ -160,7 +160,7 @@ public static BindingsImpl Create(ModelLoadContext ctx, Schema input, return new BindingsImpl(input, rowMapper, suffix, scoreKind, false, scoreColIndex, predColType); } - public override void Save(ModelSaveContext ctx) + internal override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -335,7 +335,7 @@ private protected PredictedLabelScorerBase(IHost host, ModelLoadContext ctx, IDa private protected override void SaveCore(ModelSaveContext ctx) { Host.AssertValue(ctx); - Bindings.Save(ctx); + Bindings.SaveModel(ctx); } void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx) diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index cbbb2bbe19..64f71aba83 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -134,7 +134,11 @@ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) return (IRowToRowMapper)Scorer.ApplyToData(Host, new EmptyDataView(Host, inputSchema)); } - protected void SaveModel(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected abstract void SaveModel(ModelSaveContext ctx); + + protected void SaveModelCore(ModelSaveContext ctx) { // *** Binary format *** // @@ -157,7 +161,7 @@ protected void SaveModel(ModelSaveContext ctx) /// Those are all the transformers that work with one feature column. /// /// The model used to transform the data. - public abstract class SingleFeaturePredictionTransformerBase : PredictionTransformerBase, ISingleFeaturePredictionTransformer, ICanSaveModel + public abstract class SingleFeaturePredictionTransformerBase : PredictionTransformerBase, ISingleFeaturePredictionTransformer where TModel : class { /// @@ -226,7 +230,7 @@ public sealed override Schema GetOutputSchema(Schema inputSchema) return Transform(new EmptyDataView(Host, inputSchema)).Schema; } - public void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -235,7 +239,7 @@ public void Save(ModelSaveContext ctx) protected virtual void SaveCore(ModelSaveContext ctx) { - SaveModel(ctx); + SaveModelCore(ctx); ctx.SaveStringOrNull(FeatureColumn); } diff --git a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs index fb68553672..082b083a40 100644 --- a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs @@ -49,7 +49,7 @@ private protected RowToRowScorerBase(IHost host, ModelLoadContext ctx, IDataView ctx.LoadModel(host, out Bindable, "SchemaBindableMapper"); } - public sealed override void Save(ModelSaveContext ctx) + private protected sealed override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -417,7 +417,7 @@ protected void SaveBase(ModelSaveContext ctx) } } - public abstract void Save(ModelSaveContext ctx); + internal abstract void SaveModel(ModelSaveContext ctx); protected override ColumnType GetColumnTypeCore(int iinfo) { diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index 5d5e9e3178..5b62f45a7e 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -76,7 +76,9 @@ protected SchemaBindablePredictorWrapperBase(IHostEnvironment env, ModelLoadCont ScoreType = GetScoreType(Predictor, out ValueMapper); } - public virtual void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected virtual void SaveModel(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -283,11 +285,11 @@ public static SchemaBindablePredictorWrapper Create(IHostEnvironment env, ModelL return new SchemaBindablePredictorWrapper(env, ctx); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.SetVersionInfo(GetVersionInfo()); - base.Save(ctx); + base.SaveModel(ctx); } private protected override void SaveAsPfaCore(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputNames) @@ -390,11 +392,11 @@ public static SchemaBindableBinaryPredictorWrapper Create(IHostEnvironment env, return new SchemaBindableBinaryPredictorWrapper(env, ctx); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.SetVersionInfo(GetVersionInfo()); - base.Save(ctx); + base.SaveModel(ctx); } private protected override void SaveAsPfaCore(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputNames) @@ -631,7 +633,7 @@ private SchemaBindableQuantileRegressionPredictor(IHostEnvironment env, ModelLoa Contracts.CheckDecode(Utils.Size(_quantiles) > 0); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.SetVersionInfo(GetVersionInfo()); @@ -641,7 +643,7 @@ public override void Save(ModelSaveContext ctx) // int: the number of quantiles // double[]: the quantiles - base.Save(ctx); + base.SaveModel(ctx); ctx.Writer.WriteDoubleArray(_quantiles); } diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index d8b4c53265..586ac66bd0 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -25,6 +25,31 @@ public abstract class TrainCatalogBase [BestFriend] internal IHostEnvironment Environment => Host; + /// + /// A pair of datasets, for the train and test set. + /// + public struct TrainTestData + { + /// + /// Training set. + /// + public readonly IDataView TrainSet; + /// + /// Testing set. + /// + public readonly IDataView TestSet; + /// + /// Create pair of datasets. + /// + /// Training set. + /// Testing set. + internal TrainTestData(IDataView trainSet, IDataView testSet) + { + TrainSet = trainSet; + TestSet = testSet; + } + } + /// /// Split the dataset into the train set and test set according to the given fraction. /// Respects the if provided. @@ -37,8 +62,7 @@ public abstract class TrainCatalogBase /// Optional parameter used in combination with the . /// If the is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used. - /// A pair of datasets, for the train and test set. - public (IDataView trainSet, IDataView testSet) TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null, uint? seed = null) + public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null, uint? seed = null) { Host.CheckValue(data, nameof(data)); Host.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); @@ -61,14 +85,71 @@ public abstract class TrainCatalogBase Complement = false }, data); - return (trainFilter, testFilter); + return new TrainTestData(trainFilter, testFilter); + } + + /// + /// Results for specific cross-validation fold. + /// + protected internal struct CrossValidationResult + { + /// + /// Model trained during cross validation fold. + /// + public readonly ITransformer Model; + /// + /// Scored test set with for this fold. + /// + public readonly IDataView Scores; + /// + /// Fold number. + /// + public readonly int Fold; + + public CrossValidationResult(ITransformer model, IDataView scores, int fold) + { + Model = model; + Scores = scores; + Fold = fold; + } + } + /// + /// Results of running cross-validation. + /// + /// Type of metric class. + public sealed class CrossValidationResult where T : class + { + /// + /// Metrics for this cross-validation fold. + /// + public readonly T Metrics; + /// + /// Model trained during cross-validation fold. + /// + public readonly ITransformer Model; + /// + /// The scored hold-out set for this fold. + /// + public readonly IDataView ScoredHoldOutSet; + /// + /// Fold number. + /// + public readonly int Fold; + + internal CrossValidationResult(ITransformer model, T metrics, IDataView scores, int fold) + { + Model = model; + Metrics = metrics; + ScoredHoldOutSet = scores; + Fold = fold; + } } /// /// Train the on folds of the data sequentially. /// Return each model and each scored test dataset. /// - protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidateTrain(IDataView data, IEstimator estimator, + protected internal CrossValidationResult[] CrossValidateTrain(IDataView data, IEstimator estimator, int numFolds, string stratificationColumn, uint? seed = null) { Host.CheckValue(data, nameof(data)); @@ -78,7 +159,7 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate EnsureStratificationColumn(ref data, ref stratificationColumn, seed); - Func foldFunction = + Func foldFunction = fold => { var trainFilter = new RangeFilter(Host, new RangeFilter.Options @@ -98,17 +179,17 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate var model = estimator.Fit(trainFilter); var scoredTest = model.Transform(testFilter); - return (scoredTest, model); + return new CrossValidationResult(model, scoredTest, fold); }; // Sequential per-fold training. // REVIEW: we could have a parallel implementation here. We would need to // spawn off a separate host per fold in that case. - var result = new List<(IDataView scores, ITransformer model)>(); + var result = new CrossValidationResult[numFolds]; for (int fold = 0; fold < numFolds; fold++) - result.Add(foldFunction(fold)); + result[fold] = foldFunction(fold); - return result.ToArray(); + return result; } protected internal TrainCatalogBase(IHostEnvironment env, string registrationName) @@ -263,13 +344,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string /// If the is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public (BinaryClassificationMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidateNonCalibrated( + public CrossValidationResult[] CrossValidateNonCalibrated( IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (EvaluateNonCalibrated(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => new CrossValidationResult(x.Model, + EvaluateNonCalibrated(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); } /// @@ -287,13 +369,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string /// train to the test set. /// If not present in dataset we will generate random filled column based on provided . /// Per-fold results: metrics, models, scored datasets. - public (CalibratedBinaryClassificationMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + public CrossValidationResult[] CrossValidate( IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => new CrossValidationResult(x.Model, + Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); } } @@ -369,12 +452,13 @@ public ClusteringMetrics Evaluate(IDataView data, /// If the is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public (ClusteringMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + public CrossValidationResult[] CrossValidate( IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = null, string featuresColumn = null, string stratificationColumn = null, uint? seed = null) { var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (Evaluate(x.scoredTestSet, label: labelColumn, features: featuresColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => new CrossValidationResult(x.Model, + Evaluate(x.Scores, label: labelColumn, features: featuresColumn), x.Scores, x.Fold)).ToArray(); } } @@ -444,13 +528,14 @@ public MultiClassClassifierMetrics Evaluate(IDataView data, string label = Defau /// If the is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public (MultiClassClassifierMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + public CrossValidationResult[] CrossValidate( IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => new CrossValidationResult(x.Model, + Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); } } @@ -511,13 +596,14 @@ public RegressionMetrics Evaluate(IDataView data, string label = DefaultColumnNa /// If the is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public (RegressionMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + public CrossValidationResult[] CrossValidate( IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => new CrossValidationResult(x.Model, + Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); } } diff --git a/src/Microsoft.ML.Data/Transforms/BootstrapSamplingTransformer.cs b/src/Microsoft.ML.Data/Transforms/BootstrapSamplingTransformer.cs index cb61c00db1..ee3289396f 100644 --- a/src/Microsoft.ML.Data/Transforms/BootstrapSamplingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/BootstrapSamplingTransformer.cs @@ -129,7 +129,7 @@ private BootstrapSamplingTransformer(IHost host, ModelLoadContext ctx, IDataView Host.CheckDecode(_poolSize >= 0); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs index f9b5d6bf8e..8206e66723 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs @@ -263,7 +263,7 @@ private static VersionInfo GetVersionInfo() private const int VersionAddedAliases = 0x00010002; private const int VersionTransformer = 0x00010003; - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -864,7 +864,7 @@ private protected override Func GetDependenciesCore(Func a protected override Schema.DetachedColumn[] GetOutputColumnsCore() => _columns.Select(x => x.MakeSchemaColumn()).ToArray(); - public override void Save(ModelSaveContext ctx) => _parent.Save(ctx); + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { diff --git a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs index 7dfdf17d23..07e5784bad 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs @@ -166,7 +166,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { ctx.SetVersionInfo(GetVersionInfo()); SaveColumns(ctx); diff --git a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs index c3e6b50018..00bd42fee0 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs @@ -124,7 +124,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) /// /// The allows the user to specify columns to drop or keep from a given input. /// - public sealed class ColumnSelectingTransformer : ITransformer, ICanSaveModel + public sealed class ColumnSelectingTransformer : ITransformer { internal const string Summary = "Selects which columns from the dataset to keep."; internal const string UserName = "Select Columns Transform"; @@ -417,7 +417,9 @@ private static IDataTransform Create(IHostEnvironment env, Options options, IDat return new SelectColumnsDataTransform(env, transform, new Mapper(transform, input.Schema), input); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + internal void SaveModel(ModelSaveContext ctx) { ctx.SetVersionInfo(GetVersionInfo()); @@ -678,7 +680,7 @@ public RowCursor[] GetRowCursorSet(IEnumerable columnsNeeded, int return cursors; } - public void Save(ModelSaveContext ctx) => _transform.Save(ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => _transform.SaveModel(ctx); public Func GetDependencies(Func activeOutput) { diff --git a/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs b/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs index 461a3f61e9..f6f46d62ac 100644 --- a/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs @@ -172,7 +172,7 @@ private FeatureContributionCalculatingTransformer(IHostEnvironment env, ModelLoa Normalize = ctx.Reader.ReadBoolByte(); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs index 54fb453894..c14cdd49d9 100644 --- a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs @@ -163,7 +163,7 @@ public static Bindings Create(ModelLoadContext ctx, Schema input) return new Bindings(useCounter, states, input, false, names); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -309,7 +309,7 @@ public static GenerateNumberTransform Create(IHostEnvironment env, ModelLoadCont return h.Apply("Loading Model", ch => new GenerateNumberTransform(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index f98f4baa47..d60794d2d8 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -278,7 +278,7 @@ private HashingTransformer(IHost host, ModelLoadContext ctx) TextModelHelper.LoadAll(Host, ctx, columnsLength, out _keyValues, out _kvTypes); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs index ffb6c61da3..eaaf566e95 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs @@ -144,7 +144,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs index d2c7a5e260..cc4ed51905 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs @@ -143,7 +143,7 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(KeyToVectorMappingTransformer).Assembly.FullName); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs index 408e62e330..cf0b9790ea 100644 --- a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs @@ -120,7 +120,7 @@ public static LabelConvertTransform Create(IHostEnvironment env, ModelLoadContex }); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs index 21b0b13bce..e848312ee9 100644 --- a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs @@ -98,7 +98,7 @@ public static LabelIndicatorTransform Create(IHostEnvironment env, ch => new LabelIndicatorTransform(h, args, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/NAFilter.cs b/src/Microsoft.ML.Data/Transforms/NAFilter.cs index ffc04cbc80..956f201cd6 100644 --- a/src/Microsoft.ML.Data/Transforms/NAFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/NAFilter.cs @@ -169,7 +169,7 @@ public static NAFilter Create(IHostEnvironment env, ModelLoadContext ctx, IDataV return h.Apply("Loading Model", ch => new NAFilter(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/NopTransform.cs b/src/Microsoft.ML.Data/Transforms/NopTransform.cs index 92f701d704..234c060be2 100644 --- a/src/Microsoft.ML.Data/Transforms/NopTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/NopTransform.cs @@ -89,7 +89,7 @@ private NopTransform(IHost host, ModelLoadContext ctx, IDataView input) // Nothing :) } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs index 4888aeb4c4..f15f3e77fc 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs @@ -369,7 +369,9 @@ private AffineColumnFunction(IHost host) Host = host; } - public abstract void Save(ModelSaveContext ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected abstract void SaveModel(ModelSaveContext ctx); public abstract JToken PfaInfo(BoundPfaContext ctx, JToken srcToken); public bool CanSaveOnnx(OnnxContext ctx) => true; @@ -485,7 +487,9 @@ private CdfColumnFunction(IHost host) Host = host; } - public abstract void Save(ModelSaveContext ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected abstract void SaveModel(ModelSaveContext ctx); public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) => null; @@ -614,7 +618,9 @@ protected BinColumnFunction(IHost host) Host = host; } - public abstract void Save(ModelSaveContext ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected abstract void SaveModel(ModelSaveContext ctx); public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) => null; diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs index 54b4c11ce4..e2050bdcdc 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs @@ -568,7 +568,7 @@ private void GetResult(ref TFloat input, ref TFloat value) value = (input - Offset) * Scale; } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { AffineNormSerializationUtils.SaveModel(ctx, 1, null, new[] { Scale }, new[] { Offset }, saveText: true); } @@ -628,7 +628,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr return new ImplVec(host, scales, offsets, (offsets != null && nz.Count < cv / 2) ? nz.ToArray() : null); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { AffineNormSerializationUtils.SaveModel(ctx, Scale.Length, null, Scale, Offset, saveText: true); } @@ -892,7 +892,7 @@ private void GetResult(ref TFloat input, ref TFloat value) value = CdfUtils.Cdf(val, Mean, Stddev); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -947,7 +947,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr return new ImplVec(host, mean, stddev, useLog); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -1071,7 +1071,7 @@ public ImplOne(IHost host, TFloat[] binUpperBounds, bool fixZero) return new ImplOne(host, binUpperBounds[0], fixZero); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -1157,7 +1157,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr return new ImplVec(host, binUpperBounds, fixZero); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs index 482e5d8f71..7a4abcdfbb 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs @@ -570,7 +570,7 @@ private void GetResult(ref TFloat input, ref TFloat value) value = (input - Offset) * Scale; } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { AffineNormSerializationUtils.SaveModel(ctx, 1, null, new[] { Scale }, new[] { Offset }, saveText: true); } @@ -629,7 +629,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr return new ImplVec(host, scales, offsets, (offsets != null && nz.Count < cv / 2) ? nz.ToArray() : null); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { AffineNormSerializationUtils.SaveModel(ctx, Scale.Length, null, Scale, Offset, saveText: true); } @@ -896,7 +896,7 @@ private void GetResult(ref TFloat input, ref TFloat value) value = CdfUtils.Cdf(val, Mean, Stddev); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -951,7 +951,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr return new ImplVec(host, mean, stddev, useLog); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -1076,7 +1076,7 @@ public ImplOne(IHost host, TFloat[] binUpperBounds, bool fixZero) return new ImplOne(host, binUpperBounds[0], fixZero); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -1162,7 +1162,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr return new ImplVec(host, binUpperBounds, fixZero); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs index bff9503614..78edabbe0c 100644 --- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs +++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs @@ -540,7 +540,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs index f1cf1f52b3..4a4a94f7de 100644 --- a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs +++ b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs @@ -110,7 +110,7 @@ private protected override Func GetDependenciesCore(Func a return col => active[col]; } - public override void Save(ModelSaveContext ctx) => _parent.Save(ctx); + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); } } } diff --git a/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs b/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs index c5e95a7683..f226d28a30 100644 --- a/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs @@ -133,7 +133,9 @@ protected PerGroupTransformBase(IHostEnvironment env, ModelLoadContext ctx, IDat GroupCol = ctx.LoadNonEmptyString(); } - public virtual void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected virtual void SaveModel(ModelSaveContext ctx) { Host.AssertValue(ctx); diff --git a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs index 838f1698b2..e6f356ffe4 100644 --- a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs @@ -177,7 +177,7 @@ public static RangeFilter Create(IHostEnvironment env, ModelLoadContext ctx, IDa return h.Apply("Loading Model", ch => new RangeFilter(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs index 0ed952c379..0f6b6a5dc3 100644 --- a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs @@ -159,7 +159,7 @@ public static RowShufflingTransformer Create(IHostEnvironment env, ModelLoadCont return h.Apply("Loading Model", ch => new RowShufflingTransformer(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs index 104cdad96d..13622d6063 100644 --- a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs +++ b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs @@ -13,7 +13,7 @@ namespace Microsoft.ML.Data /// /// Base class for transformer which produce new columns, but doesn't affect existing ones. /// - public abstract class RowToRowTransformerBase : ITransformer, ICanSaveModel + public abstract class RowToRowTransformerBase : ITransformer { protected readonly IHost Host; @@ -23,7 +23,9 @@ protected RowToRowTransformerBase(IHost host) Host = host; } - public abstract void Save(ModelSaveContext ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected abstract void SaveModel(ModelSaveContext ctx); public bool IsRowToRowMapper => true; @@ -109,7 +111,9 @@ Func IRowMapper.GetDependencies(Func activeOutput) [BestFriend] private protected abstract Func GetDependenciesCore(Func activeOutput); - public abstract void Save(ModelSaveContext ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected abstract void SaveModel(ModelSaveContext ctx); public ITransformer GetTransformer() => _parent; } diff --git a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs index 6c8a316a7c..b83e701597 100644 --- a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs @@ -168,7 +168,7 @@ public static SkipTakeFilter Create(IHostEnvironment env, ModelLoadContext ctx, } ///Saves class data to context - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs index 9b57d6d031..037521e7e1 100644 --- a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs @@ -324,7 +324,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/TransformBase.cs b/src/Microsoft.ML.Data/Transforms/TransformBase.cs index ec8dec1af9..0ab2b61c27 100644 --- a/src/Microsoft.ML.Data/Transforms/TransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/TransformBase.cs @@ -43,7 +43,9 @@ protected TransformBase(IHost host, IDataView input) Source = input; } - public abstract void Save(ModelSaveContext ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected abstract void SaveModel(ModelSaveContext ctx); public abstract long? GetRowCount(); @@ -388,7 +390,7 @@ public static Bindings Create(OneToOneTransformBase parent, ModelLoadContext ctx return new Bindings(parent, infos, inputSchema, false, names); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs index 5bdbdb0a29..7dcd1ab641 100644 --- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs +++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs @@ -206,7 +206,7 @@ internal TypeConvertingTransformer(IHostEnvironment env, params TypeConvertingEs _columns = columns.ToArray(); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs index b6469b55cc..c74650a7a0 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs @@ -740,7 +740,7 @@ protected static PrimitiveType GetPrimitiveType(Type rawType, out bool isVectorT return ColumnTypeExtensions.PrimitiveTypeFromKind(kind); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs index 011d491608..dacd543f29 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs @@ -641,7 +641,7 @@ private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] info return termMap; } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.Ensemble/Batch.cs b/src/Microsoft.ML.Ensemble/Batch.cs index caaf6bf4f3..756b095c72 100644 --- a/src/Microsoft.ML.Ensemble/Batch.cs +++ b/src/Microsoft.ML.Ensemble/Batch.cs @@ -4,7 +4,7 @@ using Microsoft.ML.Data; -namespace Microsoft.ML.Ensemble +namespace Microsoft.ML.Trainers.Ensemble { internal sealed class Batch { diff --git a/src/Microsoft.ML.Ensemble/EnsembleUtils.cs b/src/Microsoft.ML.Ensemble/EnsembleUtils.cs index 2af9b96f3b..342a33b9bd 100644 --- a/src/Microsoft.ML.Ensemble/EnsembleUtils.cs +++ b/src/Microsoft.ML.Ensemble/EnsembleUtils.cs @@ -7,7 +7,7 @@ using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Ensemble +namespace Microsoft.ML.Trainers.Ensemble { internal static class EnsembleUtils { diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs index cf88c501b8..6ce5673797 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs @@ -10,15 +10,13 @@ using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble; -using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Trainers.Ensemble; [assembly: LoadableClass(typeof(void), typeof(EnsembleCreator), null, typeof(SignatureEntryPointModule), "CreateEnsemble")] -namespace Microsoft.ML.EntryPoints +namespace Microsoft.ML.Trainers.Ensemble { /// /// A component to combine given models into an ensemble model. diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs b/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs index 83d8fd0a96..075b3659ec 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs @@ -2,16 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.DiversityMeasure; using Microsoft.ML.EntryPoints; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.DiversityMeasure; [assembly: EntryPointModule(typeof(DisagreementDiversityFactory))] [assembly: EntryPointModule(typeof(RegressionDisagreementDiversityFactory))] [assembly: EntryPointModule(typeof(MultiDisagreementDiversityFactory))] -namespace Microsoft.ML.Ensemble.EntryPoints +namespace Microsoft.ML.Trainers.Ensemble { [TlcModule.Component(Name = DisagreementDiversityMeasure.LoadName, FriendlyName = DisagreementDiversityMeasure.UserName)] internal sealed class DisagreementDiversityFactory : ISupportBinaryDiversityMeasureFactory diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs index 6cf8b418bc..afec3af471 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs @@ -3,12 +3,12 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML; -using Microsoft.ML.Ensemble.EntryPoints; using Microsoft.ML.EntryPoints; +using Microsoft.ML.Trainers.Ensemble; [assembly: LoadableClass(typeof(void), typeof(Ensemble), null, typeof(SignatureEntryPointModule), "TrainEnsemble")] -namespace Microsoft.ML.Ensemble.EntryPoints +namespace Microsoft.ML.Trainers.Ensemble { internal static class Ensemble { diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/FeatureSelector.cs b/src/Microsoft.ML.Ensemble/EntryPoints/FeatureSelector.cs index c2d7862a31..8af31a0723 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/FeatureSelector.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/FeatureSelector.cs @@ -2,15 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.FeatureSelector; using Microsoft.ML.EntryPoints; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.FeatureSelector; [assembly: EntryPointModule(typeof(AllFeatureSelectorFactory))] [assembly: EntryPointModule(typeof(RandomFeatureSelector))] -namespace Microsoft.ML.Ensemble.EntryPoints +namespace Microsoft.ML.Trainers.Ensemble { [TlcModule.Component(Name = AllFeatureSelector.LoadName, FriendlyName = AllFeatureSelector.UserName)] public sealed class AllFeatureSelectorFactory : ISupportFeatureSelectorFactory diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/OutputCombiner.cs b/src/Microsoft.ML.Ensemble/EntryPoints/OutputCombiner.cs index 7583d72cfb..5af5cdf487 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/OutputCombiner.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/OutputCombiner.cs @@ -2,9 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.EntryPoints; +using Microsoft.ML.Trainers.Ensemble; [assembly: EntryPointModule(typeof(AverageFactory))] [assembly: EntryPointModule(typeof(MedianFactory))] @@ -18,7 +17,7 @@ [assembly: EntryPointModule(typeof(VotingFactory))] [assembly: EntryPointModule(typeof(WeightedAverage))] -namespace Microsoft.ML.Ensemble.EntryPoints +namespace Microsoft.ML.Trainers.Ensemble { [TlcModule.Component(Name = Average.LoadName, FriendlyName = Average.UserName)] public sealed class AverageFactory : ISupportBinaryOutputCombinerFactory, ISupportRegressionOutputCombinerFactory diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs index c80dba1738..2bafa85a2f 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs @@ -4,13 +4,13 @@ using Microsoft.Data.DataView; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.EntryPoints; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Trainers.Ensemble; [assembly: EntryPointModule(typeof(PipelineEnsemble))] -namespace Microsoft.ML.Ensemble.EntryPoints +namespace Microsoft.ML.Trainers.Ensemble { internal static class PipelineEnsemble { diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/SubModelSelector.cs b/src/Microsoft.ML.Ensemble/EntryPoints/SubModelSelector.cs index 10b1551c62..867833f999 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/SubModelSelector.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/SubModelSelector.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.SubModelSelector; using Microsoft.ML.EntryPoints; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.SubModelSelector; [assembly: EntryPointModule(typeof(AllSelectorFactory))] [assembly: EntryPointModule(typeof(AllSelectorMultiClassFactory))] @@ -16,7 +15,7 @@ [assembly: EntryPointModule(typeof(BestPerformanceSelector))] [assembly: EntryPointModule(typeof(BestPerformanceSelectorMultiClass))] -namespace Microsoft.ML.Ensemble.EntryPoints +namespace Microsoft.ML.Trainers.Ensemble { [TlcModule.Component(Name = AllSelector.LoadName, FriendlyName = AllSelector.UserName)] public sealed class AllSelectorFactory : ISupportBinarySubModelSelectorFactory, ISupportRegressionSubModelSelectorFactory diff --git a/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs b/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs index 733573d39f..d583a1b125 100644 --- a/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs +++ b/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs @@ -6,7 +6,7 @@ using System.Collections.Generic; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Ensemble +namespace Microsoft.ML.Trainers.Ensemble { internal sealed class FeatureSubsetModel { diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs index 05b85b37ec..e78e63cecc 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs @@ -4,15 +4,15 @@ using System; using Microsoft.ML; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.Model; +using Microsoft.ML.Trainers.Ensemble; [assembly: LoadableClass(typeof(Average), null, typeof(SignatureCombiner), Average.UserName)] [assembly: LoadableClass(typeof(Average), null, typeof(SignatureLoadModel), Average.UserName, Average.LoaderSignature)] -namespace Microsoft.ML.Ensemble.OutputCombiners +namespace Microsoft.ML.Trainers.Ensemble { - public sealed class Average : BaseAverager, ICanSaveModel, IRegressionOutputCombiner + public sealed class Average : BaseAverager, IRegressionOutputCombiner { public const string UserName = "Average"; public const string LoadName = "Average"; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs index 1ffbe13ec1..1d8f56e3fe 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs @@ -5,9 +5,9 @@ using System; using Microsoft.ML.Model; -namespace Microsoft.ML.Ensemble.OutputCombiners +namespace Microsoft.ML.Trainers.Ensemble { - public abstract class BaseAverager : IBinaryOutputCombiner + public abstract class BaseAverager : IBinaryOutputCombiner, ICanSaveModel { protected readonly IHost Host; public BaseAverager(IHostEnvironment env, string name) @@ -30,7 +30,7 @@ protected BaseAverager(IHostEnvironment env, string name, ModelLoadContext ctx) Host.CheckDecode(cbFloat == sizeof(Single)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs index 73e6f4ea7e..044ccfb44e 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs @@ -8,7 +8,7 @@ using Microsoft.ML.Model; using Microsoft.ML.Numeric; -namespace Microsoft.ML.Ensemble.OutputCombiners +namespace Microsoft.ML.Trainers.Ensemble { public abstract class BaseMultiAverager : BaseMultiCombiner { diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs index 5e8296862b..d0ee4583b3 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs @@ -9,9 +9,9 @@ using Microsoft.ML.Model; using Microsoft.ML.Numeric; -namespace Microsoft.ML.Ensemble.OutputCombiners +namespace Microsoft.ML.Trainers.Ensemble { - public abstract class BaseMultiCombiner : IMultiClassOutputCombiner + public abstract class BaseMultiCombiner : IMultiClassOutputCombiner, ICanSaveModel { protected readonly IHost Host; @@ -49,7 +49,7 @@ internal BaseMultiCombiner(IHostEnvironment env, string name, ModelLoadContext c Normalize = ctx.Reader.ReadBoolByte(); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs index ba75e05080..59700cb18c 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs @@ -7,7 +7,7 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; -namespace Microsoft.ML.Ensemble.OutputCombiners +namespace Microsoft.ML.Trainers.Ensemble { internal abstract class BaseScalarStacking : BaseStacking { diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs index 81d0c545f9..a172a6bce4 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs @@ -13,9 +13,9 @@ using Microsoft.ML.Model; using Microsoft.ML.Training; -namespace Microsoft.ML.Ensemble.OutputCombiners +namespace Microsoft.ML.Trainers.Ensemble { - internal abstract class BaseStacking : IStackingTrainer + internal abstract class BaseStacking : IStackingTrainer, ICanSaveModel { public abstract class ArgumentsBase { @@ -67,7 +67,7 @@ private protected BaseStacking(IHostEnvironment env, string name, ModelLoadConte CheckMeta(); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Host.Check(Meta != null, "Can't save an untrained Stacking combiner"); Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs index bbbbecf217..e054cd2fa8 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs @@ -7,7 +7,7 @@ using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; -namespace Microsoft.ML.Ensemble.OutputCombiners +namespace Microsoft.ML.Trainers.Ensemble { /// /// Signature for combiners. diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs index 3b94196eb9..88a9d73ca3 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs @@ -4,14 +4,14 @@ using System; using Microsoft.ML; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; +using Microsoft.ML.Trainers.Ensemble; [assembly: LoadableClass(typeof(Median), null, typeof(SignatureCombiner), Median.UserName, Median.LoadName)] [assembly: LoadableClass(typeof(Median), null, typeof(SignatureLoadModel), Median.UserName, Median.LoaderSignature)] -namespace Microsoft.ML.Ensemble.OutputCombiners +namespace Microsoft.ML.Trainers.Ensemble { /// /// Generic interface for combining outputs of multiple models @@ -59,7 +59,7 @@ public static Median Create(IHostEnvironment env, ModelLoadContext ctx) return new Median(env, ctx); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs index 1ba5cdf028..3cf424b50f 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs @@ -5,18 +5,18 @@ using System; using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.EntryPoints; using Microsoft.ML.Model; +using Microsoft.ML.Trainers.Ensemble; [assembly: LoadableClass(typeof(MultiAverage), typeof(MultiAverage.Arguments), typeof(SignatureCombiner), Average.UserName, MultiAverage.LoadName)] [assembly: LoadableClass(typeof(MultiAverage), null, typeof(SignatureLoadModel), Average.UserName, MultiAverage.LoadName, MultiAverage.LoaderSignature)] -namespace Microsoft.ML.Ensemble.OutputCombiners +namespace Microsoft.ML.Trainers.Ensemble { - public sealed class MultiAverage : BaseMultiAverager, ICanSaveModel + public sealed class MultiAverage : BaseMultiAverager { public const string LoadName = "MultiAverage"; public const string LoaderSignature = "MultiAverageCombiner"; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs index 477a658355..3b44413397 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs @@ -5,21 +5,21 @@ using System; using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; +using Microsoft.ML.Trainers.Ensemble; [assembly: LoadableClass(typeof(MultiMedian), typeof(MultiMedian.Arguments), typeof(SignatureCombiner), Median.UserName, MultiMedian.LoadName)] [assembly: LoadableClass(typeof(MultiMedian), null, typeof(SignatureLoadModel), Median.UserName, MultiMedian.LoaderSignature)] -namespace Microsoft.ML.Ensemble.OutputCombiners +namespace Microsoft.ML.Trainers.Ensemble { /// /// Generic interface for combining outputs of multiple models /// - public sealed class MultiMedian : BaseMultiCombiner, ICanSaveModel + public sealed class MultiMedian : BaseMultiCombiner { public const string LoadName = "MultiMedian"; public const string LoaderSignature = "MultiMedianCombiner"; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs index 3ec8ea85c8..01011a99d4 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs @@ -6,10 +6,10 @@ using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Model; +using Microsoft.ML.Trainers.Ensemble; [assembly: LoadableClass(typeof(MultiStacking), typeof(MultiStacking.Arguments), typeof(SignatureCombiner), Stacking.UserName, MultiStacking.LoadName)] @@ -17,10 +17,10 @@ [assembly: LoadableClass(typeof(MultiStacking), null, typeof(SignatureLoadModel), Stacking.UserName, MultiStacking.LoaderSignature)] -namespace Microsoft.ML.Ensemble.OutputCombiners +namespace Microsoft.ML.Trainers.Ensemble { using TVectorPredictor = IPredictorProducing>; - internal sealed class MultiStacking : BaseStacking>, ICanSaveModel, IMultiClassOutputCombiner + internal sealed class MultiStacking : BaseStacking>, IMultiClassOutputCombiner { public const string LoadName = "MultiStacking"; public const string LoaderSignature = "MultiStackingCombiner"; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs index a8ba932926..f943a4176c 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs @@ -5,19 +5,19 @@ using System; using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; using Microsoft.ML.Numeric; +using Microsoft.ML.Trainers.Ensemble; [assembly: LoadableClass(typeof(MultiVoting), null, typeof(SignatureCombiner), Voting.UserName, MultiVoting.LoadName)] [assembly: LoadableClass(typeof(MultiVoting), null, typeof(SignatureLoadModel), Voting.UserName, MultiVoting.LoaderSignature)] -namespace Microsoft.ML.Ensemble.OutputCombiners +namespace Microsoft.ML.Trainers.Ensemble { // REVIEW: Why is MultiVoting based on BaseMultiCombiner? Normalizing the model outputs // is senseless, so the base adds no real functionality. - public sealed class MultiVoting : BaseMultiCombiner, ICanSaveModel + public sealed class MultiVoting : BaseMultiCombiner { public const string LoadName = "MultiVoting"; public const string LoaderSignature = "MultiVotingCombiner"; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs index deef23abcd..aecc3963bc 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs @@ -6,10 +6,10 @@ using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Model; +using Microsoft.ML.Trainers.Ensemble; [assembly: LoadableClass(typeof(MultiWeightedAverage), typeof(MultiWeightedAverage.Arguments), typeof(SignatureCombiner), MultiWeightedAverage.UserName, MultiWeightedAverage.LoadName)] @@ -17,12 +17,12 @@ [assembly: LoadableClass(typeof(MultiWeightedAverage), null, typeof(SignatureLoadModel), MultiWeightedAverage.UserName, MultiWeightedAverage.LoaderSignature)] -namespace Microsoft.ML.Ensemble.OutputCombiners +namespace Microsoft.ML.Trainers.Ensemble { /// /// Generic interface for combining outputs of multiple models /// - public sealed class MultiWeightedAverage : BaseMultiAverager, IWeightedAverager, ICanSaveModel + public sealed class MultiWeightedAverage : BaseMultiAverager, IWeightedAverager { public const string UserName = "Multi Weighted Average"; public const string LoadName = "MultiWeightedAverage"; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs index ab57cfe9aa..ae9ef67db8 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs @@ -4,10 +4,10 @@ using System; using Microsoft.ML; using Microsoft.ML.CommandLine; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Model; +using Microsoft.ML.Trainers.Ensemble; [assembly: LoadableClass(typeof(RegressionStacking), typeof(RegressionStacking.Arguments), typeof(SignatureCombiner), Stacking.UserName, RegressionStacking.LoadName)] @@ -15,11 +15,11 @@ [assembly: LoadableClass(typeof(RegressionStacking), null, typeof(SignatureLoadModel), Stacking.UserName, RegressionStacking.LoaderSignature)] -namespace Microsoft.ML.Ensemble.OutputCombiners +namespace Microsoft.ML.Trainers.Ensemble { using TScalarPredictor = IPredictorProducing; - internal sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner, ICanSaveModel + internal sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner { public const string LoadName = "RegressionStacking"; public const string LoaderSignature = "RegressionStacking"; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs index 891963dbea..93fe1d3240 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs @@ -5,18 +5,18 @@ using System; using Microsoft.ML; using Microsoft.ML.CommandLine; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Model; +using Microsoft.ML.Trainers.Ensemble; [assembly: LoadableClass(typeof(Stacking), typeof(Stacking.Arguments), typeof(SignatureCombiner), Stacking.UserName, Stacking.LoadName)] [assembly: LoadableClass(typeof(Stacking), null, typeof(SignatureLoadModel), Stacking.UserName, Stacking.LoaderSignature)] -namespace Microsoft.ML.Ensemble.OutputCombiners +namespace Microsoft.ML.Trainers.Ensemble { using TScalarPredictor = IPredictorProducing; - internal sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner, ICanSaveModel + internal sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner { public const string UserName = "Stacking"; public const string LoadName = "Stacking"; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs index b18de67b18..17c569d8ce 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs @@ -4,14 +4,14 @@ using System; using Microsoft.ML; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; +using Microsoft.ML.Trainers.Ensemble; [assembly: LoadableClass(typeof(Voting), null, typeof(SignatureCombiner), Voting.UserName, Voting.LoadName)] [assembly: LoadableClass(typeof(Voting), null, typeof(SignatureLoadModel), Voting.UserName, Voting.LoaderSignature)] -namespace Microsoft.ML.Ensemble.OutputCombiners +namespace Microsoft.ML.Trainers.Ensemble { public sealed class Voting : IBinaryOutputCombiner, ICanSaveModel { @@ -57,7 +57,7 @@ public static Voting Create(IHostEnvironment env, ModelLoadContext ctx) return new Voting(env, ctx); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs index ec1d2dcd11..e5aa458768 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs @@ -6,10 +6,10 @@ using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Model; +using Microsoft.ML.Trainers.Ensemble; [assembly: LoadableClass(typeof(WeightedAverage), typeof(WeightedAverage.Arguments), typeof(SignatureCombiner), WeightedAverage.UserName, WeightedAverage.LoadName)] @@ -17,9 +17,9 @@ [assembly: LoadableClass(typeof(WeightedAverage), null, typeof(SignatureLoadModel), WeightedAverage.UserName, WeightedAverage.LoaderSignature)] -namespace Microsoft.ML.Ensemble.OutputCombiners +namespace Microsoft.ML.Trainers.Ensemble { - public sealed class WeightedAverage : BaseAverager, IWeightedAverager, ICanSaveModel + public sealed class WeightedAverage : BaseAverager, IWeightedAverager { public const string UserName = "Weighted Average"; public const string LoadName = "WeightedAverage"; diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs index 2e5da36c3a..4f44a42204 100644 --- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs @@ -10,18 +10,17 @@ using Microsoft.Data.DataView; using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Calibration; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; +using Microsoft.ML.Trainers.Ensemble; [assembly: LoadableClass(typeof(SchemaBindablePipelineEnsembleBase), null, typeof(SignatureLoadModel), SchemaBindablePipelineEnsembleBase.UserName, SchemaBindablePipelineEnsembleBase.LoaderSignature)] -namespace Microsoft.ML.Ensemble +namespace Microsoft.ML.Trainers.Ensemble { /// /// This class represents an ensemble predictor, where each predictor has its own featurization pipeline. It is @@ -482,7 +481,7 @@ protected SchemaBindablePipelineEnsembleBase(IHostEnvironment env, ModelLoadCont _inputCols[i] = ctx.LoadNonEmptyString(); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Host.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs index b4106a7b1d..c2f40e0b1a 100644 --- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs @@ -6,7 +6,7 @@ using System.Collections.Concurrent; using System.Collections.Generic; -namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure +namespace Microsoft.ML.Trainers.Ensemble.DiversityMeasure { internal abstract class BaseDisagreementDiversityMeasure : IDiversityMeasure { diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs index bb0127003f..4cb1cbd883 100644 --- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs @@ -4,13 +4,13 @@ using System; using Microsoft.ML; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.DiversityMeasure; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.DiversityMeasure; [assembly: LoadableClass(typeof(DisagreementDiversityMeasure), null, typeof(SignatureEnsembleDiversityMeasure), DisagreementDiversityMeasure.UserName, DisagreementDiversityMeasure.LoadName)] -namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure +namespace Microsoft.ML.Trainers.Ensemble.DiversityMeasure { internal sealed class DisagreementDiversityMeasure : BaseDisagreementDiversityMeasure, IBinaryDiversityMeasure { diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs index b182ad9963..04c9e7f2d2 100644 --- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs +++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs @@ -4,7 +4,7 @@ using System; -namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure +namespace Microsoft.ML.Trainers.Ensemble.DiversityMeasure { internal sealed class ModelDiversityMetric { diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs index ffb6839818..ccdd4d6b0b 100644 --- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs @@ -5,14 +5,14 @@ using System; using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.DiversityMeasure; using Microsoft.ML.Numeric; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.DiversityMeasure; [assembly: LoadableClass(typeof(MultiDisagreementDiversityMeasure), null, typeof(SignatureEnsembleDiversityMeasure), DisagreementDiversityMeasure.UserName, MultiDisagreementDiversityMeasure.LoadName)] -namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure +namespace Microsoft.ML.Trainers.Ensemble.DiversityMeasure { internal sealed class MultiDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure>, IMulticlassDiversityMeasure { diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs index 2b60e44f1c..c1b37411b6 100644 --- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs @@ -4,13 +4,13 @@ using System; using Microsoft.ML; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.DiversityMeasure; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.DiversityMeasure; [assembly: LoadableClass(typeof(RegressionDisagreementDiversityMeasure), null, typeof(SignatureEnsembleDiversityMeasure), DisagreementDiversityMeasure.UserName, RegressionDisagreementDiversityMeasure.LoadName)] -namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure +namespace Microsoft.ML.Trainers.Ensemble.DiversityMeasure { internal sealed class RegressionDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure, IRegressionDiversityMeasure { diff --git a/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/AllFeatureSelector.cs b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/AllFeatureSelector.cs index f2ffadf877..52c0a4752b 100644 --- a/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/AllFeatureSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/AllFeatureSelector.cs @@ -5,13 +5,13 @@ using System; using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.FeatureSelector; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.FeatureSelector; [assembly: LoadableClass(typeof(AllFeatureSelector), null, typeof(SignatureEnsembleFeatureSelector), AllFeatureSelector.UserName, AllFeatureSelector.LoadName)] -namespace Microsoft.ML.Ensemble.Selector.FeatureSelector +namespace Microsoft.ML.Trainers.Ensemble.FeatureSelector { internal sealed class AllFeatureSelector : IFeatureSelector { diff --git a/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs index 93c4bd7603..5841d8c126 100644 --- a/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs @@ -7,15 +7,15 @@ using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.FeatureSelector; using Microsoft.ML.EntryPoints; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.FeatureSelector; using Microsoft.ML.Training; [assembly: LoadableClass(typeof(RandomFeatureSelector), typeof(RandomFeatureSelector.Arguments), typeof(SignatureEnsembleFeatureSelector), RandomFeatureSelector.UserName, RandomFeatureSelector.LoadName)] -namespace Microsoft.ML.Ensemble.Selector.FeatureSelector +namespace Microsoft.ML.Trainers.Ensemble.FeatureSelector { internal class RandomFeatureSelector : IFeatureSelector { diff --git a/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs index 8ac30f3818..96642ccc8f 100644 --- a/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs @@ -6,10 +6,10 @@ using System.Collections.Concurrent; using System.Collections.Generic; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.Selector.DiversityMeasure; using Microsoft.ML.EntryPoints; +using Microsoft.ML.Trainers.Ensemble.DiversityMeasure; -namespace Microsoft.ML.Ensemble.Selector +namespace Microsoft.ML.Trainers.Ensemble { internal interface IDiversityMeasure { diff --git a/src/Microsoft.ML.Ensemble/Selector/IFeatureSelector.cs b/src/Microsoft.ML.Ensemble/Selector/IFeatureSelector.cs index e4eb986294..6ccc6da5d7 100644 --- a/src/Microsoft.ML.Ensemble/Selector/IFeatureSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/IFeatureSelector.cs @@ -6,7 +6,7 @@ using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; -namespace Microsoft.ML.Ensemble.Selector +namespace Microsoft.ML.Trainers.Ensemble { internal interface IFeatureSelector { diff --git a/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs index e5b35082ee..6f910f6a44 100644 --- a/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs @@ -7,7 +7,7 @@ using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; -namespace Microsoft.ML.Ensemble.Selector +namespace Microsoft.ML.Trainers.Ensemble { internal interface ISubModelSelector { diff --git a/src/Microsoft.ML.Ensemble/Selector/ISubsetSelector.cs b/src/Microsoft.ML.Ensemble/Selector/ISubsetSelector.cs index 6ba7002508..8ffef7aba1 100644 --- a/src/Microsoft.ML.Ensemble/Selector/ISubsetSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/ISubsetSelector.cs @@ -7,7 +7,7 @@ using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; -namespace Microsoft.ML.Ensemble.Selector +namespace Microsoft.ML.Trainers.Ensemble { internal interface ISubsetSelector { diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelector.cs index f88df3bfee..ce62ee9180 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelector.cs @@ -4,12 +4,12 @@ using System; using Microsoft.ML; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.SubModelSelector; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.SubModelSelector; [assembly: LoadableClass(typeof(AllSelector), null, typeof(SignatureEnsembleSubModelSelector), AllSelector.UserName, AllSelector.LoadName)] -namespace Microsoft.ML.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector { internal sealed class AllSelector : BaseSubModelSelector, IBinarySubModelSelector, IRegressionSubModelSelector { diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelectorMultiClass.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelectorMultiClass.cs index 2158b05733..6905579c3b 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelectorMultiClass.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/AllSelectorMultiClass.cs @@ -5,13 +5,13 @@ using System; using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.SubModelSelector; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.SubModelSelector; [assembly: LoadableClass(typeof(AllSelectorMultiClass), null, typeof(SignatureEnsembleSubModelSelector), AllSelectorMultiClass.UserName, AllSelectorMultiClass.LoadName)] -namespace Microsoft.ML.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector { internal sealed class AllSelectorMultiClass : BaseSubModelSelector>, IMulticlassSubModelSelector { diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs index 55add475f5..a13ef47b35 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs @@ -8,7 +8,7 @@ using System.Reflection; using Microsoft.ML.CommandLine; -namespace Microsoft.ML.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector { internal abstract class BaseBestPerformanceSelector : SubModelDataSelector { diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs index e1edba726e..8f6d783c76 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs @@ -6,11 +6,11 @@ using System.Collections.Concurrent; using System.Collections.Generic; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.Selector.DiversityMeasure; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Trainers.Ensemble.DiversityMeasure; using Microsoft.ML.Training; -namespace Microsoft.ML.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector { internal abstract class BaseDiverseSelector : SubModelDataSelector where TDiversityMetric : class, IDiversityMeasure diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs index 8d1189d049..87a61192cd 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs @@ -8,7 +8,7 @@ using Microsoft.Data.DataView; using Microsoft.ML.Data; -namespace Microsoft.ML.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector { internal abstract class BaseSubModelSelector : ISubModelSelector { diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorBinary.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorBinary.cs index e2a0147f01..869da7307e 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorBinary.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorBinary.cs @@ -7,20 +7,17 @@ using System.Collections.Generic; using Microsoft.ML; using Microsoft.ML.CommandLine; -using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.DiversityMeasure; -using Microsoft.ML.Ensemble.Selector.SubModelSelector; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.DiversityMeasure; +using Microsoft.ML.Trainers.Ensemble.SubModelSelector; [assembly: LoadableClass(typeof(BestDiverseSelectorBinary), typeof(BestDiverseSelectorBinary.Arguments), typeof(SignatureEnsembleSubModelSelector), BestDiverseSelectorBinary.UserName, BestDiverseSelectorBinary.LoadName)] -namespace Microsoft.ML.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector { - using TScalarPredictor = IPredictorProducing; - internal sealed class BestDiverseSelectorBinary : BaseDiverseSelector, IBinarySubModelSelector { public const string UserName = "Best Diverse Selector"; diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorMultiClass.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorMultiClass.cs index a25fe4d8d1..9a39b4f5b3 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorMultiClass.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorMultiClass.cs @@ -8,20 +8,17 @@ using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.DiversityMeasure; -using Microsoft.ML.Ensemble.Selector.SubModelSelector; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.DiversityMeasure; +using Microsoft.ML.Trainers.Ensemble.SubModelSelector; [assembly: LoadableClass(typeof(BestDiverseSelectorMultiClass), typeof(BestDiverseSelectorMultiClass.Arguments), typeof(SignatureEnsembleSubModelSelector), BestDiverseSelectorMultiClass.UserName, BestDiverseSelectorMultiClass.LoadName)] -namespace Microsoft.ML.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector { - using TVectorPredictor = IPredictorProducing>; - internal sealed class BestDiverseSelectorMultiClass : BaseDiverseSelector, IDiversityMeasure>>, IMulticlassSubModelSelector { public const string UserName = "Best Diverse Selector"; diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorRegression.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorRegression.cs index ef310eab7e..2022e7d057 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorRegression.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorRegression.cs @@ -7,20 +7,17 @@ using System.Collections.Generic; using Microsoft.ML; using Microsoft.ML.CommandLine; -using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.DiversityMeasure; -using Microsoft.ML.Ensemble.Selector.SubModelSelector; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.DiversityMeasure; +using Microsoft.ML.Trainers.Ensemble.SubModelSelector; [assembly: LoadableClass(typeof(BestDiverseSelectorRegression), typeof(BestDiverseSelectorRegression.Arguments), typeof(SignatureEnsembleSubModelSelector), BestDiverseSelectorRegression.UserName, BestDiverseSelectorRegression.LoadName)] -namespace Microsoft.ML.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector { - using TScalarPredictor = IPredictorProducing; - internal sealed class BestDiverseSelectorRegression : BaseDiverseSelector, IRegressionSubModelSelector { public const string UserName = "Best Diverse Selector"; diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceRegressionSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceRegressionSelector.cs index be7b7e8f35..de6fe8874d 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceRegressionSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceRegressionSelector.cs @@ -6,15 +6,15 @@ using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.SubModelSelector; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.SubModelSelector; [assembly: LoadableClass(typeof(BestPerformanceRegressionSelector), typeof(BestPerformanceRegressionSelector.Arguments), typeof(SignatureEnsembleSubModelSelector), BestPerformanceRegressionSelector.UserName, BestPerformanceRegressionSelector.LoadName)] -namespace Microsoft.ML.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector { internal sealed class BestPerformanceRegressionSelector : BaseBestPerformanceSelector, IRegressionSubModelSelector { diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelector.cs index 9b24798276..f11b30556f 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelector.cs @@ -6,15 +6,15 @@ using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.SubModelSelector; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.SubModelSelector; [assembly: LoadableClass(typeof(BestPerformanceSelector), typeof(BestPerformanceSelector.Arguments), typeof(SignatureEnsembleSubModelSelector), BestPerformanceSelector.UserName, BestPerformanceSelector.LoadName)] -namespace Microsoft.ML.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector { internal sealed class BestPerformanceSelector : BaseBestPerformanceSelector, IBinarySubModelSelector { diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelectorMultiClass.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelectorMultiClass.cs index 36f9635c87..34ce8719db 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelectorMultiClass.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestPerformanceSelectorMultiClass.cs @@ -6,15 +6,15 @@ using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.SubModelSelector; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.SubModelSelector; [assembly: LoadableClass(typeof(BestPerformanceSelectorMultiClass), typeof(BestPerformanceSelectorMultiClass.Arguments), typeof(SignatureEnsembleSubModelSelector), BestPerformanceSelectorMultiClass.UserName, BestPerformanceSelectorMultiClass.LoadName)] -namespace Microsoft.ML.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector { internal sealed class BestPerformanceSelectorMultiClass : BaseBestPerformanceSelector>, IMulticlassSubModelSelector { diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/SubModelDataSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/SubModelDataSelector.cs index 0ba6497f25..ff8dd74354 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/SubModelDataSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/SubModelDataSelector.cs @@ -6,7 +6,7 @@ using Microsoft.ML.CommandLine; using Microsoft.ML.Internal.Internallearn; -namespace Microsoft.ML.Ensemble.Selector.SubModelSelector +namespace Microsoft.ML.Trainers.Ensemble.SubModelSelector { internal abstract class SubModelDataSelector : BaseSubModelSelector { diff --git a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/AllInstanceSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/AllInstanceSelector.cs index 43203e9ca9..1c1c15c3cd 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/AllInstanceSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/AllInstanceSelector.cs @@ -5,16 +5,16 @@ using System; using System.Collections.Generic; using Microsoft.ML; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.SubsetSelector; using Microsoft.ML.EntryPoints; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.SubsetSelector; [assembly: LoadableClass(typeof(AllInstanceSelector), typeof(AllInstanceSelector.Arguments), typeof(SignatureEnsembleDataSelector), AllInstanceSelector.UserName, AllInstanceSelector.LoadName)] [assembly: EntryPointModule(typeof(AllInstanceSelector))] -namespace Microsoft.ML.Ensemble.Selector.SubsetSelector +namespace Microsoft.ML.Trainers.Ensemble.SubsetSelector { internal sealed class AllInstanceSelector : BaseSubsetSelector { diff --git a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BaseSubsetSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BaseSubsetSelector.cs index a17ea20d88..3305574d36 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BaseSubsetSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BaseSubsetSelector.cs @@ -6,10 +6,9 @@ using System.Collections.Generic; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.EntryPoints; using Microsoft.ML.Transforms; -namespace Microsoft.ML.Ensemble.Selector.SubsetSelector +namespace Microsoft.ML.Trainers.Ensemble.SubsetSelector { internal abstract class BaseSubsetSelector : ISubsetSelector where TArgs : BaseSubsetSelector.ArgumentsBase diff --git a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs index c3a98e47c8..8006ba2195 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs @@ -6,9 +6,9 @@ using System.Collections.Generic; using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.SubsetSelector; using Microsoft.ML.EntryPoints; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.SubsetSelector; using Microsoft.ML.Transforms; [assembly: LoadableClass(typeof(BootstrapSelector), typeof(BootstrapSelector.Arguments), @@ -16,7 +16,7 @@ [assembly: EntryPointModule(typeof(BootstrapSelector))] -namespace Microsoft.ML.Ensemble.Selector.SubsetSelector +namespace Microsoft.ML.Trainers.Ensemble.SubsetSelector { internal sealed class BootstrapSelector : BaseSubsetSelector { diff --git a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs index a142189808..a7f52e6b7b 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/RandomPartitionSelector.cs @@ -6,9 +6,9 @@ using System.Collections.Generic; using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.SubsetSelector; using Microsoft.ML.EntryPoints; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.Ensemble.SubsetSelector; using Microsoft.ML.Transforms; [assembly: LoadableClass(typeof(RandomPartitionSelector), typeof(RandomPartitionSelector.Arguments), @@ -16,7 +16,7 @@ [assembly: EntryPointModule(typeof(RandomPartitionSelector))] -namespace Microsoft.ML.Ensemble.Selector.SubsetSelector +namespace Microsoft.ML.Trainers.Ensemble.SubsetSelector { internal sealed class RandomPartitionSelector : BaseSubsetSelector { diff --git a/src/Microsoft.ML.Ensemble/Subset.cs b/src/Microsoft.ML.Ensemble/Subset.cs index 743be33df6..e1e579fd7b 100644 --- a/src/Microsoft.ML.Ensemble/Subset.cs +++ b/src/Microsoft.ML.Ensemble/Subset.cs @@ -5,7 +5,7 @@ using System.Collections; using Microsoft.ML.Data; -namespace Microsoft.ML.Ensemble +namespace Microsoft.ML.Trainers.Ensemble { internal sealed class Subset { diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs index 83637c333f..1e734c875c 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs @@ -7,12 +7,8 @@ using System.Linq; using Microsoft.ML; using Microsoft.ML.CommandLine; -using Microsoft.ML.Ensemble; -using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Ensemble.OutputCombiners; -using Microsoft.ML.Ensemble.Selector; using Microsoft.ML.Internal.Internallearn; -using Microsoft.ML.Learners; +using Microsoft.ML.Trainers.Ensemble; using Microsoft.ML.Trainers.Online; using Microsoft.ML.Training; @@ -23,7 +19,7 @@ [assembly: LoadableClass(typeof(EnsembleTrainer), typeof(EnsembleTrainer.Arguments), typeof(SignatureModelCombiner), "Binary Classification Ensemble Model Combiner", EnsembleTrainer.LoadNameValue, "pe", "ParallelEnsemble")] -namespace Microsoft.ML.Ensemble +namespace Microsoft.ML.Trainers.Ensemble { using TDistPredictor = IDistPredictorProducing; using TScalarPredictor = IPredictorProducing; diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs index e63f2059e5..d425db8e4f 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs @@ -9,16 +9,15 @@ using Microsoft.Data.DataView; using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; +using Microsoft.ML.Trainers.Ensemble; // These are for deserialization from a model repository. [assembly: LoadableClass(typeof(EnsembleDistributionModelParameters), null, typeof(SignatureLoadModel), EnsembleDistributionModelParameters.UserName, EnsembleDistributionModelParameters.LoaderSignature)] -namespace Microsoft.ML.Ensemble +namespace Microsoft.ML.Trainers.Ensemble { using TDistPredictor = IDistPredictorProducing; diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs index 032ff39695..9eb56d5bfd 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs @@ -7,17 +7,16 @@ using Microsoft.Data.DataView; using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.EntryPoints; using Microsoft.ML.Model; +using Microsoft.ML.Trainers.Ensemble; [assembly: LoadableClass(typeof(EnsembleModelParameters), null, typeof(SignatureLoadModel), EnsembleModelParameters.UserName, EnsembleModelParameters.LoaderSignature)] [assembly: EntryPointModule(typeof(EnsembleModelParameters))] -namespace Microsoft.ML.Ensemble +namespace Microsoft.ML.Trainers.Ensemble { /// /// A class for artifacts of ensembled models. diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs index 5e79bbe41e..f8568cc2d7 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs @@ -6,12 +6,12 @@ using System.Collections.Generic; using System.IO; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; +using Microsoft.ML.Trainers.Ensemble; -namespace Microsoft.ML.Ensemble +namespace Microsoft.ML.Trainers.Ensemble { public abstract class EnsembleModelParametersBase : ModelParametersBase, IPredictorProducing, ICanSaveInTextFormat, ICanSaveSummary diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs index ef7ebc6d4d..b9aad6228f 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs @@ -8,15 +8,13 @@ using System.Threading.Tasks; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble.OutputCombiners; -using Microsoft.ML.Ensemble.Selector; -using Microsoft.ML.Ensemble.Selector.SubsetSelector; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Trainers.Ensemble.SubsetSelector; using Microsoft.ML.Training; -namespace Microsoft.ML.Ensemble +namespace Microsoft.ML.Trainers.Ensemble { using Stopwatch = System.Diagnostics.Stopwatch; diff --git a/src/Microsoft.ML.Ensemble/Trainer/IModelCombiner.cs b/src/Microsoft.ML.Ensemble/Trainer/IModelCombiner.cs index 85f55d8111..48af1cfcfa 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/IModelCombiner.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/IModelCombiner.cs @@ -4,7 +4,7 @@ using System.Collections.Generic; -namespace Microsoft.ML.Ensemble +namespace Microsoft.ML.Trainers.Ensemble { public delegate void SignatureModelCombiner(PredictionKind kind); diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs index 134e58d51d..163b256f1e 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs @@ -7,17 +7,14 @@ using Microsoft.Data.DataView; using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.Model; +using Microsoft.ML.Trainers.Ensemble; [assembly: LoadableClass(typeof(EnsembleMultiClassModelParameters), null, typeof(SignatureLoadModel), EnsembleMultiClassModelParameters.UserName, EnsembleMultiClassModelParameters.LoaderSignature)] -namespace Microsoft.ML.Ensemble +namespace Microsoft.ML.Trainers.Ensemble { - using TVectorPredictor = IPredictorProducing>; - public sealed class EnsembleMultiClassModelParameters : EnsembleModelParametersBase>, IValueMapper { internal const string UserName = "Ensemble Multiclass Executor"; diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs index d194dfc947..4420b15ba9 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs @@ -8,12 +8,8 @@ using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble; -using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Ensemble.OutputCombiners; -using Microsoft.ML.Ensemble.Selector; using Microsoft.ML.Internal.Internallearn; -using Microsoft.ML.Learners; +using Microsoft.ML.Trainers.Ensemble; using Microsoft.ML.Training; [assembly: LoadableClass(MulticlassDataPartitionEnsembleTrainer.Summary, typeof(MulticlassDataPartitionEnsembleTrainer), @@ -25,7 +21,7 @@ [assembly: LoadableClass(typeof(MulticlassDataPartitionEnsembleTrainer), typeof(MulticlassDataPartitionEnsembleTrainer.Arguments), typeof(SignatureModelCombiner), "Multiclass Classification Ensemble Model Combiner", MulticlassDataPartitionEnsembleTrainer.LoadNameValue)] -namespace Microsoft.ML.Ensemble +namespace Microsoft.ML.Trainers.Ensemble { using TVectorPredictor = IPredictorProducing>; /// diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs index 569d4cd829..e3b7976d51 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs @@ -7,13 +7,8 @@ using System.Linq; using Microsoft.ML; using Microsoft.ML.CommandLine; -using Microsoft.ML.Data; -using Microsoft.ML.Ensemble; -using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Ensemble.OutputCombiners; -using Microsoft.ML.Ensemble.Selector; using Microsoft.ML.Internal.Internallearn; -using Microsoft.ML.Learners; +using Microsoft.ML.Trainers.Ensemble; using Microsoft.ML.Trainers.Online; using Microsoft.ML.Training; @@ -25,7 +20,7 @@ [assembly: LoadableClass(typeof(RegressionEnsembleTrainer), typeof(RegressionEnsembleTrainer.Arguments), typeof(SignatureModelCombiner), "Regression Ensemble Model Combiner", RegressionEnsembleTrainer.LoadNameValue)] -namespace Microsoft.ML.Ensemble +namespace Microsoft.ML.Trainers.Ensemble { using TScalarPredictor = IPredictorProducing; internal sealed class RegressionEnsembleTrainer : EnsembleTrainerBase ComputeTests(double[] scores) - { - List result = new List() - { - new TestResult("DominationLoss", ComputeDominationLoss(scores), 1, true, TestResult.ValueOperator.Sum), - }; - - return result; - } - } -#endif -} diff --git a/src/Microsoft.ML.FastTree/Application/LogLossApplication.cs b/src/Microsoft.ML.FastTree/Application/LogLossApplication.cs deleted file mode 100644 index 9dd896a401..0000000000 --- a/src/Microsoft.ML.FastTree/Application/LogLossApplication.cs +++ /dev/null @@ -1,169 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -namespace Microsoft.ML.Trainers.FastTree.Internal -{ -#if OLD_DATALOAD - public class LogLossCommandLineArgs : TrainingCommandLineArgs - { - public enum LogLossMode { Pairwise, Wholepage }; - [Argument(ArgumentType.LastOccurenceWins, HelpText = "Which style of log loss to use", ShortName = "llm")] - public LogLossMode loglossmode = LogLossMode.Pairwise; - [Argument(ArgumentType.LastOccurenceWins, HelpText = "log loss cefficient", ShortName = "llc")] - public double loglosscoef = 1.0; - } - - public class LogLossTrainingApplication : ApplicationBase - { - new LogLossCommandLineArgs cmd; - private const string RegistrationName = "LogLossApplication"; - - public LogLossTrainingApplication(IHostEnvironment env, string args, TrainingApplicationData data) - : base(env, RegistrationName, args, data) - { - base.cmd = this.cmd = new LogLossCommandLineArgs(); - } - - public override ObjectiveFunction ConstructObjFunc() - { - return new LogLossObjectiveFunction(TrainSet, cmd); - } - - public override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch) - { - OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch); - // optimizationAlgorithm.AdjustTreeOutputsOverride = new NoOutputOptimization(); - var lossCalculator = new LogLossTest(optimizationAlgorithm.TrainingScores, cmd.loglosscoef); - int lossIndex = (cmd.loglossmode == LogLossCommandLineArgs.LogLossMode.Pairwise) ? 0 : 1; - - optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, lossIndex, cmd.numPostBracketSteps, cmd.minStepSize); - - return optimizationAlgorithm; - } - - protected override void PrepareLabels(IChannel ch) - { - } - - protected override Test ConstructTestForTrainingData() - { - return new LogLossTest(ConstructScoreTracker(TrainSet), cmd.loglosscoef); - } - - protected override void InitializeTests() - { - Tests.Add(new LogLossTest(ConstructScoreTracker(TrainSet), cmd.loglosscoef)); - if (ValidSet != null) - Tests.Add(new LogLossTest(ConstructScoreTracker(ValidSet), cmd.loglosscoef)); - - if (TestSets != null) - { - for (int t = 0; t < TestSets.Length; ++t) - { - Tests.Add(new LogLossTest(ConstructScoreTracker(TestSets[t]), cmd.loglosscoef)); - } - } - } - } - - public class LogLossObjectiveFunction : RankingObjectiveFunction - { - private LogLossCommandLineArgs.LogLossMode _mode; - private double _coef = 1.0; - - public LogLossObjectiveFunction(Dataset trainSet, LogLossCommandLineArgs cmd) - : base(trainSet, trainSet.Ratings, cmd) - { - _mode = cmd.loglossmode; - _coef = cmd.loglosscoef; - } - - protected override void GetGradientInOneQuery(int query, int threadIndex) - { - int begin = Dataset.Boundaries[query]; - int end = Dataset.Boundaries[query + 1]; - short[] labels = Dataset.Ratings; - - if (end - begin <= 1) - return; - Array.Clear(_gradient, begin, end - begin); - - for (int d1 = begin; d1 < end - 1; ++d1) - { - int stop = (_mode == LogLossCommandLineArgs.LogLossMode.Pairwise) ? d1 + 2 : end; - for (int d2 = d1 + 1; d2 < stop; ++d2) - { - short labelDiff = (short)(labels[d1] - labels[d2]); - if (labelDiff == 0) - continue; - double delta = (_coef * labelDiff) / (1.0 + Math.Exp(_coef * labelDiff * (_scores[d1] - _scores[d2]))); - - _gradient[d1] += delta; - _gradient[d2] -= delta; - } - } - } - } - - public class LogLossTest : Test - { - protected double _coef; - public LogLossTest(ScoreTracker scoreTracker, double coef) - : base(scoreTracker) - { - _coef = coef; - } - - public override IEnumerable ComputeTests(double[] scores) - { - Object _lock = new Object(); - double pairedLoss = 0.0; - double allPairLoss = 0.0; - short maxLabel = 0; - short minLabel = 10000; - - for (int query = 0; query < Dataset.Boundaries.Length - 1; query++) - { - int start = Dataset.Boundaries[query]; - int length = Dataset.Boundaries[query + 1] - start; - for (int i = start; i < start + length; i++) - for (int j = i + 1; j < start + length; j++) - allPairLoss += Math.Log((1.0 + Math.Exp(-_coef * (Dataset.Ratings[i] - Dataset.Ratings[j]) * (scores[i] - scores[j])))); - //allPairLoss += Math.Max(0.0, _coef - (Dataset.Ratings[i] - Dataset.Ratings[j]) * (scores[i] - scores[j])); - - for (int i = start; i < start + length - 1; i++) - { - pairedLoss += Math.Log((1.0 + Math.Exp(-_coef * (Dataset.Ratings[i] - Dataset.Ratings[i + 1]) * (scores[i] - scores[i + 1])))); - // pairedLoss += Math.Max(0.0, _coef - (Dataset.Ratings[i] - Dataset.Ratings[i + 1]) * (scores[i] - scores[i + 1])); - } - } - - for (int i = 0; i < Dataset.Ratings.Length; i++) - { - if (Dataset.Ratings[i] > maxLabel) - maxLabel = Dataset.Ratings[i]; - if (Dataset.Ratings[i] < minLabel) - minLabel = Dataset.Ratings[i]; - } - List result = new List() - { - new TestResult("paired loss", pairedLoss, Dataset.NumDocs, true, TestResult.ValueOperator.Average), - new TestResult("all pairs loss", allPairLoss, Dataset.NumDocs, true, TestResult.ValueOperator.Average), - new TestResult("coefficient", _coef, 1, true, TestResult.ValueOperator.Constant), - new TestResult("max Label", maxLabel, 1, false, TestResult.ValueOperator.Max), - new TestResult("min Label", minLabel, 1, true, TestResult.ValueOperator.Min), - }; - - return result; - } - } - - public class NoOutputOptimization : IStepSearch - { - public NoOutputOptimization() { } - - public void AdjustTreeOutputs(IChannel ch, RegressionTree tree, DocumentPartitioning partitioning, ScoreTracker trainingScores) { } - } -#endif -} diff --git a/src/Microsoft.ML.FastTree/Application/SizeAdjustedLogLossApplication.cs b/src/Microsoft.ML.FastTree/Application/SizeAdjustedLogLossApplication.cs deleted file mode 100644 index 0cb8c9496f..0000000000 --- a/src/Microsoft.ML.FastTree/Application/SizeAdjustedLogLossApplication.cs +++ /dev/null @@ -1,388 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -namespace Microsoft.ML.Trainers.FastTree.Internal -{ -#if OLD_DATALOAD - public class SizeAdjustedLogLossCommandLineArgs : TrainingCommandLineArgs - { - public enum LogLossMode - { - Pairwise, - Wholepage - }; - - public enum CostFunctionMode - { - SizeAdjustedWinratePredictor, - SizeAdjustedPageOrdering - }; - - [Argument(ArgumentType.LastOccurenceWins, HelpText = "Which style of log loss to use", ShortName = "llm")] - public LogLossMode loglossmode = LogLossMode.Pairwise; - - [Argument(ArgumentType.LastOccurenceWins, HelpText = "Which style of cost function to use", ShortName = "cfm")] - public CostFunctionMode costFunctionMode = CostFunctionMode.SizeAdjustedPageOrdering; - - // REVIEW: If we ever want to expose this application in TLC, the natural thing would be for these to - // be loaded from a column from the input data view. However I'll keep them around for now, so that when we - // do migrate there's some clear idea of how it is to be done. - [Argument(ArgumentType.AtMostOnce, HelpText = "tab seperated file which contains the max and min score from the model", ShortName = "srange")] - public string scoreRangeFileName = null; - - [Argument(ArgumentType.MultipleUnique, HelpText = "TSV filename with float size values associated with train bin file", ShortName = "trsl")] - public string[] trainSizeLabelFilenames = null; - - [Argument(ArgumentType.LastOccurenceWins, HelpText = "TSV filename with float size values associated with validation bin file", ShortName = "vasl")] - public string validSizeLabelFilename = null; - - [Argument(ArgumentType.MultipleUnique, HelpText = "TSV filename with float size values associated with test bin file", ShortName = "tesl")] - public string[] testSizeLabelFilenames = null; - - [Argument(ArgumentType.LastOccurenceWins, HelpText = "log loss cefficient", ShortName = "llc")] - public double loglosscoef = 1.0; - - } - - public class SizeAdjustedLogLossUtil - { - public const string sizeLabelName = "size"; - - public static float[] GetSizeLabels(Dataset set) - { - if (set == null) - { - return null; - } - float[] labels = set.Skeleton.GetData(sizeLabelName); - if (labels == null) - { - labels = set.Ratings.Select(x => (float)x).ToArray(); - } - return labels; - } - } - - public class SizeAdjustedLogLossTrainingApplication : ApplicationBase - { - new SizeAdjustedLogLossCommandLineArgs cmd; - float[] trainSetSizeLabels; - private const string RegistrationName = "SizeAdjustedLogLossApplication"; - - public SizeAdjustedLogLossTrainingApplication(IHostEnvironment env, string args, TrainingApplicationData data) - : base(env, RegistrationName, args, data) - { - base.cmd = this.cmd = new SizeAdjustedLogLossCommandLineArgs(); - } - - public override ObjectiveFunction ConstructObjFunc() - { - return new SizeAdjustedLogLossObjectiveFunction(TrainSet, cmd); - } - - public override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch) - { - OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch); - // optimizationAlgorithm.AdjustTreeOutputsOverride = new NoOutputOptimization(); // For testing purposes - this will not use line search and thus the scores won't be scaled. - var lossCalculator = new SizeAdjustedLogLossTest(optimizationAlgorithm.TrainingScores, cmd.scoreRangeFileName, trainSetSizeLabels, cmd.loglosscoef); - - // The index of the label signifies which index from TestResult would be used as a loss. For every query, we compute both wholepage and pairwise loss with the two cost function modes, this index lets us pick the appropriate one. - int lossIndex = 0; - if (cmd.loglossmode == SizeAdjustedLogLossCommandLineArgs.LogLossMode.Wholepage && cmd.costFunctionMode == SizeAdjustedLogLossCommandLineArgs.CostFunctionMode.SizeAdjustedPageOrdering) - { - lossIndex = 1; - } - else if (cmd.loglossmode == SizeAdjustedLogLossCommandLineArgs.LogLossMode.Pairwise && cmd.costFunctionMode == SizeAdjustedLogLossCommandLineArgs.CostFunctionMode.SizeAdjustedWinratePredictor) - { - lossIndex = 2; - } - else if (cmd.loglossmode == SizeAdjustedLogLossCommandLineArgs.LogLossMode.Wholepage && cmd.costFunctionMode == SizeAdjustedLogLossCommandLineArgs.CostFunctionMode.SizeAdjustedWinratePredictor) - { - lossIndex = 3; - } - - optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, lossIndex, cmd.numPostBracketSteps, cmd.minStepSize); - - return optimizationAlgorithm; - } - - private static IEnumerable LoadSizeLabels(string filename) - { - using (StreamReader reader = new StreamReader(new FileStream(filename, FileMode.Open, FileAccess.Read, FileShare.Read))) - { - string line = reader.ReadLine(); - Contracts.Check(line != null && line.Trim() == "m:Size", "Regression label file should contain only one column m:Size"); - while ((line = reader.ReadLine()) != null) - { - float val = float.Parse(line.Trim(), CultureInfo.InvariantCulture); - yield return val; - } - } - } - - protected override void PrepareLabels(IChannel ch) - { - trainSetSizeLabels = SizeAdjustedLogLossUtil.GetSizeLabels(TrainSet); - } - - protected override Test ConstructTestForTrainingData() - { - return new SizeAdjustedLogLossTest(ConstructScoreTracker(TrainSet), cmd.scoreRangeFileName, SizeAdjustedLogLossUtil.GetSizeLabels(TrainSet), cmd.loglosscoef); - } - - protected override void ProcessBinFile(int trainValidTest, int index, DatasetBinFile bin) - { - string labelPath = null; - switch (trainValidTest) - { - case 0: - if (cmd.trainSizeLabelFilenames != null && index < cmd.trainSetFilenames.Length) - { - labelPath = cmd.trainSizeLabelFilenames[index]; - } - break; - case 1: - if (cmd.validSizeLabelFilename != null) - { - labelPath = cmd.validSizeLabelFilename; - } - break; - case 2: - if (cmd.testSizeLabelFilenames != null && index < cmd.testSizeLabelFilenames.Length) - { - labelPath = cmd.testSizeLabelFilenames[index]; - } - break; - } - // If we have no labels, return. - if (labelPath == null) - { - return; - } - float[] labels = LoadSizeLabels(labelPath).ToArray(); - bin.DatasetSkeleton.SetData(SizeAdjustedLogLossUtil.sizeLabelName, labels, false); - } - - protected override void InitializeTests() - { - Tests.Add(new SizeAdjustedLogLossTest(ConstructScoreTracker(TrainSet), cmd.scoreRangeFileName, SizeAdjustedLogLossUtil.GetSizeLabels(TrainSet), cmd.loglosscoef)); - if (ValidSet != null) - { - Tests.Add(new SizeAdjustedLogLossTest(ConstructScoreTracker(ValidSet), cmd.scoreRangeFileName, SizeAdjustedLogLossUtil.GetSizeLabels(ValidSet), cmd.loglosscoef)); - } - - if (TestSets != null && TestSets.Length > 0) - { - for (int t = 0; t < TestSets.Length; ++t) - { - Tests.Add(new SizeAdjustedLogLossTest(ConstructScoreTracker(TestSets[t]), cmd.scoreRangeFileName, SizeAdjustedLogLossUtil.GetSizeLabels(TestSets[t]), cmd.loglosscoef)); - } - } - } - - public override void PrintIterationMessage(IChannel ch, IProgressChannel pch) - { - base.PrintIterationMessage(ch, pch); - } - } - - public class SizeAdjustedLogLossObjectiveFunction : RankingObjectiveFunction - { - private SizeAdjustedLogLossCommandLineArgs.LogLossMode _mode; - private SizeAdjustedLogLossCommandLineArgs.CostFunctionMode _algo; - private double _llc; - - public SizeAdjustedLogLossObjectiveFunction(Dataset trainSet, SizeAdjustedLogLossCommandLineArgs cmd) - : base(trainSet, trainSet.Ratings, cmd) - { - _mode = cmd.loglossmode; - _algo = cmd.costFunctionMode; - _llc = cmd.loglosscoef; - } - - protected override void GetGradientInOneQuery(int query, int threadIndex) - { - int begin = Dataset.Boundaries[query]; - int end = Dataset.Boundaries[query + 1]; - short[] labels = Dataset.Ratings; - float[] sizes = SizeAdjustedLogLossUtil.GetSizeLabels(Dataset); - - Contracts.Check(Dataset.NumDocs == sizes.Length, "Mismatch between dataset and labels"); - - if (end - begin <= 1) - { - return; - } - Array.Clear(_gradient, begin, end - begin); - - for (int d1 = begin; d1 < end - 1; ++d1) - { - for (int d2 = d1 + 1; d2 < end; ++d2) - { - float size = sizes[d1]; - - //Compute Lij - float sizeAdjustedLoss = 0.0F; - for (int d3 = d2; d3 < end; ++d3) - { - size -= sizes[d3]; - if (size >= 0.0F && labels[d3] > 0) - { - sizeAdjustedLoss = 1.0F; - } - else if (size < 0.0F && labels[d3] > 0) - { - sizeAdjustedLoss = (1.0F + (size / sizes[d3])); - } - - if (size <= 0.0F || sizeAdjustedLoss > 0.0F) - { - // Exit condition- we have reached size or size adjusted loss is already populated. - break; - } - } - - double scoreDiff = _scores[d1] - _scores[d2]; - float labelDiff = ((float)labels[d1] - sizeAdjustedLoss); - double delta = 0.0; - if (_algo == SizeAdjustedLogLossCommandLineArgs.CostFunctionMode.SizeAdjustedPageOrdering) - { - delta = (_llc * labelDiff) / (1.0 + Math.Exp(_llc * labelDiff * scoreDiff)); - } - else - { - delta = (double)labels[d1] - ((double)(labels[d1] + sizeAdjustedLoss) / (1.0 + Math.Exp(-scoreDiff))); - } - - _gradient[d1] += delta; - _gradient[d2] -= delta; - - if (_mode == SizeAdjustedLogLossCommandLineArgs.LogLossMode.Pairwise) - { - break; - } - } - } - } - } - - public class SizeAdjustedLogLossTest : Test - { - protected string _scoreRangeFileName = null; - static double maxScore = double.MinValue; - static double minScore = double.MaxValue; - private float[] _sizeLabels; - private double _llc; - - public SizeAdjustedLogLossTest(ScoreTracker scoreTracker, string scoreRangeFileName, float[] sizeLabels, double loglossCoeff) - : base(scoreTracker) - { - Contracts.Check(scoreTracker.Dataset.NumDocs == sizeLabels.Length, "Mismatch between dataset and labels"); - _sizeLabels = sizeLabels; - _scoreRangeFileName = scoreRangeFileName; - _llc = loglossCoeff; - } - - public override IEnumerable ComputeTests(double[] scores) - { - Object _lock = new Object(); - double pairedPageOrderingLoss = 0.0; - double allPairPageOrderingLoss = 0.0; - double pairedSAWRPredictLoss = 0.0; - double allPairSAWRPredictLoss = 0.0; - - short maxLabel = 0; - short minLabel = 10000; - for (int query = 0; query < Dataset.Boundaries.Length - 1; ++query) - { - int begin = Dataset.Boundaries[query]; - int end = Dataset.Boundaries[query + 1]; - - if (end - begin <= 1) - { - continue; - } - - for (int d1 = begin; d1 < end - 1; ++d1) - { - bool firstTime = false; - for (int d2 = d1 + 1; d2 < end; ++d2) - { - float size = _sizeLabels[d1]; - - //Compute Lij - float sizeAdjustedLoss = 0.0F; - for (int d3 = d2; d3 < end; ++d3) - { - size -= _sizeLabels[d3]; - if (size >= 0.0F && Dataset.Ratings[d3] > 0) - { - sizeAdjustedLoss = 1.0F; - } - else if (size < 0.0F && Dataset.Ratings[d3] > 0) - { - sizeAdjustedLoss = (1.0F + (size / _sizeLabels[d3])); - } - - if (size <= 0.0F || sizeAdjustedLoss > 0.0F) - { - // Exit condition- we have reached size or size adjusted loss is already populated. - break; - } - } - //Compute page ordering loss - double scoreDiff = scores[d1] - scores[d2]; - float labelDiff = ((float)Dataset.Ratings[d1] - sizeAdjustedLoss); - double pageOrderingLoss = Math.Log(1.0 + Math.Exp(-_llc * labelDiff * scoreDiff)); - - // Compute SAWR predict loss - double sawrPredictLoss = (double)((Dataset.Ratings[d1] + sizeAdjustedLoss) * Math.Log(1.0 + Math.Exp(scoreDiff))) - ((double)Dataset.Ratings[d1] * scoreDiff); - if (!firstTime) - { - pairedPageOrderingLoss += pageOrderingLoss; - pairedSAWRPredictLoss += sawrPredictLoss; - firstTime = true; - } - allPairPageOrderingLoss += pageOrderingLoss; - allPairSAWRPredictLoss += sawrPredictLoss; - } - } - } - - for (int i = 0; i < Dataset.Ratings.Length; i++) - { - if (Dataset.Ratings[i] > maxLabel) - maxLabel = Dataset.Ratings[i]; - if (Dataset.Ratings[i] < minLabel) - minLabel = Dataset.Ratings[i]; - if (scores[i] > maxScore) - maxScore = scores[i]; - if (scores[i] < minScore) - minScore = scores[i]; - } - - if (_scoreRangeFileName != null) - { - using (StreamWriter sw = File.CreateText(_scoreRangeFileName)) - { - sw.WriteLine(string.Format("{0}\t{1}", minScore, maxScore)); - } - } - - List result = new List() - { - // The index of the label signifies which index from TestResult would be used as a loss. For every query, we compute both wholepage and pairwise loss with the two cost function modes, this index lets us pick the appropriate one. - new TestResult("page ordering paired loss", pairedPageOrderingLoss, Dataset.NumDocs, true, TestResult.ValueOperator.Average), - new TestResult("page ordering all pairs loss", allPairPageOrderingLoss, Dataset.NumDocs, true, TestResult.ValueOperator.Average), - new TestResult("SAWR predict paired loss", pairedSAWRPredictLoss, Dataset.NumDocs, true, TestResult.ValueOperator.Average), - new TestResult("SAWR predict all pairs loss", allPairSAWRPredictLoss, Dataset.NumDocs, true, TestResult.ValueOperator.Average), - new TestResult("max Label", maxLabel, 1, false, TestResult.ValueOperator.Max), - new TestResult("min Label", minLabel, 1, true, TestResult.ValueOperator.Min), - }; - - return result; - } - } -#endif -} diff --git a/src/Microsoft.ML.FastTree/Application/WinLossSurplusApplication.cs b/src/Microsoft.ML.FastTree/Application/WinLossSurplusApplication.cs deleted file mode 100644 index 1708033161..0000000000 --- a/src/Microsoft.ML.FastTree/Application/WinLossSurplusApplication.cs +++ /dev/null @@ -1,181 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -namespace Microsoft.ML.Trainers.FastTree.Internal -{ -#if OLD_DATALOAD - public class WinLossSurplusCommandLineArgs : TrainingCommandLineArgs - { - [Argument(ArgumentType.AtMostOnce, HelpText = "Scaling Factor for win loss surplus", ShortName = "wls")] - public double winlossScaleFactor = 1.0; - } - - public class WinLossSurplusTrainingApplication : RankingApplication - { - new WinLossSurplusCommandLineArgs cmd; - private const string RegistrationName = "WinLossSurplusApplication"; - - public WinLossSurplusTrainingApplication(IHostEnvironment env, string args, TrainingApplicationData data) - : base(env, RegistrationName, args, data) - { - base.cmd = this.cmd = new WinLossSurplusCommandLineArgs(); - } - - public override ObjectiveFunction ConstructObjFunc() - { - return new WinLossSurplusObjectiveFunction(TrainSet, TrainSet.Ratings, cmd); - } - - public override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch) - { - OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch); - var lossCalculator = new WinLossSurplusTest(optimizationAlgorithm.TrainingScores, TrainSet.Ratings, cmd.sortingAlgorithm, cmd.winlossScaleFactor); - optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, 0, cmd.numPostBracketSteps, cmd.minStepSize); - return optimizationAlgorithm; - } - - protected override Test ConstructTestForTrainingData() - { - return new WinLossSurplusTest(ConstructScoreTracker(TrainSet), TrainSet.Ratings, cmd.sortingAlgorithm, cmd.winlossScaleFactor); - } - - public override void PrintIterationMessage(IChannel ch, IProgressChannel pch) - { - // REVIEW: Shift this to use progress channels. -#if OLD_TRACING - if (PruningTest != null) - { - if (PruningTest is TestWindowWithTolerance) - { - if (PruningTest.BestIteration != -1) - ch.Info("Iteration {0} \t(Best tolerated validation moving average WinLossSurplus {1}:{2:00.00}~{3:00.00})", - Ensemble.NumTrees, - PruningTest.BestIteration, - (PruningTest as TestWindowWithTolerance).BestAverageValue, - (PruningTest as TestWindowWithTolerance).CurrentAverageValue); - else - ch.Info("Iteration {0}", Ensemble.NumTrees); - } - else - { - ch.Info("Iteration {0} \t(best validation WinLoss {1}:{2:00.00}>{3:00.00})", - Ensemble.NumTrees, - PruningTest.BestIteration, - PruningTest.BestResult.FinalValue, - PruningTest.ComputeTests().First().FinalValue); - } - } - else - base.PrintIterationMessage(ch, pch); -#else - base.PrintIterationMessage(ch, pch); -#endif - } - - protected override Test CreateStandardTest(Dataset dataset) - { - return new WinLossSurplusTest( - ConstructScoreTracker(dataset), - dataset.Ratings, - cmd.sortingAlgorithm, - cmd.winlossScaleFactor); - } - - protected override Test CreateSpecialTrainSetTest() - { - return CreateStandardTest(TrainSet); - } - - protected override Test CreateSpecialValidSetTest() - { - return CreateStandardTest(ValidSet); - } - - protected override string GetTestGraphHeader() - { - return "Eval:\tFileName\tMaxSurplus\tSurplus@100\tSurplus@200\tSurplus@300\tSurplus@400\tSurplus@500\tSurplus@1000\tMaxSurplusPos\tPercentTop\n"; - } - } - - public class WinLossSurplusObjectiveFunction : LambdaRankObjectiveFunction - { - public WinLossSurplusObjectiveFunction(Dataset trainSet, short[] labels, WinLossSurplusCommandLineArgs cmd) - : base(trainSet, labels, cmd) - { - } - - protected override void GetGradientInOneQuery(int query, int threadIndex) - { - int begin = Dataset.Boundaries[query]; - int numDocuments = Dataset.Boundaries[query + 1] - Dataset.Boundaries[query]; - - Array.Clear(_gradient, begin, numDocuments); - Array.Clear(_weights, begin, numDocuments); - - double inverseMaxDCG = _inverseMaxDCGT[query]; - - int[] permutation = _permutationBuffers[threadIndex]; - - short[] labels = Labels; - double[] scoresToUse = _scores; - - // Keep track of top 3 labels for later use - //GetTopQueryLabels(query, permutation, false); - unsafe - { - fixed (int* pPermutation = permutation) - fixed (short* pLabels = labels) - fixed (double* pScores = scoresToUse, pLambdas = _gradient, pWeights = _weights, pDiscount = _discount) - fixed (double* pGain = _gain, pGainLabels = _gainLabels, pSigmoidTable = _sigmoidTable) - fixed (int* pOneTwoThree = _oneTwoThree) - { - // calculates the permutation that orders "scores" in descending order, without modifying "scores" - Array.Copy(_oneTwoThree, permutation, numDocuments); -#if USE_FASTTREENATIVE - double lambdaSum = 0; - - C_Sort(pPermutation, &pScores[begin], &pLabels[begin], numDocuments); - - // Keep track of top 3 labels for later use - GetTopQueryLabels(query, permutation, true); - - int numActualResults = numDocuments; - - C_GetSurplusDerivatives(numDocuments, begin, pPermutation, pLabels, - pScores, pLambdas, pWeights, pDiscount, - pGainLabels, - pSigmoidTable, _minScore, _maxScore, _sigmoidTable.Length, _scoreToSigmoidTableFactor, - _costFunctionParam, _distanceWeight2, &lambdaSum, double.MinValue); - - if (_normalizeQueryLambdas) - { - if (lambdaSum > 0) - { - double normFactor = (10 * Math.Log(1 + lambdaSum)) / lambdaSum; - - for (int i = begin; i < begin + numDocuments; ++i) - { - pLambdas[i] = pLambdas[i] * normFactor; - pWeights[i] = pWeights[i] * normFactor; - } - } - } -#else - throw new Exception("Shifted NDCG / ContinuousWeightedRanknet / WinLossSurplus / distanceWeight2 / normalized lambdas are only supported by unmanaged code"); -#endif - } - } - } - - [DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Ansi)] - private unsafe static extern void C_GetSurplusDerivatives(int numDocuments, int begin, int* pPermutation, short* pLabels, - double* pScores, double* pLambdas, double* pWeights, double* pDiscount, - double* pGainLabels, double* lambdaTable, double minScore, double maxScore, - int lambdaTableLength, double scoreToLambdaTableFactor, - char costFunctionParam, [MarshalAs(UnmanagedType.U1)] bool distanceWeight2, - double* pLambdaSum, double doubleMinValue); - - } -#endif -} diff --git a/src/Microsoft.ML.FastTree/BinFile/BinFinder.cs b/src/Microsoft.ML.FastTree/BinFile/BinFinder.cs index 7725cb7aeb..7df57005b7 100644 --- a/src/Microsoft.ML.FastTree/BinFile/BinFinder.cs +++ b/src/Microsoft.ML.FastTree/BinFile/BinFinder.cs @@ -7,7 +7,7 @@ using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { /// /// A class that bins vectors of doubles into a specified number of equal mass bins. diff --git a/src/Microsoft.ML.FastTree/BinFile/IniFileParserInterface.cs b/src/Microsoft.ML.FastTree/BinFile/IniFileParserInterface.cs index 9b28df134c..257b9e83b5 100644 --- a/src/Microsoft.ML.FastTree/BinFile/IniFileParserInterface.cs +++ b/src/Microsoft.ML.FastTree/BinFile/IniFileParserInterface.cs @@ -7,7 +7,7 @@ using System.Runtime.InteropServices; using System.Text; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { internal sealed class IniFileParserInterface { diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index e84ec0117b..072ee2e20f 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -6,7 +6,6 @@ using System.Linq; using Microsoft.ML.Core.Data; using Microsoft.ML.Internal.Internallearn; -using Microsoft.ML.Trainers.FastTree.Internal; using Float = System.Single; namespace Microsoft.ML.Trainers.FastTree diff --git a/src/Microsoft.ML.FastTree/Dataset/Dataset.cs b/src/Microsoft.ML.FastTree/Dataset/Dataset.cs index 330ef028ca..1e5de9b229 100644 --- a/src/Microsoft.ML.FastTree/Dataset/Dataset.cs +++ b/src/Microsoft.ML.FastTree/Dataset/Dataset.cs @@ -8,7 +8,7 @@ using System.Threading.Tasks; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { /// /// A dataset of features. diff --git a/src/Microsoft.ML.FastTree/Dataset/DatasetUtils.cs b/src/Microsoft.ML.FastTree/Dataset/DatasetUtils.cs index 542b097381..617e6559e0 100644 --- a/src/Microsoft.ML.FastTree/Dataset/DatasetUtils.cs +++ b/src/Microsoft.ML.FastTree/Dataset/DatasetUtils.cs @@ -5,7 +5,7 @@ using System.Collections.Generic; using System.Linq; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { /// /// Loads training/validation/test sets from file diff --git a/src/Microsoft.ML.FastTree/Dataset/DenseIntArray.cs b/src/Microsoft.ML.FastTree/Dataset/DenseIntArray.cs index c6f10843d8..68c33a0d2d 100644 --- a/src/Microsoft.ML.FastTree/Dataset/DenseIntArray.cs +++ b/src/Microsoft.ML.FastTree/Dataset/DenseIntArray.cs @@ -8,7 +8,7 @@ using System.Runtime.InteropServices; using System.Security; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { #if USE_SINGLE_PRECISION using FloatType = System.Single; diff --git a/src/Microsoft.ML.FastTree/Dataset/Feature.cs b/src/Microsoft.ML.FastTree/Dataset/Feature.cs index c74ef5baf4..3013b536cf 100644 --- a/src/Microsoft.ML.FastTree/Dataset/Feature.cs +++ b/src/Microsoft.ML.FastTree/Dataset/Feature.cs @@ -4,7 +4,7 @@ using System.Linq; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { /// diff --git a/src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs b/src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs index 73c6b646bf..2d370499a4 100644 --- a/src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs +++ b/src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs @@ -15,7 +15,7 @@ using Microsoft.ML.Internal.CpuMath; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { /// /// Holds statistics per bin value for a feature. These are yielded by diff --git a/src/Microsoft.ML.FastTree/Dataset/FeatureHistogram.cs b/src/Microsoft.ML.FastTree/Dataset/FeatureHistogram.cs index 2294525a9f..4968ad6a58 100644 --- a/src/Microsoft.ML.FastTree/Dataset/FeatureHistogram.cs +++ b/src/Microsoft.ML.FastTree/Dataset/FeatureHistogram.cs @@ -5,7 +5,7 @@ using System; using System.Linq; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { #if USE_SINGLE_PRECISION using FloatType = System.Single; diff --git a/src/Microsoft.ML.FastTree/Dataset/FileObjectStore.cs b/src/Microsoft.ML.FastTree/Dataset/FileObjectStore.cs index 9f7623d0ad..ff6e990dca 100644 --- a/src/Microsoft.ML.FastTree/Dataset/FileObjectStore.cs +++ b/src/Microsoft.ML.FastTree/Dataset/FileObjectStore.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { #if !NO_STORE /// diff --git a/src/Microsoft.ML.FastTree/Dataset/IntArray.cs b/src/Microsoft.ML.FastTree/Dataset/IntArray.cs index e2c8de872b..05ed42c957 100644 --- a/src/Microsoft.ML.FastTree/Dataset/IntArray.cs +++ b/src/Microsoft.ML.FastTree/Dataset/IntArray.cs @@ -6,7 +6,7 @@ using System.Collections.Generic; using System.Linq; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { #if USE_SINGLE_PRECISION using FloatType = System.Single; diff --git a/src/Microsoft.ML.FastTree/Dataset/NHotFeatureFlock.cs b/src/Microsoft.ML.FastTree/Dataset/NHotFeatureFlock.cs index fcfa6c938a..43eacf22f0 100644 --- a/src/Microsoft.ML.FastTree/Dataset/NHotFeatureFlock.cs +++ b/src/Microsoft.ML.FastTree/Dataset/NHotFeatureFlock.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { /// /// This is a feature flock that misuses a property of diff --git a/src/Microsoft.ML.FastTree/Dataset/OneHotFeatureFlock.cs b/src/Microsoft.ML.FastTree/Dataset/OneHotFeatureFlock.cs index da8592fc14..605d912328 100644 --- a/src/Microsoft.ML.FastTree/Dataset/OneHotFeatureFlock.cs +++ b/src/Microsoft.ML.FastTree/Dataset/OneHotFeatureFlock.cs @@ -4,7 +4,7 @@ using System.Linq; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { /// /// A feature flock for a set of features where per example at most one of the features has a diff --git a/src/Microsoft.ML.FastTree/Dataset/RepeatIntArray.cs b/src/Microsoft.ML.FastTree/Dataset/RepeatIntArray.cs index 3b9502a8db..ac176c20ec 100644 --- a/src/Microsoft.ML.FastTree/Dataset/RepeatIntArray.cs +++ b/src/Microsoft.ML.FastTree/Dataset/RepeatIntArray.cs @@ -6,7 +6,7 @@ using System.Collections.Generic; using System.Linq; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { #if USE_SINGLE_PRECISION using FloatType = Single; diff --git a/src/Microsoft.ML.FastTree/Dataset/SegmentIntArray.cs b/src/Microsoft.ML.FastTree/Dataset/SegmentIntArray.cs index 156f5e227c..bfeb19f2f2 100644 --- a/src/Microsoft.ML.FastTree/Dataset/SegmentIntArray.cs +++ b/src/Microsoft.ML.FastTree/Dataset/SegmentIntArray.cs @@ -7,7 +7,7 @@ using System.Runtime.InteropServices; using System.Security; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { #if USE_SINGLE_PRECISION using FloatType = System.Single; diff --git a/src/Microsoft.ML.FastTree/Dataset/SingletonFeatureFlock.cs b/src/Microsoft.ML.FastTree/Dataset/SingletonFeatureFlock.cs index c1cf5fe4d4..30f609faa4 100644 --- a/src/Microsoft.ML.FastTree/Dataset/SingletonFeatureFlock.cs +++ b/src/Microsoft.ML.FastTree/Dataset/SingletonFeatureFlock.cs @@ -5,7 +5,7 @@ using System.Linq; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { /// /// The singleton feature flock is the simplest possible sort of flock, that is, a flock diff --git a/src/Microsoft.ML.FastTree/Dataset/SparseIntArray.cs b/src/Microsoft.ML.FastTree/Dataset/SparseIntArray.cs index b3dc61b07a..b9e5cc0f28 100644 --- a/src/Microsoft.ML.FastTree/Dataset/SparseIntArray.cs +++ b/src/Microsoft.ML.FastTree/Dataset/SparseIntArray.cs @@ -7,7 +7,7 @@ using System.Runtime.InteropServices; using System.Security; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { #if USE_SINGLE_PRECISION using FloatType = System.Single; diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index fc94e6099c..c2a90c4f3e 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -23,7 +23,6 @@ using Microsoft.ML.Model; using Microsoft.ML.Model.Onnx; using Microsoft.ML.Model.Pfa; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Conversions; diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index ff47449238..4ae7ab7555 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -15,7 +15,6 @@ using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; [assembly: LoadableClass(FastTreeBinaryClassificationTrainer.Summary, typeof(FastTreeBinaryClassificationTrainer), typeof(FastTreeBinaryClassificationTrainer.Options), diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 485a7639ca..059cd8ab93 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -17,7 +17,6 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; // REVIEW: Do we really need all these names? diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index e26ce00e82..1399b8bebe 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Linq; using System.Text; using Microsoft.Data.DataView; @@ -13,7 +12,6 @@ using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; [assembly: LoadableClass(FastTreeRegressionTrainer.Summary, typeof(FastTreeRegressionTrainer), typeof(FastTreeRegressionTrainer.Options), diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index b8ea0ce030..547d48f7ca 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -14,7 +14,6 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; [assembly: LoadableClass(FastTreeTweedieTrainer.Summary, typeof(FastTreeTweedieTrainer), typeof(FastTreeTweedieTrainer.Options), diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs index 635862e178..79f8265dbb 100644 --- a/src/Microsoft.ML.FastTree/GamClassification.cs +++ b/src/Microsoft.ML.FastTree/GamClassification.cs @@ -14,7 +14,6 @@ using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; [assembly: LoadableClass(BinaryClassificationGamTrainer.Summary, diff --git a/src/Microsoft.ML.FastTree/GamModelParameters.cs b/src/Microsoft.ML.FastTree/GamModelParameters.cs index 38c931f03f..a726edd1bc 100644 --- a/src/Microsoft.ML.FastTree/GamModelParameters.cs +++ b/src/Microsoft.ML.FastTree/GamModelParameters.cs @@ -18,7 +18,6 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; [assembly: LoadableClass(typeof(GamModelParametersBase.VisualizationCommand), typeof(GamModelParametersBase.VisualizationCommand.Arguments), typeof(SignatureCommand), diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs index 100cc6b61a..4c3fbe021b 100644 --- a/src/Microsoft.ML.FastTree/GamRegression.cs +++ b/src/Microsoft.ML.FastTree/GamRegression.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using Microsoft.Data.DataView; using Microsoft.ML; using Microsoft.ML.CommandLine; @@ -11,7 +10,6 @@ using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; [assembly: LoadableClass(RegressionGamTrainer.Summary, diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index 87bcc370ba..11953bbe36 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -14,15 +14,12 @@ using Microsoft.ML.Internal.CpuMath; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; -using Timer = Microsoft.ML.Trainers.FastTree.Internal.Timer; [assembly: LoadableClass(typeof(void), typeof(Gam), null, typeof(SignatureEntryPointModule), "GAM")] namespace Microsoft.ML.Trainers.FastTree { - using AutoResetEvent = System.Threading.AutoResetEvent; using SplitInfo = LeastSquaresRegressionTreeLearner.SplitInfo; /// diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs index 943b10f621..aea3a991bf 100644 --- a/src/Microsoft.ML.FastTree/RandomForest.cs +++ b/src/Microsoft.ML.FastTree/RandomForest.cs @@ -2,9 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using Microsoft.ML.Core.Data; -using Microsoft.ML.Trainers.FastTree.Internal; namespace Microsoft.ML.Trainers.FastTree { diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index d6a31ec6a0..13d2261ea9 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -15,7 +15,6 @@ using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; [assembly: LoadableClass(FastForestClassification.Summary, typeof(FastForestClassification), typeof(FastForestClassification.Options), diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index de83cabeca..daed4b26c1 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -13,7 +13,6 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; [assembly: LoadableClass(FastForestRegression.Summary, typeof(FastForestRegression), typeof(FastForestRegression.Options), diff --git a/src/Microsoft.ML.FastTree/RegressionTree.cs b/src/Microsoft.ML.FastTree/RegressionTree.cs index 867d6015ef..96a0784602 100644 --- a/src/Microsoft.ML.FastTree/RegressionTree.cs +++ b/src/Microsoft.ML.FastTree/RegressionTree.cs @@ -4,7 +4,7 @@ using System.Collections.Generic; using System.Collections.Immutable; -using Microsoft.ML.Trainers.FastTree.Internal; +using Microsoft.ML.Trainers.FastTree; namespace Microsoft.ML.FastTree { diff --git a/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs b/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs index a25b370586..7816dc993b 100644 --- a/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs +++ b/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs @@ -17,14 +17,13 @@ using Microsoft.ML.CommandLine; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; [assembly: LoadableClass(typeof(SumupPerformanceCommand), typeof(SumupPerformanceCommand.Arguments), typeof(SignatureCommand), "", "FastTreeSumupPerformance", "ftsumup")] namespace Microsoft.ML.Trainers.FastTree { -using Stopwatch = System.Diagnostics.Stopwatch; + using Stopwatch = System.Diagnostics.Stopwatch; /// /// This is an internal utility command to measure the performance of the IntArray sumup operation. diff --git a/src/Microsoft.ML.FastTree/Training/Applications/GradientWrappers.cs b/src/Microsoft.ML.FastTree/Training/Applications/GradientWrappers.cs index a1e65a455b..b360b9ea4a 100644 --- a/src/Microsoft.ML.FastTree/Training/Applications/GradientWrappers.cs +++ b/src/Microsoft.ML.FastTree/Training/Applications/GradientWrappers.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { /// /// Trivial weights wrapper. Creates proxy class holding the targets. diff --git a/src/Microsoft.ML.FastTree/Training/Applications/ObjectiveFunction.cs b/src/Microsoft.ML.FastTree/Training/Applications/ObjectiveFunction.cs index 2808e24cb7..644add5903 100644 --- a/src/Microsoft.ML.FastTree/Training/Applications/ObjectiveFunction.cs +++ b/src/Microsoft.ML.FastTree/Training/Applications/ObjectiveFunction.cs @@ -7,7 +7,7 @@ using System.Linq; using System.Threading.Tasks; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { public abstract class ObjectiveFunctionBase { diff --git a/src/Microsoft.ML.FastTree/Training/BaggingProvider.cs b/src/Microsoft.ML.FastTree/Training/BaggingProvider.cs index 3fd5125620..0dc50ce83b 100644 --- a/src/Microsoft.ML.FastTree/Training/BaggingProvider.cs +++ b/src/Microsoft.ML.FastTree/Training/BaggingProvider.cs @@ -4,7 +4,7 @@ using System; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { public class BaggingProvider { diff --git a/src/Microsoft.ML.FastTree/Training/DcgCalculator.cs b/src/Microsoft.ML.FastTree/Training/DcgCalculator.cs index e9284a6a5e..7fd960a190 100644 --- a/src/Microsoft.ML.FastTree/Training/DcgCalculator.cs +++ b/src/Microsoft.ML.FastTree/Training/DcgCalculator.cs @@ -7,7 +7,7 @@ using System.Threading.Tasks; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { public sealed class DcgCalculator { diff --git a/src/Microsoft.ML.FastTree/Training/DcgPermutationComparer.cs b/src/Microsoft.ML.FastTree/Training/DcgPermutationComparer.cs index 05e2955ec9..e0fb3ff926 100644 --- a/src/Microsoft.ML.FastTree/Training/DcgPermutationComparer.cs +++ b/src/Microsoft.ML.FastTree/Training/DcgPermutationComparer.cs @@ -4,7 +4,7 @@ using System.Collections.Generic; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { public abstract class DcgPermutationComparer : IComparer { diff --git a/src/Microsoft.ML.FastTree/Training/DocumentPartitioning.cs b/src/Microsoft.ML.FastTree/Training/DocumentPartitioning.cs index 3a52acdebf..6e06480014 100644 --- a/src/Microsoft.ML.FastTree/Training/DocumentPartitioning.cs +++ b/src/Microsoft.ML.FastTree/Training/DocumentPartitioning.cs @@ -7,7 +7,7 @@ using System.Linq; using System.Threading.Tasks; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { #if USE_SINGLE_PRECISION using FloatType = System.Single; diff --git a/src/Microsoft.ML.FastTree/Training/EnsembleCompression/IEnsembleCompressor.cs b/src/Microsoft.ML.FastTree/Training/EnsembleCompression/IEnsembleCompressor.cs index 7f3c687819..fb87ab95f2 100644 --- a/src/Microsoft.ML.FastTree/Training/EnsembleCompression/IEnsembleCompressor.cs +++ b/src/Microsoft.ML.FastTree/Training/EnsembleCompression/IEnsembleCompressor.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { internal interface IEnsembleCompressor { diff --git a/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoBasedEnsembleCompressor.cs b/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoBasedEnsembleCompressor.cs index eba82d0cfe..69b2acd798 100644 --- a/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoBasedEnsembleCompressor.cs +++ b/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoBasedEnsembleCompressor.cs @@ -6,7 +6,7 @@ using System.Collections.Generic; using System.Diagnostics; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { /// /// This implementation is based on: diff --git a/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoFit.cs b/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoFit.cs index 7a1e7ac321..a29752f612 100644 --- a/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoFit.cs +++ b/src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoFit.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { public sealed class LassoFit { diff --git a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/AcceleratedGradientDescent.cs b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/AcceleratedGradientDescent.cs index 0678eb4060..72cf5af52c 100644 --- a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/AcceleratedGradientDescent.cs +++ b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/AcceleratedGradientDescent.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { //Accelerated gradient descent score tracker internal class AcceleratedGradientDescent : GradientDescent diff --git a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/ConjugateGradientDescent.cs b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/ConjugateGradientDescent.cs index 044ef7bcf3..bde9917c53 100644 --- a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/ConjugateGradientDescent.cs +++ b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/ConjugateGradientDescent.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { // Conjugate gradient descent internal class ConjugateGradientDescent : GradientDescent diff --git a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/GradientDescent.cs b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/GradientDescent.cs index b2cf4f584b..5543763351 100644 --- a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/GradientDescent.cs +++ b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/GradientDescent.cs @@ -6,7 +6,7 @@ using System.Collections.Generic; using System.Linq; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { internal class GradientDescent : OptimizationAlgorithm { diff --git a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/NoOptimizationAlgorithm.cs b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/NoOptimizationAlgorithm.cs index 1925272122..9e53353316 100644 --- a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/NoOptimizationAlgorithm.cs +++ b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/NoOptimizationAlgorithm.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { /// /// This is dummy optimizer. As Random forest does not have any boosting based optimization, this is place holder to be consistent diff --git a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/OptimizationAlgorithm.cs b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/OptimizationAlgorithm.cs index 299d13a300..986a9c13f7 100644 --- a/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/OptimizationAlgorithm.cs +++ b/src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/OptimizationAlgorithm.cs @@ -5,7 +5,7 @@ using System; using System.Collections.Generic; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { //An interface that can be implemnted on public interface IFastTrainingScoresUpdate diff --git a/src/Microsoft.ML.FastTree/Training/Parallel/IParallelTraining.cs b/src/Microsoft.ML.FastTree/Training/Parallel/IParallelTraining.cs index fccb0673eb..965026556c 100644 --- a/src/Microsoft.ML.FastTree/Training/Parallel/IParallelTraining.cs +++ b/src/Microsoft.ML.FastTree/Training/Parallel/IParallelTraining.cs @@ -4,7 +4,6 @@ using System; using Microsoft.ML.EntryPoints; -using Microsoft.ML.Trainers.FastTree.Internal; namespace Microsoft.ML.Trainers.FastTree { diff --git a/src/Microsoft.ML.FastTree/Training/Parallel/SingleTrainer.cs b/src/Microsoft.ML.FastTree/Training/Parallel/SingleTrainer.cs index 734cda69d9..52adef9a50 100644 --- a/src/Microsoft.ML.FastTree/Training/Parallel/SingleTrainer.cs +++ b/src/Microsoft.ML.FastTree/Training/Parallel/SingleTrainer.cs @@ -15,9 +15,8 @@ namespace Microsoft.ML.Trainers.FastTree { - using Microsoft.ML.Trainers.FastTree.Internal; - using LeafSplitCandidates = Internal.LeastSquaresRegressionTreeLearner.LeafSplitCandidates; - using SplitInfo = Internal.LeastSquaresRegressionTreeLearner.SplitInfo; + using LeafSplitCandidates = LeastSquaresRegressionTreeLearner.LeafSplitCandidates; + using SplitInfo = LeastSquaresRegressionTreeLearner.SplitInfo; internal sealed class SingleTrainer : IParallelTraining { diff --git a/src/Microsoft.ML.FastTree/Training/RegressionTreeNodeDocuments.cs b/src/Microsoft.ML.FastTree/Training/RegressionTreeNodeDocuments.cs index 75fe766cc6..8849aaa53e 100644 --- a/src/Microsoft.ML.FastTree/Training/RegressionTreeNodeDocuments.cs +++ b/src/Microsoft.ML.FastTree/Training/RegressionTreeNodeDocuments.cs @@ -5,7 +5,7 @@ using System.Collections.Generic; using System.Linq; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { // RegressionTreeNodeDocuments represents an association between a node in a regression // tree and documents belonging to that node. diff --git a/src/Microsoft.ML.FastTree/Training/ScoreTracker.cs b/src/Microsoft.ML.FastTree/Training/ScoreTracker.cs index 297aa06e29..882fa4a2d6 100644 --- a/src/Microsoft.ML.FastTree/Training/ScoreTracker.cs +++ b/src/Microsoft.ML.FastTree/Training/ScoreTracker.cs @@ -6,7 +6,7 @@ using System.Collections.Generic; using System.Threading.Tasks; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { public class ScoreTracker { diff --git a/src/Microsoft.ML.FastTree/Training/StepSearch.cs b/src/Microsoft.ML.FastTree/Training/StepSearch.cs index bffa5213ca..ca767a22ed 100644 --- a/src/Microsoft.ML.FastTree/Training/StepSearch.cs +++ b/src/Microsoft.ML.FastTree/Training/StepSearch.cs @@ -5,7 +5,7 @@ using System; using System.Linq; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { internal interface IStepSearch { diff --git a/src/Microsoft.ML.FastTree/Training/Test.cs b/src/Microsoft.ML.FastTree/Training/Test.cs index c76903d522..88f985f67a 100644 --- a/src/Microsoft.ML.FastTree/Training/Test.cs +++ b/src/Microsoft.ML.FastTree/Training/Test.cs @@ -8,7 +8,7 @@ using System.Threading; using System.Threading.Tasks; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { public sealed class TestResult : IComparable { diff --git a/src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs b/src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs index 3009bd500a..ae7f2a3b2f 100644 --- a/src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs +++ b/src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs @@ -4,7 +4,7 @@ using System; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { internal class RandomForestLeastSquaresTreeLearner : LeastSquaresRegressionTreeLearner { diff --git a/src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs b/src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs index 823942f824..2cf8f44a75 100644 --- a/src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs +++ b/src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs @@ -9,7 +9,7 @@ using Microsoft.ML.Internal.CpuMath; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { #if USE_SINGLE_PRECISION using FloatType = System.Single; diff --git a/src/Microsoft.ML.FastTree/Training/TreeLearners/TreeLearner.cs b/src/Microsoft.ML.FastTree/Training/TreeLearners/TreeLearner.cs index e5f633b9c9..93dd45ca60 100644 --- a/src/Microsoft.ML.FastTree/Training/TreeLearners/TreeLearner.cs +++ b/src/Microsoft.ML.FastTree/Training/TreeLearners/TreeLearner.cs @@ -4,7 +4,7 @@ using System; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { internal abstract class TreeLearner { diff --git a/src/Microsoft.ML.FastTree/Training/WinLossCalculator.cs b/src/Microsoft.ML.FastTree/Training/WinLossCalculator.cs index 0320a23ded..a183ddc9f6 100644 --- a/src/Microsoft.ML.FastTree/Training/WinLossCalculator.cs +++ b/src/Microsoft.ML.FastTree/Training/WinLossCalculator.cs @@ -8,7 +8,7 @@ using System.Threading.Tasks; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { public sealed class WinLossCalculator { diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalQuantileRegressionTree.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalQuantileRegressionTree.cs index f93111201a..7a8d5dcf23 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalQuantileRegressionTree.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalQuantileRegressionTree.cs @@ -7,7 +7,7 @@ using Microsoft.ML.Model; using Float = System.Single; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { internal class InternalQuantileRegressionTree : InternalRegressionTree { @@ -50,7 +50,7 @@ public InternalQuantileRegressionTree(byte[] buffer, ref int position) _instanceWeights = buffer.ToDoubleArray(ref position); } - public override void Save(ModelSaveContext ctx) + internal override void Save(ModelSaveContext ctx) { // *** Binary format *** // double[]: Labels Distribution. diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs index 4727919022..31f65ae92e 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs @@ -15,7 +15,7 @@ using Microsoft.ML.Model.Pfa; using Newtonsoft.Json.Linq; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { /// Note that is shared between FastTree and LightGBM assemblies, /// so has . @@ -409,7 +409,7 @@ protected void Save(ModelSaveContext ctx, TreeType code) writer.WriteDoubleArray(_previousLeafValue); } - public virtual void Save(ModelSaveContext ctx) + internal virtual void Save(ModelSaveContext ctx) { Save(ctx, TreeType.Regression); } diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalTreeEnsemble.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalTreeEnsemble.cs index 421f788ef6..62fd670fc8 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalTreeEnsemble.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalTreeEnsemble.cs @@ -14,7 +14,7 @@ using Microsoft.ML.Model.Pfa; using Newtonsoft.Json.Linq; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { [BestFriend] internal class InternalTreeEnsemble @@ -56,7 +56,7 @@ public InternalTreeEnsemble(ModelLoadContext ctx, bool usingDefaultValues, bool _firstInputInitializationContent = ctx.LoadStringOrNull(); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { // *** Binary format *** // int: Number of trees diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs index f215dec82a..ab451165c8 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs @@ -4,16 +4,14 @@ using System.Collections.Generic; using Microsoft.ML; -using Microsoft.ML.Calibrator; using Microsoft.ML.Data; -using Microsoft.ML.Ensemble; using Microsoft.ML.Internal.Calibration; -using Microsoft.ML.Internal.Internallearn; -using Microsoft.ML.Trainers.FastTree.Internal; +using Microsoft.ML.Trainers.Ensemble; +using Microsoft.ML.Trainers.FastTree; [assembly: LoadableClass(typeof(TreeEnsembleCombiner), null, typeof(SignatureModelCombiner), "Fast Tree Model Combiner", "FastTreeCombiner")] -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { public sealed class TreeEnsembleCombiner : IModelCombiner { diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index e1aa26f56f..18a0cf1aec 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -393,7 +393,7 @@ public TreeEnsembleFeaturizerBindableMapper(IHostEnvironment env, ModelLoadConte _totalLeafCount = CountLeaves(_ensemble); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.FastTree/Utils/Algorithms.cs b/src/Microsoft.ML.FastTree/Utils/Algorithms.cs index 9ad94b359a..553e0e7553 100644 --- a/src/Microsoft.ML.FastTree/Utils/Algorithms.cs +++ b/src/Microsoft.ML.FastTree/Utils/Algorithms.cs @@ -5,7 +5,7 @@ using System; using System.Linq; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { public static class Algorithms { diff --git a/src/Microsoft.ML.FastTree/Utils/BlockingThreadPool.cs b/src/Microsoft.ML.FastTree/Utils/BlockingThreadPool.cs index 76d623a850..e096de7005 100644 --- a/src/Microsoft.ML.FastTree/Utils/BlockingThreadPool.cs +++ b/src/Microsoft.ML.FastTree/Utils/BlockingThreadPool.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { /// /// This class wraps the standard .NET ThreadPool and adds the following functionality: diff --git a/src/Microsoft.ML.FastTree/Utils/BufferPoolManager.cs b/src/Microsoft.ML.FastTree/Utils/BufferPoolManager.cs index cd0d03cfc8..526fddba09 100644 --- a/src/Microsoft.ML.FastTree/Utils/BufferPoolManager.cs +++ b/src/Microsoft.ML.FastTree/Utils/BufferPoolManager.cs @@ -10,7 +10,7 @@ using System.Linq; using System.Runtime.InteropServices; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { /// /// This class enables basic buffer pooling. diff --git a/src/Microsoft.ML.FastTree/Utils/CompressUtils.cs b/src/Microsoft.ML.FastTree/Utils/CompressUtils.cs index 0588e386ae..4c87a666a9 100644 --- a/src/Microsoft.ML.FastTree/Utils/CompressUtils.cs +++ b/src/Microsoft.ML.FastTree/Utils/CompressUtils.cs @@ -7,7 +7,7 @@ using System.IO; using System.IO.Compression; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { internal struct BufferBlock { diff --git a/src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs b/src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs index 939306cbaa..a9400739dd 100644 --- a/src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs +++ b/src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs @@ -8,7 +8,7 @@ using Microsoft.ML.Internal.Calibration; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { internal static class FastTreeIniFileUtils { diff --git a/src/Microsoft.ML.FastTree/Utils/LinqExtensions.cs b/src/Microsoft.ML.FastTree/Utils/LinqExtensions.cs index 88d8825205..0b912e58bc 100644 --- a/src/Microsoft.ML.FastTree/Utils/LinqExtensions.cs +++ b/src/Microsoft.ML.FastTree/Utils/LinqExtensions.cs @@ -6,7 +6,7 @@ using System.Collections.Generic; using System.Linq; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { public static class LinqExtensions { diff --git a/src/Microsoft.ML.FastTree/Utils/MD5Hasher.cs b/src/Microsoft.ML.FastTree/Utils/MD5Hasher.cs index fc31ec097e..8a5fda81ea 100644 --- a/src/Microsoft.ML.FastTree/Utils/MD5Hasher.cs +++ b/src/Microsoft.ML.FastTree/Utils/MD5Hasher.cs @@ -7,7 +7,7 @@ using System.Security.Cryptography; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { public struct MD5Hash { diff --git a/src/Microsoft.ML.FastTree/Utils/MappedObjectPool.cs b/src/Microsoft.ML.FastTree/Utils/MappedObjectPool.cs index bcfb68e210..ad9b54e6cd 100644 --- a/src/Microsoft.ML.FastTree/Utils/MappedObjectPool.cs +++ b/src/Microsoft.ML.FastTree/Utils/MappedObjectPool.cs @@ -5,7 +5,7 @@ using System; using System.Linq; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { /// /// Implements a paging mechanism on indexed objects. diff --git a/src/Microsoft.ML.FastTree/Utils/PseudorandomFunction.cs b/src/Microsoft.ML.FastTree/Utils/PseudorandomFunction.cs index 59adf3a453..9a7800d5ac 100644 --- a/src/Microsoft.ML.FastTree/Utils/PseudorandomFunction.cs +++ b/src/Microsoft.ML.FastTree/Utils/PseudorandomFunction.cs @@ -5,7 +5,7 @@ using System; using System.Linq; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { /// /// This class defines a psuedorandom function, mapping a number to diff --git a/src/Microsoft.ML.FastTree/Utils/StreamExtensions.cs b/src/Microsoft.ML.FastTree/Utils/StreamExtensions.cs index c12ba5d066..125b49fecb 100644 --- a/src/Microsoft.ML.FastTree/Utils/StreamExtensions.cs +++ b/src/Microsoft.ML.FastTree/Utils/StreamExtensions.cs @@ -5,7 +5,7 @@ using System.IO; using System.IO.Compression; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { public static class StreamExtensions { diff --git a/src/Microsoft.ML.FastTree/Utils/ThreadTaskManager.cs b/src/Microsoft.ML.FastTree/Utils/ThreadTaskManager.cs index d479c2c53b..aa0e06ffc0 100644 --- a/src/Microsoft.ML.FastTree/Utils/ThreadTaskManager.cs +++ b/src/Microsoft.ML.FastTree/Utils/ThreadTaskManager.cs @@ -7,7 +7,7 @@ using System.Linq; using System.Threading.Tasks; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { internal static class ThreadTaskManager { diff --git a/src/Microsoft.ML.FastTree/Utils/Timer.cs b/src/Microsoft.ML.FastTree/Utils/Timer.cs index 35770a5e50..feb8d631cc 100644 --- a/src/Microsoft.ML.FastTree/Utils/Timer.cs +++ b/src/Microsoft.ML.FastTree/Utils/Timer.cs @@ -6,7 +6,7 @@ using System.Text; using System.Threading; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { using Stopwatch = System.Diagnostics.Stopwatch; diff --git a/src/Microsoft.ML.FastTree/Utils/ToByteArrayExtensions.cs b/src/Microsoft.ML.FastTree/Utils/ToByteArrayExtensions.cs index 108aeeac87..1452a3bbba 100644 --- a/src/Microsoft.ML.FastTree/Utils/ToByteArrayExtensions.cs +++ b/src/Microsoft.ML.FastTree/Utils/ToByteArrayExtensions.cs @@ -7,7 +7,7 @@ using System.Text; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { /// /// This class contains extension methods that support binary serialization of some base C# types diff --git a/src/Microsoft.ML.FastTree/Utils/VectorUtils.cs b/src/Microsoft.ML.FastTree/Utils/VectorUtils.cs index 8941d50da0..3b54b1a33b 100644 --- a/src/Microsoft.ML.FastTree/Utils/VectorUtils.cs +++ b/src/Microsoft.ML.FastTree/Utils/VectorUtils.cs @@ -5,7 +5,7 @@ using System; using System.Text; -namespace Microsoft.ML.Trainers.FastTree.Internal +namespace Microsoft.ML.Trainers.FastTree { public class VectorUtils { diff --git a/src/Microsoft.ML.HalLearners/ComputeLRTrainingStdThroughHal.cs b/src/Microsoft.ML.HalLearners/ComputeLRTrainingStdThroughHal.cs index bc30deabfc..ce3bad8b99 100644 --- a/src/Microsoft.ML.HalLearners/ComputeLRTrainingStdThroughHal.cs +++ b/src/Microsoft.ML.HalLearners/ComputeLRTrainingStdThroughHal.cs @@ -7,7 +7,7 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Trainers.HalLearners; -namespace Microsoft.ML.Learners +namespace Microsoft.ML.Trainers { using Mkl = OlsLinearRegressionTrainer.Mkl; diff --git a/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs b/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs index 1b635d49fc..564a5df3c6 100644 --- a/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs +++ b/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs @@ -2,11 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; using Microsoft.ML.Trainers.HalLearners; -using Microsoft.ML.Trainers.SymSgd; using Microsoft.ML.Transforms.Projections; namespace Microsoft.ML diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs index 18a5d97b97..31c01b6324 100644 --- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs +++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs @@ -15,7 +15,6 @@ using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Learners; using Microsoft.ML.Model; using Microsoft.ML.Trainers.HalLearners; using Microsoft.ML.Training; diff --git a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs index 8b7377bbe1..e23619e81c 100644 --- a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs +++ b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs @@ -16,8 +16,7 @@ using Microsoft.ML.Internal.Calibration; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Learners; -using Microsoft.ML.Trainers.SymSgd; +using Microsoft.ML.Trainers.HalLearners; using Microsoft.ML.Training; using Microsoft.ML.Transforms; @@ -29,7 +28,7 @@ [assembly: LoadableClass(typeof(void), typeof(SymSgdClassificationTrainer), null, typeof(SignatureEntryPointModule), SymSgdClassificationTrainer.LoadNameValue)] -namespace Microsoft.ML.Trainers.SymSgd +namespace Microsoft.ML.Trainers.HalLearners { using TPredictor = CalibratedModelParametersBase; diff --git a/src/Microsoft.ML.HalLearners/VectorWhitening.cs b/src/Microsoft.ML.HalLearners/VectorWhitening.cs index f865ba4fcd..fc298c1ed1 100644 --- a/src/Microsoft.ML.HalLearners/VectorWhitening.cs +++ b/src/Microsoft.ML.HalLearners/VectorWhitening.cs @@ -464,7 +464,7 @@ private static void TrainModels(IHostEnvironment env, IChannel ch, float[][] col } } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs index cf5388d8da..fc9ed6c699 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs @@ -141,7 +141,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs index 8829d1dd2f..dab2dc17c2 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs @@ -138,7 +138,7 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, TextType.Instance.ToString(), inputSchema[srcCol].Type.ToString()); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs index a34fff871c..8359c4d9a6 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs @@ -252,7 +252,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs index 665cfd516d..a23c0d0a93 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs @@ -235,7 +235,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs index cc0a5da52e..da26e1f5dc 100644 --- a/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs @@ -195,7 +195,7 @@ public ColInfoEx(ModelLoadContext ctx) Interleave = ctx.Reader.ReadBoolByte(); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -306,7 +306,7 @@ public static VectorToImageTransform Create(IHostEnvironment env, ModelLoadConte }); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index e73bf4acb9..bf01aecbfd 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -13,7 +13,6 @@ using Microsoft.ML.LightGBM; using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; [assembly: LoadableClass(LightGbmBinaryTrainer.Summary, typeof(LightGbmBinaryTrainer), typeof(Options), diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index a3ec3ffef5..3b2a1b56af 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -12,7 +12,7 @@ using Microsoft.ML.Internal.Calibration; using Microsoft.ML.LightGBM; using Microsoft.ML.Trainers; -using Microsoft.ML.Trainers.FastTree.Internal; +using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Training; [assembly: LoadableClass(LightGbmMulticlassTrainer.Summary, typeof(LightGbmMulticlassTrainer), typeof(Options), diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index 04e78e7a4d..559ed6b5e1 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -11,7 +11,6 @@ using Microsoft.ML.LightGBM; using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; [assembly: LoadableClass(LightGbmRankingTrainer.UserName, typeof(LightGbmRankingTrainer), typeof(Options), diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index 676627aa84..2068746cea 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -10,7 +10,6 @@ using Microsoft.ML.LightGBM; using Microsoft.ML.Model; using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; [assembly: LoadableClass(LightGbmRegressorTrainer.Summary, typeof(LightGbmRegressorTrainer), typeof(Options), diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index 1ce9b0ccf4..695157f91b 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -8,7 +8,7 @@ using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Trainers.FastTree.Internal; +using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Training; namespace Microsoft.ML.LightGBM diff --git a/src/Microsoft.ML.LightGBM/WrappedLightGbmBooster.cs b/src/Microsoft.ML.LightGBM/WrappedLightGbmBooster.cs index 04dc9f9f62..cf8b9e4c8e 100644 --- a/src/Microsoft.ML.LightGBM/WrappedLightGbmBooster.cs +++ b/src/Microsoft.ML.LightGBM/WrappedLightGbmBooster.cs @@ -6,7 +6,7 @@ using System.Collections.Generic; using System.Globalization; using System.Linq; -using Microsoft.ML.Trainers.FastTree.Internal; +using Microsoft.ML.Trainers.FastTree; namespace Microsoft.ML.LightGBM { diff --git a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs index 30f20846c4..22bb76c019 100644 --- a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs +++ b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs @@ -261,7 +261,7 @@ internal OnnxTransformer(IHostEnvironment env, string[] outputColumnNames, strin { } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.AssertValue(ctx); @@ -368,7 +368,7 @@ private protected override Func GetDependenciesCore(Func a return col => Enumerable.Range(0, _parent.Outputs.Length).Any(i => activeOutput(i)) && _inputColIndices.Any(i => i == col); } - public override void Save(ModelSaveContext ctx) => _parent.Save(ctx); + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); private interface INamedOnnxValueGetter { diff --git a/src/Microsoft.ML.PCA/PcaTransformer.cs b/src/Microsoft.ML.PCA/PcaTransformer.cs index b6b134a1e0..dd17255d0a 100644 --- a/src/Microsoft.ML.PCA/PcaTransformer.cs +++ b/src/Microsoft.ML.PCA/PcaTransformer.cs @@ -140,7 +140,7 @@ public TransformInfo(ModelLoadContext ctx) Contracts.CheckDecode(MeanProjected == null || (MeanProjected.Length == Rank && FloatUtils.IsFinite(MeanProjected))); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -279,7 +279,7 @@ private static PrincipalComponentAnalysisTransformer Create(IHostEnvironment env return new PrincipalComponentAnalysisTransformer(host, ctx); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Parquet/ParquetLoader.cs b/src/Microsoft.ML.Parquet/ParquetLoader.cs index 0f733273fb..9f5eaa97d3 100644 --- a/src/Microsoft.ML.Parquet/ParquetLoader.cs +++ b/src/Microsoft.ML.Parquet/ParquetLoader.cs @@ -406,7 +406,7 @@ public RowCursor[] GetRowCursorSet(IEnumerable columnsNee return new RowCursor[] { GetRowCursor(columnsNeeded, rand) }; } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs b/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs index 8f2c83a17f..98c83a59e2 100644 --- a/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs +++ b/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs @@ -253,7 +253,7 @@ public static PartitionedFileLoader Create(IHostEnvironment env, ModelLoadContex ch => new PartitionedFileLoader(host, ctx, files)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Parquet/PartitionedPathParser.cs b/src/Microsoft.ML.Parquet/PartitionedPathParser.cs index 06ae2b9700..70a01be64b 100644 --- a/src/Microsoft.ML.Parquet/PartitionedPathParser.cs +++ b/src/Microsoft.ML.Parquet/PartitionedPathParser.cs @@ -148,7 +148,7 @@ public static SimplePartitionedPathParser Create(IHostEnvironment env, ModelLoad ch => new SimplePartitionedPathParser(host, ctx)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.SetVersionInfo(GetVersionInfo()); @@ -261,7 +261,7 @@ public static ParquetPartitionedPathParser Create(IHostEnvironment env, ModelLoa ch => new ParquetPartitionedPathParser(host, ctx)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs index 8642188c2f..a0007d9108 100644 --- a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs +++ b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs @@ -16,7 +16,7 @@ using Microsoft.ML.Recommender.Internal; using Microsoft.ML.Trainers.Recommender; -[assembly: LoadableClass(typeof(MatrixFactorizationPredictor), null, typeof(SignatureLoadModel), "Matrix Factorization Predictor Executor", MatrixFactorizationPredictor.LoaderSignature)] +[assembly: LoadableClass(typeof(MatrixFactorizationModelParameters), null, typeof(SignatureLoadModel), "Matrix Factorization Predictor Executor", MatrixFactorizationModelParameters.LoaderSignature)] [assembly: LoadableClass(typeof(MatrixFactorizationPredictionTransformer), typeof(MatrixFactorizationPredictionTransformer), null, typeof(SignatureLoadModel), "", MatrixFactorizationPredictionTransformer.LoaderSignature)] @@ -24,12 +24,15 @@ namespace Microsoft.ML.Trainers.Recommender { /// - /// stores two factor matrices, P and Q, for approximating the training matrix, R, by P * Q, - /// where * is a matrix multiplication. This predictor expects two inputs, row index and column index, and produces the (approximated) + /// Model parameters for matrix factorization recommender. + /// + /// + /// stores two factor matrices, P and Q, for approximating the training matrix, R, by P * Q, + /// where * is a matrix multiplication. This model expects two inputs, row index and column index, and produces the (approximated) /// value at the location specified by the two inputs in R. More specifically, if input row and column indices are u and v, respectively. /// The output (a scalar) would be the inner product product of the u-th row in P and the v-th column in Q. - /// - public sealed class MatrixFactorizationPredictor : IPredictor, ICanSaveModel, ICanSaveInTextFormat, ISchemaBindableMapper + /// + public sealed class MatrixFactorizationModelParameters : IPredictor, ICanSaveModel, ICanSaveInTextFormat, ISchemaBindableMapper { internal const string LoaderSignature = "MFPredictor"; internal const string RegistrationName = "MatrixFactorizationPredictor"; @@ -43,33 +46,42 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(MatrixFactorizationPredictor).Assembly.FullName); + loaderAssemblyName: typeof(MatrixFactorizationModelParameters).Assembly.FullName); } private const uint VersionNoMinCount = 0x00010002; private readonly IHost _host; - // The number of rows. - private readonly int _numberOfRows; - // The number of columns. - private readonly int _numberofColumns; - // The rank of the factor matrices. - private readonly int _approximationRank; - // Packed _numberOfRows by _approximationRank matrix. - private readonly float[] _leftFactorMatrix; - // Packed _approximationRank by _numberofColumns matrix. - private readonly float[] _rightFactorMatrix; - - public PredictionKind PredictionKind - { - get { return PredictionKind.Recommendation; } - } + /// The number of rows. + public readonly int NumberOfRows; + /// The number of columns. + public readonly int NumberOfColumns; + /// The rank of the factor matrices. + public readonly int ApproximationRank; + /// + /// Left approximation matrix + /// + /// + /// This is two dimensional matrix with size of * flattened into one-dimensional matrix. + /// Row by row. + /// + public readonly IReadOnlyList LeftFactorMatrix; + /// + /// Left approximation matrix + /// + /// + /// This is two dimensional matrix with size of * flattened into one-dimensional matrix. + /// Row by row. + /// + public readonly IReadOnlyList RightFactorMatrix; + + public PredictionKind PredictionKind => PredictionKind.Recommendation; - public ColumnType OutputType { get { return NumberType.Float; } } + private ColumnType OutputType => NumberType.Float; - public ColumnType MatrixColumnIndexType { get; } - public ColumnType MatrixRowIndexType { get; } + internal ColumnType MatrixColumnIndexType { get; } + internal ColumnType MatrixRowIndexType { get; } - internal MatrixFactorizationPredictor(IHostEnvironment env, SafeTrainingAndModelBuffer buffer, KeyType matrixColumnIndexType, KeyType matrixRowIndexType) + internal MatrixFactorizationModelParameters(IHostEnvironment env, SafeTrainingAndModelBuffer buffer, KeyType matrixColumnIndexType, KeyType matrixRowIndexType) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(RegistrationName); @@ -78,18 +90,19 @@ internal MatrixFactorizationPredictor(IHostEnvironment env, SafeTrainingAndModel _host.CheckValue(buffer, nameof(buffer)); _host.CheckValue(matrixColumnIndexType, nameof(matrixColumnIndexType)); _host.CheckValue(matrixRowIndexType, nameof(matrixRowIndexType)); - - buffer.Get(out _numberOfRows, out _numberofColumns, out _approximationRank, out _leftFactorMatrix, out _rightFactorMatrix); - _host.Assert(_numberofColumns == matrixColumnIndexType.GetCountAsInt32(_host)); - _host.Assert(_numberOfRows == matrixRowIndexType.GetCountAsInt32(_host)); - _host.Assert(_leftFactorMatrix.Length == _numberOfRows * _approximationRank); - _host.Assert(_rightFactorMatrix.Length == _numberofColumns * _approximationRank); + buffer.Get(out NumberOfRows, out NumberOfColumns, out ApproximationRank, out var leftFactorMatrix, out var rightFactorMatrix); + LeftFactorMatrix = leftFactorMatrix; + RightFactorMatrix = rightFactorMatrix; + _host.Assert(NumberOfColumns == matrixColumnIndexType.GetCountAsInt32(_host)); + _host.Assert(NumberOfRows == matrixRowIndexType.GetCountAsInt32(_host)); + _host.Assert(LeftFactorMatrix.Count == NumberOfRows * ApproximationRank); + _host.Assert(RightFactorMatrix.Count == ApproximationRank * NumberOfColumns); MatrixColumnIndexType = matrixColumnIndexType; MatrixRowIndexType = matrixRowIndexType; } - private MatrixFactorizationPredictor(IHostEnvironment env, ModelLoadContext ctx) + private MatrixFactorizationModelParameters(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(RegistrationName); @@ -100,49 +113,49 @@ private MatrixFactorizationPredictor(IHostEnvironment env, ModelLoadContext ctx) // float[m * k]: the left factor matrix // float[k * n]: the right factor matrix - _numberOfRows = ctx.Reader.ReadInt32(); - _host.CheckDecode(_numberOfRows > 0); + NumberOfRows = ctx.Reader.ReadInt32(); + _host.CheckDecode(NumberOfRows > 0); if (ctx.Header.ModelVerWritten < VersionNoMinCount) { ulong mMin = ctx.Reader.ReadUInt64(); // We no longer support non zero Min for KeyType. _host.CheckDecode(mMin == 0); - _host.CheckDecode((ulong)_numberOfRows <= ulong.MaxValue - mMin); + _host.CheckDecode((ulong)NumberOfRows <= ulong.MaxValue - mMin); } - _numberofColumns = ctx.Reader.ReadInt32(); - _host.CheckDecode(_numberofColumns > 0); + NumberOfColumns = ctx.Reader.ReadInt32(); + _host.CheckDecode(NumberOfColumns > 0); if (ctx.Header.ModelVerWritten < VersionNoMinCount) { ulong nMin = ctx.Reader.ReadUInt64(); // We no longer support non zero Min for KeyType. _host.CheckDecode(nMin == 0); - _host.CheckDecode((ulong)_numberofColumns <= ulong.MaxValue - nMin); + _host.CheckDecode((ulong)NumberOfColumns <= ulong.MaxValue - nMin); } - _approximationRank = ctx.Reader.ReadInt32(); - _host.CheckDecode(_approximationRank > 0); + ApproximationRank = ctx.Reader.ReadInt32(); + _host.CheckDecode(ApproximationRank > 0); - _leftFactorMatrix = Utils.ReadSingleArray(ctx.Reader, checked(_numberOfRows * _approximationRank)); - _rightFactorMatrix = Utils.ReadSingleArray(ctx.Reader, checked(_numberofColumns * _approximationRank)); + LeftFactorMatrix = Utils.ReadSingleArray(ctx.Reader, checked(NumberOfRows * ApproximationRank)); + RightFactorMatrix = Utils.ReadSingleArray(ctx.Reader, checked(NumberOfColumns * ApproximationRank)); - MatrixColumnIndexType = new KeyType(typeof(uint), _numberofColumns); - MatrixRowIndexType = new KeyType(typeof(uint), _numberOfRows); + MatrixColumnIndexType = new KeyType(typeof(uint), NumberOfColumns); + MatrixRowIndexType = new KeyType(typeof(uint), NumberOfRows); } /// /// Load model from the given context /// - public static MatrixFactorizationPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static MatrixFactorizationModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new MatrixFactorizationPredictor(env, ctx); + return new MatrixFactorizationModelParameters(env, ctx); } /// /// Save model to the given context /// - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); @@ -154,16 +167,16 @@ public void Save(ModelSaveContext ctx) // float[m * k]: the left factor matrix // float[k * n]: the right factor matrix - _host.Check(_numberOfRows > 0, "Number of rows must be positive"); - _host.Check(_numberofColumns > 0, "Number of columns must be positive"); - _host.Check(_approximationRank > 0, "Number of latent factors must be positive"); - ctx.Writer.Write(_numberOfRows); - ctx.Writer.Write(_numberofColumns); - ctx.Writer.Write(_approximationRank); - _host.Check(Utils.Size(_leftFactorMatrix) == _numberOfRows * _approximationRank, "Unexpected matrix size of a factor matrix (matrix P in LIBMF paper)"); - _host.Check(Utils.Size(_rightFactorMatrix) == _numberofColumns * _approximationRank, "Unexpected matrix size of a factor matrix (matrix Q in LIBMF paper)"); - Utils.WriteSinglesNoCount(ctx.Writer, _leftFactorMatrix.AsSpan(0, _numberOfRows * _approximationRank)); - Utils.WriteSinglesNoCount(ctx.Writer, _rightFactorMatrix.AsSpan(0, _numberofColumns * _approximationRank)); + _host.Check(NumberOfRows > 0, "Number of rows must be positive"); + _host.Check(NumberOfColumns > 0, "Number of columns must be positive"); + _host.Check(ApproximationRank > 0, "Number of latent factors must be positive"); + ctx.Writer.Write(NumberOfRows); + ctx.Writer.Write(NumberOfColumns); + ctx.Writer.Write(ApproximationRank); + _host.Check(Utils.Size(LeftFactorMatrix) == NumberOfRows * ApproximationRank, "Unexpected matrix size of a factor matrix (matrix P in LIBMF paper)"); + _host.Check(Utils.Size(RightFactorMatrix) == NumberOfColumns * ApproximationRank, "Unexpected matrix size of a factor matrix (matrix Q in LIBMF paper)"); + Utils.WriteSinglesNoCount(ctx.Writer, LeftFactorMatrix as float[]); + Utils.WriteSinglesNoCount(ctx.Writer, RightFactorMatrix as float[]); } /// @@ -172,20 +185,20 @@ public void Save(ModelSaveContext ctx) void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { writer.WriteLine("# Imputed matrix is P * Q'"); - writer.WriteLine("# P in R^({0} x {1}), rows correpond to Y item", _numberOfRows, _approximationRank); - for (int i = 0; i < _leftFactorMatrix.Length; ++i) + writer.WriteLine("# P in R^({0} x {1}), rows correpond to Y item", NumberOfRows, ApproximationRank); + for (int i = 0; i < LeftFactorMatrix.Count; ++i) { - writer.Write(_leftFactorMatrix[i].ToString("G")); - if (i % _approximationRank == _approximationRank - 1) + writer.Write(LeftFactorMatrix[i].ToString("G")); + if (i % ApproximationRank == ApproximationRank - 1) writer.WriteLine(); else writer.Write('\t'); } - writer.WriteLine("# Q in R^({0} x {1}), rows correpond to X item", _numberofColumns, _approximationRank); - for (int i = 0; i < _rightFactorMatrix.Length; ++i) + writer.WriteLine("# Q in R^({0} x {1}), rows correpond to X item", NumberOfColumns, ApproximationRank); + for (int i = 0; i < RightFactorMatrix.Count; ++i) { - writer.Write(_rightFactorMatrix[i].ToString("G")); - if (i % _approximationRank == _approximationRank - 1) + writer.Write(RightFactorMatrix[i].ToString("G")); + if (i % ApproximationRank == ApproximationRank - 1) writer.WriteLine(); else writer.Write('\t'); @@ -241,7 +254,7 @@ private void MapperCore(in uint srcCol, ref uint srcRow, ref float dst) // training. For higher-than-expected values, the predictor version would return // 0, rather than NaN as we do here. It is in my mind an open question as to what // is actually correct. - if (srcRow == 0 || srcRow > _numberOfRows || srcCol == 0 || srcCol > _numberofColumns) + if (srcRow == 0 || srcRow > NumberOfRows || srcCol == 0 || srcCol > NumberOfColumns) { dst = float.NaN; return; @@ -251,15 +264,15 @@ private void MapperCore(in uint srcCol, ref uint srcRow, ref float dst) private float Score(int columnIndex, int rowIndex) { - _host.Assert(0 <= rowIndex && rowIndex < _numberOfRows); - _host.Assert(0 <= columnIndex && columnIndex < _numberofColumns); + _host.Assert(0 <= rowIndex && rowIndex < NumberOfRows); + _host.Assert(0 <= columnIndex && columnIndex < NumberOfColumns); float score = 0; // Starting position of the rowIndex-th row in the left factor factor matrix - int rowOffset = rowIndex * _approximationRank; + int rowOffset = rowIndex * ApproximationRank; // Starting position of the columnIndex-th column in the right factor factor matrix - int columnOffset = columnIndex * _approximationRank; - for (int i = 0; i < _approximationRank; i++) - score += _leftFactorMatrix[rowOffset + i] * _rightFactorMatrix[columnOffset + i]; + int columnOffset = columnIndex * ApproximationRank; + for (int i = 0; i < ApproximationRank; i++) + score += LeftFactorMatrix[rowOffset + i] * RightFactorMatrix[columnOffset + i]; return score; } @@ -276,7 +289,7 @@ ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSc private sealed class RowMapper : ISchemaBoundRowMapper { - private readonly MatrixFactorizationPredictor _parent; + private readonly MatrixFactorizationModelParameters _parent; // The tail "ColumnIndex" means the column index in IDataView private readonly int _matrixColumnIndexColumnIndex; private readonly int _matrixRowIndexCololumnIndex; @@ -289,7 +302,7 @@ private sealed class RowMapper : ISchemaBoundRowMapper public RoleMappedSchema InputRoleMappedSchema { get; } - public RowMapper(IHostEnvironment env, MatrixFactorizationPredictor parent, RoleMappedSchema schema, Schema outputSchema) + public RowMapper(IHostEnvironment env, MatrixFactorizationModelParameters parent, RoleMappedSchema schema, Schema outputSchema) { Contracts.AssertValue(parent); _env = env; @@ -375,13 +388,16 @@ public Row GetRow(Row input, Func active) } } - public sealed class MatrixFactorizationPredictionTransformer : PredictionTransformerBase, ICanSaveModel + /// + /// Trains a . It factorizes the training matrix into the product of two low-rank matrices. + /// + public sealed class MatrixFactorizationPredictionTransformer : PredictionTransformerBase { - public const string LoaderSignature = "MaFactPredXf"; - public string MatrixColumnIndexColumnName { get; } - public string MatrixRowIndexColumnName { get; } - public ColumnType MatrixColumnIndexColumnType { get; } - public ColumnType MatrixRowIndexColumnType { get; } + internal const string LoaderSignature = "MaFactPredXf"; + internal string MatrixColumnIndexColumnName { get; } + internal string MatrixRowIndexColumnName { get; } + internal ColumnType MatrixColumnIndexColumnType { get; } + internal ColumnType MatrixRowIndexColumnType { get; } /// /// Build a transformer based on matrix factorization predictor (model) and the input schema (trainSchema). The created @@ -395,7 +411,7 @@ public sealed class MatrixFactorizationPredictionTransformer : PredictionTransfo /// The name of the column used as role in matrix factorization world /// The name of the column used as role in matrix factorization world /// A string attached to the output column name of this transformer - public MatrixFactorizationPredictionTransformer(IHostEnvironment env, MatrixFactorizationPredictor model, Schema trainSchema, + internal MatrixFactorizationPredictionTransformer(IHostEnvironment env, MatrixFactorizationModelParameters model, Schema trainSchema, string matrixColumnIndexColumnName, string matrixRowIndexColumnName, string scoreColumnNameSuffix = "") : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MatrixFactorizationPredictionTransformer)), model, trainSchema) { @@ -432,7 +448,7 @@ private RoleMappedSchema GetSchema() /// The counter constructor of re-creating from the context where /// the original transform is saved. /// - public MatrixFactorizationPredictionTransformer(IHostEnvironment host, ModelLoadContext ctx) + private MatrixFactorizationPredictionTransformer(IHostEnvironment host, ModelLoadContext ctx) : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(MatrixFactorizationPredictionTransformer)), ctx) { // *** Binary format *** @@ -458,6 +474,10 @@ public MatrixFactorizationPredictionTransformer(IHostEnvironment host, ModelLoad Scorer = new GenericScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } + /// + /// Schema propagation for transformers. + /// Returns the output schema of the data, if the input schema is like the one provided. + /// public override Schema GetOutputSchema(Schema inputSchema) { if (!inputSchema.TryGetColumnIndex(MatrixColumnIndexColumnName, out int xCol)) @@ -468,7 +488,7 @@ public override Schema GetOutputSchema(Schema inputSchema) return Transform(new EmptyDataView(Host, inputSchema)).Schema; } - public void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs index c3696d22bf..e09eab6ace 100644 --- a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs +++ b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs @@ -83,15 +83,37 @@ namespace Microsoft.ML.Trainers /// /// /// /// /// - public sealed class MatrixFactorizationTrainer : TrainerBase, + public sealed class MatrixFactorizationTrainer : TrainerBase, IEstimator { - public enum LossFunctionType { SquareLossRegression = 0, SquareLossOneClass = 12 }; + /// + /// Type of loss function. + /// + public enum LossFunctionType + { + /// + /// Used in traditional collaborative filtering problem with squared loss. + /// + /// + /// See Equation (1). + /// + SquareLossRegression = 0, + /// + /// Used in implicit-feedback recommendation problem. + /// + /// + /// See Equation (3). + /// + SquareLossOneClass = 12 + }; + /// + /// Advanced options for the . + /// public sealed class Options { /// @@ -110,40 +132,69 @@ public sealed class Options public string LabelColumnName; /// - /// Loss function minimized for finding factor matrices. Two values are allowed, 0 or 12. The values 0 means traditional collaborative filtering - /// problem with squared loss. The value 12 triggers one-class matrix factorization for implicit-feedback recommendation problem. + /// Loss function minimized for finding factor matrices. /// + /// + /// Two values are allowed, or . + /// The means traditional collaborative filtering problem with squared loss. + /// The triggers one-class matrix factorization for implicit-feedback recommendation problem. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Loss function minimized for finding factor matrices.")] [TGUI(SuggestedSweeps = "0,12")] [TlcModule.SweepableDiscreteParam("LossFunction", new object[] { LossFunctionType.SquareLossRegression, LossFunctionType.SquareLossOneClass })] - public LossFunctionType LossFunction = LossFunctionType.SquareLossRegression; + public LossFunctionType LossFunction = Defaults.LossFunction; + /// + /// Regularization parameter. + /// + /// + /// It's the weight of factor matrices Frobenius norms in the objective function minimized by matrix factorization's algorithm. A small value could cause over-fitting. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Regularization parameter. " + - "It's the weight of factor matrices' norms in the objective function minimized by matrix factorization's algorithm. " + + "It's the weight of factor matrices Frobenius norms in the objective function minimized by matrix factorization's algorithm. " + "A small value could cause over-fitting.")] [TGUI(SuggestedSweeps = "0.01,0.05,0.1,0.5,1")] [TlcModule.SweepableDiscreteParam("Lambda", new object[] { 0.01f, 0.05f, 0.1f, 0.5f, 1f })] - public double Lambda = 0.1; + public double Lambda = Defaults.Lambda; + /// + /// Rank of approximation matrixes. + /// + /// + /// If input data has size of m-by-n we would build two approximation matrixes m-by-k and k-by-n where k is approximation rank. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Latent space dimension (denoted by k). If the factorized matrix is m-by-n, " + "two factor matrices found by matrix factorization are m-by-k and k-by-n, respectively. " + - "This value is also known as the rank of matrix factorization because k is generally much smaller than m and n.")] + "This value is also known as the rank of matrix factorization because k is generally much smaller than m and n.", ShortName = "K")] [TGUI(SuggestedSweeps = "8,16,64,128")] [TlcModule.SweepableDiscreteParam("K", new object[] { 8, 16, 64, 128 })] - public int K = 8; + public int ApproximationRank = Defaults.ApproximationRank; + /// + /// Number of training iterations. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Training iterations; that is, the times that the training algorithm iterates through the whole training data once.", ShortName = "iter")] [TGUI(SuggestedSweeps = "10,20,40")] [TlcModule.SweepableDiscreteParam("NumIterations", new object[] { 10, 20, 40 })] - public int NumIterations = 20; - + public int NumIterations = Defaults.NumIterations; + + /// + /// Initial learning rate. It specifies the speed of the training algorithm. + /// + /// + /// Small value may increase the number of iterations needed to achieve a reasonable result. + /// Large value may lead to numerical difficulty such as a infinity value. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Initial learning rate. It specifies the speed of the training algorithm. " + - "Small value may increase the number of iterations needed to achieve a reasonable result. Large value may lead to numerical difficulty such as a infinity value.")] + "Small value may increase the number of iterations needed to achieve a reasonable result. Large value may lead to numerical difficulty such as a infinity value.", ShortName = "Eta")] [TGUI(SuggestedSweeps = "0.001,0.01,0.1")] [TlcModule.SweepableDiscreteParam("Eta", new object[] { 0.001f, 0.01f, 0.1f })] - public double Eta = 0.1; + public double LearningRate = Defaults.LearningRate; /// + /// Importance of unobserved entries' loss in one-class matrix factorization. Applicable if set to + /// + /// /// Importance of unobserved (i.e., negative) entries' loss in one-class matrix factorization. /// In general, only a few of matrix entries (e.g., less than 1%) in the training are observed (i.e., positive). /// To balance the contributions from unobserved and obverved in the overall loss function, this parameter is @@ -154,32 +205,57 @@ public sealed class Options /// Alpha = (# of observed entries) / (# of unobserved entries) can make observed and unobserved entries equally important /// in the minimized loss function. However, the best setting in machine learning is alwasy data-depedent so user still needs to /// try multiple values. - /// + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Importance of unobserved entries' loss in one-class matrix factorization.")] [TGUI(SuggestedSweeps = "1,0.01,0.0001,0.000001")] [TlcModule.SweepableDiscreteParam("Alpha", new object[] { 1f, 0.01f, 0.0001f, 0.000001f })] - public double Alpha = 0.0001; + public double Alpha = Defaults.Alpha; /// - /// Desired negative entries value in one-class matrix factorization. In one-class matrix factorization, all matrix values observed are one - /// (which can be viewed as positive cases in binary classification) while unobserved values (which can be viewed as negative cases in binary - /// classification) need to be specified manually using this option. + /// Desired negative entries value in one-class matrix factorization. Applicable if set to /// + /// + /// In one-class matrix factorization, all matrix values observed are one (which can be viewed as positive cases in binary classification) + /// while unobserved values (which can be viewed as negative cases in binary classification) need to be specified manually using this option. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Desired negative entries' value in one-class matrix factorization")] [TGUI(SuggestedSweeps = "0.000001,0,0001,0.01")] [TlcModule.SweepableDiscreteParam("C", new object[] { 0.000001f, 0.0001f, 0.01f })] - public double C = 0.000001f; + public double C = Defaults.C; + /// + /// Number of threads will be used during training. If unspecified all aviable threads will be use. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Number of threads can be used in the training procedure.", ShortName = "t")] public int? NumThreads; + /// + /// Suppress writing additional information to output. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Suppress writing additional information to output.")] - public bool Quiet; + public bool Quiet = Defaults.Quiet; + /// + /// Force the factor matrices to be non-negative. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Force the factor matrices to be non-negative.", ShortName = "nn")] - public bool NonNegative; + public bool NonNegative = Defaults.NonNegative; }; + [BestFriend] + internal static class Defaults + { + public const bool Quiet = false; + public const bool NonNegative = false; + public const double C = 0.000001f; + public const double Alpha = 0.0001f; + public const double LearningRate = 0.1; + public const int NumIterations = 20; + public const int ApproximationRank = 8; + public const double Lambda = 0.1; + public const LossFunctionType LossFunction = LossFunctionType.SquareLossRegression; + } + internal const string Summary = "From pairs of row/column indices and a value of a matrix, this trains a predictor capable of filling in unknown entries of the matrix, " + "using a low-rank matrix factorization. This technique is often used in recommender system, where the row and column indices indicate users and items, " + "and the values of the matrix are ratings. "; @@ -197,7 +273,8 @@ public sealed class Options private readonly bool _doNmf; public override PredictionKind PredictionKind => PredictionKind.Recommendation; - public const string LoadNameValue = "MatrixFactorization"; + + internal const string LoadNameValue = "MatrixFactorization"; /// /// The row index, column index, and label columns needed to specify the training matrix. This trainer uses tuples of (row index, column index, label value) to specify a matrix. @@ -238,18 +315,18 @@ internal MatrixFactorizationTrainer(IHostEnvironment env, Options options) : bas { const string posError = "Parameter must be positive"; Host.CheckValue(options, nameof(options)); - Host.CheckUserArg(options.K > 0, nameof(options.K), posError); + Host.CheckUserArg(options.ApproximationRank > 0, nameof(options.ApproximationRank), posError); Host.CheckUserArg(!options.NumThreads.HasValue || options.NumThreads > 0, nameof(options.NumThreads), posError); Host.CheckUserArg(options.NumIterations > 0, nameof(options.NumIterations), posError); Host.CheckUserArg(options.Lambda > 0, nameof(options.Lambda), posError); - Host.CheckUserArg(options.Eta > 0, nameof(options.Eta), posError); + Host.CheckUserArg(options.LearningRate > 0, nameof(options.LearningRate), posError); Host.CheckUserArg(options.Alpha > 0, nameof(options.Alpha), posError); _fun = (int)options.LossFunction; _lambda = options.Lambda; - _k = options.K; + _k = options.ApproximationRank; _iter = options.NumIterations; - _eta = options.Eta; + _eta = options.LearningRate; _alpha = options.Alpha; _c = options.C; _threads = options.NumThreads ?? Environment.ProcessorCount; @@ -270,21 +347,27 @@ internal MatrixFactorizationTrainer(IHostEnvironment env, Options options) : bas /// The name of the column hosting the matrix's column IDs. /// The name of the column hosting the matrix's row IDs. /// The name of the label column. + /// Rank of approximation matrixes. + /// Initial learning rate. It specifies the speed of the training algorithm. + /// Number of training iterations. [BestFriend] internal MatrixFactorizationTrainer(IHostEnvironment env, string matrixColumnIndexColumnName, string matrixRowIndexColumnName, - string labelColumn = DefaultColumnNames.Label) + string labelColumn = DefaultColumnNames.Label, + int approximationRank = Defaults.ApproximationRank, + double learningRate = Defaults.LearningRate, + int numIterations = Defaults.NumIterations) : base(env, LoadNameValue) { var args = new Options(); _fun = (int)args.LossFunction; - _lambda = args.Lambda; - _k = args.K; - _iter = args.NumIterations; - _eta = args.Eta; + _k = approximationRank; + _iter = numIterations; + _eta = learningRate; _alpha = args.Alpha; + _lambda = args.Lambda; _c = args.C; _threads = args.NumThreads ?? Environment.ProcessorCount; _quiet = args.Quiet; @@ -301,7 +384,7 @@ internal MatrixFactorizationTrainer(IHostEnvironment env, /// Train a matrix factorization model based on training data, validation data, and so on in the given context. /// /// The information collection needed for training. for details. - private protected override MatrixFactorizationPredictor Train(TrainContext context) + private protected override MatrixFactorizationModelParameters Train(TrainContext context) { Host.CheckValue(context, nameof(context)); using (var ch = Host.Start("Training")) @@ -310,7 +393,7 @@ private protected override MatrixFactorizationPredictor Train(TrainContext conte } } - private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data, RoleMappedData validData = null) + private MatrixFactorizationModelParameters TrainCore(IChannel ch, RoleMappedData data, RoleMappedData validData = null) { Host.AssertValue(ch); ch.AssertValue(data); @@ -321,7 +404,7 @@ private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data, var labelCol = data.Schema.Label.Value; if (labelCol.Type != NumberType.R4 && labelCol.Type != NumberType.R8) throw ch.Except("Column '{0}' for label should be floating point, but is instead {1}", labelCol.Name, labelCol.Type); - MatrixFactorizationPredictor predictor; + MatrixFactorizationModelParameters predictor; if (validData != null) { ch.CheckValue(validData, nameof(validData)); @@ -362,7 +445,7 @@ private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data, using (var buffer = PrepareBuffer()) { buffer.Train(ch, rowCount, colCount, cursor, labGetter, matrixRowIndexGetter, matrixColumnIndexGetter); - predictor = new MatrixFactorizationPredictor(Host, buffer, (KeyType)matrixColumnIndexColInfo.Type, (KeyType)matrixRowIndexColInfo.Type); + predictor = new MatrixFactorizationModelParameters(Host, buffer, (KeyType)matrixColumnIndexColInfo.Type, (KeyType)matrixRowIndexColInfo.Type); } } else @@ -380,7 +463,7 @@ private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data, buffer.TrainWithValidation(ch, rowCount, colCount, cursor, labGetter, matrixRowIndexGetter, matrixColumnIndexGetter, validCursor, validLabelGetter, validMatrixRowIndexGetter, validMatrixColumnIndexGetter); - predictor = new MatrixFactorizationPredictor(Host, buffer, (KeyType)matrixColumnIndexColInfo.Type, (KeyType)matrixRowIndexColInfo.Type); + predictor = new MatrixFactorizationModelParameters(Host, buffer, (KeyType)matrixColumnIndexColInfo.Type, (KeyType)matrixRowIndexColInfo.Type); } } } @@ -403,7 +486,7 @@ private SafeTrainingAndModelBuffer PrepareBuffer() /// The validation data set. public MatrixFactorizationPredictionTransformer Train(IDataView trainData, IDataView validationData = null) { - MatrixFactorizationPredictor model = null; + MatrixFactorizationModelParameters model = null; var roles = new List>(); roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Label, LabelName)); @@ -427,6 +510,10 @@ public MatrixFactorizationPredictionTransformer Train(IDataView trainData, IData /// The training data set. public MatrixFactorizationPredictionTransformer Fit(IDataView input) => Train(input); + /// + /// Schema propagation for transformers. Returns the output schema of the data, if + /// the input schema is like the one provided. + /// public SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); diff --git a/src/Microsoft.ML.Recommender/RecommenderCatalog.cs b/src/Microsoft.ML.Recommender/RecommenderCatalog.cs index 448595ec41..637e4962ea 100644 --- a/src/Microsoft.ML.Recommender/RecommenderCatalog.cs +++ b/src/Microsoft.ML.Recommender/RecommenderCatalog.cs @@ -55,11 +55,24 @@ internal RecommendationTrainers(RecommendationCatalog catalog) /// The name of the column hosting the matrix's column IDs. /// The name of the column hosting the matrix's row IDs. /// The name of the label column. + /// Rank of approximation matrixes. + /// Initial learning rate. It specifies the speed of the training algorithm. + /// Number of training iterations. + /// + /// + /// + /// public MatrixFactorizationTrainer MatrixFactorization( string matrixColumnIndexColumnName, string matrixRowIndexColumnName, - string labelColumn = DefaultColumnNames.Label) - => new MatrixFactorizationTrainer(Owner.Environment, matrixColumnIndexColumnName, matrixRowIndexColumnName, labelColumn); + string labelColumn = DefaultColumnNames.Label, + int approximationRank = MatrixFactorizationTrainer.Defaults.ApproximationRank, + double learningRate = MatrixFactorizationTrainer.Defaults.LearningRate, + int numIterations = MatrixFactorizationTrainer.Defaults.NumIterations) + => new MatrixFactorizationTrainer(Owner.Environment, matrixColumnIndexColumnName, matrixRowIndexColumnName, labelColumn, + approximationRank, learningRate, numIterations); /// /// Train a matrix factorization model. It factorizes the training matrix into the product of two low-rank matrices. @@ -74,7 +87,7 @@ public MatrixFactorizationTrainer MatrixFactorization( /// /// /// /// public MatrixFactorizationTrainer MatrixFactorization( @@ -115,13 +128,13 @@ public RegressionMetrics Evaluate(IDataView data, string label = DefaultColumnNa /// If the is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public (RegressionMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + public CrossValidationResult[] CrossValidate( IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => new CrossValidationResult(x.Model, Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); } } } diff --git a/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs b/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs index 2ce8ae224f..ba7f0b2fce 100644 --- a/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs +++ b/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs @@ -12,6 +12,7 @@ using Microsoft.ML.Command; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; +using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; using Microsoft.ML.Tools; @@ -20,7 +21,7 @@ using Microsoft.ML.ExperimentVisualization; #endif -namespace Microsoft.ML.Internal.Internallearn.ResultProcessor +namespace Microsoft.ML.ResultProcessor { using Float = System.Single; /// diff --git a/src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs b/src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs new file mode 100644 index 0000000000..83fafd8658 --- /dev/null +++ b/src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.ML.Data; + +namespace Microsoft.ML.SamplesUtils +{ + /// + /// Utilities for creating console outputs in samples' code. + /// + public static class ConsoleUtils + { + /// + /// Pretty-print BinaryClassificationMetrics objects. + /// + /// Binary classification metrics. + public static void PrintMetrics(BinaryClassificationMetrics metrics) + { + Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}"); + Console.WriteLine($"AUC: {metrics.Auc:F2}"); + Console.WriteLine($"F1 Score: {metrics.F1Score:F2}"); + Console.WriteLine($"Negative Precision: {metrics.NegativePrecision:F2}"); + Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}"); + Console.WriteLine($"Positive Precision: {metrics.PositivePrecision:F2}"); + Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}"); + } + } +} diff --git a/src/Microsoft.ML.SamplesUtils/Microsoft.ML.SamplesUtils.csproj b/src/Microsoft.ML.SamplesUtils/Microsoft.ML.SamplesUtils.csproj index e4d6c5d504..0bdb047d42 100644 --- a/src/Microsoft.ML.SamplesUtils/Microsoft.ML.SamplesUtils.csproj +++ b/src/Microsoft.ML.SamplesUtils/Microsoft.ML.SamplesUtils.csproj @@ -6,7 +6,9 @@ + + diff --git a/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs b/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs index edaf2d55c5..918cd26a9d 100644 --- a/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs +++ b/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs @@ -7,6 +7,7 @@ using System.IO; using System.Net; using Microsoft.Data.DataView; +using Microsoft.ML; using Microsoft.ML.Data; namespace Microsoft.ML.SamplesUtils @@ -17,7 +18,12 @@ public static class DatasetUtils /// Downloads the housing dataset from the ML.NET repo. /// public static string DownloadHousingRegressionDataset() - => Download("https://raw.githubusercontent.com/dotnet/machinelearning/024bd4452e1d3660214c757237a19d6123f951ca/test/data/housing.txt", "housing.txt"); + { + var fileName = "housing.txt"; + if (!File.Exists(fileName)) + Download("https://raw.githubusercontent.com/dotnet/machinelearning/024bd4452e1d3660214c757237a19d6123f951ca/test/data/housing.txt", fileName); + return fileName; + } public static IDataView LoadHousingRegressionDataset(MLContext mlContext) { @@ -81,6 +87,57 @@ public static string DownloadSentimentDataset() public static string DownloadAdultDataset() => Download("https://raw.githubusercontent.com/dotnet/machinelearning/244a8c2ac832657af282aa312d568211698790aa/test/data/adult.train", "adult.txt"); + public static IDataView LoadFeaturizedAdultDataset(MLContext mlContext) + { + // Download the file + string dataFile = DownloadAdultDataset(); + + // Define the columns to read + var reader = mlContext.Data.CreateTextLoader( + columns: new[] + { + new TextLoader.Column("age", DataKind.R4, 0), + new TextLoader.Column("workclass", DataKind.TX, 1), + new TextLoader.Column("fnlwgt", DataKind.R4, 2), + new TextLoader.Column("education", DataKind.TX, 3), + new TextLoader.Column("education-num", DataKind.R4, 4), + new TextLoader.Column("marital-status", DataKind.TX, 5), + new TextLoader.Column("occupation", DataKind.TX, 6), + new TextLoader.Column("relationship", DataKind.TX, 7), + new TextLoader.Column("ethnicity", DataKind.TX, 8), + new TextLoader.Column("sex", DataKind.TX, 9), + new TextLoader.Column("capital-gain", DataKind.R4, 10), + new TextLoader.Column("capital-loss", DataKind.R4, 11), + new TextLoader.Column("hours-per-week", DataKind.R4, 12), + new TextLoader.Column("native-country", DataKind.R4, 13), + new TextLoader.Column("IsOver50K", DataKind.BL, 14), + }, + separatorChar: ',', + hasHeader: true + ); + + // Create data featurizing pipeline + var pipeline = + // Convert categorical features to one-hot vectors + mlContext.Transforms.Categorical.OneHotEncoding("workclass") + .Append(mlContext.Transforms.Categorical.OneHotEncoding("education")) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("marital-status")) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("occupation")) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("relationship")) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("ethnicity")) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("native-country")) + // Combine all features into one feature vector + .Append(mlContext.Transforms.Concatenate("Features", "workclass", "education", "marital-status", + "occupation", "relationship", "ethnicity", "native-country", "age", "education-num", + "capital-gain", "capital-loss", "hours-per-week")) + // Min-max normalized all the features + .Append(mlContext.Transforms.Normalize("Features")); + + var data = reader.Read(dataFile); + var featurizedData = pipeline.Fit(data).Transform(data); + return featurizedData; + } + /// /// Downloads the breast cancer dataset from the ML.NET repo. /// @@ -121,14 +178,14 @@ public static string DownloadTensorFlowSentimentModel() string remotePath = "https://github.com/dotnet/machinelearning-testdata/raw/master/Microsoft.ML.TensorFlow.TestModels/sentiment_model/"; string path = "sentiment_model"; - if(!Directory.Exists(path)) + if (!Directory.Exists(path)) Directory.CreateDirectory(path); string varPath = Path.Combine(path, "variables"); if (!Directory.Exists(varPath)) Directory.CreateDirectory(varPath); - Download(Path.Combine(remotePath, "saved_model.pb"), Path.Combine(path,"saved_model.pb")); + Download(Path.Combine(remotePath, "saved_model.pb"), Path.Combine(path, "saved_model.pb")); Download(Path.Combine(remotePath, "imdb_word_index.csv"), Path.Combine(path, "imdb_word_index.csv")); Download(Path.Combine(remotePath, "variables", "variables.data-00000-of-00001"), Path.Combine(varPath, "variables.data-00000-of-00001")); Download(Path.Combine(remotePath, "variables", "variables.index"), Path.Combine(varPath, "variables.index")); @@ -221,7 +278,7 @@ public static IEnumerable GetTopicsData() public class SampleTemperatureData { - public DateTime Date {get; set; } + public DateTime Date { get; set; } public float Temperature { get; set; } } @@ -374,7 +431,7 @@ public class BinaryLabelFloatFeatureVectorSample public float[] Features; } - public static IEnumerable GenerateBinaryLabelFloatFeatureVectorSamples(int exampleCount) + public static IEnumerable GenerateBinaryLabelFloatFeatureVectorSamples(int exampleCount) { var rnd = new Random(0); var data = new List(); @@ -405,7 +462,7 @@ public class FloatLabelFloatFeatureVectorSample public float[] Features; } - public static IEnumerable GenerateFloatLabelFloatFeatureVectorSamples(int exampleCount, double naRate = 0) + public static IEnumerable GenerateFloatLabelFloatFeatureVectorSamples(int exampleCount, double naRate = 0) { var rnd = new Random(0); var data = new List(); @@ -446,17 +503,20 @@ public class FfmExample public float[] Field2; } - public static IEnumerable GenerateFfmSamples(int exampleCount) + public static IEnumerable GenerateFfmSamples(int exampleCount) { var rnd = new Random(0); var data = new List(); for (int i = 0; i < exampleCount; ++i) { // Initialize an example with a random label and an empty feature vector. - var sample = new FfmExample() { Label = rnd.Next() % 2 == 0, + var sample = new FfmExample() + { + Label = rnd.Next() % 2 == 0, Field0 = new float[_simpleBinaryClassSampleFeatureLength], Field1 = new float[_simpleBinaryClassSampleFeatureLength], - Field2 = new float[_simpleBinaryClassSampleFeatureLength] }; + Field2 = new float[_simpleBinaryClassSampleFeatureLength] + }; // Fill feature vector according the assigned label. for (int j = 0; j < 10; ++j) { @@ -546,5 +606,49 @@ public static List GenerateRandomMulticlassClas } return examples; } + + // The following variables defines the shape of a matrix. Its shape is _synthesizedMatrixRowCount-by-_synthesizedMatrixColumnCount. + // Because in ML.NET key type's minimal value is zero, the first row index is always zero in C# data structure (e.g., MatrixColumnIndex=0 + // and MatrixRowIndex=0 in MatrixElement below specifies the value at the upper-left corner in the training matrix). If user's row index + // starts with 1, their row index 1 would be mapped to the 2nd row in matrix factorization module and their first row may contain no values. + // This behavior is also true to column index. + private const int _synthesizedMatrixFirstColumnIndex = 1; + private const int _synthesizedMatrixFirstRowIndex = 1; + private const int _synthesizedMatrixColumnCount = 60; + private const int _synthesizedMatrixRowCount = 100; + + // A data structure used to encode a single value in matrix + public class MatrixElement + { + // Matrix column index is at most _synthesizedMatrixColumnCount + _synthesizedMatrixFirstColumnIndex. + [KeyType(Count = _synthesizedMatrixColumnCount + _synthesizedMatrixFirstColumnIndex)] + public uint MatrixColumnIndex; + // Matrix row index is at most _synthesizedMatrixRowCount + _synthesizedMatrixFirstRowIndex. + [KeyType(Count = _synthesizedMatrixRowCount + _synthesizedMatrixFirstRowIndex)] + public uint MatrixRowIndex; + // The value at the column MatrixColumnIndex and row MatrixRowIndex. + public float Value; + } + + // A data structure used to encode prediction result. Comparing with MatrixElement, The field Value in MatrixElement is + // renamed to Score because Score is the default name of matrix factorization's output. + public class MatrixElementForScore + { + [KeyType(Count = _synthesizedMatrixColumnCount + _synthesizedMatrixFirstColumnIndex)] + public uint MatrixColumnIndex; + [KeyType(Count = _synthesizedMatrixRowCount + _synthesizedMatrixFirstRowIndex)] + public uint MatrixRowIndex; + public float Score; + } + + // Create an in-memory matrix as a list of tuples (column index, row index, value). + public static List GetRecommendationData() + { + var dataMatrix = new List(); + for (uint i = _synthesizedMatrixFirstColumnIndex; i < _synthesizedMatrixFirstColumnIndex + _synthesizedMatrixColumnCount; ++i) + for (uint j = _synthesizedMatrixFirstRowIndex; j < _synthesizedMatrixFirstRowIndex + _synthesizedMatrixRowCount; ++j) + dataMatrix.Add(new MatrixElement() { MatrixColumnIndex = i, MatrixRowIndex = j, Value = (i + j) % 5 }); + return dataMatrix; + } } } diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs index dcfe680c4d..ad81336f58 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs @@ -511,6 +511,10 @@ public FieldAwareFactorizationMachinePredictionTransformer Train(IDataView train public FieldAwareFactorizationMachinePredictionTransformer Fit(IDataView input) => Train(input); + /// + /// Schema propagation for transformers. Returns the output schema of the data, if + /// the input schema is like the one provided. + /// public SchemaShape GetOutputSchema(SchemaShape inputSchema) { diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs index 4bf005b25a..9aac84dae1 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs @@ -283,7 +283,7 @@ public float[] GetLatentWeights() } } - public sealed class FieldAwareFactorizationMachinePredictionTransformer : PredictionTransformerBase, ICanSaveModel + public sealed class FieldAwareFactorizationMachinePredictionTransformer : PredictionTransformerBase { public const string LoaderSignature = "FAFMPredXfer"; @@ -387,7 +387,7 @@ public override Schema GetOutputSchema(Schema inputSchema) /// Saves the transformer to file. /// /// The that facilitates saving to the . - public void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs index 3a88c62128..df52e48433 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs @@ -14,11 +14,11 @@ using Microsoft.ML.Internal.Calibration; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Learners; using Microsoft.ML.Model; using Microsoft.ML.Model.Onnx; using Microsoft.ML.Model.Pfa; using Microsoft.ML.Numeric; +using Microsoft.ML.Trainers; using Newtonsoft.Json.Linq; // This is for deserialization from a model repository. @@ -36,7 +36,7 @@ "Poisson Regression Executor", PoissonRegressionModelParameters.LoaderSignature)] -namespace Microsoft.ML.Learners +namespace Microsoft.ML.Trainers { public abstract class LinearModelParameters : ModelParametersBase, IValueMapper, diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictorUtils.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictorUtils.cs index f9f64f3689..b61f443e02 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictorUtils.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictorUtils.cs @@ -13,7 +13,7 @@ using Microsoft.ML.Internal.Utilities; using Float = System.Single; -namespace Microsoft.ML.Learners +namespace Microsoft.ML.Trainers { /// /// Helper methods for linear predictors diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs index c6c6602c52..ca335e033f 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -14,7 +14,7 @@ using Microsoft.ML.Numeric; using Microsoft.ML.Training; -namespace Microsoft.ML.Learners +namespace Microsoft.ML.Trainers { public abstract class LbfgsTrainerBase : TrainerEstimatorBase where TTransformer : ISingleFeaturePredictionTransformer diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs index 3cdb24c7df..d5fe66bdd7 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs @@ -13,8 +13,8 @@ using Microsoft.ML.Internal.Calibration; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Learners; using Microsoft.ML.Numeric; +using Microsoft.ML.Trainers; using Microsoft.ML.Training; [assembly: LoadableClass(LogisticRegression.Summary, typeof(LogisticRegression), typeof(LogisticRegression.Options), @@ -26,7 +26,7 @@ [assembly: LoadableClass(typeof(void), typeof(LogisticRegression), null, typeof(SignatureEntryPointModule), LogisticRegression.LoadNameValue)] -namespace Microsoft.ML.Learners +namespace Microsoft.ML.Trainers { /// diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 0ed3f008da..9d809dd3fa 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -14,7 +14,6 @@ using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Learners; using Microsoft.ML.Model; using Microsoft.ML.Model.Onnx; using Microsoft.ML.Model.Pfa; @@ -35,7 +34,7 @@ "Multiclass LR Executor", MulticlassLogisticRegressionModelParameters.LoaderSignature)] -namespace Microsoft.ML.Learners +namespace Microsoft.ML.Trainers { /// /// diff --git a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs index d2efdc8d57..e5c2854da7 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs @@ -12,15 +12,15 @@ using Microsoft.ML.Internal.CpuMath; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Learners; using Microsoft.ML.Model; +using Microsoft.ML.Trainers; // This is for deserialization from a model repository. [assembly: LoadableClass(typeof(LinearModelStatistics), null, typeof(SignatureLoadModel), "Linear Model Statistics", LinearModelStatistics.LoaderSignature)] -namespace Microsoft.ML.Learners +namespace Microsoft.ML.Trainers { /// /// Represents a coefficient statistics object. @@ -166,7 +166,7 @@ internal static LinearModelStatistics Create(IHostEnvironment env, ModelLoadCont return new LinearModelStatistics(env, ctx); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.AssertValue(_env); _env.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs index fdd32f2f20..448a078a8b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs @@ -14,7 +14,7 @@ using Microsoft.ML.Trainers.Online; using Microsoft.ML.Training; -namespace Microsoft.ML.Learners +namespace Microsoft.ML.Trainers { using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs index cf7a083985..53f0f35fca 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs @@ -16,7 +16,6 @@ using Microsoft.ML.Internal.Calibration; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Learners; using Microsoft.ML.Model; using Microsoft.ML.Model.Pfa; using Microsoft.ML.Trainers; diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs index 5e3dd33f8b..344f9e5f6a 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs @@ -11,7 +11,6 @@ using Microsoft.ML.Internal.Calibration; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Learners; using Microsoft.ML.Model; using Microsoft.ML.Trainers; using Microsoft.ML.Training; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs index bf7a88d27d..9234cd2df4 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs @@ -9,7 +9,6 @@ using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Learners; using Microsoft.ML.Numeric; // TODO: Check if it works properly if Averaged is set to false @@ -18,38 +17,89 @@ namespace Microsoft.ML.Trainers.Online { public abstract class AveragedLinearArguments : OnlineLinearArguments { + /// + /// Learning rate. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr", SortOrder = 50)] [TGUI(Label = "Learning rate", SuggestedSweeps = "0.01,0.1,0.5,1.0")] [TlcModule.SweepableDiscreteParam("LearningRate", new object[] { 0.01, 0.1, 0.5, 1.0 })] public float LearningRate = AveragedDefaultArgs.LearningRate; + /// + /// Determine whether to decrease the or not. + /// + /// + /// to decrease the as iterations progress; otherwise, . + /// Default is . + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Decrease learning rate", ShortName = "decreaselr", SortOrder = 50)] [TGUI(Label = "Decrease Learning Rate", Description = "Decrease learning rate as iterations progress")] [TlcModule.SweepableDiscreteParam("DecreaseLearningRate", new object[] { false, true })] public bool DecreaseLearningRate = AveragedDefaultArgs.DecreaseLearningRate; + /// + /// Number of examples after which weights will be reset to the current average. + /// + /// + /// Default is , which disables this feature. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Number of examples after which weights will be reset to the current average", ShortName = "numreset")] public long? ResetWeightsAfterXExamples = null; + /// + /// Determines when to update averaged weights. + /// + /// + /// to update averaged weights only when loss is nonzero. + /// to update averaged weights on every example. + /// Default is . + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Instead of updating averaged weights on every example, only update when loss is nonzero", ShortName = "lazy")] public bool DoLazyUpdates = true; + /// + /// L2 weight for regularization. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization Weight", ShortName = "reg", SortOrder = 50)] [TGUI(Label = "L2 Regularization Weight")] [TlcModule.SweepableFloatParam("L2RegularizerWeight", 0.0f, 0.4f)] public float L2RegularizerWeight = AveragedDefaultArgs.L2RegularizerWeight; + /// + /// Extra weight given to more recent updates. + /// + /// + /// Default is 0, i.e. no extra gain. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Extra weight given to more recent updates", ShortName = "rg")] public float RecencyGain = 0; + /// + /// Determines whether is multiplicative or additive. + /// + /// + /// means is multiplicative. + /// means is additive. + /// Default is . + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Whether Recency Gain is multiplicative (vs. additive)", ShortName = "rgm")] public bool RecencyGainMulti = false; + /// + /// Determines whether to do averaging or not. + /// + /// + /// to do averaging; otherwise, . + /// Default is . + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Do averaging?", ShortName = "avg")] public bool Averaged = true; + /// + /// The inexactness tolerance for averaging. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "The inexactness tolerance for averaging", ShortName = "avgtol")] - public float AveragedTolerance = (float)1e-2; + internal float AveragedTolerance = (float)1e-2; [BestFriend] internal class AveragedDefaultArgs : OnlineDefaultArgs diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index b3659974fa..d3dbdf619e 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -11,7 +11,6 @@ using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Calibration; using Microsoft.ML.Internal.Internallearn; -using Microsoft.ML.Learners; using Microsoft.ML.Numeric; using Microsoft.ML.Trainers.Online; using Microsoft.ML.Training; @@ -25,12 +24,12 @@ namespace Microsoft.ML.Trainers.Online { - // This is an averaged perceptron classifier. - // Configurable subcomponents: - // - Loss function. By default, hinge loss (aka max-margin avgd perceptron) - // - Feature normalization. By default, rescaling between min and max values for every feature - // - Prediction calibration to produce probabilities. Off by default, if on, uses exponential (aka Platt) calibration. - /// + /// + /// This is averaged perceptron trainer. + /// + /// + /// For usage details, please see + /// public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer, LinearBinaryModelParameters> { public const string LoadNameValue = "AveragedPerceptron"; @@ -42,12 +41,21 @@ public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer + /// The custom loss. + /// [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] public ISupportClassificationLossFactory LossFunction = new HingeLoss.Arguments(); + /// + /// The calibrator for producing probabilities. Default is exponential (aka Platt) calibration. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "The calibrator kind to apply to the predictor. Specify null for no calibration", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] public ICalibratorTrainerFactory Calibrator = new PlattCalibratorTrainerFactory(); + /// + /// The maximum number of examples to use when training the calibrator. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] public int MaxCalibrationExamples = 1000000; @@ -101,9 +109,9 @@ internal AveragedPerceptronTrainer(IHostEnvironment env, Options options) /// The name of the feature column. /// The optional name of the weights column. /// The learning rate. - /// Wheather to decrease learning rate as iterations progress. + /// Whether to decrease learning rate as iterations progress. /// L2 Regularization Weight. - /// The number of training iteraitons. + /// The number of training iterations. internal AveragedPerceptronTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -121,7 +129,7 @@ internal AveragedPerceptronTrainer(IHostEnvironment env, LearningRate = learningRate, DecreaseLearningRate = decreaseLearningRate, L2RegularizerWeight = l2RegularizerWeight, - NumIterations = numIterations, + NumberOfIterations = numIterations, LossFunction = new TrivialFactory(lossFunction ?? new HingeLoss()) }) { diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs index 184a6554aa..2cd75e623c 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs @@ -12,7 +12,6 @@ using Microsoft.ML.Internal.Calibration; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Learners; using Microsoft.ML.Numeric; using Microsoft.ML.Trainers.Online; using Microsoft.ML.Training; @@ -240,7 +239,7 @@ internal LinearSvmTrainer(IHostEnvironment env, LabelColumn = labelColumn, FeatureColumn = featureColumn, InitialWeights = weightsColumn, - NumIterations = numIterations, + NumberOfIterations = numIterations, }) { } diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs index 39ed31a6ec..b683aefb07 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs @@ -10,7 +10,6 @@ using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; -using Microsoft.ML.Learners; using Microsoft.ML.Numeric; using Microsoft.ML.Trainers.Online; using Microsoft.ML.Training; @@ -115,7 +114,7 @@ internal OnlineGradientDescentTrainer(IHostEnvironment env, LearningRate = learningRate, DecreaseLearningRate = decreaseLearningRate, L2RegularizerWeight = l2RegularizerWeight, - NumIterations = numIterations, + NumberOfIterations = numIterations, LabelColumn = labelColumn, FeatureColumn = featureColumn, InitialWeights = weightsColumn, diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index 662afe9a01..4dbdc8da57 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -11,7 +11,6 @@ using Microsoft.ML.Internal.Calibration; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Learners; using Microsoft.ML.Numeric; using Microsoft.ML.Training; @@ -20,27 +19,41 @@ namespace Microsoft.ML.Trainers.Online public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel { + /// + /// Number of training iterations through the data. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter", SortOrder = 50)] [TGUI(Label = "Number of Iterations", Description = "Number of training iterations through data", SuggestedSweeps = "1,10,100")] [TlcModule.SweepableLongParamAttribute("NumIterations", 1, 100, stepSize: 10, isLogScale: true)] - public int NumIterations = OnlineDefaultArgs.NumIterations; + public int NumberOfIterations = OnlineDefaultArgs.NumIterations; + /// + /// Initial weights and bias, comma-separated. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Initial Weights and bias, comma-separated", ShortName = "initweights")] [TGUI(NoSweep = true)] public string InitialWeights; - [Argument(ArgumentType.AtMostOnce, HelpText = "Init weights diameter", ShortName = "initwts", SortOrder = 140)] + /// + /// Initial weights and bias scale. + /// + /// + /// This property is only used if the provided value is positive and is not specified. + /// The weights and bias will be randomly selected from InitialWeights * [-0.5,0.5] interval with uniform distribution. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Init weights diameter", ShortName = "initwts, initWtsDiameter", SortOrder = 140)] [TGUI(Label = "Initial Weights Scale", SuggestedSweeps = "0,0.1,0.5,1")] [TlcModule.SweepableFloatParamAttribute("InitWtsDiameter", 0.0f, 1.0f, numSteps: 5)] - public float InitWtsDiameter = 0; + public float InitialWeightsDiameter = 0; + /// + /// to shuffle data for each training iteration; otherwise, . + /// Default is . + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to shuffle for each training iteration", ShortName = "shuf")] [TlcModule.SweepableDiscreteParamAttribute("Shuffle", new object[] { false, true })] public bool Shuffle = true; - [Argument(ArgumentType.AtMostOnce, HelpText = "Size of cache when trained in Scope", ShortName = "cache")] - public int StreamingCacheSize = 1000000; - [BestFriend] internal class OnlineDefaultArgs { @@ -135,13 +148,13 @@ protected TrainStateBase(IChannel ch, int numFeatures, LinearModelParameters pre Weights = new VBuffer(numFeatures, weightValues); Bias = float.Parse(weightStr[numFeatures], CultureInfo.InvariantCulture); } - else if (parent.Args.InitWtsDiameter > 0) + else if (parent.Args.InitialWeightsDiameter > 0) { var weightValues = new float[numFeatures]; for (int i = 0; i < numFeatures; i++) - weightValues[i] = parent.Args.InitWtsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5); + weightValues[i] = parent.Args.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5); Weights = new VBuffer(numFeatures, weightValues); - Bias = parent.Args.InitWtsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5); + Bias = parent.Args.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5); } else if (numFeatures <= 1000) Weights = VBufferUtils.CreateDense(numFeatures); @@ -239,9 +252,8 @@ private protected OnlineLinearTrainer(OnlineLinearArguments args, IHostEnvironme : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.InitialWeights)) { Contracts.CheckValue(args, nameof(args)); - Contracts.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), UserErrorPositive); - Contracts.CheckUserArg(args.InitWtsDiameter >= 0, nameof(args.InitWtsDiameter), UserErrorNonNegative); - Contracts.CheckUserArg(args.StreamingCacheSize > 0, nameof(args.StreamingCacheSize), UserErrorPositive); + Contracts.CheckUserArg(args.NumberOfIterations > 0, nameof(args.NumberOfIterations), UserErrorPositive); + Contracts.CheckUserArg(args.InitialWeightsDiameter >= 0, nameof(args.InitialWeightsDiameter), UserErrorNonNegative); Args = args; Name = name; @@ -300,7 +312,7 @@ private void TrainCore(IChannel ch, RoleMappedData data, TrainStateBase state) var cursorFactory = new FloatLabelCursor.Factory(data, cursorOpt); long numBad = 0; - while (state.Iteration < Args.NumIterations) + while (state.Iteration < Args.NumberOfIterations) { state.BeginIteration(ch); @@ -318,7 +330,7 @@ private void TrainCore(IChannel ch, RoleMappedData data, TrainStateBase state) { ch.Warning( "Skipped {0} instances with missing features during training (over {1} iterations; {2} inst/iter)", - numBad, Args.NumIterations, numBad / Args.NumIterations); + numBad, Args.NumberOfIterations, numBad / Args.NumberOfIterations); } } diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/doc.xml b/src/Microsoft.ML.StandardLearners/Standard/Online/doc.xml index 8e8f5dc2ba..292aeface5 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/doc.xml +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/doc.xml @@ -25,44 +25,5 @@ - - - - Averaged Perceptron Binary Classifier. - - - Perceptron is a classification algorithm that makes its predictions based on a linear function. - I.e., for an instance with feature values f0, f1,..., f_D-1, , the prediction is given by the sign of sigma[0,D-1] ( w_i * f_i), where w_0, w_1,...,w_D-1 are the weights computed by the algorithm. - - Perceptron is an online algorithm, i.e., it processes the instances in the training set one at a time. - The weights are initialized to be 0, or some random values. Then, for each example in the training set, the value of sigma[0, D-1] (w_i * f_i) is computed. - If this value has the same sign as the label of the current example, the weights remain the same. If they have opposite signs, - the weights vector is updated by either subtracting or adding (if the label is negative or positive, respectively) the feature vector of the current example, - multiplied by a factor 0 < a <= 1, called the learning rate. In a generalization of this algorithm, the weights are updated by adding the feature vector multiplied by the learning rate, - and by the gradient of some loss function (in the specific case described above, the loss is hinge-loss, whose gradient is 1 when it is non-zero). - - - In Averaged Perceptron (AKA voted-perceptron), the weight vectors are stored, - together with a weight that counts the number of iterations it survived (this is equivalent to storing the weight vector after every iteration, regardless of whether it was updated or not). - The prediction is then calculated by taking the weighted average of all the sums sigma[0, D-1] (w_i * f_i) or the different weight vectors. - - For more information see: - Wikipedia entry for Perceptron - Large Margin Classification Using the Perceptron Algorithm - - - - - - new AveragedPerceptronBinaryClassifier - { - NumIterations = 10, - L2RegularizerWeight = 0.01f, - LossFunction = new ExpLossClassificationLossFunction() - } - - - - diff --git a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs index caf4a37166..b389fc6034 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs @@ -10,7 +10,6 @@ using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Learners; using Microsoft.ML.Numeric; using Microsoft.ML.Trainers; using Microsoft.ML.Training; diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs index 191c906b78..ca07f7fbe5 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs @@ -18,7 +18,6 @@ using Microsoft.ML.Internal.CpuMath; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Learners; using Microsoft.ML.Numeric; using Microsoft.ML.Trainers; using Microsoft.ML.Training; diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index aceabb7b82..ff49788028 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -14,7 +14,6 @@ using Microsoft.ML.Internal.CpuMath; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Learners; using Microsoft.ML.Numeric; using Microsoft.ML.Trainers; using Microsoft.ML.Training; diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs index 49b65abe89..9524c48365 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs @@ -12,7 +12,6 @@ using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Learners; using Microsoft.ML.Trainers; using Microsoft.ML.Training; diff --git a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs index 10fa3b75c8..d9265513f8 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs @@ -10,7 +10,7 @@ using Microsoft.ML.Training; using Microsoft.ML.Transforms; -namespace Microsoft.ML.Learners +namespace Microsoft.ML.Trainers { public abstract class StochasticTrainerBase : TrainerEstimatorBase where TTransformer : ISingleFeaturePredictionTransformer diff --git a/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs b/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs index 6442e4eb01..d78464acde 100644 --- a/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs +++ b/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs @@ -5,7 +5,6 @@ using System; using Microsoft.ML.Data; using Microsoft.ML.Internal.Calibration; -using Microsoft.ML.Learners; using Microsoft.ML.Trainers; using Microsoft.ML.Trainers.Online; using Microsoft.ML.Training; @@ -191,17 +190,46 @@ public static SdcaMultiClassTrainer StochasticDualCoordinateAscent(this Multicla } /// - /// Predict a target using a linear binary classification model trained with the AveragedPerceptron trainer. + /// Predict a target using a linear binary classification model trained with averaged perceptron trainer. /// + /// + /// Perceptron is a classification algorithm that makes its predictions by finding a separating hyperplane. + /// For instance, with feature values f0, f1,..., f_D-1, the prediction is given by determining what side of the hyperplane the point falls into. + /// That is the same as the sign of sigma[0, D-1] (w_i * f_i), where w_0, w_1,..., w_D-1 are the weights computed by the algorithm. + /// + /// The perceptron is an online algorithm, which means it processes the instances in the training set one at a time. + /// It starts with a set of initial weights (zero, random, or initialized from a previous learner). Then, for each example in the training set, the weighted sum of the features (sigma[0, D-1] (w_i * f_i)) is computed. + /// If this value has the same sign as the label of the current example, the weights remain the same.If they have opposite signs, + /// the weights vector is updated by either subtracting or adding (if the label is negative or positive, respectively) the feature vector of the current example, + /// multiplied by a factor 0 < a <= 1, called the learning rate.In a generalization of this algorithm, the weights are updated by adding the feature vector multiplied by the learning rate, + /// and by the gradient of some loss function (in the specific case described above, the loss is hinge-loss, whose gradient is 1 when it is non-zero). + /// + /// In Averaged Perceptron (AKA voted-perceptron), the weight vectors are stored, + /// together with a weight that counts the number of iterations it survived (this is equivalent to storing the weight vector after every iteration, regardless of whether it was updated or not). + /// The prediction is then calculated by taking the weighted average of all the sums sigma[0, D-1] (w_i * f_i) or the different weight vectors. + /// + /// For more information see Wikipedia entry for Perceptron + /// or Large Margin Classification Using the Perceptron Algorithm + /// /// The binary classification catalog trainer object. /// The name of the label column, or dependent variable. /// The features, or independent variables. - /// The custom loss. + /// The custom loss. If , hinge loss will be used resulting in max-margin averaged perceptron. /// The optional example weights. - /// The learning Rate. - /// Decrease learning rate as iterations progress. - /// L2 regularization weight. + /// Learning rate. + /// + /// to decrease the as iterations progress; otherwise, . + /// Default is . + /// + /// L2 weight for regularization. /// Number of training iterations through the data. + /// + /// + /// + /// + /// public static AveragedPerceptronTrainer AveragedPerceptron( this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, string labelColumn = DefaultColumnNames.Label, @@ -220,10 +248,18 @@ public static AveragedPerceptronTrainer AveragedPerceptron( } /// - /// Predict a target using a linear binary classification model trained with the AveragedPerceptron trainer. + /// Predict a target using a linear binary classification model trained with averaged perceptron trainer using advanced options. + /// For usage details, please see /// /// The binary classification catalog trainer object. - /// Advanced arguments to the algorithm. + /// Trainer options. + /// + /// + /// + /// + /// public static AveragedPerceptronTrainer AveragedPerceptron( this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, AveragedPerceptronTrainer.Options options) { @@ -293,7 +329,7 @@ public static OnlineGradientDescentTrainer OnlineGradientDescent(this Regression } /// - /// Predict a target using a linear binary classification model trained with the trainer. + /// Predict a target using a linear binary classification model trained with the trainer. /// /// The binary classificaiton catalog trainer object. /// The label column name, or dependent variable. @@ -302,7 +338,7 @@ public static OnlineGradientDescentTrainer OnlineGradientDescent(this Regression /// Enforce non-negative weights. /// Weight of L1 regularization term. /// Weight of L2 regularization term. - /// Memory size for . Low=faster, less accurate. + /// Memory size for . Low=faster, less accurate. /// Threshold for optimizer convergence. /// /// @@ -327,7 +363,7 @@ public static LogisticRegression LogisticRegression(this BinaryClassificationCat } /// - /// Predict a target using a linear binary classification model trained with the trainer. + /// Predict a target using a linear binary classification model trained with the trainer. /// /// The binary classificaiton catalog trainer object. /// Advanced arguments to the algorithm. @@ -341,7 +377,7 @@ public static LogisticRegression LogisticRegression(this BinaryClassificationCat } /// - /// Predict a target using a linear regression model trained with the trainer. + /// Predict a target using a linear regression model trained with the trainer. /// /// The regression catalog trainer object. /// The labelColumn, or dependent variable. @@ -350,7 +386,7 @@ public static LogisticRegression LogisticRegression(this BinaryClassificationCat /// Weight of L1 regularization term. /// Weight of L2 regularization term. /// Threshold for optimizer convergence. - /// Memory size for . Low=faster, less accurate. + /// Memory size for . Low=faster, less accurate. /// Enforce non-negative weights. public static PoissonRegression PoissonRegression(this RegressionCatalog.RegressionTrainers catalog, string labelColumn = DefaultColumnNames.Label, @@ -368,7 +404,7 @@ public static PoissonRegression PoissonRegression(this RegressionCatalog.Regress } /// - /// Predict a target using a linear regression model trained with the trainer. + /// Predict a target using a linear regression model trained with the trainer. /// /// The regression catalog trainer object. /// Advanced arguments to the algorithm. @@ -382,7 +418,7 @@ public static PoissonRegression PoissonRegression(this RegressionCatalog.Regress } /// - /// Predict a target using a linear multiclass classification model trained with the trainer. + /// Predict a target using a linear multiclass classification model trained with the trainer. /// /// The . /// The labelColumn, or dependent variable. @@ -391,7 +427,7 @@ public static PoissonRegression PoissonRegression(this RegressionCatalog.Regress /// Enforce non-negative weights. /// Weight of L1 regularization term. /// Weight of L2 regularization term. - /// Memory size for . Low=faster, less accurate. + /// Memory size for . Low=faster, less accurate. /// Threshold for optimizer convergence. public static MulticlassLogisticRegression LogisticRegression(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, string labelColumn = DefaultColumnNames.Label, @@ -409,7 +445,7 @@ public static MulticlassLogisticRegression LogisticRegression(this MulticlassCla } /// - /// Predict a target using a linear multiclass classification model trained with the trainer. + /// Predict a target using a linear multiclass classification model trained with the trainer. /// /// The . /// Advanced arguments to the algorithm. diff --git a/src/Microsoft.ML.StaticPipe/LbfgsStatic.cs b/src/Microsoft.ML.StaticPipe/LbfgsStatic.cs index 5de22b6e89..7e2ed821c7 100644 --- a/src/Microsoft.ML.StaticPipe/LbfgsStatic.cs +++ b/src/Microsoft.ML.StaticPipe/LbfgsStatic.cs @@ -6,7 +6,6 @@ using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Calibration; -using Microsoft.ML.Learners; using Microsoft.ML.StaticPipe.Runtime; using Microsoft.ML.Trainers; @@ -20,7 +19,7 @@ namespace Microsoft.ML.StaticPipe public static class LbfgsBinaryClassificationStaticExtensions { /// - /// Predict a target using a linear binary classification model trained with the trainer. + /// Predict a target using a linear binary classification model trained with the trainer. /// /// The binary classificaiton catalog trainer object. /// The label, or dependent variable. @@ -29,7 +28,7 @@ public static class LbfgsBinaryClassificationStaticExtensions /// Enforce non-negative weights. /// Weight of L1 regularization term. /// Weight of L2 regularization term. - /// Memory size for . Low=faster, less accurate. + /// Memory size for . Low=faster, less accurate. /// Threshold for optimizer convergence. /// A delegate that is called every time the /// method is called on the @@ -66,7 +65,7 @@ public static (Scalar score, Scalar probability, Scalar pred } /// - /// Predict a target using a linear binary classification model trained with the trainer. + /// Predict a target using a linear binary classification model trained with the trainer. /// /// The binary classificaiton catalog trainer object. /// The label, or dependent variable. @@ -116,7 +115,7 @@ public static (Scalar score, Scalar probability, Scalar pred public static class LbfgsRegressionExtensions { /// - /// Predict a target using a linear regression model trained with the trainer. + /// Predict a target using a linear regression model trained with the trainer. /// /// The regression catalog trainer object. /// The label, or dependent variable. @@ -125,7 +124,7 @@ public static class LbfgsRegressionExtensions /// Enforce non-negative weights. /// Weight of L1 regularization term. /// Weight of L2 regularization term. - /// Memory size for . Low=faster, less accurate. + /// Memory size for . Low=faster, less accurate. /// Threshold for optimizer convergence. /// A delegate that is called every time the /// method is called on the @@ -162,7 +161,7 @@ public static Scalar PoissonRegression(this RegressionCatalog.RegressionT } /// - /// Predict a target using a linear regression model trained with the trainer. + /// Predict a target using a linear regression model trained with the trainer. /// /// The regression catalog trainer object. /// The label, or dependent variable. @@ -212,7 +211,7 @@ public static Scalar PoissonRegression(this RegressionCatalog.RegressionT public static class LbfgsMulticlassExtensions { /// - /// Predict a target using a linear multiclass classification model trained with the trainer. + /// Predict a target using a linear multiclass classification model trained with the trainer. /// /// The multiclass classification catalog trainer object. /// The label, or dependent variable. @@ -221,7 +220,7 @@ public static class LbfgsMulticlassExtensions /// Enforce non-negative weights. /// Weight of L1 regularization term. /// Weight of L2 regularization term. - /// Memory size for . Low=faster, less accurate. + /// Memory size for . Low=faster, less accurate. /// Threshold for optimizer convergence. /// A delegate that is called every time the /// method is called on the @@ -258,7 +257,7 @@ public static (Vector score, Key predictedLabel) } /// - /// Predict a target using a linear multiclass classification model trained with the trainer. + /// Predict a target using a linear multiclass classification model trained with the trainer. /// /// The multiclass classification catalog trainer object. /// The label, or dependent variable. diff --git a/src/Microsoft.ML.StaticPipe/MatrixFactorizationStatic.cs b/src/Microsoft.ML.StaticPipe/MatrixFactorizationStatic.cs index 0ca0f30854..2ecdf6e438 100644 --- a/src/Microsoft.ML.StaticPipe/MatrixFactorizationStatic.cs +++ b/src/Microsoft.ML.StaticPipe/MatrixFactorizationStatic.cs @@ -32,7 +32,7 @@ public static class MatrixFactorizationExtensions public static Scalar MatrixFactorization(this RegressionCatalog.RegressionTrainers catalog, Scalar label, Key matrixColumnIndex, Key matrixRowIndex, MatrixFactorizationTrainer.Options options, - Action onFit = null) + Action onFit = null) { Contracts.CheckValue(label, nameof(label)); Contracts.CheckValue(matrixColumnIndex, nameof(matrixColumnIndex)); diff --git a/src/Microsoft.ML.StaticPipe/OnlineLearnerStatic.cs b/src/Microsoft.ML.StaticPipe/OnlineLearnerStatic.cs index fea4867428..38165451a6 100644 --- a/src/Microsoft.ML.StaticPipe/OnlineLearnerStatic.cs +++ b/src/Microsoft.ML.StaticPipe/OnlineLearnerStatic.cs @@ -3,8 +3,8 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Learners; using Microsoft.ML.StaticPipe.Runtime; +using Microsoft.ML.Trainers; using Microsoft.ML.Trainers.Online; namespace Microsoft.ML.StaticPipe diff --git a/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs index 444a9fd7bc..62bc7f5e48 100644 --- a/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs @@ -4,7 +4,6 @@ using System; using Microsoft.ML.Internal.Calibration; -using Microsoft.ML.Learners; using Microsoft.ML.StaticPipe.Runtime; using Microsoft.ML.Trainers; diff --git a/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs index 357246a658..53992cf4ff 100644 --- a/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs @@ -49,8 +49,8 @@ public static (DataView trainSet, DataView testSet) TrainTestSplit(this stratName = indexer.Get(column); } - var (trainData, testData) = catalog.TrainTestSplit(data.AsDynamic, testFraction, stratName, seed); - return (new DataView(env, trainData, data.Shape), new DataView(env, testData, data.Shape)); + var split = catalog.TrainTestSplit(data.AsDynamic, testFraction, stratName, seed); + return (new DataView(env, split.TrainSet, data.Shape), new DataView(env, split.TestSet, data.Shape)); } /// @@ -105,9 +105,9 @@ public static (RegressionMetrics metrics, Transformer ( - x.metrics, - new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape), - new DataView(env, x.scoredTestData, estimator.Shape))) + x.Metrics, + new Transformer(env, (TTransformer)x.Model, data.Shape, estimator.Shape), + new DataView(env, x.ScoredHoldOutSet, estimator.Shape))) .ToArray(); } @@ -163,9 +163,9 @@ public static (MultiClassClassifierMetrics metrics, Transformer ( - x.metrics, - new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape), - new DataView(env, x.scoredTestData, estimator.Shape))) + x.Metrics, + new Transformer(env, (TTransformer)x.Model, data.Shape, estimator.Shape), + new DataView(env, x.ScoredHoldOutSet, estimator.Shape))) .ToArray(); } @@ -221,9 +221,9 @@ public static (BinaryClassificationMetrics metrics, Transformer ( - x.metrics, - new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape), - new DataView(env, x.scoredTestData, estimator.Shape))) + x.Metrics, + new Transformer(env, (TTransformer)x.Model, data.Shape, estimator.Shape), + new DataView(env, x.ScoredHoldOutSet, estimator.Shape))) .ToArray(); } @@ -279,9 +279,9 @@ public static (CalibratedBinaryClassificationMetrics metrics, Transformer ( - x.metrics, - new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape), - new DataView(env, x.scoredTestData, estimator.Shape))) + x.Metrics, + new Transformer(env, (TTransformer)x.Model, data.Shape, estimator.Shape), + new DataView(env, x.ScoredHoldOutSet, estimator.Shape))) .ToArray(); } } diff --git a/src/Microsoft.ML.Sweeper/Algorithms/KdoSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/KdoSweeper.cs index 7e0ca2c035..469988ec7c 100644 --- a/src/Microsoft.ML.Sweeper/Algorithms/KdoSweeper.cs +++ b/src/Microsoft.ML.Sweeper/Algorithms/KdoSweeper.cs @@ -9,7 +9,7 @@ using Microsoft.ML.CommandLine; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Sweeper.Algorithms; -using Microsoft.ML.Trainers.FastTree.Internal; +using Microsoft.ML.Trainers.FastTree; using Float = System.Single; [assembly: LoadableClass(typeof(KdoSweeper), typeof(KdoSweeper.Arguments), typeof(SignatureSweeper), diff --git a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs index 6c1738b297..8a81844eb0 100644 --- a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs +++ b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs @@ -13,7 +13,6 @@ using Microsoft.ML.Sweeper; using Microsoft.ML.Sweeper.Algorithms; using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Float = System.Single; [assembly: LoadableClass(typeof(SmacSweeper), typeof(SmacSweeper.Arguments), typeof(SignatureSweeper), diff --git a/src/Microsoft.ML.Sweeper/ConfigRunner.cs b/src/Microsoft.ML.Sweeper/ConfigRunner.cs index 082a3fb4c9..cbe2a99900 100644 --- a/src/Microsoft.ML.Sweeper/ConfigRunner.cs +++ b/src/Microsoft.ML.Sweeper/ConfigRunner.cs @@ -12,7 +12,7 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Sweeper; -using ResultProcessorInternal = Microsoft.ML.Internal.Internallearn.ResultProcessor; +using ResultProcessorInternal = Microsoft.ML.ResultProcessor; [assembly: LoadableClass(typeof(LocalExeConfigRunner), typeof(LocalExeConfigRunner.Arguments), typeof(SignatureConfigRunner), "Local Sweep Config Runner", "Local")] diff --git a/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs b/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs index 8b03b89cb9..a55f8d4647 100644 --- a/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs +++ b/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs @@ -8,7 +8,7 @@ using Microsoft.ML.CommandLine; using Microsoft.ML.Data; using Microsoft.ML.Sweeper; -using ResultProcessor = Microsoft.ML.Internal.Internallearn.ResultProcessor; +using ResultProcessor = Microsoft.ML.ResultProcessor; [assembly: LoadableClass(typeof(InternalSweepResultEvaluator), typeof(InternalSweepResultEvaluator.Arguments), typeof(SignatureSweepResultEvaluator), "TLC Sweep Result Evaluator", "TlcEvaluator", "Tlc")] diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index bef2bf1b1f..cfe041445e 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -667,15 +667,6 @@ internal static (TFDataType[] tfInputTypes, TFShape[] tfInputShapes) GetInputInf var tfInput = new TFOutput(session.Graph[inputs[i]]); tfInputTypes[i] = tfInput.OutputType; tfInputShapes[i] = session.Graph.GetTensorShape(tfInput); - if (tfInputShapes[i].NumDimensions != -1) - { - var newShape = new long[tfInputShapes[i].NumDimensions]; - newShape[0] = tfInputShapes[i][0] == -1 ? BatchSize : tfInputShapes[i][0]; - - for (int j = 1; j < tfInputShapes[i].NumDimensions; j++) - newShape[j] = tfInputShapes[i][j]; - tfInputShapes[i] = new TFShape(newShape); - } } return (tfInputTypes, tfInputShapes); } @@ -698,7 +689,14 @@ internal static (TFDataType[] tfOutputTypes, ColumnType[] outputTypes) GetOutput { var tfOutput = new TFOutput(session.Graph[outputs[i]]); var shape = session.Graph.GetTensorShape(tfOutput); + + // The transformer can only retreive the output as fixed length vector with shape of kind [-1, d1, d2, d3, ...] + // i.e. the first dimension (if unknown) is assumed to be batch dimension. + // If there are other dimension that are unknown the transformer will return a variable length vector. + // This is the work around in absence of reshape transformer. int[] dims = shape.NumDimensions > 0 ? shape.ToIntArray().Skip(shape[0] == -1 ? 1 : 0).ToArray() : new[] { 0 }; + for (int j = 0; j < dims.Length; j++) + dims[j] = dims[j] == -1 ? 0 : dims[j]; var type = TensorFlowUtils.Tf2MlNetType(tfOutput.OutputType); outputTypes[i] = new VectorType(type, dims); tfOutputTypes[i] = tfOutput.OutputType; @@ -709,7 +707,7 @@ internal static (TFDataType[] tfOutputTypes, ColumnType[] outputTypes) GetOutput private protected override IRowMapper MakeRowMapper(Schema inputSchema) => new Mapper(this, inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.AssertValue(ctx); ctx.CheckAtModel(); @@ -837,14 +835,22 @@ public Mapper(TensorFlowTransformer parent, Schema inputSchema) : var originalShape = _parent.TFInputShapes[i]; var shape = originalShape.ToIntArray(); - var colTypeDims = vecType.Dimensions.Prepend(1).Select(dim => (long)dim).ToArray(); + var colTypeDims = vecType.Dimensions.Select(dim => (long)dim).ToArray(); if (shape == null) _fullySpecifiedShapes[i] = new TFShape(colTypeDims); - else if (vecType.Dimensions.Length == 1) + else { // If the column is one dimension we make sure that the total size of the TF shape matches. // Compute the total size of the known dimensions of the shape. - int valCount = shape.Where(x => x > 0).Aggregate((x, y) => x * y); + int valCount = 1; + int numOfUnkDim = 0; + foreach (var s in shape) + { + if (s > 0) + valCount *= s; + else + numOfUnkDim++; + } // The column length should be divisible by this, so that the other dimensions can be integral. int typeValueCount = type.GetValueCount(); if (typeValueCount % valCount != 0) @@ -853,8 +859,8 @@ public Mapper(TensorFlowTransformer parent, Schema inputSchema) : // If the shape is multi-dimensional, we should be able to create the length of the vector by plugging // in a single value for the unknown shapes. For example, if the shape is [?,?,3], then there should exist a value // d such that d*d*3 is equal to the length of the input column. - var d = originalShape.NumDimensions > 2 ? Math.Pow(typeValueCount / valCount, 1.0 / (originalShape.NumDimensions - 2)) : 1; - if (originalShape.NumDimensions > 2 && d - (int)d != 0) + var d = numOfUnkDim > 0 ? Math.Pow(typeValueCount / valCount, 1.0 / numOfUnkDim) : 0; + if (d - (int)d != 0) throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is of length {typeValueCount}."); // Fill in the unknown dimensions. @@ -863,21 +869,10 @@ public Mapper(TensorFlowTransformer parent, Schema inputSchema) : l[ishape] = originalShape[ishape] == -1 ? (int)d : originalShape[ishape]; _fullySpecifiedShapes[i] = new TFShape(l); } - else - { - if (shape.Select((dim, j) => dim != -1 && dim != colTypeDims[j]).Any(b => b)) - throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is {vecType.ToString()}."); - - // Fill in the unknown dimensions. - var l = new long[originalShape.NumDimensions]; - for (int ishape = 0; ishape < originalShape.NumDimensions; ishape++) - l[ishape] = originalShape[ishape] == -1 ? colTypeDims[ishape] : originalShape[ishape]; - _fullySpecifiedShapes[i] = new TFShape(l); - } } } - public override void Save(ModelSaveContext ctx) => _parent.Save(ctx); + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); private class OutputCache { diff --git a/src/Microsoft.ML.TimeSeries.StaticPipe/TimeSeriesStatic.cs b/src/Microsoft.ML.TimeSeries.StaticPipe/TimeSeriesStatic.cs index 8e8f7e137b..8eec3a20e3 100644 --- a/src/Microsoft.ML.TimeSeries.StaticPipe/TimeSeriesStatic.cs +++ b/src/Microsoft.ML.TimeSeries.StaticPipe/TimeSeriesStatic.cs @@ -5,12 +5,10 @@ using System.Collections.Generic; using Microsoft.ML.Core.Data; using Microsoft.ML.StaticPipe.Runtime; -using Microsoft.ML.TimeSeriesProcessing; +using Microsoft.ML.Transforms.TimeSeries; namespace Microsoft.ML.StaticPipe { - using IidBase = Microsoft.ML.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase; - using SsaBase = Microsoft.ML.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase; /// /// Static API extension methods for . @@ -25,7 +23,7 @@ public OutColumn( Scalar input, int confidence, int changeHistoryLength, - IidBase.MartingaleType martingale, + MartingaleType martingale, double eps) : base(new Reconciler(confidence, changeHistoryLength, martingale, eps), input) { @@ -37,13 +35,13 @@ private sealed class Reconciler : EstimatorReconciler { private readonly int _confidence; private readonly int _changeHistoryLength; - private readonly IidBase.MartingaleType _martingale; + private readonly MartingaleType _martingale; private readonly double _eps; public Reconciler( int confidence, int changeHistoryLength, - IidBase.MartingaleType martingale, + MartingaleType martingale, double eps) { _confidence = confidence; @@ -60,11 +58,11 @@ public override IEstimator Reconcile(IHostEnvironment env, { Contracts.Assert(toOutput.Length == 1); var outCol = (OutColumn)toOutput[0]; - return new IidChangePointEstimator(env, + return new MLContext().Transforms.IidChangePointEstimator( outputNames[outCol], + inputNames[outCol.Input], _confidence, _changeHistoryLength, - inputNames[outCol.Input], _martingale, _eps); } @@ -77,7 +75,7 @@ public static Vector IidChangePointDetect( this Scalar input, int confidence, int changeHistoryLength, - IidBase.MartingaleType martingale = IidBase.MartingaleType.Power, + MartingaleType martingale = MartingaleType.Power, double eps = 0.1) => new OutColumn(input, confidence, changeHistoryLength, martingale, eps); } @@ -93,7 +91,7 @@ private sealed class OutColumn : Vector public OutColumn(Scalar input, int confidence, int pvalueHistoryLength, - IidBase.AnomalySide side) + AnomalySide side) : base(new Reconciler(confidence, pvalueHistoryLength, side), input) { Input = input; @@ -104,12 +102,12 @@ private sealed class Reconciler : EstimatorReconciler { private readonly int _confidence; private readonly int _pvalueHistoryLength; - private readonly IidBase.AnomalySide _side; + private readonly AnomalySide _side; public Reconciler( int confidence, int pvalueHistoryLength, - IidBase.AnomalySide side) + AnomalySide side) { _confidence = confidence; _pvalueHistoryLength = pvalueHistoryLength; @@ -124,11 +122,11 @@ public override IEstimator Reconcile(IHostEnvironment env, { Contracts.Assert(toOutput.Length == 1); var outCol = (OutColumn)toOutput[0]; - return new IidSpikeEstimator(env, + return new MLContext().Transforms.IidSpikeEstimator( outputNames[outCol], + inputNames[outCol.Input], _confidence, _pvalueHistoryLength, - inputNames[outCol.Input], _side); } } @@ -140,7 +138,7 @@ public static Vector IidSpikeDetect( this Scalar input, int confidence, int pvalueHistoryLength, - IidBase.AnomalySide side = IidBase.AnomalySide.TwoSided + AnomalySide side = AnomalySide.TwoSided ) => new OutColumn(input, confidence, pvalueHistoryLength, side); } @@ -158,8 +156,8 @@ public OutColumn(Scalar input, int changeHistoryLength, int trainingWindowSize, int seasonalityWindowSize, - ErrorFunctionUtils.ErrorFunction errorFunction, - SsaBase.MartingaleType martingale, + ErrorFunction errorFunction, + MartingaleType martingale, double eps) : base(new Reconciler(confidence, changeHistoryLength, trainingWindowSize, seasonalityWindowSize, errorFunction, martingale, eps), input) { @@ -173,8 +171,8 @@ private sealed class Reconciler : EstimatorReconciler private readonly int _changeHistoryLength; private readonly int _trainingWindowSize; private readonly int _seasonalityWindowSize; - private readonly ErrorFunctionUtils.ErrorFunction _errorFunction; - private readonly SsaBase.MartingaleType _martingale; + private readonly ErrorFunction _errorFunction; + private readonly MartingaleType _martingale; private readonly double _eps; public Reconciler( @@ -182,8 +180,8 @@ public Reconciler( int changeHistoryLength, int trainingWindowSize, int seasonalityWindowSize, - ErrorFunctionUtils.ErrorFunction errorFunction, - SsaBase.MartingaleType martingale, + ErrorFunction errorFunction, + MartingaleType martingale, double eps) { _confidence = confidence; @@ -203,13 +201,13 @@ public override IEstimator Reconcile(IHostEnvironment env, { Contracts.Assert(toOutput.Length == 1); var outCol = (OutColumn)toOutput[0]; - return new SsaChangePointEstimator(env, + return new MLContext().Transforms.SsaChangePointEstimator( outputNames[outCol], + inputNames[outCol.Input], _confidence, _changeHistoryLength, _trainingWindowSize, _seasonalityWindowSize, - inputNames[outCol.Input], _errorFunction, _martingale, _eps); @@ -225,8 +223,8 @@ public static Vector SsaChangePointDetect( int changeHistoryLength, int trainingWindowSize, int seasonalityWindowSize, - ErrorFunctionUtils.ErrorFunction errorFunction = ErrorFunctionUtils.ErrorFunction.SignedDifference, - SsaBase.MartingaleType martingale = SsaBase.MartingaleType.Power, + ErrorFunction errorFunction = ErrorFunction.SignedDifference, + MartingaleType martingale = MartingaleType.Power, double eps = 0.1) => new OutColumn(input, confidence, changeHistoryLength, trainingWindowSize, seasonalityWindowSize, errorFunction, martingale, eps); } @@ -244,8 +242,8 @@ public OutColumn(Scalar input, int pvalueHistoryLength, int trainingWindowSize, int seasonalityWindowSize, - SsaBase.AnomalySide side, - ErrorFunctionUtils.ErrorFunction errorFunction) + AnomalySide side, + ErrorFunction errorFunction) : base(new Reconciler(confidence, pvalueHistoryLength, trainingWindowSize, seasonalityWindowSize, side, errorFunction), input) { Input = input; @@ -258,16 +256,16 @@ private sealed class Reconciler : EstimatorReconciler private readonly int _pvalueHistoryLength; private readonly int _trainingWindowSize; private readonly int _seasonalityWindowSize; - private readonly SsaBase.AnomalySide _side; - private readonly ErrorFunctionUtils.ErrorFunction _errorFunction; + private readonly AnomalySide _side; + private readonly ErrorFunction _errorFunction; public Reconciler( int confidence, int pvalueHistoryLength, int trainingWindowSize, int seasonalityWindowSize, - SsaBase.AnomalySide side, - ErrorFunctionUtils.ErrorFunction errorFunction) + AnomalySide side, + ErrorFunction errorFunction) { _confidence = confidence; _pvalueHistoryLength = pvalueHistoryLength; @@ -285,13 +283,13 @@ public override IEstimator Reconcile(IHostEnvironment env, { Contracts.Assert(toOutput.Length == 1); var outCol = (OutColumn)toOutput[0]; - return new SsaSpikeEstimator(env, + return new MLContext().Transforms.SsaSpikeEstimator( outputNames[outCol], + inputNames[outCol.Input], _confidence, _pvalueHistoryLength, _trainingWindowSize, _seasonalityWindowSize, - inputNames[outCol.Input], _side, _errorFunction); } @@ -306,8 +304,8 @@ public static Vector SsaSpikeDetect( int changeHistoryLength, int trainingWindowSize, int seasonalityWindowSize, - SsaBase.AnomalySide side = SsaBase.AnomalySide.TwoSided, - ErrorFunctionUtils.ErrorFunction errorFunction = ErrorFunctionUtils.ErrorFunction.SignedDifference + AnomalySide side = AnomalySide.TwoSided, + ErrorFunction errorFunction = ErrorFunction.SignedDifference ) => new OutColumn(input, confidence, changeHistoryLength, trainingWindowSize, seasonalityWindowSize, side, errorFunction); } diff --git a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs index 4a3593ca14..cf5e602cf9 100644 --- a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs +++ b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs @@ -11,22 +11,21 @@ using Microsoft.ML.Internal.CpuMath; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; -using Microsoft.ML.TimeSeries; -using Microsoft.ML.TimeSeriesProcessing; +using Microsoft.ML.Transforms.TimeSeries; [assembly: LoadableClass(typeof(AdaptiveSingularSpectrumSequenceModeler), typeof(AdaptiveSingularSpectrumSequenceModeler), null, typeof(SignatureLoadModel), "SSA Sequence Modeler", AdaptiveSingularSpectrumSequenceModeler.LoaderSignature)] -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { /// /// This class implements basic Singular Spectrum Analysis (SSA) model for modeling univariate time-series. /// For the details of the model, refer to http://arxiv.org/pdf/1206.6910.pdf. /// - public sealed class AdaptiveSingularSpectrumSequenceModeler : SequenceModelerBase + internal sealed class AdaptiveSingularSpectrumSequenceModeler : SequenceModelerBase { - public const string LoaderSignature = "SSAModel"; + internal const string LoaderSignature = "SSAModel"; public enum RankSelectionMethod { @@ -35,7 +34,7 @@ public enum RankSelectionMethod Fast } - public sealed class SsaForecastResult : ForecastResultBase + internal sealed class SsaForecastResult : ForecastResultBase { public VBuffer ForecastStandardDeviation; public VBuffer UpperBound; @@ -465,7 +464,7 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo _xSmooth = new CpuAlignedVector(_windowSize, CpuMathUtils.GetVectorAlignment()); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -1517,7 +1516,7 @@ internal override SequenceModelerBase Clone() /// /// The input forecast object /// The confidence level in [0, 1) - public static void ComputeForecastIntervals(ref SsaForecastResult forecast, Single confidenceLevel = 0.95f) + internal static void ComputeForecastIntervals(ref SsaForecastResult forecast, Single confidenceLevel = 0.95f) { Contracts.CheckParam(0 <= confidenceLevel && confidenceLevel < 1, nameof(confidenceLevel), "The confidence level must be in [0, 1)."); Contracts.CheckValue(forecast, nameof(forecast)); diff --git a/src/Microsoft.ML.TimeSeries/EigenUtils.cs b/src/Microsoft.ML.TimeSeries/EigenUtils.cs index 8d0b6794e5..fc8224b94e 100644 --- a/src/Microsoft.ML.TimeSeries/EigenUtils.cs +++ b/src/Microsoft.ML.TimeSeries/EigenUtils.cs @@ -8,10 +8,10 @@ using Microsoft.ML.Internal.Utilities; using Float = System.Single; -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { //REVIEW: improve perf with SSE and Multithreading - public static class EigenUtils + internal static class EigenUtils { //Compute the Eigen-decomposition of a symmetric matrix //REVIEW: use matrix/vector operations, not Array Math diff --git a/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs b/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs index 15e7a84a8f..29653b2620 100644 --- a/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs +++ b/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs @@ -10,25 +10,26 @@ using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; -using Microsoft.ML.TimeSeriesProcessing; +using Microsoft.ML.Transforms.TimeSeries; [assembly: LoadableClass(ExponentialAverageTransform.Summary, typeof(ExponentialAverageTransform), typeof(ExponentialAverageTransform.Arguments), typeof(SignatureDataTransform), ExponentialAverageTransform.UserName, ExponentialAverageTransform.LoaderSignature, ExponentialAverageTransform.ShortName)] [assembly: LoadableClass(ExponentialAverageTransform.Summary, typeof(ExponentialAverageTransform), null, typeof(SignatureLoadDataTransform), ExponentialAverageTransform.UserName, ExponentialAverageTransform.LoaderSignature)] -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { /// /// ExponentialAverageTransform is a weighted average of the values: ExpAvg(y_t) = a * y_t + (1-a) * ExpAvg(y_(t-1)). /// - public sealed class ExponentialAverageTransform : SequentialTransformBase + internal sealed class ExponentialAverageTransform : SequentialTransformBase { public const string Summary = "Applies a Exponential average on a time series."; public const string LoaderSignature = "ExpAverageTransform"; public const string UserName = "Exponential Average Transform"; public const string ShortName = "ExpAvg"; +#pragma warning disable 0649 public sealed class Arguments : TransformInputBase { [Argument(ArgumentType.Required, HelpText = "The name of the source column", ShortName = "src", @@ -43,6 +44,7 @@ public sealed class Arguments : TransformInputBase ShortName = "d", SortOrder = 4)] public Single Decay = 0.9f; } +#pragma warning restore 0649 private static VersionInfo GetVersionInfo() { @@ -77,7 +79,7 @@ public ExponentialAverageTransform(IHostEnvironment env, ModelLoadContext ctx, I Host.CheckDecode(WindowSize == 1); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(WindowSize >= 1); @@ -89,7 +91,7 @@ public override void Save(ModelSaveContext ctx) // // Single _decay - base.Save(ctx); + base.SaveModel(ctx); ctx.Writer.Write(_decay); } diff --git a/src/Microsoft.ML.TimeSeries/ExtensionsCatalog.cs b/src/Microsoft.ML.TimeSeries/ExtensionsCatalog.cs new file mode 100644 index 0000000000..e013f7c671 --- /dev/null +++ b/src/Microsoft.ML.TimeSeries/ExtensionsCatalog.cs @@ -0,0 +1,87 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.ML.Transforms.TimeSeries; + +namespace Microsoft.ML +{ + public static class TimeSeriesCatalog + { + /// + /// Create a new instance of + /// + /// The transform's catalog. + /// Name of the column resulting from the transformation of . + /// Column is a vector of type double and size 4. The vector contains Alert, Raw Score, P-Value and Martingale score as first four values. + /// Name of column to transform. If set to , the value of the will be used as source. + /// The confidence for change point detection in the range [0, 100]. + /// The length of the sliding window on p-values for computing the martingale score. + /// The martingale used for scoring. + /// The epsilon parameter for the Power martingale. + public static IidChangePointEstimator IidChangePointEstimator(this TransformsCatalog catalog, string outputColumnName, string inputColumnName, + int confidence, int changeHistoryLength, MartingaleType martingale = MartingaleType.Power, double eps = 0.1) + => new IidChangePointEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, confidence, changeHistoryLength, inputColumnName, martingale, eps); + + /// + /// Create a new instance of + /// + /// The transform's catalog. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. + /// The confidence for spike detection in the range [0, 100]. + /// The size of the sliding window for computing the p-value. + /// The argument that determines whether to detect positive or negative anomalies, or both. + public static IidSpikeEstimator IidSpikeEstimator(this TransformsCatalog catalog, string outputColumnName, string inputColumnName, + int confidence, int pvalueHistoryLength, AnomalySide side = AnomalySide.TwoSided) + => new IidSpikeEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, confidence, pvalueHistoryLength, inputColumnName, side); + + /// + /// Create a new instance of + /// + /// The transform's catalog. + /// Name of the column resulting from the transformation of . + /// Column is a vector of type double and size 4. The vector contains Alert, Raw Score, P-Value and Martingale score as first four values. + /// Name of column to transform. If set to , the value of the will be used as source. + /// The confidence for change point detection in the range [0, 100]. + /// The number of points from the beginning of the sequence used for training. + /// The size of the sliding window for computing the p-value. + /// An upper bound on the largest relevant seasonality in the input time-series. + /// The function used to compute the error between the expected and the observed value. + /// The martingale used for scoring. + /// The epsilon parameter for the Power martingale. + public static SsaChangePointEstimator SsaChangePointEstimator(this TransformsCatalog catalog, string outputColumnName, string inputColumnName, + int confidence, int changeHistoryLength, int trainingWindowSize, int seasonalityWindowSize, ErrorFunction errorFunction = ErrorFunction.SignedDifference, + MartingaleType martingale = MartingaleType.Power, double eps = 0.1) + => new SsaChangePointEstimator(CatalogUtils.GetEnvironment(catalog), new SsaChangePointDetector.Options + { + Name = outputColumnName, + Source = inputColumnName ?? outputColumnName, + Confidence = confidence, + ChangeHistoryLength = changeHistoryLength, + TrainingWindowSize = trainingWindowSize, + SeasonalWindowSize = seasonalityWindowSize, + Martingale = martingale, + PowerMartingaleEpsilon = eps, + ErrorFunction = errorFunction + }); + + /// + /// Create a new instance of + /// + /// The transform's catalog. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. + /// The confidence for spike detection in the range [0, 100]. + /// The size of the sliding window for computing the p-value. + /// The number of points from the beginning of the sequence used for training. + /// An upper bound on the largest relevant seasonality in the input time-series. + /// The vector contains Alert, Raw Score, P-Value as first three values. + /// The argument that determines whether to detect positive or negative anomalies, or both. + /// The function used to compute the error between the expected and the observed value. + public static SsaSpikeEstimator SsaSpikeEstimator(this TransformsCatalog catalog, string outputColumnName, string inputColumnName, int confidence, int pvalueHistoryLength, + int trainingWindowSize, int seasonalityWindowSize, AnomalySide side = AnomalySide.TwoSided, ErrorFunction errorFunction = ErrorFunction.SignedDifference) + => new SsaSpikeEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, confidence, pvalueHistoryLength, trainingWindowSize, seasonalityWindowSize, inputColumnName, side, errorFunction); + } +} diff --git a/src/Microsoft.ML.TimeSeries/FftUtils.cs b/src/Microsoft.ML.TimeSeries/FftUtils.cs index a02575c031..28a19a89ed 100644 --- a/src/Microsoft.ML.TimeSeries/FftUtils.cs +++ b/src/Microsoft.ML.TimeSeries/FftUtils.cs @@ -6,7 +6,7 @@ using System.Runtime.InteropServices; using System.Security; -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { /// /// The utility functions that wrap the native Discrete Fast Fourier Transform functionality from Intel MKL. diff --git a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs index eeff5db3d2..25cf99347b 100644 --- a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs +++ b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs @@ -5,107 +5,193 @@ using System; using System.IO; using Microsoft.Data.DataView; +using Microsoft.ML.Core.Data; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; -using Microsoft.ML.TimeSeries; -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { /// - /// This transform computes the p-values and martingale scores for a supposedly i.i.d input sequence of floats. In other words, it assumes + /// The is the wrapper to that computes the p-values and martingale scores for a supposedly i.i.d input sequence of floats. In other words, it assumes /// the input sequence represents the raw anomaly score which might have been computed via another process. /// - public abstract class IidAnomalyDetectionBase : SequentialAnomalyDetectionTransformBase + public class IidAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveModel { - public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment env) - : base(args, name, env) + /// + /// Whether a call to should succeed, on an + /// appropriate schema. + /// + public bool IsRowToRowMapper => InternalTransform.IsRowToRowMapper; + + /// + /// Creates a clone of the transfomer. Used for taking the snapshot of the state. + /// + /// + IStatefulTransformer IStatefulTransformer.Clone() => InternalTransform.Clone(); + + /// + /// Schema propagation for transformers. + /// Returns the output schema of the data, if the input schema is like the one provided. + /// + public Schema GetOutputSchema(Schema inputSchema) => InternalTransform.GetOutputSchema(inputSchema); + + /// + /// Constructs a row-to-row mapper based on an input schema. If + /// is false, then an exception should be thrown. If the input schema is in any way + /// unsuitable for constructing the mapper, an exception should likewise be thrown. + /// + /// The input schema for which we should get the mapper. + /// The row to row mapper. + public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) => InternalTransform.GetRowToRowMapper(inputSchema); + + /// + /// Same as but also supports mechanism to save the state. + /// + /// The input schema for which we should get the mapper. + /// The row to row mapper. + public IRowToRowMapper GetStatefulRowToRowMapper(Schema inputSchema) => ((IStatefulTransformer)InternalTransform).GetStatefulRowToRowMapper(inputSchema); + + /// + /// Take the data in, make transformations, output the data. + /// Note that 's are lazy, so no actual transformations happen here, just schema validation. + /// + public IDataView Transform(IDataView input) => InternalTransform.Transform(input); + + /// + /// For saving a model into a repository. + /// + public virtual void Save(ModelSaveContext ctx) { - InitialWindowSize = 0; - StateRef = new State(); - StateRef.InitState(WindowSize, InitialWindowSize, this, Host); + InternalTransform.SaveThis(ctx); } - public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name) - : base(env, ctx, name) - { - Host.CheckDecode(InitialWindowSize == 0); - StateRef = new State(ctx.Reader); - StateRef.InitState(this, Host); - } - - public override Schema GetOutputSchema(Schema inputSchema) - { - Host.CheckValue(inputSchema, nameof(inputSchema)); + /// + /// Creates a row mapper from Schema. + /// + internal IStatefulRowMapper MakeRowMapper(Schema schema) => InternalTransform.MakeRowMapper(schema); - if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName); + /// + /// Creates an IDataTransform from an IDataView. + /// + internal IDataTransform MakeDataTransform(IDataView input) => InternalTransform.MakeDataTransform(input); - var colType = inputSchema[col].Type; - if (colType != NumberType.R4) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, "float", colType.ToString()); + internal IidAnomalyDetectionBase InternalTransform; - return Transform(new EmptyDataView(Host, inputSchema)).Schema; + internal IidAnomalyDetectionBaseWrapper(ArgumentsBase args, string name, IHostEnvironment env) + { + InternalTransform = new IidAnomalyDetectionBase(args, name, env, this); } - public override void Save(ModelSaveContext ctx) + internal IidAnomalyDetectionBaseWrapper(IHostEnvironment env, ModelLoadContext ctx, string name) { - ctx.CheckAtModel(); - Host.Assert(InitialWindowSize == 0); - base.Save(ctx); - - // *** Binary format *** - // - // State: StateRef - StateRef.Save(ctx.Writer); + InternalTransform = new IidAnomalyDetectionBase(env, ctx, name, this); } - public sealed class State : AnomalyDetectionStateBase + /// + /// This transform computes the p-values and martingale scores for a supposedly i.i.d input sequence of floats. In other words, it assumes + /// the input sequence represents the raw anomaly score which might have been computed via another process. + /// + internal class IidAnomalyDetectionBase : SequentialAnomalyDetectionTransformBase { - public State() - { - } + internal IidAnomalyDetectionBaseWrapper Parent; - internal State(BinaryReader reader) : base(reader) + public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment env, IidAnomalyDetectionBaseWrapper parent) + : base(args, name, env) { - WindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host); - InitialWindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host); + InitialWindowSize = 0; + StateRef = new State(); + StateRef.InitState(WindowSize, InitialWindowSize, this, Host); + Parent = parent; } - internal override void Save(BinaryWriter writer) + public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name, IidAnomalyDetectionBaseWrapper parent) + : base(env, ctx, name) { - base.Save(writer); - TimeSeriesUtils.SerializeFixedSizeQueue(WindowedBuffer, writer); - TimeSeriesUtils.SerializeFixedSizeQueue(InitialWindowedBuffer, writer); + Host.CheckDecode(InitialWindowSize == 0); + StateRef = new State(ctx.Reader); + StateRef.InitState(this, Host); + Parent = parent; } - private protected override void CloneCore(StateBase state) + public override Schema GetOutputSchema(Schema inputSchema) { - base.CloneCore(state); - Contracts.Assert(state is State); - var stateLocal = state as State; - stateLocal.WindowedBuffer = WindowedBuffer.Clone(); - stateLocal.InitialWindowedBuffer = InitialWindowedBuffer.Clone(); - } + Host.CheckValue(inputSchema, nameof(inputSchema)); - private protected override void LearnStateFromDataCore(FixedSizeQueue data) - { - // This method is empty because there is no need for initial tuning for this transform. + if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName); + + var colType = inputSchema[col].Type; + if (colType != NumberType.R4) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, NumberType.R4.ToString(), colType.ToString()); + + return Transform(new EmptyDataView(Host, inputSchema)).Schema; } - private protected override void InitializeAnomalyDetector() + private protected override void SaveModel(ModelSaveContext ctx) { - // This method is empty because there is no need for any extra initialization for this transform. + Parent.Save(ctx); } - private protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue windowedBuffer, long iteration) + internal void SaveThis(ModelSaveContext ctx) { - // This transform treats the input sequenence as the raw anomaly score. - return (double)input; + ctx.CheckAtModel(); + Host.Assert(InitialWindowSize == 0); + base.SaveModel(ctx); + + // *** Binary format *** + // + // State: StateRef + StateRef.Save(ctx.Writer); } - public override void Consume(float value) + internal sealed class State : AnomalyDetectionStateBase { + public State() + { + } + + internal State(BinaryReader reader) : base(reader) + { + WindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host); + InitialWindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host); + } + + internal override void Save(BinaryWriter writer) + { + base.Save(writer); + TimeSeriesUtils.SerializeFixedSizeQueue(WindowedBuffer, writer); + TimeSeriesUtils.SerializeFixedSizeQueue(InitialWindowedBuffer, writer); + } + + private protected override void CloneCore(State state) + { + base.CloneCore(state); + Contracts.Assert(state is State); + var stateLocal = state as State; + stateLocal.WindowedBuffer = WindowedBuffer.Clone(); + stateLocal.InitialWindowedBuffer = InitialWindowedBuffer.Clone(); + } + + private protected override void LearnStateFromDataCore(FixedSizeQueue data) + { + // This method is empty because there is no need for initial tuning for this transform. + } + + private protected override void InitializeAnomalyDetector() + { + // This method is empty because there is no need for any extra initialization for this transform. + } + + private protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue windowedBuffer, long iteration) + { + // This transform treats the input sequenence as the raw anomaly score. + return (double)input; + } + + public override void Consume(float value) + { + } } } } diff --git a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs index cb0874e2ce..9f5fc85a2d 100644 --- a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs +++ b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs @@ -12,11 +12,9 @@ using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; using Microsoft.ML.Model; -using Microsoft.ML.TimeSeries; -using Microsoft.ML.TimeSeriesProcessing; -using static Microsoft.ML.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase; +using Microsoft.ML.Transforms.TimeSeries; -[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IDataTransform), typeof(IidChangePointDetector), typeof(IidChangePointDetector.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IDataTransform), typeof(IidChangePointDetector), typeof(IidChangePointDetector.Options), typeof(SignatureDataTransform), IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature, IidChangePointDetector.ShortName)] [assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IDataTransform), typeof(IidChangePointDetector), null, typeof(SignatureLoadDataTransform), @@ -28,19 +26,19 @@ [assembly: LoadableClass(typeof(IRowMapper), typeof(IidChangePointDetector), null, typeof(SignatureLoadRowMapper), IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature)] -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { /// /// This class implements the change point detector transform for an i.i.d. sequence based on adaptive kernel density estimation and martingales. /// - public sealed class IidChangePointDetector : IidAnomalyDetectionBase + public sealed class IidChangePointDetector : IidAnomalyDetectionBaseWrapper, IStatefulTransformer { internal const string Summary = "This transform detects the change-points in an i.i.d. sequence using adaptive kernel density estimation and martingales."; - public const string LoaderSignature = "IidChangePointDetector"; - public const string UserName = "IID Change Point Detection"; - public const string ShortName = "ichgpnt"; + internal const string LoaderSignature = "IidChangePointDetector"; + internal const string UserName = "IID Change Point Detection"; + internal const string ShortName = "ichgpnt"; - public sealed class Arguments : TransformInputBase + internal sealed class Options : TransformInputBase { [Argument(ArgumentType.Required, HelpText = "The name of the source column.", ShortName = "src", SortOrder = 1, Purpose = SpecialPurpose.ColumnName)] @@ -59,7 +57,7 @@ public sealed class Arguments : TransformInputBase public double Confidence = 95; [Argument(ArgumentType.AtMostOnce, HelpText = "The martingale used for scoring.", ShortName = "mart", SortOrder = 103)] - public MartingaleType Martingale = SequentialAnomalyDetectionTransformBase.MartingaleType.Power; + public MartingaleType Martingale = MartingaleType.Power; [Argument(ArgumentType.AtMostOnce, HelpText = "The epsilon parameter for the Power martingale.", ShortName = "eps", SortOrder = 104)] @@ -68,27 +66,27 @@ public sealed class Arguments : TransformInputBase private sealed class BaseArguments : ArgumentsBase { - public BaseArguments(Arguments args) + public BaseArguments(Options options) { - Source = args.Source; - Name = args.Name; - Side = SequentialAnomalyDetectionTransformBase.AnomalySide.TwoSided; - WindowSize = args.ChangeHistoryLength; - Martingale = args.Martingale; - PowerMartingaleEpsilon = args.PowerMartingaleEpsilon; - AlertOn = SequentialAnomalyDetectionTransformBase.AlertingScore.MartingaleScore; + Source = options.Source; + Name = options.Name; + Side = AnomalySide.TwoSided; + WindowSize = options.ChangeHistoryLength; + Martingale = options.Martingale; + PowerMartingaleEpsilon = options.PowerMartingaleEpsilon; + AlertOn = AlertingScore.MartingaleScore; } public BaseArguments(IidChangePointDetector transform) { - Source = transform.InputColumnName; - Name = transform.OutputColumnName; + Source = transform.InternalTransform.InputColumnName; + Name = transform.InternalTransform.OutputColumnName; Side = AnomalySide.TwoSided; - WindowSize = transform.WindowSize; - Martingale = transform.Martingale; - PowerMartingaleEpsilon = transform.PowerMartingaleEpsilon; + WindowSize = transform.InternalTransform.WindowSize; + Martingale = transform.InternalTransform.Martingale; + PowerMartingaleEpsilon = transform.InternalTransform.PowerMartingaleEpsilon; AlertOn = AlertingScore.MartingaleScore; - AlertThreshold = transform.AlertThreshold; + AlertThreshold = transform.InternalTransform.AlertThreshold; } } @@ -103,39 +101,39 @@ private static VersionInfo GetVersionInfo() } // Factory method for SignatureDataTransform. - private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - return new IidChangePointDetector(env, args).MakeDataTransform(input); + return new IidChangePointDetector(env, options).MakeDataTransform(input); } - internal override IStatefulTransformer Clone() + IStatefulTransformer IStatefulTransformer.Clone() { var clone = (IidChangePointDetector)MemberwiseClone(); - clone.StateRef = (State)clone.StateRef.Clone(); - clone.StateRef.InitState(clone, Host); + clone.InternalTransform.StateRef = (IidAnomalyDetectionBase.State)clone.InternalTransform.StateRef.Clone(); + clone.InternalTransform.StateRef.InitState(clone.InternalTransform, InternalTransform.Host); return clone; } - internal IidChangePointDetector(IHostEnvironment env, Arguments args) - : base(new BaseArguments(args), LoaderSignature, env) + internal IidChangePointDetector(IHostEnvironment env, Options options) + : base(new BaseArguments(options), LoaderSignature, env) { - switch (Martingale) + switch (InternalTransform.Martingale) { case MartingaleType.None: - AlertThreshold = Double.MaxValue; + InternalTransform.AlertThreshold = Double.MaxValue; break; case MartingaleType.Power: - AlertThreshold = Math.Exp(WindowSize * LogPowerMartigaleBettingFunc(1 - args.Confidence / 100, PowerMartingaleEpsilon)); + InternalTransform.AlertThreshold = Math.Exp(InternalTransform.WindowSize * InternalTransform.LogPowerMartigaleBettingFunc(1 - options.Confidence / 100, InternalTransform.PowerMartingaleEpsilon)); break; case MartingaleType.Mixture: - AlertThreshold = Math.Exp(WindowSize * LogMixtureMartigaleBettingFunc(1 - args.Confidence / 100)); + InternalTransform.AlertThreshold = Math.Exp(InternalTransform.WindowSize * InternalTransform.LogMixtureMartigaleBettingFunc(1 - options.Confidence / 100)); break; default: - throw Host.ExceptParam(nameof(args.Martingale), + throw InternalTransform.Host.ExceptParam(nameof(options.Martingale), "The martingale type can be only (0) None, (1) Power or (2) Mixture."); } } @@ -166,8 +164,8 @@ internal IidChangePointDetector(IHostEnvironment env, ModelLoadContext ctx) // *** Binary format *** // - Host.CheckDecode(ThresholdScore == AlertingScore.MartingaleScore); - Host.CheckDecode(Side == AnomalySide.TwoSided); + InternalTransform.Host.CheckDecode(InternalTransform.ThresholdScore == AlertingScore.MartingaleScore); + InternalTransform.Host.CheckDecode(InternalTransform.Side == AnomalySide.TwoSided); } private IidChangePointDetector(IHostEnvironment env, IidChangePointDetector transform) @@ -177,12 +175,12 @@ private IidChangePointDetector(IHostEnvironment env, IidChangePointDetector tran public override void Save(ModelSaveContext ctx) { - Host.CheckValue(ctx, nameof(ctx)); + InternalTransform.Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - Host.Assert(ThresholdScore == AlertingScore.MartingaleScore); - Host.Assert(Side == AnomalySide.TwoSided); + InternalTransform.Host.Assert(InternalTransform.ThresholdScore == AlertingScore.MartingaleScore); + InternalTransform.Host.Assert(InternalTransform.Side == AnomalySide.TwoSided); // *** Binary format *** // @@ -219,10 +217,10 @@ public sealed class IidChangePointEstimator : TrivialEstimatorName of column to transform. If set to , the value of the will be used as source. /// The martingale used for scoring. /// The epsilon parameter for the Power martingale. - public IidChangePointEstimator(IHostEnvironment env, string outputColumnName, int confidence, + internal IidChangePointEstimator(IHostEnvironment env, string outputColumnName, int confidence, int changeHistoryLength, string inputColumnName, MartingaleType martingale = MartingaleType.Power, double eps = 0.1) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidChangePointEstimator)), - new IidChangePointDetector(env, new IidChangePointDetector.Arguments + new IidChangePointDetector(env, new IidChangePointDetector.Options { Name = outputColumnName, Source = inputColumnName ?? outputColumnName, @@ -234,28 +232,32 @@ public IidChangePointEstimator(IHostEnvironment env, string outputColumnName, in { } - public IidChangePointEstimator(IHostEnvironment env, IidChangePointDetector.Arguments args) + internal IidChangePointEstimator(IHostEnvironment env, IidChangePointDetector.Options options) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidChangePointEstimator)), - new IidChangePointDetector(env, args)) + new IidChangePointDetector(env, options)) { } + /// + /// Returns the of the schema which will be produced by the transformer. + /// Used for schema propagation and verification in a pipeline. + /// public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - if (!inputSchema.TryFindColumn(Transformer.InputColumnName, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName); + if (!inputSchema.TryFindColumn(Transformer.InternalTransform.InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InternalTransform.InputColumnName); if (col.ItemType != NumberType.R4) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName, "float", col.GetTypeString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InternalTransform.InputColumnName, "float", col.GetTypeString()); var metadata = new List() { new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) }; var resultDic = inputSchema.ToDictionary(x => x.Name); - resultDic[Transformer.OutputColumnName] = new SchemaShape.Column( - Transformer.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata)); + resultDic[Transformer.InternalTransform.OutputColumnName] = new SchemaShape.Column( + Transformer.InternalTransform.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata)); return new SchemaShape(resultDic.Values); } diff --git a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs index 20cecfc285..813043606a 100644 --- a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs +++ b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs @@ -11,11 +11,9 @@ using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; using Microsoft.ML.Model; -using Microsoft.ML.TimeSeries; -using Microsoft.ML.TimeSeriesProcessing; -using static Microsoft.ML.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase; +using Microsoft.ML.Transforms.TimeSeries; -[assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IDataTransform), typeof(IidSpikeDetector), typeof(IidSpikeDetector.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IDataTransform), typeof(IidSpikeDetector), typeof(IidSpikeDetector.Options), typeof(SignatureDataTransform), IidSpikeDetector.UserName, IidSpikeDetector.LoaderSignature, IidSpikeDetector.ShortName)] [assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IDataTransform), typeof(IidSpikeDetector), null, typeof(SignatureLoadDataTransform), @@ -27,19 +25,19 @@ [assembly: LoadableClass(typeof(IRowMapper), typeof(IidSpikeDetector), null, typeof(SignatureLoadRowMapper), IidSpikeDetector.UserName, IidSpikeDetector.LoaderSignature)] -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { /// /// This class implements the spike detector transform for an i.i.d. sequence based on adaptive kernel density estimation. /// - public sealed class IidSpikeDetector : IidAnomalyDetectionBase + public sealed class IidSpikeDetector : IidAnomalyDetectionBaseWrapper, IStatefulTransformer { internal const string Summary = "This transform detects the spikes in a i.i.d. sequence using adaptive kernel density estimation."; - public const string LoaderSignature = "IidSpikeDetector"; - public const string UserName = "IID Spike Detection"; - public const string ShortName = "ispike"; + internal const string LoaderSignature = "IidSpikeDetector"; + internal const string UserName = "IID Spike Detection"; + internal const string ShortName = "ispike"; - public sealed class Arguments : TransformInputBase + internal sealed class Options : TransformInputBase { [Argument(ArgumentType.Required, HelpText = "The name of the source column.", ShortName = "src", SortOrder = 1, Purpose = SpecialPurpose.ColumnName)] @@ -64,24 +62,24 @@ public sealed class Arguments : TransformInputBase private sealed class BaseArguments : ArgumentsBase { - public BaseArguments(Arguments args) + public BaseArguments(Options options) { - Source = args.Source; - Name = args.Name; - Side = args.Side; - WindowSize = args.PvalueHistoryLength; - AlertThreshold = 1 - args.Confidence / 100; - AlertOn = SequentialAnomalyDetectionTransformBase.AlertingScore.PValueScore; + Source = options.Source; + Name = options.Name; + Side = options.Side; + WindowSize = options.PvalueHistoryLength; + AlertThreshold = 1 - options.Confidence / 100; + AlertOn = AlertingScore.PValueScore; Martingale = MartingaleType.None; } public BaseArguments(IidSpikeDetector transform) { - Source = transform.InputColumnName; - Name = transform.OutputColumnName; - Side = transform.Side; - WindowSize = transform.WindowSize; - AlertThreshold = transform.AlertThreshold; + Source = transform.InternalTransform.InputColumnName; + Name = transform.InternalTransform.OutputColumnName; + Side = transform.InternalTransform.Side; + WindowSize = transform.InternalTransform.WindowSize; + AlertThreshold = transform.InternalTransform.AlertThreshold; AlertOn = AlertingScore.PValueScore; Martingale = MartingaleType.None; } @@ -99,25 +97,25 @@ private static VersionInfo GetVersionInfo() } // Factory method for SignatureDataTransform. - private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - return new IidSpikeDetector(env, args).MakeDataTransform(input); + return new IidSpikeDetector(env, options).MakeDataTransform(input); } - internal override IStatefulTransformer Clone() + IStatefulTransformer IStatefulTransformer.Clone() { var clone = (IidSpikeDetector)MemberwiseClone(); - clone.StateRef = (State)clone.StateRef.Clone(); - clone.StateRef.InitState(clone, Host); + clone.InternalTransform.StateRef = (IidAnomalyDetectionBase.State)clone.InternalTransform.StateRef.Clone(); + clone.InternalTransform.StateRef.InitState(clone.InternalTransform, InternalTransform.Host); return clone; } - internal IidSpikeDetector(IHostEnvironment env, Arguments args) - : base(new BaseArguments(args), LoaderSignature, env) + internal IidSpikeDetector(IHostEnvironment env, Options options) + : base(new BaseArguments(options), LoaderSignature, env) { // This constructor is empty. } @@ -142,14 +140,15 @@ private static IidSpikeDetector Create(IHostEnvironment env, ModelLoadContext ct return new IidSpikeDetector(env, ctx); } - public IidSpikeDetector(IHostEnvironment env, ModelLoadContext ctx) + internal IidSpikeDetector(IHostEnvironment env, ModelLoadContext ctx) : base(env, ctx, LoaderSignature) { // *** Binary format *** // - Host.CheckDecode(ThresholdScore == AlertingScore.PValueScore); + InternalTransform.Host.CheckDecode(InternalTransform.ThresholdScore == AlertingScore.PValueScore); } + private IidSpikeDetector(IHostEnvironment env, IidSpikeDetector transform) : base(new BaseArguments(transform), LoaderSignature, env) { @@ -157,11 +156,11 @@ private IidSpikeDetector(IHostEnvironment env, IidSpikeDetector transform) public override void Save(ModelSaveContext ctx) { - Host.CheckValue(ctx, nameof(ctx)); + InternalTransform.Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - Host.Assert(ThresholdScore == AlertingScore.PValueScore); + InternalTransform.Host.Assert(InternalTransform.ThresholdScore == AlertingScore.PValueScore); // *** Binary format *** // @@ -197,9 +196,9 @@ public sealed class IidSpikeEstimator : TrivialEstimator /// The size of the sliding window for computing the p-value. /// Name of column to transform. If set to , the value of the will be used as source. /// The argument that determines whether to detect positive or negative anomalies, or both. - public IidSpikeEstimator(IHostEnvironment env, string outputColumnName, int confidence, int pvalueHistoryLength, string inputColumnName, AnomalySide side = AnomalySide.TwoSided) + internal IidSpikeEstimator(IHostEnvironment env, string outputColumnName, int confidence, int pvalueHistoryLength, string inputColumnName, AnomalySide side = AnomalySide.TwoSided) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidSpikeDetector)), - new IidSpikeDetector(env, new IidSpikeDetector.Arguments + new IidSpikeDetector(env, new IidSpikeDetector.Options { Name = outputColumnName, Source = inputColumnName, @@ -210,26 +209,30 @@ public IidSpikeEstimator(IHostEnvironment env, string outputColumnName, int conf { } - public IidSpikeEstimator(IHostEnvironment env, IidSpikeDetector.Arguments args) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidSpikeEstimator)), new IidSpikeDetector(env, args)) + internal IidSpikeEstimator(IHostEnvironment env, IidSpikeDetector.Options options) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidSpikeEstimator)), new IidSpikeDetector(env, options)) { } + /// + /// Schema propagation for transformers. + /// Returns the output schema of the data, if the input schema is like the one provided. + /// public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - if (!inputSchema.TryFindColumn(Transformer.InputColumnName, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName); + if (!inputSchema.TryFindColumn(Transformer.InternalTransform.InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InternalTransform.InputColumnName); if (col.ItemType != NumberType.R4) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName, "float", col.GetTypeString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InternalTransform.InputColumnName, "float", col.GetTypeString()); var metadata = new List() { new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) }; var resultDic = inputSchema.ToDictionary(x => x.Name); - resultDic[Transformer.OutputColumnName] = new SchemaShape.Column( - Transformer.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata)); + resultDic[Transformer.InternalTransform.OutputColumnName] = new SchemaShape.Column( + Transformer.InternalTransform.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata)); return new SchemaShape(resultDic.Values); } diff --git a/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs b/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs index f207c7b75a..1e04173d49 100644 --- a/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs +++ b/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs @@ -10,20 +10,20 @@ using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; -using Microsoft.ML.TimeSeriesProcessing; +using Microsoft.ML.Transforms.TimeSeries; [assembly: LoadableClass(MovingAverageTransform.Summary, typeof(MovingAverageTransform), typeof(MovingAverageTransform.Arguments), typeof(SignatureDataTransform), "Moving Average Transform", MovingAverageTransform.LoaderSignature, "MoAv")] [assembly: LoadableClass(MovingAverageTransform.Summary, typeof(MovingAverageTransform), null, typeof(SignatureLoadDataTransform), "Moving Average Transform", MovingAverageTransform.LoaderSignature)] -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { /// /// MovingAverageTransform is a weighted average of the values in /// the sliding window. /// - public sealed class MovingAverageTransform : SequentialTransformBase + internal sealed class MovingAverageTransform : SequentialTransformBase { public const string Summary = "Applies a moving average on a time series. Only finite values are taken into account."; public const string LoaderSignature = "MovingAverageTransform"; @@ -94,7 +94,7 @@ public MovingAverageTransform(IHostEnvironment env, ModelLoadContext ctx, IDataV Host.CheckDecode(_weights == null || Utils.Size(_weights) == WindowSize + 1 - _lag); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(WindowSize >= 1); @@ -107,7 +107,7 @@ public override void Save(ModelSaveContext ctx) // int: _lag // Single[]: _weights - base.Save(ctx); + base.SaveModel(ctx); ctx.Writer.Write(_lag); Host.Assert(_weights == null || Utils.Size(_weights) == WindowSize + 1 - _lag); ctx.Writer.WriteSingleArray(_weights); diff --git a/src/Microsoft.ML.TimeSeries/PValueTransform.cs b/src/Microsoft.ML.TimeSeries/PValueTransform.cs index d8f3adcb29..528f8b7fd0 100644 --- a/src/Microsoft.ML.TimeSeries/PValueTransform.cs +++ b/src/Microsoft.ML.TimeSeries/PValueTransform.cs @@ -10,20 +10,20 @@ using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; -using Microsoft.ML.TimeSeriesProcessing; +using Microsoft.ML.Transforms.TimeSeries; [assembly: LoadableClass(PValueTransform.Summary, typeof(PValueTransform), typeof(PValueTransform.Arguments), typeof(SignatureDataTransform), PValueTransform.UserName, PValueTransform.LoaderSignature, PValueTransform.ShortName)] [assembly: LoadableClass(PValueTransform.Summary, typeof(PValueTransform), null, typeof(SignatureLoadDataTransform), PValueTransform.UserName, PValueTransform.LoaderSignature)] -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { /// /// PValueTransform is a sequential transform that computes the empirical p-value of the current value in the series based on the other values in /// the sliding window. /// - public sealed class PValueTransform : SequentialTransformBase + internal sealed class PValueTransform : SequentialTransformBase { internal const string Summary = "This P-Value transform calculates the p-value of the current input in the sequence with regard to the values in the sliding window."; public const string LoaderSignature = "PValueTransform"; @@ -91,7 +91,7 @@ public PValueTransform(IHostEnvironment env, ModelLoadContext ctx, IDataView inp Host.CheckDecode(WindowSize >= 1); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(WindowSize >= 1); @@ -103,7 +103,7 @@ public override void Save(ModelSaveContext ctx) // int: _percentile // byte: _isPositiveSide - base.Save(ctx); + base.SaveModel(ctx); ctx.Writer.Write(_seed); ctx.Writer.WriteBoolByte(_isPositiveSide); } diff --git a/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs b/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs index 868e446899..a4e8c792d0 100644 --- a/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs +++ b/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs @@ -10,20 +10,20 @@ using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; -using Microsoft.ML.TimeSeriesProcessing; +using Microsoft.ML.Transforms.TimeSeries; [assembly: LoadableClass(PercentileThresholdTransform.Summary, typeof(PercentileThresholdTransform), typeof(PercentileThresholdTransform.Arguments), typeof(SignatureDataTransform), PercentileThresholdTransform.UserName, PercentileThresholdTransform.LoaderSignature, PercentileThresholdTransform.ShortName)] [assembly: LoadableClass(PercentileThresholdTransform.Summary, typeof(PercentileThresholdTransform), null, typeof(SignatureLoadDataTransform), PercentileThresholdTransform.UserName, PercentileThresholdTransform.LoaderSignature)] -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { /// /// PercentileThresholdTransform is a sequential transform that decides whether the current value of the time-series belongs to the 'percentile' % of the top values in /// the sliding window. The output of the transform will be a boolean flag. /// - public sealed class PercentileThresholdTransform : SequentialTransformBase + internal sealed class PercentileThresholdTransform : SequentialTransformBase { public const string Summary = "Detects the values of time-series that are in the top percentile of the sliding window."; public const string LoaderSignature = "PercentThrTransform"; @@ -86,7 +86,7 @@ public PercentileThresholdTransform(IHostEnvironment env, ModelLoadContext ctx, Host.CheckDecode(MinPercentile <= _percentile && _percentile <= MaxPercentile); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(MinPercentile <= _percentile && _percentile <= MaxPercentile); @@ -98,7 +98,7 @@ public override void Save(ModelSaveContext ctx) // // Double: _percentile - base.Save(ctx); + base.SaveModel(ctx); ctx.Writer.Write(_percentile); } diff --git a/src/Microsoft.ML.TimeSeries/PolynomialUtils.cs b/src/Microsoft.ML.TimeSeries/PolynomialUtils.cs index 094984dcb2..7ba7dd0d25 100644 --- a/src/Microsoft.ML.TimeSeries/PolynomialUtils.cs +++ b/src/Microsoft.ML.TimeSeries/PolynomialUtils.cs @@ -8,9 +8,9 @@ using System.Numerics; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { - public static class PolynomialUtils + internal static class PolynomialUtils { // Part 1: Computing the polynomial real and complex roots from its real coefficients diff --git a/src/Microsoft.ML.TimeSeries/PredictionFunction.cs b/src/Microsoft.ML.TimeSeries/PredictionFunction.cs index e0c352e632..5aab239fef 100644 --- a/src/Microsoft.ML.TimeSeries/PredictionFunction.cs +++ b/src/Microsoft.ML.TimeSeries/PredictionFunction.cs @@ -10,7 +10,7 @@ using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -namespace Microsoft.ML.TimeSeries +namespace Microsoft.ML.Transforms.TimeSeries { internal interface IStatefulRowToRowMapper : IRowToRowMapper { @@ -18,8 +18,17 @@ internal interface IStatefulRowToRowMapper : IRowToRowMapper internal interface IStatefulTransformer : ITransformer { + /// + /// Same as but also supports mechanism to save the state. + /// + /// The input schema for which we should get the mapper. + /// The row to row mapper. IRowToRowMapper GetStatefulRowToRowMapper(Schema inputSchema); + /// + /// Creates a clone of the transfomer. Used for taking the snapshot of the state. + /// + /// IStatefulTransformer Clone(); } @@ -90,6 +99,10 @@ private static ITransformer CloneTransformers(ITransformer transformer) return transformer is IStatefulTransformer ? ((IStatefulTransformer)transformer).Clone() : transformer; } + /// + /// Contructor for creating time series specific prediction engine. It allows update the time series model to be updated with the observations + /// seen at prediction time via + /// public TimeSeriesPredictionFunction(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) : base(env, CloneTransformers(transformer), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition) @@ -200,7 +213,7 @@ private IRowToRowMapper GetRowToRowMapper(Schema inputSchema) return new CompositeRowToRowMapper(inputSchema, mappers); } - protected override Func TransformerChecker(IExceptionContext ectx, ITransformer transformer) + private protected override Func TransformerChecker(IExceptionContext ectx, ITransformer transformer) { ectx.CheckValue(transformer, nameof(transformer)); ectx.CheckParam(IsRowToRowMapper(transformer), nameof(transformer), "Must be a row to row mapper or " + nameof(IStatefulTransformer)); diff --git a/src/Microsoft.ML.TimeSeries/Properties/AssemblyInfo.cs b/src/Microsoft.ML.TimeSeries/Properties/AssemblyInfo.cs index 56899c0981..a16f8e17fd 100644 --- a/src/Microsoft.ML.TimeSeries/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.TimeSeries/Properties/AssemblyInfo.cs @@ -5,4 +5,7 @@ using System.Runtime.CompilerServices; using Microsoft.ML; +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries.Tests" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)] + +[assembly: WantsToBeBestFriends] diff --git a/src/Microsoft.ML.TimeSeries/SequenceModelerBase.cs b/src/Microsoft.ML.TimeSeries/SequenceModelerBase.cs index d66b58c9f8..b6da0925cc 100644 --- a/src/Microsoft.ML.TimeSeries/SequenceModelerBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequenceModelerBase.cs @@ -6,13 +6,13 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { /// /// The base container class for the forecast result on a sequence of type . /// /// The type of the elements in the sequence - public abstract class ForecastResultBase + internal abstract class ForecastResultBase { public VBuffer PointForecast; } @@ -22,7 +22,7 @@ public abstract class ForecastResultBase /// /// The type of the elements in the input sequence /// The type of the elements in the output sequence - public abstract class SequenceModelerBase : ICanSaveModel + internal abstract class SequenceModelerBase : ICanSaveModel { private protected SequenceModelerBase() { @@ -75,6 +75,8 @@ private protected SequenceModelerBase() /// /// Implementation of . /// - public abstract void Save(ModelSaveContext ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected abstract void SaveModel(ModelSaveContext ctx); } } diff --git a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs index 91a554b96d..4bd1cbc87d 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs @@ -12,682 +12,680 @@ using Microsoft.ML.Internal.CpuMath; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; -using Microsoft.ML.TimeSeries; +using Microsoft.ML.Transforms.TimeSeries; -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { - // REVIEW: This base class and its children classes generate one output column of type VBuffer to output 3 different anomaly scores as well as - // the alert flag. Ideally these 4 output information should be put in four seaparate columns instead of one VBuffer<> column. However, this is not currently - // possible due to our design restriction. This must be fixed in the next version and will potentially affect the children classes. - /// - /// The base class for sequential anomaly detection transforms that supports the p-value as well as the martingales scores computation from the sequence of - /// raw anomaly scores whose calculation is specified by the children classes. This class also provides mechanism for the threshold-based alerting on - /// the raw anomaly score, the p-value score or the martingale score. Currently, this class supports Power and Mixture martingales. - /// For more details, please refer to http://arxiv.org/pdf/1204.3251.pdf + /// The type of the martingale. /// - /// The type of the input sequence - /// The type of the state object for sequential anomaly detection. Must be a class inherited from AnomalyDetectionStateBase - public abstract class SequentialAnomalyDetectionTransformBase : SequentialTransformerBase, TState> - where TState : SequentialAnomalyDetectionTransformBase.AnomalyDetectionStateBase, new() + public enum MartingaleType : byte { /// - /// The type of the martingale. + /// (None) No martingale is used. /// - public enum MartingaleType : byte - { - /// - /// (None) No martingale is used. - /// - None, - /// - /// (Power) The Power martingale is used. - /// - Power, - /// - /// (Mixture) The Mixture martingale is used. - /// - Mixture - } - + None, /// - /// The side of anomaly detection. + /// (Power) The Power martingale is used. /// - public enum AnomalySide : byte - { - /// - /// (Positive) Only positive anomalies are detected. - /// - Positive, - /// - /// (Negative) Only negative anomalies are detected. - /// - Negative, - /// - /// (TwoSided) Both positive and negative anomalies are detected. - /// - TwoSided - } + Power, + /// + /// (Mixture) The Mixture martingale is used. + /// + Mixture + } + /// + /// The side of anomaly detection. + /// + public enum AnomalySide : byte + { /// - /// The score that should be thresholded to generate alerts. + /// (Positive) Only positive anomalies are detected. /// - public enum AlertingScore : byte - { - /// - /// (RawScore) The raw anomaly score is thresholded. - /// - RawScore, - /// - /// (PValueScore) The p-value score is thresholded. - /// - PValueScore, - /// - /// (MartingaleScore) The martingale score is thresholded. - /// - MartingaleScore - } + Positive, + /// + /// (Negative) Only negative anomalies are detected. + /// + Negative, + /// + /// (TwoSided) Both positive and negative anomalies are detected. + /// + TwoSided + } + /// + /// The score that should be thresholded to generate alerts. + /// + internal enum AlertingScore : byte + { /// - /// The base class that can be inherited by the 'Argument' classes in the derived classes containing the shared input parameters. + /// (RawScore) The raw anomaly score is thresholded. /// - public abstract class ArgumentsBase - { - [Argument(ArgumentType.Required, HelpText = "The name of the source column", ShortName = "src", - SortOrder = 1, Purpose = SpecialPurpose.ColumnName)] - public string Source; + RawScore, + /// + /// (PValueScore) The p-value score is thresholded. + /// + PValueScore, + /// + /// (MartingaleScore) The martingale score is thresholded. + /// + MartingaleScore + } - [Argument(ArgumentType.Required, HelpText = "The name of the new column", ShortName = "name", - SortOrder = 2)] - public string Name; + /// + /// The base class that can be inherited by the 'Argument' classes in the derived classes containing the shared input parameters. + /// + internal abstract class ArgumentsBase + { + [Argument(ArgumentType.Required, HelpText = "The name of the source column", ShortName = "src", + SortOrder = 1, Purpose = SpecialPurpose.ColumnName)] + public string Source; - [Argument(ArgumentType.AtMostOnce, HelpText = "The argument that determines whether to detect positive or negative anomalies, or both", ShortName = "side", - SortOrder = 3)] - public AnomalySide Side = AnomalySide.TwoSided; + [Argument(ArgumentType.Required, HelpText = "The name of the new column", ShortName = "name", + SortOrder = 2)] + public string Name; - [Argument(ArgumentType.AtMostOnce, HelpText = "The size of the sliding window for computing the p-value.", ShortName = "wnd", - SortOrder = 4)] - public int WindowSize = 1; + [Argument(ArgumentType.AtMostOnce, HelpText = "The argument that determines whether to detect positive or negative anomalies, or both", ShortName = "side", + SortOrder = 3)] + public AnomalySide Side = AnomalySide.TwoSided; - [Argument(ArgumentType.AtMostOnce, HelpText = "The size of the initial window for computing the p-value as well as training if needed. The default value is set to 0, which means there is no initial window considered.", - ShortName = "initwnd", SortOrder = 5)] - public int InitialWindowSize = 0; + [Argument(ArgumentType.AtMostOnce, HelpText = "The size of the sliding window for computing the p-value.", ShortName = "wnd", + SortOrder = 4)] + public int WindowSize = 1; - [Argument(ArgumentType.AtMostOnce, HelpText = "The martingale used for scoring", - ShortName = "martingale", SortOrder = 6)] - public MartingaleType Martingale = MartingaleType.Power; + [Argument(ArgumentType.AtMostOnce, HelpText = "The size of the initial window for computing the p-value as well as training if needed. The default value is set to 0, which means there is no initial window considered.", + ShortName = "initwnd", SortOrder = 5)] + public int InitialWindowSize = 0; - [Argument(ArgumentType.AtMostOnce, HelpText = "The argument that determines whether anomalies should be detected based on the raw anomaly score, the p-value or the martingale score", - ShortName = "alert", SortOrder = 7)] - public AlertingScore AlertOn = AlertingScore.MartingaleScore; + [Argument(ArgumentType.AtMostOnce, HelpText = "The martingale used for scoring", + ShortName = "martingale", SortOrder = 6)] + public MartingaleType Martingale = MartingaleType.Power; - [Argument(ArgumentType.AtMostOnce, HelpText = "The epsilon parameter for the Power martingale", - ShortName = "eps", SortOrder = 8)] - public Double PowerMartingaleEpsilon = 0.1; + [Argument(ArgumentType.AtMostOnce, HelpText = "The argument that determines whether anomalies should be detected based on the raw anomaly score, the p-value or the martingale score", + ShortName = "alert", SortOrder = 7)] + public AlertingScore AlertOn = AlertingScore.MartingaleScore; - [Argument(ArgumentType.Required, HelpText = "The threshold for alerting", - ShortName = "thr", SortOrder = 9)] - public Double AlertThreshold; - } + [Argument(ArgumentType.AtMostOnce, HelpText = "The epsilon parameter for the Power martingale", + ShortName = "eps", SortOrder = 8)] + public Double PowerMartingaleEpsilon = 0.1; - // Determines the side of anomaly detection for this transform. - protected AnomalySide Side; + [Argument(ArgumentType.Required, HelpText = "The threshold for alerting", + ShortName = "thr", SortOrder = 9)] + public Double AlertThreshold; + } - // Determines the type of martingale used by this transform. - protected MartingaleType Martingale; + // REVIEW: This base class and its children classes generate one output column of type VBuffer to output 3 different anomaly scores as well as + // the alert flag. Ideally these 4 output information should be put in four seaparate columns instead of one VBuffer<> column. However, this is not currently + // possible due to our design restriction. This must be fixed in the next version and will potentially affect the children classes. + /// + /// The base class for sequential anomaly detection transforms that supports the p-value as well as the martingales scores computation from the sequence of + /// raw anomaly scores whose calculation is specified by the children classes. This class also provides mechanism for the threshold-based alerting on + /// the raw anomaly score, the p-value score or the martingale score. Currently, this class supports Power and Mixture martingales. + /// For more details, please refer to http://arxiv.org/pdf/1204.3251.pdf + /// + /// The type of the input sequence + /// The type of the input sequence + internal abstract class SequentialAnomalyDetectionTransformBase : SequentialTransformerBase, TState> + where TState : SequentialAnomalyDetectionTransformBase.AnomalyDetectionStateBase, new() + { + // Determines the side of anomaly detection for this transform. + internal AnomalySide Side; - // The epsilon parameter used by the Power martingale. - protected Double PowerMartingaleEpsilon; + // Determines the type of martingale used by this transform. + internal MartingaleType Martingale; - // Determines the score that should be thresholded to generate alerts by this transform. - protected AlertingScore ThresholdScore; + // The epsilon parameter used by the Power martingale. + internal Double PowerMartingaleEpsilon; - // Determines the threshold for generating alerts. - protected Double AlertThreshold; + // Determines the score that should be thresholded to generate alerts by this transform. + internal AlertingScore ThresholdScore; - // The size of the VBuffer in the dst column. - private int _outputLength; + // Determines the threshold for generating alerts. + internal Double AlertThreshold; - private static int GetOutputLength(AlertingScore alertingScore, IHostEnvironment host) - { - switch (alertingScore) + // The size of the VBuffer in the dst column. + internal int OutputLength; + + // The minimum value for p-values. The smaller p-values are ceiled to this value. + internal const Double MinPValue = 1e-8; + + // The maximun value for p-values. The larger p-values are floored to this value. + internal const Double MaxPValue = 1 - MinPValue; + + private static int GetOutputLength(AlertingScore alertingScore, IHostEnvironment host) + { + switch (alertingScore) + { + case AlertingScore.RawScore: + return 2; + case AlertingScore.PValueScore: + return 3; + case AlertingScore.MartingaleScore: + return 4; + default: + throw host.Except("The alerting score can be only (0) RawScore, (1) PValueScore or (2) MartingaleScore."); + } + } + + private protected SequentialAnomalyDetectionTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, string name, IHostEnvironment env, + AnomalySide anomalySide, MartingaleType martingale, AlertingScore alertingScore, Double powerMartingaleEpsilon, + Double alertThreshold) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), windowSize, initialWindowSize, outputColumnName, inputColumnName, new VectorType(NumberType.R8, GetOutputLength(alertingScore, env))) { - case AlertingScore.RawScore: - return 2; - case AlertingScore.PValueScore: - return 3; - case AlertingScore.MartingaleScore: - return 4; - default: - throw host.Except("The alerting score can be only (0) RawScore, (1) PValueScore or (2) MartingaleScore."); + Host.CheckUserArg(Enum.IsDefined(typeof(MartingaleType), martingale), nameof(ArgumentsBase.Martingale), "Value is undefined."); + Host.CheckUserArg(Enum.IsDefined(typeof(AnomalySide), anomalySide), nameof(ArgumentsBase.Side), "Value is undefined."); + Host.CheckUserArg(Enum.IsDefined(typeof(AlertingScore), alertingScore), nameof(ArgumentsBase.AlertOn), "Value is undefined."); + Host.CheckUserArg(martingale != MartingaleType.None || alertingScore != AlertingScore.MartingaleScore, nameof(ArgumentsBase.Martingale), "A martingale type should be specified if alerting is based on the martingale score."); + Host.CheckUserArg(windowSize > 0 || alertingScore == AlertingScore.RawScore, nameof(ArgumentsBase.AlertOn), + "When there is no windowed buffering (i.e., " + nameof(ArgumentsBase.WindowSize) + " = 0), the alert can be generated only based on the raw score (i.e., " + + nameof(ArgumentsBase.AlertOn) + " = " + nameof(AlertingScore.RawScore) + ")"); + Host.CheckUserArg(0 < powerMartingaleEpsilon && powerMartingaleEpsilon < 1, nameof(ArgumentsBase.PowerMartingaleEpsilon), "Should be in (0,1)."); + Host.CheckUserArg(alertThreshold >= 0, nameof(ArgumentsBase.AlertThreshold), "Must be non-negative."); + Host.CheckUserArg(alertingScore != AlertingScore.PValueScore || (0 <= alertThreshold && alertThreshold <= 1), nameof(ArgumentsBase.AlertThreshold), "Must be in [0,1]."); + + ThresholdScore = alertingScore; + Side = anomalySide; + Martingale = martingale; + PowerMartingaleEpsilon = powerMartingaleEpsilon; + AlertThreshold = alertThreshold; + OutputLength = GetOutputLength(ThresholdScore, Host); } - } - private protected SequentialAnomalyDetectionTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, string name, IHostEnvironment env, - AnomalySide anomalySide, MartingaleType martingale, AlertingScore alertingScore, Double powerMartingaleEpsilon, - Double alertThreshold) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), windowSize, initialWindowSize, outputColumnName, inputColumnName, new VectorType(NumberType.R8, GetOutputLength(alertingScore, env))) - { - Host.CheckUserArg(Enum.IsDefined(typeof(MartingaleType), martingale), nameof(ArgumentsBase.Martingale), "Value is undefined."); - Host.CheckUserArg(Enum.IsDefined(typeof(AnomalySide), anomalySide), nameof(ArgumentsBase.Side), "Value is undefined."); - Host.CheckUserArg(Enum.IsDefined(typeof(AlertingScore), alertingScore), nameof(ArgumentsBase.AlertOn), "Value is undefined."); - Host.CheckUserArg(martingale != MartingaleType.None || alertingScore != AlertingScore.MartingaleScore, nameof(ArgumentsBase.Martingale), "A martingale type should be specified if alerting is based on the martingale score."); - Host.CheckUserArg(windowSize > 0 || alertingScore == AlertingScore.RawScore, nameof(ArgumentsBase.AlertOn), - "When there is no windowed buffering (i.e., " + nameof(ArgumentsBase.WindowSize) + " = 0), the alert can be generated only based on the raw score (i.e., " - + nameof(ArgumentsBase.AlertOn) + " = " + nameof(AlertingScore.RawScore) + ")"); - Host.CheckUserArg(0 < powerMartingaleEpsilon && powerMartingaleEpsilon < 1, nameof(ArgumentsBase.PowerMartingaleEpsilon), "Should be in (0,1)."); - Host.CheckUserArg(alertThreshold >= 0, nameof(ArgumentsBase.AlertThreshold), "Must be non-negative."); - Host.CheckUserArg(alertingScore != AlertingScore.PValueScore || (0 <= alertThreshold && alertThreshold <= 1), nameof(ArgumentsBase.AlertThreshold), "Must be in [0,1]."); - - ThresholdScore = alertingScore; - Side = anomalySide; - Martingale = martingale; - PowerMartingaleEpsilon = powerMartingaleEpsilon; - AlertThreshold = alertThreshold; - _outputLength = GetOutputLength(ThresholdScore, Host); - } + private protected SequentialAnomalyDetectionTransformBase(ArgumentsBase args, string name, IHostEnvironment env) + : this(args.WindowSize, args.InitialWindowSize, args.Source, args.Name, name, env, args.Side, args.Martingale, + args.AlertOn, args.PowerMartingaleEpsilon, args.AlertThreshold) + { + } - private protected SequentialAnomalyDetectionTransformBase(ArgumentsBase args, string name, IHostEnvironment env) - : this(args.WindowSize, args.InitialWindowSize, args.Source, args.Name, name, env, args.Side, args.Martingale, - args.AlertOn, args.PowerMartingaleEpsilon, args.AlertThreshold) - { - } + private protected SequentialAnomalyDetectionTransformBase(IHostEnvironment env, ModelLoadContext ctx, string name) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), ctx) + { + // *** Binary format *** + // + // byte: _martingale + // byte: _alertingScore + // byte: _anomalySide + // Double: _powerMartingaleEpsilon + // Double: _alertThreshold + + byte temp; + temp = ctx.Reader.ReadByte(); + Host.CheckDecode(Enum.IsDefined(typeof(MartingaleType), temp)); + Martingale = (MartingaleType)temp; + + temp = ctx.Reader.ReadByte(); + Host.CheckDecode(Enum.IsDefined(typeof(AlertingScore), temp)); + ThresholdScore = (AlertingScore)temp; + + Host.CheckDecode(Martingale != MartingaleType.None || ThresholdScore != AlertingScore.MartingaleScore); + Host.CheckDecode(WindowSize > 0 || ThresholdScore == AlertingScore.RawScore); + + temp = ctx.Reader.ReadByte(); + Host.CheckDecode(Enum.IsDefined(typeof(AnomalySide), temp)); + Side = (AnomalySide)temp; + + PowerMartingaleEpsilon = ctx.Reader.ReadDouble(); + Host.CheckDecode(0 < PowerMartingaleEpsilon && PowerMartingaleEpsilon < 1); + + AlertThreshold = ctx.Reader.ReadDouble(); + Host.CheckDecode(AlertThreshold >= 0); + Host.CheckDecode(ThresholdScore != AlertingScore.PValueScore || (0 <= AlertThreshold && AlertThreshold <= 1)); + + OutputLength = GetOutputLength(ThresholdScore, Host); + } - private protected SequentialAnomalyDetectionTransformBase(IHostEnvironment env, ModelLoadContext ctx, string name) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), ctx) - { - // *** Binary format *** - // - // byte: _martingale - // byte: _alertingScore - // byte: _anomalySide - // Double: _powerMartingaleEpsilon - // Double: _alertThreshold - - byte temp; - temp = ctx.Reader.ReadByte(); - Host.CheckDecode(Enum.IsDefined(typeof(MartingaleType), temp)); - Martingale = (MartingaleType)temp; - - temp = ctx.Reader.ReadByte(); - Host.CheckDecode(Enum.IsDefined(typeof(AlertingScore), temp)); - ThresholdScore = (AlertingScore)temp; - - Host.CheckDecode(Martingale != MartingaleType.None || ThresholdScore != AlertingScore.MartingaleScore); - Host.CheckDecode(WindowSize > 0 || ThresholdScore == AlertingScore.RawScore); - - temp = ctx.Reader.ReadByte(); - Host.CheckDecode(Enum.IsDefined(typeof(AnomalySide), temp)); - Side = (AnomalySide)temp; - - PowerMartingaleEpsilon = ctx.Reader.ReadDouble(); - Host.CheckDecode(0 < PowerMartingaleEpsilon && PowerMartingaleEpsilon < 1); - - AlertThreshold = ctx.Reader.ReadDouble(); - Host.CheckDecode(AlertThreshold >= 0); - Host.CheckDecode(ThresholdScore != AlertingScore.PValueScore || (0 <= AlertThreshold && AlertThreshold <= 1)); - - _outputLength = GetOutputLength(ThresholdScore, Host); - } + private protected override void SaveModel(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + + Host.Assert(Enum.IsDefined(typeof(MartingaleType), Martingale)); + Host.Assert(Enum.IsDefined(typeof(AlertingScore), ThresholdScore)); + Host.Assert(Martingale != MartingaleType.None || ThresholdScore != AlertingScore.MartingaleScore); + Host.Assert(WindowSize > 0 || ThresholdScore == AlertingScore.RawScore); + Host.Assert(Enum.IsDefined(typeof(AnomalySide), Side)); + Host.Assert(0 < PowerMartingaleEpsilon && PowerMartingaleEpsilon < 1); + Host.Assert(AlertThreshold >= 0); + Host.Assert(ThresholdScore != AlertingScore.PValueScore || (0 <= AlertThreshold && AlertThreshold <= 1)); + + // *** Binary format *** + // + // byte: _martingale + // byte: _alertingScore + // byte: _anomalySide + // Double: _powerMartingaleEpsilon + // Double: _alertThreshold + + base.SaveModel(ctx); + ctx.Writer.Write((byte)Martingale); + ctx.Writer.Write((byte)ThresholdScore); + ctx.Writer.Write((byte)Side); + ctx.Writer.Write(PowerMartingaleEpsilon); + ctx.Writer.Write(AlertThreshold); + } - public override void Save(ModelSaveContext ctx) - { - Host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(); - - Host.Assert(Enum.IsDefined(typeof(MartingaleType), Martingale)); - Host.Assert(Enum.IsDefined(typeof(AlertingScore), ThresholdScore)); - Host.Assert(Martingale != MartingaleType.None || ThresholdScore != AlertingScore.MartingaleScore); - Host.Assert(WindowSize > 0 || ThresholdScore == AlertingScore.RawScore); - Host.Assert(Enum.IsDefined(typeof(AnomalySide), Side)); - Host.Assert(0 < PowerMartingaleEpsilon && PowerMartingaleEpsilon < 1); - Host.Assert(AlertThreshold >= 0); - Host.Assert(ThresholdScore != AlertingScore.PValueScore || (0 <= AlertThreshold && AlertThreshold <= 1)); - - // *** Binary format *** - // - // byte: _martingale - // byte: _alertingScore - // byte: _anomalySide - // Double: _powerMartingaleEpsilon - // Double: _alertThreshold - - base.Save(ctx); - ctx.Writer.Write((byte)Martingale); - ctx.Writer.Write((byte)ThresholdScore); - ctx.Writer.Write((byte)Side); - ctx.Writer.Write(PowerMartingaleEpsilon); - ctx.Writer.Write(AlertThreshold); - } + /// + /// Calculates the betting function for the Power martingale in the log scale. + /// For more details, please refer to http://arxiv.org/pdf/1204.3251.pdf. + /// + /// The p-value + /// The epsilon + /// The Power martingale betting function value in the natural logarithmic scale. + internal Double LogPowerMartigaleBettingFunc(Double p, Double epsilon) + { + Host.Assert(MinPValue > 0); + Host.Assert(MaxPValue < 1); + Host.Assert(MinPValue <= p && p <= MaxPValue); + Host.Assert(0 < epsilon && epsilon < 1); - // The minimum value for p-values. The smaller p-values are ceiled to this value. - private const Double MinPValue = 1e-8; + return Math.Log(epsilon) + (epsilon - 1) * Math.Log(p); + } - // The maximun value for p-values. The larger p-values are floored to this value. - private const Double MaxPValue = 1 - MinPValue; + /// + /// Calculates the betting function for the Mixture martingale in the log scale. + /// For more details, please refer to http://arxiv.org/pdf/1204.3251.pdf. + /// + /// The p-value + /// The Mixure (marginalized over epsilon) martingale betting function value in the natural logarithmic scale. + internal Double LogMixtureMartigaleBettingFunc(Double p) + { + Host.Assert(MinPValue > 0); + Host.Assert(MaxPValue < 1); + Host.Assert(MinPValue <= p && p <= MaxPValue); - /// - /// Calculates the betting function for the Power martingale in the log scale. - /// For more details, please refer to http://arxiv.org/pdf/1204.3251.pdf. - /// - /// The p-value - /// The epsilon - /// The Power martingale betting function value in the natural logarithmic scale. - protected Double LogPowerMartigaleBettingFunc(Double p, Double epsilon) - { - Host.Assert(MinPValue > 0); - Host.Assert(MaxPValue < 1); - Host.Assert(MinPValue <= p && p <= MaxPValue); - Host.Assert(0 < epsilon && epsilon < 1); + Double logP = Math.Log(p); + return Math.Log(p * logP + 1 - p) - 2 * Math.Log(-logP) - logP; + } - return Math.Log(epsilon) + (epsilon - 1) * Math.Log(p); - } + internal override IStatefulRowMapper MakeRowMapper(Schema schema) => new Mapper(Host, this, schema); - /// - /// Calculates the betting function for the Mixture martingale in the log scale. - /// For more details, please refer to http://arxiv.org/pdf/1204.3251.pdf. - /// - /// The p-value - /// The Mixure (marginalized over epsilon) martingale betting function value in the natural logarithmic scale. - protected Double LogMixtureMartigaleBettingFunc(Double p) - { - Host.Assert(MinPValue > 0); - Host.Assert(MaxPValue < 1); - Host.Assert(MinPValue <= p && p <= MaxPValue); + internal sealed class Mapper : IStatefulRowMapper + { + private readonly IHost _host; + private readonly SequentialAnomalyDetectionTransformBase _parent; + private readonly Schema _parentSchema; + private readonly int _inputColumnIndex; + private readonly VBuffer> _slotNames; + private AnomalyDetectionStateBase State { get; set; } + + public Mapper(IHostEnvironment env, SequentialAnomalyDetectionTransformBase parent, Schema inputSchema) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(Mapper)); + _host.CheckValue(inputSchema, nameof(inputSchema)); + _host.CheckValue(parent, nameof(parent)); - Double logP = Math.Log(p); - return Math.Log(p * logP + 1 - p) - 2 * Math.Log(-logP) - logP; - } + if (!inputSchema.TryGetColumnIndex(parent.InputColumnName, out _inputColumnIndex)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.InputColumnName); - /// - /// The base state class for sequential anomaly detection: this class implements the p-values and martinagle calculations for anomaly detection - /// given that the raw anomaly score calculation is specified by the derived classes. - /// - public abstract class AnomalyDetectionStateBase : StateBase - { - // A reference to the parent transform. - protected SequentialAnomalyDetectionTransformBase Parent; + var colType = inputSchema[_inputColumnIndex].Type; + if (colType != NumberType.R4) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.InputColumnName, "float", colType.ToString()); - // A windowed buffer to cache the update values to the martingale score in the log scale. - private FixedSizeQueue LogMartingaleUpdateBuffer { get; set; } + _parent = parent; + _parentSchema = inputSchema; + _slotNames = new VBuffer>(4, new[] { "Alert".AsMemory(), "Raw Score".AsMemory(), + "P-Value Score".AsMemory(), "Martingale Score".AsMemory() }); - // A windowed buffer to cache the raw anomaly scores for p-value calculation. - private FixedSizeQueue RawScoreBuffer { get; set; } + State = (AnomalyDetectionStateBase)_parent.StateRef; + } - // The current martingale score in the log scale. - private Double _logMartingaleValue; + public Schema.DetachedColumn[] GetOutputColumns() + { + var meta = new MetadataBuilder(); + meta.AddSlotNames(_parent.OutputLength, GetSlotNames); + var info = new Schema.DetachedColumn[1]; + info[0] = new Schema.DetachedColumn(_parent.OutputColumnName, new VectorType(NumberType.R8, _parent.OutputLength), meta.GetMetadata()); + return info; + } - // Sum of the squared Euclidean distances among the raw socres in the buffer. - // Used for computing the optimal bandwidth for Kernel Density Estimation of p-values. - private Double _sumSquaredDist; + public void GetSlotNames(ref VBuffer> dst) => _slotNames.CopyTo(ref dst, 0, _parent.OutputLength); - private int _martingaleAlertCounter; + public Func GetDependencies(Func activeOutput) + { + if (activeOutput(0)) + return col => col == _inputColumnIndex; + else + return col => false; + } - protected Double LatestMartingaleScore => Math.Exp(_logMartingaleValue); + public void Save(ModelSaveContext ctx) => _parent.SaveModel(ctx); - private protected AnomalyDetectionStateBase() { } + public Delegate[] CreateGetters(Row input, Func activeOutput, out Action disposer) + { + disposer = null; + var getters = new Delegate[1]; + if (activeOutput(0)) + getters[0] = MakeGetter(input, State); - private protected override void CloneCore(StateBase state) - { - base.CloneCore(state); - Contracts.Assert(state is AnomalyDetectionStateBase); - var stateLocal = state as AnomalyDetectionStateBase; - stateLocal.LogMartingaleUpdateBuffer = LogMartingaleUpdateBuffer.Clone(); - stateLocal.RawScoreBuffer = RawScoreBuffer.Clone(); - } + return getters; + } - private protected AnomalyDetectionStateBase(BinaryReader reader) : base(reader) - { - LogMartingaleUpdateBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueDouble(reader, Host); - RawScoreBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host); - _logMartingaleValue = reader.ReadDouble(); - _sumSquaredDist = reader.ReadDouble(); - _martingaleAlertCounter = reader.ReadInt32(); - } + private delegate void ProcessData(ref TInput src, ref VBuffer dst); - internal override void Save(BinaryWriter writer) - { - base.Save(writer); - TimeSeriesUtils.SerializeFixedSizeQueue(LogMartingaleUpdateBuffer, writer); - TimeSeriesUtils.SerializeFixedSizeQueue(RawScoreBuffer, writer); - writer.Write(_logMartingaleValue); - writer.Write(_sumSquaredDist); - writer.Write(_martingaleAlertCounter); - } + private Delegate MakeGetter(Row input, AnomalyDetectionStateBase state) + { + _host.AssertValue(input); + var srcGetter = input.GetGetter(_inputColumnIndex); + ProcessData processData = _parent.WindowSize > 0 ? + (ProcessData)state.Process : state.ProcessWithoutBuffer; - private Double ComputeKernelPValue(Double rawScore) - { - int i; - int n = RawScoreBuffer.Count; + ValueGetter> valueGetter = (ref VBuffer dst) => + { + TInput src = default; + srcGetter(ref src); + processData(ref src, ref dst); + }; + return valueGetter; + } - if (n == 0) - return 0.5; + public Action CreatePinger(Row input, Func activeOutput, out Action disposer) + { + disposer = null; + Action pinger = null; + if (activeOutput(0)) + pinger = MakePinger(input, State); - Double pValue = 0; - Double bandWidth = Math.Sqrt(2) * ((n == 1) ? 1 : Math.Sqrt(_sumSquaredDist) / n); - bandWidth = Math.Max(bandWidth, 1e-6); + return pinger; + } - Double diff; - for (i = 0; i < n; ++i) + private Action MakePinger(Row input, AnomalyDetectionStateBase state) { - diff = rawScore - RawScoreBuffer[i]; - pValue -= ProbabilityFunctions.Erf(diff / bandWidth); - _sumSquaredDist += diff * diff; + _host.AssertValue(input); + var srcGetter = input.GetGetter(_inputColumnIndex); + Action pinger = (long rowPosition) => + { + TInput src = default; + srcGetter(ref src); + state.UpdateState(ref src, rowPosition, _parent.WindowSize > 0); + }; + return pinger; } - pValue = 0.5 + pValue / (2 * n); - if (RawScoreBuffer.IsFull) + public void CloneState() { - for (i = 1; i < n; ++i) + if (Interlocked.Increment(ref _parent.StateRefCount) > 1) { - diff = RawScoreBuffer[0] - RawScoreBuffer[i]; - _sumSquaredDist -= diff * diff; + State = (AnomalyDetectionStateBase)_parent.StateRef.Clone(); } - - diff = RawScoreBuffer[0] - rawScore; - _sumSquaredDist -= diff * diff; } - return pValue; + public ITransformer GetTransformer() + { + return _parent; + } } - - private protected override void SetNaOutput(ref VBuffer dst) + /// + /// The base state class for sequential anomaly detection: this class implements the p-values and martinagle calculations for anomaly detection + /// given that the raw anomaly score calculation is specified by the derived classes. + /// + internal abstract class AnomalyDetectionStateBase : SequentialTransformerBase, TState>.StateBase { - var outputLength = Parent._outputLength; - var editor = VBufferEditor.Create(ref dst, outputLength); + // A reference to the parent transform. + protected SequentialAnomalyDetectionTransformBase Parent; - for (int i = 0; i < outputLength; ++i) - editor.Values[i] = Double.NaN; - - dst = editor.Commit(); - } + // A windowed buffer to cache the update values to the martingale score in the log scale. + private FixedSizeQueue LogMartingaleUpdateBuffer { get; set; } - private protected override sealed void TransformCore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration, ref VBuffer dst) - { - var outputLength = Parent._outputLength; - Host.Assert(outputLength >= 2); + // A windowed buffer to cache the raw anomaly scores for p-value calculation. + private FixedSizeQueue RawScoreBuffer { get; set; } - var result = VBufferEditor.Create(ref dst, outputLength); - float rawScore = 0; + // The current martingale score in the log scale. + private Double _logMartingaleValue; - for (int i = 0; i < outputLength; ++i) - result.Values[i] = Double.NaN; + // Sum of the squared Euclidean distances among the raw socres in the buffer. + // Used for computing the optimal bandwidth for Kernel Density Estimation of p-values. + private Double _sumSquaredDist; - // Step 1: Computing the raw anomaly score - result.Values[1] = ComputeRawAnomalyScore(ref input, windowedBuffer, iteration); + private int _martingaleAlertCounter; - if (Double.IsNaN(result.Values[1])) - result.Values[0] = 0; - else - { - if (WindowSize > 0) - { - // Step 2: Computing the p-value score - rawScore = (float)result.Values[1]; - if (Parent.ThresholdScore == AlertingScore.RawScore) - { - switch (Parent.Side) - { - case AnomalySide.Negative: - rawScore = (float)(-result.Values[1]); - break; + protected Double LatestMartingaleScore => Math.Exp(_logMartingaleValue); - case AnomalySide.Positive: - break; + private protected AnomalyDetectionStateBase() { } - default: - rawScore = (float)Math.Abs(result.Values[1]); - break; - } - } - else - { - result.Values[2] = ComputeKernelPValue(rawScore); - - switch (Parent.Side) - { - case AnomalySide.Negative: - result.Values[2] = 1 - result.Values[2]; - break; - - case AnomalySide.Positive: - break; - - default: - result.Values[2] = Math.Min(result.Values[2], 1 - result.Values[2]); - break; - } + private protected override void CloneCore(TState state) + { + base.CloneCore(state); + Contracts.Assert(state is AnomalyDetectionStateBase); + var stateLocal = state as AnomalyDetectionStateBase; + stateLocal.LogMartingaleUpdateBuffer = LogMartingaleUpdateBuffer.Clone(); + stateLocal.RawScoreBuffer = RawScoreBuffer.Clone(); + } - // Keeping the p-value in the safe range - if (result.Values[2] < MinPValue) - result.Values[2] = MinPValue; - else if (result.Values[2] > MaxPValue) - result.Values[2] = MaxPValue; + private protected AnomalyDetectionStateBase(BinaryReader reader) : base(reader) + { + LogMartingaleUpdateBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueDouble(reader, Host); + RawScoreBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host); + _logMartingaleValue = reader.ReadDouble(); + _sumSquaredDist = reader.ReadDouble(); + _martingaleAlertCounter = reader.ReadInt32(); + } - RawScoreBuffer.AddLast(rawScore); + internal override void Save(BinaryWriter writer) + { + base.Save(writer); + TimeSeriesUtils.SerializeFixedSizeQueue(LogMartingaleUpdateBuffer, writer); + TimeSeriesUtils.SerializeFixedSizeQueue(RawScoreBuffer, writer); + writer.Write(_logMartingaleValue); + writer.Write(_sumSquaredDist); + writer.Write(_martingaleAlertCounter); + } - // Step 3: Computing the martingale value - if (Parent.Martingale != MartingaleType.None && Parent.ThresholdScore == AlertingScore.MartingaleScore) - { - Double martingaleUpdate = 0; - switch (Parent.Martingale) - { - case MartingaleType.Power: - martingaleUpdate = Parent.LogPowerMartigaleBettingFunc(result.Values[2], Parent.PowerMartingaleEpsilon); - break; + private Double ComputeKernelPValue(Double rawScore) + { + int i; + int n = RawScoreBuffer.Count; - case MartingaleType.Mixture: - martingaleUpdate = Parent.LogMixtureMartigaleBettingFunc(result.Values[2]); - break; - } + if (n == 0) + return 0.5; - if (LogMartingaleUpdateBuffer.Count == 0) - { - for (int i = 0; i < LogMartingaleUpdateBuffer.Capacity; ++i) - LogMartingaleUpdateBuffer.AddLast(martingaleUpdate); - _logMartingaleValue += LogMartingaleUpdateBuffer.Capacity * martingaleUpdate; - } - else - { - _logMartingaleValue += martingaleUpdate; - _logMartingaleValue -= LogMartingaleUpdateBuffer.PeekFirst(); - LogMartingaleUpdateBuffer.AddLast(martingaleUpdate); - } + Double pValue = 0; + Double bandWidth = Math.Sqrt(2) * ((n == 1) ? 1 : Math.Sqrt(_sumSquaredDist) / n); + bandWidth = Math.Max(bandWidth, 1e-6); - result.Values[3] = Math.Exp(_logMartingaleValue); - } - } + Double diff; + for (i = 0; i < n; ++i) + { + diff = rawScore - RawScoreBuffer[i]; + pValue -= ProbabilityFunctions.Erf(diff / bandWidth); + _sumSquaredDist += diff * diff; } - // Generating alert - bool alert = false; - - if (RawScoreBuffer.IsFull) // No alert until the buffer is completely full. + pValue = 0.5 + pValue / (2 * n); + if (RawScoreBuffer.IsFull) { - switch (Parent.ThresholdScore) + for (i = 1; i < n; ++i) { - case AlertingScore.RawScore: - alert = rawScore >= Parent.AlertThreshold; - break; - case AlertingScore.PValueScore: - alert = result.Values[2] <= Parent.AlertThreshold; - break; - case AlertingScore.MartingaleScore: - alert = (Parent.Martingale != MartingaleType.None) && (result.Values[3] >= Parent.AlertThreshold); - - if (alert) - { - if (_martingaleAlertCounter > 0) - alert = false; - else - _martingaleAlertCounter = Parent.WindowSize; - } - - _martingaleAlertCounter--; - _martingaleAlertCounter = _martingaleAlertCounter < 0 ? 0 : _martingaleAlertCounter; - break; + diff = RawScoreBuffer[0] - RawScoreBuffer[i]; + _sumSquaredDist -= diff * diff; } + + diff = RawScoreBuffer[0] - rawScore; + _sumSquaredDist -= diff * diff; } - result.Values[0] = Convert.ToDouble(alert); + return pValue; } - dst = result.Commit(); - } - - private protected override sealed void InitializeStateCore(bool disk = false) - { - Parent = (SequentialAnomalyDetectionTransformBase)ParentTransform; - Host.Assert(WindowSize >= 0); - - if (disk == false) + private protected override void SetNaOutput(ref VBuffer dst) { - if (Parent.Martingale != MartingaleType.None) - LogMartingaleUpdateBuffer = new FixedSizeQueue(WindowSize == 0 ? 1 : WindowSize); - else - LogMartingaleUpdateBuffer = new FixedSizeQueue(1); + var outputLength = Parent.OutputLength; + var editor = VBufferEditor.Create(ref dst, outputLength); - RawScoreBuffer = new FixedSizeQueue(WindowSize == 0 ? 1 : WindowSize); + for (int i = 0; i < outputLength; ++i) + editor.Values[i] = Double.NaN; - _logMartingaleValue = 0; + dst = editor.Commit(); } - InitializeAnomalyDetector(); - } + private protected override sealed void TransformCore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration, ref VBuffer dst) + { + var outputLength = Parent.OutputLength; + Host.Assert(outputLength >= 2); - /// - /// The abstract method that realizes the initialization functionality for the anomaly detector. - /// - private protected abstract void InitializeAnomalyDetector(); + var result = VBufferEditor.Create(ref dst, outputLength); + float rawScore = 0; - /// - /// The abstract method that realizes the main logic for calculating the raw anomaly score bfor the current input given a windowed buffer - /// - /// A reference to the input object. - /// A reference to the windowed buffer. - /// A long number that indicates the number of times ComputeRawAnomalyScore has been called so far (starting value = 0). - /// The raw anomaly score for the input. The Assumption is the higher absolute value of the raw score, the more anomalous the input is. - /// The sign of the score determines whether it's a positive anomaly or a negative one. - private protected abstract Double ComputeRawAnomalyScore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration); - } + for (int i = 0; i < outputLength; ++i) + result.Values[i] = Double.NaN; - private protected override IStatefulRowMapper MakeRowMapper(Schema schema) => new Mapper(Host, this, schema); + // Step 1: Computing the raw anomaly score + result.Values[1] = ComputeRawAnomalyScore(ref input, windowedBuffer, iteration); - private sealed class Mapper : IStatefulRowMapper - { - private readonly IHost _host; - private readonly SequentialAnomalyDetectionTransformBase _parent; - private readonly Schema _parentSchema; - private readonly int _inputColumnIndex; - private readonly VBuffer> _slotNames; - private TState State { get; set; } - - public Mapper(IHostEnvironment env, SequentialAnomalyDetectionTransformBase parent, Schema inputSchema) - { - Contracts.CheckValue(env, nameof(env)); - _host = env.Register(nameof(Mapper)); - _host.CheckValue(inputSchema, nameof(inputSchema)); - _host.CheckValue(parent, nameof(parent)); + if (Double.IsNaN(result.Values[1])) + result.Values[0] = 0; + else + { + if (WindowSize > 0) + { + // Step 2: Computing the p-value score + rawScore = (float)result.Values[1]; + if (Parent.ThresholdScore == AlertingScore.RawScore) + { + switch (Parent.Side) + { + case AnomalySide.Negative: + rawScore = (float)(-result.Values[1]); + break; - if (!inputSchema.TryGetColumnIndex(parent.InputColumnName, out _inputColumnIndex)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.InputColumnName); + case AnomalySide.Positive: + break; - var colType = inputSchema[_inputColumnIndex].Type; - if (colType != NumberType.R4) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.InputColumnName, "float", colType.ToString()); + default: + rawScore = (float)Math.Abs(result.Values[1]); + break; + } + } + else + { + result.Values[2] = ComputeKernelPValue(rawScore); - _parent = parent; - _parentSchema = inputSchema; - _slotNames = new VBuffer>(4, new[] { "Alert".AsMemory(), "Raw Score".AsMemory(), - "P-Value Score".AsMemory(), "Martingale Score".AsMemory() }); + switch (Parent.Side) + { + case AnomalySide.Negative: + result.Values[2] = 1 - result.Values[2]; + break; - State = _parent.StateRef; - } + case AnomalySide.Positive: + break; - public Schema.DetachedColumn[] GetOutputColumns() - { - var meta = new MetadataBuilder(); - meta.AddSlotNames(_parent._outputLength, GetSlotNames); - var info = new Schema.DetachedColumn[1]; - info[0] = new Schema.DetachedColumn(_parent.OutputColumnName, new VectorType(NumberType.R8, _parent._outputLength), meta.GetMetadata()); - return info; - } + default: + result.Values[2] = Math.Min(result.Values[2], 1 - result.Values[2]); + break; + } - public void GetSlotNames(ref VBuffer> dst) => _slotNames.CopyTo(ref dst, 0, _parent._outputLength); + // Keeping the p-value in the safe range + if (result.Values[2] < SequentialAnomalyDetectionTransformBase.MinPValue) + result.Values[2] = SequentialAnomalyDetectionTransformBase.MinPValue; + else if (result.Values[2] > SequentialAnomalyDetectionTransformBase.MaxPValue) + result.Values[2] = SequentialAnomalyDetectionTransformBase.MaxPValue; - public Func GetDependencies(Func activeOutput) - { - if (activeOutput(0)) - return col => col == _inputColumnIndex; - else - return col => false; - } + RawScoreBuffer.AddLast(rawScore); + + // Step 3: Computing the martingale value + if (Parent.Martingale != MartingaleType.None && Parent.ThresholdScore == AlertingScore.MartingaleScore) + { + Double martingaleUpdate = 0; + switch (Parent.Martingale) + { + case MartingaleType.Power: + martingaleUpdate = Parent.LogPowerMartigaleBettingFunc(result.Values[2], Parent.PowerMartingaleEpsilon); + break; + + case MartingaleType.Mixture: + martingaleUpdate = Parent.LogMixtureMartigaleBettingFunc(result.Values[2]); + break; + } + + if (LogMartingaleUpdateBuffer.Count == 0) + { + for (int i = 0; i < LogMartingaleUpdateBuffer.Capacity; ++i) + LogMartingaleUpdateBuffer.AddLast(martingaleUpdate); + _logMartingaleValue += LogMartingaleUpdateBuffer.Capacity * martingaleUpdate; + } + else + { + _logMartingaleValue += martingaleUpdate; + _logMartingaleValue -= LogMartingaleUpdateBuffer.PeekFirst(); + LogMartingaleUpdateBuffer.AddLast(martingaleUpdate); + } - public void Save(ModelSaveContext ctx) => _parent.Save(ctx); + result.Values[3] = Math.Exp(_logMartingaleValue); + } + } + } - public Delegate[] CreateGetters(Row input, Func activeOutput, out Action disposer) - { - disposer = null; - var getters = new Delegate[1]; - if (activeOutput(0)) - getters[0] = MakeGetter(input, State); + // Generating alert + bool alert = false; - return getters; - } + if (RawScoreBuffer.IsFull) // No alert until the buffer is completely full. + { + switch (Parent.ThresholdScore) + { + case AlertingScore.RawScore: + alert = rawScore >= Parent.AlertThreshold; + break; + case AlertingScore.PValueScore: + alert = result.Values[2] <= Parent.AlertThreshold; + break; + case AlertingScore.MartingaleScore: + alert = (Parent.Martingale != MartingaleType.None) && (result.Values[3] >= Parent.AlertThreshold); + + if (alert) + { + if (_martingaleAlertCounter > 0) + alert = false; + else + _martingaleAlertCounter = Parent.WindowSize; + } + + _martingaleAlertCounter--; + _martingaleAlertCounter = _martingaleAlertCounter < 0 ? 0 : _martingaleAlertCounter; + break; + } + } - private delegate void ProcessData(ref TInput src, ref VBuffer dst); + result.Values[0] = Convert.ToDouble(alert); + } - private Delegate MakeGetter(Row input, TState state) - { - _host.AssertValue(input); - var srcGetter = input.GetGetter(_inputColumnIndex); - ProcessData processData = _parent.WindowSize > 0 ? - (ProcessData)state.Process : state.ProcessWithoutBuffer; + dst = result.Commit(); + } - ValueGetter> valueGetter = (ref VBuffer dst) => + private protected override sealed void InitializeStateCore(bool disk = false) { - TInput src = default; - srcGetter(ref src); - processData(ref src, ref dst); - }; - return valueGetter; - } + Parent = (SequentialAnomalyDetectionTransformBase)ParentTransform; + Host.Assert(WindowSize >= 0); - public Action CreatePinger(Row input, Func activeOutput, out Action disposer) - { - disposer = null; - Action pinger = null; - if (activeOutput(0)) - pinger = MakePinger(input, State); + if (disk == false) + { + if (Parent.Martingale != MartingaleType.None) + LogMartingaleUpdateBuffer = new FixedSizeQueue(WindowSize == 0 ? 1 : WindowSize); + else + LogMartingaleUpdateBuffer = new FixedSizeQueue(1); - return pinger; - } + RawScoreBuffer = new FixedSizeQueue(WindowSize == 0 ? 1 : WindowSize); - private Action MakePinger(Row input, TState state) - { - _host.AssertValue(input); - var srcGetter = input.GetGetter(_inputColumnIndex); - Action pinger = (long rowPosition) => - { - TInput src = default; - srcGetter(ref src); - state.UpdateState(ref src, rowPosition, _parent.WindowSize > 0); - }; - return pinger; - } + _logMartingaleValue = 0; + } - public void CloneState() - { - if (Interlocked.Increment(ref _parent.StateRefCount) > 1) - { - State = (TState)_parent.StateRef.Clone(); + InitializeAnomalyDetector(); } - } - public ITransformer GetTransformer() - { - return _parent; + /// + /// The abstract method that realizes the initialization functionality for the anomaly detector. + /// + private protected abstract void InitializeAnomalyDetector(); + + /// + /// The abstract method that realizes the main logic for calculating the raw anomaly score bfor the current input given a windowed buffer + /// + /// A reference to the input object. + /// A reference to the windowed buffer. + /// A long number that indicates the number of times ComputeRawAnomalyScore has been called so far (starting value = 0). + /// The raw anomaly score for the input. The Assumption is the higher absolute value of the raw score, the more anomalous the input is. + /// The sign of the score determines whether it's a positive anomaly or a negative one. + private protected abstract Double ComputeRawAnomalyScore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration); } } - } } diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs index 79e2c81846..728d3bc710 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs @@ -11,7 +11,7 @@ using Microsoft.ML.Model; using Microsoft.ML.Transforms; -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { /// /// The box class that is used to box the TInput and TOutput for the LambdaTransform. @@ -38,7 +38,7 @@ public DataBox(T value) /// The input type of the sequential processing. /// The dst type of the sequential processing. /// The state type of the sequential processing. Must be a class inherited from StateBase - public abstract class SequentialTransformBase : TransformBase + internal abstract class SequentialTransformBase : TransformBase where TState : SequentialTransformBase.StateBase, new() { /// @@ -302,7 +302,7 @@ private protected SequentialTransformBase(IHostEnvironment env, ModelLoadContext _transform = CreateLambdaTransform(Host, input, OutputColumnName, InputColumnName, InitFunction, WindowSize > 0, ct); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(InitialWindowSize >= 0); diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs index dd25afe637..1c795b1bee 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs @@ -13,25 +13,25 @@ using Microsoft.ML.Model; using Microsoft.ML.Model.Onnx; using Microsoft.ML.Model.Pfa; -using Microsoft.ML.TimeSeries; -using Microsoft.ML.Transforms; -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { + /// /// The base class for sequential processing transforms. This class implements the basic sliding window buffering. The derived classes need to specify the transform logic, /// the initialization logic and the learning logic via implementing the abstract methods TransformCore(), InitializeStateCore() and LearnStateFromDataCore(), respectively /// /// The input type of the sequential processing. /// The dst type of the sequential processing. - /// The state type of the sequential processing. Must be a class inherited from StateBase - public abstract class SequentialTransformerBase : IStatefulTransformer, ICanSaveModel - where TState : SequentialTransformerBase.StateBase, new() + /// The dst type of the sequential processing. + internal abstract class SequentialTransformerBase : IStatefulTransformer + where TState : SequentialTransformerBase.StateBase, new() { + /// /// The base class for encapsulating the State object for sequential processing. This class implements a windowed buffer. /// - public abstract class StateBase + internal class StateBase { // Ideally this class should be private. However, due to the current constraints with the LambdaTransform, we need to have // access to the state class when inheriting from SequentialTransformerBase. @@ -61,7 +61,7 @@ public abstract class StateBase /// protected long RowCounter { get; private set; } - private protected StateBase() + public StateBase() { } @@ -204,7 +204,7 @@ public void ProcessWithoutBuffer(ref TInput input, ref TOutput output) /// The abstract method that specifies the NA value for the dst type. /// /// - private protected abstract void SetNaOutput(ref TOutput dst); + private protected virtual void SetNaOutput(ref TOutput dst) { } /// /// The abstract method that realizes the main logic for the transform. @@ -213,41 +213,52 @@ public void ProcessWithoutBuffer(ref TInput input, ref TOutput output) /// A reference to the dst object. /// A reference to the windowed buffer. /// A long number that indicates the number of times TransformCore has been called so far (starting value = 0). - private protected abstract void TransformCore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration, ref TOutput dst); + private protected virtual void TransformCore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration, ref TOutput dst) + { + + } /// /// The abstract method that realizes the logic for initializing the state object. /// - private protected abstract void InitializeStateCore(bool disk = false); + private protected virtual void InitializeStateCore(bool disk = false) + { + + } /// /// The abstract method that realizes the logic for learning the parameters and the initial state object from data. /// /// A queue of data points used for training - private protected abstract void LearnStateFromDataCore(FixedSizeQueue data); + private protected virtual void LearnStateFromDataCore(FixedSizeQueue data) + { + } - public abstract void Consume(TInput value); + public virtual void Consume(TInput value) + { - public StateBase Clone() + } + + public TState Clone() { - var clone = (StateBase)MemberwiseClone(); + var clone = (TState)MemberwiseClone(); CloneCore(clone); return clone; } - private protected virtual void CloneCore(StateBase state) + private protected virtual void CloneCore(TState state) { state.WindowedBuffer = WindowedBuffer.Clone(); state.InitialWindowedBuffer = InitialWindowedBuffer.Clone(); } } - private protected readonly IHost Host; + internal readonly IHost Host; /// /// The window size for buffering. /// - private protected readonly int WindowSize; + internal readonly int WindowSize; /// /// The number of datapoints from the beginning of the sequence that are used for learning the initial state. @@ -260,7 +271,7 @@ private protected virtual void CloneCore(StateBase state) public bool IsRowToRowMapper => false; - public TState StateRef { get; set; } + internal TState StateRef { get; set; } public int StateRefCount; @@ -320,7 +331,9 @@ private protected SequentialTransformerBase(IHost host, ModelLoadContext ctx) OutputColumnType = bs.LoadTypeDescriptionOrNull(ctx.Reader.BaseStream); } - public virtual void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected virtual void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(InitialWindowSize >= 0); @@ -344,9 +357,9 @@ public virtual void Save(ModelSaveContext ctx) public abstract Schema GetOutputSchema(Schema inputSchema); - private protected abstract IStatefulRowMapper MakeRowMapper(Schema schema); + internal abstract IStatefulRowMapper MakeRowMapper(Schema schema); - private protected SequentialDataTransform MakeDataTransform(IDataView input) + internal SequentialDataTransform MakeDataTransform(IDataView input) { Host.CheckValue(input, nameof(input)); return new SequentialDataTransform(Host, this, input, MakeRowMapper(input.Schema)); @@ -450,9 +463,9 @@ protected override RowCursor GetRowCursorCore(IEnumerable columns public override RowCursor[] GetRowCursorSet(IEnumerable columnsNeeded, int n, Random rand = null) => new RowCursor[] { GetRowCursorCore(columnsNeeded, rand) }; - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { - _parent.Save(ctx); + (_parent as ICanSaveModel).Save(ctx); } IDataTransform ITransformTemplate.ApplyToData(IHostEnvironment env, IDataView newSource) @@ -637,7 +650,7 @@ public static TimeSeriesRowToRowMapperTransform Create(IHostEnvironment env, Mod return h.Apply("Loading Model", ch => new TimeSeriesRowToRowMapperTransform(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.TimeSeries/SlidingWindowTransform.cs b/src/Microsoft.ML.TimeSeries/SlidingWindowTransform.cs index 96cddd0207..103d63f0fa 100644 --- a/src/Microsoft.ML.TimeSeries/SlidingWindowTransform.cs +++ b/src/Microsoft.ML.TimeSeries/SlidingWindowTransform.cs @@ -7,19 +7,19 @@ using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.Model; -using Microsoft.ML.TimeSeriesProcessing; +using Microsoft.ML.Transforms.TimeSeries; [assembly: LoadableClass(SlidingWindowTransform.Summary, typeof(SlidingWindowTransform), typeof(SlidingWindowTransform.Arguments), typeof(SignatureDataTransform), SlidingWindowTransform.UserName, SlidingWindowTransform.LoaderSignature, SlidingWindowTransform.ShortName)] [assembly: LoadableClass(SlidingWindowTransform.Summary, typeof(SlidingWindowTransform), null, typeof(SignatureLoadDataTransform), SlidingWindowTransform.UserName, SlidingWindowTransform.LoaderSignature)] -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { /// /// Outputs a sliding window on a time series of type Single. /// - public sealed class SlidingWindowTransform : SlidingWindowTransformBase + internal sealed class SlidingWindowTransform : SlidingWindowTransformBase { public const string Summary = "Returns the last values for a time series [y(t-d-l+1), y(t-d-l+2), ..., y(t-l-1), y(t-l)] where d is the size of the window, l the lag and y is a Float."; public const string LoaderSignature = "SlideWinTransform"; @@ -49,7 +49,7 @@ public SlidingWindowTransform(IHostEnvironment env, ModelLoadContext ctx, IDataV // } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -57,7 +57,7 @@ public override void Save(ModelSaveContext ctx) // *** Binary format *** // - base.Save(ctx); + base.SaveModel(ctx); } } } diff --git a/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs b/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs index 1555c98afa..677ddc2488 100644 --- a/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs @@ -6,13 +6,11 @@ using Microsoft.Data.DataView; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; -using Microsoft.ML.Data.Conversion; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; -using Microsoft.ML.Transforms; -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { /// /// SlidingWindowTransformBase outputs a sliding window as a VBuffer from a series of any type. @@ -22,7 +20,7 @@ namespace Microsoft.ML.TimeSeriesProcessing /// and l is the delay. /// - public abstract class SlidingWindowTransformBase : SequentialTransformBase, SlidingWindowTransformBase.StateSlide> + internal abstract class SlidingWindowTransformBase : SequentialTransformBase, SlidingWindowTransformBase.StateSlide> { /// /// Defines what should be done about the first rows. @@ -104,13 +102,13 @@ private TInput GetNaValue() int index; sch.TryGetColumnIndex(InputColumnName, out index); ColumnType col = sch[index].Type; - TInput nanValue = Conversions.Instance.GetNAOrDefault(col); + TInput nanValue = Data.Conversion.Conversions.Instance.GetNAOrDefault(col); // We store the nan_value here to avoid getting it each time a state is instanciated. return nanValue; } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(WindowSize >= 1); @@ -123,7 +121,7 @@ public override void Save(ModelSaveContext ctx) // Int32 lag // byte begin - base.Save(ctx); + base.SaveModel(ctx); ctx.Writer.Write(_lag); ctx.Writer.Write((byte)_begin); } diff --git a/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs index eac9a8ebeb..3a7411693f 100644 --- a/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs +++ b/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs @@ -6,31 +6,32 @@ using System.IO; using Microsoft.Data.DataView; using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; -using Microsoft.ML.TimeSeries; +using Microsoft.ML.Transforms.TimeSeries; -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { + public enum ErrorFunction : byte + { + SignedDifference, + AbsoluteDifference, + SignedProportion, + AbsoluteProportion, + SquaredDifference + } + /// /// Provides the utility functions for different error functions for computing deviation. /// - public static class ErrorFunctionUtils + internal static class ErrorFunctionUtils { public const string ErrorFunctionHelpText = "The error function should be either (0) SignedDifference, (1) AbsoluteDifference, (2) SignedProportion" + " (3) AbsoluteProportion or (4) SquaredDifference."; - public enum ErrorFunction : byte - { - SignedDifference, - AbsoluteDifference, - SignedProportion, - AbsoluteProportion, - SquaredDifference - } - - public static Double SignedDifference(Double actual, Double predicted) + public static double SignedDifference(double actual, Double predicted) { return actual - predicted; } @@ -82,12 +83,70 @@ public static Func GetErrorFunction(ErrorFunction errorF } /// - /// This base class that implements the general anomaly detection transform based on Singular Spectrum modeling of the time-series. + /// The wrapper to that implements the general anomaly detection transform based on Singular Spectrum modeling of the time-series. /// For the details of the Singular Spectrum Analysis (SSA), refer to http://arxiv.org/pdf/1206.6910.pdf. /// - public abstract class SsaAnomalyDetectionBase : SequentialAnomalyDetectionTransformBase + public class SsaAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveModel { - public abstract class SsaArguments : ArgumentsBase + /// + /// Whether a call to should succeed, on an + /// appropriate schema. + /// + public bool IsRowToRowMapper => InternalTransform.IsRowToRowMapper; + + /// + /// Creates a clone of the transfomer. Used for taking the snapshot of the state. + /// + /// + IStatefulTransformer IStatefulTransformer.Clone() => InternalTransform.Clone(); + + /// + /// Schema propagation for transformers. + /// Returns the output schema of the data, if the input schema is like the one provided. + /// + public Schema GetOutputSchema(Schema inputSchema) => InternalTransform.GetOutputSchema(inputSchema); + + /// + /// Constructs a row-to-row mapper based on an input schema. If + /// is false, then an exception should be thrown. If the input schema is in any way + /// unsuitable for constructing the mapper, an exception should likewise be thrown. + /// + /// The input schema for which we should get the mapper. + /// The row to row mapper. + public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) => InternalTransform.GetRowToRowMapper(inputSchema); + + /// + /// Same as but also supports mechanism to save the state. + /// + /// The input schema for which we should get the mapper. + /// The row to row mapper. + public IRowToRowMapper GetStatefulRowToRowMapper(Schema inputSchema) => ((IStatefulTransformer)InternalTransform).GetStatefulRowToRowMapper(inputSchema); + + /// + /// Take the data in, make transformations, output the data. + /// Note that 's are lazy, so no actual transformations happen here, just schema validation. + /// + public IDataView Transform(IDataView input) => InternalTransform.Transform(input); + + /// + /// For saving a model into a repository. + /// + public virtual void Save(ModelSaveContext ctx) => InternalTransform.SaveThis(ctx); + + /// + /// Creates a row mapper from Schema. + /// + internal IStatefulRowMapper MakeRowMapper(Schema schema) => InternalTransform.MakeRowMapper(schema); + + /// + /// Creates an IDataTransform from an IDataView. + /// + internal IDataTransform MakeDataTransform(IDataView input) => InternalTransform.MakeDataTransform(input); + + /// + /// Options for SSA Anomaly Detection. + /// + internal abstract class SsaOptions : ArgumentsBase { [Argument(ArgumentType.Required, HelpText = "The inner window size for SSA in [2, windowSize]", ShortName = "swnd", SortOrder = 11)] public int SeasonalWindowSize; @@ -96,178 +155,204 @@ public abstract class SsaArguments : ArgumentsBase public Single DiscountFactor = 1; [Argument(ArgumentType.AtMostOnce, HelpText = "The function used to compute the error between the expected and the observed value", ShortName = "err", SortOrder = 13)] - public ErrorFunctionUtils.ErrorFunction ErrorFunction = ErrorFunctionUtils.ErrorFunction.SignedDifference; + public ErrorFunction ErrorFunction = ErrorFunction.SignedDifference; [Argument(ArgumentType.AtMostOnce, HelpText = "The flag determing whether the model is adaptive", ShortName = "adp", SortOrder = 14)] public bool IsAdaptive = false; } - protected readonly int SeasonalWindowSize; - protected readonly Single DiscountFactor; - protected readonly bool IsAdaptive; - protected readonly ErrorFunctionUtils.ErrorFunction ErrorFunction; - protected readonly Func ErrorFunc; - protected SequenceModelerBase Model; + internal SsaAnomalyDetectionBase InternalTransform; - public SsaAnomalyDetectionBase(SsaArguments args, string name, IHostEnvironment env) - : base(args.WindowSize, 0, args.Source, args.Name, name, env, args.Side, args.Martingale, args.AlertOn, args.PowerMartingaleEpsilon, args.AlertThreshold) + internal SsaAnomalyDetectionBaseWrapper(SsaOptions options, string name, IHostEnvironment env) { - Host.CheckUserArg(2 <= args.SeasonalWindowSize, nameof(args.SeasonalWindowSize), "Must be at least 2."); - Host.CheckUserArg(0 <= args.DiscountFactor && args.DiscountFactor <= 1, nameof(args.DiscountFactor), "Must be in the range [0, 1]."); - Host.CheckUserArg(Enum.IsDefined(typeof(ErrorFunctionUtils.ErrorFunction), args.ErrorFunction), nameof(args.ErrorFunction), ErrorFunctionUtils.ErrorFunctionHelpText); - - SeasonalWindowSize = args.SeasonalWindowSize; - DiscountFactor = args.DiscountFactor; - ErrorFunction = args.ErrorFunction; - ErrorFunc = ErrorFunctionUtils.GetErrorFunction(ErrorFunction); - IsAdaptive = args.IsAdaptive; - // Creating the master SSA model - Model = new AdaptiveSingularSpectrumSequenceModeler(Host, args.InitialWindowSize, SeasonalWindowSize + 1, SeasonalWindowSize, - DiscountFactor, AdaptiveSingularSpectrumSequenceModeler.RankSelectionMethod.Exact, null, SeasonalWindowSize / 2, false, false); - - StateRef = new State(); - StateRef.InitState(WindowSize, InitialWindowSize, this, Host); + InternalTransform = new SsaAnomalyDetectionBase(options, name, env, this); } - public SsaAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name) - : base(env, ctx, name) + internal SsaAnomalyDetectionBaseWrapper(IHostEnvironment env, ModelLoadContext ctx, string name) { - // *** Binary format *** - // - // int: _seasonalWindowSize - // float: _discountFactor - // byte: _errorFunction - // bool: _isAdaptive - // AdaptiveSingularSpectrumSequenceModeler: _model - - Host.CheckDecode(InitialWindowSize == 0); - - SeasonalWindowSize = ctx.Reader.ReadInt32(); - Host.CheckDecode(2 <= SeasonalWindowSize); - - DiscountFactor = ctx.Reader.ReadSingle(); - Host.CheckDecode(0 <= DiscountFactor && DiscountFactor <= 1); - - byte temp; - temp = ctx.Reader.ReadByte(); - Host.CheckDecode(Enum.IsDefined(typeof(ErrorFunctionUtils.ErrorFunction), temp)); - ErrorFunction = (ErrorFunctionUtils.ErrorFunction)temp; - ErrorFunc = ErrorFunctionUtils.GetErrorFunction(ErrorFunction); - - IsAdaptive = ctx.Reader.ReadBoolean(); - StateRef = new State(ctx.Reader); - - ctx.LoadModel, SignatureLoadModel>(env, out Model, "SSA"); - Host.CheckDecode(Model != null); - StateRef.InitState(this, Host); + InternalTransform = new SsaAnomalyDetectionBase(env, ctx, name); } - public override Schema GetOutputSchema(Schema inputSchema) + /// + /// This base class that implements the general anomaly detection transform based on Singular Spectrum modeling of the time-series. + /// For the details of the Singular Spectrum Analysis (SSA), refer to http://arxiv.org/pdf/1206.6910.pdf. + /// + internal sealed class SsaAnomalyDetectionBase : SequentialAnomalyDetectionTransformBase { - Host.CheckValue(inputSchema, nameof(inputSchema)); - - if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName); + internal SsaAnomalyDetectionBaseWrapper Parent; + internal readonly int SeasonalWindowSize; + internal readonly Single DiscountFactor; + internal readonly bool IsAdaptive; + internal readonly ErrorFunction ErrorFunction; + internal readonly Func ErrorFunc; + internal SequenceModelerBase Model; + + public SsaAnomalyDetectionBase(SsaOptions options, string name, IHostEnvironment env, SsaAnomalyDetectionBaseWrapper parent) + : base(options.WindowSize, 0, options.Source, options.Name, name, env, options.Side, options.Martingale, options.AlertOn, options.PowerMartingaleEpsilon, options.AlertThreshold) + { + Host.CheckUserArg(2 <= options.SeasonalWindowSize, nameof(options.SeasonalWindowSize), "Must be at least 2."); + Host.CheckUserArg(0 <= options.DiscountFactor && options.DiscountFactor <= 1, nameof(options.DiscountFactor), "Must be in the range [0, 1]."); + Host.CheckUserArg(Enum.IsDefined(typeof(ErrorFunction), options.ErrorFunction), nameof(options.ErrorFunction), ErrorFunctionUtils.ErrorFunctionHelpText); + + SeasonalWindowSize = options.SeasonalWindowSize; + DiscountFactor = options.DiscountFactor; + ErrorFunction = options.ErrorFunction; + ErrorFunc = ErrorFunctionUtils.GetErrorFunction(ErrorFunction); + IsAdaptive = options.IsAdaptive; + // Creating the master SSA model + Model = new AdaptiveSingularSpectrumSequenceModeler(Host, options.InitialWindowSize, SeasonalWindowSize + 1, SeasonalWindowSize, + DiscountFactor, AdaptiveSingularSpectrumSequenceModeler.RankSelectionMethod.Exact, null, SeasonalWindowSize / 2, false, false); + + StateRef = new State(); + StateRef.InitState(WindowSize, InitialWindowSize, this, Host); + Parent = parent; + } - var colType = inputSchema[col].Type; - if (colType != NumberType.R4) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, "float", colType.ToString()); + public SsaAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name) + : base(env, ctx, name) + { + // *** Binary format *** + // + // int: _seasonalWindowSize + // float: _discountFactor + // byte: _errorFunction + // bool: _isAdaptive + // AdaptiveSingularSpectrumSequenceModeler: _model + + Host.CheckDecode(InitialWindowSize == 0); + + SeasonalWindowSize = ctx.Reader.ReadInt32(); + Host.CheckDecode(2 <= SeasonalWindowSize); + + DiscountFactor = ctx.Reader.ReadSingle(); + Host.CheckDecode(0 <= DiscountFactor && DiscountFactor <= 1); + + byte temp; + temp = ctx.Reader.ReadByte(); + Host.CheckDecode(Enum.IsDefined(typeof(ErrorFunction), temp)); + ErrorFunction = (ErrorFunction)temp; + ErrorFunc = ErrorFunctionUtils.GetErrorFunction(ErrorFunction); + + IsAdaptive = ctx.Reader.ReadBoolean(); + StateRef = new State(ctx.Reader); + + ctx.LoadModel, SignatureLoadModel>(env, out Model, "SSA"); + Host.CheckDecode(Model != null); + StateRef.InitState(this, Host); + } - return Transform(new EmptyDataView(Host, inputSchema)).Schema; - } + public override Schema GetOutputSchema(Schema inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); - public override void Save(ModelSaveContext ctx) - { - Host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(); - - Host.Assert(InitialWindowSize == 0); - Host.Assert(2 <= SeasonalWindowSize); - Host.Assert(0 <= DiscountFactor && DiscountFactor <= 1); - Host.Assert(Enum.IsDefined(typeof(ErrorFunctionUtils.ErrorFunction), ErrorFunction)); - Host.Assert(Model != null); - - // *** Binary format *** - // - // int: _seasonalWindowSize - // float: _discountFactor - // byte: _errorFunction - // bool: _isAdaptive - // State: StateRef - // AdaptiveSingularSpectrumSequenceModeler: _model - - base.Save(ctx); - ctx.Writer.Write(SeasonalWindowSize); - ctx.Writer.Write(DiscountFactor); - ctx.Writer.Write((byte)ErrorFunction); - ctx.Writer.Write(IsAdaptive); - StateRef.Save(ctx.Writer); - - ctx.SaveModel(Model, "SSA"); - } + if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName); - public sealed class State : AnomalyDetectionStateBase - { - private SequenceModelerBase _model; - private SsaAnomalyDetectionBase _parentAnomalyDetector; + var colType = inputSchema[col].Type; + if (colType != NumberType.R4) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, "float", colType.ToString()); - public State() - { + return Transform(new EmptyDataView(Host, inputSchema)).Schema; } - internal State(BinaryReader reader) : base(reader) + private protected override void SaveModel(ModelSaveContext ctx) { - WindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host); - InitialWindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host); + Parent.Save(ctx); } - internal override void Save(BinaryWriter writer) + internal void SaveThis(ModelSaveContext ctx) { - base.Save(writer); - TimeSeriesUtils.SerializeFixedSizeQueue(WindowedBuffer, writer); - TimeSeriesUtils.SerializeFixedSizeQueue(InitialWindowedBuffer, writer); + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + + Host.Assert(InitialWindowSize == 0); + Host.Assert(2 <= SeasonalWindowSize); + Host.Assert(0 <= DiscountFactor && DiscountFactor <= 1); + Host.Assert(Enum.IsDefined(typeof(ErrorFunction), ErrorFunction)); + Host.Assert(Model != null); + + // *** Binary format *** + // + // int: _seasonalWindowSize + // float: _discountFactor + // byte: _errorFunction + // bool: _isAdaptive + // State: StateRef + // AdaptiveSingularSpectrumSequenceModeler: _model + + base.SaveModel(ctx); + ctx.Writer.Write(SeasonalWindowSize); + ctx.Writer.Write(DiscountFactor); + ctx.Writer.Write((byte)ErrorFunction); + ctx.Writer.Write(IsAdaptive); + StateRef.Save(ctx.Writer); + + ctx.SaveModel(Model, "SSA"); } - private protected override void CloneCore(StateBase state) + internal sealed class State : AnomalyDetectionStateBase { - base.CloneCore(state); - Contracts.Assert(state is State); - var stateLocal = state as State; - stateLocal.WindowedBuffer = WindowedBuffer.Clone(); - stateLocal.InitialWindowedBuffer = InitialWindowedBuffer.Clone(); - if (_model != null) + private SequenceModelerBase _model; + private SsaAnomalyDetectionBase _parentAnomalyDetector; + + public State() { - _parentAnomalyDetector.Model = _parentAnomalyDetector.Model.Clone(); - _model = _parentAnomalyDetector.Model; } - } - private protected override void LearnStateFromDataCore(FixedSizeQueue data) - { - // This method is empty because there is no need to implement a training logic here. - } + internal State(BinaryReader reader) : base(reader) + { + WindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host); + InitialWindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host); + } - private protected override void InitializeAnomalyDetector() - { - _parentAnomalyDetector = (SsaAnomalyDetectionBase)Parent; - _model = _parentAnomalyDetector.Model; - } + internal override void Save(BinaryWriter writer) + { + base.Save(writer); + TimeSeriesUtils.SerializeFixedSizeQueue(WindowedBuffer, writer); + TimeSeriesUtils.SerializeFixedSizeQueue(InitialWindowedBuffer, writer); + } - private protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue windowedBuffer, long iteration) - { - // Get the prediction for the next point opn the series - Single expectedValue = 0; - _model.PredictNext(ref expectedValue); + private protected override void CloneCore(State state) + { + base.CloneCore(state); + Contracts.Assert(state is State); + var stateLocal = state as State; + stateLocal.WindowedBuffer = WindowedBuffer.Clone(); + stateLocal.InitialWindowedBuffer = InitialWindowedBuffer.Clone(); + if (_model != null) + { + _parentAnomalyDetector.Model = _parentAnomalyDetector.Model.Clone(); + _model = _parentAnomalyDetector.Model; + } + } - if (PreviousPosition == -1) - // Feed the current point to the model - _model.Consume(ref input, _parentAnomalyDetector.IsAdaptive); + private protected override void LearnStateFromDataCore(FixedSizeQueue data) + { + // This method is empty because there is no need to implement a training logic here. + } - // Return the error as the raw anomaly score - return _parentAnomalyDetector.ErrorFunc(input, expectedValue); - } + private protected override void InitializeAnomalyDetector() + { + _parentAnomalyDetector = (SsaAnomalyDetectionBase)Parent; + _model = _parentAnomalyDetector.Model; + } + + private protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue windowedBuffer, long iteration) + { + // Get the prediction for the next point opn the series + Single expectedValue = 0; + _model.PredictNext(ref expectedValue); + + if (PreviousPosition == -1) + // Feed the current point to the model + _model.Consume(ref input, _parentAnomalyDetector.IsAdaptive); - public override void Consume(Single input) => _model.Consume(ref input, _parentAnomalyDetector.IsAdaptive); + // Return the error as the raw anomaly score + return _parentAnomalyDetector.ErrorFunc(input, expectedValue); + } + + public override void Consume(Single input) => _model.Consume(ref input, _parentAnomalyDetector.IsAdaptive); + } } } } diff --git a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs index 3f5e3cb009..e6ea241e26 100644 --- a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs +++ b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs @@ -12,11 +12,9 @@ using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; using Microsoft.ML.Model; -using Microsoft.ML.TimeSeries; -using Microsoft.ML.TimeSeriesProcessing; -using static Microsoft.ML.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase; +using Microsoft.ML.Transforms.TimeSeries; -[assembly: LoadableClass(SsaChangePointDetector.Summary, typeof(IDataTransform), typeof(SsaChangePointDetector), typeof(SsaChangePointDetector.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(SsaChangePointDetector.Summary, typeof(IDataTransform), typeof(SsaChangePointDetector), typeof(SsaChangePointDetector.Options), typeof(SignatureDataTransform), SsaChangePointDetector.UserName, SsaChangePointDetector.LoaderSignature, SsaChangePointDetector.ShortName)] [assembly: LoadableClass(SsaChangePointDetector.Summary, typeof(IDataTransform), typeof(SsaChangePointDetector), null, typeof(SignatureLoadDataTransform), @@ -28,20 +26,20 @@ [assembly: LoadableClass(typeof(IRowMapper), typeof(SsaChangePointDetector), null, typeof(SignatureLoadRowMapper), SsaChangePointDetector.UserName, SsaChangePointDetector.LoaderSignature)] -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { /// /// This class implements the change point detector transform based on Singular Spectrum modeling of the time-series. /// For the details of the Singular Spectrum Analysis (SSA), refer to http://arxiv.org/pdf/1206.6910.pdf. /// - public sealed class SsaChangePointDetector : SsaAnomalyDetectionBase + public sealed class SsaChangePointDetector : SsaAnomalyDetectionBaseWrapper, IStatefulTransformer { internal const string Summary = "This transform detects the change-points in a seasonal time-series using Singular Spectrum Analysis (SSA)."; internal const string LoaderSignature = "SsaChangePointDetector"; internal const string UserName = "SSA Change Point Detection"; internal const string ShortName = "chgpnt"; - public sealed class Arguments : TransformInputBase + internal sealed class Options : TransformInputBase { [Argument(ArgumentType.Required, HelpText = "The name of the source column.", ShortName = "src", SortOrder = 1, Purpose = SpecialPurpose.ColumnName)] @@ -67,32 +65,32 @@ public sealed class Arguments : TransformInputBase public int SeasonalWindowSize = 10; [Argument(ArgumentType.AtMostOnce, HelpText = "The function used to compute the error between the expected and the observed value.", ShortName = "err", SortOrder = 103)] - public ErrorFunctionUtils.ErrorFunction ErrorFunction = ErrorFunctionUtils.ErrorFunction.SignedDifference; + public ErrorFunction ErrorFunction = ErrorFunction.SignedDifference; [Argument(ArgumentType.AtMostOnce, HelpText = "The martingale used for scoring.", ShortName = "mart", SortOrder = 104)] - public MartingaleType Martingale = SequentialAnomalyDetectionTransformBase.MartingaleType.Power; + public MartingaleType Martingale = MartingaleType.Power; [Argument(ArgumentType.AtMostOnce, HelpText = "The epsilon parameter for the Power martingale.", ShortName = "eps", SortOrder = 105)] public double PowerMartingaleEpsilon = 0.1; } - private sealed class BaseArguments : SsaArguments + private sealed class BaseArguments : SsaOptions { - public BaseArguments(Arguments args) + public BaseArguments(Options options) { - Source = args.Source; - Name = args.Name; - Side = SequentialAnomalyDetectionTransformBase.AnomalySide.TwoSided; - WindowSize = args.ChangeHistoryLength; - InitialWindowSize = args.TrainingWindowSize; - SeasonalWindowSize = args.SeasonalWindowSize; - Martingale = args.Martingale; - PowerMartingaleEpsilon = args.PowerMartingaleEpsilon; - AlertOn = SequentialAnomalyDetectionTransformBase.AlertingScore.MartingaleScore; + Source = options.Source; + Name = options.Name; + Side = AnomalySide.TwoSided; + WindowSize = options.ChangeHistoryLength; + InitialWindowSize = options.TrainingWindowSize; + SeasonalWindowSize = options.SeasonalWindowSize; + Martingale = options.Martingale; + PowerMartingaleEpsilon = options.PowerMartingaleEpsilon; + AlertOn = AlertingScore.MartingaleScore; DiscountFactor = 1; IsAdaptive = false; - ErrorFunction = args.ErrorFunction; + ErrorFunction = options.ErrorFunction; } } @@ -106,48 +104,48 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(SsaChangePointDetector).Assembly.FullName); } - internal SsaChangePointDetector(IHostEnvironment env, Arguments args, IDataView input) - : this(env, args) + internal SsaChangePointDetector(IHostEnvironment env, Options options, IDataView input) + : this(env, options) { - Model.Train(new RoleMappedData(input, null, InputColumnName)); + InternalTransform.Model.Train(new RoleMappedData(input, null, InternalTransform.InputColumnName)); } // Factory method for SignatureDataTransform. - private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - return new SsaChangePointDetector(env, args, input).MakeDataTransform(input); + return new SsaChangePointDetector(env, options, input).MakeDataTransform(input); } - internal override IStatefulTransformer Clone() + IStatefulTransformer IStatefulTransformer.Clone() { var clone = (SsaChangePointDetector)MemberwiseClone(); - clone.Model = clone.Model.Clone(); - clone.StateRef = (State)clone.StateRef.Clone(); - clone.StateRef.InitState(clone, Host); + clone.InternalTransform.Model = clone.InternalTransform.Model.Clone(); + clone.InternalTransform.StateRef = (SsaAnomalyDetectionBase.State)clone.InternalTransform.StateRef.Clone(); + clone.InternalTransform.StateRef.InitState(clone.InternalTransform, InternalTransform.Host); return clone; } - internal SsaChangePointDetector(IHostEnvironment env, Arguments args) - : base(new BaseArguments(args), LoaderSignature, env) + internal SsaChangePointDetector(IHostEnvironment env, Options options) + : base(new BaseArguments(options), LoaderSignature, env) { - switch (Martingale) + switch (InternalTransform.Martingale) { case MartingaleType.None: - AlertThreshold = Double.MaxValue; + InternalTransform.AlertThreshold = Double.MaxValue; break; case MartingaleType.Power: - AlertThreshold = Math.Exp(WindowSize * LogPowerMartigaleBettingFunc(1 - args.Confidence / 100, PowerMartingaleEpsilon)); + InternalTransform.AlertThreshold = Math.Exp(InternalTransform.WindowSize * InternalTransform.LogPowerMartigaleBettingFunc(1 - options.Confidence / 100, InternalTransform.PowerMartingaleEpsilon)); break; case MartingaleType.Mixture: - AlertThreshold = Math.Exp(WindowSize * LogMixtureMartigaleBettingFunc(1 - args.Confidence / 100)); + InternalTransform.AlertThreshold = Math.Exp(InternalTransform.WindowSize * InternalTransform.LogMixtureMartigaleBettingFunc(1 - options.Confidence / 100)); break; default: - Host.Assert(!Enum.IsDefined(typeof(MartingaleType), Martingale)); - throw Host.ExceptUserArg(nameof(args.Martingale), "Value not defined."); + InternalTransform.Host.Assert(!Enum.IsDefined(typeof(MartingaleType), InternalTransform.Martingale)); + throw InternalTransform.Host.ExceptUserArg(nameof(options.Martingale), "Value not defined."); } } @@ -177,22 +175,22 @@ internal SsaChangePointDetector(IHostEnvironment env, ModelLoadContext ctx) // *** Binary format *** // - Host.CheckDecode(ThresholdScore == AlertingScore.MartingaleScore); - Host.CheckDecode(Side == AnomalySide.TwoSided); - Host.CheckDecode(DiscountFactor == 1); - Host.CheckDecode(IsAdaptive == false); + InternalTransform.Host.CheckDecode(InternalTransform.ThresholdScore == AlertingScore.MartingaleScore); + InternalTransform.Host.CheckDecode(InternalTransform.Side == AnomalySide.TwoSided); + InternalTransform.Host.CheckDecode(InternalTransform.DiscountFactor == 1); + InternalTransform.Host.CheckDecode(InternalTransform.IsAdaptive == false); } public override void Save(ModelSaveContext ctx) { - Host.CheckValue(ctx, nameof(ctx)); + InternalTransform.Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - Host.Assert(ThresholdScore == AlertingScore.MartingaleScore); - Host.Assert(Side == AnomalySide.TwoSided); - Host.Assert(DiscountFactor == 1); - Host.Assert(IsAdaptive == false); + InternalTransform.Host.Assert(InternalTransform.ThresholdScore == AlertingScore.MartingaleScore); + InternalTransform.Host.Assert(InternalTransform.Side == AnomalySide.TwoSided); + InternalTransform.Host.Assert(InternalTransform.DiscountFactor == 1); + InternalTransform.Host.Assert(InternalTransform.IsAdaptive == false); // *** Binary format *** // @@ -219,7 +217,7 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Sch public sealed class SsaChangePointEstimator : IEstimator { private readonly IHost _host; - private readonly SsaChangePointDetector.Arguments _args; + private readonly SsaChangePointDetector.Options _options; /// /// Create a new instance of @@ -235,16 +233,16 @@ public sealed class SsaChangePointEstimator : IEstimator /// The function used to compute the error between the expected and the observed value. /// The martingale used for scoring. /// The epsilon parameter for the Power martingale. - public SsaChangePointEstimator(IHostEnvironment env, string outputColumnName, + internal SsaChangePointEstimator(IHostEnvironment env, string outputColumnName, int confidence, int changeHistoryLength, int trainingWindowSize, int seasonalityWindowSize, string inputColumnName = null, - ErrorFunctionUtils.ErrorFunction errorFunction = ErrorFunctionUtils.ErrorFunction.SignedDifference, + ErrorFunction errorFunction = ErrorFunction.SignedDifference, MartingaleType martingale = MartingaleType.Power, double eps = 0.1) - : this(env, new SsaChangePointDetector.Arguments + : this(env, new SsaChangePointDetector.Options { Name = outputColumnName, Source = inputColumnName ?? outputColumnName, @@ -259,38 +257,45 @@ public SsaChangePointEstimator(IHostEnvironment env, string outputColumnName, { } - public SsaChangePointEstimator(IHostEnvironment env, SsaChangePointDetector.Arguments args) + internal SsaChangePointEstimator(IHostEnvironment env, SsaChangePointDetector.Options options) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(SsaChangePointEstimator)); - _host.CheckNonEmpty(args.Name, nameof(args.Name)); - _host.CheckNonEmpty(args.Source, nameof(args.Source)); + _host.CheckNonEmpty(options.Name, nameof(options.Name)); + _host.CheckNonEmpty(options.Source, nameof(options.Source)); - _args = args; + _options = options; } + /// + /// Train and return a transformer. + /// public SsaChangePointDetector Fit(IDataView input) { _host.CheckValue(input, nameof(input)); - return new SsaChangePointDetector(_host, _args, input); + return new SsaChangePointDetector(_host, _options, input); } + /// + /// Schema propagation for transformers. + /// Returns the output schema of the data, if the input schema is like the one provided. + /// public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - if (!inputSchema.TryFindColumn(_args.Source, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source); + if (!inputSchema.TryFindColumn(_options.Source, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _options.Source); if (col.ItemType != NumberType.R4) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source, "float", col.GetTypeString()); + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _options.Source, "float", col.GetTypeString()); var metadata = new List() { new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) }; var resultDic = inputSchema.ToDictionary(x => x.Name); - resultDic[_args.Name] = new SchemaShape.Column( - _args.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata)); + resultDic[_options.Name] = new SchemaShape.Column( + _options.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata)); return new SchemaShape(resultDic.Values); } diff --git a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs index 02b81458b3..399bc6b6e0 100644 --- a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs +++ b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs @@ -11,11 +11,9 @@ using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; using Microsoft.ML.Model; -using Microsoft.ML.TimeSeries; -using Microsoft.ML.TimeSeriesProcessing; -using static Microsoft.ML.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase; +using Microsoft.ML.Transforms.TimeSeries; -[assembly: LoadableClass(SsaSpikeDetector.Summary, typeof(IDataTransform), typeof(SsaSpikeDetector), typeof(SsaSpikeDetector.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(SsaSpikeDetector.Summary, typeof(IDataTransform), typeof(SsaSpikeDetector), typeof(SsaSpikeDetector.Options), typeof(SignatureDataTransform), SsaSpikeDetector.UserName, SsaSpikeDetector.LoaderSignature, SsaSpikeDetector.ShortName)] [assembly: LoadableClass(SsaSpikeDetector.Summary, typeof(IDataTransform), typeof(SsaSpikeDetector), null, typeof(SignatureLoadDataTransform), @@ -27,20 +25,20 @@ [assembly: LoadableClass(typeof(IRowMapper), typeof(SsaSpikeDetector), null, typeof(SignatureLoadRowMapper), SsaSpikeDetector.UserName, SsaSpikeDetector.LoaderSignature)] -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { /// /// This class implements the spike detector transform based on Singular Spectrum modeling of the time-series. /// For the details of the Singular Spectrum Analysis (SSA), refer to http://arxiv.org/pdf/1206.6910.pdf. /// - public sealed class SsaSpikeDetector : SsaAnomalyDetectionBase + public sealed class SsaSpikeDetector : SsaAnomalyDetectionBaseWrapper, IStatefulTransformer { internal const string Summary = "This transform detects the spikes in a seasonal time-series using Singular Spectrum Analysis (SSA)."; - public const string LoaderSignature = "SsaSpikeDetector"; - public const string UserName = "SSA Spike Detection"; - public const string ShortName = "spike"; + internal const string LoaderSignature = "SsaSpikeDetector"; + internal const string UserName = "SSA Spike Detection"; + internal const string ShortName = "spike"; - public sealed class Arguments : TransformInputBase + internal sealed class Options : TransformInputBase { [Argument(ArgumentType.Required, HelpText = "The name of the source column.", ShortName = "src", SortOrder = 1, Purpose = SpecialPurpose.ColumnName)] @@ -70,24 +68,24 @@ public sealed class Arguments : TransformInputBase public int SeasonalWindowSize = 10; [Argument(ArgumentType.AtMostOnce, HelpText = "The function used to compute the error between the expected and the observed value.", ShortName = "err", SortOrder = 103)] - public ErrorFunctionUtils.ErrorFunction ErrorFunction = ErrorFunctionUtils.ErrorFunction.SignedDifference; + public ErrorFunction ErrorFunction = ErrorFunction.SignedDifference; } - private sealed class BaseArguments : SsaArguments + private sealed class BaseArguments : SsaOptions { - public BaseArguments(Arguments args) + public BaseArguments(Options options) { - Source = args.Source; - Name = args.Name; - Side = args.Side; - WindowSize = args.PvalueHistoryLength; - InitialWindowSize = args.TrainingWindowSize; - SeasonalWindowSize = args.SeasonalWindowSize; - AlertThreshold = 1 - args.Confidence / 100; - AlertOn = SequentialAnomalyDetectionTransformBase.AlertingScore.PValueScore; + Source = options.Source; + Name = options.Name; + Side = options.Side; + WindowSize = options.PvalueHistoryLength; + InitialWindowSize = options.TrainingWindowSize; + SeasonalWindowSize = options.SeasonalWindowSize; + AlertThreshold = 1 - options.Confidence / 100; + AlertOn = AlertingScore.PValueScore; DiscountFactor = 1; IsAdaptive = false; - ErrorFunction = args.ErrorFunction; + ErrorFunction = options.ErrorFunction; Martingale = MartingaleType.None; } } @@ -103,24 +101,24 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(SsaSpikeDetector).Assembly.FullName); } - internal SsaSpikeDetector(IHostEnvironment env, Arguments args, IDataView input) - : base(new BaseArguments(args), LoaderSignature, env) + internal SsaSpikeDetector(IHostEnvironment env, Options options, IDataView input) + : base(new BaseArguments(options), LoaderSignature, env) { - Model.Train(new RoleMappedData(input, null, InputColumnName)); + InternalTransform.Model.Train(new RoleMappedData(input, null, InternalTransform.InputColumnName)); } // Factory method for SignatureDataTransform. - private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - return new SsaSpikeDetector(env, args, input).MakeDataTransform(input); + return new SsaSpikeDetector(env, options, input).MakeDataTransform(input); } - internal SsaSpikeDetector(IHostEnvironment env, Arguments args) - : base(new BaseArguments(args), LoaderSignature, env) + internal SsaSpikeDetector(IHostEnvironment env, Options options) + : base(new BaseArguments(options), LoaderSignature, env) { // This constructor is empty. } @@ -135,12 +133,12 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, return new SsaSpikeDetector(env, ctx).MakeDataTransform(input); } - internal override IStatefulTransformer Clone() + IStatefulTransformer IStatefulTransformer.Clone() { var clone = (SsaSpikeDetector)MemberwiseClone(); - clone.Model = clone.Model.Clone(); - clone.StateRef = (State)clone.StateRef.Clone(); - clone.StateRef.InitState(clone, Host); + clone.InternalTransform.Model = clone.InternalTransform.Model.Clone(); + clone.InternalTransform.StateRef = (SsaAnomalyDetectionBase.State)clone.InternalTransform.StateRef.Clone(); + clone.InternalTransform.StateRef.InitState(clone.InternalTransform, InternalTransform.Host); return clone; } @@ -160,20 +158,20 @@ internal SsaSpikeDetector(IHostEnvironment env, ModelLoadContext ctx) // *** Binary format *** // - Host.CheckDecode(ThresholdScore == AlertingScore.PValueScore); - Host.CheckDecode(DiscountFactor == 1); - Host.CheckDecode(IsAdaptive == false); + InternalTransform.Host.CheckDecode(InternalTransform.ThresholdScore == AlertingScore.PValueScore); + InternalTransform.Host.CheckDecode(InternalTransform.DiscountFactor == 1); + InternalTransform.Host.CheckDecode(InternalTransform.IsAdaptive == false); } public override void Save(ModelSaveContext ctx) { - Host.CheckValue(ctx, nameof(ctx)); + InternalTransform.Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - Host.Assert(ThresholdScore == AlertingScore.PValueScore); - Host.Assert(DiscountFactor == 1); - Host.Assert(IsAdaptive == false); + InternalTransform.Host.Assert(InternalTransform.ThresholdScore == AlertingScore.PValueScore); + InternalTransform.Host.Assert(InternalTransform.DiscountFactor == 1); + InternalTransform.Host.Assert(InternalTransform.IsAdaptive == false); // *** Binary format *** // @@ -200,7 +198,7 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Sch public sealed class SsaSpikeEstimator : IEstimator { private readonly IHost _host; - private readonly SsaSpikeDetector.Arguments _args; + private readonly SsaSpikeDetector.Options _options; /// /// Create a new instance of @@ -215,7 +213,7 @@ public sealed class SsaSpikeEstimator : IEstimator /// The vector contains Alert, Raw Score, P-Value as first three values. /// The argument that determines whether to detect positive or negative anomalies, or both. /// The function used to compute the error between the expected and the observed value. - public SsaSpikeEstimator(IHostEnvironment env, + internal SsaSpikeEstimator(IHostEnvironment env, string outputColumnName, int confidence, int pvalueHistoryLength, @@ -223,8 +221,8 @@ public SsaSpikeEstimator(IHostEnvironment env, int seasonalityWindowSize, string inputColumnName = null, AnomalySide side = AnomalySide.TwoSided, - ErrorFunctionUtils.ErrorFunction errorFunction = ErrorFunctionUtils.ErrorFunction.SignedDifference) - : this(env, new SsaSpikeDetector.Arguments + ErrorFunction errorFunction = ErrorFunction.SignedDifference) + : this(env, new SsaSpikeDetector.Options { Source = inputColumnName ?? outputColumnName, Name = outputColumnName, @@ -238,38 +236,45 @@ public SsaSpikeEstimator(IHostEnvironment env, { } - public SsaSpikeEstimator(IHostEnvironment env, SsaSpikeDetector.Arguments args) + internal SsaSpikeEstimator(IHostEnvironment env, SsaSpikeDetector.Options options) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(SsaSpikeEstimator)); - _host.CheckNonEmpty(args.Name, nameof(args.Name)); - _host.CheckNonEmpty(args.Source, nameof(args.Source)); + _host.CheckNonEmpty(options.Name, nameof(options.Name)); + _host.CheckNonEmpty(options.Source, nameof(options.Source)); - _args = args; + _options = options; } + /// + /// Train and return a transformer. + /// public SsaSpikeDetector Fit(IDataView input) { _host.CheckValue(input, nameof(input)); - return new SsaSpikeDetector(_host, _args, input); + return new SsaSpikeDetector(_host, _options, input); } + /// + /// Schema propagation for transformers. + /// Returns the output schema of the data, if the input schema is like the one provided. + /// public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - if (!inputSchema.TryFindColumn(_args.Source, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source); + if (!inputSchema.TryFindColumn(_options.Source, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _options.Source); if (col.ItemType != NumberType.R4) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source, "float", col.GetTypeString()); + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _options.Source, "float", col.GetTypeString()); var metadata = new List() { new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) }; var resultDic = inputSchema.ToDictionary(x => x.Name); - resultDic[_args.Name] = new SchemaShape.Column( - _args.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata)); + resultDic[_options.Name] = new SchemaShape.Column( + _options.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata)); return new SchemaShape(resultDic.Values); } diff --git a/src/Microsoft.ML.TimeSeries/TimeSeriesProcessing.cs b/src/Microsoft.ML.TimeSeries/TimeSeriesProcessing.cs index 9e4564baa7..e128a4730c 100644 --- a/src/Microsoft.ML.TimeSeries/TimeSeriesProcessing.cs +++ b/src/Microsoft.ML.TimeSeries/TimeSeriesProcessing.cs @@ -3,11 +3,11 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.EntryPoints; -using Microsoft.ML.TimeSeriesProcessing; +using Microsoft.ML.Transforms.TimeSeries; [assembly: EntryPointModule(typeof(TimeSeriesProcessingEntryPoints))] -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { /// /// Entry points for text anylytics transforms. @@ -15,7 +15,7 @@ namespace Microsoft.ML.TimeSeriesProcessing internal static class TimeSeriesProcessingEntryPoints { [TlcModule.EntryPoint(Desc = ExponentialAverageTransform.Summary, UserName = ExponentialAverageTransform.UserName, ShortName = ExponentialAverageTransform.ShortName)] - public static CommonOutputs.TransformOutput ExponentialAverage(IHostEnvironment env, ExponentialAverageTransform.Arguments input) + internal static CommonOutputs.TransformOutput ExponentialAverage(IHostEnvironment env, ExponentialAverageTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ExponentialAverageTransform", input); var xf = new ExponentialAverageTransform(h, input, input.Data); @@ -26,32 +26,38 @@ public static CommonOutputs.TransformOutput ExponentialAverage(IHostEnvironment }; } - [TlcModule.EntryPoint(Desc = TimeSeriesProcessing.IidChangePointDetector.Summary, UserName = TimeSeriesProcessing.IidChangePointDetector.UserName, ShortName = TimeSeriesProcessing.IidChangePointDetector.ShortName)] - public static CommonOutputs.TransformOutput IidChangePointDetector(IHostEnvironment env, IidChangePointDetector.Arguments input) + [TlcModule.EntryPoint(Desc = TimeSeries.IidChangePointDetector.Summary, + UserName = TimeSeries.IidChangePointDetector.UserName, + ShortName = TimeSeries.IidChangePointDetector.ShortName)] + internal static CommonOutputs.TransformOutput IidChangePointDetector(IHostEnvironment env, IidChangePointDetector.Options options) { - var h = EntryPointUtils.CheckArgsAndCreateHost(env, "IidChangePointDetector", input); - var view = new IidChangePointEstimator(h, input).Fit(input.Data).Transform(input.Data); + var h = EntryPointUtils.CheckArgsAndCreateHost(env, "IidChangePointDetector", options); + var view = new IidChangePointEstimator(h, options).Fit(options.Data).Transform(options.Data); return new CommonOutputs.TransformOutput() { - Model = new TransformModelImpl(h, view, input.Data), + Model = new TransformModelImpl(h, view, options.Data), OutputData = view }; } - [TlcModule.EntryPoint(Desc = TimeSeriesProcessing.IidSpikeDetector.Summary, UserName = TimeSeriesProcessing.IidSpikeDetector.UserName, ShortName = TimeSeriesProcessing.IidSpikeDetector.ShortName)] - public static CommonOutputs.TransformOutput IidSpikeDetector(IHostEnvironment env, IidSpikeDetector.Arguments input) + [TlcModule.EntryPoint(Desc = TimeSeries.IidSpikeDetector.Summary, + UserName = TimeSeries.IidSpikeDetector.UserName, + ShortName = TimeSeries.IidSpikeDetector.ShortName)] + internal static CommonOutputs.TransformOutput IidSpikeDetector(IHostEnvironment env, IidSpikeDetector.Options options) { - var h = EntryPointUtils.CheckArgsAndCreateHost(env, "IidSpikeDetector", input); - var view = new IidSpikeEstimator(h, input).Fit(input.Data).Transform(input.Data); + var h = EntryPointUtils.CheckArgsAndCreateHost(env, "IidSpikeDetector", options); + var view = new IidSpikeEstimator(h, options).Fit(options.Data).Transform(options.Data); return new CommonOutputs.TransformOutput() { - Model = new TransformModelImpl(h, view, input.Data), + Model = new TransformModelImpl(h, view, options.Data), OutputData = view }; } - [TlcModule.EntryPoint(Desc = TimeSeriesProcessing.PercentileThresholdTransform.Summary, UserName = TimeSeriesProcessing.PercentileThresholdTransform.UserName, ShortName = TimeSeriesProcessing.PercentileThresholdTransform.ShortName)] - public static CommonOutputs.TransformOutput PercentileThresholdTransform(IHostEnvironment env, PercentileThresholdTransform.Arguments input) + [TlcModule.EntryPoint(Desc = TimeSeries.PercentileThresholdTransform.Summary, + UserName = TimeSeries.PercentileThresholdTransform.UserName, + ShortName = TimeSeries.PercentileThresholdTransform.ShortName)] + internal static CommonOutputs.TransformOutput PercentileThresholdTransform(IHostEnvironment env, PercentileThresholdTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "PercentileThresholdTransform", input); var view = new PercentileThresholdTransform(h, input, input.Data); @@ -62,8 +68,10 @@ public static CommonOutputs.TransformOutput PercentileThresholdTransform(IHostEn }; } - [TlcModule.EntryPoint(Desc = TimeSeriesProcessing.PValueTransform.Summary, UserName = TimeSeriesProcessing.PValueTransform.UserName, ShortName = TimeSeriesProcessing.PValueTransform.ShortName)] - public static CommonOutputs.TransformOutput PValueTransform(IHostEnvironment env, PValueTransform.Arguments input) + [TlcModule.EntryPoint(Desc = TimeSeries.PValueTransform.Summary, + UserName = TimeSeries.PValueTransform.UserName, + ShortName = TimeSeries.PValueTransform.ShortName)] + internal static CommonOutputs.TransformOutput PValueTransform(IHostEnvironment env, PValueTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "PValueTransform", input); var view = new PValueTransform(h, input, input.Data); @@ -74,8 +82,10 @@ public static CommonOutputs.TransformOutput PValueTransform(IHostEnvironment env }; } - [TlcModule.EntryPoint(Desc = TimeSeriesProcessing.SlidingWindowTransform.Summary, UserName = TimeSeriesProcessing.SlidingWindowTransform.UserName, ShortName = TimeSeriesProcessing.SlidingWindowTransform.ShortName)] - public static CommonOutputs.TransformOutput SlidingWindowTransform(IHostEnvironment env, SlidingWindowTransform.Arguments input) + [TlcModule.EntryPoint(Desc = TimeSeries.SlidingWindowTransform.Summary, + UserName = TimeSeries.SlidingWindowTransform.UserName, + ShortName = TimeSeries.SlidingWindowTransform.ShortName)] + internal static CommonOutputs.TransformOutput SlidingWindowTransform(IHostEnvironment env, SlidingWindowTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "SlidingWindowTransform", input); var view = new SlidingWindowTransform(h, input, input.Data); @@ -86,26 +96,30 @@ public static CommonOutputs.TransformOutput SlidingWindowTransform(IHostEnvironm }; } - [TlcModule.EntryPoint(Desc = TimeSeriesProcessing.SsaChangePointDetector.Summary, UserName = TimeSeriesProcessing.SsaChangePointDetector.UserName, ShortName = TimeSeriesProcessing.SsaChangePointDetector.ShortName)] - public static CommonOutputs.TransformOutput SsaChangePointDetector(IHostEnvironment env, SsaChangePointDetector.Arguments input) + [TlcModule.EntryPoint(Desc = TimeSeries.SsaChangePointDetector.Summary, + UserName = TimeSeries.SsaChangePointDetector.UserName, + ShortName = TimeSeries.SsaChangePointDetector.ShortName)] + internal static CommonOutputs.TransformOutput SsaChangePointDetector(IHostEnvironment env, SsaChangePointDetector.Options options) { - var h = EntryPointUtils.CheckArgsAndCreateHost(env, "SsaChangePointDetector", input); - var view = new SsaChangePointEstimator(h, input).Fit(input.Data).Transform(input.Data); + var h = EntryPointUtils.CheckArgsAndCreateHost(env, "SsaChangePointDetector", options); + var view = new SsaChangePointEstimator(h, options).Fit(options.Data).Transform(options.Data); return new CommonOutputs.TransformOutput() { - Model = new TransformModelImpl(h, view, input.Data), + Model = new TransformModelImpl(h, view, options.Data), OutputData = view }; } - [TlcModule.EntryPoint(Desc = TimeSeriesProcessing.SsaSpikeDetector.Summary, UserName = TimeSeriesProcessing.SsaSpikeDetector.UserName, ShortName = TimeSeriesProcessing.SsaSpikeDetector.ShortName)] - public static CommonOutputs.TransformOutput SsaSpikeDetector(IHostEnvironment env, SsaSpikeDetector.Arguments input) + [TlcModule.EntryPoint(Desc = TimeSeries.SsaSpikeDetector.Summary, + UserName = TimeSeries.SsaSpikeDetector.UserName, + ShortName = TimeSeries.SsaSpikeDetector.ShortName)] + public static CommonOutputs.TransformOutput SsaSpikeDetector(IHostEnvironment env, SsaSpikeDetector.Options options) { - var h = EntryPointUtils.CheckArgsAndCreateHost(env, "SsaSpikeDetector", input); - var view = new SsaSpikeEstimator(h, input).Fit(input.Data).Transform(input.Data); + var h = EntryPointUtils.CheckArgsAndCreateHost(env, "SsaSpikeDetector", options); + var view = new SsaSpikeEstimator(h, options).Fit(options.Data).Transform(options.Data); return new CommonOutputs.TransformOutput() { - Model = new TransformModelImpl(h, view, input.Data), + Model = new TransformModelImpl(h, view, options.Data), OutputData = view }; } diff --git a/src/Microsoft.ML.TimeSeries/TimeSeriesUtils.cs b/src/Microsoft.ML.TimeSeries/TimeSeriesUtils.cs index 82334f7976..e8eb79f1ba 100644 --- a/src/Microsoft.ML.TimeSeries/TimeSeriesUtils.cs +++ b/src/Microsoft.ML.TimeSeries/TimeSeriesUtils.cs @@ -2,7 +2,7 @@ using System.IO; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.TimeSeries +namespace Microsoft.ML.Transforms.TimeSeries { internal static class TimeSeriesUtils { diff --git a/src/Microsoft.ML.TimeSeries/TrajectoryMatrix.cs b/src/Microsoft.ML.TimeSeries/TrajectoryMatrix.cs index 537b0f7307..2ea67dae96 100644 --- a/src/Microsoft.ML.TimeSeries/TrajectoryMatrix.cs +++ b/src/Microsoft.ML.TimeSeries/TrajectoryMatrix.cs @@ -5,7 +5,7 @@ using System; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.TimeSeriesProcessing +namespace Microsoft.ML.Transforms.TimeSeries { /// /// This class encapsulates the trajectory matrix of a time-series used in Singular Spectrum Analysis (SSA). @@ -27,7 +27,7 @@ namespace Microsoft.ML.TimeSeriesProcessing /// This class does not explicitly store the trajectory matrix though. Furthermore, since the trajectory matrix is /// a Hankel matrix, its multiplication by an arbitrary vector is implemented efficiently using the Discrete Fast Fourier Transform. /// - public sealed class TrajectoryMatrix + internal sealed class TrajectoryMatrix { /// /// The time series data diff --git a/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs b/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs index e639ad7a6d..609d9c842f 100644 --- a/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs +++ b/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs @@ -19,7 +19,7 @@ namespace Microsoft.ML.Transforms /// /// The type that describes what 'source' columns are consumed from the input . /// The type that describes what new columns are added by this transform. - public sealed class CustomMappingTransformer : ITransformer, ICanSaveModel + public sealed class CustomMappingTransformer : ITransformer where TSrc : class, new() where TDst : class, new() { @@ -60,7 +60,8 @@ public CustomMappingTransformer(IHostEnvironment env, Action mapActi AddedSchema = outSchema; } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + internal void SaveModel(ModelSaveContext ctx) { if (_contractName == null) throw _host.Except("Empty contract name for a transform: the transform cannot be saved"); @@ -174,8 +175,8 @@ Schema.DetachedColumn[] IRowMapper.GetOutputColumns() return Enumerable.Range(0, dstRow.Schema.Count).Select(x => new Schema.DetachedColumn(dstRow.Schema[x])).ToArray(); } - public void Save(ModelSaveContext ctx) - => _parent.Save(ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) + => _parent.SaveModel(ctx); public ITransformer GetTransformer() { diff --git a/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs b/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs index af04cd10d5..40915c484c 100644 --- a/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs +++ b/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs @@ -104,7 +104,7 @@ private GaussianFourierSampler(IHostEnvironment env, ModelLoadContext ctx) _host.CheckDecode(FloatUtils.IsFinite(_gamma)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { ctx.SetVersionInfo(GetVersionInfo()); @@ -185,7 +185,7 @@ private LaplacianFourierSampler(IHostEnvironment env, ModelLoadContext ctx) _host.CheckDecode(FloatUtils.IsFinite(_a)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.Transforms/GcnTransform.cs b/src/Microsoft.ML.Transforms/GcnTransform.cs index f48a4329b7..5c54d5eccc 100644 --- a/src/Microsoft.ML.Transforms/GcnTransform.cs +++ b/src/Microsoft.ML.Transforms/GcnTransform.cs @@ -306,7 +306,7 @@ private LpNormalizingTransformer(IHost host, ModelLoadContext ctx) _columns[i] = new ColumnInfoLoaded(ctx, ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName, ctx.Header.ModelVerWritten >= VerVectorNormalizerSupported); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.Transforms/GroupTransform.cs b/src/Microsoft.ML.Transforms/GroupTransform.cs index e36a2c4c4a..4c4ac807da 100644 --- a/src/Microsoft.ML.Transforms/GroupTransform.cs +++ b/src/Microsoft.ML.Transforms/GroupTransform.cs @@ -137,7 +137,7 @@ private GroupTransform(IHost host, ModelLoadContext ctx, IDataView input) _groupBinding = new GroupBinding(input.Schema, host, ctx); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -305,7 +305,7 @@ private Schema BuildOutputSchema(Schema sourceSchema) return schemaBuilder.GetSchema(); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { _ectx.AssertValue(ctx); diff --git a/src/Microsoft.ML.Transforms/HashJoiningTransform.cs b/src/Microsoft.ML.Transforms/HashJoiningTransform.cs index b0a4edbea8..61f2e86adb 100644 --- a/src/Microsoft.ML.Transforms/HashJoiningTransform.cs +++ b/src/Microsoft.ML.Transforms/HashJoiningTransform.cs @@ -285,7 +285,7 @@ public static HashJoiningTransform Create(IHostEnvironment env, ModelLoadContext return h.Apply("Loading Model", ch => new HashJoiningTransform(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs b/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs index 4364ba6097..036287d3ed 100644 --- a/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs +++ b/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs @@ -83,7 +83,7 @@ internal KeyToBinaryVectorMappingTransformer(IHostEnvironment env, params (strin { } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.Transforms/LambdaTransform.cs b/src/Microsoft.ML.Transforms/LambdaTransform.cs index 2907a41941..c823ae3999 100644 --- a/src/Microsoft.ML.Transforms/LambdaTransform.cs +++ b/src/Microsoft.ML.Transforms/LambdaTransform.cs @@ -204,7 +204,7 @@ public static ITransformTemplate CreateFilter(IHostEnvironment env /// * a custom save action that serializes the transform 'state' to the binary writer. /// * a custom load action that de-serializes the transform from the binary reader. This must be a public static method of a public class. /// - internal abstract class LambdaTransformBase + internal abstract class LambdaTransformBase : ICanSaveModel { private readonly Action _saveAction; private readonly byte[] _loadFuncBytes; @@ -248,7 +248,7 @@ protected LambdaTransformBase(IHostEnvironment env, string name, LambdaTransform AssertConsistentSerializable(); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Check(CanSave(), "Cannot save this transform as it was not specified as being savable"); diff --git a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs index a5813aa392..84593a7805 100644 --- a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs @@ -135,7 +135,7 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Sch /// /// Saves the transform. /// - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs index 4cfbddf91d..74e54e896e 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs @@ -116,7 +116,7 @@ public static MissingValueIndicatorTransform Create(IHostEnvironment env, ModelL }); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs index eb91e3eeed..31d775dc68 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs @@ -131,7 +131,7 @@ internal static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Sc /// /// Saves the transform. /// - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs index b99676cbc4..21c0054fff 100644 --- a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs +++ b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs @@ -492,7 +492,7 @@ private void WriteTypeAndValue(Stream stream, BinarySaver saver, ColumnType t throw Host.Except("We do not know how to serialize terms of type '{0}'", type); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.Transforms/OneHotEncoding.cs b/src/Microsoft.ML.Transforms/OneHotEncoding.cs index 31505972ed..bc9bec1555 100644 --- a/src/Microsoft.ML.Transforms/OneHotEncoding.cs +++ b/src/Microsoft.ML.Transforms/OneHotEncoding.cs @@ -26,7 +26,7 @@ namespace Microsoft.ML.Transforms.Categorical { /// - public sealed class OneHotEncodingTransformer : ITransformer, ICanSaveModel + public sealed class OneHotEncodingTransformer : ITransformer { public enum OutputKind : byte { @@ -165,7 +165,7 @@ internal OneHotEncodingTransformer(ValueToKeyMappingEstimator term, IEstimator _transformer.Transform(input); - public void Save(ModelSaveContext ctx) => _transformer.Save(ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => (_transformer as ICanSaveModel).Save(ctx); public bool IsRowToRowMapper => _transformer.IsRowToRowMapper; diff --git a/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs b/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs index 6c61b29762..df693a9c1b 100644 --- a/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs +++ b/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs @@ -24,7 +24,7 @@ namespace Microsoft.ML.Transforms.Categorical /// /// Produces a column of indicator vectors. The mapping between a value and a corresponding index is done through hashing. /// - public sealed class OneHotHashEncodingTransformer : ITransformer, ICanSaveModel + public sealed class OneHotHashEncodingTransformer : ITransformer { internal sealed class Column : OneToOneColumn { @@ -189,7 +189,7 @@ internal OneHotHashEncodingTransformer(HashingEstimator hash, IEstimator public IDataView Transform(IDataView input) => _transformer.Transform(input); - public void Save(ModelSaveContext ctx) => _transformer.Save(ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => (_transformer as ICanSaveModel).Save(ctx); /// /// Whether a call to should succeed, on an appropriate schema. diff --git a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs index 3927ae6daf..c0da05f872 100644 --- a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs +++ b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs @@ -280,7 +280,7 @@ public static OptionalColumnTransform Create(IHostEnvironment env, ModelLoadCont return h.Apply("Loading Model", ch => new OptionalColumnTransform(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/ProduceIdTransform.cs b/src/Microsoft.ML.Transforms/ProduceIdTransform.cs index 63abdc723f..91ac258ace 100644 --- a/src/Microsoft.ML.Transforms/ProduceIdTransform.cs +++ b/src/Microsoft.ML.Transforms/ProduceIdTransform.cs @@ -27,7 +27,7 @@ namespace Microsoft.ML.Transforms /// some other file, then apply this transform to that dataview, it may of course have a different /// result. This is distinct from most transforms that produce results based on data alone. /// - public sealed class ProduceIdTransform : RowToRowTransformBase + internal sealed class ProduceIdTransform : RowToRowTransformBase { public sealed class Arguments { @@ -60,7 +60,7 @@ public static Bindings Create(ModelLoadContext ctx, Schema input) return new Bindings(input, true, name); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -127,7 +127,7 @@ public static ProduceIdTransform Create(IHostEnvironment env, ModelLoadContext c return h.Apply("Loading Model", ch => new ProduceIdTransform(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs index 754ac9b573..171e8a6427 100644 --- a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs +++ b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs @@ -162,7 +162,7 @@ public TransformInfo(IHostEnvironment env, ModelLoadContext ctx, string director InitializeFourierCoefficients(roundedUpNumFeatures, roundedUpD); } - public void Save(ModelSaveContext ctx, string directoryName) + internal void Save(ModelSaveContext ctx, string directoryName) { Contracts.AssertValue(ctx); @@ -463,7 +463,7 @@ private static RandomFourierFeaturizingTransformer Create(IHostEnvironment env, return new RandomFourierFeaturizingTransformer(host, ctx); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 7141d697ec..f759edd697 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -340,7 +340,7 @@ internal LdaSummary GetLdaSummary(VBuffer> mapping) } } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); long memBlockSize = 0; @@ -733,7 +733,7 @@ private static LatentDirichletAllocationTransformer Create(IHostEnvironment env, }); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs index 7d01048a44..84e95b54d4 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs @@ -256,7 +256,7 @@ internal NgramHashingTransformer(IHostEnvironment env, IDataView input, params N } } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -605,7 +605,7 @@ private protected override Func GetDependenciesCore(Func a return col => active[col]; } - public override void Save(ModelSaveContext ctx) => _parent.Save(ctx); + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); protected override Schema.DetachedColumn[] GetOutputColumnsCore() { diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs index b32a96aea9..acb31ed104 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs @@ -162,7 +162,7 @@ public TransformInfo(ModelLoadContext ctx, bool readWeighting) NonEmptyLevels = ctx.Reader.ReadBoolArray(NgramLength); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -451,7 +451,7 @@ private static NgramExtractingTransformer Create(IHostEnvironment env, ModelLoad return new NgramExtractingTransformer(host, ctx); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs index f633c711b4..46a2b56a78 100644 --- a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs @@ -203,7 +203,7 @@ internal StopWordsRemovingTransformer(IHostEnvironment env, params StopWordsRemo _columns = columns; } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -477,7 +477,7 @@ private protected override Func GetDependenciesCore(Func a return col => active[col]; } - public override void Save(ModelSaveContext ctx) => _parent.Save(ctx); + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); } } @@ -850,7 +850,7 @@ internal CustomStopWordsRemovingTransformer(IHostEnvironment env, string stopwor LoadStopWords(ch, stopwords.AsMemory(), dataFile, stopwordsColumn, loader, out _stopWordsMap); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs index 985c6eb6fa..2ea6f6c071 100644 --- a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs +++ b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs @@ -562,7 +562,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat return estimator.Fit(data).Transform(data) as IDataTransform; } - private sealed class Transformer : ITransformer, ICanSaveModel + private sealed class Transformer : ITransformer { private const string TransformDirTemplate = "Step_{0:000}"; @@ -607,7 +607,7 @@ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) return new CompositeRowToRowMapper(inputSchema, revMaps.ToArray()); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs index 75bab0f8b4..e5764fafd8 100644 --- a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs +++ b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs @@ -118,7 +118,7 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, TextNormalizingEstimator.ExpectedColumnType, type.ToString()); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs index 8ef8109c52..bd23d4396e 100644 --- a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs +++ b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs @@ -136,7 +136,7 @@ private TokenizingByCharactersTransformer(IHost host, ModelLoadContext ctx) : private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) => Create(env, ctx).MakeDataTransform(input); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs index 8a6c8108b5..42fa72033e 100644 --- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs +++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs @@ -284,7 +284,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs index 1d46ed420e..bf05a60e84 100644 --- a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs +++ b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs @@ -150,7 +150,7 @@ private WordTokenizingTransformer(IHost host, ModelLoadContext ctx) : private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) => Create(env, ctx).MakeDataTransform(input); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/UngroupTransform.cs b/src/Microsoft.ML.Transforms/UngroupTransform.cs index 2dd0112dab..e0a594d8d6 100644 --- a/src/Microsoft.ML.Transforms/UngroupTransform.cs +++ b/src/Microsoft.ML.Transforms/UngroupTransform.cs @@ -138,7 +138,7 @@ private UngroupTransform(IHost host, ModelLoadContext ctx, IDataView input) _ungroupBinding = UngroupBinding.Create(ctx, host, input.Schema); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -365,7 +365,7 @@ public static UngroupBinding Create(ModelLoadContext ctx, IExceptionContext ectx return new UngroupBinding(ectx, inputSchema, mode, pivotColumns); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { _ectx.AssertValue(ctx); diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 356268709c..0c60d474be 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -4,19 +4,19 @@ Data.IDataViewArrayConverter Create an array variable of IDataView Microsoft.ML. Data.PredictorModelArrayConverter Create an array variable of PredictorModel Microsoft.ML.EntryPoints.MacroUtils MakeArray Microsoft.ML.EntryPoints.MacroUtils+ArrayIPredictorModelInput Microsoft.ML.EntryPoints.MacroUtils+ArrayIPredictorModelOutput Data.TextLoader Import a dataset from a text file Microsoft.ML.EntryPoints.ImportTextData TextLoader Microsoft.ML.EntryPoints.ImportTextData+LoaderInput Microsoft.ML.EntryPoints.ImportTextData+Output Models.AnomalyDetectionEvaluator Evaluates an anomaly detection scored dataset. Microsoft.ML.Data.Evaluate AnomalyDetection Microsoft.ML.Data.AnomalyDetectionMamlEvaluator+Arguments Microsoft.ML.EntryPoints.CommonOutputs+CommonEvaluateOutput -Models.AnomalyPipelineEnsemble Combine anomaly detection models into an ensemble Microsoft.ML.EntryPoints.EnsembleCreator CreateAnomalyPipelineEnsemble Microsoft.ML.EntryPoints.EnsembleCreator+PipelineAnomalyInput Microsoft.ML.EntryPoints.CommonOutputs+AnomalyDetectionOutput +Models.AnomalyPipelineEnsemble Combine anomaly detection models into an ensemble Microsoft.ML.Trainers.Ensemble.EnsembleCreator CreateAnomalyPipelineEnsemble Microsoft.ML.Trainers.Ensemble.EnsembleCreator+PipelineAnomalyInput Microsoft.ML.EntryPoints.CommonOutputs+AnomalyDetectionOutput Models.BinaryClassificationEvaluator Evaluates a binary classification scored dataset. Microsoft.ML.Data.Evaluate Binary Microsoft.ML.Data.BinaryClassifierMamlEvaluator+Arguments Microsoft.ML.EntryPoints.CommonOutputs+ClassificationEvaluateOutput -Models.BinaryEnsemble Combine binary classifiers into an ensemble Microsoft.ML.EntryPoints.EnsembleCreator CreateBinaryEnsemble Microsoft.ML.EntryPoints.EnsembleCreator+ClassifierInput Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput -Models.BinaryPipelineEnsemble Combine binary classification models into an ensemble Microsoft.ML.EntryPoints.EnsembleCreator CreateBinaryPipelineEnsemble Microsoft.ML.EntryPoints.EnsembleCreator+PipelineClassifierInput Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput +Models.BinaryEnsemble Combine binary classifiers into an ensemble Microsoft.ML.Trainers.Ensemble.EnsembleCreator CreateBinaryEnsemble Microsoft.ML.Trainers.Ensemble.EnsembleCreator+ClassifierInput Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput +Models.BinaryPipelineEnsemble Combine binary classification models into an ensemble Microsoft.ML.Trainers.Ensemble.EnsembleCreator CreateBinaryPipelineEnsemble Microsoft.ML.Trainers.Ensemble.EnsembleCreator+PipelineClassifierInput Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Models.ClassificationEvaluator Evaluates a multi class classification scored dataset. Microsoft.ML.Data.Evaluate MultiClass Microsoft.ML.Data.MultiClassMamlEvaluator+Arguments Microsoft.ML.EntryPoints.CommonOutputs+ClassificationEvaluateOutput Models.ClusterEvaluator Evaluates a clustering scored dataset. Microsoft.ML.Data.Evaluate Clustering Microsoft.ML.Data.ClusteringMamlEvaluator+Arguments Microsoft.ML.EntryPoints.CommonOutputs+CommonEvaluateOutput Models.CrossValidationResultsCombiner Combine the metric data views returned from cross validation. Microsoft.ML.EntryPoints.CrossValidationMacro CombineMetrics Microsoft.ML.EntryPoints.CrossValidationMacro+CombineMetricsInput Microsoft.ML.EntryPoints.CrossValidationMacro+CombinedOutput Models.CrossValidator Cross validation for general learning Microsoft.ML.EntryPoints.CrossValidationMacro CrossValidate Microsoft.ML.EntryPoints.CrossValidationMacro+Arguments Microsoft.ML.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.EntryPoints.CrossValidationMacro+Output] Models.CrossValidatorDatasetSplitter Split the dataset into the specified number of cross-validation folds (train and test sets) Microsoft.ML.EntryPoints.CVSplit Split Microsoft.ML.EntryPoints.CVSplit+Input Microsoft.ML.EntryPoints.CVSplit+Output Models.DatasetTransformer Applies a TransformModel to a dataset. Microsoft.ML.EntryPoints.ModelOperations Apply Microsoft.ML.EntryPoints.ModelOperations+ApplyTransformModelInput Microsoft.ML.EntryPoints.ModelOperations+ApplyTransformModelOutput -Models.EnsembleSummary Summarize a pipeline ensemble predictor. Microsoft.ML.Ensemble.EntryPoints.PipelineEnsemble Summarize Microsoft.ML.EntryPoints.SummarizePredictor+Input Microsoft.ML.Ensemble.EntryPoints.PipelineEnsemble+SummaryOutput +Models.EnsembleSummary Summarize a pipeline ensemble predictor. Microsoft.ML.Trainers.Ensemble.PipelineEnsemble Summarize Microsoft.ML.EntryPoints.SummarizePredictor+Input Microsoft.ML.Trainers.Ensemble.PipelineEnsemble+SummaryOutput Models.FixedPlattCalibrator Apply a Platt calibrator with a fixed slope and offset to an input model Microsoft.ML.Internal.Calibration.Calibrate FixedPlatt Microsoft.ML.Internal.Calibration.Calibrate+FixedPlattInput Microsoft.ML.EntryPoints.CommonOutputs+CalibratorOutput -Models.MultiClassPipelineEnsemble Combine multiclass classifiers into an ensemble Microsoft.ML.EntryPoints.EnsembleCreator CreateMultiClassPipelineEnsemble Microsoft.ML.EntryPoints.EnsembleCreator+PipelineClassifierInput Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput +Models.MultiClassPipelineEnsemble Combine multiclass classifiers into an ensemble Microsoft.ML.Trainers.Ensemble.EnsembleCreator CreateMultiClassPipelineEnsemble Microsoft.ML.Trainers.Ensemble.EnsembleCreator+PipelineClassifierInput Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput Models.MultiOutputRegressionEvaluator Evaluates a multi output regression scored dataset. Microsoft.ML.Data.Evaluate MultiOutputRegression Microsoft.ML.Data.MultiOutputRegressionMamlEvaluator+Arguments Microsoft.ML.EntryPoints.CommonOutputs+CommonEvaluateOutput Models.NaiveCalibrator Apply a Naive calibrator to an input model Microsoft.ML.Internal.Calibration.Calibrate Naive Microsoft.ML.Internal.Calibration.Calibrate+NoArgumentsInput Microsoft.ML.EntryPoints.CommonOutputs+CalibratorOutput Models.OneVersusAll One-vs-All macro (OVA) Microsoft.ML.EntryPoints.OneVersusAllMacro OneVersusAll Microsoft.ML.EntryPoints.OneVersusAllMacro+Arguments Microsoft.ML.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.EntryPoints.OneVersusAllMacro+Output] @@ -26,23 +26,23 @@ Models.PAVCalibrator Apply a PAV calibrator to an input model Microsoft.ML.Inter Models.PlattCalibrator Apply a Platt calibrator to an input model Microsoft.ML.Internal.Calibration.Calibrate Platt Microsoft.ML.Internal.Calibration.Calibrate+NoArgumentsInput Microsoft.ML.EntryPoints.CommonOutputs+CalibratorOutput Models.QuantileRegressionEvaluator Evaluates a quantile regression scored dataset. Microsoft.ML.Data.Evaluate QuantileRegression Microsoft.ML.Data.QuantileRegressionMamlEvaluator+Arguments Microsoft.ML.EntryPoints.CommonOutputs+CommonEvaluateOutput Models.RankerEvaluator Evaluates a ranking scored dataset. Microsoft.ML.Data.Evaluate Ranking Microsoft.ML.Data.RankerMamlEvaluator+Arguments Microsoft.ML.EntryPoints.CommonOutputs+CommonEvaluateOutput -Models.RegressionEnsemble Combine regression models into an ensemble Microsoft.ML.EntryPoints.EnsembleCreator CreateRegressionEnsemble Microsoft.ML.EntryPoints.EnsembleCreator+RegressionInput Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput +Models.RegressionEnsemble Combine regression models into an ensemble Microsoft.ML.Trainers.Ensemble.EnsembleCreator CreateRegressionEnsemble Microsoft.ML.Trainers.Ensemble.EnsembleCreator+RegressionInput Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Models.RegressionEvaluator Evaluates a regression scored dataset. Microsoft.ML.Data.Evaluate Regression Microsoft.ML.Data.RegressionMamlEvaluator+Arguments Microsoft.ML.EntryPoints.CommonOutputs+CommonEvaluateOutput -Models.RegressionPipelineEnsemble Combine regression models into an ensemble Microsoft.ML.EntryPoints.EnsembleCreator CreateRegressionPipelineEnsemble Microsoft.ML.EntryPoints.EnsembleCreator+PipelineRegressionInput Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput +Models.RegressionPipelineEnsemble Combine regression models into an ensemble Microsoft.ML.Trainers.Ensemble.EnsembleCreator CreateRegressionPipelineEnsemble Microsoft.ML.Trainers.Ensemble.EnsembleCreator+PipelineRegressionInput Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Models.Summarizer Summarize a linear regression predictor. Microsoft.ML.EntryPoints.SummarizePredictor Summarize Microsoft.ML.EntryPoints.SummarizePredictor+Input Microsoft.ML.EntryPoints.CommonOutputs+SummaryOutput Models.TrainTestEvaluator General train test for any supported evaluator Microsoft.ML.EntryPoints.TrainTestMacro TrainTest Microsoft.ML.EntryPoints.TrainTestMacro+Arguments Microsoft.ML.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.EntryPoints.TrainTestMacro+Output] -TimeSeriesProcessingEntryPoints.ExponentialAverage Applies a Exponential average on a time series. Microsoft.ML.TimeSeriesProcessing.TimeSeriesProcessingEntryPoints ExponentialAverage Microsoft.ML.TimeSeriesProcessing.ExponentialAverageTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -TimeSeriesProcessingEntryPoints.IidChangePointDetector This transform detects the change-points in an i.i.d. sequence using adaptive kernel density estimation and martingales. Microsoft.ML.TimeSeriesProcessing.TimeSeriesProcessingEntryPoints IidChangePointDetector Microsoft.ML.TimeSeriesProcessing.IidChangePointDetector+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -TimeSeriesProcessingEntryPoints.IidSpikeDetector This transform detects the spikes in a i.i.d. sequence using adaptive kernel density estimation. Microsoft.ML.TimeSeriesProcessing.TimeSeriesProcessingEntryPoints IidSpikeDetector Microsoft.ML.TimeSeriesProcessing.IidSpikeDetector+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -TimeSeriesProcessingEntryPoints.PercentileThresholdTransform Detects the values of time-series that are in the top percentile of the sliding window. Microsoft.ML.TimeSeriesProcessing.TimeSeriesProcessingEntryPoints PercentileThresholdTransform Microsoft.ML.TimeSeriesProcessing.PercentileThresholdTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -TimeSeriesProcessingEntryPoints.PValueTransform This P-Value transform calculates the p-value of the current input in the sequence with regard to the values in the sliding window. Microsoft.ML.TimeSeriesProcessing.TimeSeriesProcessingEntryPoints PValueTransform Microsoft.ML.TimeSeriesProcessing.PValueTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -TimeSeriesProcessingEntryPoints.SlidingWindowTransform Returns the last values for a time series [y(t-d-l+1), y(t-d-l+2), ..., y(t-l-1), y(t-l)] where d is the size of the window, l the lag and y is a Float. Microsoft.ML.TimeSeriesProcessing.TimeSeriesProcessingEntryPoints SlidingWindowTransform Microsoft.ML.TimeSeriesProcessing.SlidingWindowTransformBase`1+Arguments[System.Single] Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -TimeSeriesProcessingEntryPoints.SsaChangePointDetector This transform detects the change-points in a seasonal time-series using Singular Spectrum Analysis (SSA). Microsoft.ML.TimeSeriesProcessing.TimeSeriesProcessingEntryPoints SsaChangePointDetector Microsoft.ML.TimeSeriesProcessing.SsaChangePointDetector+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -TimeSeriesProcessingEntryPoints.SsaSpikeDetector This transform detects the spikes in a seasonal time-series using Singular Spectrum Analysis (SSA). Microsoft.ML.TimeSeriesProcessing.TimeSeriesProcessingEntryPoints SsaSpikeDetector Microsoft.ML.TimeSeriesProcessing.SsaSpikeDetector+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +TimeSeriesProcessingEntryPoints.ExponentialAverage Applies a Exponential average on a time series. Microsoft.ML.Transforms.TimeSeries.TimeSeriesProcessingEntryPoints ExponentialAverage Microsoft.ML.Transforms.TimeSeries.ExponentialAverageTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +TimeSeriesProcessingEntryPoints.IidChangePointDetector This transform detects the change-points in an i.i.d. sequence using adaptive kernel density estimation and martingales. Microsoft.ML.Transforms.TimeSeries.TimeSeriesProcessingEntryPoints IidChangePointDetector Microsoft.ML.Transforms.TimeSeries.IidChangePointDetector+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +TimeSeriesProcessingEntryPoints.IidSpikeDetector This transform detects the spikes in a i.i.d. sequence using adaptive kernel density estimation. Microsoft.ML.Transforms.TimeSeries.TimeSeriesProcessingEntryPoints IidSpikeDetector Microsoft.ML.Transforms.TimeSeries.IidSpikeDetector+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +TimeSeriesProcessingEntryPoints.PercentileThresholdTransform Detects the values of time-series that are in the top percentile of the sliding window. Microsoft.ML.Transforms.TimeSeries.TimeSeriesProcessingEntryPoints PercentileThresholdTransform Microsoft.ML.Transforms.TimeSeries.PercentileThresholdTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +TimeSeriesProcessingEntryPoints.PValueTransform This P-Value transform calculates the p-value of the current input in the sequence with regard to the values in the sliding window. Microsoft.ML.Transforms.TimeSeries.TimeSeriesProcessingEntryPoints PValueTransform Microsoft.ML.Transforms.TimeSeries.PValueTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +TimeSeriesProcessingEntryPoints.SlidingWindowTransform Returns the last values for a time series [y(t-d-l+1), y(t-d-l+2), ..., y(t-l-1), y(t-l)] where d is the size of the window, l the lag and y is a Float. Microsoft.ML.Transforms.TimeSeries.TimeSeriesProcessingEntryPoints SlidingWindowTransform Microsoft.ML.Transforms.TimeSeries.SlidingWindowTransformBase`1+Arguments[System.Single] Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +TimeSeriesProcessingEntryPoints.SsaChangePointDetector This transform detects the change-points in a seasonal time-series using Singular Spectrum Analysis (SSA). Microsoft.ML.Transforms.TimeSeries.TimeSeriesProcessingEntryPoints SsaChangePointDetector Microsoft.ML.Transforms.TimeSeries.SsaChangePointDetector+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +TimeSeriesProcessingEntryPoints.SsaSpikeDetector This transform detects the spikes in a seasonal time-series using Singular Spectrum Analysis (SSA). Microsoft.ML.Transforms.TimeSeries.TimeSeriesProcessingEntryPoints SsaSpikeDetector Microsoft.ML.Transforms.TimeSeries.SsaSpikeDetector+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Trainers.AveragedPerceptronBinaryClassifier Averaged Perceptron Binary Classifier. Microsoft.ML.Trainers.Online.AveragedPerceptronTrainer TrainBinary Microsoft.ML.Trainers.Online.AveragedPerceptronTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput -Trainers.EnsembleBinaryClassifier Train binary ensemble. Microsoft.ML.Ensemble.EntryPoints.Ensemble CreateBinaryEnsemble Microsoft.ML.Ensemble.EnsembleTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput -Trainers.EnsembleClassification Train multiclass ensemble. Microsoft.ML.Ensemble.EntryPoints.Ensemble CreateMultiClassEnsemble Microsoft.ML.Ensemble.MulticlassDataPartitionEnsembleTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput -Trainers.EnsembleRegression Train regression ensemble. Microsoft.ML.Ensemble.EntryPoints.Ensemble CreateRegressionEnsemble Microsoft.ML.Ensemble.RegressionEnsembleTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput +Trainers.EnsembleBinaryClassifier Train binary ensemble. Microsoft.ML.Trainers.Ensemble.Ensemble CreateBinaryEnsemble Microsoft.ML.Trainers.Ensemble.EnsembleTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput +Trainers.EnsembleClassification Train multiclass ensemble. Microsoft.ML.Trainers.Ensemble.Ensemble CreateMultiClassEnsemble Microsoft.ML.Trainers.Ensemble.MulticlassDataPartitionEnsembleTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput +Trainers.EnsembleRegression Train regression ensemble. Microsoft.ML.Trainers.Ensemble.Ensemble CreateRegressionEnsemble Microsoft.ML.Trainers.Ensemble.RegressionEnsembleTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.FastForestBinaryClassifier Uses a random forest learner to perform binary classification. Microsoft.ML.Trainers.FastTree.FastForest TrainBinary Microsoft.ML.Trainers.FastTree.FastForestClassification+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Trainers.FastForestRegressor Trains a random forest to fit target values using least-squares. Microsoft.ML.Trainers.FastTree.FastForest TrainRegression Microsoft.ML.Trainers.FastTree.FastForestRegression+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.FastTreeBinaryClassifier Uses a logit-boost boosted tree learner to perform binary classification. Microsoft.ML.Trainers.FastTree.FastTree TrainBinary Microsoft.ML.Trainers.FastTree.FastTreeBinaryClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput @@ -58,8 +58,8 @@ Trainers.LightGbmClassifier Train a LightGBM multi class model. Microsoft.ML.Lig Trainers.LightGbmRanker Train a LightGBM ranking model. Microsoft.ML.LightGBM.LightGbm TrainRanking Microsoft.ML.LightGBM.Options Microsoft.ML.EntryPoints.CommonOutputs+RankingOutput Trainers.LightGbmRegressor LightGBM Regression Microsoft.ML.LightGBM.LightGbm TrainRegression Microsoft.ML.LightGBM.Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.LinearSvmBinaryClassifier Train a linear SVM. Microsoft.ML.Trainers.Online.LinearSvmTrainer TrainLinearSvm Microsoft.ML.Trainers.Online.LinearSvmTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput -Trainers.LogisticRegressionBinaryClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Learners.LogisticRegression TrainBinary Microsoft.ML.Learners.LogisticRegression+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput -Trainers.LogisticRegressionClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Learners.LogisticRegression TrainMultiClass Microsoft.ML.Learners.MulticlassLogisticRegression+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput +Trainers.LogisticRegressionBinaryClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Trainers.LogisticRegression TrainBinary Microsoft.ML.Trainers.LogisticRegression+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput +Trainers.LogisticRegressionClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Trainers.LogisticRegression TrainMultiClass Microsoft.ML.Trainers.MulticlassLogisticRegression+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput Trainers.NaiveBayesClassifier Train a MultiClassNaiveBayesTrainer. Microsoft.ML.Trainers.MultiClassNaiveBayesTrainer TrainMultiClassNaiveBayesTrainer Microsoft.ML.Trainers.MultiClassNaiveBayesTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput Trainers.OnlineGradientDescentRegressor Train a Online gradient descent perceptron. Microsoft.ML.Trainers.Online.OnlineGradientDescentTrainer TrainRegression Microsoft.ML.Trainers.Online.OnlineGradientDescentTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.OrdinaryLeastSquaresRegressor Train an OLS regression model. Microsoft.ML.Trainers.HalLearners.OlsLinearRegressionTrainer TrainRegression Microsoft.ML.Trainers.HalLearners.OlsLinearRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput @@ -69,7 +69,7 @@ Trainers.StochasticDualCoordinateAscentBinaryClassifier Train an SDCA binary mod Trainers.StochasticDualCoordinateAscentClassifier The SDCA linear multi-class classification trainer. Microsoft.ML.Trainers.Sdca TrainMultiClass Microsoft.ML.Trainers.SdcaMultiClassTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput Trainers.StochasticDualCoordinateAscentRegressor The SDCA linear regression trainer. Microsoft.ML.Trainers.Sdca TrainRegression Microsoft.ML.Trainers.SdcaRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.StochasticGradientDescentBinaryClassifier Train an Hogwild SGD binary model. Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer TrainBinary Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput -Trainers.SymSgdBinaryClassifier Train a symbolic SGD. Microsoft.ML.Trainers.SymSgd.SymSgdClassificationTrainer TrainSymSgd Microsoft.ML.Trainers.SymSgd.SymSgdClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput +Trainers.SymSgdBinaryClassifier Train a symbolic SGD. Microsoft.ML.Trainers.HalLearners.SymSgdClassificationTrainer TrainSymSgd Microsoft.ML.Trainers.HalLearners.SymSgdClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Transforms.ApproximateBootstrapSampler Approximate bootstrap sampling. Microsoft.ML.Transforms.BootstrapSample GetSample Microsoft.ML.Transforms.BootstrapSamplingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.BinaryPredictionScoreColumnsRenamer For binary prediction, it renames the PredictedLabel and Score columns to include the name of the positive class. Microsoft.ML.EntryPoints.ScoreModel RenameBinaryPredictionScoreColumns Microsoft.ML.EntryPoints.ScoreModel+RenameBinaryPredictionScoreColumnsInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.BinNormalizer The values are assigned into equidensity bins and a value is mapped to its bin_number/number_of_bins. Microsoft.ML.Data.Normalize Bin Microsoft.ML.Transforms.Normalizers.NormalizeTransform+BinArguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 5154957bf1..5d720ce6de 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -4306,7 +4306,7 @@ } }, { - "Name": "NumIterations", + "Name": "NumberOfIterations", "Type": "Int", "Desc": "Number of iterations", "Aliases": [ @@ -4325,11 +4325,12 @@ } }, { - "Name": "InitWtsDiameter", + "Name": "InitialWeightsDiameter", "Type": "Float", "Desc": "Init weights diameter", "Aliases": [ - "initwts" + "initwts", + "initWtsDiameter" ], "Required": false, "SortOrder": 140.0, @@ -4467,18 +4468,6 @@ true ] } - }, - { - "Name": "StreamingCacheSize", - "Type": "Int", - "Desc": "Size of cache when trained in Scope", - "Aliases": [ - "cache" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": 1000000 } ], "Outputs": [ @@ -13247,7 +13236,7 @@ } }, { - "Name": "NumIterations", + "Name": "NumberOfIterations", "Type": "Int", "Desc": "Number of iterations", "Aliases": [ @@ -13266,11 +13255,12 @@ } }, { - "Name": "InitWtsDiameter", + "Name": "InitialWeightsDiameter", "Type": "Float", "Desc": "Init weights diameter", "Aliases": [ - "initwts" + "initwts", + "initWtsDiameter" ], "Required": false, "SortOrder": 140.0, @@ -13353,18 +13343,6 @@ ] } }, - { - "Name": "StreamingCacheSize", - "Type": "Int", - "Desc": "Size of cache when trained in Scope", - "Aliases": [ - "cache" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": 1000000 - }, { "Name": "BatchSize", "Type": "Int", @@ -14272,7 +14250,7 @@ } }, { - "Name": "NumIterations", + "Name": "NumberOfIterations", "Type": "Int", "Desc": "Number of iterations", "Aliases": [ @@ -14291,11 +14269,12 @@ } }, { - "Name": "InitWtsDiameter", + "Name": "InitialWeightsDiameter", "Type": "Float", "Desc": "Init weights diameter", "Aliases": [ - "initwts" + "initwts", + "initWtsDiameter" ], "Required": false, "SortOrder": 140.0, @@ -14410,18 +14389,6 @@ true ] } - }, - { - "Name": "StreamingCacheSize", - "Type": "Int", - "Desc": "Size of cache when trained in Scope", - "Aliases": [ - "cache" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": 1000000 } ], "Outputs": [ diff --git a/test/Microsoft.ML.Benchmarks.Tests/BenchmarksTest.cs b/test/Microsoft.ML.Benchmarks.Tests/BenchmarksTest.cs index 940ca83dd6..9dc6854e8b 100644 --- a/test/Microsoft.ML.Benchmarks.Tests/BenchmarksTest.cs +++ b/test/Microsoft.ML.Benchmarks.Tests/BenchmarksTest.cs @@ -12,7 +12,7 @@ using BenchmarkDotNet.Loggers; using BenchmarkDotNet.Running; using Microsoft.ML.Benchmarks.Harness; -using Microsoft.ML.Internal.CpuMath; +using Microsoft.ML.TestFramework.Attributes; using Xunit; using Xunit.Abstractions; @@ -29,15 +29,6 @@ public class BenchmarksTest private ITestOutputHelper Output { get; } - public static bool CanExecute => -#if DEBUG - false; // BenchmarkDotNet does not allow running the benchmarks in Debug, so this test is disabled for DEBUG -#elif NET461 - false; // We are currently not running Benchmarks for FullFramework -#else - Environment.Is64BitProcess; // we don't support 32 bit yet -#endif - public static TheoryData GetBenchmarks() { TheoryData benchmarks = new TheoryData(); @@ -54,7 +45,7 @@ where Attribute.IsDefined(type, typeof(CIBenchmark)) return benchmarks; } - [ConditionalTheory(typeof(BenchmarksTest), nameof(CanExecute))] + [BenchmarkTheory] [MemberData(nameof(GetBenchmarks))] public void BenchmarksProjectIsNotBroken(Type type) { diff --git a/test/Microsoft.ML.Benchmarks/KMeansAndLogisticRegressionBench.cs b/test/Microsoft.ML.Benchmarks/KMeansAndLogisticRegressionBench.cs index 3c92045b31..870480ebbc 100644 --- a/test/Microsoft.ML.Benchmarks/KMeansAndLogisticRegressionBench.cs +++ b/test/Microsoft.ML.Benchmarks/KMeansAndLogisticRegressionBench.cs @@ -6,8 +6,8 @@ using Microsoft.ML.Benchmarks.Harness; using Microsoft.ML.Data; using Microsoft.ML.Internal.Calibration; -using Microsoft.ML.Learners; using Microsoft.ML.TestFramework; +using Microsoft.ML.Trainers; namespace Microsoft.ML.Benchmarks { diff --git a/test/Microsoft.ML.Benchmarks/README.md b/test/Microsoft.ML.Benchmarks/README.md index 4cefdfbaab..0f84b4ca46 100644 --- a/test/Microsoft.ML.Benchmarks/README.md +++ b/test/Microsoft.ML.Benchmarks/README.md @@ -94,7 +94,7 @@ you can debug this test locally by: build.cmd -release -buildNative 2- Changing the configuration in Visual Studio from Debug -> Release -3- Changing the annotation in the `BenchmarksProjectIsNotBroken` to replace `ConditionalTheory` with `Theory`, as below. +3- Changing the annotation in the `BenchmarksProjectIsNotBroken` to replace `BenchmarkTheory` with `Theory`, as below. ```cs [Theory] diff --git a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs index 7c67d13ac4..c08118bee6 100644 --- a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs +++ b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs @@ -9,7 +9,6 @@ using Microsoft.Data.DataView; using Microsoft.ML.Benchmarks.Harness; using Microsoft.ML.Data; -using Microsoft.ML.Learners; using Microsoft.ML.TestFramework; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Code/BestFriendOnPublicDeclarationTest.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/BestFriendOnPublicDeclarationTest.cs new file mode 100644 index 0000000000..ddd8fb5d72 --- /dev/null +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/BestFriendOnPublicDeclarationTest.cs @@ -0,0 +1,63 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Reflection; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.ML.CodeAnalyzer.Tests.Helpers; +using Xunit; + +namespace Microsoft.ML.InternalCodeAnalyzer.Tests +{ + public sealed class BestFriendOnPublicDeclarationTest : DiagnosticVerifier + { + private readonly Lazy SourceAttribute = TestUtils.LazySource("BestFriendAttribute.cs"); + private readonly Lazy SourceDeclaration = TestUtils.LazySource("BestFriendOnPublicDeclaration.cs"); + + [Fact] + public void BestFriendOnPublicDeclaration() + { + Solution solution = null; + var projA = CreateProject("ProjectA", ref solution, SourceDeclaration.Value, SourceAttribute.Value); + + var analyzer = new BestFriendOnPublicDeclarationsAnalyzer(); + + var refs = new List { + RefFromType(), RefFromType(), + MetadataReference.CreateFromFile(Assembly.Load("netstandard, Version=2.0.0.0").Location), + MetadataReference.CreateFromFile(Assembly.Load("System.Runtime, Version=0.0.0.0").Location) + }; + + var comp = projA.GetCompilationAsync().Result.WithReferences(refs.ToArray()); + var compilationWithAnalyzers = comp.WithAnalyzers(ImmutableArray.Create((DiagnosticAnalyzer)analyzer)); + var allDiags = compilationWithAnalyzers.GetAnalyzerDiagnosticsAsync().Result; + + var projectTrees = new HashSet(projA.Documents.Select(r => r.GetSyntaxTreeAsync().Result)); + var diags = allDiags + .Where(d => d.Location == Location.None || d.Location.IsInMetadata || projectTrees.Contains(d.Location.SourceTree)) + .OrderBy(d => d.Location.SourceSpan.Start).ToArray(); + + var diag = analyzer.SupportedDiagnostics[0]; + var expected = new DiagnosticResult[] { + diag.CreateDiagnosticResult(8, 6, "PublicClass"), + diag.CreateDiagnosticResult(11, 10, "PublicField"), + diag.CreateDiagnosticResult(14, 10, "PublicProperty"), + diag.CreateDiagnosticResult(20, 10, "PublicMethod"), + diag.CreateDiagnosticResult(26, 10, "PublicDelegate"), + diag.CreateDiagnosticResult(29, 10, "PublicClass"), + diag.CreateDiagnosticResult(35, 6, "PublicStruct"), + diag.CreateDiagnosticResult(40, 6, "PublicEnum"), + diag.CreateDiagnosticResult(47, 6, "PublicInterface"), + diag.CreateDiagnosticResult(102, 10, "PublicMethod") + }; + + VerifyDiagnosticResults(diags, analyzer, expected); + } + } +} + diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/DiagnosticVerifier.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/DiagnosticVerifier.cs index 7bf7e5693f..248a4645f5 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/DiagnosticVerifier.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/DiagnosticVerifier.cs @@ -15,6 +15,7 @@ using Microsoft.Data.DataView; using Microsoft.ML.Data; using Microsoft.ML.StaticPipe; +using Microsoft.ML.Transforms.Conversions; using Xunit; namespace Microsoft.ML.CodeAnalyzer.Tests.Helpers @@ -267,7 +268,7 @@ private static string FormatDiagnostics(DiagnosticAnalyzer analyzer, params Diag private static readonly MetadataReference MSDataDataViewReference = RefFromType(); private static readonly MetadataReference MLNetCoreReference = RefFromType(); - private static readonly MetadataReference MLNetDataReference = RefFromType(); + private static readonly MetadataReference MLNetDataReference = RefFromType(); private static readonly MetadataReference MLNetStaticPipeReference = RefFromType(); protected static MetadataReference RefFromType() diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/BestFriendOnPublicDeclaration.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/BestFriendOnPublicDeclaration.cs new file mode 100644 index 0000000000..1515f21e1f --- /dev/null +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/BestFriendOnPublicDeclaration.cs @@ -0,0 +1,107 @@ +using System; +using Microsoft.ML; + +namespace TestNamespace +{ + // all of the best friend declaration should fail the diagnostic + + [BestFriend] + public class PublicClass + { + [BestFriend] + public int PublicField; + + [BestFriend] + public string PublicProperty + { + get { return string.Empty; } + } + + [BestFriend] + public bool PublicMethod() + { + return true; + } + + [BestFriend] + public delegate string PublicDelegate(); + + [BestFriend] + public PublicClass() + { + } + } + + [BestFriend] + public struct PublicStruct + { + } + + [BestFriend] + public enum PublicEnum + { + EnumValue1, + EnumValue2 + } + + [BestFriend] + public interface PublicInterface + { + } + + // these should work + + [BestFriend] + internal class InternalClass + { + [BestFriend] + internal int InternalField; + + [BestFriend] + internal string InternalProperty + { + get { return string.Empty; } + } + + [BestFriend] + internal bool InternalMethod() + { + return true; + } + + [BestFriend] + internal delegate string InternalDelegate(); + + [BestFriend] + internal InternalClass() + { + } + } + + [BestFriend] + internal struct InternalStruct + { + } + + [BestFriend] + internal enum InternalEnum + { + EnumValue1, + EnumValue2 + } + + [BestFriend] + internal interface InternalInterface + { + } + + // this should fail the diagnostic + // a repro for https://github.com/dotnet/machinelearning/pull/2434#discussion_r254770946 + internal class InternalClassWithPublicMember + { + [BestFriend] + public void PublicMethod() + { + } + } +} diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 8fcb00842e..512329a123 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -12,28 +12,27 @@ using Microsoft.ML.Core.Tests.UnitTests; using Microsoft.ML.Data; using Microsoft.ML.Data.IO; -using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.EntryPoints; using Microsoft.ML.EntryPoints.JsonUtils; using Microsoft.ML.ImageAnalytics; using Microsoft.ML.Internal.Calibration; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Learners; using Microsoft.ML.LightGBM; using Microsoft.ML.Model.Onnx; -using Microsoft.ML.TimeSeriesProcessing; +using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.Trainers; +using Microsoft.ML.Trainers.Ensemble; using Microsoft.ML.Trainers.FastTree; +using Microsoft.ML.Trainers.HalLearners; using Microsoft.ML.Trainers.PCA; -using Microsoft.ML.Trainers.SymSgd; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Categorical; using Microsoft.ML.Transforms.Conversions; using Microsoft.ML.Transforms.Normalizers; using Microsoft.ML.Transforms.Projections; using Microsoft.ML.Transforms.Text; +using Microsoft.ML.Transforms.TimeSeries; using Newtonsoft.Json; using Newtonsoft.Json.Linq; using Xunit; @@ -1313,7 +1312,7 @@ public void EntryPointMulticlassPipelineEnsemble() } } - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] + [LessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline")] public void EntryPointPipelineEnsembleGetSummary() { var dataPath = GetDataPath("breast-cancer-withheader.txt"); @@ -1917,14 +1916,14 @@ public void EntryPointEvaluateRanking() } } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] public void EntryPointLightGbmBinary() { Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryModelParameters).Assembly); TestEntryPointRoutine("breast-cancer.txt", "Trainers.LightGbmBinaryClassifier"); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] public void EntryPointLightGbmMultiClass() { Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryModelParameters).Assembly); @@ -3650,7 +3649,7 @@ public void EntryPointWordEmbeddings() } } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + [TensorFlowFact] public void EntryPointTensorFlowTransform() { Env.ComponentCatalog.RegisterAssembly(typeof(TensorFlowTransformer).Assembly); @@ -4070,7 +4069,7 @@ public void TestSimpleTrainExperiment() } } - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] // netcore3.0 output differs from Baseline + [LessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline")] public void TestCrossValidationMacro() { var dataPath = GetDataPath(TestDatasets.generatedRegressionDatasetmacro.trainFilename); @@ -5448,11 +5447,10 @@ public void TestOvaMacroWithUncalibratedLearner() 'RecencyGainMulti': false, 'Averaged': true, 'AveragedTolerance': 0.01, - 'NumIterations': 1, + 'NumberOfIterations': 1, 'InitialWeights': null, - 'InitWtsDiameter': 0.0, + 'InitialWeightsDiameter': 0.0, 'Shuffle': false, - 'StreamingCacheSize': 1000000, 'LabelColumn': 'Label', 'TrainingData': '$Var_9ccc8bce4f6540eb8a244ab40585602a', 'FeatureColumn': 'Features', @@ -5538,7 +5536,7 @@ public void TestOvaMacroWithUncalibratedLearner() } } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + [TensorFlowFact] public void TestTensorFlowEntryPoint() { var dataPath = GetDataPath("Train-Tiny-28x28.txt"); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestUtilities.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestUtilities.cs index 9c0bd79e8e..e94de49320 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestUtilities.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestUtilities.cs @@ -7,7 +7,6 @@ using System.Linq; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.RunTests; -using Microsoft.ML.Trainers.FastTree.Internal; using Xunit; using Xunit.Abstractions; diff --git a/test/Microsoft.ML.Functional.Tests/Common.cs b/test/Microsoft.ML.Functional.Tests/Common.cs new file mode 100644 index 0000000000..29088298d3 --- /dev/null +++ b/test/Microsoft.ML.Functional.Tests/Common.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Data.DataView; +using Microsoft.ML.Data; +using Microsoft.ML.SamplesUtils; +using Microsoft.ML.Trainers.HalLearners; +using Xunit; + +namespace Microsoft.ML.Functional.Tests +{ + internal static class Common + { + public static void CheckMetrics(RegressionMetrics metrics) + { + // Perform sanity checks on the metrics + Assert.True(metrics.Rms >= 0); + Assert.True(metrics.L1 >= 0); + Assert.True(metrics.L2 >= 0); + Assert.True(metrics.RSquared <= 1); + } + } +} diff --git a/test/Microsoft.ML.Functional.Tests/Microsoft.ML.Functional.Tests.csproj b/test/Microsoft.ML.Functional.Tests/Microsoft.ML.Functional.Tests.csproj new file mode 100644 index 0000000000..106db8f36c --- /dev/null +++ b/test/Microsoft.ML.Functional.Tests/Microsoft.ML.Functional.Tests.csproj @@ -0,0 +1,52 @@ + + + + + false + false + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/test/Microsoft.ML.Functional.Tests/Prediction.cs b/test/Microsoft.ML.Functional.Tests/Prediction.cs new file mode 100644 index 0000000000..24cc049e8f --- /dev/null +++ b/test/Microsoft.ML.Functional.Tests/Prediction.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.RunTests; +using Microsoft.ML.TestFramework; +using Xunit; + +namespace Microsoft.ML.Functional.Tests +{ + public class PredictionScenarios + { + /// + /// Reconfigurable predictions: The following should be possible: A user trains a binary classifier, + /// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold + /// and configures the scorer (or more precisely instantiates a new scorer over the same model parameters) + /// with some threshold derived from that. + /// + [Fact] + public void ReconfigurablePrediction() + { + var mlContext = new MLContext(seed: 789); + + // Get the dataset, create a train and test + var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(), hasHeader: true) + .Read(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename)); + var split = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.2); + + // Create a pipeline to train on the housing data + var pipeline = mlContext.Transforms.Concatenate("Features", new string[] { + "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", + "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"}) + .Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue")) + .Append(mlContext.Regression.Trainers.OrdinaryLeastSquares()); + + var model = pipeline.Fit(split.TrainSet); + + var scoredTest = model.Transform(split.TestSet); + var metrics = mlContext.Regression.Evaluate(scoredTest); + + Common.CheckMetrics(metrics); + + // Todo #2465: Allow the setting of threshold and thresholdColumn for scoring. + // This is no longer possible in the API + //var newModel = new BinaryPredictionTransformer>(ml, model.Model, trainData.Schema, model.FeatureColumn, threshold: 0.01f, thresholdColumn: DefaultColumnNames.Probability); + //var newScoredTest = newModel.Transform(pipeline.Transform(testData)); + //var newMetrics = mlContext.BinaryClassification.Evaluate(scoredTest); + // And the Threshold and ThresholdColumn properties are not settable. + //var predictor = model.LastTransformer; + //predictor.Threshold = 0.01; // Not possible + } + } +} diff --git a/test/Microsoft.ML.Functional.Tests/Validation.cs b/test/Microsoft.ML.Functional.Tests/Validation.cs new file mode 100644 index 0000000000..b04eff387a --- /dev/null +++ b/test/Microsoft.ML.Functional.Tests/Validation.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Data.DataView; +using Microsoft.ML.Data; +using Microsoft.ML.RunTests; +using Microsoft.ML.TestFramework; +using Microsoft.ML.Trainers.HalLearners; +using Xunit; + +namespace Microsoft.ML.Functional.Tests +{ + public class ValidationScenarios + { + /// + /// Cross-validation: Have a mechanism to do cross validation, that is, you come up with + /// a data source (optionally with stratification column), come up with an instantiable transform + /// and trainer pipeline, and it will handle (1) splitting up the data, (2) training the separate + /// pipelines on in-fold data, (3) scoring on the out-fold data, (4) returning the set of + /// metrics, trained pipelines, and scored test data for each fold. + /// + [Fact] + void CrossValidation() + { + var mlContext = new MLContext(seed: 789); + + // Get the dataset + var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(), hasHeader: true) + .Read(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename)); + + // Create a pipeline to train on the sentiment data + var pipeline = mlContext.Transforms.Concatenate("Features", new string[] { + "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", + "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"}) + .Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue")) + .Append(mlContext.Regression.Trainers.OrdinaryLeastSquares()); + + // Compute the CV result + var cvResult = mlContext.Regression.CrossValidate(data, pipeline, numFolds: 5); + + // Check that the results are valid + Assert.IsType(cvResult[0].Metrics); + Assert.IsType>>(cvResult[0].Model); + Assert.True(cvResult[0].ScoredHoldOutSet is IDataView); + Assert.Equal(5, cvResult.Length); + + // And validate the metrics + foreach (var result in cvResult) + Common.CheckMetrics(result.Metrics); + } + } +} diff --git a/test/Microsoft.ML.OnnxTransformTest/DnnImageFeaturizerTest.cs b/test/Microsoft.ML.OnnxTransformTest/DnnImageFeaturizerTest.cs index 70ec3dffc8..2ebb55ee4d 100644 --- a/test/Microsoft.ML.OnnxTransformTest/DnnImageFeaturizerTest.cs +++ b/test/Microsoft.ML.OnnxTransformTest/DnnImageFeaturizerTest.cs @@ -13,6 +13,7 @@ using Microsoft.ML.Model; using Microsoft.ML.RunTests; using Microsoft.ML.StaticPipe; +using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.StaticPipe; using Xunit; @@ -57,16 +58,9 @@ public DnnImageFeaturizerTests(ITestOutputHelper helper) : base(helper) { } - // Onnx is only supported on x64 Windows - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] + [OnnxFact] void TestDnnImageFeaturizer() { - // Onnxruntime supports Ubuntu 16.04, but not CentOS - // Do not execute on CentOS image - if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - return; - - var samplevector = GetSampleArrayData(); var dataView = DataViewConstructionUtils.CreateFromList(Env, @@ -97,13 +91,9 @@ void TestDnnImageFeaturizer() catch (InvalidOperationException) { } } - // Onnx is only supported on x64 Windows - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] + [OnnxFact] public void OnnxStatic() { - if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - return; - var env = new MLContext(null, 1); var imageHeight = 224; var imageWidth = 224; @@ -141,13 +131,9 @@ public void OnnxStatic() } // Onnx is only supported on x64 Windows - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] + [OnnxFact] public void TestOldSavingAndLoading() { - if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - return; - - var samplevector = GetSampleArrayData(); var dataView = ML.Data.ReadFromEnumerable( diff --git a/test/Microsoft.ML.OnnxTransformTest/Microsoft.ML.OnnxTransformTest.csproj b/test/Microsoft.ML.OnnxTransformTest/Microsoft.ML.OnnxTransformTest.csproj index a153655b27..a5dbaca9f8 100644 --- a/test/Microsoft.ML.OnnxTransformTest/Microsoft.ML.OnnxTransformTest.csproj +++ b/test/Microsoft.ML.OnnxTransformTest/Microsoft.ML.OnnxTransformTest.csproj @@ -11,7 +11,7 @@ - + diff --git a/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs b/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs index 904282a7a2..6b95e217db 100644 --- a/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs +++ b/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs @@ -13,6 +13,7 @@ using Microsoft.ML.Model; using Microsoft.ML.RunTests; using Microsoft.ML.StaticPipe; +using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.Tools; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.StaticPipe; @@ -21,24 +22,6 @@ namespace Microsoft.ML.Tests { - - /// - /// A Fact attribute for Onnx unit tests. Onnxruntime only supported - /// on Windows, Linux (Ubuntu 16.04) and 64-bit platforms. - /// - public class OnnxFact : FactAttribute - { - public OnnxFact() - { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux) || - RuntimeInformation.IsOSPlatform(OSPlatform.OSX) || - !Environment.Is64BitProcess) - { - Skip = "Require 64 bit and Windows or Linux (Ubuntu 16.04)."; - } - } - } - public class OnnxTransformTests : TestDataPipeBase { private const int inputSize = 150528; @@ -137,16 +120,12 @@ void TestSimpleCase() catch (ArgumentOutOfRangeException) { } catch (InvalidOperationException) { } } - - // x86 not supported - [ConditionalTheory(typeof(Environment), nameof(Environment.Is64BitProcess))] + + [OnnxTheory] [InlineData(null, false)] [InlineData(null, true)] void TestOldSavingAndLoading(int? gpuDeviceId, bool fallbackToCpu) { - if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - return; - var modelFile = "squeezenet/00000001/model.onnx"; var samplevector = GetSampleArrayData(); diff --git a/test/Microsoft.ML.Predictor.Tests/TestParallelFasttreeInterface.cs b/test/Microsoft.ML.Predictor.Tests/TestParallelFasttreeInterface.cs index 6c246fa4a0..776431faba 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestParallelFasttreeInterface.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestParallelFasttreeInterface.cs @@ -6,7 +6,7 @@ using Microsoft.ML; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.RunTests; -using Microsoft.ML.Trainers.FastTree.Internal; +using Microsoft.ML.Trainers.FastTree; using Xunit; using Xunit.Abstractions; diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs index 339bede3e5..f509c71908 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; +using Microsoft.ML.TestFramework.Attributes; using Float = System.Single; namespace Microsoft.ML.RunTests @@ -14,17 +15,15 @@ namespace Microsoft.ML.RunTests using Microsoft.Data.DataView; using Microsoft.ML; using Microsoft.ML.Data; - using Microsoft.ML.Ensemble; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Utilities; - using Microsoft.ML.Learners; using Microsoft.ML.LightGBM; using Microsoft.ML.TestFramework; + using Microsoft.ML.Trainers; + using Microsoft.ML.Trainers.Ensemble; using Microsoft.ML.Trainers.FastTree; - using Microsoft.ML.Trainers.FastTree.Internal; + using Microsoft.ML.Trainers.HalLearners; using Microsoft.ML.Trainers.Online; - using Microsoft.ML.Trainers.SymSgd; - using Microsoft.ML.Transforms.Categorical; using Xunit; using Xunit.Abstractions; using TestLearners = TestLearnersBase; @@ -162,7 +161,7 @@ public void EarlyStoppingTest() /// /// Multiclass Logistic Regression test. /// - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] // netcore3.0 output differs from Baseline + [LessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline")] [TestCategory("Multiclass")] [TestCategory("Logistic Regression")] public void MulticlassLRTest() @@ -174,7 +173,7 @@ public void MulticlassLRTest() /// /// Multiclass Logistic Regression with non-negative coefficients test. /// - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] // netcore3.0 output differs from Baseline + [LessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline")] [TestCategory("Multiclass")] [TestCategory("Logistic Regression")] public void MulticlassLRNonNegativeTest() @@ -198,7 +197,7 @@ public void MulticlassSdcaTest() /// /// Multiclass Logistic Regression test with a tree featurizer. /// - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 output differs from Baseline + [X64Fact("x86 output differs from Baseline")] [TestCategory("Multiclass")] [TestCategory("Logistic Regression")] [TestCategory("FastTree")] @@ -256,7 +255,7 @@ public void KMeansClusteringTest() Done(); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 output differs from Baseline + [X64Fact("x86 output differs from Baseline")] [TestCategory("Binary")] [TestCategory("SDCA")] public void LinearClassifierTest() @@ -277,7 +276,7 @@ public void LinearClassifierTest() /// ///A test for binary classifiers /// - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 output differs from Baseline + [X64Fact("x86 output differs from Baseline")] [TestCategory("Binary")] public void BinaryClassifierLogisticRegressionTest() { @@ -287,7 +286,7 @@ public void BinaryClassifierLogisticRegressionTest() Done(); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 output differs from Baseline + [X64Fact("x86 output differs from Baseline")] [TestCategory("Binary")] public void BinaryClassifierSymSgdTest() { @@ -299,7 +298,7 @@ public void BinaryClassifierSymSgdTest() Done(); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 output differs from Baseline + [X64Fact("x86 output differs from Baseline")] [TestCategory("Binary")] public void BinaryClassifierTesterThresholdingTest() { @@ -325,7 +324,7 @@ public void BinaryClassifierLogisticRegressionNormTest() /// ///A test for binary classifiers with non-negative coefficients /// - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCoreAnd64BitProcess))] // netcore3.0 and x86 output differs from Baseline + [LessThanNetCore30OrNotNetCoreAndX64Fact("netcoreapp3.0 and x86 output differs from Baseline")] [TestCategory("Binary")] public void BinaryClassifierLogisticRegressionNonNegativeTest() { @@ -338,7 +337,7 @@ public void BinaryClassifierLogisticRegressionNonNegativeTest() /// ///A test for binary classifiers /// - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] // netcore3.0 output differs from Baseline + [LessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline")] [TestCategory("Binary")] public void BinaryClassifierLogisticRegressionBinNormTest() { @@ -351,7 +350,7 @@ public void BinaryClassifierLogisticRegressionBinNormTest() /// ///A test for binary classifiers /// - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCoreAnd64BitProcess))] // x86 output differs from Baseline and flaky on netcore 3.0 + [LessThanNetCore30OrNotNetCoreAndX64Fact("x86 output differs from Baseline and flaky on netcore 3.0")] [TestCategory("Binary")] public void BinaryClassifierLogisticRegressionGaussianNormTest() { @@ -388,7 +387,7 @@ public void BinaryClassifierFastRankClassificationTest() /// ///A test for binary classifiers /// - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 output differs from Baseline + [X64Fact("x86 output differs from Baseline")] [TestCategory("Binary")] [TestCategory("FastForest")] public void FastForestClassificationTest() @@ -455,7 +454,7 @@ public void WeightingFastForestRegressionPredictorsTest() Done(); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 output differs from Baseline + [X64Fact("x86 output differs from Baseline")] [TestCategory("Binary")] [TestCategory("FastTree")] public void FastTreeBinaryClassificationTest() @@ -474,7 +473,7 @@ public void FastTreeBinaryClassificationTest() Done(); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] [TestCategory("Binary")] [TestCategory("LightGBM")] public void LightGBMClassificationTest() @@ -490,7 +489,7 @@ public void LightGBMClassificationTest() Done(); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] [TestCategory("Binary"), TestCategory("LightGBM")] public void GossLightGBMTest() { @@ -500,7 +499,7 @@ public void GossLightGBMTest() Done(); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] [TestCategory("Binary")] [TestCategory("LightGBM")] public void DartLightGBMTest() @@ -514,7 +513,7 @@ public void DartLightGBMTest() /// /// A test for multi class classifiers. /// - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] [TestCategory("Multiclass")] [TestCategory("LightGBM")] public void MultiClassifierLightGBMKeyLabelTest() @@ -528,7 +527,7 @@ public void MultiClassifierLightGBMKeyLabelTest() /// /// A test for multi class classifiers. /// - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] [TestCategory("Multiclass")] [TestCategory("LightGBM")] public void MultiClassifierLightGBMKeyLabelU404Test() @@ -542,7 +541,7 @@ public void MultiClassifierLightGBMKeyLabelU404Test() /// /// A test for regression. /// - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] [TestCategory("Regression")] [TestCategory("LightGBM")] public void RegressorLightGBMTest() @@ -556,7 +555,7 @@ public void RegressorLightGBMTest() /// /// A test for regression. /// - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] [TestCategory("Regression")] [TestCategory("LightGBM")] public void RegressorLightGBMMAETest() @@ -570,7 +569,7 @@ public void RegressorLightGBMMAETest() /// /// A test for regression. /// - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] [TestCategory("Regression")] [TestCategory("LightGBM")] public void RegressorLightGBMRMSETest() @@ -602,8 +601,7 @@ public void RankingLightGBMTest() Done(); } - // x86 fails. Associated GitHubIssue: https://github.com/dotnet/machinelearning/issues/1216 - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] + [X64Fact("x86 fails. Associated GitHubIssue: https://github.com/dotnet/machinelearning/issues/1216")] public void TestTreeEnsembleCombiner() { var dataPath = GetDataPath("breast-cancer.txt"); @@ -624,8 +622,7 @@ public void TestTreeEnsembleCombiner() CombineAndTestTreeEnsembles(dataView, fastTrees); } - // x86 fails. Associated GitHubIssue: https://github.com/dotnet/machinelearning/issues/1216 - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] + [X64Fact("x86 fails. Associated GitHubIssue: https://github.com/dotnet/machinelearning/issues/1216")] public void TestTreeEnsembleCombinerWithCategoricalSplits() { var dataPath = GetDataPath("adult.tiny.with-schema.txt"); @@ -727,8 +724,7 @@ private void CombineAndTestTreeEnsembles(IDataView idv, PredictorModel[] fastTre } } - // x86 fails. Associated GitHubIssue: https://github.com/dotnet/machinelearning/issues/1216 - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] + [X64Fact("x86 fails. Associated GitHubIssue: https://github.com/dotnet/machinelearning/issues/1216")] public void TestEnsembleCombiner() { var dataPath = GetDataPath("breast-cancer.txt"); @@ -748,7 +744,7 @@ public void TestEnsembleCombiner() { FeatureColumn = "Features", LabelColumn = DefaultColumnNames.Label, - NumIterations = 2, + NumberOfIterations = 2, TrainingData = dataView, NormalizeFeatures = NormalizeOption.No }).PredictorModel, @@ -773,8 +769,7 @@ public void TestEnsembleCombiner() CombineAndTestEnsembles(dataView, "pe", "oc=average", PredictionKind.BinaryClassification, predictors); } - // x86 fails. Associated GitHubIssue: https://github.com/dotnet/machinelearning/issues/1216 - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] + [X64Fact("x86 fails. Associated GitHubIssue: https://github.com/dotnet/machinelearning/issues/1216")] public void TestMultiClassEnsembleCombiner() { var dataPath = GetDataPath("breast-cancer.txt"); @@ -943,7 +938,7 @@ private void CombineAndTestEnsembles(IDataView idv, string name, string options, } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 output differs from Baseline + [X64Fact("x86 output differs from Baseline")] [TestCategory("Binary")] [TestCategory("FastTree")] public void FastTreeBinaryClassificationCategoricalSplitTest() @@ -982,7 +977,7 @@ public void FastTreeRegressionCategoricalSplitTest() Done(); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 output differs from Baseline + [X64Fact("x86 output differs from Baseline")] [TestCategory("Binary")] [TestCategory("FastTree")] public void FastTreeBinaryClassificationNoOpGroupIdTest() @@ -1002,7 +997,7 @@ public void FastTreeBinaryClassificationNoOpGroupIdTest() Done(); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 output differs from Baseline + [X64Fact("x86 output differs from Baseline")] [TestCategory("Binary")] [TestCategory("FastTree")] public void FastTreeHighMinDocsTest() @@ -1577,7 +1572,7 @@ public IList GetDatasetsForCalibratorTest() /// ///A test for no calibrators /// - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] // netcore3.0 output differs from Baseline + [LessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline")] [TestCategory("Calibrator")] public void DefaultCalibratorPerceptronTest() { @@ -1589,7 +1584,7 @@ public void DefaultCalibratorPerceptronTest() /// ///A test for PAV calibrators /// - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] // netcore3.0 output differs from Baseline + [LessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline")] [TestCategory("Calibrator")] public void PAVCalibratorPerceptronTest() { @@ -1601,7 +1596,7 @@ public void PAVCalibratorPerceptronTest() /// ///A test for random calibrators /// - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCoreAnd64BitProcess))] // netcore3.0 and x86 output differs from Baseline + [LessThanNetCore30OrNotNetCoreAndX64Fact("netcoreapp3.0 and x86 output differs from Baseline")] [TestCategory("Calibrator")] public void RandomCalibratorPerceptronTest() { diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index 8931f53697..5f76129a89 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -10,12 +10,11 @@ using Microsoft.ML.FactorizationMachine; using Microsoft.ML.Internal.Calibration; using Microsoft.ML.Internal.Internallearn; -using Microsoft.ML.Learners; using Microsoft.ML.LightGBM; using Microsoft.ML.LightGBM.StaticPipe; using Microsoft.ML.RunTests; -using Microsoft.ML.SamplesUtils; using Microsoft.ML.StaticPipe; +using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.Trainers; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.KMeans; @@ -450,7 +449,7 @@ public void FastTreeRegression() Assert.InRange(metrics.LossFn, 0, double.PositiveInfinity); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] public void LightGbmBinaryClassification() { var env = new MLContext(seed: 0); @@ -490,7 +489,7 @@ public void LightGbmBinaryClassification() Assert.InRange(metrics.Auprc, 0, 1); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] public void LightGbmRegression() { var env = new MLContext(seed: 0); @@ -806,7 +805,7 @@ public void FastTreeRanking() Assert.InRange(metrics.Ndcg[2], 36.5, 37); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] public void LightGBMRanking() { var env = new MLContext(seed: 0); @@ -847,7 +846,7 @@ public void LightGBMRanking() Assert.InRange(metrics.Ndcg[2], 36.5, 37); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] public void MultiClassLightGBM() { var env = new MLContext(seed: 0); @@ -968,7 +967,7 @@ public void HogwildSGDBinaryClassification() Assert.InRange(metrics.Auprc, 0, 1); } - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCoreAnd64BitProcess))] // netcore3.0 and x86 output differs from Baseline. This test is being fixed as part of issue #1441. + [LessThanNetCore30OrNotNetCoreAndX64Fact("netcoreapp3.0 and x86 output differs from Baseline. Being tracked as part of https://github.com/dotnet/machinelearning/issues/1441")] public void MatrixFactorization() { // Create a new context for ML.NET operations. It can be used for exception tracking and logging, @@ -986,7 +985,7 @@ public void MatrixFactorization() // The parameter that will be into the onFit method below. The obtained predictor will be assigned to this variable // so that we will be able to touch it. - MatrixFactorizationPredictor pred = null; + MatrixFactorizationModelParameters pred = null; // Create a statically-typed matrix factorization estimator. The MatrixFactorization's input and output defined in MatrixFactorizationStatic // tell what (aks a Scalar) is expected. Notice that only one thread is used for deterministic outcome. @@ -1019,7 +1018,7 @@ public void MatrixFactorization() Assert.InRange(metrics.L2, 0, 0.5); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] public void MultiClassLightGbmStaticPipelineWithInMemoryData() { // Create a general context for ML.NET operations. It can be used for exception tracking and logging, @@ -1027,7 +1026,7 @@ public void MultiClassLightGbmStaticPipelineWithInMemoryData() var mlContext = new MLContext(seed: 1, conc: 1); // Create in-memory examples as C# native class. - var examples = DatasetUtils.GenerateRandomMulticlassClassificationExamples(1000); + var examples = SamplesUtils.DatasetUtils.GenerateRandomMulticlassClassificationExamples(1000); // Convert native C# class to IDataView, a consumble format to ML.NET functions. var dataView = mlContext.Data.ReadFromEnumerable(examples); @@ -1079,7 +1078,7 @@ public void MultiClassLightGbmStaticPipelineWithInMemoryData() Assert.Equal(0.86309523809523814, metrics.AccuracyMicro, 6); // Convert prediction in ML.NET format to native C# class. - var nativePredictions = mlContext.CreateEnumerable(prediction.AsDynamic, false).ToList(); + var nativePredictions = mlContext.CreateEnumerable(prediction.AsDynamic, false).ToList(); // Get schema object of the prediction. It contains metadata such as the mapping from predicted label index // (e.g., 1) to its actual label (e.g., "AA"). @@ -1087,7 +1086,7 @@ public void MultiClassLightGbmStaticPipelineWithInMemoryData() // Retrieve the mapping from labels to label indexes. var labelBuffer = new VBuffer>(); - schema[nameof(DatasetUtils.MulticlassClassificationExample.PredictedLabelIndex)].Metadata.GetValue("KeyValues", ref labelBuffer); + schema[nameof(SamplesUtils.DatasetUtils.MulticlassClassificationExample.PredictedLabelIndex)].Metadata.GetValue("KeyValues", ref labelBuffer); var nativeLabels = labelBuffer.DenseValues().ToList(); // nativeLabels[nativePrediction.PredictedLabelIndex-1] is the original label indexed by nativePrediction.PredictedLabelIndex. // Show prediction result for the 3rd example. diff --git a/test/Microsoft.ML.TestFramework/Attributes/BenchmarkTheoryAttribute.cs b/test/Microsoft.ML.TestFramework/Attributes/BenchmarkTheoryAttribute.cs new file mode 100644 index 0000000000..fae467176e --- /dev/null +++ b/test/Microsoft.ML.TestFramework/Attributes/BenchmarkTheoryAttribute.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.TestFramework.Attributes +{ + /// + /// A theory for BenchmarkDotNet tests. + /// + public sealed class BenchmarkTheoryAttribute : EnvironmentSpecificTheoryAttribute + { +#if DEBUG + private const string SkipMessage = "BenchmarkDotNet does not allow running the benchmarks in Debug, so this test is disabled for DEBUG"; + private readonly bool _isEnvironmentSupported = false; +#elif NET461 + private const string SkipMessage = "We are currently not running Benchmarks for FullFramework"; + private readonly bool _isEnvironmentSupported = false; +#else + private const string SkipMessage = "We don't support 32 bit yet"; + private readonly bool _isEnvironmentSupported = System.Environment.Is64BitProcess; +#endif + + public BenchmarkTheoryAttribute() : base(SkipMessage) + { + } + + protected override bool IsEnvironmentSupported() => _isEnvironmentSupported; + } +} \ No newline at end of file diff --git a/test/Microsoft.ML.TestFramework/Attributes/EnvironmentSpecificFactAttribute.cs b/test/Microsoft.ML.TestFramework/Attributes/EnvironmentSpecificFactAttribute.cs new file mode 100644 index 0000000000..fe754b6eb2 --- /dev/null +++ b/test/Microsoft.ML.TestFramework/Attributes/EnvironmentSpecificFactAttribute.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Xunit; + +namespace Microsoft.ML.TestFramework.Attributes +{ + /// + /// A base class for environment-specific fact attributes. + /// + [AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = true)] + public abstract class EnvironmentSpecificFactAttribute : FactAttribute + { + private readonly string _skipMessage; + + /// + /// Creates a new instance of the class. + /// + /// The message to be used when skipping the test marked with this attribute. + protected EnvironmentSpecificFactAttribute(string skipMessage) + { + _skipMessage = skipMessage ?? throw new ArgumentNullException(nameof(skipMessage)); + } + + public sealed override string Skip => IsEnvironmentSupported() ? null : _skipMessage; + + /// + /// A method used to evaluate whether to skip a test marked with this attribute. Skips iff this method evaluates to false. + /// + protected abstract bool IsEnvironmentSupported(); + } +} \ No newline at end of file diff --git a/test/Microsoft.ML.TestFramework/Attributes/EnvironmentSpecificTheoryAttribute.cs b/test/Microsoft.ML.TestFramework/Attributes/EnvironmentSpecificTheoryAttribute.cs new file mode 100644 index 0000000000..c8f4b44c56 --- /dev/null +++ b/test/Microsoft.ML.TestFramework/Attributes/EnvironmentSpecificTheoryAttribute.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Xunit; + +namespace Microsoft.ML.TestFramework.Attributes +{ + /// + /// A base class for environment-specific fact attributes. + /// + [AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = true)] + public abstract class EnvironmentSpecificTheoryAttribute : TheoryAttribute + { + private readonly string _skipMessage; + + /// + /// Creates a new instance of the class. + /// + /// The message to be used when skipping the test marked with this attribute. + protected EnvironmentSpecificTheoryAttribute(string skipMessage) + { + _skipMessage = skipMessage ?? throw new ArgumentNullException(nameof(skipMessage)); + } + + public sealed override string Skip => IsEnvironmentSupported() ? null : _skipMessage; + + /// + /// A method used to evaluate whether to skip a test marked with this attribute. Skips iff this method evaluates to false. + /// + protected abstract bool IsEnvironmentSupported(); + } +} \ No newline at end of file diff --git a/test/Microsoft.ML.TestFramework/Attributes/LessThanNetCore30OrNotNetCoreAndX64FactAttribute.cs b/test/Microsoft.ML.TestFramework/Attributes/LessThanNetCore30OrNotNetCoreAndX64FactAttribute.cs new file mode 100644 index 0000000000..4d8814b5d6 --- /dev/null +++ b/test/Microsoft.ML.TestFramework/Attributes/LessThanNetCore30OrNotNetCoreAndX64FactAttribute.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.TestFramework.Attributes +{ + /// + /// A fact for tests requiring x64 environment and either .NET Core version lower than 3.0 or framework other than .NET Core. + /// + public sealed class LessThanNetCore30OrNotNetCoreAndX64FactAttribute : EnvironmentSpecificFactAttribute + { + public LessThanNetCore30OrNotNetCoreAndX64FactAttribute(string skipMessage) : base(skipMessage) + { + } + + /// + protected override bool IsEnvironmentSupported() + { + return Environment.Is64BitProcess && AppDomain.CurrentDomain.GetData("FX_PRODUCT_VERSION") == null; + } + } +} \ No newline at end of file diff --git a/test/Microsoft.ML.TestFramework/Attributes/LessThanNetCore30OrNotNetCoreFactAttribute.cs b/test/Microsoft.ML.TestFramework/Attributes/LessThanNetCore30OrNotNetCoreFactAttribute.cs new file mode 100644 index 0000000000..7fdfdf2708 --- /dev/null +++ b/test/Microsoft.ML.TestFramework/Attributes/LessThanNetCore30OrNotNetCoreFactAttribute.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.TestFramework.Attributes +{ + /// + /// A fact for tests requiring either .NET Core version lower than 3.0 or framework other than .NET Core. + /// + public sealed class LessThanNetCore30OrNotNetCoreFactAttribute : EnvironmentSpecificFactAttribute + { + public LessThanNetCore30OrNotNetCoreFactAttribute(string skipMessage) : base(skipMessage) + { + } + + /// + protected override bool IsEnvironmentSupported() + { + return AppDomain.CurrentDomain.GetData("FX_PRODUCT_VERSION") == null; + } + } +} \ No newline at end of file diff --git a/test/Microsoft.ML.TestFramework/Attributes/LightGBMFactAttribute.cs b/test/Microsoft.ML.TestFramework/Attributes/LightGBMFactAttribute.cs new file mode 100644 index 0000000000..726d50e970 --- /dev/null +++ b/test/Microsoft.ML.TestFramework/Attributes/LightGBMFactAttribute.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.TestFramework.Attributes +{ + /// + /// A fact for tests requiring LightGBM. + /// + public sealed class LightGBMFactAttribute : EnvironmentSpecificFactAttribute + { + public LightGBMFactAttribute() : base("LightGBM is 64-bit only") + { + } + + /// + protected override bool IsEnvironmentSupported() + { + return Environment.Is64BitProcess; + } + } +} \ No newline at end of file diff --git a/test/Microsoft.ML.TestFramework/Attributes/MatrixFactorizationFactAttribute.cs b/test/Microsoft.ML.TestFramework/Attributes/MatrixFactorizationFactAttribute.cs new file mode 100644 index 0000000000..8ba0bba3e1 --- /dev/null +++ b/test/Microsoft.ML.TestFramework/Attributes/MatrixFactorizationFactAttribute.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.TestFramework.Attributes +{ + /// + /// A fact for tests requiring matrix factorization. + /// + public sealed class MatrixFactorizationFactAttribute : EnvironmentSpecificFactAttribute + { + public MatrixFactorizationFactAttribute() : base("Disabled - this test is being fixed as part of https://github.com/dotnet/machinelearning/issues/1441") + { + } + + /// + protected override bool IsEnvironmentSupported() + { + return Environment.Is64BitProcess; + } + } +} \ No newline at end of file diff --git a/test/Microsoft.ML.TestFramework/Attributes/OnnxFactAttribute.cs b/test/Microsoft.ML.TestFramework/Attributes/OnnxFactAttribute.cs new file mode 100644 index 0000000000..d2ed6763d6 --- /dev/null +++ b/test/Microsoft.ML.TestFramework/Attributes/OnnxFactAttribute.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.ML.TestFramework.Attributes +{ + /// + /// A fact for tests requiring Onnx. + /// + public sealed class OnnxFactAttribute : EnvironmentSpecificFactAttribute + { + public OnnxFactAttribute() : base("Onnx is 64-bit Windows only") + { + } + + /// + protected override bool IsEnvironmentSupported() + { + return Environment.Is64BitProcess && RuntimeInformation.IsOSPlatform(OSPlatform.Windows); + } + } +} \ No newline at end of file diff --git a/test/Microsoft.ML.TestFramework/Attributes/OnnxTheoryAttribute.cs b/test/Microsoft.ML.TestFramework/Attributes/OnnxTheoryAttribute.cs new file mode 100644 index 0000000000..979654c9c3 --- /dev/null +++ b/test/Microsoft.ML.TestFramework/Attributes/OnnxTheoryAttribute.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.ML.TestFramework.Attributes +{ + /// + /// A fact for tests requiring Onnx. + /// + public sealed class OnnxTheoryAttribute : EnvironmentSpecificTheoryAttribute + { + public OnnxTheoryAttribute() : base("Onnx is 64-bit Windows only") + { + } + + /// + protected override bool IsEnvironmentSupported() + { + return Environment.Is64BitProcess && RuntimeInformation.IsOSPlatform(OSPlatform.Windows); + } + } +} \ No newline at end of file diff --git a/test/Microsoft.ML.TestFramework/Attributes/TensorflowFactAttribute.cs b/test/Microsoft.ML.TestFramework/Attributes/TensorflowFactAttribute.cs new file mode 100644 index 0000000000..7711e4d05f --- /dev/null +++ b/test/Microsoft.ML.TestFramework/Attributes/TensorflowFactAttribute.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.TestFramework.Attributes +{ + /// + /// A fact for tests requiring TensorFlow. + /// + public sealed class TensorFlowFactAttribute : EnvironmentSpecificFactAttribute + { + public TensorFlowFactAttribute() : base("TensorFlow is 64-bit only") + { + } + + /// + protected override bool IsEnvironmentSupported() + { + return Environment.Is64BitProcess; + } + } +} diff --git a/test/Microsoft.ML.TestFramework/Attributes/X64FactAttribute.cs b/test/Microsoft.ML.TestFramework/Attributes/X64FactAttribute.cs new file mode 100644 index 0000000000..dcf8814bb8 --- /dev/null +++ b/test/Microsoft.ML.TestFramework/Attributes/X64FactAttribute.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.TestFramework.Attributes +{ + /// + /// A fact for tests requiring X64 environment. + /// + public sealed class X64FactAttribute : EnvironmentSpecificFactAttribute + { + public X64FactAttribute(string skipMessage) : base(skipMessage) + { + } + + /// + protected override bool IsEnvironmentSupported() + { + return Environment.Is64BitProcess; + } + } +} \ No newline at end of file diff --git a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs index b3c2faf3be..43d89d9982 100644 --- a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs +++ b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs @@ -24,20 +24,8 @@ public abstract partial class BaseTestBaseline : BaseTestClass { public const int DigitsOfPrecision = 7; - public static bool NotFullFramework { get; set;} = true; - - public static bool LessThanNetCore30OrNotNetCore { get; } = AppDomain.CurrentDomain.GetData("FX_PRODUCT_VERSION") == null ? true : false; - - public static bool LessThanNetCore30AndNotFullFramework { get; set; } = LessThanNetCore30OrNotNetCore; - - public static bool LessThanNetCore30OrNotNetCoreAnd64BitProcess { get; } = LessThanNetCore30OrNotNetCore && Environment.Is64BitProcess; - protected BaseTestBaseline(ITestOutputHelper output) : base(output) { -#if NETFRAMEWORK - NotFullFramework = false; - LessThanNetCore30AndNotFullFramework = false; -#endif } internal const string RawSuffix = ".raw"; @@ -777,9 +765,6 @@ public void RunMTAThread(ThreadStart fn) } }); t.IsBackground = true; -#if !CORECLR // CoreCLR does not support apartment state settings for threads. - t.SetApartmentState(ApartmentState.MTA); -#endif t.Start(); t.Join(); if (inner != null) @@ -802,11 +787,8 @@ public void RunMTAThread(ThreadStart fn) protected static StreamWriter OpenWriter(string path, bool append = false, Encoding encoding = null, int bufferSize = 1024) { Contracts.CheckNonWhiteSpace(path, nameof(path)); -#if CORECLR + return Utils.OpenWriter(File.Open(path, append ? FileMode.Append : FileMode.OpenOrCreate), encoding, bufferSize, false); -#else - return new StreamWriter(path, append); -#endif } /// @@ -817,11 +799,8 @@ protected static StreamWriter OpenWriter(string path, bool append = false, Encod protected static StreamReader OpenReader(string path) { Contracts.CheckNonWhiteSpace(path, nameof(path)); -#if CORECLR + return new StreamReader(File.Open(path, FileMode.Open, FileAccess.Read, FileShare.Read)); -#else - return new StreamReader(path); -#endif } /// diff --git a/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs b/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs index 57b2983f44..16e2ef0325 100644 --- a/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs +++ b/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs @@ -9,7 +9,7 @@ namespace Microsoft.ML.RunTests { - using ResultProcessor = Microsoft.ML.Internal.Internallearn.ResultProcessor.ResultProcessor; + using ResultProcessor = ResultProcessor.ResultProcessor; /// /// This is a base test class designed to support running trainings and related diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index 61149b7022..fa7c4c8ea0 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -1096,23 +1096,6 @@ protected Func GetColumnComparer(Row r1, Row r2, int col, ColumnType type, return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); } -#if !CORECLR // REVIEW: Port Picture type to CoreTLC. - if (type is PictureType) - { - var g1 = r1.GetGetter(col); - var g2 = r2.GetGetter(col); - Picture v1 = null; - Picture v2 = null; - return - () => - { - g1(ref v1); - g2(ref v2); - return ComparePicture(v1, v2); - }; - } -#endif - throw Contracts.Except("Unknown type in GetColumnComparer: '{0}'", type); } @@ -1250,36 +1233,5 @@ protected void VerifyVecEquality(ValueGetter> vecGetter, ValueGett vecNGetter(ref fvn); Assert.True(CompareVec(in fv, in fvn, size, compare)); } - -#if !CORECLR - // REVIEW: Port Picture type to Core TLC. - protected bool ComparePicture(Picture v1, Picture v2) - { - if (v1 == null || v2 == null) - return v1 == v2; - - var p1 = v1.Contents.Pixels; - var p2 = v2.Contents.Pixels; - - if (p1.Width != p2.Width) - return false; - if (p1.Height != p2.Height) - return false; - if (p1.PixelFormat != p2.PixelFormat) - return false; - - for (int y = 0; y < p1.Height; y++) - { - for (int x = 0; x < p1.Width; x++) - { - var x1 = p1.GetPixel(x, y); - var x2 = p2.GetPixel(x, y); - if (x1 != x2) - return false; - } - } - return true; - } -#endif } } diff --git a/test/Microsoft.ML.TestFramework/Datasets.cs b/test/Microsoft.ML.TestFramework/Datasets.cs index 6d6ba61191..1bdfa5048b 100644 --- a/test/Microsoft.ML.TestFramework/Datasets.cs +++ b/test/Microsoft.ML.TestFramework/Datasets.cs @@ -158,7 +158,24 @@ public static class TestDatasets name = "housing", trainFilename = "housing.txt", testFilename = "housing.txt", - loaderSettings = "loader=Text{col=Label:0 col=Features:~ header=+}" + loaderSettings = "loader=Text{col=Label:0 col=Features:~ header=+}", + GetLoaderColumns = () => + { + return new[] { + new TextLoader.Column("MedianHomeValue", DataKind.R4, 0), + new TextLoader.Column("CrimesPerCapita", DataKind.R4, 1), + new TextLoader.Column("PercentResidental", DataKind.R4, 2), + new TextLoader.Column("PercentNonRetail", DataKind.R4, 3), + new TextLoader.Column("CharlesRiver", DataKind.R4, 4), + new TextLoader.Column("NitricOxides", DataKind.R4, 5), + new TextLoader.Column("RoomsPerDwelling", DataKind.R4, 6), + new TextLoader.Column("PercentPre40s", DataKind.R4, 7), + new TextLoader.Column("EmploymentDistance", DataKind.R4, 8), + new TextLoader.Column("HighwayDistance", DataKind.R4, 9), + new TextLoader.Column("TaxRate", DataKind.R4, 10), + new TextLoader.Column("TeacherRatio", DataKind.R4, 11), + }; + } }; public static TestDataset generatedRegressionDatasetmacro = new TestDataset diff --git a/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs b/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs index 1b8cbdcb60..51f9c7e465 100644 --- a/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs +++ b/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs @@ -3,9 +3,9 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Data; -using Microsoft.ML.Ensemble; using Microsoft.ML.EntryPoints; -using Microsoft.ML.Learners; +using Microsoft.ML.Trainers; +using Microsoft.ML.Trainers.Ensemble; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.KMeans; using Microsoft.ML.Trainers.PCA; diff --git a/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj b/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj index a8a0ad7000..d24892a1a8 100644 --- a/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj +++ b/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj @@ -1,8 +1,4 @@  - - CORECLR - - diff --git a/test/Microsoft.ML.TestFramework/TestCommandBase.cs b/test/Microsoft.ML.TestFramework/TestCommandBase.cs index cb3a88782f..4475568a6b 100644 --- a/test/Microsoft.ML.TestFramework/TestCommandBase.cs +++ b/test/Microsoft.ML.TestFramework/TestCommandBase.cs @@ -15,6 +15,7 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; using Microsoft.ML.TestFramework; +using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.Tools; using Xunit; using Xunit.Abstractions; @@ -838,7 +839,7 @@ public void CommandCrossValidation() Done(); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 output differs from Baseline + [X64Fact("x86 output differs from Baseline")] public void CommandCrossValidationKeyLabelWithFloatKeyValues() { RunMTAThread(() => @@ -1170,7 +1171,7 @@ public void CommandTrainMlrWithLabelNames() Done(); } - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] // netcore3.0 output differs from Baseline + [LessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline")] [TestCategory(Cat), TestCategory("Multiclass"), TestCategory("Logistic Regression")] public void CommandTrainMlrWithStats() { diff --git a/test/Microsoft.ML.Tests/FeatureContributionTests.cs b/test/Microsoft.ML.Tests/FeatureContributionTests.cs index 446aed03a0..7833b5afae 100644 --- a/test/Microsoft.ML.Tests/FeatureContributionTests.cs +++ b/test/Microsoft.ML.Tests/FeatureContributionTests.cs @@ -10,6 +10,7 @@ using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.RunTests; +using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.Trainers; using Microsoft.ML.Training; using Microsoft.ML.Transforms; @@ -47,7 +48,7 @@ public void TestOrdinaryLeastSquaresRegression() TestFeatureContribution(ML.Regression.Trainers.OrdinaryLeastSquares(), GetSparseDataset(numberOfInstances: 100), "LeastSquaresRegression"); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] public void TestLightGbmRegression() { TestFeatureContribution(ML.Regression.Trainers.LightGbm(), GetSparseDataset(numberOfInstances: 100), "LightGbmRegression"); @@ -104,7 +105,7 @@ public void TestFastTreeRanking() TestFeatureContribution(ML.Ranking.Trainers.FastTree(), GetSparseDataset(TaskType.Ranking, 100), "FastTreeRanking"); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] public void TestLightGbmRanking() { TestFeatureContribution(ML.Ranking.Trainers.LightGbm(), GetSparseDataset(TaskType.Ranking, 100), "LightGbmRanking"); @@ -141,7 +142,7 @@ public void TestFastTreeBinary() TestFeatureContribution(ML.BinaryClassification.Trainers.FastTree(), GetSparseDataset(TaskType.BinaryClassification, 100), "FastTreeBinary"); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] public void TestLightGbmBinary() { TestFeatureContribution(ML.BinaryClassification.Trainers.LightGbm(), GetSparseDataset(TaskType.BinaryClassification, 100), "LightGbmBinary"); diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index 2d56666b7f..37f4b25c1e 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -46,7 +46,7 @@ - - + + diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index f5b062c31a..4d8a06972f 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -11,10 +11,11 @@ using Google.Protobuf; using Microsoft.Data.DataView; using Microsoft.ML.Data; -using Microsoft.ML.Learners; using Microsoft.ML.Model.Onnx; using Microsoft.ML.RunTests; +using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.Tools; +using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; using Microsoft.ML.UniversalModelFormat.Onnx; using Newtonsoft.Json; @@ -118,7 +119,7 @@ private class BreastCancerMulticlassExample public float[] Features; } - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] // Tracked by https://github.com/dotnet/machinelearning/issues/2087 + [LessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline. Tracked by https://github.com/dotnet/machinelearning/issues/2087")] public void KmeansOnnxConversionTest() { // Create a new context for ML.NET operations. It can be used for exception tracking and logging, @@ -330,7 +331,7 @@ public void LogisticRegressionOnnxConversionTest() Done(); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] public void LightGbmBinaryClassificationOnnxConversionTest() { // Step 1: Create and train a ML.NET pipeline. diff --git a/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs b/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs index e6fa53c584..6be9b0adfe 100644 --- a/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs +++ b/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs @@ -8,8 +8,8 @@ using Microsoft.Data.DataView; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Learners; using Microsoft.ML.RunTests; +using Microsoft.ML.Trainers; using Xunit; using Xunit.Abstractions; diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs index 01501677be..eab7d56fdb 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs @@ -10,7 +10,6 @@ using Microsoft.ML; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Learners; using Microsoft.ML.RunTests; using Microsoft.ML.StaticPipe; using Microsoft.ML.TestFramework; diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs index 39957cc031..dd05dfd549 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs @@ -11,9 +11,9 @@ using Microsoft.Data.DataView; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Learners; using Microsoft.ML.RunTests; using Microsoft.ML.TestFramework; +using Microsoft.ML.Trainers; using Microsoft.ML.Transforms.Categorical; using Microsoft.ML.Transforms.Normalizers; using Microsoft.ML.Transforms.Text; @@ -428,12 +428,12 @@ private void CrossValidationOn(string dataPath) .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent()); // Split the data 90:10 into train and test sets, train and evaluate. - var (trainData, testData) = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1); + var split = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1); // Train the model. - var model = pipeline.Fit(trainData); + var model = pipeline.Fit(split.TrainSet); // Compute quality metrics on the test set. - var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(testData)); + var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(split.TestSet)); Console.WriteLine(metrics.AccuracyMicro); // Now run the 5-fold cross-validation experiment, using the same pipeline. @@ -441,7 +441,7 @@ private void CrossValidationOn(string dataPath) // The results object is an array of 5 elements. For each of the 5 folds, we have metrics, model and scored test data. // Let's compute the average micro-accuracy. - var microAccuracies = cvResults.Select(r => r.metrics.AccuracyMicro); + var microAccuracies = cvResults.Select(r => r.Metrics.AccuracyMicro); Console.WriteLine(microAccuracies.Average()); } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs deleted file mode 100644 index a4e3afc2cc..0000000000 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs +++ /dev/null @@ -1,36 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.RunTests; -using Microsoft.ML.Trainers; -using Xunit; - -namespace Microsoft.ML.Tests.Scenarios.Api -{ - public partial class ApiScenariosTests - { - /// - /// Cross-validation: Have a mechanism to do cross validation, that is, you come up with - /// a data source (optionally with stratification column), come up with an instantiable transform - /// and trainer pipeline, and it will handle (1) splitting up the data, (2) training the separate - /// pipelines on in-fold data, (3) scoring on the out-fold data, (4) returning the set of - /// evaluations and optionally trained pipes. (People always want metrics out of xfold, - /// they sometimes want the actual models too.) - /// - [Fact] - void CrossValidation() - { - var ml = new MLContext(seed: 1, conc: 1); - - var data = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true); - - // Pipeline. - var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText") - .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( - new SdcaBinaryTrainer.Options { ConvergenceTolerance = 1f, NumThreads = 1, })); - - var cvResult = ml.BinaryClassification.CrossValidate(data, pipeline); - } - } -} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs index fd61f0eb77..5223cc0e38 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs @@ -94,7 +94,7 @@ public void FastTreeClassificationIntrospectiveTraining() public void FastForestRegressionIntrospectiveTraining() { var ml = new MLContext(seed: 1, conc: 1); - var data = DatasetUtils.GenerateFloatLabelFloatFeatureVectorSamples(1000); + var data = SamplesUtils.DatasetUtils.GenerateFloatLabelFloatFeatureVectorSamples(1000); var dataView = ml.Data.ReadFromEnumerable(data); RegressionPredictionTransformer pred = null; diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs deleted file mode 100644 index 254dd73e45..0000000000 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs +++ /dev/null @@ -1,47 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.Data; -using Microsoft.ML.RunTests; -using Microsoft.ML.Trainers; -using Xunit; - -namespace Microsoft.ML.Tests.Scenarios.Api -{ - public partial class ApiScenariosTests - { - /// - /// Reconfigurable predictions: The following should be possible: A user trains a binary classifier, - /// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold - /// and configures the scorer (or more precisely instantiates a new scorer over the same predictor) - /// with some threshold derived from that. - /// - [Fact] - public void ReconfigurablePrediction() - { - var ml = new MLContext(seed: 1, conc: 1); - var dataReader = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true); - - var data = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true); - var testData = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.Sentiment.testFilename), hasHeader: true); - - // Pipeline. - var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText") - .Fit(data); - - var trainer = ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( - new SdcaBinaryTrainer.Options { NumThreads = 1 }); - - var trainData = ml.Data.Cache(pipeline.Transform(data)); // Cache the data right before the trainer to boost the training speed. - var model = trainer.Fit(trainData); - - var scoredTest = model.Transform(pipeline.Transform(testData)); - var metrics = ml.BinaryClassification.Evaluate(scoredTest); - - var newModel = new BinaryPredictionTransformer>(ml, model.Model, trainData.Schema, model.FeatureColumn, threshold: 0.01f, thresholdColumn: DefaultColumnNames.Probability); - var newScoredTest = newModel.Transform(pipeline.Transform(testData)); - var newMetrics = ml.BinaryClassification.Evaluate(scoredTest); - } - } -} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs index 6d70b43d61..7a328c73d6 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs @@ -6,7 +6,7 @@ using Microsoft.ML.Data; using Microsoft.ML.RunTests; using Microsoft.ML.Trainers; -using Microsoft.ML.Trainers.SymSgd; +using Microsoft.ML.Trainers.HalLearners; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs index bc43d688fb..ac54587b65 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs @@ -182,7 +182,7 @@ public void TrainAveragedPerceptronWithCache() var cached = mlContext.Data.Cache(xf); var estimator = mlContext.BinaryClassification.Trainers.AveragedPerceptron( - new AveragedPerceptronTrainer.Options { NumIterations = 2 }); + new AveragedPerceptronTrainer.Options { NumberOfIterations = 2 }); estimator.Fit(cached).Transform(cached); @@ -312,22 +312,22 @@ public void TestTrainTestSplit() // Let's test what train test properly works with seed. // In order to do that, let's split same dataset, but in one case we will use default seed value, // and in other case we set seed to be specific value. - var (simpleTrain, simpleTest) = mlContext.BinaryClassification.TrainTestSplit(input); - var (simpleTrainWithSeed, simpleTestWithSeed) = mlContext.BinaryClassification.TrainTestSplit(input, seed: 10); + var simpleSplit = mlContext.BinaryClassification.TrainTestSplit(input); + var splitWithSeed = mlContext.BinaryClassification.TrainTestSplit(input, seed: 10); // Since test fraction is 0.1, it's much faster to compare test subsets of split. - var simpleTestWorkClass = getWorkclass(simpleTest); + var simpleTestWorkClass = getWorkclass(simpleSplit.TestSet); - var simpleWithSeedTestWorkClass = getWorkclass(simpleTestWithSeed); + var simpleWithSeedTestWorkClass = getWorkclass(splitWithSeed.TestSet); // Validate we get different test sets. Assert.NotEqual(simpleTestWorkClass, simpleWithSeedTestWorkClass); // Now let's do same thing but with presence of stratificationColumn. // Rows with same values in this stratificationColumn should end up in same subset (train or test). // So let's break dataset by "Workclass" column. - var (stratTrain, stratTest) = mlContext.BinaryClassification.TrainTestSplit(input, stratificationColumn: "Workclass"); - var stratTrainWorkclass = getWorkclass(stratTrain); - var stratTestWorkClass = getWorkclass(stratTest); + var stratSplit = mlContext.BinaryClassification.TrainTestSplit(input, stratificationColumn: "Workclass"); + var stratTrainWorkclass = getWorkclass(stratSplit.TrainSet); + var stratTestWorkClass = getWorkclass(stratSplit.TestSet); // Let's get unique values for "Workclass" column from train subset. var uniqueTrain = stratTrainWorkclass.GroupBy(x => x.ToString()).Select(x => x.First()).ToList(); // and from test subset. @@ -337,9 +337,9 @@ public void TestTrainTestSplit() // Let's do same thing, but this time we will choose different seed. // Stratification column should still break dataset properly without same values in both subsets. - var (stratWithSeedTrain, stratWithSeedTest) = mlContext.BinaryClassification.TrainTestSplit(input, stratificationColumn:"Workclass", seed: 1000000); - var stratTrainWithSeedWorkclass = getWorkclass(stratWithSeedTrain); - var stratTestWithSeedWorkClass = getWorkclass(stratWithSeedTest); + var stratSeed = mlContext.BinaryClassification.TrainTestSplit(input, stratificationColumn:"Workclass", seed: 1000000); + var stratTrainWithSeedWorkclass = getWorkclass(stratSeed.TrainSet); + var stratTestWithSeedWorkClass = getWorkclass(stratSeed.TestSet); // Let's get unique values for "Workclass" column from train subset. var uniqueSeedTrain = stratTrainWithSeedWorkclass.GroupBy(x => x.ToString()).Select(x => x.First()).ToList(); // and from test subset. diff --git a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs index ecea5411a6..9954669bd3 100644 --- a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs +++ b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs @@ -3,8 +3,6 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Data; -using Microsoft.ML.Learners; -using Microsoft.ML.Trainers; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.Online; using Xunit; @@ -133,7 +131,7 @@ public void OvaLinearSvm() // Pipeline var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll( - mlContext.BinaryClassification.Trainers.LinearSupportVectorMachines(new LinearSvmTrainer.Options { NumIterations = 100 }), + mlContext.BinaryClassification.Trainers.LinearSupportVectorMachines(new LinearSvmTrainer.Options { NumberOfIterations = 100 }), useProbabilities: false); var model = pipeline.Fit(data); diff --git a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs index c36fb058a5..8d615d9330 100644 --- a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs @@ -2,11 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.IO; using Microsoft.ML.Data; using Microsoft.ML.ImageAnalytics; -using Microsoft.ML.Trainers; +using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Conversions; using Xunit; @@ -15,7 +14,7 @@ namespace Microsoft.ML.Scenarios { public partial class ScenariosTests { - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + [TensorFlowFact] public void TensorFlowTransforCifarEndToEndTest() { var imageHeight = 32; diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index fd8d3183bb..f006c5a7f0 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -10,6 +10,7 @@ using Microsoft.ML.Data; using Microsoft.ML.ImageAnalytics; using Microsoft.ML.RunTests; +using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Conversions; using Microsoft.ML.Transforms.Normalizers; @@ -29,7 +30,7 @@ private class TestData public float[] b; } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + [TensorFlowFact] public void TensorFlowTransformMatrixMultiplicationTest() { var modelLocation = "model_matmul/frozen_saved_model.pb"; @@ -74,6 +75,117 @@ public void TensorFlowTransformMatrixMultiplicationTest() } } + private class ShapeData + { + // Data will be passed as 1-D vector. + // Intended data shape [5], model shape [None] + [VectorType(5)] + public float[] OneDim; + + // Data will be passed as flat vector. + // Intended data shape [2,2], model shape [2, None] + [VectorType(4)] + public float[] TwoDim; + + // Data will be passed as 3-D vector. + // Intended data shape [1, 2, 2], model shape [1, None, 2] + [VectorType(1, 2, 2)] + public float[] ThreeDim; + + // Data will be passed as flat vector. + // Intended data shape [1, 2, 2, 3], model shape [1, None, None, 3] + [VectorType(12)] + public float[] FourDim; + + // Data will be passed as 4-D vector. + // Intended data shape [2, 2, 2, 2], model shape [2, 2, 2, 2] + [VectorType(2, 2, 2, 2)] + public float[] FourDimKnown; + } + + private List GetShapeData() + { + return new List(new ShapeData[] { + new ShapeData() { OneDim = new[] { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f }, + TwoDim = new[] { 1.0f, 2.0f, 3.0f, 4.0f }, + ThreeDim = new[] { 11.0f, 12.0f, 13.0f, 14.0f }, + FourDim = new[]{ 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, + 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f }, + FourDimKnown = new[]{ 41.0f , 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, + 49.0f , 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f} + }, + new ShapeData() { OneDim = new[] { 100.1f, 100.2f, 100.3f, 100.4f, 100.5f }, + TwoDim = new[] { 101.0f, 102.0f, 103.0f, 104.0f }, + ThreeDim = new[] { 111.0f, 112.0f, 113.0f, 114.0f }, + FourDim = new[]{ 121.0f, 122.0f, 123.0f, 124.0f, 125.0f, 126.0f, + 127.0f, 128.0f, 129.0f, 130.0f, 131.0f, 132.0f}, + FourDimKnown = new[]{ 141.0f , 142.0f, 143.0f, 144.0f, 145.0f, 146.0f, 147.0f, 148.0f, + 149.0f , 150.0f, 151.0f, 152.0f, 153.0f, 154.0f, 155.0f, 156.0f } + } + }); + } + + [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + public void TensorFlowTransformInputShapeTest() + { + var modelLocation = "model_shape_test"; + var mlContext = new MLContext(seed: 1, conc: 1); + var data = GetShapeData(); + // Pipeline + var loader = mlContext.Data.ReadFromEnumerable(data); + var inputs = new string[] { "OneDim", "TwoDim", "ThreeDim", "FourDim", "FourDimKnown" }; + var outputs = new string[] { "o_OneDim", "o_TwoDim", "o_ThreeDim", "o_FourDim", "o_FourDimKnown" }; + + var trans = mlContext.Transforms.ScoreTensorFlowModel(modelLocation, outputs, inputs).Fit(loader).Transform(loader); + + using (var cursor = trans.GetRowCursorForAllColumns()) + { + int outColIndex = 5; + var oneDimgetter = cursor.GetGetter>(outColIndex); + var twoDimgetter = cursor.GetGetter>(outColIndex + 1); + var threeDimgetter = cursor.GetGetter>(outColIndex + 2); + var fourDimgetter = cursor.GetGetter>(outColIndex + 3); + var fourDimKnowngetter = cursor.GetGetter>(outColIndex + 4); + + VBuffer oneDim = default; + VBuffer twoDim = default; + VBuffer threeDim = default; + VBuffer fourDim = default; + VBuffer fourDimKnown = default; + foreach (var sample in data) + { + Assert.True(cursor.MoveNext()); + + oneDimgetter(ref oneDim); + twoDimgetter(ref twoDim); + threeDimgetter(ref threeDim); + fourDimgetter(ref fourDim); + fourDimKnowngetter(ref fourDimKnown); + + var oneDimValues = oneDim.GetValues(); + Assert.Equal(sample.OneDim.Length, oneDimValues.Length); + Assert.True(oneDimValues.SequenceEqual(sample.OneDim)); + + var twoDimValues = twoDim.GetValues(); + Assert.Equal(sample.TwoDim.Length, twoDimValues.Length); + Assert.True(twoDimValues.SequenceEqual(sample.TwoDim)); + + var threeDimValues = threeDim.GetValues(); + Assert.Equal(sample.ThreeDim.Length, threeDimValues.Length); + Assert.True(threeDimValues.SequenceEqual(sample.ThreeDim)); + + var fourDimValues = fourDim.GetValues(); + Assert.Equal(sample.FourDim.Length, fourDimValues.Length); + Assert.True(fourDimValues.SequenceEqual(sample.FourDim)); + + var fourDimKnownValues = fourDimKnown.GetValues(); + Assert.Equal(sample.FourDimKnown.Length, fourDimKnownValues.Length); + Assert.True(fourDimKnownValues.SequenceEqual(sample.FourDimKnown)); + } + Assert.False(cursor.MoveNext()); + } + } + private class TypesData { [VectorType(2)] @@ -103,7 +215,7 @@ private class TypesData /// /// Test to ensure the supported datatypes can passed to TensorFlow . /// - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + [TensorFlowFact] public void TensorFlowTransformInputOutputTypesTest() { // This an identity model which returns the same output as input. @@ -142,7 +254,7 @@ public void TensorFlowTransformInputOutputTypesTest() var loader = mlContext.Data.ReadFromEnumerable(data); - var inputs = new string[]{"f64", "f32", "i64", "i32", "i16", "i8", "u64", "u32", "u16", "u8","b"}; + var inputs = new string[] { "f64", "f32", "i64", "i32", "i16", "i8", "u64", "u32", "u16", "u8", "b" }; var outputs = new string[] { "o_f64", "o_f32", "o_i64", "o_i32", "o_i16", "o_i8", "o_u64", "o_u32", "o_u16", "o_u8", "o_b" }; var trans = mlContext.Transforms.ScoreTensorFlowModel(model_location, outputs, inputs).Fit(loader).Transform(loader); ; @@ -160,7 +272,7 @@ public void TensorFlowTransformInputOutputTypesTest() var u8getter = cursor.GetGetter>(20); var boolgetter = cursor.GetGetter>(21); - + VBuffer f64 = default; VBuffer f32 = default; VBuffer i64 = default; @@ -297,7 +409,7 @@ public void TensorFlowTransformInceptionTest() } } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + [TensorFlowFact] public void TensorFlowInputsOutputsSchemaTest() { var mlContext = new MLContext(seed: 1, conc: 1); @@ -374,7 +486,7 @@ public void TensorFlowInputsOutputsSchemaTest() } } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + [TensorFlowFact] public void TensorFlowTransformMNISTConvTest() { var mlContext = new MLContext(seed: 1, conc: 1); @@ -412,7 +524,7 @@ public void TensorFlowTransformMNISTConvTest() Assert.Equal(5, GetMaxIndexForOnePrediction(onePrediction)); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + [TensorFlowFact] public void TensorFlowTransformMNISTLRTrainingTest() { const double expectedMicroAccuracy = 0.72173913043478266; @@ -449,7 +561,7 @@ public void TensorFlowTransformMNISTLRTrainingTest() ReTrain = true })) .Append(mlContext.Transforms.Concatenate("Features", "Prediction")) - .Append(mlContext.Transforms.Conversion.MapValueToKey("KeyLabel","Label", maxNumKeys: 10)) + .Append(mlContext.Transforms.Conversion.MapValueToKey("KeyLabel", "Label", maxNumKeys: 10)) .Append(mlContext.MulticlassClassification.Trainers.LightGbm("KeyLabel", "Features")); var trainedModel = pipe.Fit(trainData); @@ -497,7 +609,7 @@ private void CleanUp(string model_location) } } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + [TensorFlowFact] public void TensorFlowTransformMNISTConvTrainingTest() { ExecuteTFTransformMNISTConvTrainingTest(false, null, 0.74782608695652175, 0.608843537414966); @@ -592,7 +704,7 @@ private void ExecuteTFTransformMNISTConvTrainingTest(bool shuffle, int? shuffleS } } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + [TensorFlowFact] public void TensorFlowTransformMNISTConvSavedModelTest() { // This test trains a multi-class classifier pipeline where a pre-trained Tenroflow model is used for featurization. @@ -715,7 +827,7 @@ public class MNISTPrediction public float[] PredictedLabels; } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + [TensorFlowFact] public void TensorFlowTransformCifar() { var modelLocation = "cifar_model/frozen_model.pb"; @@ -761,7 +873,7 @@ public void TensorFlowTransformCifar() } } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + [TensorFlowFact] public void TensorFlowTransformCifarSavedModel() { var modelLocation = "cifar_saved_model"; @@ -803,7 +915,7 @@ public void TensorFlowTransformCifarSavedModel() } // This test has been created as result of https://github.com/dotnet/machinelearning/issues/2156. - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + [TensorFlowFact] public void TensorFlowGettingSchemaMultipleTimes() { var modelLocation = "cifar_saved_model"; @@ -816,7 +928,7 @@ public void TensorFlowGettingSchemaMultipleTimes() } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] + [TensorFlowFact] public void TensorFlowTransformCifarInvalidShape() { var modelLocation = "cifar_model/frozen_model.pb"; @@ -861,7 +973,7 @@ public class TensorFlowSentiment public float[] Prediction; } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] + [TensorFlowFact] public void TensorFlowSentimentClassificationTest() { var mlContext = new MLContext(seed: 1, conc: 1); diff --git a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs index a7d75f86cb..9a05f95bbf 100644 --- a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs @@ -12,6 +12,7 @@ using Microsoft.ML.Model; using Microsoft.ML.RunTests; using Microsoft.ML.StaticPipe; +using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.Tools; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.StaticPipe; @@ -56,7 +57,7 @@ public TensorFlowEstimatorTests(ITestOutputHelper output) : base(output) { } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + [TensorFlowFact] void TestSimpleCase() { var modelFile = "model_matmul/frozen_saved_model.pb"; @@ -96,7 +97,7 @@ void TestSimpleCase() catch (InvalidOperationException) { } } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + [TensorFlowFact] void TestOldSavingAndLoading() { var modelFile = "model_matmul/frozen_saved_model.pb"; @@ -132,7 +133,7 @@ void TestOldSavingAndLoading() } } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 output differs from Baseline + [TensorFlowFact] void TestCommandLine() { // typeof helps to load the TensorFlowTransformer type. @@ -140,7 +141,7 @@ void TestCommandLine() Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=a:R4:0-3 col=b:R4:0-3} xf=TFTransform{inputs=a inputs=b outputs=c modellocation={model_matmul/frozen_saved_model.pb}}" }), (int)0); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + [TensorFlowFact] public void TestTensorFlowStatic() { var modelLocation = "cifar_model/frozen_model.pb"; @@ -182,7 +183,7 @@ public void TestTensorFlowStatic() } } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] + [TensorFlowFact] public void TestTensorFlowStaticWithSchema() { const string modelLocation = "cifar_model/frozen_model.pb"; diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/CalibratorEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/CalibratorEstimators.cs index b3f4207de4..3de1a3125a 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/CalibratorEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/CalibratorEstimators.cs @@ -6,8 +6,7 @@ using Microsoft.ML.Calibrator; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; -using Microsoft.ML.Learners; -using Microsoft.ML.Trainers.Online; +using Microsoft.ML.Trainers; using Xunit; namespace Microsoft.ML.Tests.TrainerEstimators diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs index 134b52f96e..4e61c3480b 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs @@ -2,12 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Linq; using Microsoft.Data.DataView; using Microsoft.ML.Core.Data; using Microsoft.ML.Data; using Microsoft.ML.Internal.Calibration; -using Microsoft.ML.Learners; using Microsoft.ML.Trainers; using Xunit; diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs index 7eaf3130a4..2749961baf 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs @@ -10,6 +10,7 @@ using Microsoft.ML.Core.Data; using Microsoft.ML.Data; using Microsoft.ML.RunTests; +using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.Trainers; using Xunit; @@ -17,7 +18,7 @@ namespace Microsoft.ML.Tests.TrainerEstimators { public partial class TrainerEstimators : TestDataPipeBase { - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // This test is being fixed as part of issue #1441. + [MatrixFactorizationFact] public void MatrixFactorization_Estimator() { string labelColumnName = "Label"; @@ -39,18 +40,18 @@ public void MatrixFactorization_Estimator() MatrixRowIndexColumnName = matrixRowIndexColumnName, LabelColumnName = labelColumnName, NumIterations = 3, - NumThreads = 1, - K = 4, + NumThreads = 1, + ApproximationRank = 4, }; - var est = new MatrixFactorizationTrainer(Env, options); + var est = ML.Recommendation().Trainers.MatrixFactorization(options); TestEstimatorCore(est, data, invalidInput: invalidData); Done(); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // This test is being fixed as part of issue #1441. + [MatrixFactorizationFact] public void MatrixFactorizationSimpleTrainAndPredict() { var mlContext = new MLContext(seed: 1, conc: 1); @@ -68,13 +69,14 @@ public void MatrixFactorizationSimpleTrainAndPredict() var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.trivialMatrixFactorization.trainFilename))); // Create a pipeline with a single operator. - var options = new MatrixFactorizationTrainer.Options { + var options = new MatrixFactorizationTrainer.Options + { MatrixColumnIndexColumnName = userColumnName, MatrixRowIndexColumnName = itemColumnName, LabelColumnName = labelColumnName, NumIterations = 3, NumThreads = 1, // To eliminate randomness, # of threads must be 1. - K = 7, + ApproximationRank = 7, }; var pipeline = mlContext.Recommendation().Trainers.MatrixFactorization(options); @@ -82,6 +84,20 @@ public void MatrixFactorizationSimpleTrainAndPredict() // Train a matrix factorization model. var model = pipeline.Fit(data); + // Let's validate content of the model. + Assert.Equal(model.Model.ApproximationRank, options.ApproximationRank); + var leftMatrix = model.Model.LeftFactorMatrix; + var rightMatrix = model.Model.RightFactorMatrix; + Assert.Equal(leftMatrix.Count, model.Model.NumberOfRows * model.Model.ApproximationRank); + Assert.Equal(rightMatrix.Count, model.Model.NumberOfColumns * model.Model.ApproximationRank); + // MF produce different matrixes on different platforms, so at least test thier content on windows. + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Equal(leftMatrix[0], (double)0.3091519, 5); + Assert.Equal(leftMatrix[leftMatrix.Count - 1], (double)0.5639161, 5); + Assert.Equal(rightMatrix[0], (double)0.243584976, 5); + Assert.Equal(rightMatrix[rightMatrix.Count - 1], (double)0.380032182, 5); + } // Read the test data set as an IDataView var testData = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.trivialMatrixFactorization.testFilename))); @@ -174,7 +190,7 @@ internal class MatrixElementForScore public float Score; } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // This test is being fixed as part of issue #1441. + [MatrixFactorizationFact] public void MatrixFactorizationInMemoryData() { // Create an in-memory matrix as a list of tuples (column index, row index, value). @@ -197,7 +213,7 @@ public void MatrixFactorizationInMemoryData() LabelColumnName = nameof(MatrixElement.Value), NumIterations = 10, NumThreads = 1, // To eliminate randomness, # of threads must be 1. - K = 32, + ApproximationRank = 32, }; var pipeline = mlContext.Recommendation().Trainers.MatrixFactorization(options); @@ -263,7 +279,7 @@ internal class MatrixElementZeroBasedForScore public float Score; } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // This test is being fixed as part of issue #1441. + [MatrixFactorizationFact] public void MatrixFactorizationInMemoryDataZeroBaseIndex() { // Create an in-memory matrix as a list of tuples (column index, row index, value). @@ -287,8 +303,8 @@ public void MatrixFactorizationInMemoryDataZeroBaseIndex() LabelColumnName = nameof(MatrixElement.Value), NumIterations = 100, NumThreads = 1, // To eliminate randomness, # of threads must be 1. - K = 32, - Eta = 0.5, + ApproximationRank = 32, + LearningRate = 0.5, }; var pipeline = mlContext.Recommendation().Trainers.MatrixFactorization(options); @@ -377,7 +393,7 @@ private class OneClassMatrixElementZeroBasedForScore public float Score; } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // This test is being fixed as part of issue #1441. + [MatrixFactorizationFact] public void OneClassMatrixFactorizationInMemoryDataZeroBaseIndex() { // Create an in-memory matrix as a list of tuples (column index, row index, value). For one-class matrix @@ -409,7 +425,7 @@ public void OneClassMatrixFactorizationInMemoryDataZeroBaseIndex() NumIterations = 100, NumThreads = 1, // To eliminate randomness, # of threads must be 1. Lambda = 0.025, // Let's test non-default regularization coefficient. - K = 16, + ApproximationRank = 16, Alpha = 0.01, // Importance coefficient of loss function over matrix elements not specified in the input matrix. C = 0.15, // Desired value for matrix elements not specified in the input matrix. }; @@ -449,7 +465,7 @@ public void OneClassMatrixFactorizationInMemoryDataZeroBaseIndex() CompareNumbersWithTolerance(0.141411, testResults[1].Score, digitsOfPrecision: 5); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // This test is being fixed as part of issue #1441. + [MatrixFactorizationFact] public void MatrixFactorizationBackCompat() { // This test is meant to check backwards compatibility after the change that removed Min and Contiguous from KeyType. diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs index 5b1a698e95..1b4dbf86e2 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs @@ -4,10 +4,9 @@ using System.Linq; using Microsoft.ML.Data; +using Microsoft.ML.Trainers.HalLearners; using Microsoft.ML.Internal.Calibration; -using Microsoft.ML.Learners; using Microsoft.ML.Trainers; -using Microsoft.ML.Trainers.SymSgd; using Xunit; namespace Microsoft.ML.Tests.TrainerEstimators diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs index 2cde2c918e..e6ced83d4d 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs @@ -2,12 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using System.Linq; using Microsoft.Data.DataView; using Microsoft.ML.Data; -using Microsoft.ML.SamplesUtils; using Microsoft.ML.Trainers.FastTree; -using System; -using System.Linq; using Xunit; namespace Microsoft.ML.Tests.TrainerEstimators @@ -18,7 +17,7 @@ public partial class TrainerEstimators public void TreeEnsembleFeaturizerOutputSchemaTest() { // Create data set - var data = DatasetUtils.GenerateBinaryLabelFloatFeatureVectorSamples(1000).ToList(); + var data = SamplesUtils.DatasetUtils.GenerateBinaryLabelFloatFeatureVectorSamples(1000).ToList(); var dataView = ML.Data.ReadFromEnumerable(data); // Define a tree model whose trees will be extracted to construct a tree featurizer. @@ -38,8 +37,8 @@ public void TreeEnsembleFeaturizerOutputSchemaTest() // To get output schema, we need to create RoleMappedSchema for calling Bind(...). var roleMappedSchema = new RoleMappedSchema(dataView.Schema, - label: nameof(DatasetUtils.BinaryLabelFloatFeatureVectorSample.Label), - feature: nameof(DatasetUtils.BinaryLabelFloatFeatureVectorSample.Features)); + label: nameof(SamplesUtils.DatasetUtils.BinaryLabelFloatFeatureVectorSample.Label), + feature: nameof(SamplesUtils.DatasetUtils.BinaryLabelFloatFeatureVectorSample.Features)); // Retrieve output schema. var boundMapper = (treeFeaturizer as ISchemaBindableMapper).Bind(Env, roleMappedSchema); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index ddaf368dd8..9e995ceee0 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -10,6 +10,7 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.LightGBM; using Microsoft.ML.RunTests; +using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Transforms.Conversions; using Xunit; @@ -42,7 +43,7 @@ public void FastTreeBinaryEstimator() Done(); } - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] public void LightGBMBinaryEstimator() { var (pipe, dataView) = GetBinaryClassificationPipeline(); @@ -128,7 +129,7 @@ public void FastTreeRankerEstimator() /// /// LightGbmRankingTrainer TrainerEstimator test /// - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] public void LightGBMRankerEstimator() { var (pipe, dataView) = GetRankingPipeline(); @@ -161,7 +162,7 @@ public void FastTreeRegressorEstimator() /// /// LightGbmRegressorTrainer TrainerEstimator test /// - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] public void LightGBMRegressorEstimator() { var dataView = GetRegressionPipeline(); @@ -237,7 +238,7 @@ public void FastForestRegressorEstimator() /// /// LightGbmMulticlass TrainerEstimator test /// - [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only + [LightGBMFact] public void LightGbmMultiClassEstimator() { var (pipeline, dataView) = GetMultiClassPipeline(); @@ -331,7 +332,7 @@ private void LightGbmHelper(bool useSoftmax, out string modelString, out List data = new List(); var dataView = env.Data.ReadFromEnumerable(data); - var args = new SsaChangePointDetector.Arguments() + var args = new SsaChangePointDetector.Options() { Confidence = 95, Source = "Value", @@ -134,7 +134,7 @@ public void ChangePointDetectionWithSeasonality() } } - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] + [LessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline")] public void ChangePointDetectionWithSeasonalityPredictionEngineNoColumn() { const int ChangeHistorySize = 10; @@ -157,7 +157,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngineNoColumn() // Pipeline. var pipeline = ml.Transforms.Text.FeaturizeText("Text_Featurized", "Text") - .Append(new SsaChangePointEstimator(ml, new SsaChangePointDetector.Arguments() + .Append(new SsaChangePointEstimator(ml, new SsaChangePointDetector.Options() { Confidence = 95, Source = "Value", @@ -210,7 +210,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngineNoColumn() Assert.Equal(0.12216401100158691, prediction2.Change[1], precision: 5); // Raw score } - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] + [LessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline")] public void ChangePointDetectionWithSeasonalityPredictionEngine() { const int ChangeHistorySize = 10; @@ -233,7 +233,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngine() // Pipeline. var pipeline = ml.Transforms.Text.FeaturizeText("Text_Featurized", "Text") - .Append(new SsaChangePointEstimator(ml, new SsaChangePointDetector.Arguments() + .Append(new SsaChangePointEstimator(ml, new SsaChangePointDetector.Options() { Confidence = 95, Source = "Value", diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs index e65c9da724..fa77255e3c 100644 --- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs @@ -5,7 +5,7 @@ using System.Collections.Generic; using Microsoft.ML.Data; using Microsoft.ML.RunTests; -using Microsoft.ML.TimeSeriesProcessing; +using Microsoft.ML.Transforms.TimeSeries; using Xunit; using Xunit.Abstractions; diff --git a/tools-local/Microsoft.ML.InternalCodeAnalyzer/BestFriendOnPublicDeclarationsAnalyzer.cs b/tools-local/Microsoft.ML.InternalCodeAnalyzer/BestFriendOnPublicDeclarationsAnalyzer.cs new file mode 100644 index 0000000000..658e957c16 --- /dev/null +++ b/tools-local/Microsoft.ML.InternalCodeAnalyzer/BestFriendOnPublicDeclarationsAnalyzer.cs @@ -0,0 +1,70 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Diagnostics; + +namespace Microsoft.ML.InternalCodeAnalyzer +{ + [DiagnosticAnalyzer(LanguageNames.CSharp)] + public sealed class BestFriendOnPublicDeclarationsAnalyzer : DiagnosticAnalyzer + { + private const string Category = "Access"; + internal const string DiagnosticId = "MSML_BestFriendOnPublicDeclaration"; + + private const string Title = "Public declarations should not have " + AttributeName + " attribute."; + private const string Format = "The " + AttributeName + " should not be applied to publicly visible members."; + + private const string Description = + "The " + AttributeName + " attribute is not valid on public identifiers."; + + private static DiagnosticDescriptor Rule = + new DiagnosticDescriptor(DiagnosticId, Title, Format, Category, + DiagnosticSeverity.Warning, isEnabledByDefault: true, description: Description); + + private const string AttributeName = "Microsoft.ML.BestFriendAttribute"; + + public override ImmutableArray SupportedDiagnostics => + ImmutableArray.Create(Rule); + + public override void Initialize(AnalysisContext context) + { + context.EnableConcurrentExecution(); + context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None); + + context.RegisterCompilationStartAction(CompilationStart); + } + + private void CompilationStart(CompilationStartAnalysisContext context) + { + var list = new List { AttributeName, "Microsoft.ML.Internal.CpuMath.Core.BestFriendAttribute" }; + + foreach (var attributeName in list) + { + var attribute = context.Compilation.GetTypeByMetadataName(attributeName); + + if (attribute == null) + continue; + + context.RegisterSymbolAction(c => AnalyzeCore(c, attribute), SymbolKind.NamedType, SymbolKind.Method, SymbolKind.Field, SymbolKind.Property); + } + } + + private void AnalyzeCore(SymbolAnalysisContext context, INamedTypeSymbol attributeType) + { + if (context.Symbol.DeclaredAccessibility != Accessibility.Public) + return; + + var attribute = context.Symbol.GetAttributes().FirstOrDefault(a => a.AttributeClass == attributeType); + if (attribute == null) + return; + + var diagnostic = Diagnostic.Create(Rule, attribute.ApplicationSyntaxReference.GetSyntax().GetLocation(), context.Symbol.Name); + context.ReportDiagnostic(diagnostic); + } + } +} \ No newline at end of file