Skip to content

Replaces ChooseColumnsTransform and DropColumnsTransform with SelectColumnsTransform #1371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/Microsoft.ML.Data/Commands/SaveDataCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,10 @@ private void RunCore(IChannel ch)

if (!string.IsNullOrWhiteSpace(Args.Columns))
{
var args = new ChooseColumnsTransform.Arguments();
args.Column = Args.Columns
.Split(new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries).Select(s => new ChooseColumnsTransform.Column() { Name = s }).ToArray();
if (Utils.Size(args.Column) > 0)
data = new ChooseColumnsTransform(Host, args, data);
var keepColumns = Args.Columns
.Split(new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries).ToArray();
if (Utils.Size(keepColumns) > 0)
data = SelectColumnsTransform.CreateKeep(Host, data, keepColumns);
}

IDataSaver saver;
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ public static CommonOutputs.TransformOutput RenameBinaryPredictionScoreColumns(I
}

var copyColumn = new CopyColumnsTransform(env, copyCols.ToArray()).Transform(input.Data);
var dropColumn = new DropColumnsTransform(env, new DropColumnsTransform.Arguments() { Column = copyCols.Select(c => c.Source).ToArray() }, copyColumn);
var dropColumn = SelectColumnsTransform.CreateDrop(env, copyColumn, copyCols.Select(c => c.Source).ToArray());
return new CommonOutputs.TransformOutput { Model = new TransformModel(env, dropColumn, input.Data), OutputData = dropColumn };
}
}
Expand Down
75 changes: 29 additions & 46 deletions src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -703,59 +703,42 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDa
}
}

var args = new ChooseColumnsTransform.Arguments();
var cols = new List<ChooseColumnsTransform.Column>()
{
new ChooseColumnsTransform.Column()
{
Name = string.Format(FoldDrAtKFormat, _k),
Source = AnomalyDetectionEvaluator.OverallMetrics.DrAtK
},
new ChooseColumnsTransform.Column()
{
Name = string.Format(FoldDrAtPFormat, _p),
Source = AnomalyDetectionEvaluator.OverallMetrics.DrAtPFpr
},
new ChooseColumnsTransform.Column()
{
Name = string.Format(FoldDrAtNumAnomaliesFormat, numAnomalies),
Source=AnomalyDetectionEvaluator.OverallMetrics.DrAtNumPos
},
new ChooseColumnsTransform.Column()
{
Name=AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK
},
new ChooseColumnsTransform.Column()
{
Name=AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP
},
new ChooseColumnsTransform.Column()
{
Name=AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos
},
new ChooseColumnsTransform.Column()
{
Name = BinaryClassifierEvaluator.Auc
}
};
var kFormatName = string.Format(FoldDrAtKFormat, _k);
var pFormatName = string.Format(FoldDrAtPFormat, _p);
var numAnomName = string.Format(FoldDrAtNumAnomaliesFormat, numAnomalies);

(string Source, string Name)[] cols =
{
(AnomalyDetectionEvaluator.OverallMetrics.DrAtK, kFormatName),
(AnomalyDetectionEvaluator.OverallMetrics.DrAtPFpr, pFormatName),
(AnomalyDetectionEvaluator.OverallMetrics.DrAtNumPos, numAnomName)
};

// List of columns to keep, note that the order specified determines the order of the output
var colsToKeep = new List<string>();
colsToKeep.Add(kFormatName);
colsToKeep.Add(pFormatName);
colsToKeep.Add(numAnomName);
colsToKeep.Add(AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK);
colsToKeep.Add(AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP);
colsToKeep.Add(AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos);
colsToKeep.Add(BinaryClassifierEvaluator.Auc);

overall = new CopyColumnsTransform(Host, cols).Transform(overall);
IDataView fold = SelectColumnsTransform.CreateKeep(Host, overall, colsToKeep.ToArray());

args.Column = cols.ToArray();
IDataView fold = new ChooseColumnsTransform(Host, args, overall);
string weightedFold;
ch.Info(MetricWriter.GetPerFoldResults(Host, fold, out weightedFold));
}

protected override IDataView GetOverallResultsCore(IDataView overall)
{
var args = new DropColumnsTransform.Arguments();
args.Column = new[]
{
AnomalyDetectionEvaluator.OverallMetrics.NumAnomalies,
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK,
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP,
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos
};
return new DropColumnsTransform(Host, args, overall);
return SelectColumnsTransform.CreateDrop(Host,
overall,
AnomalyDetectionEvaluator.OverallMetrics.NumAnomalies,
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK,
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP,
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos);
}

protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
Expand Down
58 changes: 23 additions & 35 deletions src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1333,43 +1333,33 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDa
if (!metrics.TryGetValue(MetricKinds.ConfusionMatrix, out conf))
throw ch.Except("No overall metrics found");

var args = new ChooseColumnsTransform.Arguments();
var cols = new List<ChooseColumnsTransform.Column>()
{
new ChooseColumnsTransform.Column()
{
Name = FoldAccuracy,
Source = BinaryClassifierEvaluator.Accuracy
},
new ChooseColumnsTransform.Column()
{
Name = FoldLogLoss,
Source = BinaryClassifierEvaluator.LogLoss
},
new ChooseColumnsTransform.Column()
{
Name = BinaryClassifierEvaluator.Entropy
},
new ChooseColumnsTransform.Column()
{
Name = FoldLogLosRed,
Source = BinaryClassifierEvaluator.LogLossReduction
},
new ChooseColumnsTransform.Column()
{
Name = BinaryClassifierEvaluator.Auc
}
};
(string Source, string Name)[] cols =
{
(BinaryClassifierEvaluator.Accuracy, FoldAccuracy),
(BinaryClassifierEvaluator.LogLoss, FoldLogLoss),
(BinaryClassifierEvaluator.LogLossReduction, FoldLogLosRed)
};

var colsToKeep = new List<string>();
colsToKeep.Add(FoldAccuracy);
colsToKeep.Add(FoldLogLoss);
colsToKeep.Add(BinaryClassifierEvaluator.Entropy);
colsToKeep.Add(FoldLogLosRed);
colsToKeep.Add(BinaryClassifierEvaluator.Auc);

int index;
if (fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out index))
cols.Add(new ChooseColumnsTransform.Column() { Name = MetricKinds.ColumnNames.IsWeighted });
colsToKeep.Add(MetricKinds.ColumnNames.IsWeighted);
if (fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out index))
cols.Add(new ChooseColumnsTransform.Column() { Name = MetricKinds.ColumnNames.StratCol });
colsToKeep.Add(MetricKinds.ColumnNames.StratCol);
if (fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out index))
cols.Add(new ChooseColumnsTransform.Column() { Name = MetricKinds.ColumnNames.StratVal });
colsToKeep.Add(MetricKinds.ColumnNames.StratVal);

fold = new CopyColumnsTransform(Host, cols).Transform(fold);

// Select the columns that are specified in the Copy
fold = SelectColumnsTransform.CreateKeep(Host, fold, colsToKeep.ToArray());

args.Column = cols.ToArray();
fold = new ChooseColumnsTransform(Host, args, fold);
string weightedConf;
var unweightedConf = MetricWriter.GetConfusionTable(Host, conf, out weightedConf);
string weightedFold;
Expand All @@ -1386,9 +1376,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDa

protected override IDataView GetOverallResultsCore(IDataView overall)
{
var args = new DropColumnsTransform.Arguments();
args.Column = new[] { BinaryClassifierEvaluator.Entropy };
return new DropColumnsTransform(Host, args, overall);
return SelectColumnsTransform.CreateDrop(Host, overall, BinaryClassifierEvaluator.Entropy);
}

protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary<string, IDataView>[] metrics)
Expand Down
9 changes: 3 additions & 6 deletions src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, string
variableSizeVectorColumnName, type);

// Drop the old column that does not have variable length.
idv = new DropColumnsTransform(env, new DropColumnsTransform.Arguments() { Column = new[] { variableSizeVectorColumnName } }, idv);
idv = SelectColumnsTransform.CreateDrop(env, idv, variableSizeVectorColumnName);
}
return idv;
};
Expand Down Expand Up @@ -1059,8 +1059,7 @@ internal static IDataView GetOverallMetricsData(IHostEnvironment env, IDataView
{
if (Utils.Size(nonAveragedCols) > 0)
{
var dropArgs = new DropColumnsTransform.Arguments() { Column = nonAveragedCols.ToArray() };
data = new DropColumnsTransform(env, dropArgs, data);
data = SelectColumnsTransform.CreateDrop(env, data, nonAveragedCols.ToArray());
}
idvList.Add(data);
}
Expand Down Expand Up @@ -1734,9 +1733,7 @@ public static IDataView GetNonStratifiedMetrics(IHostEnvironment env, IDataView
var found = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out stratVal);
env.Check(found, "If stratification column exist, data view must also contain a StratVal column");

var dropArgs = new DropColumnsTransform.Arguments();
dropArgs.Column = new[] { data.Schema.GetColumnName(stratCol), data.Schema.GetColumnName(stratVal) };
data = new DropColumnsTransform(env, dropArgs, data);
data = SelectColumnsTransform.CreateDrop(env, data, data.Schema.GetColumnName(stratCol), data.Schema.GetColumnName(stratVal));
return data;
}
}
Expand Down
23 changes: 13 additions & 10 deletions src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,14 @@ private IDataView WrapPerInstance(RoleMappedData perInst)
var idv = perInst.Data;

