Skip to content

Commit f85e722

Browse files
authored
Role mapped improvements (#496)
* RoleMappedSchema/Data change to use constructors * Nuke all create methods, pointless "no-roles" constructor. * Nuke TrainUtils.CreateExamples/CreateExamplesOpt * Opportunistically improve code quality and reporting of Kmeans++
1 parent 52cc874 commit f85e722

File tree

57 files changed

+387
-437
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+387
-437
lines changed

src/Microsoft.ML.Api/ComponentCreation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public static RoleMappedData CreateExamples(this IHostEnvironment env, IDataView
5252
env.CheckValueOrNull(weight);
5353
env.CheckValueOrNull(custom);
5454

55-
return TrainUtils.CreateExamples(data, label, features, group, weight, name: null, custom: custom);
55+
return new RoleMappedData(data, label, features, group, weight, name: null, custom: custom);
5656
}
5757

5858
/// <summary>

src/Microsoft.ML.Api/GenerateCodeCommand.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ public void Run()
108108
{
109109
var roles = ModelFileUtils.LoadRoleMappingsOrNull(_host, fs);
110110
scorer = roles != null
111-
? _host.CreateDefaultScorer(RoleMappedData.CreateOpt(transformPipe, roles), pred)
112-
: _host.CreateDefaultScorer(_host.CreateExamples(transformPipe, "Features"), pred);
111+
? _host.CreateDefaultScorer(new RoleMappedData(transformPipe, roles, opt: true), pred)
112+
: _host.CreateDefaultScorer(new RoleMappedData(transformPipe, label: null, "Features"), pred);
113113
}
114114

115115
var nonScoreSb = new StringBuilder();

src/Microsoft.ML.Api/PredictionEngine.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ internal BatchPredictionEngine(IHostEnvironment env, Stream modelStream, bool ig
4949
{
5050
var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, modelStream);
5151
pipe = roles != null
52-
? env.CreateDefaultScorer(RoleMappedData.CreateOpt(pipe, roles), predictor)
53-
: env.CreateDefaultScorer(env.CreateExamples(pipe, "Features"), predictor);
52+
? env.CreateDefaultScorer(new RoleMappedData(pipe, roles, opt: true), predictor)
53+
: env.CreateDefaultScorer(new RoleMappedData(pipe, label: null, "Features"), predictor);
5454
}
5555

5656
_pipeEngine = new PipeEngine<TDst>(env, pipe, ignoreMissingColumns, outputSchemaDefinition);

src/Microsoft.ML.Core/Data/MetadataUtils.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,6 @@ public static bool HasSlotNames(this ISchema schema, int col, int vectorSize)
312312
public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer<DvText> slotNames)
313313
{
314314
Contracts.CheckValueOrNull(schema);
315-
Contracts.CheckValue(role.Value, nameof(role));
316315
Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize));
317316

318317
IReadOnlyList<ColumnInfo> list;

src/Microsoft.ML.Core/Data/RoleMappedSchema.cs

Lines changed: 218 additions & 143 deletions
Large diffs are not rendered by default.

src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ private RoleMappedData ApplyAllTransformsToData(IHostEnvironment env, IChannel c
254254
RoleMappedData srcData, IDataView marker)
255255
{
256256
var pipe = ApplyTransformUtils.ApplyAllTransformsToData(env, srcData.Data, dstData, marker);
257-
return RoleMappedData.Create(pipe, srcData.Schema.GetColumnRoleNames());
257+
return new RoleMappedData(pipe, srcData.Schema.GetColumnRoleNames());
258258
}
259259

260260
/// <summary>
@@ -277,7 +277,7 @@ private RoleMappedData CreateRoleMappedData(IHostEnvironment env, IChannel ch, I
277277
// Training pipe and examples.
278278
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
279279

280-
return TrainUtils.CreateExamples(data, label, features, group, weight, name, customCols);
280+
return new RoleMappedData(data, label, features, group, weight, name, customCols);
281281
}
282282

