From b186cd2600cfabf83eb1cbf7901a198dc9b38e70 Mon Sep 17 00:00:00 2001 From: Shahab Moradi Date: Fri, 15 Mar 2019 13:57:43 -0700 Subject: [PATCH] In-memory & self-contained sample template. --- .../Dynamic/Trainers/Regression/FastTree.cs | 50 +++++++++++++++++-- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastTree.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastTree.cs index 933ed4f744..082bc340f3 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastTree.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastTree.cs @@ -20,18 +20,48 @@ public static void Example() var examples = GenerateRandomDataPoints(1000); // Convert the examples list to an IDataView object, which is consumable by ML.NET API. - var data = mlContext.Data.LoadFromEnumerable(examples); + var trainingData = mlContext.Data.LoadFromEnumerable(examples); // Define the trainer. - var pipeline = mlContext.BinaryClassification.Trainers.FastTree(); + var pipeline = mlContext.Regression.Trainers.FastTree(); // Train the model. - var model = pipeline.Fit(data); + var model = pipeline.Fit(trainingData); + + // Create testing examples. Use different random seed to make it different from training data. + var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123)); + + // Run the model on test data set. + var transformedTestData = model.Transform(testData); + + // Convert IDataView object to a list. + var predictions = mlContext.Data.CreateEnumerable(transformedTestData, reuseRowObject: false).ToList(); + + // Look at 5 predictions + foreach (var p in predictions.Take(5)) + Console.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}"); + + // Expected output: + // Label: 0.985, Prediction: 0.938 + // Label: 0.155, Prediction: 0.131 + // Label: 0.515, Prediction: 0.517 + // Label: 0.566, Prediction: 0.519 + // Label: 0.096, Prediction: 0.089 + + // Evaluate the overall metrics + var metrics = mlContext.Regression.Evaluate(transformedTestData); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + // Expected output: + // Mean Absolute Error: 0.05 + // Mean Squared Error: 0.00 + // Root Mean Squared Error: 0.06 + // RSquared: 0.95 } - private static IEnumerable GenerateRandomDataPoints(int count) + private static IEnumerable GenerateRandomDataPoints(int count, int seed=0) { - var random = new Random(0); + var random = new Random(seed); float randomFloat() => (float)random.NextDouble(); for (int i = 0; i < count; i++) { @@ -45,11 +75,21 @@ private static IEnumerable GenerateRandomDataPoints(int count) } } + // Example with label and 50 feature values. A data set is a collection of such examples. private class DataPoint { public float Label { get; set; } [VectorType(50)] public float[] Features { get; set; } } + + // Class used to capture predictions. + private class Prediction + { + // Original label. + public float Label { get; set; } + // Predicted score from the trainer. + public float Score { get; set; } + } } } \ No newline at end of file