diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs index 0c2655fff1..639d4df13a 100644 --- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs @@ -53,9 +53,10 @@ public BoundBase(SchemaBindablePipelineEnsembleBase parent, RoleMappedSchema sch for (int i = 0; i < Parent._inputCols.Length; i++) { var name = Parent._inputCols[i]; - if (!InputRoleMappedSchema.Schema.TryGetColumnIndex(name, out int col)) - throw Parent.Host.Except("Schema does not contain required input column '{0}'", name); - _inputColIndices.Add(col); + var col = InputRoleMappedSchema.Schema.GetColumnOrNull(name); + if (!col.HasValue) + throw Parent.Host.ExceptSchemaMismatch(nameof(InputRoleMappedSchema), "input", name); + _inputColIndices.Add(col.Value.Index); } Mappers = new ISchemaBoundRowMapper[Parent.PredictorModels.Length]; @@ -74,8 +75,10 @@ public BoundBase(SchemaBindablePipelineEnsembleBase parent, RoleMappedSchema sch throw Parent.Host.Except("Predictor {0} is not a row to row mapper", i); // Make sure there is a score column, and remember its index. - if (!Mappers[i].OutputSchema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out ScoreCols[i])) + var scoreCol = Mappers[i].OutputSchema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + if (!scoreCol.HasValue) throw Parent.Host.Except("Predictor {0} does not contain a score column", i); + ScoreCols[i] = scoreCol.Value.Index; // Get the pipeline. var dv = new EmptyDataView(Parent.Host, schema.Schema); diff --git a/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs b/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs index ed3a554a64..00836338ff 100644 --- a/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs +++ b/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs @@ -115,15 +115,15 @@ private static string GetTerms(IDataView data, string colName) { Contracts.AssertValue(data); Contracts.AssertNonWhiteSpace(colName); - int col; var schema = data.Schema; - if (!schema.TryGetColumnIndex(colName, out col)) + var col = schema.GetColumnOrNull(colName); + if (!col.HasValue) return null; - var type = schema[col].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; + var type = col.Value.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; if (type == null || !type.IsKnownSizeVector || !(type.ItemType is TextType)) return null; var metadata = default(VBuffer>); - schema[col].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref metadata); + col.Value.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref metadata); if (!metadata.IsDense) return null; var sb = new StringBuilder(); @@ -231,10 +231,11 @@ public static CommonOutputs.TransformOutput PrepareClassificationLabel(IHostEnvi host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - int labelCol; - if (!input.Data.Schema.TryGetColumnIndex(input.LabelColumn, out labelCol)) - throw host.Except($"Column '{input.LabelColumn}' not found."); - var labelType = input.Data.Schema[labelCol].Type; + var labelCol = input.Data.Schema.GetColumnOrNull(input.LabelColumn); + if (!labelCol.HasValue) + throw host.ExceptSchemaMismatch(nameof(input), "Label", input.LabelColumn); + + var labelType = labelCol.Value.Type; if (labelType.IsKey || labelType is BoolType) { var nop = NopTransform.CreateIfNeeded(env, input.Data); @@ -266,10 +267,10 @@ public static CommonOutputs.TransformOutput ConvertPredictedLabel(IHostEnvironme host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - int predictedLabelCol; - if (!input.Data.Schema.TryGetColumnIndex(input.PredictedLabelColumn, out predictedLabelCol)) - throw host.Except($"Column '{input.PredictedLabelColumn}' not found."); - var predictedLabelType = input.Data.Schema[predictedLabelCol].Type; + var predictedLabelCol = input.Data.Schema.GetColumnOrNull(input.PredictedLabelColumn); + if (!predictedLabelCol.HasValue) + throw host.ExceptSchemaMismatch(nameof(input), "PredictedLabel",input.PredictedLabelColumn); + var predictedLabelType = predictedLabelCol.Value.Type; if (predictedLabelType is NumberType || predictedLabelType is BoolType) { var nop = NopTransform.CreateIfNeeded(env, input.Data); @@ -288,10 +289,10 @@ public static CommonOutputs.TransformOutput PrepareRegressionLabel(IHostEnvironm host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - int labelCol; - if (!input.Data.Schema.TryGetColumnIndex(input.LabelColumn, out labelCol)) + var labelCol = input.Data.Schema.GetColumnOrNull(input.LabelColumn); + if (!labelCol.HasValue) throw host.Except($"Column '{input.LabelColumn}' not found."); - var labelType = input.Data.Schema[labelCol].Type; + var labelType = labelCol.Value.Type; if (labelType == NumberType.R4 || !(labelType is NumberType)) { var nop = NopTransform.CreateIfNeeded(env, input.Data); diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index f779580a8d..ba8b1242b4 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -1285,8 +1285,8 @@ public Context(IChannel ch, GamModelParametersBase pred, RoleMappedData data, IE new RoleMappedSchema.ColumnRole(MetadataUtils.Const.ScoreValueKind.Score).Bind(DefaultColumnNames.Score)); } - _data.Schema.Schema.TryGetColumnIndex(DefaultColumnNames.Features, out int featureIndex); - MetadataUtils.TryGetCategoricalFeatureIndices(_data.Schema.Schema, featureIndex, out _catsMap); + var featureCol = _data.Schema.Schema[DefaultColumnNames.Features]; + MetadataUtils.TryGetCategoricalFeatureIndices(_data.Schema.Schema, featureCol.Index, out _catsMap); } public FeatureInfo GetInfoForIndex(int index) => FeatureInfo.GetInfoForIndex(this, index); diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index fd0fbd832c..be7f501a4e 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -779,10 +779,11 @@ private static IDataView AppendLabelTransform(IHostEnvironment env, IChannel ch, ch.AssertValue(input); ch.AssertNonWhiteSpace(labelName); - int col; - if (!input.Schema.TryGetColumnIndex(labelName, out col)) - throw ch.Except("Label column '{0}' not found.", labelName); - ColumnType labelType = input.Schema[col].Type; + var col = input.Schema.GetColumnOrNull(labelName); + if (!col.HasValue) + throw ch.ExceptSchemaMismatch(nameof(input), "Label", labelName); + + ColumnType labelType = col.Value.Type; if (!labelType.IsKey) { if (labelPermutationSeed != 0) diff --git a/src/Microsoft.ML.HalLearners/VectorWhitening.cs b/src/Microsoft.ML.HalLearners/VectorWhitening.cs index 9d7daf09ec..23c161ea00 100644 --- a/src/Microsoft.ML.HalLearners/VectorWhitening.cs +++ b/src/Microsoft.ML.HalLearners/VectorWhitening.cs @@ -381,9 +381,12 @@ private static void GetColTypesAndIndex(IHostEnvironment env, IDataView inputDat for (int i = 0; i < columns.Length; i++) { - if (!inputSchema.TryGetColumnIndex(columns[i].Input, out cols[i])) + var col = inputSchema.GetColumnOrNull(columns[i].Input); + if (!col.HasValue) throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", columns[i].Input); - srcTypes[i] = inputSchema[cols[i]].Type; + + cols[i] = col.Value.Index; + srcTypes[i] = col.Value.Type; var reason = TestColumn(srcTypes[i]); if (reason != null) throw env.ExceptParam(nameof(inputData.Schema), reason); diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index ec25a73c30..6d9987c35d 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -285,8 +285,8 @@ private CategoricalMetaData GetCategoricalMetaData(IChannel ch, RoleMappedData t ch.Info("Auto-tuning parameters: " + nameof(Args.UseCat) + " = " + useCat); if (useCat) { - trainData.Schema.Schema.TryGetColumnIndex(DefaultColumnNames.Features, out int featureIndex); - MetadataUtils.TryGetCategoricalFeatureIndices(trainData.Schema.Schema, featureIndex, out categoricalFeatures); + var featureCol = trainData.Schema.Schema[DefaultColumnNames.Features]; + MetadataUtils.TryGetCategoricalFeatureIndices(trainData.Schema.Schema, featureCol.Index, out categoricalFeatures); } var colType = trainData.Schema.Feature.Value.Type; int rawNumCol = colType.VectorSize; diff --git a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs index 05a792260c..dc04d13355 100644 --- a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs +++ b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs @@ -348,8 +348,10 @@ public Mapper(OnnxTransform parent, Schema inputSchema) : _inputTensorShapes[i] = inputShape.ToList(); _inputOnnxTypes[i] = inputNodeInfo.Type; - if (!inputSchema.TryGetColumnIndex(_parent.Inputs[i], out _inputColIndices[i])) - throw Host.Except($"Column {_parent.Inputs[i]} doesn't exist"); + var col = inputSchema.GetColumnOrNull(_parent.Inputs[i]); + if (!col.HasValue) + throw Host.ExceptSchemaMismatch( nameof(inputSchema),"input", _parent.Inputs[i]); + _inputColIndices[i] = col.Value.Index; var type = inputSchema[_inputColIndices[i]].Type; _isInputVector[i] = type.IsVector; diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index b7c883d53b..e5996b00be 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -500,23 +500,23 @@ public void EntryPointCreateEnsemble() using (var curs3 = regScored.GetRowCursor(col => true)) using (var curs4 = zippedScores.GetRowCursor(col => true)) { - var found = curs1.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out int scoreCol); - Assert.True(found); - var avgScoreGetter = curs1.GetGetter(scoreCol); + var scoreColumn = curs1.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var avgScoreGetter = curs1.GetGetter(scoreColumn.Value.Index); - found = curs2.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreCol); - Assert.True(found); - var medScoreGetter = curs2.GetGetter(scoreCol); + scoreColumn = curs2.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var medScoreGetter = curs2.GetGetter(scoreColumn.Value.Index); - found = curs3.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreCol); - Assert.True(found); - var regScoreGetter = curs3.GetGetter(scoreCol); + scoreColumn = curs3.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var regScoreGetter = curs3.GetGetter(scoreColumn.Value.Index); var individualScoreGetters = new ValueGetter[nModels]; for (int i = 0; i < nModels; i++) { - curs4.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score + i, out scoreCol); - individualScoreGetters[i] = curs4.GetGetter(scoreCol); + scoreColumn = curs4.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score + i); + individualScoreGetters[i] = curs4.GetGetter(scoreColumn.Value.Index); } var scoreBuffer = new Single[nModels]; @@ -845,28 +845,28 @@ public void EntryPointPipelineEnsemble() }).ScoredData; // Make sure the scorers have the correct types. - var hasScoreCol = binaryScored.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out int scoreIndex); - Assert.True(hasScoreCol, "Data scored with binary ensemble does not have a score column"); - var type = binaryScored.Schema[scoreIndex].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.ScoreColumnKind)?.Type; + var scoreCol = binaryScored.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreCol.HasValue, "Data scored with binary ensemble does not have a score column"); + var type = binaryScored.Schema[scoreCol.Value.Index].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.ScoreColumnKind)?.Type; Assert.True(type is TextType, "Binary ensemble scored data does not have correct type of metadata."); var kind = default(ReadOnlyMemory); - binaryScored.Schema[scoreIndex].Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnKind, ref kind); + binaryScored.Schema[scoreCol.Value.Index].Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnKind, ref kind); Assert.True(ReadOnlyMemoryUtils.EqualsStr(MetadataUtils.Const.ScoreColumnKind.BinaryClassification, kind), $"Binary ensemble scored data column type should be '{MetadataUtils.Const.ScoreColumnKind.BinaryClassification}', but is instead '{kind}'"); - hasScoreCol = regressionScored.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreIndex); - Assert.True(hasScoreCol, "Data scored with regression ensemble does not have a score column"); - type = regressionScored.Schema[scoreIndex].Metadata.Schema[MetadataUtils.Kinds.ScoreColumnKind].Type; + scoreCol = regressionScored.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreCol.HasValue, "Data scored with regression ensemble does not have a score column"); + type = regressionScored.Schema[scoreCol.Value.Index].Metadata.Schema[MetadataUtils.Kinds.ScoreColumnKind].Type; Assert.True(type is TextType, "Regression ensemble scored data does not have correct type of metadata."); - regressionScored.Schema[scoreIndex].Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnKind, ref kind); + regressionScored.Schema[scoreCol.Value.Index].Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnKind, ref kind); Assert.True(ReadOnlyMemoryUtils.EqualsStr(MetadataUtils.Const.ScoreColumnKind.Regression, kind), $"Regression ensemble scored data column type should be '{MetadataUtils.Const.ScoreColumnKind.Regression}', but is instead '{kind}'"); - hasScoreCol = anomalyScored.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreIndex); - Assert.True(hasScoreCol, "Data scored with anomaly detection ensemble does not have a score column"); - type = anomalyScored.Schema[scoreIndex].Metadata.Schema[MetadataUtils.Kinds.ScoreColumnKind].Type; + scoreCol = anomalyScored.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreCol.HasValue, "Data scored with anomaly detection ensemble does not have a score column"); + type = anomalyScored.Schema[scoreCol.Value.Index].Metadata.Schema[MetadataUtils.Kinds.ScoreColumnKind].Type; Assert.True(type is TextType, "Anomaly detection ensemble scored data does not have correct type of metadata."); - anomalyScored.Schema[scoreIndex].Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnKind, ref kind); + anomalyScored.Schema[scoreCol.Value.Index].Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnKind, ref kind); Assert.True(ReadOnlyMemoryUtils.EqualsStr(MetadataUtils.Const.ScoreColumnKind.AnomalyDetection, kind), $"Anomaly detection ensemble scored data column type should be '{MetadataUtils.Const.ScoreColumnKind.AnomalyDetection}', but is instead '{kind}'"); @@ -898,36 +898,36 @@ public void EntryPointPipelineEnsemble() using (var curs4 = individualScores[4].GetRowCursor(col => true)) using (var cursSaved = scoredFromSaved.GetRowCursor(col => true)) { - var good = curs0.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out int col); - Assert.True(good); - var getter0 = curs0.GetGetter(col); - good = curs1.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getter1 = curs1.GetGetter(col); - good = curs2.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getter2 = curs2.GetGetter(col); - good = curs3.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getter3 = curs3.GetGetter(col); - good = curs4.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getter4 = curs4.GetGetter(col); - good = cursReg.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getterReg = cursReg.GetGetter(col); - good = cursBin.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getterBin = cursBin.GetGetter(col); - good = cursBinCali.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getterBinCali = cursBinCali.GetGetter(col); - good = cursSaved.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getterSaved = cursSaved.GetGetter(col); - good = cursAnom.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getterAnom = cursAnom.GetGetter(col); + var scoreColumn = curs0.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getter0 = curs0.GetGetter(scoreColumn.Value.Index); + scoreColumn = curs1.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getter1 = curs1.GetGetter(scoreColumn.Value.Index); + scoreColumn = curs2.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getter2 = curs2.GetGetter(scoreColumn.Value.Index); + scoreColumn = curs3.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getter3 = curs3.GetGetter(scoreColumn.Value.Index); + scoreColumn = curs4.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getter4 = curs4.GetGetter(scoreColumn.Value.Index); + scoreColumn = cursReg.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getterReg = cursReg.GetGetter(scoreColumn.Value.Index); + scoreColumn = cursBin.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getterBin = cursBin.GetGetter(scoreColumn.Value.Index); + scoreColumn = cursBinCali.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getterBinCali = cursBinCali.GetGetter(scoreColumn.Value.Index); + scoreColumn = cursSaved.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getterSaved = cursSaved.GetGetter(scoreColumn.Value.Index); + scoreColumn = cursAnom.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getterAnom = cursAnom.GetGetter(scoreColumn.Value.Index); var c = new Average(Env).GetCombiner(); while (cursReg.MoveNext()) @@ -1121,33 +1121,33 @@ public void EntryPointPipelineEnsembleText() using (var curs4 = individualScores[4].GetRowCursor(col => true)) using (var cursSaved = scoredFromSaved.GetRowCursor(col => true)) { - var good = curs0.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out int col); - Assert.True(good); - var getter0 = curs0.GetGetter(col); - good = curs1.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getter1 = curs1.GetGetter(col); - good = curs2.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getter2 = curs2.GetGetter(col); - good = curs3.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getter3 = curs3.GetGetter(col); - good = curs4.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getter4 = curs4.GetGetter(col); - good = cursReg.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getterReg = cursReg.GetGetter(col); - good = cursBin.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getterBin = cursBin.GetGetter(col); - good = cursBinCali.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getterBinCali = cursBinCali.GetGetter(col); - good = cursSaved.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getterSaved = cursSaved.GetGetter(col); + var scoreColumn = curs0.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getter0 = curs0.GetGetter(scoreColumn.Value.Index); + scoreColumn = curs1.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getter1 = curs1.GetGetter(scoreColumn.Value.Index); + scoreColumn = curs2.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getter2 = curs2.GetGetter(scoreColumn.Value.Index); + scoreColumn = curs3.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getter3 = curs3.GetGetter(scoreColumn.Value.Index); + scoreColumn = curs4.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getter4 = curs4.GetGetter(scoreColumn.Value.Index); + scoreColumn = cursReg.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getterReg = cursReg.GetGetter(scoreColumn.Value.Index); + scoreColumn = cursBin.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getterBin = cursBin.GetGetter(scoreColumn.Value.Index); + scoreColumn = cursBinCali.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getterBinCali = cursBinCali.GetGetter(scoreColumn.Value.Index); + scoreColumn = cursSaved.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getterSaved = cursSaved.GetGetter(scoreColumn.Value.Index); var c = new Average(Env).GetCombiner(); while (cursReg.MoveNext()) @@ -1282,27 +1282,27 @@ public void EntryPointMulticlassPipelineEnsemble() using (var curs3 = individualScores[3].GetRowCursor(col => true)) using (var curs4 = individualScores[4].GetRowCursor(col => true)) { - var good = curs0.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out int col); - Assert.True(good); - var getter0 = curs0.GetGetter>(col); - good = curs1.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getter1 = curs1.GetGetter>(col); - good = curs2.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getter2 = curs2.GetGetter>(col); - good = curs3.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getter3 = curs3.GetGetter>(col); - good = curs4.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getter4 = curs4.GetGetter>(col); - good = curs.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getter = curs.GetGetter>(col); - good = cursSaved.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out col); - Assert.True(good); - var getterSaved = cursSaved.GetGetter>(col); + var scoreColumn = curs0.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getter0 = curs0.GetGetter>(scoreColumn.Value.Index); + scoreColumn = curs1.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getter1 = curs1.GetGetter>(scoreColumn.Value.Index); + scoreColumn = curs2.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getter2 = curs2.GetGetter>(scoreColumn.Value.Index); + scoreColumn = curs3.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getter3 = curs3.GetGetter>(scoreColumn.Value.Index); + scoreColumn = curs4.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getter4 = curs4.GetGetter>(scoreColumn.Value.Index); + scoreColumn = curs.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getter = curs.GetGetter>(scoreColumn.Value.Index); + scoreColumn = cursSaved.Schema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); + Assert.True(scoreColumn.HasValue); + var getterSaved = cursSaved.GetGetter>(scoreColumn.Value.Index); var c = new MultiAverage(Env, new MultiAverage.Arguments()).GetCombiner(); VBuffer score = default(VBuffer); @@ -1368,7 +1368,7 @@ public void EntryPointPipelineEnsembleGetSummary() for (int i = 0; i < nModels; i++) { var data = splitOutput.TrainData[i]; - data = new OneHotEncodingEstimator(Env,"Cat").Fit(data).Transform(data); + data = new OneHotEncodingEstimator(Env, "Cat").Fit(data).Transform(data); data = new ColumnConcatenatingTransformer(Env, new ColumnConcatenatingTransformer.ColumnInfo("Features", i % 2 == 0 ? new[] { "Features", "Cat" } : new[] { "Cat", "Features" })).Transform(data); if (i % 2 == 0) { @@ -1377,9 +1377,9 @@ public void EntryPointPipelineEnsembleGetSummary() TrainingData = data, NormalizeFeatures = NormalizeOption.Yes, NumThreads = 1, - ShowTrainingStats = true, + ShowTrainingStats = true, StdComputer = new ComputeLRTrainingStdThroughHal() - }; + }; predictorModels[i] = LogisticRegression.TrainBinary(Env, lrInput).PredictorModel; var transformModel = new TransformModelImpl(Env, data, splitOutput.TrainData[i]); @@ -1691,15 +1691,15 @@ public void EntryPointTextToKeyToText() ReadOnlyMemory catValue = default; uint catKey = 0; - bool success = loader.Schema.TryGetColumnIndex("Cat", out int catCol); - Assert.True(success); - var catGetter = cursor.GetGetter>(catCol); - success = loader.Schema.TryGetColumnIndex("CatValue", out int catValueCol); - Assert.True(success); - var catValueGetter = cursor.GetGetter>(catValueCol); - success = loader.Schema.TryGetColumnIndex("Key", out int keyCol); - Assert.True(success); - var keyGetter = cursor.GetGetter(keyCol); + var catColumn = loader.Schema.GetColumnOrNull("Cat"); + Assert.True(catColumn.HasValue); + var catGetter = cursor.GetGetter>(catColumn.Value.Index); + var catValueCol = loader.Schema.GetColumnOrNull("CatValue"); + Assert.True(catValueCol.HasValue); + var catValueGetter = cursor.GetGetter>(catValueCol.Value.Index); + var keyColumn = loader.Schema.GetColumnOrNull("Key"); + Assert.True(keyColumn.HasValue); + var keyGetter = cursor.GetGetter(keyColumn.Value.Index); while (cursor.MoveNext()) { @@ -1938,8 +1938,8 @@ public void EntryPointEvaluateRanking() using (var loader = new BinaryLoader(Env, new BinaryLoader.Arguments(), instanceMetricsPath)) { Assert.Equal(103, CountRows(loader)); - Assert.True(loader.Schema.TryGetColumnIndex("GroupId", out var groupCol)); - Assert.True(loader.Schema.TryGetColumnIndex("Label", out var labelCol)); + Assert.NotNull(loader.Schema.GetColumnOrNull("GroupId")); + Assert.NotNull(loader.Schema.GetColumnOrNull("Label")); } } @@ -2737,8 +2737,9 @@ public void EntryPointTrainTestMacroNoTransformInput() Assert.NotNull(metrics); using (var cursor = metrics.GetRowCursor(col => true)) { - Assert.True(cursor.Schema.TryGetColumnIndex("AUC", out int aucCol)); - var aucGetter = cursor.GetGetter(aucCol); + var aucCol = cursor.Schema.GetColumnOrNull("AUC"); + Assert.True(aucCol.HasValue); + var aucGetter = cursor.GetGetter(aucCol.Value.Index); Assert.True(cursor.MoveNext()); double auc = 0; aucGetter(ref auc); @@ -2841,8 +2842,9 @@ public void EntryPointTrainTestMacro() Assert.NotNull(metrics); using (var cursor = metrics.GetRowCursor(col => true)) { - Assert.True(cursor.Schema.TryGetColumnIndex("AUC", out int aucCol)); - var aucGetter = cursor.GetGetter(aucCol); + var aucCol = cursor.Schema.GetColumnOrNull("AUC"); + Assert.True(aucCol.HasValue); + var aucGetter = cursor.GetGetter(aucCol.Value.Index); Assert.True(cursor.MoveNext()); double auc = 0; aucGetter(ref auc); @@ -3001,28 +3003,25 @@ public void EntryPointChainedTrainTestMacros() Assert.NotNull(model); var metrics = runner.GetOutput("OverallMetrics"); - Assert.NotNull(metrics); - using (var cursor = metrics.GetRowCursor(col => true)) + + Action validateAuc = (metricsIdv) => { - Assert.True(cursor.Schema.TryGetColumnIndex("AUC", out int aucCol)); - var aucGetter = cursor.GetGetter(aucCol); - Assert.True(cursor.MoveNext()); - double auc = 0; - aucGetter(ref auc); - Assert.True(auc > 0.99); - } + Assert.NotNull(metricsIdv); + using (var cursor = metricsIdv.GetRowCursor(col => true)) + { + var aucCol = cursor.Schema.GetColumnOrNull("AUC"); + var aucGetter = cursor.GetGetter(aucCol.Value.Index); + Assert.True(cursor.MoveNext()); + double auc = 0; + aucGetter(ref auc); + Assert.True(auc > 0.99); + } + }; + + validateAuc(metrics); metrics = runner.GetOutput("OverallMetrics2"); - Assert.NotNull(metrics); - using (var cursor = metrics.GetRowCursor(col => true)) - { - Assert.True(cursor.Schema.TryGetColumnIndex("AUC", out int aucCol)); - var aucGetter = cursor.GetGetter(aucCol); - Assert.True(cursor.MoveNext()); - double auc = 0; - aucGetter(ref auc); - Assert.True(auc > 0.99); - } + validateAuc(metrics); } [Fact] @@ -3195,28 +3194,26 @@ public void EntryPointChainedCrossValMacros() Assert.NotNull(model[0]); var metrics = runner.GetOutput("OverallMetrics"); - Assert.NotNull(metrics); - using (var cursor = metrics.GetRowCursor(col => true)) + + Action aucValidate = (metricsIdv) => { - Assert.True(cursor.Schema.TryGetColumnIndex("AUC", out int aucCol)); - var aucGetter = cursor.GetGetter(aucCol); - Assert.True(cursor.MoveNext()); - double auc = 0; - aucGetter(ref auc); - Assert.True(auc > 0.99); - } + Assert.NotNull(metricsIdv); + using (var cursor = metrics.GetRowCursor(col => true)) + { + var aucColumn = cursor.Schema.GetColumnOrNull("AUC"); + Assert.True(aucColumn.HasValue); + var aucGetter = cursor.GetGetter(aucColumn.Value.Index); + Assert.True(cursor.MoveNext()); + double auc = 0; + aucGetter(ref auc); + Assert.True(auc > 0.99); + } + }; + + aucValidate(metrics); metrics = runner.GetOutput("OverallMetrics2"); - Assert.NotNull(metrics); - using (var cursor = metrics.GetRowCursor(col => true)) - { - Assert.True(cursor.Schema.TryGetColumnIndex("AUC", out int aucCol)); - var aucGetter = cursor.GetGetter(aucCol); - Assert.True(cursor.MoveNext()); - double auc = 0; - aucGetter(ref auc); - Assert.True(auc > 0.99); - } + aucValidate(metrics); } [Fact] @@ -3378,8 +3375,8 @@ public void EntryPointLinearPredictorSummary() NormalizeFeatures = NormalizeOption.Yes, NumThreads = 1, ShowTrainingStats = true, - StdComputer= new ComputeLRTrainingStdThroughHal() - }; + StdComputer = new ComputeLRTrainingStdThroughHal() + }; var model = LogisticRegression.TrainBinary(Env, lrInput).PredictorModel; var mcLrInput = new MulticlassLogisticRegression.Arguments @@ -3546,9 +3543,9 @@ public void EntryPointPrepareLabelConvertPredictedLabel() { ReadOnlyMemory predictedLabel = default; - var success = loader.Schema.TryGetColumnIndex("PredictedLabel", out int predictedLabelCol); - Assert.True(success); - var predictedLabelGetter = cursor.GetGetter>(predictedLabelCol); + var predictedLabelCol = loader.Schema.GetColumnOrNull("PredictedLabel"); + Assert.True(predictedLabelCol.HasValue); + var predictedLabelGetter = cursor.GetGetter>(predictedLabelCol.Value.Index); while (cursor.MoveNext()) { @@ -3602,17 +3599,24 @@ public void EntryPointTreeLeafFeaturizer() }); var view = treeLeaf.OutputData; - Assert.True(view.Schema.TryGetColumnIndex("Trees", out int treesCol)); - Assert.True(view.Schema.TryGetColumnIndex("Leaves", out int leavesCol)); - Assert.True(view.Schema.TryGetColumnIndex("Paths", out int pathsCol)); + var treesCol = view.Schema.GetColumnOrNull("Trees"); + Assert.True(treesCol.HasValue); + + var leavesCol = view.Schema.GetColumnOrNull("Leaves"); + Assert.True(leavesCol.HasValue); + + var pathsCol = view.Schema.GetColumnOrNull("Paths"); + Assert.True(pathsCol.HasValue); + + VBuffer treeValues = default(VBuffer); VBuffer leafIndicators = default(VBuffer); VBuffer pathIndicators = default(VBuffer); - using (var curs = view.GetRowCursor(c => c == treesCol || c == leavesCol || c == pathsCol)) + using (var curs = view.GetRowCursor(c => c == treesCol.Value.Index || c == leavesCol.Value.Index || c == pathsCol.Value.Index)) { - var treesGetter = curs.GetGetter>(treesCol); - var leavesGetter = curs.GetGetter>(leavesCol); - var pathsGetter = curs.GetGetter>(pathsCol); + var treesGetter = curs.GetGetter>(treesCol.Value.Index); + var leavesGetter = curs.GetGetter>(leavesCol.Value.Index); + var pathsGetter = curs.GetGetter>(pathsCol.Value.Index); while (curs.MoveNext()) { treesGetter(ref treeValues); @@ -3659,8 +3663,9 @@ public void EntryPointWordEmbeddings() var result = embedding.OutputData; using (var cursor = result.GetRowCursor((x => true))) { - Assert.True(result.Schema.TryGetColumnIndex("Features", out int featColumn)); - var featGetter = cursor.GetGetter>(featColumn); + var featColumn = result.Schema.GetColumnOrNull("Features"); + Assert.True(featColumn.HasValue); + var featGetter = cursor.GetGetter>(featColumn.Value.Index); VBuffer feat = default; while (cursor.MoveNext()) { @@ -4076,12 +4081,12 @@ public void TestSimpleTrainExperiment() var data = runner.GetOutput("Var_2130b277d4e0485f9cc5162c176767fa"); var schema = data.Schema; - var b = schema.TryGetColumnIndex("AUC", out int aucCol); - Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == aucCol)) + var aucCol = schema.GetColumnOrNull("AUC"); + Assert.True(aucCol.HasValue); + using (var cursor = data.GetRowCursor(col => col == aucCol.Value.Index)) { - var getter = cursor.GetGetter(aucCol); - b = cursor.MoveNext(); + var getter = cursor.GetGetter(aucCol.Value.Index); + var b = cursor.MoveNext(); Assert.True(b); double auc = 0; getter(ref auc); @@ -4256,20 +4261,22 @@ public void TestCrossValidationMacro() var data = runner.GetOutput("overallMetrics"); var schema = data.Schema; - var b = schema.TryGetColumnIndex("L1(avg)", out int metricCol); - Assert.True(b); - b = schema.TryGetColumnIndex("Fold Index", out int foldCol); - Assert.True(b); - b = schema.TryGetColumnIndex("IsWeighted", out int isWeightedCol); - using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol || col == isWeightedCol)) + var metricCol = schema.GetColumnOrNull("L1(avg)"); + Assert.True(metricCol.HasValue); + var foldCol = schema.GetColumnOrNull("Fold Index"); + Assert.True(foldCol.HasValue); + var isWeightedCol = schema.GetColumnOrNull("IsWeighted"); + Assert.True(isWeightedCol.HasValue); + using (var cursor = data.GetRowCursor(col => col == metricCol.Value.Index || col == foldCol.Value.Index || col == isWeightedCol.Value.Index)) { - var getter = cursor.GetGetter(metricCol); - var foldGetter = cursor.GetGetter>(foldCol); + var getter = cursor.GetGetter(metricCol.Value.Index); + var foldGetter = cursor.GetGetter>(foldCol.Value.Index); ReadOnlyMemory fold = default; - var isWeightedGetter = cursor.GetGetter(isWeightedCol); + var isWeightedGetter = cursor.GetGetter(isWeightedCol.Value.Index); bool isWeighted = default; double avg = 0; double weightedAvg = 0; + bool b; for (int w = 0; w < 2; w++) { // Get the average. @@ -4438,18 +4445,18 @@ public void TestCrossValidationMacroWithMultiClass() var data = runner.GetOutput("overallMetrics"); var schema = data.Schema; - var b = schema.TryGetColumnIndex("Accuracy(micro-avg)", out int metricCol); - Assert.True(b); - b = schema.TryGetColumnIndex("Fold Index", out int foldCol); - Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) + var metricCol = schema.GetColumnOrNull("Accuracy(micro-avg)"); + Assert.True(metricCol.HasValue); + var foldCol = schema.GetColumnOrNull("Fold Index"); + Assert.True(foldCol.HasValue); + using (var cursor = data.GetRowCursor(col => col == metricCol.Value.Index || col == foldCol.Value.Index)) { - var getter = cursor.GetGetter(metricCol); - var foldGetter = cursor.GetGetter>(foldCol); + var getter = cursor.GetGetter(metricCol.Value.Index); + var foldGetter = cursor.GetGetter>(foldCol.Value.Index); ReadOnlyMemory fold = default; // Get the average. - b = cursor.MoveNext(); + var b = cursor.MoveNext(); Assert.True(b); double avg = 0; getter(ref avg); @@ -4483,14 +4490,14 @@ public void TestCrossValidationMacroWithMultiClass() var confusion = runner.GetOutput("confusionMatrix"); schema = confusion.Schema; - b = schema.TryGetColumnIndex("Count", out int countCol); - Assert.True(b); - b = schema.TryGetColumnIndex("Fold Index", out foldCol); - Assert.True(b); - var type = schema[countCol].Metadata.Schema[MetadataUtils.Kinds.SlotNames].Type; + var countCol = schema.GetColumnOrNull("Count"); + Assert.True(countCol.HasValue); + foldCol = schema.GetColumnOrNull("Fold Index"); + Assert.True(foldCol.HasValue); + var type = schema["Count"].Metadata.Schema[MetadataUtils.Kinds.SlotNames].Type; Assert.True(type is VectorType vecType && vecType.ItemType is TextType && vecType.Size == 10); var slotNames = default(VBuffer>); - schema[countCol].GetSlotNames(ref slotNames); + schema["Count"].GetSlotNames(ref slotNames); var slotNameValues = slotNames.GetValues(); for (int i = 0; i < slotNameValues.Length; i++) { @@ -4498,8 +4505,8 @@ public void TestCrossValidationMacroWithMultiClass() } using (var curs = confusion.GetRowCursor(col => true)) { - var countGetter = curs.GetGetter>(countCol); - var foldGetter = curs.GetGetter>(foldCol); + var countGetter = curs.GetGetter>(countCol.Value.Index); + var foldGetter = curs.GetGetter>(foldCol.Value.Index); var confCount = default(VBuffer); var foldIndex = default(ReadOnlyMemory); int rowCount = 0; @@ -4666,13 +4673,13 @@ public void TestCrossValidationMacroMultiClassWithWarnings() var warnings = runner.GetOutput("warning"); var schema = warnings.Schema; - var b = schema.TryGetColumnIndex("WarningText", out int warningCol); - Assert.True(b); - using (var cursor = warnings.GetRowCursor(col => col == warningCol)) + var warningCol = schema.GetColumnOrNull("WarningText"); + Assert.True(warningCol.HasValue); + using (var cursor = warnings.GetRowCursor(col => col == warningCol.Value.Index)) { - var getter = cursor.GetGetter>(warningCol); + var getter = cursor.GetGetter>(warningCol.Value.Index); - b = cursor.MoveNext(); + var b = cursor.MoveNext(); Assert.True(b); var warning = default(ReadOnlyMemory); getter(ref warning); @@ -4846,14 +4853,15 @@ public void TestCrossValidationMacroWithStratification() var data = runner.GetOutput("overallmetrics"); var schema = data.Schema; - var b = schema.TryGetColumnIndex("AUC", out int metricCol); - Assert.True(b); - b = schema.TryGetColumnIndex("Fold Index", out int foldCol); - Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) + var metricCol = schema.GetColumnOrNull("AUC"); + Assert.True(metricCol.HasValue); + var foldCol = schema.GetColumnOrNull("Fold Index"); + Assert.True(foldCol.HasValue); + bool b; + using (var cursor = data.GetRowCursor(col => col == metricCol.Value.Index || col == foldCol.Value.Index)) { - var getter = cursor.GetGetter(metricCol); - var foldGetter = cursor.GetGetter>(foldCol); + var getter = cursor.GetGetter(metricCol.Value.Index); + var foldGetter = cursor.GetGetter>(foldCol.Value.Index); ReadOnlyMemory fold = default; // Get the verage. @@ -5149,14 +5157,15 @@ public void TestCrossValidationMacroWithNonDefaultNames() var data = runner.GetOutput("overallMetrics"); var schema = data.Schema; - var b = schema.TryGetColumnIndex("NDCG", out int metricCol); - Assert.True(b); - b = schema.TryGetColumnIndex("Fold Index", out int foldCol); - Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) + var metricCol = schema.GetColumnOrNull("NDCG"); + Assert.True(metricCol.HasValue); + var foldCol = schema.GetColumnOrNull("Fold Index"); + Assert.True(foldCol.HasValue); + bool b; + using (var cursor = data.GetRowCursor(col => col == metricCol.Value.Index || col == foldCol.Value.Index)) { - var getter = cursor.GetGetter>(metricCol); - var foldGetter = cursor.GetGetter>(foldCol); + var getter = cursor.GetGetter>(metricCol.Value.Index); + var foldGetter = cursor.GetGetter>(foldCol.Value.Index); ReadOnlyMemory fold = default; // Get the verage. @@ -5203,10 +5212,11 @@ public void TestCrossValidationMacroWithNonDefaultNames() } data = runner.GetOutput("perInstanceMetric"); - Assert.True(data.Schema.TryGetColumnIndex("Instance", out int nameCol)); - using (var cursor = data.GetRowCursor(col => col == nameCol)) + var nameCol = data.Schema.GetColumnOrNull("Instance"); + Assert.True(nameCol.HasValue); + using (var cursor = data.GetRowCursor(col => col == nameCol.Value.Index)) { - var getter = cursor.GetGetter>(nameCol); + var getter = cursor.GetGetter>(nameCol.Value.Index); while (cursor.MoveNext()) { ReadOnlyMemory name = default; @@ -5366,11 +5376,12 @@ public void TestOvaMacro() var data = runner.GetOutput("overallMetrics"); var schema = data.Schema; - var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol); - Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == accCol)) + var accCol = schema.GetColumnOrNull(MultiClassClassifierEvaluator.AccuracyMacro); + Assert.True(accCol.HasValue); + bool b; + using (var cursor = data.GetRowCursor(col => col == accCol.Value.Index)) { - var getter = cursor.GetGetter(accCol); + var getter = cursor.GetGetter(accCol.Value.Index); b = cursor.MoveNext(); Assert.True(b); double acc = 0; @@ -5537,11 +5548,12 @@ public void TestOvaMacroWithUncalibratedLearner() var data = runner.GetOutput("overallMetrics"); var schema = data.Schema; - var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol); - Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == accCol)) + var accCol = schema.GetColumnOrNull(MultiClassClassifierEvaluator.AccuracyMacro); + Assert.True(accCol.HasValue); + bool b; + using (var cursor = data.GetRowCursor(col => col == accCol.Value.Index)) { - var getter = cursor.GetGetter(accCol); + var getter = cursor.GetGetter(accCol.Value.Index); b = cursor.MoveNext(); Assert.True(b); double acc = 0;