283283
private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output)
@@ -568,7 +568,7 @@ private FoldResult RunFold(int fold)
568568
{
569569
using (var file = host.CreateOutputFile(modelFileName))
570570
{
571-
var rmd = RoleMappedData.Create(
571+
var rmd = new RoleMappedData(
572572
CompositeDataLoader.ApplyTransform(host, _loader, null, null,
573573
(e, newSource) => ApplyTransformUtils.ApplyAllTransformsToData(e, trainData.Data, newSource)),
574574
trainData.Schema.GetColumnRoleNames());
@@ -581,17 +581,17 @@ private FoldResult RunFold(int fold)
581581
if (!evalComp.IsGood())
582582
evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
583583
var eval = evalComp.CreateInstance(host);
584-
// Note that this doesn't require the provided columns to exist (because of "Opt").
584+
// Note that this doesn't require the provided columns to exist (because of the "opt" parameter).
585585
// We don't normally expect the scorer to drop columns, but if it does, we should not require
586586
// all the columns in the test pipeline to still be present.
587-
var dataEval = RoleMappedData.CreateOpt(scorePipe, testData.Schema.GetColumnRoleNames());
587+
var dataEval = new RoleMappedData(scorePipe, testData.Schema.GetColumnRoleNames(), opt: true);
588588

589589
var dict = eval.Evaluate(dataEval);
590590
RoleMappedData perInstance = null;
591591
if (_savePerInstance)
592592
{
593593
var perInst = eval.GetPerInstanceMetrics(dataEval);
594-
perInstance = RoleMappedData.CreateOpt(perInst, dataEval.Schema.GetColumnRoleNames());
594+
perInstance = new RoleMappedData(perInst, dataEval.Schema.GetColumnRoleNames(), opt: true);
595595
}
596596
ch.Done();
597597
return new FoldResult(dict, dataEval.Schema.Schema, perInstance, trainData.Schema);

src/Microsoft.ML.Data/Commands/DataCommand.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ protected void LoadModelObjects(
305305
// can be loaded with no data at all, to get their schemas.
306306
if (trainPipe == null)
307307
trainPipe = ModelFileUtils.LoadLoader(Host, rep, new MultiFileSource(null), loadTransforms: true);
308-
trainSchema = RoleMappedSchema.Create(trainPipe.Schema, trainRoleMappings);
308+
trainSchema = new RoleMappedSchema(trainPipe.Schema, trainRoleMappings);
309309
}
310310
// If the role mappings are null, an alternative would be to fail. However the idea
311311
// is that the scorer should always still succeed, although perhaps with reduced

src/Microsoft.ML.Data/Commands/EvaluateCommand.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
158158
evalComp = EvaluateUtils.GetEvaluatorType(ch, input.Schema);
159159

160160
var eval = evalComp.CreateInstance(env);
161-
var data = TrainUtils.CreateExamples(input, label, null, group, weight, null, customCols);
161+
var data = new RoleMappedData(input, label, null, group, weight, null, customCols);
162162
return eval.GetPerInstanceMetrics(data);
163163
}
164164
}
@@ -236,7 +236,7 @@ private void RunCore(IChannel ch)
236236
if (!evalComp.IsGood())
237237
evalComp = EvaluateUtils.GetEvaluatorType(ch, view.Schema);
238238
var evaluator = evalComp.CreateInstance(Host);
239-
var data = TrainUtils.CreateExamples(view, label, null, group, weight, name, customCols);
239+
var data = new RoleMappedData(view, label, null, group, weight, name, customCols);
240240
var metrics = evaluator.Evaluate(data);
241241
MetricWriter.PrintWarnings(ch, metrics);
242242
evaluator.PrintFoldResults(ch, metrics);
@@ -248,7 +248,7 @@ private void RunCore(IChannel ch)
248248
if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
249249
{
250250
var perInst = evaluator.GetPerInstanceMetrics(data);
251-
var perInstData = TrainUtils.CreateExamples(perInst, label, null, group, weight, name, customCols);
251+
var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
252252
var idv = evaluator.GetPerInstanceDataViewToSave(perInstData);
253253
MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv);
254254
}

