diff --git a/src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs b/src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs index a16d5617ff..05cba5e8a7 100644 --- a/src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs +++ b/src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs @@ -4,11 +4,36 @@ using System.Collections.Generic; using System.Linq; +using Microsoft.ML.Data; namespace Microsoft.ML.AutoML { internal class BestResultUtil { + public static RunDetail GetBestRun(IEnumerable> results, + BinaryClassificationMetric metric) + { + var metricsAgent = new BinaryMetricsAgent(null, metric); + var metricInfo = new OptimizingMetricInfo(metric); + return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing); + } + + public static RunDetail GetBestRun(IEnumerable> results, + RegressionMetric metric) + { + var metricsAgent = new RegressionMetricsAgent(null, metric); + var metricInfo = new OptimizingMetricInfo(metric); + return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing); + } + + public static RunDetail GetBestRun(IEnumerable> results, + MulticlassClassificationMetric metric) + { + var metricsAgent = new MultiMetricsAgent(null, metric); + var metricInfo = new OptimizingMetricInfo(metric); + return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing); + } + public static RunDetail GetBestRun(IEnumerable> results, IMetricsAgent metricsAgent, bool isMetricMaximizing) { diff --git a/src/mlnet/AutoML/AutoMLEngine.cs b/src/mlnet/AutoML/AutoMLEngine.cs index 7acfe5f84a..7ad28b3cd4 100644 --- a/src/mlnet/AutoML/AutoMLEngine.cs +++ b/src/mlnet/AutoML/AutoMLEngine.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Collections.Generic; +using System; using Microsoft.ML.AutoML; using Microsoft.ML.CLI.Data; using Microsoft.ML.CLI.ShellProgressBar; @@ -44,47 +44,44 @@ public ColumnInferenceResults InferColumns(MLContext context, ColumnInformation return columnInference; } - ExperimentResult IAutoMLEngine.ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressBar progressBar) + void IAutoMLEngine.ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressHandlers.BinaryClassificationHandler handler, ProgressBar progressBar) { - var progressReporter = new ProgressHandlers.BinaryClassificationHandler(optimizationMetric, progressBar); - var result = context.Auto() + ExperimentResult result = context.Auto() .CreateBinaryClassificationExperiment(new BinaryExperimentSettings() { MaxExperimentTimeInSeconds = settings.MaxExplorationTime, CacheBeforeTrainer = this.cacheBeforeTrainer, OptimizingMetric = optimizationMetric }) - .Execute(trainData, validationData, columnInformation, progressHandler: progressReporter); + .Execute(trainData, validationData, columnInformation, progressHandler: handler); + logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline); - return result; } - ExperimentResult IAutoMLEngine.ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressBar progressBar) + void IAutoMLEngine.ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressHandlers.RegressionHandler handler, ProgressBar progressBar) { - var progressReporter = new ProgressHandlers.RegressionHandler(optimizationMetric, progressBar); - var result = context.Auto() + ExperimentResult result = context.Auto() .CreateRegressionExperiment(new RegressionExperimentSettings() { MaxExperimentTimeInSeconds = settings.MaxExplorationTime, OptimizingMetric = optimizationMetric, CacheBeforeTrainer = this.cacheBeforeTrainer - }).Execute(trainData, validationData, columnInformation, progressHandler: progressReporter); + }).Execute(trainData, validationData, columnInformation, progressHandler: handler); + logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline); - return result; } - ExperimentResult IAutoMLEngine.ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressBar progressBar) + void IAutoMLEngine.ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressHandlers.MulticlassClassificationHandler handler, ProgressBar progressBar) { - var progressReporter = new ProgressHandlers.MulticlassClassificationHandler(optimizationMetric, progressBar); - var result = context.Auto() + ExperimentResult result = context.Auto() .CreateMulticlassClassificationExperiment(new MulticlassExperimentSettings() { MaxExperimentTimeInSeconds = settings.MaxExplorationTime, CacheBeforeTrainer = this.cacheBeforeTrainer, OptimizingMetric = optimizationMetric - }).Execute(trainData, validationData, columnInformation, progressHandler: progressReporter); + }).Execute(trainData, validationData, columnInformation, progressHandler: handler); + logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline); - return result; } } diff --git a/src/mlnet/AutoML/IAutoMLEngine.cs b/src/mlnet/AutoML/IAutoMLEngine.cs index 882ff654fe..4af8fcf2df 100644 --- a/src/mlnet/AutoML/IAutoMLEngine.cs +++ b/src/mlnet/AutoML/IAutoMLEngine.cs @@ -2,9 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; using Microsoft.ML.AutoML; using Microsoft.ML.CLI.ShellProgressBar; +using Microsoft.ML.CLI.Utilities; using Microsoft.ML.Data; namespace Microsoft.ML.CLI.CodeGenerator @@ -13,11 +15,11 @@ internal interface IAutoMLEngine { ColumnInferenceResults InferColumns(MLContext context, ColumnInformation columnInformation); - ExperimentResult ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressBar progressBar = null); + void ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressHandlers.BinaryClassificationHandler handler, ProgressBar progressBar = null); - ExperimentResult ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressBar progressBar = null); + void ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressHandlers.MulticlassClassificationHandler handler, ProgressBar progressBar = null); - ExperimentResult ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressBar progressBar = null); + void ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressHandlers.RegressionHandler handler, ProgressBar progressBar = null); } } diff --git a/src/mlnet/CodeGenerator/CodeGenerationHelper.cs b/src/mlnet/CodeGenerator/CodeGenerationHelper.cs index f2153fd785..a58088ad78 100644 --- a/src/mlnet/CodeGenerator/CodeGenerationHelper.cs +++ b/src/mlnet/CodeGenerator/CodeGenerationHelper.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; @@ -36,7 +37,7 @@ public void GenerateCode() { Stopwatch watch = Stopwatch.StartNew(); var context = new MLContext(); - ConsumeAutoMLSDKLogs(context); + context.Log += ConsumeAutoMLSDKLog; var verboseLevel = Utils.GetVerbosity(settings.Verbosity); @@ -73,9 +74,14 @@ public void GenerateCode() // The reason why we are doing this way of defining 3 different results is because of the AutoML API // i.e there is no common class/interface to handle all three tasks together. - ExperimentResult binaryExperimentResult = default; - ExperimentResult multiclassExperimentResult = default; - ExperimentResult regressionExperimentResult = default; + List> completedBinaryRuns = new List>(); + List> completedMulticlassRuns = new List>(); + List> completedRegressionRuns = new List>(); + + ProgressHandlers.BinaryClassificationHandler binaryHandler = default; + ProgressHandlers.RegressionHandler regressionHandler = default; + ProgressHandlers.MulticlassClassificationHandler multiClassHandler = default; + if (verboseLevel > LogLevel.Trace) { Console.Write($"{Strings.ExplorePipeline}: "); @@ -90,6 +96,10 @@ public void GenerateCode() logger.Log(LogLevel.Trace, $"{Strings.ExplorePipeline}: {settings.MlTask}"); logger.Log(LogLevel.Trace, $"{Strings.FurtherLearning}: {Strings.LearningHttpLink}"); + + // TODO the below region needs more refactoring to be done especially with so many switch cases. + + #region RunAutoMLEngine try { var options = new ProgressBarOptions @@ -113,13 +123,16 @@ public void GenerateCode() { // TODO: It may be a good idea to convert the below Threads to Tasks or get rid of this progress bar all together and use an existing one in opensource. case TaskKind.BinaryClassification: - t = new Thread(() => SafeExecute(() => automlEngine.ExploreBinaryClassificationModels(context, trainData, validationData, columnInformation, new BinaryExperimentSettings().OptimizingMetric, pbar), out ex, out binaryExperimentResult, pbar)); + binaryHandler = new ProgressHandlers.BinaryClassificationHandler(new BinaryExperimentSettings().OptimizingMetric, completedBinaryRuns, pbar); + t = new Thread(() => SafeExecute(() => automlEngine.ExploreBinaryClassificationModels(context, trainData, validationData, columnInformation, new BinaryExperimentSettings().OptimizingMetric, binaryHandler, pbar), out ex, pbar)); break; case TaskKind.Regression: - t = new Thread(() => SafeExecute(() => automlEngine.ExploreRegressionModels(context, trainData, validationData, columnInformation, new RegressionExperimentSettings().OptimizingMetric, pbar), out ex, out regressionExperimentResult, pbar)); + regressionHandler = new ProgressHandlers.RegressionHandler(new RegressionExperimentSettings().OptimizingMetric, completedRegressionRuns, pbar); + t = new Thread(() => SafeExecute(() => automlEngine.ExploreRegressionModels(context, trainData, validationData, columnInformation, new RegressionExperimentSettings().OptimizingMetric, regressionHandler, pbar), out ex, pbar)); break; case TaskKind.MulticlassClassification: - t = new Thread(() => SafeExecute(() => automlEngine.ExploreMultiClassificationModels(context, trainData, validationData, columnInformation, new MulticlassExperimentSettings().OptimizingMetric, pbar), out ex, out multiclassExperimentResult, pbar)); + multiClassHandler = new ProgressHandlers.MulticlassClassificationHandler(new MulticlassExperimentSettings().OptimizingMetric, completedMulticlassRuns, pbar); + t = new Thread(() => SafeExecute(() => automlEngine.ExploreMultiClassificationModels(context, trainData, validationData, columnInformation, new MulticlassExperimentSettings().OptimizingMetric, multiClassHandler, pbar), out ex, pbar)); break; default: logger.Log(LogLevel.Error, Strings.UnsupportedMlTask); @@ -127,21 +140,24 @@ public void GenerateCode() } t.Start(); - if (!pbar.CompletedHandle.WaitOne(wait)) - pbar.Message = $"{nameof(FixedDurationBar)} did not signal {nameof(FixedDurationBar.CompletedHandle)} after {wait}"; - - if (t.IsAlive == true) + pbar.CompletedHandle.WaitOne(wait); + context.Log -= ConsumeAutoMLSDKLog; + switch (taskKind) { - string waitingMessage = Strings.WaitingForLastIteration; - string originalMessage = pbar.Message; - pbar.Message = waitingMessage; - t.Join(); - if (waitingMessage.Equals(pbar.Message)) - { - // Corner cases where thread was alive but has completed all iterations. - pbar.Message = originalMessage; - } + case TaskKind.BinaryClassification: + binaryHandler.Stop(); + break; + case TaskKind.Regression: + regressionHandler.Stop(); + break; + case TaskKind.MulticlassClassification: + multiClassHandler.Stop(); + break; + default: + logger.Log(LogLevel.Error, Strings.UnsupportedMlTask); + break; } + if (ex != null) { throw ex; @@ -150,30 +166,61 @@ public void GenerateCode() } else { + Exception ex = null; + Thread t = default; switch (taskKind) { + // TODO: It may be a good idea to convert the below Threads to Tasks or get rid of this progress bar all together and use an existing one in opensource. case TaskKind.BinaryClassification: - binaryExperimentResult = automlEngine.ExploreBinaryClassificationModels(context, trainData, validationData, columnInformation, new BinaryExperimentSettings().OptimizingMetric); + binaryHandler = new ProgressHandlers.BinaryClassificationHandler(new BinaryExperimentSettings().OptimizingMetric, completedBinaryRuns, null); + t = new Thread(() => SafeExecute(() => automlEngine.ExploreBinaryClassificationModels(context, trainData, validationData, columnInformation, new BinaryExperimentSettings().OptimizingMetric, binaryHandler, null), out ex, null)); break; case TaskKind.Regression: - regressionExperimentResult = automlEngine.ExploreRegressionModels(context, trainData, validationData, columnInformation, new RegressionExperimentSettings().OptimizingMetric); + regressionHandler = new ProgressHandlers.RegressionHandler(new RegressionExperimentSettings().OptimizingMetric, completedRegressionRuns, null); + t = new Thread(() => SafeExecute(() => automlEngine.ExploreRegressionModels(context, trainData, validationData, columnInformation, new RegressionExperimentSettings().OptimizingMetric, regressionHandler, null), out ex, null)); break; case TaskKind.MulticlassClassification: - multiclassExperimentResult = automlEngine.ExploreMultiClassificationModels(context, trainData, validationData, columnInformation, new MulticlassExperimentSettings().OptimizingMetric); + multiClassHandler = new ProgressHandlers.MulticlassClassificationHandler(new MulticlassExperimentSettings().OptimizingMetric, completedMulticlassRuns, null); + t = new Thread(() => SafeExecute(() => automlEngine.ExploreMultiClassificationModels(context, trainData, validationData, columnInformation, new MulticlassExperimentSettings().OptimizingMetric, multiClassHandler, null), out ex, null)); + break; + default: + logger.Log(LogLevel.Error, Strings.UnsupportedMlTask); + break; + } + t.Start(); + Thread.Sleep(wait); + context.Log -= ConsumeAutoMLSDKLog; + switch (taskKind) + { + case TaskKind.BinaryClassification: + binaryHandler.Stop(); + break; + case TaskKind.Regression: + regressionHandler.Stop(); + break; + case TaskKind.MulticlassClassification: + multiClassHandler.Stop(); break; default: logger.Log(LogLevel.Error, Strings.UnsupportedMlTask); break; } - } - + if (ex != null) + { + throw ex; + } + } } catch (Exception) { logger.Log(LogLevel.Error, $"{Strings.ExplorePipelineException}:"); throw; } + finally + { + context.Log -= ConsumeAutoMLSDKLog; + } var elapsedTime = watch.Elapsed.TotalSeconds; @@ -185,25 +232,55 @@ public void GenerateCode() switch (taskKind) { case TaskKind.BinaryClassification: - var bestBinaryIteration = binaryExperimentResult.BestRun; - bestPipeline = bestBinaryIteration.Pipeline; - bestModel = bestBinaryIteration.Model; - ConsolePrinter.ExperimentResultsHeader(LogLevel.Info, settings.MlTask, settings.Dataset.Name, columnInformation.LabelColumnName, elapsedTime.ToString("F2"), binaryExperimentResult.RunDetails.Count()); - ConsolePrinter.PrintIterationSummary(binaryExperimentResult.RunDetails, new BinaryExperimentSettings().OptimizingMetric, 5); + if (completedBinaryRuns.Count > 0) + { + var binaryMetric = new BinaryExperimentSettings().OptimizingMetric; + var bestBinaryIteration = BestResultUtil.GetBestRun(completedBinaryRuns, binaryMetric); + bestPipeline = bestBinaryIteration.Pipeline; + bestModel = bestBinaryIteration.Model; + ConsolePrinter.ExperimentResultsHeader(LogLevel.Info, settings.MlTask, settings.Dataset.Name, columnInformation.LabelColumnName, elapsedTime.ToString("F2"), completedBinaryRuns.Count()); + ConsolePrinter.PrintIterationSummary(completedBinaryRuns, binaryMetric, 5); + } + else + { + logger.Log(LogLevel.Error, string.Format(Strings.CouldNotFinshOnTime, settings.MaxExplorationTime)); + logger.Log(LogLevel.Info, Strings.Exiting); + return; + } break; case TaskKind.Regression: - var bestRegressionIteration = regressionExperimentResult.BestRun; - bestPipeline = bestRegressionIteration.Pipeline; - bestModel = bestRegressionIteration.Model; - ConsolePrinter.ExperimentResultsHeader(LogLevel.Info, settings.MlTask, settings.Dataset.Name, columnInformation.LabelColumnName, elapsedTime.ToString("F2"), regressionExperimentResult.RunDetails.Count()); - ConsolePrinter.PrintIterationSummary(regressionExperimentResult.RunDetails, new RegressionExperimentSettings().OptimizingMetric, 5); + if (completedRegressionRuns.Count > 0) + { + var regressionMetric = new RegressionExperimentSettings().OptimizingMetric; + var bestRegressionIteration = BestResultUtil.GetBestRun(completedRegressionRuns, regressionMetric); + bestPipeline = bestRegressionIteration.Pipeline; + bestModel = bestRegressionIteration.Model; + ConsolePrinter.ExperimentResultsHeader(LogLevel.Info, settings.MlTask, settings.Dataset.Name, columnInformation.LabelColumnName, elapsedTime.ToString("F2"), completedRegressionRuns.Count()); + ConsolePrinter.PrintIterationSummary(completedRegressionRuns, regressionMetric, 5); + } + else + { + logger.Log(LogLevel.Error, string.Format(Strings.CouldNotFinshOnTime, settings.MaxExplorationTime)); + logger.Log(LogLevel.Info, Strings.Exiting); + return; + } break; case TaskKind.MulticlassClassification: - var bestMulticlassIteration = multiclassExperimentResult.BestRun; - bestPipeline = bestMulticlassIteration.Pipeline; - bestModel = bestMulticlassIteration.Model; - ConsolePrinter.ExperimentResultsHeader(LogLevel.Info, settings.MlTask, settings.Dataset.Name, columnInformation.LabelColumnName, elapsedTime.ToString("F2"), multiclassExperimentResult.RunDetails.Count()); - ConsolePrinter.PrintIterationSummary(multiclassExperimentResult.RunDetails, new MulticlassExperimentSettings().OptimizingMetric, 5); + if (completedMulticlassRuns.Count > 0) + { + var muliclassMetric = new MulticlassExperimentSettings().OptimizingMetric; + var bestMulticlassIteration = BestResultUtil.GetBestRun(completedMulticlassRuns, muliclassMetric); + bestPipeline = bestMulticlassIteration.Pipeline; + bestModel = bestMulticlassIteration.Model; + ConsolePrinter.ExperimentResultsHeader(LogLevel.Info, settings.MlTask, settings.Dataset.Name, columnInformation.LabelColumnName, elapsedTime.ToString("F2"), completedMulticlassRuns.Count()); + ConsolePrinter.PrintIterationSummary(completedMulticlassRuns, muliclassMetric, 5); + } + else + { + logger.Log(LogLevel.Error, string.Format(Strings.CouldNotFinshOnTime, settings.MaxExplorationTime)); + logger.Log(LogLevel.Info, Strings.Exiting); + return; + } break; } } @@ -212,6 +289,7 @@ public void GenerateCode() logger.Log(LogLevel.Info, Strings.ErrorBestPipeline); throw; } + #endregion // Save the model var modelprojectDir = Path.Combine(settings.OutputPath.FullName, $"{settings.Name}.Model"); @@ -284,60 +362,26 @@ internal void GenerateProject(ColumnInferenceResults columnInference, Pipeline p return (trainData, validationData); } - private void ConsumeAutoMLSDKLogs(MLContext context) + private static void ConsumeAutoMLSDKLog(object sender, LoggingEventArgs args) { - context.Log += (object sender, LoggingEventArgs loggingEventArgs) => - { - var logMessage = loggingEventArgs.Message; - if (logMessage.Contains(AutoMLLogger.ChannelName)) - { - logger.Trace(loggingEventArgs.Message); - } - }; - } - - private void SafeExecute(Func> p, out Exception ex, out ExperimentResult binaryExperimentResult, FixedDurationBar pbar) - { - try + var logMessage = args.Message; + if (logMessage.Contains(AutoMLLogger.ChannelName)) { - binaryExperimentResult = p.Invoke(); - ex = null; - } - catch (Exception e) - { - ex = e; - binaryExperimentResult = null; - return; - } - } - - private void SafeExecute(Func> p, out Exception ex, out ExperimentResult regressionExperimentResult, FixedDurationBar pbar) - { - try - { - regressionExperimentResult = p.Invoke(); - ex = null; - } - catch (Exception e) - { - ex = e; - regressionExperimentResult = null; - return; + logger.Trace(args.Message); } } - private void SafeExecute(Func> p, out Exception ex, out ExperimentResult multiClassExperimentResult, FixedDurationBar pbar) + private void SafeExecute(Action p, out Exception ex, FixedDurationBar pbar) { try { - multiClassExperimentResult = p.Invoke(); + p.Invoke(); ex = null; } catch (Exception e) { ex = e; - multiClassExperimentResult = null; - pbar.Dispose(); // or ((ManualResetEvent)pbar.CompletedHandle).Set(); + pbar?.Dispose(); return; } } diff --git a/src/mlnet/Program.cs b/src/mlnet/Program.cs index eb049ba827..568eaa4d4a 100644 --- a/src/mlnet/Program.cs +++ b/src/mlnet/Program.cs @@ -23,7 +23,7 @@ class Program public static void Main(string[] args) { var telemetry = new MlTelemetry(); - + int exitCode = 1; // Create handler outside so that commandline and the handler is decoupled and testable. var handler = CommandHandler.Create( (options) => @@ -63,6 +63,7 @@ public static void Main(string[] args) // Execute the command command.Execute(); + exitCode = 0; } catch (Exception e) { @@ -70,7 +71,6 @@ public static void Main(string[] args) logger.Log(LogLevel.Debug, e.ToString()); logger.Log(LogLevel.Info, Strings.LookIntoLogFile); logger.Log(LogLevel.Error, Strings.Exiting); - return; } }); @@ -101,6 +101,7 @@ public static void Main(string[] args) } parser.InvokeAsync(parseResult).Wait(); + Environment.Exit(exitCode); } } } diff --git a/src/mlnet/Strings.resx b/src/mlnet/Strings.resx index 81b4ad8e13..f220f9ca2c 100644 --- a/src/mlnet/Strings.resx +++ b/src/mlnet/Strings.resx @@ -198,4 +198,7 @@ Exception occured while saving the model + + {0} seconds was not enough to train at least one model for your dataset. Try with a longer time. Learn about recommended training time at https://aka.ms/cli-trainingtime + \ No newline at end of file diff --git a/src/mlnet/Utilities/ProgressHandlers.cs b/src/mlnet/Utilities/ProgressHandlers.cs index 39123432ba..1fc5996816 100644 --- a/src/mlnet/Utilities/ProgressHandlers.cs +++ b/src/mlnet/Utilities/ProgressHandlers.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using Microsoft.ML.AutoML; using Microsoft.ML.CLI.ShellProgressBar; using Microsoft.ML.Data; @@ -23,13 +24,16 @@ internal class RegressionHandler : IProgress> private readonly Func, double> GetScore; private RunDetail bestResult; private int iterationIndex; + private List> completedIterations; private ProgressBar progressBar; private string optimizationMetric = string.Empty; + private bool isStopped; - public RegressionHandler(RegressionMetric optimizationMetric, ShellProgressBar.ProgressBar progressBar) + public RegressionHandler(RegressionMetric optimizationMetric, List> completedIterations, ShellProgressBar.ProgressBar progressBar) { this.isMaximizing = new OptimizingMetricInfo(optimizationMetric).IsMaximizing; this.optimizationMetric = optimizationMetric.ToString(); + this.completedIterations = completedIterations; this.progressBar = progressBar; GetScore = (RunDetail result) => new RegressionMetricsAgent(null, optimizationMetric).GetScore(result?.ValidationMetrics); ConsolePrinter.PrintRegressionMetricsHeader(LogLevel.Trace); @@ -37,14 +41,28 @@ public RegressionHandler(RegressionMetric optimizationMetric, ShellProgressBar.P public void Report(RunDetail iterationResult) { - iterationIndex++; - UpdateBestResult(iterationResult); - if (progressBar != null) - progressBar.Message = $"Best quality({this.optimizationMetric}): {GetScore(bestResult):F4}, Best Algorithm: {bestResult?.TrainerName}, Last Algorithm: {iterationResult?.TrainerName}"; - ConsolePrinter.PrintMetrics(iterationIndex, iterationResult?.TrainerName, iterationResult?.ValidationMetrics, GetScore(bestResult), iterationResult?.RuntimeInSeconds, LogLevel.Trace); - if (iterationResult.Exception != null) + lock (this) { - ConsolePrinter.PrintException(iterationResult.Exception, LogLevel.Trace); + if (this.isStopped) + return; + + iterationIndex++; + completedIterations.Add(iterationResult); + UpdateBestResult(iterationResult); + if (progressBar != null) + progressBar.Message = $"Best quality({this.optimizationMetric}): {GetScore(bestResult):F4}, Best Algorithm: {bestResult?.TrainerName}, Last Algorithm: {iterationResult?.TrainerName}"; + ConsolePrinter.PrintMetrics(iterationIndex, iterationResult?.TrainerName, iterationResult?.ValidationMetrics, GetScore(bestResult), iterationResult?.RuntimeInSeconds, LogLevel.Trace); + if (iterationResult.Exception != null) + { + ConsolePrinter.PrintException(iterationResult.Exception, LogLevel.Trace); + } + } + } + public void Stop() + { + lock (this) + { + this.isStopped = true; } } @@ -65,11 +83,14 @@ internal class BinaryClassificationHandler : IProgress> completedIterations; + private bool isStopped; - public BinaryClassificationHandler(BinaryClassificationMetric optimizationMetric, ProgressBar progressBar) + public BinaryClassificationHandler(BinaryClassificationMetric optimizationMetric, List> completedIterations, ProgressBar progressBar) { this.isMaximizing = new OptimizingMetricInfo(optimizationMetric).IsMaximizing; this.optimizationMetric = optimizationMetric; + this.completedIterations = completedIterations; this.progressBar = progressBar; GetScore = (RunDetail result) => new BinaryMetricsAgent(null, optimizationMetric).GetScore(result?.ValidationMetrics); ConsolePrinter.PrintBinaryClassificationMetricsHeader(LogLevel.Trace); @@ -77,14 +98,20 @@ public BinaryClassificationHandler(BinaryClassificationMetric optimizationMetric public void Report(RunDetail iterationResult) { - iterationIndex++; - UpdateBestResult(iterationResult); - if (progressBar != null) - progressBar.Message = GetProgressBarMessage(iterationResult); - ConsolePrinter.PrintMetrics(iterationIndex, iterationResult?.TrainerName, iterationResult?.ValidationMetrics, GetScore(bestResult), iterationResult?.RuntimeInSeconds, LogLevel.Trace); - if (iterationResult.Exception != null) + lock (this) { - ConsolePrinter.PrintException(iterationResult.Exception, LogLevel.Trace); + if (this.isStopped) + return; + iterationIndex++; + completedIterations.Add(iterationResult); + UpdateBestResult(iterationResult); + if (progressBar != null) + progressBar.Message = GetProgressBarMessage(iterationResult); + ConsolePrinter.PrintMetrics(iterationIndex, iterationResult?.TrainerName, iterationResult?.ValidationMetrics, GetScore(bestResult), iterationResult?.RuntimeInSeconds, LogLevel.Trace); + if (iterationResult.Exception != null) + { + ConsolePrinter.PrintException(iterationResult.Exception, LogLevel.Trace); + } } } @@ -98,6 +125,14 @@ private string GetProgressBarMessage(RunDetail iter return $"Best {this.optimizationMetric}: {GetScore(bestResult):F4}, Best Algorithm: {bestResult?.TrainerName}, Last Algorithm: {iterationResult?.TrainerName}"; } + public void Stop() + { + lock (this) + { + this.isStopped = true; + } + } + private void UpdateBestResult(RunDetail iterationResult) { if (MetricComparator(GetScore(iterationResult), GetScore(bestResult), isMaximizing) > 0) @@ -115,11 +150,14 @@ internal class MulticlassClassificationHandler : IProgress> completedIterations; + private bool isStopped; - public MulticlassClassificationHandler(MulticlassClassificationMetric optimizationMetric, ProgressBar progressBar) + public MulticlassClassificationHandler(MulticlassClassificationMetric optimizationMetric, List> completedIterations, ProgressBar progressBar) { this.isMaximizing = new OptimizingMetricInfo(optimizationMetric).IsMaximizing; this.optimizationMetric = optimizationMetric; + this.completedIterations = completedIterations; this.progressBar = progressBar; GetScore = (RunDetail result) => new MultiMetricsAgent(null, optimizationMetric).GetScore(result?.ValidationMetrics); ConsolePrinter.PrintMulticlassClassificationMetricsHeader(LogLevel.Trace); @@ -127,14 +165,31 @@ public MulticlassClassificationHandler(MulticlassClassificationMetric optimizati public void Report(RunDetail iterationResult) { - iterationIndex++; - UpdateBestResult(iterationResult); - if (progressBar != null) - progressBar.Message = GetProgressBarMessage(iterationResult); - ConsolePrinter.PrintMetrics(iterationIndex, iterationResult?.TrainerName, iterationResult?.ValidationMetrics, GetScore(bestResult), iterationResult?.RuntimeInSeconds, LogLevel.Trace); - if (iterationResult.Exception != null) + lock (this) + { + if (this.isStopped) + { + return; + } + + iterationIndex++; + completedIterations.Add(iterationResult); + UpdateBestResult(iterationResult); + if (progressBar != null) + progressBar.Message = GetProgressBarMessage(iterationResult); + ConsolePrinter.PrintMetrics(iterationIndex, iterationResult?.TrainerName, iterationResult?.ValidationMetrics, GetScore(bestResult), iterationResult?.RuntimeInSeconds, LogLevel.Trace); + if (iterationResult.Exception != null) + { + ConsolePrinter.PrintException(iterationResult.Exception, LogLevel.Trace); + } + } + } + + public void Stop() + { + lock (this) { - ConsolePrinter.PrintException(iterationResult.Exception, LogLevel.Trace); + this.isStopped = true; } } diff --git a/src/mlnet/strings.Designer.cs b/src/mlnet/strings.Designer.cs index 7ae8e6bebc..e9791c1318 100644 --- a/src/mlnet/strings.Designer.cs +++ b/src/mlnet/strings.Designer.cs @@ -69,6 +69,15 @@ internal static string BestPipeline { } } + /// + /// Looks up a localized string similar to {0} seconds was not enough to train at least one model for your dataset. Try with a longer time. Learn about recommended training time at https://aka.ms/cli-trainingtime. + /// + internal static string CouldNotFinshOnTime { + get { + return ResourceManager.GetString("CouldNotFinshOnTime", resourceCulture); + } + } + /// /// Looks up a localized string similar to Creating Data loader .... ///