Skip to content

Fixing renmants of argument keyword in public API #2636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 33 additions & 33 deletions src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ public sealed class Arguments : DataCommand.ArgumentsBase
public CrossValidationCommand(IHostEnvironment env, Arguments args)
: base(env, args, RegistrationName)
{
Host.CheckUserArg(Args.NumFolds >= 2, nameof(Args.NumFolds), "Number of folds must be greater than or equal to 2.");
Host.CheckUserArg(ImplOptions.NumFolds >= 2, nameof(ImplOptions.NumFolds), "Number of folds must be greater than or equal to 2.");
TrainUtils.CheckTrainer(Host, args.Trainer, args.DataFile);
Utils.CheckOptionalUserDirectory(Args.SummaryFilename, nameof(Args.SummaryFilename));
Utils.CheckOptionalUserDirectory(Args.OutputDataFile, nameof(Args.OutputDataFile));
Utils.CheckOptionalUserDirectory(ImplOptions.SummaryFilename, nameof(ImplOptions.SummaryFilename));
Utils.CheckOptionalUserDirectory(ImplOptions.OutputDataFile, nameof(ImplOptions.OutputDataFile));
}

// This is for "forking" the host environment.
Expand All @@ -124,7 +124,7 @@ public override void Run()
using (var ch = Host.Start(LoadName))
using (var server = InitServer(ch))
{
var settings = CmdParser.GetSettings(Host, Args, new Arguments());
var settings = CmdParser.GetSettings(Host, ImplOptions, new Arguments());
string cmd = string.Format("maml.exe {0} {1}", LoadName, settings);
ch.Info(cmd);

Expand All @@ -139,7 +139,7 @@ public override void Run()

protected override void SendTelemetryCore(IPipe<TelemetryMessage> pipe)
{
SendTelemetryComponent(pipe, Args.Trainer);
SendTelemetryComponent(pipe, ImplOptions.Trainer);
base.SendTelemetryCore(pipe);
}

Expand All @@ -148,17 +148,17 @@ private void RunCore(IChannel ch, string cmd)
Host.AssertValue(ch);

IPredictor inputPredictor = null;
if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
if (ImplOptions.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, ImplOptions.InputModelFile, out inputPredictor))
ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");

ch.Trace("Constructing data pipeline");
IDataLoader loader = CreateRawLoader();

// If the per-instance results are requested and there is no name column, add a GenerateNumberTransform.
var preXf = Args.PreTransforms;
if (!string.IsNullOrEmpty(Args.OutputDataFile))
var preXf = ImplOptions.PreTransforms;
if (!string.IsNullOrEmpty(ImplOptions.OutputDataFile))
{
string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name);
string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(ImplOptions.NameColumn), ImplOptions.NameColumn, DefaultColumnNames.Name);
if (name == null)
{
preXf = preXf.Concat(
Expand All @@ -182,24 +182,24 @@ private void RunCore(IChannel ch, string cmd)

IDataView pipe = loader;
var stratificationColumn = GetSplitColumn(ch, loader, ref pipe);
var scorer = Args.Scorer;
var evaluator = Args.Evaluator;
var scorer = ImplOptions.Scorer;
var evaluator = ImplOptions.Evaluator;

Func<IDataView> validDataCreator = null;
if (Args.ValidationFile != null)
if (ImplOptions.ValidationFile != null)
{
validDataCreator =
() =>
{
// Fork the command.
var impl = new CrossValidationCommand(this);
return impl.CreateRawLoader(dataFile: Args.ValidationFile);
return impl.CreateRawLoader(dataFile: ImplOptions.ValidationFile);
};
}

FoldHelper fold = new FoldHelper(Host, RegistrationName, pipe, stratificationColumn,
Args, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator,
validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(Args.OutputDataFile));
ImplOptions, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator,
validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(ImplOptions.OutputDataFile));
var tasks = fold.GetCrossValidationTasks();

var eval = evaluator?.CreateComponent(Host) ??
Expand All @@ -218,32 +218,32 @@ private void RunCore(IChannel ch, string cmd)
throw ch.Except("No overall metrics found");

var overall = eval.GetOverallResults(overallList.ToArray());
MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, Args.NumFolds);
MetricWriter.PrintOverallMetrics(Host, ch, ImplOptions.SummaryFilename, overall, ImplOptions.NumFolds);
eval.PrintAdditionalMetrics(ch, tasks.Select(t => t.Result.Metrics).ToArray());
Dictionary<string, IDataView>[] metricValues = tasks.Select(t => t.Result.Metrics).ToArray();
SendTelemetryMetric(metricValues);

