diff --git a/src/Microsoft.ML.Api/CustomMappingTransformer.cs b/src/Microsoft.ML.Api/CustomMappingTransformer.cs index 3e4102aad3..a63c63c7a0 100644 --- a/src/Microsoft.ML.Api/CustomMappingTransformer.cs +++ b/src/Microsoft.ML.Api/CustomMappingTransformer.cs @@ -205,7 +205,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var addedCols = DataViewConstructionUtils.GetSchemaColumns(Transformer.AddedSchema); var addedSchemaShape = SchemaShape.Create(SchemaBuilder.MakeSchema(addedCols)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); var inputDef = InternalSchemaDefinition.Create(typeof(TSrc), Transformer.InputSchemaDefinition); foreach (var col in inputDef.Columns) { @@ -223,7 +223,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) } } - foreach (var addedCol in addedSchemaShape.Columns) + foreach (var addedCol in addedSchemaShape) result[addedCol.Name] = addedCol; return new SchemaShape(result.Values); diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index b82ad4f5a7..5994a671bd 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -5,8 +5,9 @@ using Microsoft.ML.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; -using System; +using System.Collections; using System.Collections.Generic; +using System.Collections.Immutable; using System.Linq; namespace Microsoft.ML.Core.Data @@ -16,13 +17,17 @@ namespace Microsoft.ML.Core.Data /// This is more relaxed than the proper , since it's only a subset of the columns, /// and also since it doesn't specify exact 's for vectors and keys. /// - public sealed class SchemaShape + public sealed class SchemaShape : IReadOnlyList { - public readonly Column[] Columns; + private readonly Column[] _columns; private static readonly SchemaShape _empty = new SchemaShape(Enumerable.Empty()); - public sealed class Column + public int Count => _columns.Count(); + + public Column this[int index] => _columns[index]; + + public struct Column { public enum VectorKind { @@ -55,13 +60,13 @@ public enum VectorKind /// public readonly SchemaShape Metadata; - public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, SchemaShape metadata = null) + [BestFriend] + internal Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, SchemaShape metadata = null) { Contracts.CheckNonEmpty(name, nameof(name)); Contracts.CheckValueOrNull(metadata); Contracts.CheckParam(!itemType.IsKey, nameof(itemType), "Item type cannot be a key"); Contracts.CheckParam(!itemType.IsVector, nameof(itemType), "Item type cannot be a vector"); - Contracts.CheckParam(!isKey || KeyType.IsValidDataKind(itemType.RawKind), nameof(itemType), "The item type must be valid for a key"); Name = name; @@ -80,9 +85,10 @@ public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, /// - The columns of of is a superset of our columns. /// - Each such metadata column is itself compatible with the input metadata column. /// - public bool IsCompatibleWith(Column inputColumn) + [BestFriend] + internal bool IsCompatibleWith(Column inputColumn) { - Contracts.CheckValue(inputColumn, nameof(inputColumn)); + Contracts.Check(inputColumn.IsValid, nameof(inputColumn)); if (Name != inputColumn.Name) return false; if (Kind != inputColumn.Kind) @@ -91,7 +97,7 @@ public bool IsCompatibleWith(Column inputColumn) return false; if (IsKey != inputColumn.IsKey) return false; - foreach (var metaCol in Metadata.Columns) + foreach (var metaCol in Metadata) { if (!inputColumn.Metadata.TryFindColumn(metaCol.Name, out var inputMetaCol)) return false; @@ -101,7 +107,8 @@ public bool IsCompatibleWith(Column inputColumn) return true; } - public string GetTypeString() + [BestFriend] + internal string GetTypeString() { string result = ItemType.ToString(); if (IsKey) @@ -112,13 +119,20 @@ public string GetTypeString() result = $"VarVector<{result}>"; return result; } + + /// + /// Return if this structure is not identical to the default value of . If true, + /// it means this structure is initialized properly and therefore considered as valid. + /// + [BestFriend] + internal bool IsValid => Name != null; } public SchemaShape(IEnumerable columns) { Contracts.CheckValue(columns, nameof(columns)); - Columns = columns.ToArray(); - Contracts.CheckParam(columns.All(c => c != null), nameof(columns), "No items should be null."); + _columns = columns.ToArray(); + Contracts.CheckParam(columns.All(c => c.IsValid), nameof(columns), "Some items are not initialized properly."); } /// @@ -151,7 +165,8 @@ internal static void GetColumnTypeShape(ColumnType type, /// /// Create a schema shape out of the fully defined schema. /// - public static SchemaShape Create(Schema schema) + [BestFriend] + internal static SchemaShape Create(Schema schema) { Contracts.CheckValue(schema, nameof(schema)); var cols = new List(); @@ -179,25 +194,23 @@ public static SchemaShape Create(Schema schema) /// /// Returns if there is a column with a specified and if so stores it in . /// - public bool TryFindColumn(string name, out Column column) + [BestFriend] + internal bool TryFindColumn(string name, out Column column) { Contracts.CheckValue(name, nameof(name)); - column = Columns.FirstOrDefault(x => x.Name == name); - return column != null; + column = _columns.FirstOrDefault(x => x.Name == name); + return column.IsValid; } + public IEnumerator GetEnumerator() => ((IEnumerable)_columns).GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + // REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape // as an input to another schema shape. I started writing, but realized that there's more than one way to check for // the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'. } - /// - /// Exception class for schema validation errors. - /// - public class SchemaException : Exception - { - } - /// /// The 'data reader' takes a certain kind of input and turns it into an . /// @@ -246,7 +259,6 @@ public interface ITransformer /// /// Schema propagation for transformers. /// Returns the output schema of the data, if the input schema is like the one provided. - /// Throws if the input schema is not valid for the transformer. /// Schema GetOutputSchema(Schema inputSchema); @@ -288,7 +300,6 @@ public interface IEstimator /// /// Schema propagation for estimators. /// Returns the output schema shape of the estimator, if the input schema shape is like the one provided. - /// Throws iff the input schema is not valid for the estimator. /// SchemaShape GetOutputSchema(SchemaShape inputSchema); } diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index 6e2bd5230f..3bb8f16ff0 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -367,7 +367,7 @@ public static bool IsNormalized(this Schema schema, int col) /// of a scalar type, which we assume, if set, should be true. public static bool IsNormalized(this SchemaShape.Column col) { - Contracts.CheckValue(col, nameof(col)); + Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly"); return col.Metadata.TryFindColumn(Kinds.IsNormalized, out var metaCol) && metaCol.Kind == SchemaShape.Column.VectorKind.Scalar && !metaCol.IsKey && metaCol.ItemType == BoolType.Instance; @@ -382,7 +382,7 @@ public static bool IsNormalized(this SchemaShape.Column col) /// metadata of definite sized vectors of text. public static bool HasSlotNames(this SchemaShape.Column col) { - Contracts.CheckValue(col, nameof(col)); + Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly"); return col.Kind == SchemaShape.Column.VectorKind.Vector && col.Metadata.TryFindColumn(Kinds.SlotNames, out var metaCol) && metaCol.Kind == SchemaShape.Column.VectorKind.Vector && !metaCol.IsKey diff --git a/src/Microsoft.ML.Core/Utilities/Contracts.cs b/src/Microsoft.ML.Core/Utilities/Contracts.cs index d567d6e883..cda7cbc539 100644 --- a/src/Microsoft.ML.Core/Utilities/Contracts.cs +++ b/src/Microsoft.ML.Core/Utilities/Contracts.cs @@ -758,6 +758,10 @@ public static void CheckAlive(this IHostEnvironment env) public static void CheckValueOrNull(T val) where T : class { } + + /// + /// This documents that the parameter can legally be null. + /// [Conditional("INVARIANT_CHECKS")] public static void CheckValueOrNull(this IExceptionContext ctx, T val) where T : class { diff --git a/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs index d94219a453..c67c6ae733 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs @@ -30,22 +30,22 @@ public FakeSchema(IHostEnvironment env, SchemaShape inputShape) { _env = env; _shape = inputShape; - _colMap = Enumerable.Range(0, _shape.Columns.Length) - .ToDictionary(idx => _shape.Columns[idx].Name, idx => idx); + _colMap = Enumerable.Range(0, _shape.Count) + .ToDictionary(idx => _shape[idx].Name, idx => idx); } - public int ColumnCount => _shape.Columns.Length; + public int ColumnCount => _shape.Count; public string GetColumnName(int col) { _env.Check(0 <= col && col < ColumnCount); - return _shape.Columns[col].Name; + return _shape[col].Name; } public ColumnType GetColumnType(int col) { _env.Check(0 <= col && col < ColumnCount); - var inputCol = _shape.Columns[col]; + var inputCol = _shape[col]; return MakeColumnType(inputCol); } @@ -66,7 +66,7 @@ private static ColumnType MakeColumnType(SchemaShape.Column inputCol) public void GetMetadata(string kind, int col, ref TValue value) { _env.Check(0 <= col && col < ColumnCount); - var inputCol = _shape.Columns[col]; + var inputCol = _shape[col]; var metaShape = inputCol.Metadata; if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn)) throw _env.ExceptGetMetadata(); @@ -89,7 +89,7 @@ public void GetMetadata(string kind, int col, ref TValue value) public ColumnType GetMetadataTypeOrNull(string kind, int col) { _env.Check(0 <= col && col < ColumnCount); - var inputCol = _shape.Columns[col]; + var inputCol = _shape[col]; var metaShape = inputCol.Metadata; if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn)) return null; @@ -99,12 +99,12 @@ public ColumnType GetMetadataTypeOrNull(string kind, int col) public IEnumerable> GetMetadataTypes(int col) { _env.Check(0 <= col && col < ColumnCount); - var inputCol = _shape.Columns[col]; + var inputCol = _shape[col]; var metaShape = inputCol.Metadata; if (metaShape == null) return Enumerable.Empty>(); - return metaShape.Columns.Select(c => new KeyValuePair(c.Name, MakeColumnType(c))); + return metaShape.Select(c => new KeyValuePair(c.Name, MakeColumnType(c))); } } } diff --git a/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs b/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs index 6ffd089f24..7b1edd3398 100644 --- a/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs +++ b/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs @@ -113,7 +113,7 @@ public void Check(IExceptionContext ectx, SchemaShape shape) private static Type GetTypeOrNull(SchemaShape.Column col) { - Contracts.AssertValue(col); + Contracts.Assert(col.IsValid); Type vecType = null; diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs index 1e49c32ed3..55d85201b7 100644 --- a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs @@ -55,13 +55,11 @@ public abstract class TrainerEstimatorBase : ITrainerEstim private protected TrainerEstimatorBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, - SchemaShape.Column weight = null) + SchemaShape.Column weight = default) { Contracts.CheckValue(host, nameof(host)); Host = host; - Host.CheckValue(feature, nameof(feature)); - Host.CheckValueOrNull(label); - Host.CheckValueOrNull(weight); + Host.CheckParam(feature.IsValid, nameof(feature), "not initialized properly"); FeatureColumn = feature; LabelColumn = label; @@ -76,7 +74,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) CheckInputSchema(inputSchema); - var outColumns = inputSchema.Columns.ToDictionary(x => x.Name); + var outColumns = inputSchema.ToDictionary(x => x.Name); foreach (var col in GetOutputColumnsCore(inputSchema)) outColumns[col.Name] = col; @@ -102,7 +100,7 @@ private void CheckInputSchema(SchemaShape inputSchema) if (!FeatureColumn.IsCompatibleWith(featureCol)) throw Host.Except($"Feature column '{FeatureColumn.Name}' is not compatible"); - if (WeightColumn != null) + if (WeightColumn.IsValid) { if (!inputSchema.TryFindColumn(WeightColumn.Name, out var weightCol)) throw Host.Except($"Weight column '{WeightColumn.Name}' is not found"); @@ -112,7 +110,7 @@ private void CheckInputSchema(SchemaShape inputSchema) // Special treatment for label column: we allow different types of labels, so the trainers // may define their own requirements on the label column. - if (LabelColumn != null) + if (LabelColumn.IsValid) { if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol)) throw Host.Except($"Label column '{LabelColumn.Name}' is not found"); @@ -122,8 +120,8 @@ private void CheckInputSchema(SchemaShape inputSchema) protected virtual void CheckLabelCompatible(SchemaShape.Column labelCol) { - Contracts.CheckValue(labelCol, nameof(labelCol)); - Contracts.AssertValue(LabelColumn); + Contracts.CheckParam(labelCol.IsValid, nameof(labelCol), "not initialized properly"); + Host.Assert(LabelColumn.IsValid); if (!LabelColumn.IsCompatibleWith(labelCol)) throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible"); @@ -133,20 +131,12 @@ protected TTransformer TrainTransformer(IDataView trainSet, IDataView validationSet = null, IPredictor initPredictor = null) { var cachedTrain = Info.WantCaching ? new CacheDataView(Host, trainSet, prefetch: null) : trainSet; + var cachedValid = Info.WantCaching && validationSet != null ? new CacheDataView(Host, validationSet, prefetch: null) : validationSet; - var trainRoles = MakeRoles(cachedTrain); + var trainRoleMapped = MakeRoles(cachedTrain); + var validRoleMapped = validationSet == null ? null : MakeRoles(cachedValid); - RoleMappedData validRoles; - - if (validationSet == null) - validRoles = null; - else - { - var cachedValid = Info.WantCaching ? new CacheDataView(Host, validationSet, prefetch: null) : validationSet; - validRoles = MakeRoles(cachedValid); - } - - var pred = TrainModelCore(new TrainContext(trainRoles, validRoles, null, initPredictor)); + var pred = TrainModelCore(new TrainContext(trainRoleMapped, validRoleMapped, null, initPredictor)); return MakeTransformer(pred, trainSet.Schema); } @@ -156,7 +146,7 @@ protected TTransformer TrainTransformer(IDataView trainSet, protected abstract TTransformer MakeTransformer(TModel model, Schema trainSchema); protected virtual RoleMappedData MakeRoles(IDataView data) => - new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, weight: WeightColumn?.Name); + new RoleMappedData(data, label: LabelColumn.Name, feature: FeatureColumn.Name, weight: WeightColumn.Name); IPredictor ITrainer.Train(TrainContext context) => ((ITrainer)this).Train(context); } @@ -178,16 +168,15 @@ public abstract class TrainerEstimatorBaseWithGroupId : Tr public TrainerEstimatorBaseWithGroupId(IHost host, SchemaShape.Column feature, SchemaShape.Column label, - SchemaShape.Column weight = null, - SchemaShape.Column groupId = null) + SchemaShape.Column weight = default, + SchemaShape.Column groupId = default) :base(host, feature, label, weight) { - Host.CheckValueOrNull(groupId); GroupIdColumn = groupId; } protected override RoleMappedData MakeRoles(IDataView data) => - new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, group: GroupIdColumn?.Name, weight: WeightColumn?.Name); + new RoleMappedData(data, label: LabelColumn.Name, feature: FeatureColumn.Name, group: GroupIdColumn.Name, weight: WeightColumn.Name); } } diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs index 41e984f522..9f8dfb68ea 100644 --- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs +++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs @@ -366,7 +366,7 @@ public static SchemaShape.Column MakeR4ScalarLabel(string labelColumn) public static SchemaShape.Column MakeU4ScalarColumn(string columnName) { if (columnName == null) - return null; + return default; return new SchemaShape.Column(columnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); } @@ -386,7 +386,7 @@ public static SchemaShape.Column MakeR4VecFeature(string featureColumn) public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, bool isExplicit = true) { if (weightColumn == null || !isExplicit) - return null; + return default; return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } } diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs index dc8ad9b389..81df252d58 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs @@ -49,7 +49,7 @@ public ITransformer Fit(IDataView input) private bool HasCategoricals(SchemaShape.Column col) { - _host.AssertValue(col); + _host.Assert(col.IsValid); if (!col.Metadata.TryFindColumn(MetadataUtils.Kinds.CategoricalSlotRanges, out var mcol)) return false; // The indices must be ints and of a definite size vector type. (Definite becuase @@ -116,7 +116,7 @@ private SchemaShape.Column CheckInputsAndMakeColumn( public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); result[_name] = CheckInputsAndMakeColumn(inputSchema, _name, _source); return new SchemaShape(result.Values); } diff --git a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs index 6ab8fe9b7b..84f26e6059 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs @@ -49,7 +49,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); + var resultDic = inputSchema.ToDictionary(x => x.Name); foreach (var (Source, Name) in Transformer.Columns) { if (!inputSchema.TryFindColumn(Source, out var originalColumn)) diff --git a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs index afbbbe59b8..3263305457 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs @@ -98,14 +98,14 @@ public static ColumnSelectingEstimator DropColumns(IHostEnvironment env, params public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - if (!Transformer.IgnoreMissing && !ColumnSelectingTransformer.IsSchemaValid(inputSchema.Columns.Select(x => x.Name), + if (!Transformer.IgnoreMissing && !ColumnSelectingTransformer.IsSchemaValid(inputSchema.Select(x => x.Name), Transformer.SelectColumns, out IEnumerable invalidColumns)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", string.Join(",", invalidColumns)); } - var columns = inputSchema.Columns.Where(c => _selectPredicate(c.Name)); + var columns = inputSchema.Where(c => _selectPredicate(c.Name)); return new SchemaShape(columns); } } diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index 60e799ddb6..4c7e353d9a 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -1235,7 +1235,7 @@ public HashingEstimator(IHostEnvironment env, params HashingTransformer.ColumnIn public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs index fd6af73412..d1aad4d96e 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs @@ -512,7 +512,7 @@ public KeyToValueMappingEstimator(IHostEnvironment env, params (string input, st public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.input, out var col)) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs index 363bc4c83f..93ca9bf874 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs @@ -746,7 +746,7 @@ private KeyToVectorMappingEstimator(IHostEnvironment env, KeyToVectorMappingTran public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs index 1434b680a7..7efbe73601 100644 --- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs +++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs @@ -213,7 +213,7 @@ public NormalizingTransformer Fit(IDataView input) public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs index 331e5fff71..c38bcd46df 100644 --- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs +++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs @@ -552,7 +552,7 @@ public TypeConvertingEstimator(IHostEnvironment env, params TypeConvertingTransf public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs index 4d01a62541..24ff1c6915 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs @@ -60,7 +60,7 @@ public ValueToKeyMappingEstimator(IHostEnvironment env, ValueToKeyMappingTransfo public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) @@ -77,7 +77,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) kv = new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, colInfo.TextKeyValues ? TextType.Instance : col.ItemType, col.IsKey); } - Contracts.AssertValue(kv); + Contracts.Assert(kv.IsValid); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) metadata = new SchemaShape(new[] { slotMeta, kv }); diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 9fe6023764..64c3191e5f 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -97,7 +97,7 @@ internal FastTreeRankingTrainer(IHostEnvironment env, Arguments args) protected override void CheckLabelCompatible(SchemaShape.Column labelCol) { - Contracts.AssertValue(labelCol); + Contracts.Assert(labelCol.IsValid); Action error = () => throw Host.ExceptSchemaMismatch(nameof(labelCol), RoleMappedSchema.ColumnRole.Label.Value, labelCol.Name, "R4 or a Key", labelCol.GetTypeString()); diff --git a/src/Microsoft.ML.HalLearners/VectorWhitening.cs b/src/Microsoft.ML.HalLearners/VectorWhitening.cs index 9c9cd820b4..5e282de6d7 100644 --- a/src/Microsoft.ML.HalLearners/VectorWhitening.cs +++ b/src/Microsoft.ML.HalLearners/VectorWhitening.cs @@ -806,7 +806,7 @@ public VectorWhiteningTransformer Fit(IDataView input) public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colPair in _infos) { if (!inputSchema.TryFindColumn(colPair.Input, out var col)) diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs index 7b89bdb173..a25dabe832 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs @@ -227,7 +227,7 @@ public ImageGrayscalingEstimator(IHostEnvironment env, params (string input, str public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.input, out var col)) diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs index 81b3d339e5..ba99793604 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs @@ -230,7 +230,7 @@ public ImageLoadingEstimator(IHostEnvironment env, ImageLoaderTransform transfor public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var (input, output) in Transformer.Columns) { if (!inputSchema.TryFindColumn(input, out var col)) diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs index ceecacb0f9..d8e7a37907 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs @@ -655,7 +655,7 @@ public ImagePixelExtractingEstimator(IHostEnvironment env, params ImagePixelExtr public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs index e5054b515f..d38af82ba8 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs @@ -447,7 +447,7 @@ public ImageResizingEstimator(IHostEnvironment env, ImageResizerTransform transf public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs index 389a8be9c2..7924d58f6f 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs @@ -121,7 +121,7 @@ internal KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args) } private KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args, Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), null, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), default, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) { Host.CheckValue(args, nameof(args)); diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index 62b476b40f..c787a606fa 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -213,7 +213,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); Contracts.Assert(success); - var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) + var metadata = new SchemaShape(labelCol.Metadata.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) .Concat(MetadataUtils.GetTrainerOutputMetadata())); return new[] { diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index 61b2c965b1..efb1c466c5 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -140,7 +140,7 @@ protected override void CheckDataValid(IChannel ch, RoleMappedData data) protected override void CheckLabelCompatible(SchemaShape.Column labelCol) { - Contracts.AssertValue(labelCol); + Contracts.Assert(labelCol.IsValid); Action error = () => throw Host.ExceptSchemaMismatch(nameof(labelCol), RoleMappedSchema.ColumnRole.Label.Value, labelCol.Name, "R4 or a Key", labelCol.GetTypeString()); diff --git a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs index e414261336..e2b060f506 100644 --- a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs +++ b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs @@ -480,8 +480,8 @@ public OnnxScoringEstimator(IHostEnvironment env, OnnxTransform transformer) public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); - var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); + var resultDic = inputSchema.ToDictionary(x => x.Name); for (var i = 0; i < Transformer.Inputs.Length; i++) { diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs index d37cf58851..c1af69d6c7 100644 --- a/src/Microsoft.ML.PCA/PcaTrainer.cs +++ b/src/Microsoft.ML.PCA/PcaTrainer.cs @@ -111,7 +111,7 @@ internal RandomizedPcaTrainer(IHostEnvironment env, Arguments args) private RandomizedPcaTrainer(IHostEnvironment env, Arguments args, string featureColumn, string weightColumn, int rank = 20, int oversampling = 20, bool center = true, int? seed = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), null, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), default, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { // if the args are not null, we got here from maml, and the internal ctor. if (args != null) @@ -152,7 +152,7 @@ private protected override PcaPredictor TrainModelCore(TrainContext context) private static SchemaShape.Column MakeWeightColumn(string weightColumn) { if (weightColumn == null) - return null; + return default; return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } diff --git a/src/Microsoft.ML.PCA/PcaTransform.cs b/src/Microsoft.ML.PCA/PcaTransform.cs index 1a6a5c85f2..8c65dd35aa 100644 --- a/src/Microsoft.ML.PCA/PcaTransform.cs +++ b/src/Microsoft.ML.PCA/PcaTransform.cs @@ -707,7 +707,7 @@ public PrincipalComponentAnalysisEstimator(IHostEnvironment env, params PcaTrans public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs index eb2db599a9..e42c6b002e 100644 --- a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs +++ b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs @@ -436,7 +436,7 @@ void CheckColumnsCompatible(SchemaShape.Column cachedColumn, string expectedColu CheckColumnsCompatible(matrixRowIndexColumn, MatrixRowIndexName); // Input columns just pass through so that output column dictionary contains all input columns. - var outColumns = inputSchema.Columns.ToDictionary(x => x.Name); + var outColumns = inputSchema.ToDictionary(x => x.Name); // Add columns produced by this estimator. foreach (var col in GetOutputColumnsCore(inputSchema)) diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs index 81ef9d03cf..be73a2de30 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs @@ -151,7 +151,7 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, FeatureColumns[i] = new SchemaShape.Column(featureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); - WeightColumn = weights != null ? new SchemaShape.Column(weights, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : null; + WeightColumn = weights != null ? new SchemaShape.Column(weights, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : default; } /// @@ -461,7 +461,7 @@ public FieldAwareFactorizationMachinePredictionTransformer Train(IDataView train roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Label, LabelColumn.Name)); - if (WeightColumn != null) + if (WeightColumn.IsValid) roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Feature, WeightColumn.Name)); var trainingData = new RoleMappedData(trainData, roles); @@ -500,10 +500,10 @@ void CheckColumnsCompatible(SchemaShape.Column column, string defaultName) CheckColumnsCompatible(feat, DefaultColumnNames.Features); } - if (WeightColumn != null) + if (WeightColumn.IsValid) CheckColumnsCompatible(WeightColumn, DefaultColumnNames.Weight); - var outColumns = inputSchema.Columns.ToDictionary(x => x.Name); + var outColumns = inputSchema.ToDictionary(x => x.Name); foreach (var col in GetOutputColumnsCore(inputSchema)) outColumns[col.Name] = col; diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs index 0d12aee904..2ec7021cd4 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -191,7 +191,7 @@ internal LbfgsTrainerBase(IHostEnvironment env, args.FeatureColumn = FeatureColumn.Name; args.LabelColumn = LabelColumn.Name; - args.WeightColumn = WeightColumn?.Name; + args.WeightColumn = WeightColumn.Name; Host.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null, nameof(Args.NumThreads), "numThreads must be positive (or empty for default)"); Host.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative"); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 6d8a411e7c..c856483f90 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -321,7 +321,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); Contracts.Assert(success); - var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) + var metadata = new SchemaShape(labelCol.Metadata.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) .Concat(MetadataUtils.GetTrainerOutputMetadata())); return new[] { diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs index 15de478efd..e7aeade4fe 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs @@ -149,7 +149,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - if (LabelColumn != null) + if (LabelColumn.IsValid) { if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol)) throw Host.ExceptSchemaMismatch(nameof(labelCol), DefaultColumnNames.PredictedLabel, DefaultColumnNames.PredictedLabel); @@ -158,7 +158,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible"); } - var outColumns = inputSchema.Columns.ToDictionary(x => x.Name); + var outColumns = inputSchema.ToDictionary(x => x.Name); foreach (var col in GetOutputColumnsCore(inputSchema)) outColumns[col.Name] = col; @@ -167,12 +167,12 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { - if (LabelColumn != null) + if (LabelColumn.IsValid) { bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); Contracts.Assert(success); - var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) + var metadata = new SchemaShape(labelCol.Metadata.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) .Concat(MetadataForScoreColumn())); return new[] { diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs index 751a62b7b6..1a82526a9e 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs @@ -76,7 +76,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); Contracts.Assert(success); - var predLabelMetadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) + var predLabelMetadata = new SchemaShape(labelCol.Metadata.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) .Concat(MetadataUtils.GetTrainerOutputMetadata())); return new[] diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index 91f68a85f6..c853aa7f79 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -163,7 +163,7 @@ protected override void CheckLabels(RoleMappedData data) protected override void CheckLabelCompatible(SchemaShape.Column labelCol) { - Contracts.AssertValue(labelCol); + Contracts.Assert(labelCol.IsValid); Action error = () => throw Host.ExceptSchemaMismatch(nameof(labelCol), RoleMappedSchema.ColumnRole.Label.Value, labelCol.Name, "BL, R8, R4 or a Key", labelCol.GetTypeString()); diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs index 42e447004a..7016568bed 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs @@ -259,13 +259,13 @@ private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColu } internal SdcaTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn, - SchemaShape.Column weight = null, Action advancedSettings = null, float? l2Const = null, + SchemaShape.Column weight = default, Action advancedSettings = null, float? l2Const = null, float? l1Threshold = null, int? maxIterations = null) : this(env, ArgsInit(featureColumn, labelColumn, advancedSettings), labelColumn, weight, l2Const, l1Threshold, maxIterations) { } - internal SdcaTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label, SchemaShape.Column weight = null, + internal SdcaTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label, SchemaShape.Column weight = default, float? l2Const = null, float? l1Threshold = null, int? maxIterations = null) : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, weight) { @@ -1520,7 +1520,7 @@ public SdcaBinaryTrainer(IHostEnvironment env, Arguments args) protected override void CheckLabelCompatible(SchemaShape.Column labelCol) { - Contracts.AssertValue(labelCol); + Contracts.Assert(labelCol.IsValid); Action error = () => throw Host.ExceptSchemaMismatch(nameof(labelCol), RoleMappedSchema.ColumnRole.Label.Value, labelCol.Name, "BL, R8, R4 or a Key", labelCol.GetTypeString()); @@ -1535,7 +1535,7 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol) private static SchemaShape.Column MakeWeightColumn(string weightColumn) { if (weightColumn == null) - return null; + return default; return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index f60b822827..af6617d862 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -101,7 +101,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); Contracts.Assert(success); - var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) + var metadata = new SchemaShape(labelCol.Metadata.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) .Concat(MetadataUtils.GetTrainerOutputMetadata())); return new[] { @@ -112,7 +112,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc protected override void CheckLabelCompatible(SchemaShape.Column labelCol) { - Contracts.AssertValue(labelCol); + Contracts.Assert(labelCol.IsValid); Action error = () => throw Host.ExceptSchemaMismatch(nameof(labelCol), RoleMappedSchema.ColumnRole.Label.Value, labelCol.Name, "R8, R4 or a Key", labelCol.GetTypeString()); diff --git a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs index 2e031433e1..c755406ee2 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs @@ -86,7 +86,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var outColumns = inputSchema.Columns.ToDictionary(x => x.Name); + var outColumns = inputSchema.ToDictionary(x => x.Name); var newColumns = new[] { @@ -327,7 +327,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var outColumns = inputSchema.Columns.ToDictionary(x => x.Name); + var outColumns = inputSchema.ToDictionary(x => x.Name); var newColumns = new[] { diff --git a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs index d419a3a8a1..792844ef26 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs @@ -15,7 +15,7 @@ public abstract class StochasticTrainerBase : TrainerEstim where TTransformer : ISingleFeaturePredictionTransformer where TModel : IPredictor { - public StochasticTrainerBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = null) + public StochasticTrainerBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = default) : base(host, feature, label, weight) { } diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index b4ced664c1..16e4b2119f 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -1101,8 +1101,8 @@ private static TensorFlowTransform.Arguments CreateArguments(TensorFlowModelInfo public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); - var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); + var resultDic = inputSchema.ToDictionary(x => x.Name); for (var i = 0; i < _args.InputColumns.Length; i++) { var input = _args.InputColumns[i]; diff --git a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs index 3a899f4339..6a42483130 100644 --- a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs +++ b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs @@ -251,7 +251,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var metadata = new List() { new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) }; - var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); + var resultDic = inputSchema.ToDictionary(x => x.Name); resultDic[Transformer.OutputColumnName] = new SchemaShape.Column( Transformer.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata)); diff --git a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs index 3fdf968025..107b2873cf 100644 --- a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs +++ b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs @@ -226,7 +226,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var metadata = new List() { new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) }; - var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); + var resultDic = inputSchema.ToDictionary(x => x.Name); resultDic[Transformer.OutputColumnName] = new SchemaShape.Column( Transformer.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata)); diff --git a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs index 9171c34662..4565f853b6 100644 --- a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs +++ b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs @@ -282,7 +282,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var metadata = new List() { new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) }; - var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); + var resultDic = inputSchema.ToDictionary(x => x.Name); resultDic[_args.Name] = new SchemaShape.Column( _args.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata)); diff --git a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs index f98c8bf34b..00eb719b46 100644 --- a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs +++ b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs @@ -260,7 +260,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var metadata = new List() { new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) }; - var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); + var resultDic = inputSchema.ToDictionary(x => x.Name); resultDic[_args.Name] = new SchemaShape.Column( _args.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata)); diff --git a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs index d597422854..eae71fae51 100644 --- a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs @@ -106,7 +106,7 @@ public CountFeatureSelectingEstimator(IHostEnvironment env, string inputColumn, public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colPair in _columns) { if (!inputSchema.TryFindColumn(colPair.Input, out var col)) diff --git a/src/Microsoft.ML.Transforms/GcnTransform.cs b/src/Microsoft.ML.Transforms/GcnTransform.cs index e35a41d5e3..5f390a27dd 100644 --- a/src/Microsoft.ML.Transforms/GcnTransform.cs +++ b/src/Microsoft.ML.Transforms/GcnTransform.cs @@ -795,7 +795,7 @@ internal static bool IsSchemaColumnValid(SchemaShape.Column col) public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colPair in Transformer.Columns) { if (!inputSchema.TryFindColumn(colPair.Input, out var col)) diff --git a/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs b/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs index 818806083c..314d98b259 100644 --- a/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs +++ b/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs @@ -465,7 +465,7 @@ private KeyToBinaryVectorMappingEstimator(IHostEnvironment env, KeyToBinaryVecto public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) diff --git a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs index 3c52227bec..45c299531d 100644 --- a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs @@ -371,7 +371,7 @@ public MissingValueDroppingEstimator(IHostEnvironment env, string input, string public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colPair in Transformer.Columns) { if (!inputSchema.TryFindColumn(colPair.input, out var col) || !Runtime.Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(col.ItemType, out Delegate del)) diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs index 049911f48e..1503b248cc 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs @@ -460,7 +460,7 @@ public MissingValueIndicatorEstimator(IHostEnvironment env, string input, string public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colPair in Transformer.Columns) { if (!inputSchema.TryFindColumn(colPair.input, out var col) || !Runtime.Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(col.ItemType, out Delegate del)) diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs index 41e7b5f857..bbc090f333 100644 --- a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs +++ b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs @@ -952,7 +952,7 @@ public MissingValueReplacingEstimator(IHostEnvironment env, params MissingValueR public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) diff --git a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs index 3c496e711c..517628655d 100644 --- a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs @@ -167,7 +167,7 @@ public ITransformer Fit(IDataView input) public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colPair in _columns) { if (!inputSchema.TryFindColumn(colPair.input, out var col)) diff --git a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs index 0984f40f97..0d16f6cfb1 100644 --- a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs +++ b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs @@ -678,7 +678,7 @@ public RandomFourierFeaturizingEstimator(IHostEnvironment env, params RandomFour public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 8978401c68..70b9e0914b 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -1145,7 +1145,7 @@ public LatentDirichletAllocationEstimator(IHostEnvironment env, params LatentDir public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs index 5f548c4632..063488aa45 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs @@ -863,7 +863,7 @@ internal static bool IsSchemaColumnValid(SchemaShape.Column col) public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs index 52ba66f0c6..2c274feec3 100644 --- a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs @@ -582,7 +582,7 @@ public StopWordsRemovingEstimator(IHostEnvironment env, params StopWordsRemoving public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) @@ -1076,7 +1076,7 @@ public CustomStopWordsRemovingEstimator(IHostEnvironment env, (string input, str public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.input, out var col)) diff --git a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs index bf4edae5de..1fa39930bd 100644 --- a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs +++ b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs @@ -472,7 +472,7 @@ private static string GenerateColumnName(ISchema schema, string srcName, string public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var srcName in _inputColumns) { if (!inputSchema.TryFindColumn(srcName, out var col)) diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs index 51bc1555bf..e82198234c 100644 --- a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs +++ b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs @@ -500,7 +500,7 @@ public TextNormalizingEstimator(IHostEnvironment env, public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.input, out var col)) diff --git a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs index 590b98af5e..4809b5708c 100644 --- a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs +++ b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs @@ -587,7 +587,7 @@ public TokenizingByCharactersEstimator(IHostEnvironment env, bool useMarkerChara public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.input, out var col)) diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs index a2056e8af5..25f4b76604 100644 --- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs +++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs @@ -846,7 +846,7 @@ public WordEmbeddingsExtractingEstimator(IHostEnvironment env, string customMode public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) diff --git a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs index 4419e08888..4631cd6674 100644 --- a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs +++ b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs @@ -468,7 +468,7 @@ public WordTokenizingEstimator(IHostEnvironment env, params WordTokenizingTransf public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index e2bc679830..7d726d7cbf 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -151,9 +151,9 @@ protected void TestEstimatorCore(IEstimator estimator, private void CheckSameSchemaShape(SchemaShape promised, SchemaShape delivered) { - Assert.True(promised.Columns.Length == delivered.Columns.Length); - var sortedCols1 = promised.Columns.OrderBy(x => x.Name); - var sortedCols2 = delivered.Columns.OrderBy(x => x.Name); + Assert.True(promised.Count == delivered.Count); + var sortedCols1 = promised.OrderBy(x => x.Name); + var sortedCols2 = delivered.OrderBy(x => x.Name); foreach (var (x, y) in sortedCols1.Zip(sortedCols2, (x, y) => (x, y))) {