diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index bccbd3e6f8..ea9d2571ad 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -107,10 +107,10 @@ public sealed class Arguments : DataCommand.ArgumentsBase public CrossValidationCommand(IHostEnvironment env, Arguments args) : base(env, args, RegistrationName) { - Host.CheckUserArg(Args.NumFolds >= 2, nameof(Args.NumFolds), "Number of folds must be greater than or equal to 2."); + Host.CheckUserArg(ImplOptions.NumFolds >= 2, nameof(ImplOptions.NumFolds), "Number of folds must be greater than or equal to 2."); TrainUtils.CheckTrainer(Host, args.Trainer, args.DataFile); - Utils.CheckOptionalUserDirectory(Args.SummaryFilename, nameof(Args.SummaryFilename)); - Utils.CheckOptionalUserDirectory(Args.OutputDataFile, nameof(Args.OutputDataFile)); + Utils.CheckOptionalUserDirectory(ImplOptions.SummaryFilename, nameof(ImplOptions.SummaryFilename)); + Utils.CheckOptionalUserDirectory(ImplOptions.OutputDataFile, nameof(ImplOptions.OutputDataFile)); } // This is for "forking" the host environment. @@ -124,7 +124,7 @@ public override void Run() using (var ch = Host.Start(LoadName)) using (var server = InitServer(ch)) { - var settings = CmdParser.GetSettings(Host, Args, new Arguments()); + var settings = CmdParser.GetSettings(Host, ImplOptions, new Arguments()); string cmd = string.Format("maml.exe {0} {1}", LoadName, settings); ch.Info(cmd); @@ -139,7 +139,7 @@ public override void Run() protected override void SendTelemetryCore(IPipe pipe) { - SendTelemetryComponent(pipe, Args.Trainer); + SendTelemetryComponent(pipe, ImplOptions.Trainer); base.SendTelemetryCore(pipe); } @@ -148,17 +148,17 @@ private void RunCore(IChannel ch, string cmd) Host.AssertValue(ch); IPredictor inputPredictor = null; - if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor)) + if (ImplOptions.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, ImplOptions.InputModelFile, out inputPredictor)) ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized."); ch.Trace("Constructing data pipeline"); IDataLoader loader = CreateRawLoader(); // If the per-instance results are requested and there is no name column, add a GenerateNumberTransform. - var preXf = Args.PreTransforms; - if (!string.IsNullOrEmpty(Args.OutputDataFile)) + var preXf = ImplOptions.PreTransforms; + if (!string.IsNullOrEmpty(ImplOptions.OutputDataFile)) { - string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name); + string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(ImplOptions.NameColumn), ImplOptions.NameColumn, DefaultColumnNames.Name); if (name == null) { preXf = preXf.Concat( @@ -182,24 +182,24 @@ private void RunCore(IChannel ch, string cmd) IDataView pipe = loader; var stratificationColumn = GetSplitColumn(ch, loader, ref pipe); - var scorer = Args.Scorer; - var evaluator = Args.Evaluator; + var scorer = ImplOptions.Scorer; + var evaluator = ImplOptions.Evaluator; Func validDataCreator = null; - if (Args.ValidationFile != null) + if (ImplOptions.ValidationFile != null) { validDataCreator = () => { // Fork the command. var impl = new CrossValidationCommand(this); - return impl.CreateRawLoader(dataFile: Args.ValidationFile); + return impl.CreateRawLoader(dataFile: ImplOptions.ValidationFile); }; } FoldHelper fold = new FoldHelper(Host, RegistrationName, pipe, stratificationColumn, - Args, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator, - validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(Args.OutputDataFile)); + ImplOptions, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator, + validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(ImplOptions.OutputDataFile)); var tasks = fold.GetCrossValidationTasks(); var eval = evaluator?.CreateComponent(Host) ?? @@ -218,32 +218,32 @@ private void RunCore(IChannel ch, string cmd) throw ch.Except("No overall metrics found"); var overall = eval.GetOverallResults(overallList.ToArray()); - MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, Args.NumFolds); + MetricWriter.PrintOverallMetrics(Host, ch, ImplOptions.SummaryFilename, overall, ImplOptions.NumFolds); eval.PrintAdditionalMetrics(ch, tasks.Select(t => t.Result.Metrics).ToArray()); Dictionary[] metricValues = tasks.Select(t => t.Result.Metrics).ToArray(); SendTelemetryMetric(metricValues); // Save the per-instance results. - if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) + if (!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile)) { - var perInstance = EvaluateUtils.ConcatenatePerInstanceDataViews(Host, eval, Args.CollateMetrics, - Args.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames); + var perInstance = EvaluateUtils.ConcatenatePerInstanceDataViews(Host, eval, ImplOptions.CollateMetrics, + ImplOptions.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames); if (variableSizeVectorColumnNames.Length > 0) { ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.", string.Join(", ", variableSizeVectorColumnNames)); } - if (Args.CollateMetrics) + if (ImplOptions.CollateMetrics) { ch.Assert(perInstance.Length == 1); - MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, perInstance[0]); + MetricWriter.SavePerInstance(Host, ch, ImplOptions.OutputDataFile, perInstance[0]); } else { int i = 0; foreach (var idv in perInstance) { - MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(Args.OutputDataFile, i), idv); + MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(ImplOptions.OutputDataFile, i), idv); i++; } } @@ -265,20 +265,20 @@ private RoleMappedData ApplyAllTransformsToData(IHostEnvironment env, IChannel c /// private RoleMappedData CreateRoleMappedData(IHostEnvironment env, IChannel ch, IDataView data, ITrainer trainer) { - foreach (var kvp in Args.Transforms) + foreach (var kvp in ImplOptions.Transforms) data = kvp.Value.CreateComponent(env, data); var schema = data.Schema; - string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.LabelColumn), Args.LabelColumn, DefaultColumnNames.Label); - string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.FeatureColumn), Args.FeatureColumn, DefaultColumnNames.Features); - string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.WeightColumn), Args.WeightColumn, DefaultColumnNames.Weight); - string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name); - string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId); + string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.LabelColumn), ImplOptions.LabelColumn, DefaultColumnNames.Label); + string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.FeatureColumn), ImplOptions.FeatureColumn, DefaultColumnNames.Features); + string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.WeightColumn), ImplOptions.WeightColumn, DefaultColumnNames.Weight); + string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.NameColumn), ImplOptions.NameColumn, DefaultColumnNames.Name); + string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.GroupColumn), ImplOptions.GroupColumn, DefaultColumnNames.GroupId); - TrainUtils.AddNormalizerIfNeeded(env, ch, trainer, ref data, features, Args.NormalizeFeatures); + TrainUtils.AddNormalizerIfNeeded(env, ch, trainer, ref data, features, ImplOptions.NormalizeFeatures); // Training pipe and examples. - var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumns); + var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns); return new RoleMappedData(data, label, features, group, weight, name, customCols); } @@ -291,11 +291,11 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output // If no stratification column was specified, but we have a group column of type Single, Double or // Key (contiguous) use it. string stratificationColumn = null; - if (!string.IsNullOrWhiteSpace(Args.StratificationColumn)) - stratificationColumn = Args.StratificationColumn; + if (!string.IsNullOrWhiteSpace(ImplOptions.StratificationColumn)) + stratificationColumn = ImplOptions.StratificationColumn; else { - string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId); + string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.GroupColumn), ImplOptions.GroupColumn, DefaultColumnNames.GroupId); int index; if (group != null && schema.TryGetColumnIndex(group, out index)) { diff --git a/src/Microsoft.ML.Data/Commands/DataCommand.cs b/src/Microsoft.ML.Data/Commands/DataCommand.cs index 8e8a390f1e..ae9ebe80cf 100644 --- a/src/Microsoft.ML.Data/Commands/DataCommand.cs +++ b/src/Microsoft.ML.Data/Commands/DataCommand.cs @@ -60,11 +60,11 @@ public abstract class ArgumentsBase } [BestFriend] - internal abstract class ImplBase : ICommand - where TArgs : ArgumentsBase + internal abstract class ImplBase : ICommand + where TOptions : ArgumentsBase { protected readonly IHost Host; - protected readonly TArgs Args; + protected readonly TOptions ImplOptions; private readonly ServerChannel.IServerFactory _serverFactory; protected ServerChannel.IServer InitServer(IChannel ch) @@ -78,35 +78,35 @@ protected ServerChannel.IServer InitServer(IChannel ch) /// The degree of concurrency is passed in the conc parameter. If it is null, the value /// of args.parralel is used. If that is null, zero is used (which means "automatic"). /// - protected ImplBase(IHostEnvironment env, TArgs args, string name, int? conc = null) + protected ImplBase(IHostEnvironment env, TOptions options, string name, int? conc = null) { Contracts.CheckValue(env, nameof(env)); // Note that env may be null here, which is OK since the CheckXxx methods are extension // methods designed to allow null. - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckParam(conc == null || conc >= 0, nameof(conc), "Degree of concurrency must be non-negative (or null)"); - conc = conc ?? args.Parallel; - env.CheckUserArg(!(conc < 0), nameof(args.Parallel), "Degree of parallelism must be non-negative (or null)"); + conc = conc ?? options.Parallel; + env.CheckUserArg(!(conc < 0), nameof(options.Parallel), "Degree of parallelism must be non-negative (or null)"); // Capture the environment options from args. - env = env.Register(name, args.RandomSeed, args.Verbose, conc); + env = env.Register(name, options.RandomSeed, options.Verbose, conc); env.CheckNonWhiteSpace(name, nameof(name)); Host = env.Register(name); - Args = args; - _serverFactory = args.Server; - Utils.CheckOptionalUserDirectory(args.OutputModelFile, nameof(args.OutputModelFile)); + ImplOptions = options; + _serverFactory = options.Server; + Utils.CheckOptionalUserDirectory(options.OutputModelFile, nameof(options.OutputModelFile)); } - protected ImplBase(ImplBase impl, string name) + protected ImplBase(ImplBase impl, string name) { Contracts.CheckValue(impl, nameof(impl)); Contracts.AssertValue(impl.Host); - impl.Host.AssertValue(impl.Args); + impl.Host.AssertValue(impl.ImplOptions); impl.Host.AssertValue(name); - Args = impl.Args; + ImplOptions = impl.ImplOptions; Host = impl.Host.Register(name); } @@ -135,9 +135,9 @@ protected virtual void SendTelemetryCore(IPipe pipe) { Contracts.AssertValue(pipe); - if (Args.Transforms != null) + if (ImplOptions.Transforms != null) { - foreach (var transform in Args.Transforms) + foreach (var transform in ImplOptions.Transforms) SendTelemetryComponent(pipe, transform.Value); } } @@ -221,9 +221,9 @@ protected void SaveLoader(IDataLoader loader, string path) protected IDataLoader CreateAndSaveLoader(Func defaultLoaderFactory = null) { var loader = CreateLoader(defaultLoaderFactory); - if (!string.IsNullOrWhiteSpace(Args.OutputModelFile)) + if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile)) { - using (var file = Host.CreateOutputFile(Args.OutputModelFile)) + using (var file = Host.CreateOutputFile(ImplOptions.OutputModelFile)) LoaderUtils.SaveLoader(loader, file); } return loader; @@ -258,7 +258,7 @@ protected void LoadModelObjects( // First handle the case where there is no input model file. // Everything must come from the command line. - using (var file = Host.OpenInputFile(Args.InputModelFile)) + using (var file = Host.OpenInputFile(ImplOptions.InputModelFile)) using (var strm = file.OpenReadStream()) using (var rep = RepositoryReader.Open(strm, Host)) { @@ -274,28 +274,28 @@ protected void LoadModelObjects( } // Next create the loader. - var loaderFactory = Args.Loader; + var loaderFactory = ImplOptions.Loader; IDataLoader trainPipe = null; if (loaderFactory != null) { // The loader is overridden from the command line. - pipe = loaderFactory.CreateComponent(Host, new MultiFileSource(Args.DataFile)); - if (Args.LoadTransforms == true) + pipe = loaderFactory.CreateComponent(Host, new MultiFileSource(ImplOptions.DataFile)); + if (ImplOptions.LoadTransforms == true) { - Host.CheckUserArg(!string.IsNullOrWhiteSpace(Args.InputModelFile), nameof(Args.InputModelFile)); + Host.CheckUserArg(!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile), nameof(ImplOptions.InputModelFile)); pipe = LoadTransformChain(pipe); } } else { - var loadTrans = Args.LoadTransforms ?? true; - pipe = LoadLoader(rep, Args.DataFile, loadTrans); + var loadTrans = ImplOptions.LoadTransforms ?? true; + pipe = LoadLoader(rep, ImplOptions.DataFile, loadTrans); if (loadTrans) trainPipe = pipe; } - if (Utils.Size(Args.Transforms) > 0) - pipe = CompositeDataLoader.Create(Host, pipe, Args.Transforms); + if (Utils.Size(ImplOptions.Transforms) > 0) + pipe = CompositeDataLoader.Create(Host, pipe, ImplOptions.Transforms); // Next consider loading the training data's role mapped schema. trainSchema = null; @@ -331,7 +331,7 @@ protected IDataLoader CreateLoader(Func 0) data = ColumnSelectingTransformer.CreateKeep(Host, data, keepColumns); } IDataSaver saver; - if (Args.Saver != null) - saver = Args.Saver.CreateComponent(Host); + if (ImplOptions.Saver != null) + saver = ImplOptions.Saver.CreateComponent(Host); else - saver = new TextSaver(Host, new TextSaver.Arguments() { Dense = Args.Dense }); + saver = new TextSaver(Host, new TextSaver.Arguments() { Dense = ImplOptions.Dense }); var cols = new List(); for (int i = 0; i < data.Schema.Count; i++) { - if (!Args.KeepHidden && data.Schema[i].IsHidden) + if (!ImplOptions.KeepHidden && data.Schema[i].IsHidden) continue; var type = data.Schema[i].Type; if (saver.IsColumnSavable(type)) @@ -156,9 +156,9 @@ private void RunCore(IChannel ch) Host.NotSensitive().Check(cols.Count > 0, "No valid columns to save"); // Send the first N lines to console. - if (Args.Rows > 0) + if (ImplOptions.Rows > 0) { - var args = new SkipTakeFilter.TakeOptions() { Count = Args.Rows }; + var args = new SkipTakeFilter.TakeOptions() { Count = ImplOptions.Rows }; data = SkipTakeFilter.Create(Host, args, data); } var textSaver = saver as TextSaver; diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs index 4747047f09..536ffc5a48 100644 --- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs @@ -82,9 +82,9 @@ public sealed class Arguments : DataCommand.ArgumentsBase public ScoreCommand(IHostEnvironment env, Arguments args) : base(env, args, nameof(ScoreCommand)) { - Host.CheckUserArg(!string.IsNullOrWhiteSpace(Args.InputModelFile), nameof(Args.InputModelFile), "The input model file is required."); - Host.CheckUserArg(!string.IsNullOrWhiteSpace(Args.OutputDataFile), nameof(Args.OutputDataFile), "The output data file is required."); - Utils.CheckOptionalUserDirectory(Args.OutputDataFile, nameof(Args.OutputDataFile)); + Host.CheckUserArg(!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile), nameof(ImplOptions.InputModelFile), "The input model file is required."); + Host.CheckUserArg(!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile), nameof(ImplOptions.OutputDataFile), "The output data file is required."); + Utils.CheckOptionalUserDirectory(ImplOptions.OutputDataFile, nameof(ImplOptions.OutputDataFile)); } public override void Run() @@ -107,17 +107,17 @@ private void RunCore(IChannel ch) ch.AssertValue(loader); ch.Trace("Creating pipeline"); - var scorer = Args.Scorer; + var scorer = ImplOptions.Scorer; ch.Assert(scorer == null || scorer is ICommandLineComponentFactory, "ScoreCommand should only be used from the command line."); var bindable = ScoreUtils.GetSchemaBindableMapper(Host, predictor, scorerFactorySettings: scorer as ICommandLineComponentFactory); ch.AssertValue(bindable); // REVIEW: We probably ought to prefer role mappings from the training schema. string feat = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, - nameof(Args.FeatureColumn), Args.FeatureColumn, DefaultColumnNames.Features); + nameof(ImplOptions.FeatureColumn), ImplOptions.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, - nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId); - var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumns); + nameof(ImplOptions.GroupColumn), ImplOptions.GroupColumn, DefaultColumnNames.GroupId); + var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns); var schema = new RoleMappedSchema(loader.Schema, label: null, feature: feat, group: group, custom: customCols, opt: true); var mapper = bindable.Bind(Host, schema); @@ -127,19 +127,19 @@ private void RunCore(IChannel ch) loader = CompositeDataLoader.ApplyTransform(Host, loader, "Scorer", scorer.ToString(), (env, view) => scorer.CreateComponent(env, view, mapper, trainSchema)); - loader = CompositeDataLoader.Create(Host, loader, Args.PostTransform); + loader = CompositeDataLoader.Create(Host, loader, ImplOptions.PostTransform); - if (!string.IsNullOrWhiteSpace(Args.OutputModelFile)) + if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile)) { ch.Trace("Saving the data pipe"); - SaveLoader(loader, Args.OutputModelFile); + SaveLoader(loader, ImplOptions.OutputModelFile); } ch.Trace("Creating saver"); IDataSaver writer; - if (Args.Saver == null) + if (ImplOptions.Saver == null) { - var ext = Path.GetExtension(Args.OutputDataFile); + var ext = Path.GetExtension(ImplOptions.OutputDataFile); var isText = ext == ".txt" || ext == ".tlc"; if (isText) { @@ -152,24 +152,24 @@ private void RunCore(IChannel ch) } else { - writer = Args.Saver.CreateComponent(Host); + writer = ImplOptions.Saver.CreateComponent(Host); } ch.Assert(writer != null); var outputIsBinary = writer is BinaryWriter; bool outputAllColumns = - Args.OutputAllColumns == true - || (Args.OutputAllColumns == null && Utils.Size(Args.OutputColumns) == 0 && outputIsBinary); + ImplOptions.OutputAllColumns == true + || (ImplOptions.OutputAllColumns == null && Utils.Size(ImplOptions.OutputColumns) == 0 && outputIsBinary); bool outputNamesAndLabels = - Args.OutputAllColumns == true || Utils.Size(Args.OutputColumns) == 0; + ImplOptions.OutputAllColumns == true || Utils.Size(ImplOptions.OutputColumns) == 0; - if (Args.OutputAllColumns == true && Utils.Size(Args.OutputColumns) != 0) - ch.Warning(nameof(Args.OutputAllColumns) + "=+ always writes all columns irrespective of " + nameof(Args.OutputColumns) + " specified."); + if (ImplOptions.OutputAllColumns == true && Utils.Size(ImplOptions.OutputColumns) != 0) + ch.Warning(nameof(ImplOptions.OutputAllColumns) + "=+ always writes all columns irrespective of " + nameof(ImplOptions.OutputColumns) + " specified."); - if (!outputAllColumns && Utils.Size(Args.OutputColumns) != 0) + if (!outputAllColumns && Utils.Size(ImplOptions.OutputColumns) != 0) { - foreach (var outCol in Args.OutputColumns) + foreach (var outCol in ImplOptions.OutputColumns) { if (!loader.Schema.TryGetColumnIndex(outCol, out int dummyColIndex)) throw ch.ExceptUserArg(nameof(Arguments.OutputColumns), "Column '{0}' not found.", outCol); @@ -183,7 +183,7 @@ private void RunCore(IChannel ch) var cols = new List(); for (int i = 0; i < loader.Schema.Count; i++) { - if (!Args.KeepHidden && loader.Schema[i].IsHidden) + if (!ImplOptions.KeepHidden && loader.Schema[i].IsHidden) continue; if (!(outputAllColumns || ShouldAddColumn(loader.Schema, i, maxScoreId, outputNamesAndLabels))) continue; @@ -200,7 +200,7 @@ private void RunCore(IChannel ch) ch.Check(cols.Count > 0, "No valid columns to save"); ch.Trace("Scoring and saving data"); - using (var file = Host.CreateOutputFile(Args.OutputDataFile)) + using (var file = Host.CreateOutputFile(ImplOptions.OutputDataFile)) using (var stream = file.CreateWriteStream()) writer.SaveData(stream, loader, cols.ToArray()); } @@ -229,7 +229,7 @@ private bool ShouldAddColumn(DataViewSchema schema, int i, uint scoreSet, bool o break; } } - if (Args.OutputColumns != null && Array.FindIndex(Args.OutputColumns, schema[i].Name.Equals) >= 0) + if (ImplOptions.OutputColumns != null && Array.FindIndex(ImplOptions.OutputColumns, schema[i].Name.Equals) >= 0) return true; return false; } diff --git a/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs b/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs index 85030a53b5..ea3ef9ba69 100644 --- a/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs @@ -69,7 +69,7 @@ private void RunCore(IChannel ch) IDataLoader loader = CreateAndSaveLoader(); using (var schemaWriter = new StringWriter()) { - RunOnData(schemaWriter, Args, loader); + RunOnData(schemaWriter, ImplOptions, loader); var str = schemaWriter.ToString(); ch.AssertNonEmpty(str); ch.Info(str); diff --git a/src/Microsoft.ML.Data/Commands/TestCommand.cs b/src/Microsoft.ML.Data/Commands/TestCommand.cs index 40f6621ab5..97d00f6d7d 100644 --- a/src/Microsoft.ML.Data/Commands/TestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TestCommand.cs @@ -63,9 +63,9 @@ public sealed class Arguments : DataCommand.ArgumentsBase public TestCommand(IHostEnvironment env, Arguments args) : base(env, args, nameof(TestCommand)) { - Host.CheckUserArg(!string.IsNullOrEmpty(Args.InputModelFile), nameof(Args.InputModelFile), "The input model file is required."); - Utils.CheckOptionalUserDirectory(Args.SummaryFilename, nameof(Args.SummaryFilename)); - Utils.CheckOptionalUserDirectory(Args.OutputDataFile, nameof(Args.OutputDataFile)); + Host.CheckUserArg(!string.IsNullOrEmpty(ImplOptions.InputModelFile), nameof(ImplOptions.InputModelFile), "The input model file is required."); + Utils.CheckOptionalUserDirectory(ImplOptions.SummaryFilename, nameof(ImplOptions.SummaryFilename)); + Utils.CheckOptionalUserDirectory(ImplOptions.OutputDataFile, nameof(ImplOptions.OutputDataFile)); } public override void Run() @@ -74,7 +74,7 @@ public override void Run() using (var ch = Host.Start(command)) using (var server = InitServer(ch)) { - var settings = CmdParser.GetSettings(Host, Args, new Arguments()); + var settings = CmdParser.GetSettings(Host, ImplOptions, new Arguments()); ch.Info("maml.exe {0} {1}", command, settings); SendTelemetry(Host); @@ -98,25 +98,25 @@ private void RunCore(IChannel ch) ch.Trace("Binding columns"); var schema = loader.Schema; - string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.LabelColumn), - Args.LabelColumn, DefaultColumnNames.Label); - string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.FeatureColumn), - Args.FeatureColumn, DefaultColumnNames.Features); - string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), - Args.GroupColumn, DefaultColumnNames.GroupId); - string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.WeightColumn), - Args.WeightColumn, DefaultColumnNames.Weight); - string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.NameColumn), - Args.NameColumn, DefaultColumnNames.Name); - var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumns); + string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.LabelColumn), + ImplOptions.LabelColumn, DefaultColumnNames.Label); + string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.FeatureColumn), + ImplOptions.FeatureColumn, DefaultColumnNames.Features); + string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.GroupColumn), + ImplOptions.GroupColumn, DefaultColumnNames.GroupId); + string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.WeightColumn), + ImplOptions.WeightColumn, DefaultColumnNames.Weight); + string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.NameColumn), + ImplOptions.NameColumn, DefaultColumnNames.Name); + var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns); // Score. ch.Trace("Scoring and evaluating"); - ch.Assert(Args.Scorer == null || Args.Scorer is ICommandLineComponentFactory, "TestCommand should only be used from the command line."); - IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, loader, features, group, customCols, Host, trainSchema); + ch.Assert(ImplOptions.Scorer == null || ImplOptions.Scorer is ICommandLineComponentFactory, "TestCommand should only be used from the command line."); + IDataScorerTransform scorePipe = ScoreUtils.GetScorer(ImplOptions.Scorer, predictor, loader, features, group, customCols, Host, trainSchema); // Evaluate. - var evaluator = Args.Evaluator?.CreateComponent(Host) ?? + var evaluator = ImplOptions.Evaluator?.CreateComponent(Host) ?? EvaluateUtils.GetEvaluator(Host, scorePipe.Schema); var data = new RoleMappedData(scorePipe, label, null, group, weight, name, customCols); var metrics = evaluator.Evaluate(data); @@ -125,16 +125,16 @@ private void RunCore(IChannel ch) if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall)) throw ch.Except("No overall metrics found"); overall = evaluator.GetOverallResults(overall); - MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1); + MetricWriter.PrintOverallMetrics(Host, ch, ImplOptions.SummaryFilename, overall, 1); evaluator.PrintAdditionalMetrics(ch, metrics); Dictionary[] metricValues = { metrics }; SendTelemetryMetric(metricValues); - if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) + if (!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile)) { var perInst = evaluator.GetPerInstanceMetrics(data); var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols); var idv = evaluator.GetPerInstanceDataViewToSave(perInstData); - MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv); + MetricWriter.SavePerInstance(Host, ch, ImplOptions.OutputDataFile, idv); } } } diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs index bd95d48993..7b0353a382 100644 --- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs @@ -113,7 +113,7 @@ public override void Run() using (var ch = Host.Start(command)) using (var server = InitServer(ch)) { - var settings = CmdParser.GetSettings(Host, Args, new Arguments()); + var settings = CmdParser.GetSettings(Host, ImplOptions, new Arguments()); string cmd = string.Format("maml.exe {0} {1}", command, settings); ch.Info(cmd); @@ -141,7 +141,7 @@ private void RunCore(IChannel ch, string cmd) ITrainer trainer = _trainer.CreateComponent(Host); IPredictor inputPredictor = null; - if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor)) + if (ImplOptions.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, ImplOptions.InputModelFile, out inputPredictor)) ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized."); ch.Trace("Constructing data pipeline"); @@ -154,16 +154,16 @@ private void RunCore(IChannel ch, string cmd) var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), _weightColumn, DefaultColumnNames.Weight); var name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), _nameColumn, DefaultColumnNames.Name); - TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref view, feature, Args.NormalizeFeatures); + TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref view, feature, ImplOptions.NormalizeFeatures); ch.Trace("Binding columns"); - var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumns); + var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns); var data = new RoleMappedData(view, label, feature, group, weight, name, customCols); // REVIEW: Unify the code that creates validation examples in Train, TrainTest and CV commands. RoleMappedData validData = null; - if (!string.IsNullOrWhiteSpace(Args.ValidationFile)) + if (!string.IsNullOrWhiteSpace(ImplOptions.ValidationFile)) { if (!trainer.Info.SupportsValidation) { @@ -172,7 +172,7 @@ private void RunCore(IChannel ch, string cmd) else { ch.Trace("Constructing the validation pipeline"); - IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile); + IDataView validPipe = CreateRawLoader(dataFile: ImplOptions.ValidationFile); validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, validPipe); validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames()); } @@ -183,23 +183,23 @@ private void RunCore(IChannel ch, string cmd) // indirectly use validation set to improve the model but the learned model should totally independent of test set. // Similar to validation set, the trainer can report the scores computed using test set. RoleMappedData testDataUsedInTrainer = null; - if (!string.IsNullOrWhiteSpace(Args.TestFile)) + if (!string.IsNullOrWhiteSpace(ImplOptions.TestFile)) { // In contrast to the if-else block for validation above, we do not throw a warning if test file is provided // because this is TrainTest command. if (trainer.Info.SupportsTest) { ch.Trace("Constructing the test pipeline"); - IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: Args.TestFile); + IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: ImplOptions.TestFile); testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, testPipeUsedInTrainer); testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames()); } } var predictor = TrainUtils.Train(Host, ch, data, trainer, validData, - Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor, testDataUsedInTrainer); + ImplOptions.Calibrator, ImplOptions.MaxCalibrationExamples, ImplOptions.CacheData, inputPredictor, testDataUsedInTrainer); - using (var file = Host.CreateOutputFile(Args.OutputModelFile)) + using (var file = Host.CreateOutputFile(ImplOptions.OutputModelFile)) TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd); } } diff --git a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs index 643be09886..b1343cb494 100644 --- a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs @@ -98,7 +98,7 @@ public override void Run() using (var ch = Host.Start(LoadName)) using (var server = InitServer(ch)) { - var settings = CmdParser.GetSettings(Host, Args, new Arguments()); + var settings = CmdParser.GetSettings(Host, ImplOptions, new Arguments()); string cmd = string.Format("maml.exe {0} {1}", LoadName, settings); ch.Info(cmd); @@ -113,7 +113,7 @@ public override void Run() protected override void SendTelemetryCore(IPipe pipe) { - SendTelemetryComponent(pipe, Args.Trainer); + SendTelemetryComponent(pipe, ImplOptions.Trainer); base.SendTelemetryCore(pipe); } @@ -123,10 +123,10 @@ private void RunCore(IChannel ch, string cmd) Host.AssertNonEmpty(cmd); ch.Trace("Constructing trainer"); - ITrainer trainer = Args.Trainer.CreateComponent(Host); + ITrainer trainer = ImplOptions.Trainer.CreateComponent(Host); IPredictor inputPredictor = null; - if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor)) + if (ImplOptions.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, ImplOptions.InputModelFile, out inputPredictor)) ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized."); ch.Trace("Constructing the training pipeline"); @@ -134,24 +134,24 @@ private void RunCore(IChannel ch, string cmd) var schema = trainPipe.Schema; string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), - Args.LabelColumn, DefaultColumnNames.Label); + ImplOptions.LabelColumn, DefaultColumnNames.Label); string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), - Args.FeatureColumn, DefaultColumnNames.Features); + ImplOptions.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), - Args.GroupColumn, DefaultColumnNames.GroupId); + ImplOptions.GroupColumn, DefaultColumnNames.GroupId); string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), - Args.WeightColumn, DefaultColumnNames.Weight); + ImplOptions.WeightColumn, DefaultColumnNames.Weight); string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), - Args.NameColumn, DefaultColumnNames.Name); + ImplOptions.NameColumn, DefaultColumnNames.Name); - TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, Args.NormalizeFeatures); + TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, ImplOptions.NormalizeFeatures); ch.Trace("Binding columns"); - var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumns); + var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns); var data = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols); RoleMappedData validData = null; - if (!string.IsNullOrWhiteSpace(Args.ValidationFile)) + if (!string.IsNullOrWhiteSpace(ImplOptions.ValidationFile)) { if (!trainer.Info.SupportsValidation) { @@ -160,7 +160,7 @@ private void RunCore(IChannel ch, string cmd) else { ch.Trace("Constructing the validation pipeline"); - IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile); + IDataView validPipe = CreateRawLoader(dataFile: ImplOptions.ValidationFile); validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe); validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames()); } @@ -171,42 +171,42 @@ private void RunCore(IChannel ch, string cmd) // indirectly use validation set to improve the model but the learned model should totally independent of test set. // Similar to validation set, the trainer can report the scores computed using test set. RoleMappedData testDataUsedInTrainer = null; - if (!string.IsNullOrWhiteSpace(Args.TestFile)) + if (!string.IsNullOrWhiteSpace(ImplOptions.TestFile)) { // In contrast to the if-else block for validation above, we do not throw a warning if test file is provided // because this is TrainTest command. if (trainer.Info.SupportsTest) { ch.Trace("Constructing the test pipeline"); - IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: Args.TestFile); + IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: ImplOptions.TestFile); testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, testPipeUsedInTrainer); testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames()); } } var predictor = TrainUtils.Train(Host, ch, data, trainer, validData, - Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor, testDataUsedInTrainer); + ImplOptions.Calibrator, ImplOptions.MaxCalibrationExamples, ImplOptions.CacheData, inputPredictor, testDataUsedInTrainer); IDataLoader testPipe; - bool hasOutfile = !string.IsNullOrEmpty(Args.OutputModelFile); + bool hasOutfile = !string.IsNullOrEmpty(ImplOptions.OutputModelFile); var tempFilePath = hasOutfile ? null : Path.GetTempFileName(); - using (var file = new SimpleFileHandle(ch, hasOutfile ? Args.OutputModelFile : tempFilePath, true, !hasOutfile)) + using (var file = new SimpleFileHandle(ch, hasOutfile ? ImplOptions.OutputModelFile : tempFilePath, true, !hasOutfile)) { TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd); ch.Trace("Constructing the testing pipeline"); using (var stream = file.OpenReadStream()) using (var rep = RepositoryReader.Open(stream, ch)) - testPipe = LoadLoader(rep, Args.TestFile, true); + testPipe = LoadLoader(rep, ImplOptions.TestFile, true); } // Score. ch.Trace("Scoring and evaluating"); - ch.Assert(Args.Scorer == null || Args.Scorer is ICommandLineComponentFactory, "TrainTestCommand should only be used from the command line."); - IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema); + ch.Assert(ImplOptions.Scorer == null || ImplOptions.Scorer is ICommandLineComponentFactory, "TrainTestCommand should only be used from the command line."); + IDataScorerTransform scorePipe = ScoreUtils.GetScorer(ImplOptions.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema); // Evaluate. - var evaluator = Args.Evaluator?.CreateComponent(Host) ?? + var evaluator = ImplOptions.Evaluator?.CreateComponent(Host) ?? EvaluateUtils.GetEvaluator(Host, scorePipe.Schema); var dataEval = new RoleMappedData(scorePipe, label, features, group, weight, name, customCols, opt: true); @@ -216,16 +216,16 @@ private void RunCore(IChannel ch, string cmd) if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall)) throw ch.Except("No overall metrics found"); overall = evaluator.GetOverallResults(overall); - MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1); + MetricWriter.PrintOverallMetrics(Host, ch, ImplOptions.SummaryFilename, overall, 1); evaluator.PrintAdditionalMetrics(ch, metrics); Dictionary[] metricValues = { metrics }; SendTelemetryMetric(metricValues); - if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) + if (!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile)) { var perInst = evaluator.GetPerInstanceMetrics(dataEval); var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols); var idv = evaluator.GetPerInstanceDataViewToSave(perInstData); - MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv); + MetricWriter.SavePerInstance(Host, ch, ImplOptions.OutputDataFile, idv); } } } diff --git a/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs b/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs index cb211bcc75..2952c9a6ea 100644 --- a/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs +++ b/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs @@ -124,13 +124,13 @@ private void Run(IChannel ch) IPredictor rawPred; RoleMappedSchema trainSchema; - if (string.IsNullOrEmpty(Args.InputModelFile)) + if (string.IsNullOrEmpty(ImplOptions.InputModelFile)) { loader = CreateLoader(); rawPred = null; trainSchema = null; - Host.CheckUserArg(Args.LoadPredictor != true, nameof(Args.LoadPredictor), - "Cannot be set to true unless " + nameof(Args.InputModelFile) + " is also specifified."); + Host.CheckUserArg(ImplOptions.LoadPredictor != true, nameof(ImplOptions.LoadPredictor), + "Cannot be set to true unless " + nameof(ImplOptions.InputModelFile) + " is also specifified."); } else LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader); @@ -209,11 +209,11 @@ private void Run(IChannel ch) writer.Write(pfaDoc.ToString(_formatting)); } - if (!string.IsNullOrWhiteSpace(Args.OutputModelFile)) + if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile)) { ch.Trace("Saving the data pipe"); // Should probably include "end"? - SaveLoader(loader, Args.OutputModelFile); + SaveLoader(loader, ImplOptions.OutputModelFile); } } } diff --git a/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs b/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs index 6a53a5d2ce..4f20940a00 100644 --- a/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs +++ b/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs @@ -101,10 +101,10 @@ public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerI } } - public TolerantEarlyStoppingCriterion(Options args, bool lowerIsBetter) - : base(args, lowerIsBetter) + public TolerantEarlyStoppingCriterion(Options options, bool lowerIsBetter) + : base(options, lowerIsBetter) { - Contracts.CheckUserArg(EarlyStoppingCriterionOptions.Threshold >= 0, nameof(args.Threshold), "Must be non-negative."); + Contracts.CheckUserArg(EarlyStoppingCriterionOptions.Threshold >= 0, nameof(options.Threshold), "Must be non-negative."); } public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate) @@ -253,8 +253,8 @@ public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerI } } - public LPEarlyStoppingCriterion(Options args, bool lowerIsBetter) - : base(args, lowerIsBetter) { } + public LPEarlyStoppingCriterion(Options options, bool lowerIsBetter) + : base(options, lowerIsBetter) { } public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate) { @@ -291,8 +291,8 @@ public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerI } } - public PQEarlyStoppingCriterion(Options args, bool lowerIsBetter) - : base(args, lowerIsBetter) { } + public PQEarlyStoppingCriterion(Options options, bool lowerIsBetter) + : base(options, lowerIsBetter) { } public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate) { diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs index 1707f87b44..56dcc6e0c0 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs @@ -12,8 +12,8 @@ namespace Microsoft.ML.Trainers.Ensemble { internal abstract class BaseMultiAverager : BaseMultiCombiner { - private protected BaseMultiAverager(IHostEnvironment env, string name, OptionsBase args) - : base(env, name, args) + private protected BaseMultiAverager(IHostEnvironment env, string name, OptionsBase options) + : base(env, name, options) { } diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs index 3207aa9d3d..a6d4e65abe 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs @@ -24,14 +24,14 @@ public abstract class OptionsBase protected readonly bool Normalize; - internal BaseMultiCombiner(IHostEnvironment env, string name, OptionsBase args) + internal BaseMultiCombiner(IHostEnvironment env, string name, OptionsBase options) { Contracts.AssertValue(env); env.AssertNonWhiteSpace(name); Host = env.Register(name); - Host.CheckValue(args, nameof(args)); + Host.CheckValue(options, nameof(options)); - Normalize = args.Normalize; + Normalize = options.Normalize; } internal BaseMultiCombiner(IHostEnvironment env, string name, ModelLoadContext ctx) diff --git a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BaseSubsetSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BaseSubsetSelector.cs index 3305574d36..cf19364daa 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BaseSubsetSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BaseSubsetSelector.cs @@ -10,8 +10,8 @@ namespace Microsoft.ML.Trainers.Ensemble.SubsetSelector { - internal abstract class BaseSubsetSelector : ISubsetSelector - where TArgs : BaseSubsetSelector.ArgumentsBase + internal abstract class BaseSubsetSelector : ISubsetSelector + where TOptions : BaseSubsetSelector.ArgumentsBase { public abstract class ArgumentsBase { @@ -20,7 +20,7 @@ public abstract class ArgumentsBase } protected readonly IHost Host; - protected readonly TArgs Args; + protected readonly TOptions BaseSubsetSelectorOptions; protected readonly IFeatureSelector FeatureSelector; protected int Size; @@ -28,15 +28,15 @@ public abstract class ArgumentsBase protected int BatchSize; protected Single ValidationDatasetProportion; - protected BaseSubsetSelector(TArgs args, IHostEnvironment env, string name) + protected BaseSubsetSelector(TOptions options, IHostEnvironment env, string name) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckNonWhiteSpace(name, nameof(name)); Host = env.Register(name); - Args = args; - FeatureSelector = Args.FeatureSelector.CreateComponent(Host); + BaseSubsetSelectorOptions = options; + FeatureSelector = BaseSubsetSelectorOptions.FeatureSelector.CreateComponent(Host); } public void Initialize(RoleMappedData data, int size, int batchSize, Single validationDatasetProportion) diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index 724e741de5..847e9b6208 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -9,12 +9,12 @@ namespace Microsoft.ML.Trainers.FastTree { - public abstract class BoostingFastTreeTrainerBase : FastTreeTrainerBase + public abstract class BoostingFastTreeTrainerBase : FastTreeTrainerBase where TTransformer : ISingleFeaturePredictionTransformer - where TArgs : BoostedTreeArgs, new() + where TOptions : BoostedTreeOptions, new() where TModel : class { - private protected BoostingFastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label) : base(env, args, label) + private protected BoostingFastTreeTrainerBase(IHostEnvironment env, TOptions options, SchemaShape.Column label) : base(env, options, label) { } @@ -32,11 +32,11 @@ private protected BoostingFastTreeTrainerBase(IHostEnvironment env, FastTreeTrainerOptions.LearningRates = learningRate; } - private protected override void CheckArgs(IChannel ch) + private protected override void CheckOptions(IChannel ch) { - if (FastTreeTrainerOptions.OptimizationAlgorithm == BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent) + if (FastTreeTrainerOptions.OptimizationAlgorithm == BoostedTreeOptions.OptimizationAlgorithmType.AcceleratedGradientDescent) FastTreeTrainerOptions.UseLineSearch = true; - if (FastTreeTrainerOptions.OptimizationAlgorithm == BoostedTreeArgs.OptimizationAlgorithmType.ConjugateGradientDescent) + if (FastTreeTrainerOptions.OptimizationAlgorithm == BoostedTreeOptions.OptimizationAlgorithmType.ConjugateGradientDescent) FastTreeTrainerOptions.UseLineSearch = true; if (FastTreeTrainerOptions.CompressEnsemble && FastTreeTrainerOptions.WriteLastEnsemble) @@ -57,7 +57,7 @@ private protected override void CheckArgs(IChannel ch) if (FastTreeTrainerOptions.UseTolerantPruning && (!FastTreeTrainerOptions.EnablePruning || !HasValidSet)) throw ch.Except("Cannot perform tolerant pruning (prtol) without pruning (pruning) and a validation set (valid)"); - base.CheckArgs(ch); + base.CheckOptions(ch); } private protected override TreeLearner ConstructTreeLearner(IChannel ch) @@ -79,13 +79,13 @@ private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm( switch (FastTreeTrainerOptions.OptimizationAlgorithm) { - case BoostedTreeArgs.OptimizationAlgorithmType.GradientDescent: + case BoostedTreeOptions.OptimizationAlgorithmType.GradientDescent: optimizationAlgorithm = new GradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper); break; - case BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent: + case BoostedTreeOptions.OptimizationAlgorithmType.AcceleratedGradientDescent: optimizationAlgorithm = new AcceleratedGradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper); break; - case BoostedTreeArgs.OptimizationAlgorithmType.ConjugateGradientDescent: + case BoostedTreeOptions.OptimizationAlgorithmType.ConjugateGradientDescent: optimizationAlgorithm = new ConjugateGradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper); break; default: diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 933fae22f5..583e3797a9 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -95,7 +95,7 @@ public abstract class FastTreeTrainerBase : // random for active features selection private Random _featureSelectionRandom; - private protected string InnerArgs => CmdParser.GetSettings(Host, FastTreeTrainerOptions, new TOptions()); + private protected string InnerOptions => CmdParser.GetSettings(Host, FastTreeTrainerOptions, new TOptions()); public override TrainerInfo Info { get; } @@ -236,7 +236,7 @@ private protected void TrainCore(IChannel ch) { using (Timer.Time(TimerEvent.TotalInitialization)) { - CheckArgs(ch); + CheckOptions(ch); PrintPrologInfo(ch); Initialize(ch); @@ -270,7 +270,7 @@ private protected virtual void PrintExecutionTimes(IChannel ch) ch.Info("Execution time breakdown:\n{0}", Timer.GetString()); } - private protected virtual void CheckArgs(IChannel ch) + private protected virtual void CheckOptions(IChannel ch) { FastTreeTrainerOptions.Check(ch); @@ -2814,7 +2814,7 @@ public abstract class TreeEnsembleModelParameters : int ITreeEnsemble.NumTrees => TrainedEnsemble.NumTrees; // Inner args is used only for documentation purposes when saving comments to INI files. - private protected readonly string InnerArgs; + private protected readonly string InnerOptions; // The total number of features used in training (takes the value of zero if the // written version of the loaded model is less than VerNumFeaturesSerialized) @@ -2865,7 +2865,7 @@ private protected TreeEnsembleModelParameters(IHostEnvironment env, string name, // the trained ensemble to, for instance, resize arrays so that they are of the length // the actual number of leaves/nodes, or remove unnecessary arrays, and so forth. TrainedEnsemble = trainedEnsemble; - InnerArgs = innerArgs; + InnerOptions = innerArgs; NumFeatures = numFeatures; MaxSplitFeatIdx = trainedEnsemble.GetMaxFeatureIndex(); @@ -2895,7 +2895,7 @@ private protected TreeEnsembleModelParameters(IHostEnvironment env, string name, TrainedEnsemble = new InternalTreeEnsemble(ctx, usingDefaultValues, categoricalSplits); MaxSplitFeatIdx = TrainedEnsemble.GetMaxFeatureIndex(); - InnerArgs = ctx.LoadStringOrNull(); + InnerOptions = ctx.LoadStringOrNull(); if (ctx.Header.ModelVerWritten >= VerNumFeaturesSerialized) { NumFeatures = ctx.Reader.ReadInt32(); @@ -2924,7 +2924,7 @@ private protected override void SaveCore(ModelSaveContext ctx) // int: Number of features (VerNumFeaturesSerialized) // specific stuff TrainedEnsemble.Save(ctx); - ctx.SaveStringOrNull(InnerArgs); + ctx.SaveStringOrNull(InnerOptions); Host.Assert(NumFeatures >= 0); ctx.Writer.Write(NumFeatures); } @@ -3003,7 +3003,7 @@ void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, I { Host.CheckValue(writer, nameof(writer)); var ensembleIni = FastTreeIniFileUtils.TreeEnsembleToIni(Host, TrainedEnsemble, schema, calibrator, - InnerArgs, appendFeatureGain: true, includeZeroGainFeatures: false); + InnerOptions, appendFeatureGain: true, includeZeroGainFeatures: false); writer.WriteLine(ensembleIni); } diff --git a/src/Microsoft.ML.FastTree/FastTreeArguments.cs b/src/Microsoft.ML.FastTree/FastTreeArguments.cs index f3afaa18e0..11bb7a3e2c 100644 --- a/src/Microsoft.ML.FastTree/FastTreeArguments.cs +++ b/src/Microsoft.ML.FastTree/FastTreeArguments.cs @@ -24,7 +24,7 @@ internal interface IFastTreeTrainerFactory : IComponentFactory public sealed partial class FastTreeBinaryClassificationTrainer { [TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)] - public sealed class Options : BoostedTreeArgs, IFastTreeTrainerFactory + public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory { /// /// Option for using derivatives optimized for unbalanced sets. @@ -40,7 +40,7 @@ public sealed class Options : BoostedTreeArgs, IFastTreeTrainerFactory public sealed partial class FastTreeRegressionTrainer { [TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)] - public sealed class Options : BoostedTreeArgs, IFastTreeTrainerFactory + public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory { public Options() { @@ -54,7 +54,7 @@ public Options() public sealed partial class FastTreeTweedieTrainer { [TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)] - public sealed class Options : BoostedTreeArgs, IFastTreeTrainerFactory + public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory { // REVIEW: It is possible to estimate this index parameter from the distribution of data, using // a combination of univariate optimization and grid search, following section 4.2 of the paper. However @@ -71,7 +71,7 @@ public sealed class Options : BoostedTreeArgs, IFastTreeTrainerFactory public sealed partial class FastTreeRankingTrainer { [TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)] - public sealed class Options : BoostedTreeArgs, IFastTreeTrainerFactory + public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory { [Argument(ArgumentType.LastOccurenceWins, HelpText = "Comma seperated list of gains associated to each relevance label.", ShortName = "gains")] [TGUI(NoSweep = true)] @@ -442,7 +442,7 @@ internal virtual void Check(IExceptionContext ectx) } } - public abstract class BoostedTreeArgs : TreeOptions + public abstract class BoostedTreeOptions : TreeOptions { // REVIEW: TLC FR likes to call it bestStepRegressionTrees which might be more appropriate. //Use the second derivative for split gains (not just outputs). Use MaxTreeOutput to "clip" cases where the second derivative is too close to zero. diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index a12de3562c..798031cb13 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -178,7 +178,7 @@ private protected override CalibratedModelParametersBase= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)"); diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index b9cf757375..0b522005d0 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -102,14 +102,14 @@ private protected override FastTreeTweedieModelParameters TrainModelCore(TrainCo ConvertData(trainData); TrainCore(ch); } - return new FastTreeTweedieModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs); + return new FastTreeTweedieModelParameters(Host, TrainedEnsemble, FeatureCount, InnerOptions); } - private protected override void CheckArgs(IChannel ch) + private protected override void CheckOptions(IChannel ch) { Contracts.AssertValue(ch); - base.CheckArgs(ch); + base.CheckOptions(ch); // REVIEW: In order to properly support early stopping, the early stopping metric should be a subcomponent, not just // a simple integer, because the metric that we might want is parameterized by this floating point "index" parameter. For now diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs index c0073d5053..fb7262ece8 100644 --- a/src/Microsoft.ML.FastTree/GamClassification.cs +++ b/src/Microsoft.ML.FastTree/GamClassification.cs @@ -117,14 +117,14 @@ private protected override ObjectiveFunctionBase CreateObjectiveFunction() return new FastTreeBinaryClassificationTrainer.ObjectiveImpl( TrainSet, ConvertTargetsToBool(TrainSet.Targets), - Args.LearningRates, + GamTrainerOptions.LearningRates, 0, _sigmoidParameter, - Args.UnbalancedSets, - Args.MaxOutput, - Args.GetDerivativesSampleRate, + GamTrainerOptions.UnbalancedSets, + GamTrainerOptions.MaxOutput, + GamTrainerOptions.GetDerivativesSampleRate, false, - Args.RngSeed, + GamTrainerOptions.RngSeed, ParallelTraining ); } @@ -134,7 +134,7 @@ private protected override void DefinePruningTest() var validTest = new BinaryClassificationTest(ValidSetScore, ConvertTargetsToBool(ValidSet.Targets), _sigmoidParameter); // As per FastTreeClassification.ConstructOptimizationAlgorithm() - PruningLossIndex = Args.UnbalancedSets ? 3 /*Unbalanced sets loss*/ : 1 /*normal loss*/; + PruningLossIndex = GamTrainerOptions.UnbalancedSets ? 3 /*Unbalanced sets loss*/ : 1 /*normal loss*/; PruningTest = new TestHistory(validTest, PruningLossIndex); } diff --git a/src/Microsoft.ML.FastTree/GamModelParameters.cs b/src/Microsoft.ML.FastTree/GamModelParameters.cs index b227743dd8..18621930bc 100644 --- a/src/Microsoft.ML.FastTree/GamModelParameters.cs +++ b/src/Microsoft.ML.FastTree/GamModelParameters.cs @@ -888,9 +888,9 @@ private Context Init(IChannel ch) calibrated = rawPred as CalibratedModelParametersBase; } var pred = rawPred as GamModelParametersBase; - ch.CheckUserArg(pred != null, nameof(Args.InputModelFile), "Predictor was not a " + nameof(GamModelParametersBase)); + ch.CheckUserArg(pred != null, nameof(ImplOptions.InputModelFile), "Predictor was not a " + nameof(GamModelParametersBase)); var data = new RoleMappedData(loader, schema.GetColumnRoleNames(), opt: true); - if (hadCalibrator && !string.IsNullOrWhiteSpace(Args.OutputModelFile)) + if (hadCalibrator && !string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile)) ch.Warning("If you save the GAM model, only the GAM model, not the wrapping calibrator, will be saved."); return new Context(ch, pred, data, InitEvaluator(pred)); @@ -936,11 +936,11 @@ private void Run(IChannel ch) sch?.Register("setEffect", context.SetEffect); // Getting the metrics. sch?.Register("metrics", context.GetMetrics); - sch?.Register("canSave", () => !string.IsNullOrEmpty(Args.OutputModelFile)); - sch?.Register("save", () => context.SaveIfNeeded(Host, ch, Args.OutputModelFile)); + sch?.Register("canSave", () => !string.IsNullOrEmpty(ImplOptions.OutputModelFile)); + sch?.Register("save", () => context.SaveIfNeeded(Host, ch, ImplOptions.OutputModelFile)); sch?.Register("quit", () => { - var retVal = context.SaveIfNeeded(Host, ch, Args.OutputModelFile); + var retVal = context.SaveIfNeeded(Host, ch, ImplOptions.OutputModelFile); ev.Set(); return retVal; }); diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs index 93a4e06a1b..c3defd7990 100644 --- a/src/Microsoft.ML.FastTree/GamRegression.cs +++ b/src/Microsoft.ML.FastTree/GamRegression.cs @@ -76,12 +76,12 @@ private protected override RegressionGamModelParameters TrainModelCore(TrainCont private protected override ObjectiveFunctionBase CreateObjectiveFunction() { - return new FastTreeRegressionTrainer.ObjectiveImpl(TrainSet, Args); + return new FastTreeRegressionTrainer.ObjectiveImpl(TrainSet, GamTrainerOptions); } private protected override void DefinePruningTest() { - var validTest = new RegressionTest(ValidSetScore, Args.PruningMetrics); + var validTest = new RegressionTest(ValidSetScore, GamTrainerOptions.PruningMetrics); // Because we specify pruning metrics as L2 by default, the results array will have 1 value PruningLossIndex = 0; PruningTest = new TestHistory(validTest, PruningLossIndex); diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index b7787367e4..e7bb739eec 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -49,9 +49,9 @@ namespace Microsoft.ML.Trainers.FastTree /// ]]> /// /// - public abstract partial class GamTrainerBase : TrainerEstimatorBase + public abstract partial class GamTrainerBase : TrainerEstimatorBase where TTransformer: ISingleFeaturePredictionTransformer - where TArgs : GamTrainerBase.OptionsBase, new() + where TOptions : GamTrainerBase.OptionsBase, new() where TPredictor : class { public abstract class OptionsBase : LearnerInputBaseWithWeight @@ -110,7 +110,7 @@ public abstract class OptionsBase : LearnerInputBaseWithWeight private const string RegisterName = "GamTraining"; //Parameters of training - private protected readonly TArgs Args; + private protected readonly TOptions GamTrainerOptions; private readonly double _gainConfidenceInSquaredStandardDeviations; private readonly double _entropyCoefficient; @@ -156,44 +156,44 @@ private protected GamTrainerBase(IHostEnvironment env, int maxBins) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { - Args = new TArgs(); - Args.NumIterations = numIterations; - Args.LearningRates = learningRate; - Args.MaxBins = maxBins; + GamTrainerOptions = new TOptions(); + GamTrainerOptions.NumIterations = numIterations; + GamTrainerOptions.LearningRates = learningRate; + GamTrainerOptions.MaxBins = maxBins; - Args.LabelColumn = label.Name; - Args.FeatureColumn = featureColumn; + GamTrainerOptions.LabelColumn = label.Name; + GamTrainerOptions.FeatureColumn = featureColumn; if (weightColumn != null) - Args.WeightColumn = weightColumn; + GamTrainerOptions.WeightColumn = weightColumn; Info = new TrainerInfo(normalization: false, calibration: NeedCalibration, caching: false, supportValid: true); - _gainConfidenceInSquaredStandardDeviations = Math.Pow(ProbabilityFunctions.Probit(1 - (1 - Args.GainConfidenceLevel) * 0.5), 2); - _entropyCoefficient = Args.EntropyCoefficient * 1e-6; + _gainConfidenceInSquaredStandardDeviations = Math.Pow(ProbabilityFunctions.Probit(1 - (1 - GamTrainerOptions.GainConfidenceLevel) * 0.5), 2); + _entropyCoefficient = GamTrainerOptions.EntropyCoefficient * 1e-6; InitializeThreads(); } - private protected GamTrainerBase(IHostEnvironment env, TArgs args, string name, SchemaShape.Column label) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), - label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + private protected GamTrainerBase(IHostEnvironment env, TOptions options, string name, SchemaShape.Column label) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(options.FeatureColumn), + label, TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn)) { Contracts.CheckValue(env, nameof(env)); - Host.CheckValue(args, nameof(args)); + Host.CheckValue(options, nameof(options)); - Host.CheckParam(args.LearningRates > 0, nameof(args.LearningRates), "Must be positive."); - Host.CheckParam(args.NumThreads == null || args.NumThreads > 0, nameof(args.NumThreads), "Must be positive."); - Host.CheckParam(0 <= args.EntropyCoefficient && args.EntropyCoefficient <= 1, nameof(args.EntropyCoefficient), "Must be in [0, 1]."); - Host.CheckParam(0 <= args.GainConfidenceLevel && args.GainConfidenceLevel < 1, nameof(args.GainConfidenceLevel), "Must be in [0, 1)."); - Host.CheckParam(0 < args.MaxBins, nameof(args.MaxBins), "Must be posittive."); - Host.CheckParam(0 < args.NumIterations, nameof(args.NumIterations), "Must be positive."); - Host.CheckParam(0 < args.MinDocuments, nameof(args.MinDocuments), "Must be positive."); + Host.CheckParam(options.LearningRates > 0, nameof(options.LearningRates), "Must be positive."); + Host.CheckParam(options.NumThreads == null || options.NumThreads > 0, nameof(options.NumThreads), "Must be positive."); + Host.CheckParam(0 <= options.EntropyCoefficient && options.EntropyCoefficient <= 1, nameof(options.EntropyCoefficient), "Must be in [0, 1]."); + Host.CheckParam(0 <= options.GainConfidenceLevel && options.GainConfidenceLevel < 1, nameof(options.GainConfidenceLevel), "Must be in [0, 1)."); + Host.CheckParam(0 < options.MaxBins, nameof(options.MaxBins), "Must be posittive."); + Host.CheckParam(0 < options.NumIterations, nameof(options.NumIterations), "Must be positive."); + Host.CheckParam(0 < options.MinDocuments, nameof(options.MinDocuments), "Must be positive."); - Args = args; + GamTrainerOptions = options; Info = new TrainerInfo(normalization: false, calibration: NeedCalibration, caching: false, supportValid: true); - _gainConfidenceInSquaredStandardDeviations = Math.Pow(ProbabilityFunctions.Probit(1 - (1 - Args.GainConfidenceLevel) * 0.5), 2); - _entropyCoefficient = Args.EntropyCoefficient * 1e-6; + _gainConfidenceInSquaredStandardDeviations = Math.Pow(ProbabilityFunctions.Probit(1 - (1 - GamTrainerOptions.GainConfidenceLevel) * 0.5), 2); + _entropyCoefficient = GamTrainerOptions.EntropyCoefficient * 1e-6; InitializeThreads(); } @@ -234,8 +234,8 @@ private void ConvertData(RoleMappedData trainData, RoleMappedData validationData trainData.CheckOptFloatWeight(); CheckLabel(trainData); - var useTranspose = UseTranspose(Args.DiskTranspose, trainData); - var instanceConverter = new ExamplesToFastTreeBins(Host, Args.MaxBins, useTranspose, !Args.FeatureFlocks, Args.MinDocuments, float.PositiveInfinity); + var useTranspose = UseTranspose(GamTrainerOptions.DiskTranspose, trainData); + var instanceConverter = new ExamplesToFastTreeBins(Host, GamTrainerOptions.MaxBins, useTranspose, !GamTrainerOptions.FeatureFlocks, GamTrainerOptions.MinDocuments, float.PositiveInfinity); ParallelTraining.InitEnvironment(); TrainSet = instanceConverter.FindBinsAndReturnDataset(trainData, PredictionKind, ParallelTraining, null, false); @@ -275,7 +275,7 @@ private void TrainCore(IChannel ch) private void TrainMainEffectsModel(IChannel ch) { Contracts.AssertValue(ch); - int iterations = Args.NumIterations; + int iterations = GamTrainerOptions.NumIterations; ch.Info("Starting to train ..."); @@ -341,7 +341,7 @@ private void TrainingIteration(int globalFeatureIndex, double[] gradient, double // Compute the split for the feature _histogram[flockIndex].FindBestSplitForFeature(_leafSplitHelper, _leafSplitCandidates, _leafSplitCandidates.Targets.Length, sumTargets, sumWeights, - globalFeatureIndex, flockIndex, subFeatureIndex, Args.MinDocuments, HasWeights, + globalFeatureIndex, flockIndex, subFeatureIndex, GamTrainerOptions.MinDocuments, HasWeights, _gainConfidenceInSquaredStandardDeviations, _entropyCoefficient, TrainSet.Flocks[flockIndex].Trust(subFeatureIndex), 0); @@ -404,8 +404,8 @@ private void UpdateScoresForSet(Dataset dataset, double[] scores, int iteration) private void CombineGraphs(IChannel ch) { // Prune backwards to the best iteration - int bestIteration = Args.NumIterations; - if (Args.EnablePruning && PruningTest != null) + int bestIteration = GamTrainerOptions.NumIterations; + if (GamTrainerOptions.EnablePruning && PruningTest != null) { ch.Info("Pruning"); var finalResult = PruningTest.ComputeTests().ToArray()[PruningLossIndex]; @@ -416,8 +416,8 @@ private void CombineGraphs(IChannel ch) bestIteration = PruningTest.BestIteration; bestLoss = PruningTest.BestResult.FinalValue; } - if (bestIteration != Args.NumIterations) - ch.Info($"Best Iteration ({lossFunctionName}): {bestIteration} @ {bestLoss:G6} (vs {Args.NumIterations} @ {finalResult.FinalValue:G6})."); + if (bestIteration != GamTrainerOptions.NumIterations) + ch.Info($"Best Iteration ({lossFunctionName}): {bestIteration} @ {bestLoss:G6} (vs {GamTrainerOptions.NumIterations} @ {finalResult.FinalValue:G6})."); else ch.Info("No pruning necessary. More iterations may be necessary."); } @@ -557,8 +557,8 @@ private void ConvertTreeToGraph(int globalFeatureIndex, int iteration) { SplitInfo splitinfo = _leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex]; _subGraph.Splits[globalFeatureIndex][iteration].SplitPoint = splitinfo.Threshold; - _subGraph.Splits[globalFeatureIndex][iteration].LteValue = Args.LearningRates * splitinfo.LteOutput; - _subGraph.Splits[globalFeatureIndex][iteration].GtValue = Args.LearningRates * splitinfo.GTOutput; + _subGraph.Splits[globalFeatureIndex][iteration].LteValue = GamTrainerOptions.LearningRates * splitinfo.LteOutput; + _subGraph.Splits[globalFeatureIndex][iteration].GtValue = GamTrainerOptions.LearningRates * splitinfo.GTOutput; } private void InitializeGamHistograms() @@ -573,7 +573,7 @@ private void Initialize(IChannel ch) using (Timer.Time(TimerEvent.InitializeTraining)) { InitializeGamHistograms(); - _subGraph = new SubGraph(TrainSet.NumFeatures, Args.NumIterations); + _subGraph = new SubGraph(TrainSet.NumFeatures, GamTrainerOptions.NumIterations); _leafSplitCandidates = new LeastSquaresRegressionTreeLearner.LeafSplitCandidates(TrainSet); _leafSplitHelper = new LeafSplitHelper(HasWeights); } @@ -583,7 +583,7 @@ private void InitializeThreads() { ParallelTraining = new SingleTrainer(); - int numThreads = Args.NumThreads ?? Environment.ProcessorCount; + int numThreads = GamTrainerOptions.NumThreads ?? Environment.ProcessorCount; if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor) using (var ch = Host.Start("GamTrainer")) { diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs index 04bb1a03d1..3b7659145d 100644 --- a/src/Microsoft.ML.FastTree/RandomForest.cs +++ b/src/Microsoft.ML.FastTree/RandomForest.cs @@ -4,8 +4,8 @@ namespace Microsoft.ML.Trainers.FastTree { - public abstract class RandomForestTrainerBase : FastTreeTrainerBase - where TArgs : FastForestOptionsBase, new() + public abstract class RandomForestTrainerBase : FastTreeTrainerBase + where TOptions : FastForestOptionsBase, new() where TModel : class where TTransformer: ISingleFeaturePredictionTransformer { @@ -14,8 +14,8 @@ public abstract class RandomForestTrainerBase : Fas /// /// Constructor invoked by the maml code-path. /// - private protected RandomForestTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label, bool quantileEnabled = false) - : base(env, args, label) + private protected RandomForestTrainerBase(IHostEnvironment env, TOptions options, SchemaShape.Column label, bool quantileEnabled = false) + : base(env, options, label) { _quantileEnabled = quantileEnabled; } @@ -72,14 +72,14 @@ private protected override TreeLearner ConstructTreeLearner(IChannel ch) internal abstract class RandomForestObjectiveFunction : ObjectiveFunctionBase { - protected RandomForestObjectiveFunction(Dataset trainData, TArgs args, double maxStepSize) + protected RandomForestObjectiveFunction(Dataset trainData, TOptions options, double maxStepSize) : base(trainData, 1, // No learning rate in random forests. 1, // No shrinkage in random forests. maxStepSize, 1, // No derivative sampling in random forests. false, // Improvements to quasi-newton step not relevant to RF. - args.RngSeed) + options.RngSeed) { } } diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 59eaf439d6..c8be0b057f 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -191,7 +191,7 @@ private protected override FastForestClassificationModelParameters TrainModelCor // calibrator, transform the scores using that. // REVIEW: Need a way to signal the outside world that we prefer simple sigmoid? - return new FastForestClassificationModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs); + return new FastForestClassificationModelParameters(Host, TrainedEnsemble, FeatureCount, InnerOptions); } private protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch) diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index bb52c0bad8..c38718a148 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -313,7 +313,7 @@ private protected override FastForestRegressionModelParameters TrainModelCore(Tr ConvertData(trainData); TrainCore(ch); } - return new FastForestRegressionModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs, FastTreeTrainerOptions.QuantileSampleCount); + return new FastForestRegressionModelParameters(Host, TrainedEnsemble, FeatureCount, InnerOptions, FastTreeTrainerOptions.QuantileSampleCount); } private protected override void PrepareLabels(IChannel ch) diff --git a/src/Microsoft.ML.LightGBM/LightGbmArguments.cs b/src/Microsoft.ML.LightGBM/LightGbmArguments.cs index 8cf005f339..a0fc2ba344 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmArguments.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmArguments.cs @@ -42,14 +42,14 @@ public interface IBoosterParameter /// public sealed class Options : LearnerInputBaseWithGroupId { - public abstract class BoosterParameter : IBoosterParameter - where TArgs : class, new() + public abstract class BoosterParameter : IBoosterParameter + where TOptions : class, new() { - protected TArgs Args { get; } + protected TOptions BoosterParameterOptions { get; } - protected BoosterParameter(TArgs args) + protected BoosterParameter(TOptions options) { - Args = args; + BoosterParameterOptions = options; } /// @@ -57,7 +57,7 @@ protected BoosterParameter(TArgs args) /// internal virtual void UpdateParameters(Dictionary res) { - FieldInfo[] fields = Args.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + FieldInfo[] fields = BoosterParameterOptions.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); foreach (var field in fields) { var attribute = field.GetCustomAttribute(false); @@ -65,7 +65,7 @@ internal virtual void UpdateParameters(Dictionary res) if (attribute == null) continue; - res[GetArgName(field.Name)] = field.GetValue(Args); + res[GetArgName(field.Name)] = field.GetValue(BoosterParameterOptions); } } @@ -174,11 +174,11 @@ public class Options : ISupportBoosterParameterFactory internal TreeBooster(Options options) : base(options) { - Contracts.CheckUserArg(Args.MinSplitGain >= 0, nameof(Args.MinSplitGain), "must be >= 0."); - Contracts.CheckUserArg(Args.MinChildWeight >= 0, nameof(Args.MinChildWeight), "must be >= 0."); - Contracts.CheckUserArg(Args.Subsample > 0 && Args.Subsample <= 1, nameof(Args.Subsample), "must be in (0,1]."); - Contracts.CheckUserArg(Args.FeatureFraction > 0 && Args.FeatureFraction <= 1, nameof(Args.FeatureFraction), "must be in (0,1]."); - Contracts.CheckUserArg(Args.ScalePosWeight > 0 && Args.ScalePosWeight <= 1, nameof(Args.ScalePosWeight), "must be in (0,1]."); + Contracts.CheckUserArg(BoosterParameterOptions.MinSplitGain >= 0, nameof(BoosterParameterOptions.MinSplitGain), "must be >= 0."); + Contracts.CheckUserArg(BoosterParameterOptions.MinChildWeight >= 0, nameof(BoosterParameterOptions.MinChildWeight), "must be >= 0."); + Contracts.CheckUserArg(BoosterParameterOptions.Subsample > 0 && BoosterParameterOptions.Subsample <= 1, nameof(BoosterParameterOptions.Subsample), "must be in (0,1]."); + Contracts.CheckUserArg(BoosterParameterOptions.FeatureFraction > 0 && BoosterParameterOptions.FeatureFraction <= 1, nameof(BoosterParameterOptions.FeatureFraction), "must be in (0,1]."); + Contracts.CheckUserArg(BoosterParameterOptions.ScalePosWeight > 0 && BoosterParameterOptions.ScalePosWeight <= 1, nameof(BoosterParameterOptions.ScalePosWeight), "must be in (0,1]."); } internal override void UpdateParameters(Dictionary res) @@ -220,9 +220,9 @@ public sealed class Options : TreeBooster.Options internal DartBooster(Options options) : base(options) { - Contracts.CheckUserArg(Args.DropRate > 0 && Args.DropRate < 1, nameof(Args.DropRate), "must be in (0,1)."); - Contracts.CheckUserArg(Args.MaxDrop > 0, nameof(Args.MaxDrop), "must be > 0."); - Contracts.CheckUserArg(Args.SkipDrop >= 0 && Args.SkipDrop < 1, nameof(Args.SkipDrop), "must be in [0,1)."); + Contracts.CheckUserArg(BoosterParameterOptions.DropRate > 0 && BoosterParameterOptions.DropRate < 1, nameof(BoosterParameterOptions.DropRate), "must be in (0,1)."); + Contracts.CheckUserArg(BoosterParameterOptions.MaxDrop > 0, nameof(BoosterParameterOptions.MaxDrop), "must be > 0."); + Contracts.CheckUserArg(BoosterParameterOptions.SkipDrop >= 0 && BoosterParameterOptions.SkipDrop < 1, nameof(BoosterParameterOptions.SkipDrop), "must be in [0,1)."); } internal override void UpdateParameters(Dictionary res) @@ -257,9 +257,9 @@ public sealed class Options : TreeBooster.Options internal GossBooster(Options options) : base(options) { - Contracts.CheckUserArg(Args.TopRate > 0 && Args.TopRate < 1, nameof(Args.TopRate), "must be in (0,1)."); - Contracts.CheckUserArg(Args.OtherRate >= 0 && Args.OtherRate < 1, nameof(Args.TopRate), "must be in [0,1)."); - Contracts.Check(Args.TopRate + Args.OtherRate <= 1, "Sum of topRate and otherRate cannot be larger than 1."); + Contracts.CheckUserArg(BoosterParameterOptions.TopRate > 0 && BoosterParameterOptions.TopRate < 1, nameof(BoosterParameterOptions.TopRate), "must be in (0,1)."); + Contracts.CheckUserArg(BoosterParameterOptions.OtherRate >= 0 && BoosterParameterOptions.OtherRate < 1, nameof(BoosterParameterOptions.TopRate), "must be in [0,1)."); + Contracts.Check(BoosterParameterOptions.TopRate + BoosterParameterOptions.OtherRate <= 1, "Sum of topRate and otherRate cannot be larger than 1."); } internal override void UpdateParameters(Dictionary res) diff --git a/src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs b/src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs index eee3c553fe..3120b24c01 100644 --- a/src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs @@ -200,13 +200,13 @@ private void Run(IChannel ch) if (_model == null) { - if (string.IsNullOrEmpty(Args.InputModelFile)) + if (string.IsNullOrEmpty(ImplOptions.InputModelFile)) { loader = CreateLoader(); rawPred = null; trainSchema = null; - Host.CheckUserArg(Args.LoadPredictor != true, nameof(Args.LoadPredictor), - "Cannot be set to true unless " + nameof(Args.InputModelFile) + " is also specifified."); + Host.CheckUserArg(ImplOptions.LoadPredictor != true, nameof(ImplOptions.LoadPredictor), + "Cannot be set to true unless " + nameof(ImplOptions.InputModelFile) + " is also specifified."); } else LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader); @@ -220,7 +220,7 @@ private void Run(IChannel ch) var assembly = System.Reflection.Assembly.GetExecutingAssembly(); var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location); var ctx = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion, - ModelVersion, _domain, Args.OnnxVersion); + ModelVersion, _domain, ImplOptions.OnnxVersion); // Get the transform chain. IDataView source; @@ -281,13 +281,13 @@ private void Run(IChannel ch) } } - if (!string.IsNullOrWhiteSpace(Args.OutputModelFile)) + if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile)) { Contracts.Assert(loader != null); ch.Trace("Saving the data pipe"); // Should probably include "end"? - SaveLoader(loader, Args.OutputModelFile); + SaveLoader(loader, ImplOptions.OutputModelFile); } } diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs index bfe29b5e0c..7d5b0d50b0 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -15,10 +15,10 @@ namespace Microsoft.ML.Trainers { - public abstract class LbfgsTrainerBase : TrainerEstimatorBase + public abstract class LbfgsTrainerBase : TrainerEstimatorBase where TTransformer : ISingleFeaturePredictionTransformer where TModel : class - where TArgs : LbfgsTrainerBase.OptionsBase, new () + where TOptions : LbfgsTrainerBase.OptionsBase, new () { public abstract class OptionsBase : LearnerInputBaseWithWeight { @@ -103,7 +103,7 @@ internal static class Defaults } } - private const string RegisterName = nameof(LbfgsTrainerBase); + private const string RegisterName = nameof(LbfgsTrainerBase); private protected int NumFeatures; private protected VBuffer CurrentWeights; @@ -113,7 +113,7 @@ internal static class Defaults private IPredictor _srcPredictor; - private protected readonly TArgs Args; + private protected readonly TOptions LbfgsTrainerOptions; private protected readonly float L2Weight; private protected readonly float L1Weight; private protected readonly float OptTol; @@ -160,7 +160,7 @@ internal LbfgsTrainerBase(IHostEnvironment env, float optimizationTolerance, int memorySize, bool enforceNoNegativity) - : this(env, new TArgs + : this(env, new TOptions { FeatureColumn = featureColumn, LabelColumn = labelColumn.Name, @@ -176,48 +176,48 @@ internal LbfgsTrainerBase(IHostEnvironment env, } internal LbfgsTrainerBase(IHostEnvironment env, - TArgs args, + TOptions options, SchemaShape.Column labelColumn, - Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), - labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + Action advancedSettings = null) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(options.FeatureColumn), + labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn)) { - Host.CheckValue(args, nameof(args)); - Args = args; + Host.CheckValue(options, nameof(options)); + LbfgsTrainerOptions = options; // Apply the advanced args, if the user supplied any. - advancedSettings?.Invoke(args); - - args.FeatureColumn = FeatureColumn.Name; - args.LabelColumn = LabelColumn.Name; - args.WeightColumn = WeightColumn.Name; - Host.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null, - nameof(Args.NumThreads), "numThreads must be positive (or empty for default)"); - Host.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative"); - Host.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative"); - Host.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive"); - Host.CheckUserArg(Args.MemorySize > 0, nameof(Args.MemorySize), "Must be positive"); - Host.CheckUserArg(Args.MaxIterations > 0, nameof(Args.MaxIterations), "Must be positive"); - Host.CheckUserArg(Args.SgdInitializationTolerance >= 0, nameof(Args.SgdInitializationTolerance), "Must be non-negative"); - Host.CheckUserArg(Args.NumThreads == null || Args.NumThreads.Value >= 0, nameof(Args.NumThreads), "Must be non-negative"); - - Host.CheckParam(!(Args.L2Weight < 0), nameof(Args.L2Weight), "Must be non-negative, if provided."); - Host.CheckParam(!(Args.L1Weight < 0), nameof(Args.L1Weight), "Must be non-negative, if provided"); - Host.CheckParam(!(Args.OptTol <= 0), nameof(Args.OptTol), "Must be positive, if provided."); - Host.CheckParam(!(Args.MemorySize <= 0), nameof(Args.MemorySize), "Must be positive, if provided."); - - L2Weight = Args.L2Weight; - L1Weight = Args.L1Weight; - OptTol = Args.OptTol; - MemorySize =Args.MemorySize; - MaxIterations = Args.MaxIterations; - SgdInitializationTolerance = Args.SgdInitializationTolerance; - Quiet = Args.Quiet; - InitWtsDiameter = Args.InitWtsDiameter; - UseThreads = Args.UseThreads; - NumThreads = Args.NumThreads; - DenseOptimizer = Args.DenseOptimizer; - EnforceNonNegativity = Args.EnforceNonNegativity; + advancedSettings?.Invoke(options); + + options.FeatureColumn = FeatureColumn.Name; + options.LabelColumn = LabelColumn.Name; + options.WeightColumn = WeightColumn.Name; + Host.CheckUserArg(!LbfgsTrainerOptions.UseThreads || LbfgsTrainerOptions.NumThreads > 0 || LbfgsTrainerOptions.NumThreads == null, + nameof(LbfgsTrainerOptions.NumThreads), "numThreads must be positive (or empty for default)"); + Host.CheckUserArg(LbfgsTrainerOptions.L2Weight >= 0, nameof(LbfgsTrainerOptions.L2Weight), "Must be non-negative"); + Host.CheckUserArg(LbfgsTrainerOptions.L1Weight >= 0, nameof(LbfgsTrainerOptions.L1Weight), "Must be non-negative"); + Host.CheckUserArg(LbfgsTrainerOptions.OptTol > 0, nameof(LbfgsTrainerOptions.OptTol), "Must be positive"); + Host.CheckUserArg(LbfgsTrainerOptions.MemorySize > 0, nameof(LbfgsTrainerOptions.MemorySize), "Must be positive"); + Host.CheckUserArg(LbfgsTrainerOptions.MaxIterations > 0, nameof(LbfgsTrainerOptions.MaxIterations), "Must be positive"); + Host.CheckUserArg(LbfgsTrainerOptions.SgdInitializationTolerance >= 0, nameof(LbfgsTrainerOptions.SgdInitializationTolerance), "Must be non-negative"); + Host.CheckUserArg(LbfgsTrainerOptions.NumThreads == null || LbfgsTrainerOptions.NumThreads.Value >= 0, nameof(LbfgsTrainerOptions.NumThreads), "Must be non-negative"); + + Host.CheckParam(!(LbfgsTrainerOptions.L2Weight < 0), nameof(LbfgsTrainerOptions.L2Weight), "Must be non-negative, if provided."); + Host.CheckParam(!(LbfgsTrainerOptions.L1Weight < 0), nameof(LbfgsTrainerOptions.L1Weight), "Must be non-negative, if provided"); + Host.CheckParam(!(LbfgsTrainerOptions.OptTol <= 0), nameof(LbfgsTrainerOptions.OptTol), "Must be positive, if provided."); + Host.CheckParam(!(LbfgsTrainerOptions.MemorySize <= 0), nameof(LbfgsTrainerOptions.MemorySize), "Must be positive, if provided."); + + L2Weight = LbfgsTrainerOptions.L2Weight; + L1Weight = LbfgsTrainerOptions.L1Weight; + OptTol = LbfgsTrainerOptions.OptTol; + MemorySize =LbfgsTrainerOptions.MemorySize; + MaxIterations = LbfgsTrainerOptions.MaxIterations; + SgdInitializationTolerance = LbfgsTrainerOptions.SgdInitializationTolerance; + Quiet = LbfgsTrainerOptions.Quiet; + InitWtsDiameter = LbfgsTrainerOptions.InitWtsDiameter; + UseThreads = LbfgsTrainerOptions.UseThreads; + NumThreads = LbfgsTrainerOptions.NumThreads; + DenseOptimizer = LbfgsTrainerOptions.DenseOptimizer; + EnforceNonNegativity = LbfgsTrainerOptions.EnforceNonNegativity; if (EnforceNonNegativity && ShowTrainingStats) { @@ -232,7 +232,7 @@ internal LbfgsTrainerBase(IHostEnvironment env, _srcPredictor = default; } - private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn, + private static TOptions ArgsInit(string featureColumn, SchemaShape.Column labelColumn, string weightColumn, float l1Weight, float l2Weight, @@ -240,7 +240,7 @@ private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColu int memorySize, bool enforceNoNegativity) { - var args = new TArgs + var args = new TOptions { FeatureColumn = featureColumn, LabelColumn = labelColumn.Name, diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs index 2a268e5c23..960e227983 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs @@ -92,7 +92,7 @@ internal LogisticRegression(IHostEnvironment env, Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); _posWeight = 0; - ShowTrainingStats = Args.ShowTrainingStats; + ShowTrainingStats = LbfgsTrainerOptions.ShowTrainingStats; } /// @@ -102,7 +102,7 @@ internal LogisticRegression(IHostEnvironment env, Options options) : base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn)) { _posWeight = 0; - ShowTrainingStats = Args.ShowTrainingStats; + ShowTrainingStats = LbfgsTrainerOptions.ShowTrainingStats; } private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification; @@ -355,11 +355,11 @@ private protected override void ComputeTrainingStatistics(IChannel ch, FloatLabe } } - if (Args.StdComputer == null) + if (LbfgsTrainerOptions.StdComputer == null) _stats = new LinearModelStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance); else { - var std = Args.StdComputer.ComputeStd(hessian, weightIndices, numParams, CurrentWeights.Length, ch, L2Weight); + var std = LbfgsTrainerOptions.StdComputer.ComputeStd(hessian, weightIndices, numParams, CurrentWeights.Length, ch, L2Weight); _stats = new LinearModelStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance, std); } } diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 59043db03a..815b7b30ba 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -95,7 +95,7 @@ internal MulticlassLogisticRegression(IHostEnvironment env, Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - ShowTrainingStats = Args.ShowTrainingStats; + ShowTrainingStats = LbfgsTrainerOptions.ShowTrainingStats; } /// @@ -104,7 +104,7 @@ internal MulticlassLogisticRegression(IHostEnvironment env, internal MulticlassLogisticRegression(IHostEnvironment env, Options options) : base(env, options, TrainerUtils.MakeU4ScalarColumn(options.LabelColumn)) { - ShowTrainingStats = Args.ShowTrainingStats; + ShowTrainingStats = LbfgsTrainerOptions.ShowTrainingStats; } private protected override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs index c49d53cdbf..ccef0f3e20 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs @@ -21,7 +21,7 @@ public abstract class MetaMulticlassTrainer : ITrainerEsti where TTransformer : ISingleFeaturePredictionTransformer where TModel : class { - public abstract class ArgumentsBase + public abstract class OptionsBase { [Argument(ArgumentType.Multiple, HelpText = "Base predictor", ShortName = "p", SortOrder = 4, SignatureType = typeof(SignatureBinaryClassifierTrainer))] [TGUI(Label = "Predictor Type", Description = "Type of underlying binary predictor")] @@ -42,7 +42,7 @@ public abstract class ArgumentsBase /// public readonly SchemaShape.Column LabelColumn; - private protected readonly ArgumentsBase Args; + private protected readonly OptionsBase Args; private protected readonly IHost Host; private protected readonly ICalibratorTrainer Calibrator; private protected readonly TScalarTrainer Trainer; @@ -55,20 +55,20 @@ public abstract class ArgumentsBase public TrainerInfo Info { get; } /// - /// Initializes the from the class. + /// Initializes the from the class. /// /// The private instance of the . - /// The legacy arguments class. + /// The legacy arguments class. /// The component name. /// The label column for the metalinear trainer and the binary trainer. /// The binary estimator. /// The calibrator. If a calibrator is not explicitly provided, it will default to - internal MetaMulticlassTrainer(IHostEnvironment env, ArgumentsBase args, string name, string labelColumn = null, + internal MetaMulticlassTrainer(IHostEnvironment env, OptionsBase options, string name, string labelColumn = null, TScalarTrainer singleEstimator = null, ICalibratorTrainer calibrator = null) { Host = Contracts.CheckRef(env, nameof(env)).Register(name); - Host.CheckValue(args, nameof(args)); - Args = args; + Host.CheckValue(options, nameof(options)); + Args = options; if (labelColumn != null) LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true); @@ -76,8 +76,8 @@ internal MetaMulticlassTrainer(IHostEnvironment env, ArgumentsBase args, string Trainer = singleEstimator ?? CreateTrainer(); Calibrator = calibrator ?? new PlattCalibratorTrainer(env); - if (args.Calibrator != null) - Calibrator = args.Calibrator.CreateComponent(Host); + if (options.Calibrator != null) + Calibrator = options.Calibrator.CreateComponent(Host); // Regarding caching, no matter what the internal predictor, we're performing many passes // simply by virtue of this being a meta-trainer, so we will still cache. diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs index a4a3d50316..6f25976ee2 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs @@ -52,7 +52,7 @@ public sealed class Ova : MetaMulticlassTrainer /// Options passed to OVA. /// - internal sealed class Options : ArgumentsBase + internal sealed class Options : OptionsBase { /// /// Whether to use probabilities (vs. raw outputs) to identify top-score category. diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs index 6587678e99..30fc0e36aa 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs @@ -64,7 +64,7 @@ public sealed class Pkpd : MetaMulticlassTrainer /// Options passed to PKPD. /// - internal sealed class Options : ArgumentsBase + internal sealed class Options : OptionsBase { } /// diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs index b1678fd7b8..31507cdd15 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs @@ -146,10 +146,10 @@ private protected virtual int ComputeNumThreads(FloatLabelCursor.Factory cursorF } } - public abstract class SdcaTrainerBase : StochasticTrainerBase + public abstract class SdcaTrainerBase : StochasticTrainerBase where TTransformer : ISingleFeaturePredictionTransformer where TModel : class - where TArgs : SdcaTrainerBase.OptionsBase, new() + where TOptions : SdcaTrainerBase.OptionsBase, new() { // REVIEW: Making it even faster and more accurate: // 1. Train with not-too-many threads. nt = 2 or 4 seems to be good enough. Didn't seem additional benefit over more threads. @@ -239,16 +239,16 @@ private protected enum MetricKind // substantial additional benefits in terms of accuracy. private const long MaxDualTableSize = 1L << 50; private const float L2LowerBound = 1e-09f; - private protected readonly TArgs Args; + private protected readonly TOptions SdcaTrainerOptions; private protected ISupportSdcaLoss Loss; - private protected override bool ShuffleData => Args.Shuffle; + private protected override bool ShuffleData => SdcaTrainerOptions.Shuffle; - private const string RegisterName = nameof(SdcaTrainerBase); + private const string RegisterName = nameof(SdcaTrainerBase); - private static TArgs ArgsInit(string featureColumnName, SchemaShape.Column labelColumn) + private static TOptions ArgsInit(string featureColumnName, SchemaShape.Column labelColumn) { - var args = new TArgs(); + var args = new TOptions(); args.FeatureColumn = featureColumnName; args.LabelColumn = labelColumn.Name; @@ -262,15 +262,15 @@ internal SdcaTrainerBase(IHostEnvironment env, string featureColumnName, SchemaS { } - internal SdcaTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label, SchemaShape.Column weight = default, + internal SdcaTrainerBase(IHostEnvironment env, TOptions options, SchemaShape.Column label, SchemaShape.Column weight = default, float? l2Const = null, float? l1Threshold = null, int? maxIterations = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, weight) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(options.FeatureColumn), label, weight) { - Args = args; - Args.L2Const = l2Const ?? args.L2Const; - Args.L1Threshold = l1Threshold ?? args.L1Threshold; - Args.MaxIterations = maxIterations ?? args.MaxIterations; - Args.Check(env); + SdcaTrainerOptions = options; + SdcaTrainerOptions.L2Const = l2Const ?? options.L2Const; + SdcaTrainerOptions.L1Threshold = l1Threshold ?? options.L1Threshold; + SdcaTrainerOptions.MaxIterations = maxIterations ?? options.MaxIterations; + SdcaTrainerOptions.Check(env); } private protected float WDot(in VBuffer features, in VBuffer weights, float bias) @@ -292,9 +292,9 @@ private protected sealed override TModel TrainCore(IChannel ch, RoleMappedData d var cursorFactory = new FloatLabelCursor.Factory(data, cursorOpt); int numThreads; - if (Args.NumThreads.HasValue) + if (SdcaTrainerOptions.NumThreads.HasValue) { - numThreads = Args.NumThreads.Value; + numThreads = SdcaTrainerOptions.NumThreads.Value; Host.CheckUserArg(numThreads > 0, nameof(OptionsBase.NumThreads), "The number of threads must be either null or a positive integer."); if (0 < Host.ConcurrencyFactor && Host.ConcurrencyFactor < numThreads) { @@ -313,8 +313,8 @@ private protected sealed override TModel TrainCore(IChannel ch, RoleMappedData d ch.Info("Using {0} threads to train.", numThreads); int checkFrequency = 0; - if (Args.CheckFrequency.HasValue) - checkFrequency = Args.CheckFrequency.Value; + if (SdcaTrainerOptions.CheckFrequency.HasValue) + checkFrequency = SdcaTrainerOptions.CheckFrequency.Value; else { checkFrequency = numThreads; @@ -414,19 +414,19 @@ private protected sealed override TModel TrainCore(IChannel ch, RoleMappedData d ch.Check(count > 0, "Training set has 0 instances, aborting training."); // Tune the default hyperparameters based on dataset size. - if (Args.MaxIterations == null) - Args.MaxIterations = TuneDefaultMaxIterations(ch, count, numThreads); + if (SdcaTrainerOptions.MaxIterations == null) + SdcaTrainerOptions.MaxIterations = TuneDefaultMaxIterations(ch, count, numThreads); - Contracts.Assert(Args.MaxIterations.HasValue); - if (Args.L2Const == null) - Args.L2Const = TuneDefaultL2(ch, Args.MaxIterations.Value, count, numThreads); + Contracts.Assert(SdcaTrainerOptions.MaxIterations.HasValue); + if (SdcaTrainerOptions.L2Const == null) + SdcaTrainerOptions.L2Const = TuneDefaultL2(ch, SdcaTrainerOptions.MaxIterations.Value, count, numThreads); - Contracts.Assert(Args.L2Const.HasValue); - if (Args.L1Threshold == null) - Args.L1Threshold = TuneDefaultL1(ch, numFeatures); + Contracts.Assert(SdcaTrainerOptions.L2Const.HasValue); + if (SdcaTrainerOptions.L1Threshold == null) + SdcaTrainerOptions.L1Threshold = TuneDefaultL1(ch, numFeatures); - ch.Assert(Args.L1Threshold.HasValue); - var l1Threshold = Args.L1Threshold.Value; + ch.Assert(SdcaTrainerOptions.L1Threshold.HasValue); + var l1Threshold = SdcaTrainerOptions.L1Threshold.Value; var l1ThresholdZero = l1Threshold == 0; var weights = new VBuffer[weightSetCount]; var bestWeights = new VBuffer[weightSetCount]; @@ -455,8 +455,8 @@ private protected sealed override TModel TrainCore(IChannel ch, RoleMappedData d int bestIter = 0; var bestPrimalLoss = double.PositiveInfinity; - ch.Assert(Args.L2Const.HasValue); - var l2Const = Args.L2Const.Value; + ch.Assert(SdcaTrainerOptions.L2Const.HasValue); + var l2Const = SdcaTrainerOptions.L2Const.Value; float lambdaNInv = 1 / (l2Const * count); DualsTableBase duals = null; @@ -519,8 +519,8 @@ private protected sealed override TModel TrainCore(IChannel ch, RoleMappedData d ch.AssertValue(metricNames); ch.AssertValue(metrics); ch.Assert(metricNames.Length == metrics.Length); - ch.Assert(Args.MaxIterations.HasValue); - var maxIterations = Args.MaxIterations.Value; + ch.Assert(SdcaTrainerOptions.MaxIterations.HasValue); + var maxIterations = SdcaTrainerOptions.MaxIterations.Value; var rands = new Random[maxIterations]; for (int i = 0; i < maxIterations; i++) @@ -544,7 +544,7 @@ private protected sealed override TModel TrainCore(IChannel ch, RoleMappedData d int idx = (int)longIdx; var features = cursor.Features; var normSquared = VectorUtils.NormSquared(features); - if (Args.BiasLearningRate == 0) + if (SdcaTrainerOptions.BiasLearningRate == 0) normSquared += 1; if (featureNormSquared != null) @@ -755,17 +755,17 @@ private protected virtual void TrainWithoutLock(IProgressChannelProvider progres VBuffer[] weights, float[] biasUnreg, VBuffer[] l1IntermediateWeights, float[] l1IntermediateBias, float[] featureNormSquared) { Contracts.AssertValueOrNull(progress); - Contracts.Assert(Args.L1Threshold.HasValue); + Contracts.Assert(SdcaTrainerOptions.L1Threshold.HasValue); Contracts.AssertValueOrNull(idToIdx); Contracts.AssertValueOrNull(invariants); Contracts.AssertValueOrNull(featureNormSquared); int maxUpdateTrials = 2 * numThreads; - var l1Threshold = Args.L1Threshold.Value; + var l1Threshold = SdcaTrainerOptions.L1Threshold.Value; bool l1ThresholdZero = l1Threshold == 0; - var lr = Args.BiasLearningRate * Args.L2Const.Value; + var lr = SdcaTrainerOptions.BiasLearningRate * SdcaTrainerOptions.L2Const.Value; var pch = progress != null ? progress.StartProgressChannel("Dual update") : null; using (pch) - using (var cursor = Args.Shuffle ? cursorFactory.Create(rand) : cursorFactory.Create()) + using (var cursor = SdcaTrainerOptions.Shuffle ? cursorFactory.Create(rand) : cursorFactory.Create()) { long rowCount = 0; if (pch != null) @@ -784,7 +784,7 @@ private protected virtual void TrainWithoutLock(IProgressChannelProvider progres { Contracts.Assert(featureNormSquared == null); var featuresNormSquared = VectorUtils.NormSquared(features); - if (Args.BiasLearningRate == 0) + if (SdcaTrainerOptions.BiasLearningRate == 0) featuresNormSquared += 1; invariant = Loss.ComputeDualUpdateInvariant(featuresNormSquared * lambdaNInv * GetInstanceWeight(cursor)); @@ -830,7 +830,7 @@ private protected virtual void TrainWithoutLock(IProgressChannelProvider progres //Thresholding: if |v[j]| < threshold, turn off weights[j] //If not, shrink: w[j] = v[i] - sign(v[j]) * threshold l1IntermediateBias[0] += primalUpdate; - if (Args.BiasLearningRate == 0) + if (SdcaTrainerOptions.BiasLearningRate == 0) { biasReg[0] = Math.Abs(l1IntermediateBias[0]) - l1Threshold > 0.0 ? l1IntermediateBias[0] - Math.Sign(l1IntermediateBias[0]) * l1Threshold @@ -950,10 +950,10 @@ private protected virtual bool CheckConvergence( Host.Assert(idToIdx == null || row == duals.Length); } - Contracts.Assert(Args.L2Const.HasValue); - Contracts.Assert(Args.L1Threshold.HasValue); - Double l2Const = Args.L2Const.Value; - Double l1Threshold = Args.L1Threshold.Value; + Contracts.Assert(SdcaTrainerOptions.L2Const.HasValue); + Contracts.Assert(SdcaTrainerOptions.L1Threshold.HasValue); + Double l2Const = SdcaTrainerOptions.L2Const.Value; + Double l1Threshold = SdcaTrainerOptions.L1Threshold.Value; Double l1Regularizer = l1Threshold * l2Const * (VectorUtils.L1Norm(in weights[0]) + Math.Abs(biasReg[0])); var l2Regularizer = l2Const * (VectorUtils.NormSquared(weights[0]) + biasReg[0] * biasReg[0]) * 0.5; var newLoss = lossSum.Sum / count + l2Regularizer + l1Regularizer; @@ -965,9 +965,9 @@ private protected virtual bool CheckConvergence( var dualityGap = metrics[(int)MetricKind.DualityGap] = newLoss - newDualLoss; metrics[(int)MetricKind.BiasUnreg] = biasUnreg[0]; metrics[(int)MetricKind.BiasReg] = biasReg[0]; - metrics[(int)MetricKind.L1Sparsity] = Args.L1Threshold == 0 ? 1 : (Double)firstWeights.GetValues().Count(w => w != 0) / weights.Length; + metrics[(int)MetricKind.L1Sparsity] = SdcaTrainerOptions.L1Threshold == 0 ? 1 : (Double)firstWeights.GetValues().Count(w => w != 0) / weights.Length; - bool converged = dualityGap / newLoss < Args.ConvergenceTolerance; + bool converged = dualityGap / newLoss < SdcaTrainerOptions.ConvergenceTolerance; if (metrics[(int)MetricKind.Loss] < bestPrimalLoss) { @@ -1403,13 +1403,13 @@ public void Add(Double summand) /// where is and is . /// public abstract class SdcaBinaryTrainerBase : - SdcaTrainerBase.BinaryArgumentBase, BinaryPredictionTransformer, TModelParameters> + SdcaTrainerBase.BinaryOptionsBase, BinaryPredictionTransformer, TModelParameters> where TModelParameters : class { private readonly ISupportSdcaClassificationLoss _loss; private readonly float _positiveInstanceWeight; - private protected override bool ShuffleData => Args.Shuffle; + private protected override bool ShuffleData => SdcaTrainerOptions.Shuffle; private readonly SchemaShape.Column[] _outputColumns; @@ -1419,7 +1419,7 @@ public abstract class SdcaBinaryTrainerBase : public override TrainerInfo Info { get; } - public class BinaryArgumentBase : OptionsBase + public class BinaryOptionsBase : OptionsBase { [Argument(ArgumentType.AtMostOnce, HelpText = "Apply weight to the positive class, for imbalanced data", ShortName = "piw")] public float PositiveInstanceWeight = 1; @@ -1458,17 +1458,17 @@ private protected SdcaBinaryTrainerBase(IHostEnvironment env, _loss = loss ?? new LogLossFactory().CreateComponent(env); Loss = _loss; Info = new TrainerInfo(calibration: false); - _positiveInstanceWeight = Args.PositiveInstanceWeight; + _positiveInstanceWeight = SdcaTrainerOptions.PositiveInstanceWeight; _outputColumns = ComputeSdcaBinaryClassifierSchemaShape(); } - private protected SdcaBinaryTrainerBase(IHostEnvironment env, BinaryArgumentBase options, ISupportSdcaClassificationLoss loss = null, bool doCalibration = false) + private protected SdcaBinaryTrainerBase(IHostEnvironment env, BinaryOptionsBase options, ISupportSdcaClassificationLoss loss = null, bool doCalibration = false) : base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn)) { _loss = loss ?? new LogLossFactory().CreateComponent(env); Loss = _loss; Info = new TrainerInfo(calibration: doCalibration); - _positiveInstanceWeight = Args.PositiveInstanceWeight; + _positiveInstanceWeight = SdcaTrainerOptions.PositiveInstanceWeight; _outputColumns = ComputeSdcaBinaryClassifierSchemaShape(); } @@ -1523,7 +1523,7 @@ public sealed class SdcaBinaryTrainer : /// /// Configuration to training logistic regression using SDCA. /// - public sealed class Options : BinaryArgumentBase + public sealed class Options : BinaryOptionsBase { } @@ -1582,7 +1582,7 @@ public sealed class SdcaNonCalibratedBinaryTrainer : SdcaBinaryTrainerBase /// General Configuration to training linear model using SDCA. /// - public sealed class Options : BinaryArgumentBase + public sealed class Options : BinaryOptionsBase { [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory(); @@ -1646,7 +1646,7 @@ internal sealed class LegacySdcaBinaryTrainer : SdcaBinaryTrainerBase /// Legacy configuration to SDCA in legacy framework. /// - public sealed class Options : BinaryArgumentBase + public sealed class Options : BinaryOptionsBase { [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory(); diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index c278780321..2ca0daeeb9 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -69,7 +69,7 @@ internal SdcaMultiClassTrainer(IHostEnvironment env, { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - _loss = loss ?? Args.LossFunction.CreateComponent(env); + _loss = loss ?? SdcaTrainerOptions.LossFunction.CreateComponent(env); Loss = _loss; } @@ -122,7 +122,7 @@ private protected override void TrainWithoutLock(IProgressChannelProvider progre VBuffer[] weights, Float[] biasUnreg, VBuffer[] l1IntermediateWeights, Float[] l1IntermediateBias, Float[] featureNormSquared) { Contracts.AssertValueOrNull(progress); - Contracts.Assert(Args.L1Threshold.HasValue); + Contracts.Assert(SdcaTrainerOptions.L1Threshold.HasValue); Contracts.AssertValueOrNull(idToIdx); Contracts.AssertValueOrNull(invariants); Contracts.AssertValueOrNull(featureNormSquared); @@ -131,13 +131,13 @@ private protected override void TrainWithoutLock(IProgressChannelProvider progre Contracts.Assert(Utils.Size(biasUnreg) == numClasses); int maxUpdateTrials = 2 * numThreads; - var l1Threshold = Args.L1Threshold.Value; + var l1Threshold = SdcaTrainerOptions.L1Threshold.Value; bool l1ThresholdZero = l1Threshold == 0; - var lr = Args.BiasLearningRate * Args.L2Const.Value; + var lr = SdcaTrainerOptions.BiasLearningRate * SdcaTrainerOptions.L2Const.Value; var pch = progress != null ? progress.StartProgressChannel("Dual update") : null; using (pch) - using (var cursor = Args.Shuffle ? cursorFactory.Create(rand) : cursorFactory.Create()) + using (var cursor = SdcaTrainerOptions.Shuffle ? cursorFactory.Create(rand) : cursorFactory.Create()) { long rowCount = 0; if (pch != null) @@ -161,7 +161,7 @@ private protected override void TrainWithoutLock(IProgressChannelProvider progre else { normSquared = VectorUtils.NormSquared(in features); - if (Args.BiasLearningRate == 0) + if (SdcaTrainerOptions.BiasLearningRate == 0) normSquared += 1; invariant = _loss.ComputeDualUpdateInvariant(2 * normSquared * lambdaNInv * GetInstanceWeight(cursor)); @@ -231,7 +231,7 @@ private protected override void TrainWithoutLock(IProgressChannelProvider progre //Thresholding: if |v[j]| < threshold, turn off weights[j] //If not, shrink: w[j] = v[i] - sign(v[j]) * threshold l1IntermediateBias[iClass] -= primalUpdate; - if (Args.BiasLearningRate == 0) + if (SdcaTrainerOptions.BiasLearningRate == 0) { biasReg[iClass] = Math.Abs(l1IntermediateBias[iClass]) - l1Threshold > 0.0 ? l1IntermediateBias[iClass] - Math.Sign(l1IntermediateBias[iClass]) * l1Threshold @@ -352,10 +352,10 @@ private protected override bool CheckConvergence( Host.Assert(idToIdx == null || row * numClasses == duals.Length); } - Contracts.Assert(Args.L2Const.HasValue); - Contracts.Assert(Args.L1Threshold.HasValue); - Double l2Const = Args.L2Const.Value; - Double l1Threshold = Args.L1Threshold.Value; + Contracts.Assert(SdcaTrainerOptions.L2Const.HasValue); + Contracts.Assert(SdcaTrainerOptions.L1Threshold.HasValue); + Double l2Const = SdcaTrainerOptions.L2Const.Value; + Double l1Threshold = SdcaTrainerOptions.L1Threshold.Value; Double weightsL1Norm = 0; Double weightsL2NormSquared = 0; @@ -367,7 +367,7 @@ private protected override bool CheckConvergence( biasRegularizationAdjustment += biasReg[iClass] * biasUnreg[iClass]; } - Double l1Regularizer = Args.L1Threshold.Value * l2Const * weightsL1Norm; + Double l1Regularizer = SdcaTrainerOptions.L1Threshold.Value * l2Const * weightsL1Norm; var l2Regularizer = l2Const * weightsL2NormSquared * 0.5; var newLoss = lossSum.Sum / count + l2Regularizer + l1Regularizer; @@ -379,10 +379,10 @@ private protected override bool CheckConvergence( metrics[(int)MetricKind.DualityGap] = dualityGap; metrics[(int)MetricKind.BiasUnreg] = biasUnreg[0]; metrics[(int)MetricKind.BiasReg] = biasReg[0]; - metrics[(int)MetricKind.L1Sparsity] = Args.L1Threshold == 0 ? 1 : weights.Sum( + metrics[(int)MetricKind.L1Sparsity] = SdcaTrainerOptions.L1Threshold == 0 ? 1 : weights.Sum( weight => weight.GetValues().Count(w => w != 0)) / (numClasses * numFeatures); - bool converged = dualityGap / newLoss < Args.ConvergenceTolerance; + bool converged = dualityGap / newLoss < SdcaTrainerOptions.ConvergenceTolerance; if (metrics[(int)MetricKind.Loss] < bestPrimalLoss) { diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs index 95690246f0..e9afb80fc9 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs @@ -73,7 +73,7 @@ internal SdcaRegressionTrainer(IHostEnvironment env, { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - _loss = loss ?? Args.LossFunction.CreateComponent(env); + _loss = loss ?? SdcaTrainerOptions.LossFunction.CreateComponent(env); Loss = _loss; } diff --git a/src/Microsoft.ML.Sweeper/Algorithms/Grid.cs b/src/Microsoft.ML.Sweeper/Algorithms/Grid.cs index 281fdbc0ce..e4e45dbb2c 100644 --- a/src/Microsoft.ML.Sweeper/Algorithms/Grid.cs +++ b/src/Microsoft.ML.Sweeper/Algorithms/Grid.cs @@ -35,32 +35,32 @@ public class OptionsBase public int Retries = 10; } - private readonly OptionsBase _args; + private readonly OptionsBase _options; protected readonly IValueGenerator[] SweepParameters; protected readonly IHost Host; - protected SweeperBase(OptionsBase args, IHostEnvironment env, string name) + protected SweeperBase(OptionsBase options, IHostEnvironment env, string name) { Contracts.CheckValue(env, nameof(env)); env.CheckNonWhiteSpace(name, nameof(name)); Host = env.Register(name); - Host.CheckValue(args, nameof(args)); - Host.CheckNonEmpty(args.SweptParameters, nameof(args.SweptParameters)); + Host.CheckValue(options, nameof(options)); + Host.CheckNonEmpty(options.SweptParameters, nameof(options.SweptParameters)); - _args = args; + _options = options; - SweepParameters = args.SweptParameters.Select(p => p.CreateComponent(Host)).ToArray(); + SweepParameters = options.SweptParameters.Select(p => p.CreateComponent(Host)).ToArray(); } - protected SweeperBase(OptionsBase args, IHostEnvironment env, IValueGenerator[] sweepParameters, string name) + protected SweeperBase(OptionsBase options, IHostEnvironment env, IValueGenerator[] sweepParameters, string name) { Contracts.CheckValue(env, nameof(env)); env.CheckNonWhiteSpace(name, nameof(name)); Host = env.Register(name); - Host.CheckValue(args, nameof(args)); + Host.CheckValue(options, nameof(options)); Host.CheckValue(sweepParameters, nameof(sweepParameters)); - _args = args; + _options = options; SweepParameters = sweepParameters; } @@ -76,7 +76,7 @@ public virtual ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable public sealed class UniformRandomSweeper : SweeperBase { - public UniformRandomSweeper(IHostEnvironment env, OptionsBase args) - : base(args, env, "UniformRandom") + public UniformRandomSweeper(IHostEnvironment env, OptionsBase options) + : base(options, env, "UniformRandom") { } - public UniformRandomSweeper(IHostEnvironment env, OptionsBase args, IValueGenerator[] sweepParameters) - : base(args, env, sweepParameters, "UniformRandom") + public UniformRandomSweeper(IHostEnvironment env, OptionsBase options, IValueGenerator[] sweepParameters) + : base(options, env, sweepParameters, "UniformRandom") { } diff --git a/src/Microsoft.ML.Sweeper/ConfigRunner.cs b/src/Microsoft.ML.Sweeper/ConfigRunner.cs index 5f2e5269f7..a301bcf521 100644 --- a/src/Microsoft.ML.Sweeper/ConfigRunner.cs +++ b/src/Microsoft.ML.Sweeper/ConfigRunner.cs @@ -33,7 +33,7 @@ public abstract class ExeConfigRunnerBase : IConfigRunner public abstract class OptionsBase { [Argument(ArgumentType.AtMostOnce, HelpText = "Command pattern for the sweeps", ShortName = "pattern")] - public string ArgsPattern; + public string OptionsPattern; [Argument(ArgumentType.AtMostOnce, HelpText = "output folder for the outputs of the sweeps", ShortName = "outfolder")] public string OutputFolderName; @@ -53,7 +53,7 @@ public abstract class OptionsBase } protected string Exe; - protected readonly string ArgsPattern; + protected readonly string OptionsPattern; protected readonly string OutputFolder; protected readonly string Prefix; protected readonly ISweepResultEvaluator ResultProcessor; @@ -63,17 +63,17 @@ public abstract class OptionsBase private readonly bool _calledFromUnitTestSuite; - protected ExeConfigRunnerBase(OptionsBase args, IHostEnvironment env, string registrationName) + protected ExeConfigRunnerBase(OptionsBase options, IHostEnvironment env, string registrationName) { Contracts.AssertValue(env); Host = env.Register(registrationName); - Host.CheckUserArg(!string.IsNullOrEmpty(args.ArgsPattern), nameof(args.ArgsPattern), "The command pattern is missing"); - Host.CheckUserArg(!string.IsNullOrEmpty(args.OutputFolderName), nameof(args.OutputFolderName), "Please specify an output folder"); - ArgsPattern = args.ArgsPattern; - OutputFolder = GetOutputFolderPath(args.OutputFolderName); - Prefix = string.IsNullOrEmpty(args.Prefix) ? "" : args.Prefix; - ResultProcessor = args.ResultProcessor.CreateComponent(Host); - _calledFromUnitTestSuite = args.CalledFromUnitTestSuite; + Host.CheckUserArg(!string.IsNullOrEmpty(options.OptionsPattern), nameof(options.OptionsPattern), "The command pattern is missing"); + Host.CheckUserArg(!string.IsNullOrEmpty(options.OutputFolderName), nameof(options.OutputFolderName), "Please specify an output folder"); + OptionsPattern = options.OptionsPattern; + OutputFolder = GetOutputFolderPath(options.OutputFolderName); + Prefix = string.IsNullOrEmpty(options.Prefix) ? "" : options.Prefix; + ResultProcessor = options.ResultProcessor.CreateComponent(Host); + _calledFromUnitTestSuite = options.CalledFromUnitTestSuite; RunNums = new List(); } @@ -151,10 +151,10 @@ public virtual string GetOutputFolderPath(string folderName) // you get lr=$ only as argument because $LR is variable and empty. protected string GetCommandLine(ParameterSet sweep) { - var arguments = ArgsPattern; + var options = OptionsPattern; foreach (var parameterValue in sweep) - arguments = arguments.Replace("$" + parameterValue.Name + "$", parameterValue.ValueText); - return arguments; + options = options.Replace("$" + parameterValue.Name + "$", parameterValue.ValueText); + return options; } public IEnumerable RunConfigs(ParameterSet[] sweeps, int min) diff --git a/src/Microsoft.ML.Sweeper/Parameters.cs b/src/Microsoft.ML.Sweeper/Parameters.cs index c9de3d2914..6f484c3dee 100644 --- a/src/Microsoft.ML.Sweeper/Parameters.cs +++ b/src/Microsoft.ML.Sweeper/Parameters.cs @@ -13,24 +13,24 @@ using Microsoft.ML.Sweeper; using Float = System.Single; -[assembly: LoadableClass(typeof(LongValueGenerator), typeof(LongParamArguments), typeof(SignatureSweeperParameter), +[assembly: LoadableClass(typeof(LongValueGenerator), typeof(LongParamOptions), typeof(SignatureSweeperParameter), "Long parameter", "lp")] -[assembly: LoadableClass(typeof(FloatValueGenerator), typeof(FloatParamArguments), typeof(SignatureSweeperParameter), +[assembly: LoadableClass(typeof(FloatValueGenerator), typeof(FloatParamOptions), typeof(SignatureSweeperParameter), "Float parameter", "fp")] -[assembly: LoadableClass(typeof(DiscreteValueGenerator), typeof(DiscreteParamArguments), typeof(SignatureSweeperParameter), +[assembly: LoadableClass(typeof(DiscreteValueGenerator), typeof(DiscreteParamOptions), typeof(SignatureSweeperParameter), "Discrete parameter", "dp")] namespace Microsoft.ML.Sweeper { public delegate void SignatureSweeperParameter(); - public abstract class BaseParamArguments + public abstract class BaseParamOptions { [Argument(ArgumentType.Required, HelpText = "Parameter name", ShortName = "n")] public string Name; } - public abstract class NumericParamArguments : BaseParamArguments + public abstract class NumericParamOptions : BaseParamOptions { [Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of steps for grid runthrough.", ShortName = "steps")] public int NumSteps = 100; @@ -42,7 +42,7 @@ public abstract class NumericParamArguments : BaseParamArguments public bool LogBase = false; } - public class FloatParamArguments : NumericParamArguments + public class FloatParamOptions : NumericParamOptions { [Argument(ArgumentType.Required, HelpText = "Minimum value")] public Float Min; @@ -51,7 +51,7 @@ public class FloatParamArguments : NumericParamArguments public Float Max; } - public class LongParamArguments : NumericParamArguments + public class LongParamOptions : NumericParamOptions { [Argument(ArgumentType.Required, HelpText = "Minimum value")] public long Min; @@ -60,7 +60,7 @@ public class LongParamArguments : NumericParamArguments public long Max; } - public class DiscreteParamArguments : BaseParamArguments + public class DiscreteParamOptions : BaseParamOptions { [Argument(ArgumentType.Multiple, HelpText = "Values", ShortName = "v")] public string[] Values = null; @@ -211,39 +211,39 @@ public interface INumericValueGenerator : IValueGenerator /// public class LongValueGenerator : INumericValueGenerator { - private readonly LongParamArguments _args; + private readonly LongParamOptions _options; private IParameterValue[] _gridValues; - public string Name { get { return _args.Name; } } + public string Name { get { return _options.Name; } } - public LongValueGenerator(LongParamArguments args) + public LongValueGenerator(LongParamOptions options) { - Contracts.Check(args.Min < args.Max, "min must be less than max"); + Contracts.Check(options.Min < options.Max, "min must be less than max"); // REVIEW: this condition can be relaxed if we change the math below to deal with it - Contracts.Check(!args.LogBase || args.Min > 0, "min must be positive if log scale is used"); - Contracts.Check(!args.LogBase || args.StepSize == null || args.StepSize > 1, "StepSize must be greater than 1 if log scale is used"); - Contracts.Check(args.LogBase || args.StepSize == null || args.StepSize > 0, "StepSize must be greater than 0 if linear scale is used"); - _args = args; + Contracts.Check(!options.LogBase || options.Min > 0, "min must be positive if log scale is used"); + Contracts.Check(!options.LogBase || options.StepSize == null || options.StepSize > 1, "StepSize must be greater than 1 if log scale is used"); + Contracts.Check(options.LogBase || options.StepSize == null || options.StepSize > 0, "StepSize must be greater than 0 if linear scale is used"); + _options = options; } // REVIEW: Is Float accurate enough? public IParameterValue CreateFromNormalized(Double normalizedValue) { long val; - if (_args.LogBase) + if (_options.LogBase) { // REVIEW: review the math below, it only works for positive Min and Max - var logBase = !_args.StepSize.HasValue - ? Math.Pow(1.0 * _args.Max / _args.Min, 1.0 / (_args.NumSteps - 1)) - : _args.StepSize.Value; - var logMax = Math.Log(_args.Max, logBase); - var logMin = Math.Log(_args.Min, logBase); - val = (long)(_args.Min * Math.Pow(logBase, normalizedValue * (logMax - logMin))); + var logBase = !_options.StepSize.HasValue + ? Math.Pow(1.0 * _options.Max / _options.Min, 1.0 / (_options.NumSteps - 1)) + : _options.StepSize.Value; + var logMax = Math.Log(_options.Max, logBase); + var logMin = Math.Log(_options.Min, logBase); + val = (long)(_options.Min * Math.Pow(logBase, normalizedValue * (logMax - logMin))); } else - val = (long)(_args.Min + normalizedValue * (_args.Max - _args.Min)); + val = (long)(_options.Min + normalizedValue * (_options.Max - _options.Min)); - return new LongParameterValue(_args.Name, val); + return new LongParameterValue(_options.Name, val); } private void EnsureParameterValues() @@ -252,39 +252,39 @@ private void EnsureParameterValues() return; var result = new List(); - if ((_args.StepSize == null && _args.NumSteps > (_args.Max - _args.Min)) || - (_args.StepSize != null && _args.StepSize <= 1)) + if ((_options.StepSize == null && _options.NumSteps > (_options.Max - _options.Min)) || + (_options.StepSize != null && _options.StepSize <= 1)) { - for (long i = _args.Min; i <= _args.Max; i++) - result.Add(new LongParameterValue(_args.Name, i)); + for (long i = _options.Min; i <= _options.Max; i++) + result.Add(new LongParameterValue(_options.Name, i)); } else { - if (_args.LogBase) + if (_options.LogBase) { // REVIEW: review the math below, it only works for positive Min and Max - var logBase = _args.StepSize ?? Math.Pow(1.0 * _args.Max / _args.Min, 1.0 / (_args.NumSteps - 1)); + var logBase = _options.StepSize ?? Math.Pow(1.0 * _options.Max / _options.Min, 1.0 / (_options.NumSteps - 1)); long prevValue = long.MinValue; - var maxPlusEpsilon = _args.Max * Math.Sqrt(logBase); - for (Double value = _args.Min; value <= maxPlusEpsilon; value *= logBase) + var maxPlusEpsilon = _options.Max * Math.Sqrt(logBase); + for (Double value = _options.Min; value <= maxPlusEpsilon; value *= logBase) { var longValue = (long)value; if (longValue > prevValue) - result.Add(new LongParameterValue(_args.Name, longValue)); + result.Add(new LongParameterValue(_options.Name, longValue)); prevValue = longValue; } } else { - var stepSize = _args.StepSize ?? (Double)(_args.Max - _args.Min) / (_args.NumSteps - 1); + var stepSize = _options.StepSize ?? (Double)(_options.Max - _options.Min) / (_options.NumSteps - 1); long prevValue = long.MinValue; - var maxPlusEpsilon = _args.Max + stepSize / 2; - for (Double value = _args.Min; value <= maxPlusEpsilon; value += stepSize) + var maxPlusEpsilon = _options.Max + stepSize / 2; + for (Double value = _options.Min; value <= maxPlusEpsilon; value += stepSize) { var longValue = (long)value; if (longValue > prevValue) - result.Add(new LongParameterValue(_args.Name, longValue)); + result.Add(new LongParameterValue(_options.Name, longValue)); prevValue = longValue; } } @@ -314,27 +314,27 @@ public Float NormalizeValue(IParameterValue value) { var valueTyped = value as LongParameterValue; Contracts.Check(valueTyped != null, "LongValueGenerator could not normalized parameter because it is not of the correct type"); - Contracts.Check(_args.Min <= valueTyped.Value && valueTyped.Value <= _args.Max, "Value not in correct range"); + Contracts.Check(_options.Min <= valueTyped.Value && valueTyped.Value <= _options.Max, "Value not in correct range"); - if (_args.LogBase) + if (_options.LogBase) { - Float logBase = (Float)(_args.StepSize ?? Math.Pow(1.0 * _args.Max / _args.Min, 1.0 / (_args.NumSteps - 1))); - return (Float)((Math.Log(valueTyped.Value, logBase) - Math.Log(_args.Min, logBase)) / (Math.Log(_args.Max, logBase) - Math.Log(_args.Min, logBase))); + Float logBase = (Float)(_options.StepSize ?? Math.Pow(1.0 * _options.Max / _options.Min, 1.0 / (_options.NumSteps - 1))); + return (Float)((Math.Log(valueTyped.Value, logBase) - Math.Log(_options.Min, logBase)) / (Math.Log(_options.Max, logBase) - Math.Log(_options.Min, logBase))); } else - return (Float)(valueTyped.Value - _args.Min) / (_args.Max - _args.Min); + return (Float)(valueTyped.Value - _options.Min) / (_options.Max - _options.Min); } public bool InRange(IParameterValue value) { var valueTyped = value as LongParameterValue; Contracts.Check(valueTyped != null, "Parameter should be of type LongParameterValue"); - return (_args.Min <= valueTyped.Value && valueTyped.Value <= _args.Max); + return (_options.Min <= valueTyped.Value && valueTyped.Value <= _options.Max); } public string ToStringParameter(IHostEnvironment env) { - return $" p=lp{{{CmdParser.GetSettings(env, _args, new LongParamArguments())}}}"; + return $" p=lp{{{CmdParser.GetSettings(env, _options, new LongParamOptions())}}}"; } } @@ -343,39 +343,39 @@ public string ToStringParameter(IHostEnvironment env) /// public class FloatValueGenerator : INumericValueGenerator { - private readonly FloatParamArguments _args; + private readonly FloatParamOptions _options; private IParameterValue[] _gridValues; - public string Name { get { return _args.Name; } } + public string Name { get { return _options.Name; } } - public FloatValueGenerator(FloatParamArguments args) + public FloatValueGenerator(FloatParamOptions options) { - Contracts.Check(args.Min < args.Max, "min must be less than max"); + Contracts.Check(options.Min < options.Max, "min must be less than max"); // REVIEW: this condition can be relaxed if we change the math below to deal with it - Contracts.Check(!args.LogBase || args.Min > 0, "min must be positive if log scale is used"); - Contracts.Check(!args.LogBase || args.StepSize == null || args.StepSize > 1, "StepSize must be greater than 1 if log scale is used"); - Contracts.Check(args.LogBase || args.StepSize == null || args.StepSize > 0, "StepSize must be greater than 0 if linear scale is used"); - _args = args; + Contracts.Check(!options.LogBase || options.Min > 0, "min must be positive if log scale is used"); + Contracts.Check(!options.LogBase || options.StepSize == null || options.StepSize > 1, "StepSize must be greater than 1 if log scale is used"); + Contracts.Check(options.LogBase || options.StepSize == null || options.StepSize > 0, "StepSize must be greater than 0 if linear scale is used"); + _options = options; } // REVIEW: Is Float accurate enough? public IParameterValue CreateFromNormalized(Double normalizedValue) { Float val; - if (_args.LogBase) + if (_options.LogBase) { // REVIEW: review the math below, it only works for positive Min and Max - var logBase = !_args.StepSize.HasValue - ? Math.Pow(1.0 * _args.Max / _args.Min, 1.0 / (_args.NumSteps - 1)) - : _args.StepSize.Value; - var logMax = Math.Log(_args.Max, logBase); - var logMin = Math.Log(_args.Min, logBase); - val = (Float)(_args.Min * Math.Pow(logBase, normalizedValue * (logMax - logMin))); + var logBase = !_options.StepSize.HasValue + ? Math.Pow(1.0 * _options.Max / _options.Min, 1.0 / (_options.NumSteps - 1)) + : _options.StepSize.Value; + var logMax = Math.Log(_options.Max, logBase); + var logMin = Math.Log(_options.Min, logBase); + val = (Float)(_options.Min * Math.Pow(logBase, normalizedValue * (logMax - logMin))); } else - val = (Float)(_args.Min + normalizedValue * (_args.Max - _args.Min)); + val = (Float)(_options.Min + normalizedValue * (_options.Max - _options.Min)); - return new FloatParameterValue(_args.Name, val); + return new FloatParameterValue(_options.Name, val); } private void EnsureParameterValues() @@ -384,31 +384,31 @@ private void EnsureParameterValues() return; var result = new List(); - if (_args.LogBase) + if (_options.LogBase) { // REVIEW: review the math below, it only works for positive Min and Max - var logBase = _args.StepSize ?? Math.Pow(1.0 * _args.Max / _args.Min, 1.0 / (_args.NumSteps - 1)); + var logBase = _options.StepSize ?? Math.Pow(1.0 * _options.Max / _options.Min, 1.0 / (_options.NumSteps - 1)); Float prevValue = Float.NegativeInfinity; - var maxPlusEpsilon = _args.Max * Math.Sqrt(logBase); - for (Double value = _args.Min; value <= maxPlusEpsilon; value *= logBase) + var maxPlusEpsilon = _options.Max * Math.Sqrt(logBase); + for (Double value = _options.Min; value <= maxPlusEpsilon; value *= logBase) { var floatValue = (Float)value; if (floatValue > prevValue) - result.Add(new FloatParameterValue(_args.Name, floatValue)); + result.Add(new FloatParameterValue(_options.Name, floatValue)); prevValue = floatValue; } } else { - var stepSize = _args.StepSize ?? (Double)(_args.Max - _args.Min) / (_args.NumSteps - 1); + var stepSize = _options.StepSize ?? (Double)(_options.Max - _options.Min) / (_options.NumSteps - 1); Float prevValue = Float.NegativeInfinity; - var maxPlusEpsilon = _args.Max + stepSize / 2; - for (Double value = _args.Min; value <= maxPlusEpsilon; value += stepSize) + var maxPlusEpsilon = _options.Max + stepSize / 2; + for (Double value = _options.Min; value <= maxPlusEpsilon; value += stepSize) { var floatValue = (Float)value; if (floatValue > prevValue) - result.Add(new FloatParameterValue(_args.Name, floatValue)); + result.Add(new FloatParameterValue(_options.Name, floatValue)); prevValue = floatValue; } } @@ -438,27 +438,27 @@ public Float NormalizeValue(IParameterValue value) { var valueTyped = value as FloatParameterValue; Contracts.Check(valueTyped != null, "FloatValueGenerator could not normalized parameter because it is not of the correct type"); - Contracts.Check(_args.Min <= valueTyped.Value && valueTyped.Value <= _args.Max, "Value not in correct range"); + Contracts.Check(_options.Min <= valueTyped.Value && valueTyped.Value <= _options.Max, "Value not in correct range"); - if (_args.LogBase) + if (_options.LogBase) { - Float logBase = (Float)(_args.StepSize ?? Math.Pow(1.0 * _args.Max / _args.Min, 1.0 / (_args.NumSteps - 1))); - return (Float)((Math.Log(valueTyped.Value, logBase) - Math.Log(_args.Min, logBase)) / (Math.Log(_args.Max, logBase) - Math.Log(_args.Min, logBase))); + Float logBase = (Float)(_options.StepSize ?? Math.Pow(1.0 * _options.Max / _options.Min, 1.0 / (_options.NumSteps - 1))); + return (Float)((Math.Log(valueTyped.Value, logBase) - Math.Log(_options.Min, logBase)) / (Math.Log(_options.Max, logBase) - Math.Log(_options.Min, logBase))); } else - return (valueTyped.Value - _args.Min) / (_args.Max - _args.Min); + return (valueTyped.Value - _options.Min) / (_options.Max - _options.Min); } public bool InRange(IParameterValue value) { var valueTyped = value as FloatParameterValue; Contracts.Check(valueTyped != null, "Parameter should be of type FloatParameterValue"); - return (_args.Min <= valueTyped.Value && valueTyped.Value <= _args.Max); + return (_options.Min <= valueTyped.Value && valueTyped.Value <= _options.Max); } public string ToStringParameter(IHostEnvironment env) { - return $" p=fp{{{CmdParser.GetSettings(env, _args, new FloatParamArguments())}}}"; + return $" p=fp{{{CmdParser.GetSettings(env, _options, new FloatParamOptions())}}}"; } } @@ -467,27 +467,27 @@ public string ToStringParameter(IHostEnvironment env) /// public class DiscreteValueGenerator : IValueGenerator { - private readonly DiscreteParamArguments _args; + private readonly DiscreteParamOptions _options; - public string Name { get { return _args.Name; } } + public string Name { get { return _options.Name; } } - public DiscreteValueGenerator(DiscreteParamArguments args) + public DiscreteValueGenerator(DiscreteParamOptions options) { - Contracts.Check(args.Values.Length > 0); - _args = args; + Contracts.Check(options.Values.Length > 0); + _options = options; } // REVIEW: Is Float accurate enough? public IParameterValue CreateFromNormalized(Double normalizedValue) { - return new StringParameterValue(_args.Name, _args.Values[(int)(_args.Values.Length * normalizedValue)]); + return new StringParameterValue(_options.Name, _options.Values[(int)(_options.Values.Length * normalizedValue)]); } public IParameterValue this[int i] { get { - return new StringParameterValue(_args.Name, _args.Values[i]); + return new StringParameterValue(_options.Name, _options.Values[i]); } } @@ -495,13 +495,13 @@ public int Count { get { - return _args.Values.Length; + return _options.Values.Length; } } public string ToStringParameter(IHostEnvironment env) { - return $" p=dp{{{CmdParser.GetSettings(env, _args, new DiscreteParamArguments())}}}"; + return $" p=dp{{{CmdParser.GetSettings(env, _options, new DiscreteParamOptions())}}}"; } } @@ -526,10 +526,10 @@ public bool TryParseParameter(string paramValue, Type paramType, string paramNam if (paramValue.Contains(',')) { - var generatorArgs = new DiscreteParamArguments(); - generatorArgs.Name = paramName; - generatorArgs.Values = paramValue.Split(','); - sweepValues = new DiscreteValueGenerator(generatorArgs); + var generatorOptions = new DiscreteParamOptions(); + generatorOptions.Name = paramName; + generatorOptions.Values = paramValue.Split(','); + sweepValues = new DiscreteValueGenerator(generatorOptions); return true; } @@ -619,17 +619,17 @@ public bool TryParseParameter(string paramValue, Type paramType, string paramNam long max; if (!long.TryParse(minStr, out min) || !long.TryParse(maxStr, out max)) return false; - var generatorArgs = new Microsoft.ML.Sweeper.LongParamArguments(); - generatorArgs.Name = paramName; - generatorArgs.Min = min; - generatorArgs.Max = max; - generatorArgs.NumSteps = numSteps; - generatorArgs.StepSize = (stepSize > 0 ? stepSize : new Nullable()); - generatorArgs.LogBase = logBase; + var generatorOptions = new Microsoft.ML.Sweeper.LongParamOptions(); + generatorOptions.Name = paramName; + generatorOptions.Min = min; + generatorOptions.Max = max; + generatorOptions.NumSteps = numSteps; + generatorOptions.StepSize = (stepSize > 0 ? stepSize : new Nullable()); + generatorOptions.LogBase = logBase; try { - sweepValues = new LongValueGenerator(generatorArgs); + sweepValues = new LongValueGenerator(generatorOptions); } catch (Exception e) { @@ -643,17 +643,17 @@ public bool TryParseParameter(string paramValue, Type paramType, string paramNam Float maxF; if (!Float.TryParse(minStr, out minF) || !Float.TryParse(maxStr, out maxF)) return false; - var floatArgs = new FloatParamArguments(); - floatArgs.Name = paramName; - floatArgs.Min = minF; - floatArgs.Max = maxF; - floatArgs.NumSteps = numSteps; - floatArgs.StepSize = (stepSize > 0 ? stepSize : new Nullable()); - floatArgs.LogBase = logBase; + var floatOptions = new FloatParamOptions(); + floatOptions.Name = paramName; + floatOptions.Min = minF; + floatOptions.Max = maxF; + floatOptions.NumSteps = numSteps; + floatOptions.StepSize = (stepSize > 0 ? stepSize : new Nullable()); + floatOptions.LogBase = logBase; try { - sweepValues = new FloatValueGenerator(floatArgs); + sweepValues = new FloatValueGenerator(floatOptions); } catch (Exception e) { diff --git a/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs b/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs index 5b5f298009..82149cea5b 100644 --- a/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs +++ b/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs @@ -28,12 +28,12 @@ public sealed class Options private readonly IHost _host; - public InternalSweepResultEvaluator(IHostEnvironment env, Options args) + public InternalSweepResultEvaluator(IHostEnvironment env, Options options) { Contracts.CheckValue(env, nameof(env)); _host = env.Register("InternalSweepResultEvaluator"); - _host.CheckNonEmpty(args.Metric, nameof(args.Metric)); - _metric = FindMetric(args.Metric, out _maximizing); + _host.CheckNonEmpty(options.Metric, nameof(options.Metric)); + _metric = FindMetric(options.Metric, out _maximizing); } private string FindMetric(string userMetric, out bool maximizing) diff --git a/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs b/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs index 0466c3d5bf..9bc6b4220e 100644 --- a/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs +++ b/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs @@ -72,13 +72,13 @@ private static VersionInfo GetVersionInfo() private readonly float _gamma; - public GaussianFourierSampler(IHostEnvironment env, Options args, float avgDist) + public GaussianFourierSampler(IHostEnvironment env, Options options, float avgDist) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(LoadName); - _host.CheckValue(args, nameof(args)); + _host.CheckValue(options, nameof(options)); - _gamma = args.Gamma / avgDist; + _gamma = options.Gamma / avgDist; } private static GaussianFourierSampler Create(IHostEnvironment env, ModelLoadContext ctx) @@ -153,13 +153,13 @@ private static VersionInfo GetVersionInfo() private readonly IHost _host; private readonly float _a; - public LaplacianFourierSampler(IHostEnvironment env, Options args, float avgDist) + public LaplacianFourierSampler(IHostEnvironment env, Options options, float avgDist) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(RegistrationName); - _host.CheckValue(args, nameof(args)); + _host.CheckValue(options, nameof(options)); - _a = args.A / avgDist; + _a = options.A / avgDist; } private static LaplacianFourierSampler Create(IHostEnvironment env, ModelLoadContext ctx) diff --git a/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs b/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs index 9f88d678d7..8563b4f0d5 100644 --- a/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs +++ b/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs @@ -44,7 +44,7 @@ public void RandomGridSweeperReturnsDistinctValuesWhenProposeSweep() private static DiscreteValueGenerator CreateDiscreteValueGenerator() { - var args = new DiscreteParamArguments() + var args = new DiscreteParamOptions() { Name = "TestParam", Values = new string[] { "one", "two" } diff --git a/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs b/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs index 07dd0db895..7e460fc712 100644 --- a/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs +++ b/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs @@ -28,7 +28,7 @@ public TestSweeper(ITestOutputHelper output) : base(output) [InlineData("bla", 10, 1000, true, 10, 3, "99")] public void TestLongValueSweep(string name, int min, int max, bool logBase, int stepSize, int numSteps, string valueText) { - var paramSweep = new LongValueGenerator(new LongParamArguments() { Name = name, Min = min, Max = max, LogBase = logBase, StepSize = stepSize, NumSteps = numSteps }); + var paramSweep = new LongValueGenerator(new LongParamOptions() { Name = name, Min = min, Max = max, LogBase = logBase, StepSize = stepSize, NumSteps = numSteps }); IParameterValue value = paramSweep.CreateFromNormalized(0.5); Assert.Equal(name, value.Name); Assert.Equal(valueText, value.ValueText); @@ -37,7 +37,7 @@ public void TestLongValueSweep(string name, int min, int max, bool logBase, int [Fact] public void TestLongValueGeneratorRoundTrip() { - var paramSweep = new LongValueGenerator(new LongParamArguments() { Name = "bla", Min = 0, Max = 17 }); + var paramSweep = new LongValueGenerator(new LongParamOptions() { Name = "bla", Min = 0, Max = 17 }); var value = new LongParameterValue("bla", 5); float normalizedValue = paramSweep.NormalizeValue(value); IParameterValue unNormalizedValue = paramSweep.CreateFromNormalized(normalizedValue); @@ -55,7 +55,7 @@ public void TestLongValueGeneratorRoundTrip() [InlineData("bla", 10, 1000, true, 10, 3, "100")] public void TestFloatValueSweep(string name, int min, int max, bool logBase, int stepSize, int numSteps, string valueText) { - var paramSweep = new FloatValueGenerator(new FloatParamArguments() { Name = name, Min = min, Max = max, LogBase = logBase, StepSize = stepSize, NumSteps = numSteps }); + var paramSweep = new FloatValueGenerator(new FloatParamOptions() { Name = name, Min = min, Max = max, LogBase = logBase, StepSize = stepSize, NumSteps = numSteps }); IParameterValue value = paramSweep.CreateFromNormalized(0.5); Assert.Equal(name, value.Name); Assert.Equal(valueText, value.ValueText); @@ -64,7 +64,7 @@ public void TestFloatValueSweep(string name, int min, int max, bool logBase, int [Fact] public void TestFloatValueGeneratorRoundTrip() { - var paramSweep = new FloatValueGenerator(new FloatParamArguments() { Name = "bla", Min = 1, Max = 5 }); + var paramSweep = new FloatValueGenerator(new FloatParamOptions() { Name = "bla", Min = 1, Max = 5 }); var random = new Random(123); var normalizedValue = (float)random.NextDouble(); var value = (FloatParameterValue)paramSweep.CreateFromNormalized(normalizedValue); @@ -80,7 +80,7 @@ public void TestFloatValueGeneratorRoundTrip() [InlineData(0.75, "baz")] public void TestDiscreteValueSweep(double normalizedValue, string expected) { - var paramSweep = new DiscreteValueGenerator(new DiscreteParamArguments() { Name = "bla", Values = new[] { "foo", "bar", "baz" } }); + var paramSweep = new DiscreteValueGenerator(new DiscreteParamOptions() { Name = "bla", Values = new[] { "foo", "bar", "baz" } }); var value = paramSweep.CreateFromNormalized(normalizedValue); Assert.Equal("bla", value.Name); Assert.Equal(expected, value.ValueText); @@ -94,9 +94,9 @@ public void TestRandomSweeper() { SweptParameters = new[] { ComponentFactoryUtils.CreateFromFunction( - environ => new LongValueGenerator(new LongParamArguments() { Name = "foo", Min = 10, Max = 20 })), + environ => new LongValueGenerator(new LongParamOptions() { Name = "foo", Min = 10, Max = 20 })), ComponentFactoryUtils.CreateFromFunction( - environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 100, Max = 200 })) + environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 100, Max = 200 })) } }; @@ -135,9 +135,9 @@ public void TestSimpleSweeperAsync() { SweptParameters = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( - environ => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5 })), + environ => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5 })), ComponentFactoryUtils.CreateFromFunction( - environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) + environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) } }); @@ -157,9 +157,9 @@ public void TestSimpleSweeperAsync() var gridArgs = new RandomGridSweeper.Options(); gridArgs.SweptParameters = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( - environ => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})), + environ => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5})), ComponentFactoryUtils.CreateFromFunction( - environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 100, LogBase = true })) + environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 100, LogBase = true })) }; var gridSweeper = new SimpleAsyncSweeper(env, gridArgs); paramSets.Clear(); @@ -188,9 +188,9 @@ public void TestDeterministicSweeperAsyncCancellation() { SweptParameters = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( - t => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})), + t => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5})), ComponentFactoryUtils.CreateFromFunction( - t => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) + t => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) } })); @@ -238,9 +238,9 @@ public void TestDeterministicSweeperAsync() { SweptParameters = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( - t => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})), + t => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5})), ComponentFactoryUtils.CreateFromFunction( - t => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) + t => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) } })); @@ -310,9 +310,9 @@ public void TestDeterministicSweeperAsyncParallel() { SweptParameters = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( - t => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})), + t => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5})), ComponentFactoryUtils.CreateFromFunction( - t => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) + t => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) } })); @@ -361,9 +361,9 @@ public async Task TestNelderMeadSweeperAsync() { var param = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( - innerEnviron => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})), + innerEnviron => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5})), ComponentFactoryUtils.CreateFromFunction( - innerEnviron => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) + innerEnviron => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) }; var nelderMeadSweeperArgs = new NelderMeadSweeper.Options() @@ -431,9 +431,9 @@ public void TestRandomGridSweeper() { SweptParameters = new[] { ComponentFactoryUtils.CreateFromFunction( - environ => new LongValueGenerator(new LongParamArguments() { Name = "foo", Min = 10, Max = 20, NumSteps = 3 })), + environ => new LongValueGenerator(new LongParamOptions() { Name = "foo", Min = 10, Max = 20, NumSteps = 3 })), ComponentFactoryUtils.CreateFromFunction( - environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 100, Max = 10000, LogBase = true, StepSize = 10 })) + environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 100, Max = 10000, LogBase = true, StepSize = 10 })) } }; var sweeper = new RandomGridSweeper(env, args); @@ -539,9 +539,9 @@ public void TestNelderMeadSweeper() var env = new MLContext(42); var param = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( - environ => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})), + environ => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5})), ComponentFactoryUtils.CreateFromFunction( - environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) + environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) }; var args = new NelderMeadSweeper.Options() @@ -595,9 +595,9 @@ public void TestNelderMeadSweeperWithDefaultFirstBatchSweeper() var env = new MLContext(42); var param = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( - environ => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})), + environ => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5})), ComponentFactoryUtils.CreateFromFunction( - environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) + environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) }; var args = new NelderMeadSweeper.Options(); @@ -647,9 +647,9 @@ public void TestSmacSweeper() NumberInitialPopulation = 20, SweptParameters = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( - environ => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})), + environ => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5})), ComponentFactoryUtils.CreateFromFunction( - environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 100, LogBase = true })) + environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 100, LogBase = true })) } }; diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index 8fab3a9331..55a4f7c9ae 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -208,7 +208,7 @@ public void TweedieRegressorEstimator() new FastTreeTweedieTrainer.Options { EntropyCoefficient = 0.3, - OptimizationAlgorithm = BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent, + OptimizationAlgorithm = BoostedTreeOptions.OptimizationAlgorithmType.AcceleratedGradientDescent, }); TestEstimatorCore(trainer, dataView);