// Save the per-instance results.
if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
if (!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile))
{
var perInstance = EvaluateUtils.ConcatenatePerInstanceDataViews(Host, eval, Args.CollateMetrics,
Args.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames);
var perInstance = EvaluateUtils.ConcatenatePerInstanceDataViews(Host, eval, ImplOptions.CollateMetrics,
ImplOptions.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames);
if (variableSizeVectorColumnNames.Length > 0)
{
ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.",
string.Join(", ", variableSizeVectorColumnNames));
}
if (Args.CollateMetrics)
if (ImplOptions.CollateMetrics)
{
ch.Assert(perInstance.Length == 1);
MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, perInstance[0]);
MetricWriter.SavePerInstance(Host, ch, ImplOptions.OutputDataFile, perInstance[0]);
}
else
{
int i = 0;
foreach (var idv in perInstance)
{
MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(Args.OutputDataFile, i), idv);
MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(ImplOptions.OutputDataFile, i), idv);
i++;
}
}
Expand All @@ -265,20 +265,20 @@ private RoleMappedData ApplyAllTransformsToData(IHostEnvironment env, IChannel c
/// </summary>
private RoleMappedData CreateRoleMappedData(IHostEnvironment env, IChannel ch, IDataView data, ITrainer trainer)
{
foreach (var kvp in Args.Transforms)
foreach (var kvp in ImplOptions.Transforms)
data = kvp.Value.CreateComponent(env, data);

var schema = data.Schema;
string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.LabelColumn), Args.LabelColumn, DefaultColumnNames.Label);
string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.FeatureColumn), Args.FeatureColumn, DefaultColumnNames.Features);
string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.WeightColumn), Args.WeightColumn, DefaultColumnNames.Weight);
string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name);
string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId);
string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.LabelColumn), ImplOptions.LabelColumn, DefaultColumnNames.Label);
string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.FeatureColumn), ImplOptions.FeatureColumn, DefaultColumnNames.Features);
string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.WeightColumn), ImplOptions.WeightColumn, DefaultColumnNames.Weight);
string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.NameColumn), ImplOptions.NameColumn, DefaultColumnNames.Name);
string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.GroupColumn), ImplOptions.GroupColumn, DefaultColumnNames.GroupId);

TrainUtils.AddNormalizerIfNeeded(env, ch, trainer, ref data, features, Args.NormalizeFeatures);
TrainUtils.AddNormalizerIfNeeded(env, ch, trainer, ref data, features, ImplOptions.NormalizeFeatures);

// Training pipe and examples.
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumns);
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns);

return new RoleMappedData(data, label, features, group, weight, name, customCols);
}
Expand All @@ -291,11 +291,11 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output
// If no stratification column was specified, but we have a group column of type Single, Double or
// Key (contiguous) use it.
string stratificationColumn = null;
if (!string.IsNullOrWhiteSpace(Args.StratificationColumn))
stratificationColumn = Args.StratificationColumn;
if (!string.IsNullOrWhiteSpace(ImplOptions.StratificationColumn))
stratificationColumn = ImplOptions.StratificationColumn;
else
{
string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId);
string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.GroupColumn), ImplOptions.GroupColumn, DefaultColumnNames.GroupId);
int index;
if (group != null && schema.TryGetColumnIndex(group, out index))
{
Expand Down
Loading