src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ public static void LoadModel(IHostEnvironment env, Stream modelStream, bool load
219219
if (roles != null)
220220
{
221221
var emptyView = ModelFileUtils.LoadPipeline(env, rep, new MultiFileSource(null));
222-
schema = RoleMappedSchema.CreateOpt(emptyView.Schema, roles);
222+
schema = new RoleMappedSchema(emptyView.Schema, roles, opt: true);
223223
}
224224
else
225225
{

src/Microsoft.ML.Data/Commands/ScoreCommand.cs

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,7 @@ private void RunCore(IChannel ch)
9797

9898
ch.Trace("Creating loader");
9999

100-
IPredictor predictor;
101-
IDataLoader loader;
102-
RoleMappedSchema trainSchema;
103-
LoadModelObjects(ch, true, out predictor, true, out trainSchema, out loader);
100+
LoadModelObjects(ch, true, out var predictor, true, out var trainSchema, out var loader);
104101
ch.AssertValue(predictor);
105102
ch.AssertValueOrNull(trainSchema);
106103
ch.AssertValue(loader);
@@ -116,7 +113,7 @@ private void RunCore(IChannel ch)
116113
string group = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema,
117114
nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId);
118115
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
119-
var schema = TrainUtils.CreateRoleMappedSchemaOpt(loader.Schema, feat, group, customCols);
116+
var schema = new RoleMappedSchema(loader.Schema, label: null, feature: feat, group: group, custom: customCols, opt: true);
120117
var mapper = bindable.Bind(Host, schema);
121118

122119
if (!scorer.IsGood())
@@ -153,22 +150,20 @@ private void RunCore(IChannel ch)
153150
Args.OutputAllColumns == true || Utils.Size(Args.OutputColumn) == 0;
154151

155152
if (Args.OutputAllColumns == true && Utils.Size(Args.OutputColumn) != 0)
156-
ch.Warning("outputAllColumns=+ always writes all columns irrespective of outputColumn specified.");
153+
ch.Warning(nameof(Args.OutputAllColumns) + "=+ always writes all columns irrespective of " + nameof(Args.OutputColumn) + " specified.");
157154

158155
if (!outputAllColumns && Utils.Size(Args.OutputColumn) != 0)
159156
{
160157
foreach (var outCol in Args.OutputColumn)
161158
{
162-
int dummyColIndex;
163-
if (!loader.Schema.TryGetColumnIndex(outCol, out dummyColIndex))
159+
if (!loader.Schema.TryGetColumnIndex(outCol, out int dummyColIndex))
164160
throw ch.ExceptUserArg(nameof(Arguments.OutputColumn), "Column '{0}' not found.", outCol);
165161
}
166162
}
167163

168-
int colMax;
169164
uint maxScoreId = 0;
170165
if (!outputAllColumns)
171-
maxScoreId = loader.Schema.GetMaxMetadataKind(out colMax, MetadataUtils.Kinds.ScoreColumnSetId);
166+
maxScoreId = loader.Schema.GetMaxMetadataKind(out int colMax, MetadataUtils.Kinds.ScoreColumnSetId);
172167
ch.Assert(outputAllColumns || maxScoreId > 0); // score set IDs are one-based
173168
var cols = new List<int>();
174169
for (int i = 0; i < loader.Schema.ColumnCount; i++)
@@ -211,12 +206,12 @@ private bool ShouldAddColumn(ISchema schema, int i, uint scoreSet, bool outputNa
211206
{
212207
switch (schema.GetColumnName(i))
213208
{
214-
case "Label":
215-
case "Name":
216-
case "Names":
217-
return true;
218-
default:
219-
break;
209+
case "Label":
210+
case "Name":
211+
case "Names":
212+
return true;
213+
default:
214+
break;
220215
}
221216
}
222217
if (Args.OutputColumn != null && Array.FindIndex(Args.OutputColumn, schema.GetColumnName(i).Equals) >= 0)
@@ -229,8 +224,7 @@ public static class ScoreUtils
229224
{
230225
public static IDataScorerTransform GetScorer(IPredictor predictor, RoleMappedData data, IHostEnvironment env, RoleMappedSchema trainSchema)
231226
{
232-
ISchemaBoundMapper mapper;
233-
var sc = GetScorerComponentAndMapper(predictor, null, data.Schema, env, out mapper);
227+
var sc = GetScorerComponentAndMapper(predictor, null, data.Schema, env, out var mapper);
234228
return sc.CreateInstance(env, data.Data, mapper, trainSchema);
235229
}
236230

@@ -247,9 +241,8 @@ public static IDataScorerTransform GetScorer(SubComponent<IDataScorerTransform,
247241
env.CheckValueOrNull(customColumns);
248242
env.CheckValueOrNull(trainSchema);
249243

250-
var schema = TrainUtils.CreateRoleMappedSchemaOpt(input.Schema, featureColName, groupColName, customColumns);
251-
ISchemaBoundMapper mapper;
252-
var sc = GetScorerComponentAndMapper(predictor, scorer, schema, env, out mapper);
244+
var schema = new RoleMappedSchema(input.Schema, label: null, feature: featureColName, group: groupColName, custom: customColumns, opt: true);
245+
var sc = GetScorerComponentAndMapper(predictor, scorer, schema, env, out var mapper);
253246
return sc.CreateInstance(env, input, mapper, trainSchema);
254247
}
255248

@@ -280,7 +273,7 @@ public static SubComponent<IDataScorerTransform, SignatureDataScorer> GetScorerC
280273
Contracts.AssertValue(mapper);
281274

282275
string loadName = null;
283-
DvText scoreKind = default(DvText);
276+
DvText scoreKind = default;
284277
if (mapper.OutputSchema.ColumnCount > 0 &&
285278
mapper.OutputSchema.TryGetMetadata(TextType.Instance, MetadataUtils.Kinds.ScoreColumnKind, 0, ref scoreKind) &&
286279
scoreKind.HasChars)
@@ -311,10 +304,8 @@ public static ISchemaBindableMapper GetSchemaBindableMapper(IHostEnvironment env
311304
env.CheckValue(predictor, nameof(predictor));
312305
env.CheckValueOrNull(scorerSettings);
313306

314-
ISchemaBindableMapper bindable;
315-
316307
// See if we can instantiate a mapper using scorer arguments.
317-
if (scorerSettings.IsGood() && TryCreateBindableFromScorer(env, predictor, scorerSettings, out bindable))
308+
if (scorerSettings.IsGood() && TryCreateBindableFromScorer(env, predictor, scorerSettings, out var bindable))
318309
return bindable;
319310

320311
// The easy case is that the predictor implements the interface.

src/Microsoft.ML.Data/Commands/TestCommand.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ private void RunCore(IChannel ch)
114114
if (!evalComp.IsGood())
115115
evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
116116
var evaluator = evalComp.CreateInstance(Host);
117-
var data = TrainUtils.CreateExamples(scorePipe, label, null, group, weight, name, customCols);
117+
var data = new RoleMappedData(scorePipe, label, null, group, weight, name, customCols);
118118
var metrics = evaluator.Evaluate(data);
119119
MetricWriter.PrintWarnings(ch, metrics);
120120
evaluator.PrintFoldResults(ch, metrics);
@@ -128,7 +128,7 @@ private void RunCore(IChannel ch)
128128
if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
129129
{
130130
var perInst = evaluator.GetPerInstanceMetrics(data);
131-
var perInstData = TrainUtils.CreateExamples(perInst, label, null, group, weight, name, customCols);
131+
var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
132132
var idv = evaluator.GetPerInstanceDataViewToSave(perInstData);
133133
MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv);
134134
}

0 commit comments

Comments
 (0)