// Make a list of column names that Maml outputs as part of the per-instance data view, and then wrap
// the per-instance data computed by the evaluator in a ChooseColumnsTransform.
var cols = new List<ChooseColumnsTransform.Column>();
// the per-instance data computed by the evaluator in a SelectColumnsTransform.
var cols = new List<(string Source, string Name)>();
var colsToKeep = new List<string>();

// If perInst is the result of cross-validation and contains a fold Id column, include it.
int foldCol;
if (perInst.Schema.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.FoldIndex, out foldCol))
cols.Add(new ChooseColumnsTransform.Column() { Source = MetricKinds.ColumnNames.FoldIndex });
colsToKeep.Add(MetricKinds.ColumnNames.FoldIndex);

// Maml always outputs a name column, if it doesn't exist add a GenerateNumberTransform.
if (perInst.Schema.Name == null)
Expand All @@ -228,22 +229,24 @@ private IDataView WrapPerInstance(RoleMappedData perInst)
args.Column = new[] { new GenerateNumberTransform.Column() { Name = "Instance" } };
args.UseCounter = true;
idv = new GenerateNumberTransform(Host, args, idv);
cols.Add(new ChooseColumnsTransform.Column() { Name = "Instance" });
colsToKeep.Add("Instance");
}
else
cols.Add(new ChooseColumnsTransform.Column() { Source = perInst.Schema.Name.Name, Name = "Instance" });
{
cols.Add((perInst.Schema.Name.Name, "Instance"));
colsToKeep.Add("Instance");
}

// Maml outputs the weight column if it exists.
if (perInst.Schema.Weight != null)
cols.Add(new ChooseColumnsTransform.Column() { Name = perInst.Schema.Weight.Name });
colsToKeep.Add(perInst.Schema.Weight.Name);

// Get the other columns from the evaluator.
foreach (var col in GetPerInstanceColumnsToSave(perInst.Schema))
cols.Add(new ChooseColumnsTransform.Column() { Name = col });
colsToKeep.Add(col);

var chooseArgs = new ChooseColumnsTransform.Arguments();
chooseArgs.Column = cols.ToArray();
idv = new ChooseColumnsTransform(Host, chooseArgs, idv);
idv = new CopyColumnsTransform(Host, cols.ToArray()).Transform(idv);
idv = SelectColumnsTransform.CreateKeep(Host, idv, colsToKeep.ToArray());
return GetPerInstanceMetricsCore(idv, perInst.Schema);
}

Expand Down
12 changes: 2 additions & 10 deletions src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1051,22 +1051,14 @@ protected override IDataView GetOverallResultsCore(IDataView overall)
private IDataView ChangeTopKAccColumnName(IDataView input)
{
input = new CopyColumnsTransform(Host, (MultiClassClassifierEvaluator.TopKAccuracy, string.Format(TopKAccuracyFormat, _outputTopKAcc))).Transform(input);
var dropArgs = new DropColumnsTransform.Arguments
{
Column = new[] { MultiClassClassifierEvaluator.TopKAccuracy }
};
return new DropColumnsTransform(Host, dropArgs, input);
return SelectColumnsTransform.CreateDrop(Host, input, MultiClassClassifierEvaluator.TopKAccuracy );
}

private IDataView DropPerClassColumn(IDataView input)
{
if (input.Schema.TryGetColumnIndex(MultiClassClassifierEvaluator.PerClassLogLoss, out int perClassCol))
{
var args = new DropColumnsTransform.Arguments
{
Column = new[] { MultiClassClassifierEvaluator.PerClassLogLoss }
};
input = new DropColumnsTransform(Host, args, input);
input = SelectColumnsTransform.CreateDrop(Host, input, MultiClassClassifierEvaluator.PerClassLogLoss);
}
return input;
}
Expand Down
Loading