Skip to content

Replacing TryGetColumnIndex with invocations of GetColumnOrNull #2088

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ public BoundBase(SchemaBindablePipelineEnsembleBase parent, RoleMappedSchema sch
for (int i = 0; i < Parent._inputCols.Length; i++)
{
var name = Parent._inputCols[i];
if (!InputRoleMappedSchema.Schema.TryGetColumnIndex(name, out int col))
throw Parent.Host.Except("Schema does not contain required input column '{0}'", name);
_inputColIndices.Add(col);
var col = InputRoleMappedSchema.Schema.GetColumnOrNull(name);
if (!col.HasValue)
throw Parent.Host.ExceptSchemaMismatch(nameof(InputRoleMappedSchema), "input", name);
_inputColIndices.Add(col.Value.Index);
}

Mappers = new ISchemaBoundRowMapper[Parent.PredictorModels.Length];
Expand All @@ -74,8 +75,10 @@ public BoundBase(SchemaBindablePipelineEnsembleBase parent, RoleMappedSchema sch
throw Parent.Host.Except("Predictor {0} is not a row to row mapper", i);

// Make sure there is a score column, and remember its index.
if (!Mappers[i].OutputSchema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out ScoreCols[i]))
var scoreCol = Mappers[i].OutputSchema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score);
if (!scoreCol.HasValue)
throw Parent.Host.Except("Predictor {0} does not contain a score column", i);
ScoreCols[i] = scoreCol.Value.Index;

