Skip to content

[AutoML] Early stopping in CLI based on the exploration time #3641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
25 changes: 25 additions & 0 deletions src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,36 @@

using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;

namespace Microsoft.ML.AutoML
{
internal class BestResultUtil
{
public static RunDetail<BinaryClassificationMetrics> GetBestRun(IEnumerable<RunDetail<BinaryClassificationMetrics>> results,
BinaryClassificationMetric metric)
{
var metricsAgent = new BinaryMetricsAgent(null, metric);
var metricInfo = new OptimizingMetricInfo(metric);
return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing);
}

public static RunDetail<RegressionMetrics> GetBestRun(IEnumerable<RunDetail<RegressionMetrics>> results,
RegressionMetric metric)
{
var metricsAgent = new RegressionMetricsAgent(null, metric);
var metricInfo = new OptimizingMetricInfo(metric);
return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing);
}

public static RunDetail<MulticlassClassificationMetrics> GetBestRun(IEnumerable<RunDetail<MulticlassClassificationMetrics>> results,
MulticlassClassificationMetric metric)
{
var metricsAgent = new MultiMetricsAgent(null, metric);
var metricInfo = new OptimizingMetricInfo(metric);
return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing);
}

public static RunDetail<TMetrics> GetBestRun<TMetrics>(IEnumerable<RunDetail<TMetrics>> results,
IMetricsAgent<TMetrics> metricsAgent, bool isMetricMaximizing)
{
Expand Down
29 changes: 13 additions & 16 deletions src/mlnet/AutoML/AutoMLEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System;
using Microsoft.ML.AutoML;
using Microsoft.ML.CLI.Data;
using Microsoft.ML.CLI.ShellProgressBar;
Expand Down Expand Up @@ -44,47 +44,44 @@ public ColumnInferenceResults InferColumns(MLContext context, ColumnInformation
return columnInference;
}

ExperimentResult<BinaryClassificationMetrics> IAutoMLEngine.ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressBar progressBar)
void IAutoMLEngine.ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressHandlers.BinaryClassificationHandler handler, ProgressBar progressBar)
{
var progressReporter = new ProgressHandlers.BinaryClassificationHandler(optimizationMetric, progressBar);
var result = context.Auto()
ExperimentResult<BinaryClassificationMetrics> result = context.Auto()
.CreateBinaryClassificationExperiment(new BinaryExperimentSettings()
{
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
CacheBeforeTrainer = this.cacheBeforeTrainer,
OptimizingMetric = optimizationMetric
})
.Execute(trainData, validationData, columnInformation, progressHandler: progressReporter);
.Execute(trainData, validationData, columnInformation, progressHandler: handler);

logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline);
return result;
}

ExperimentResult<RegressionMetrics> IAutoMLEngine.ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressBar progressBar)
void IAutoMLEngine.ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressHandlers.RegressionHandler handler, ProgressBar progressBar)
{
var progressReporter = new ProgressHandlers.RegressionHandler(optimizationMetric, progressBar);
var result = context.Auto()
ExperimentResult<RegressionMetrics> result = context.Auto()
.CreateRegressionExperiment(new RegressionExperimentSettings()
{
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
OptimizingMetric = optimizationMetric,
CacheBeforeTrainer = this.cacheBeforeTrainer
}).Execute(trainData, validationData, columnInformation, progressHandler: progressReporter);
}).Execute(trainData, validationData, columnInformation, progressHandler: handler);

logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline);
return result;
}

ExperimentResult<MulticlassClassificationMetrics> IAutoMLEngine.ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressBar progressBar)
void IAutoMLEngine.ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressHandlers.MulticlassClassificationHandler handler, ProgressBar progressBar)
{
var progressReporter = new ProgressHandlers.MulticlassClassificationHandler(optimizationMetric, progressBar);
var result = context.Auto()
ExperimentResult<MulticlassClassificationMetrics> result = context.Auto()
.CreateMulticlassClassificationExperiment(new MulticlassExperimentSettings()
{
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
CacheBeforeTrainer = this.cacheBeforeTrainer,
OptimizingMetric = optimizationMetric
}).Execute(trainData, validationData, columnInformation, progressHandler: progressReporter);
}).Execute(trainData, validationData, columnInformation, progressHandler: handler);

logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline);
return result;
}

}
Expand Down
8 changes: 5 additions & 3 deletions src/mlnet/AutoML/IAutoMLEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using Microsoft.ML.AutoML;
using Microsoft.ML.CLI.ShellProgressBar;
using Microsoft.ML.CLI.Utilities;
using Microsoft.ML.Data;

namespace Microsoft.ML.CLI.CodeGenerator
Expand All @@ -13,11 +15,11 @@ internal interface IAutoMLEngine
{
ColumnInferenceResults InferColumns(MLContext context, ColumnInformation columnInformation);

ExperimentResult<BinaryClassificationMetrics> ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressBar progressBar = null);
void ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressHandlers.BinaryClassificationHandler handler, ProgressBar progressBar = null);

ExperimentResult<MulticlassClassificationMetrics> ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressBar progressBar = null);
void ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressHandlers.MulticlassClassificationHandler handler, ProgressBar progressBar = null);

ExperimentResult<RegressionMetrics> ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressBar progressBar = null);
void ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressHandlers.RegressionHandler handler, ProgressBar progressBar = null);

}
}
Loading