// Get the pipeline.
var dv = new EmptyDataView(Parent.Host, schema.Schema);
Expand Down
31 changes: 16 additions & 15 deletions src/Microsoft.ML.EntryPoints/FeatureCombiner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,15 @@ private static string GetTerms(IDataView data, string colName)
{
Contracts.AssertValue(data);
Contracts.AssertNonWhiteSpace(colName);
int col;
var schema = data.Schema;
if (!schema.TryGetColumnIndex(colName, out col))
var col = schema.GetColumnOrNull(colName);
if (!col.HasValue)
return null;
var type = schema[col].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type;
var type = col.Value.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type;
if (type == null || !type.IsKnownSizeVector || !(type.ItemType is TextType))
return null;
var metadata = default(VBuffer<ReadOnlyMemory<char>>);
schema[col].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref metadata);
col.Value.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref metadata);
if (!metadata.IsDense)
return null;
var sb = new StringBuilder();
Expand Down Expand Up @@ -231,10 +231,11 @@ public static CommonOutputs.TransformOutput PrepareClassificationLabel(IHostEnvi
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

int labelCol;
if (!input.Data.Schema.TryGetColumnIndex(input.LabelColumn, out labelCol))
throw host.Except($"Column '{input.LabelColumn}' not found.");
var labelType = input.Data.Schema[labelCol].Type;
var labelCol = input.Data.Schema.GetColumnOrNull(input.LabelColumn);
if (!labelCol.HasValue)
throw host.ExceptSchemaMismatch(nameof(input), "Label", input.LabelColumn);

var labelType = labelCol.Value.Type;
if (labelType.IsKey || labelType is BoolType)
{
var nop = NopTransform.CreateIfNeeded(env, input.Data);
Expand Down Expand Up @@ -266,10 +267,10 @@ public static CommonOutputs.TransformOutput ConvertPredictedLabel(IHostEnvironme
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

int predictedLabelCol;
if (!input.Data.Schema.TryGetColumnIndex(input.PredictedLabelColumn, out predictedLabelCol))
throw host.Except($"Column '{input.PredictedLabelColumn}' not found.");
var predictedLabelType = input.Data.Schema[predictedLabelCol].Type;
var predictedLabelCol = input.Data.Schema.GetColumnOrNull(input.PredictedLabelColumn);
if (!predictedLabelCol.HasValue)
throw host.ExceptSchemaMismatch(nameof(input), "PredictedLabel",input.PredictedLabelColumn);
var predictedLabelType = predictedLabelCol.Value.Type;
if (predictedLabelType is NumberType || predictedLabelType is BoolType)
{
var nop = NopTransform.CreateIfNeeded(env, input.Data);
Expand All @@ -288,10 +289,10 @@ public static CommonOutputs.TransformOutput PrepareRegressionLabel(IHostEnvironm
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

int labelCol;
if (!input.Data.Schema.TryGetColumnIndex(input.LabelColumn, out labelCol))
var labelCol = input.Data.Schema.GetColumnOrNull(input.LabelColumn);
if (!labelCol.HasValue)
throw host.Except($"Column '{input.LabelColumn}' not found.");
var labelType = input.Data.Schema[labelCol].Type;
var labelType = labelCol.Value.Type;
if (labelType == NumberType.R4 || !(labelType is NumberType))
{
var nop = NopTransform.CreateIfNeeded(env, input.Data);
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.FastTree/GamTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1285,8 +1285,8 @@ public Context(IChannel ch, GamModelParametersBase pred, RoleMappedData data, IE
new RoleMappedSchema.ColumnRole(MetadataUtils.Const.ScoreValueKind.Score).Bind(DefaultColumnNames.Score));
}

_data.Schema.Schema.TryGetColumnIndex(DefaultColumnNames.Features, out int featureIndex);
MetadataUtils.TryGetCategoricalFeatureIndices(_data.Schema.Schema, featureIndex, out _catsMap);
var featureCol = _data.Schema.Schema[DefaultColumnNames.Features];
MetadataUtils.TryGetCategoricalFeatureIndices(_data.Schema.Schema, featureCol.Index, out _catsMap);
}

public FeatureInfo GetInfoForIndex(int index) => FeatureInfo.GetInfoForIndex(this, index);
Expand Down
9 changes: 5 additions & 4 deletions src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -779,10 +779,11 @@ private static IDataView AppendLabelTransform(IHostEnvironment env, IChannel ch,
ch.AssertValue(input);
ch.AssertNonWhiteSpace(labelName);

int col;
if (!input.Schema.TryGetColumnIndex(labelName, out col))
throw ch.Except("Label column '{0}' not found.", labelName);
ColumnType labelType = input.Schema[col].Type;
var col = input.Schema.GetColumnOrNull(labelName);
if (!col.HasValue)
throw ch.ExceptSchemaMismatch(nameof(input), "Label", labelName);

ColumnType labelType = col.Value.Type;
if (!labelType.IsKey)
{
if (labelPermutationSeed != 0)
Expand Down
7 changes: 5 additions & 2 deletions src/Microsoft.ML.HalLearners/VectorWhitening.cs
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,12 @@ private static void GetColTypesAndIndex(IHostEnvironment env, IDataView inputDat

for (int i = 0; i < columns.Length; i++)
{
if (!inputSchema.TryGetColumnIndex(columns[i].Input, out cols[i]))
var col = inputSchema.GetColumnOrNull(columns[i].Input);
if (!col.HasValue)
throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", columns[i].Input);
Copy link
Contributor

@TomFinley TomFinley Jan 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExceptSchemaMismatch [](start = 30, length = 20)

This would be an example of a current existing usage of ExceptSchemaMismatch, perhaps a pattern to follow in the other places where we are currently using Except. #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!


In reply to: 246524124 [](ancestors = 246524124)

srcTypes[i] = inputSchema[cols[i]].Type;

cols[i] = col.Value.Index;
srcTypes[i] = col.Value.Type;
var reason = TestColumn(srcTypes[i]);
if (reason != null)
throw env.ExceptParam(nameof(inputData.Schema), reason);
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ private CategoricalMetaData GetCategoricalMetaData(IChannel ch, RoleMappedData t
ch.Info("Auto-tuning parameters: " + nameof(Args.UseCat) + " = " + useCat);
if (useCat)
{
trainData.Schema.Schema.TryGetColumnIndex(DefaultColumnNames.Features, out int featureIndex);
MetadataUtils.TryGetCategoricalFeatureIndices(trainData.Schema.Schema, featureIndex, out categoricalFeatures);
var featureCol = trainData.Schema.Schema[DefaultColumnNames.Features];
MetadataUtils.TryGetCategoricalFeatureIndices(trainData.Schema.Schema, featureCol.Index, out categoricalFeatures);
}
var colType = trainData.Schema.Feature.Value.Type;
int rawNumCol = colType.VectorSize;
Expand Down
6 changes: 4 additions & 2 deletions src/Microsoft.ML.OnnxTransform/OnnxTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,10 @@ public Mapper(OnnxTransform parent, Schema inputSchema) :
_inputTensorShapes[i] = inputShape.ToList();
_inputOnnxTypes[i] = inputNodeInfo.Type;

if (!inputSchema.TryGetColumnIndex(_parent.Inputs[i], out _inputColIndices[i]))
throw Host.Except($"Column {_parent.Inputs[i]} doesn't exist");
var col = inputSchema.GetColumnOrNull(_parent.Inputs[i]);
if (!col.HasValue)
throw Host.ExceptSchemaMismatch( nameof(inputSchema),"input", _parent.Inputs[i]);
_inputColIndices[i] = col.Value.Index;

var type = inputSchema[_inputColIndices[i]].Type;
_isInputVector[i] = type.IsVector;
Expand Down
Loading