diff --git a/src/Microsoft.ML.Api/CustomMappingTransformer.cs b/src/Microsoft.ML.Api/CustomMappingTransformer.cs
index 3e4102aad3..a63c63c7a0 100644
--- a/src/Microsoft.ML.Api/CustomMappingTransformer.cs
+++ b/src/Microsoft.ML.Api/CustomMappingTransformer.cs
@@ -205,7 +205,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
var addedCols = DataViewConstructionUtils.GetSchemaColumns(Transformer.AddedSchema);
var addedSchemaShape = SchemaShape.Create(SchemaBuilder.MakeSchema(addedCols));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
var inputDef = InternalSchemaDefinition.Create(typeof(TSrc), Transformer.InputSchemaDefinition);
foreach (var col in inputDef.Columns)
{
@@ -223,7 +223,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
}
}
- foreach (var addedCol in addedSchemaShape.Columns)
+ foreach (var addedCol in addedSchemaShape)
result[addedCol.Name] = addedCol;
return new SchemaShape(result.Values);
diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs
index b82ad4f5a7..5994a671bd 100644
--- a/src/Microsoft.ML.Core/Data/IEstimator.cs
+++ b/src/Microsoft.ML.Core/Data/IEstimator.cs
@@ -5,8 +5,9 @@
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
-using System;
+using System.Collections;
using System.Collections.Generic;
+using System.Collections.Immutable;
using System.Linq;
namespace Microsoft.ML.Core.Data
@@ -16,13 +17,17 @@ namespace Microsoft.ML.Core.Data
/// This is more relaxed than the proper , since it's only a subset of the columns,
/// and also since it doesn't specify exact 's for vectors and keys.
///
- public sealed class SchemaShape
+ public sealed class SchemaShape : IReadOnlyList
{
- public readonly Column[] Columns;
+ private readonly Column[] _columns;
private static readonly SchemaShape _empty = new SchemaShape(Enumerable.Empty());
- public sealed class Column
+ public int Count => _columns.Count();
+
+ public Column this[int index] => _columns[index];
+
+ public struct Column
{
public enum VectorKind
{
@@ -55,13 +60,13 @@ public enum VectorKind
///
public readonly SchemaShape Metadata;
- public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, SchemaShape metadata = null)
+ [BestFriend]
+ internal Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, SchemaShape metadata = null)
{
Contracts.CheckNonEmpty(name, nameof(name));
Contracts.CheckValueOrNull(metadata);
Contracts.CheckParam(!itemType.IsKey, nameof(itemType), "Item type cannot be a key");
Contracts.CheckParam(!itemType.IsVector, nameof(itemType), "Item type cannot be a vector");
-
Contracts.CheckParam(!isKey || KeyType.IsValidDataKind(itemType.RawKind), nameof(itemType), "The item type must be valid for a key");
Name = name;
@@ -80,9 +85,10 @@ public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey,
/// - The columns of of is a superset of our columns.
/// - Each such metadata column is itself compatible with the input metadata column.
///
- public bool IsCompatibleWith(Column inputColumn)
+ [BestFriend]
+ internal bool IsCompatibleWith(Column inputColumn)
{
- Contracts.CheckValue(inputColumn, nameof(inputColumn));
+ Contracts.Check(inputColumn.IsValid, nameof(inputColumn));
if (Name != inputColumn.Name)
return false;
if (Kind != inputColumn.Kind)
@@ -91,7 +97,7 @@ public bool IsCompatibleWith(Column inputColumn)
return false;
if (IsKey != inputColumn.IsKey)
return false;
- foreach (var metaCol in Metadata.Columns)
+ foreach (var metaCol in Metadata)
{
if (!inputColumn.Metadata.TryFindColumn(metaCol.Name, out var inputMetaCol))
return false;
@@ -101,7 +107,8 @@ public bool IsCompatibleWith(Column inputColumn)
return true;
}
- public string GetTypeString()
+ [BestFriend]
+ internal string GetTypeString()
{
string result = ItemType.ToString();
if (IsKey)
@@ -112,13 +119,20 @@ public string GetTypeString()
result = $"VarVector<{result}>";
return result;
}
+
+ ///
+ /// Return if this structure is not identical to the default value of . If true,
+ /// it means this structure is initialized properly and therefore considered as valid.
+ ///
+ [BestFriend]
+ internal bool IsValid => Name != null;
}
public SchemaShape(IEnumerable columns)
{
Contracts.CheckValue(columns, nameof(columns));
- Columns = columns.ToArray();
- Contracts.CheckParam(columns.All(c => c != null), nameof(columns), "No items should be null.");
+ _columns = columns.ToArray();
+ Contracts.CheckParam(columns.All(c => c.IsValid), nameof(columns), "Some items are not initialized properly.");
}
///
@@ -151,7 +165,8 @@ internal static void GetColumnTypeShape(ColumnType type,
///
/// Create a schema shape out of the fully defined schema.
///
- public static SchemaShape Create(Schema schema)
+ [BestFriend]
+ internal static SchemaShape Create(Schema schema)
{
Contracts.CheckValue(schema, nameof(schema));
var cols = new List();
@@ -179,25 +194,23 @@ public static SchemaShape Create(Schema schema)
///
/// Returns if there is a column with a specified and if so stores it in .
///
- public bool TryFindColumn(string name, out Column column)
+ [BestFriend]
+ internal bool TryFindColumn(string name, out Column column)
{
Contracts.CheckValue(name, nameof(name));
- column = Columns.FirstOrDefault(x => x.Name == name);
- return column != null;
+ column = _columns.FirstOrDefault(x => x.Name == name);
+ return column.IsValid;
}
+ public IEnumerator GetEnumerator() => ((IEnumerable)_columns).GetEnumerator();
+
+ IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
+
// REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape
// as an input to another schema shape. I started writing, but realized that there's more than one way to check for
// the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'.
}
- ///
- /// Exception class for schema validation errors.
- ///
- public class SchemaException : Exception
- {
- }
-
///
/// The 'data reader' takes a certain kind of input and turns it into an .
///
@@ -246,7 +259,6 @@ public interface ITransformer
///
/// Schema propagation for transformers.
/// Returns the output schema of the data, if the input schema is like the one provided.
- /// Throws if the input schema is not valid for the transformer.
///
Schema GetOutputSchema(Schema inputSchema);
@@ -288,7 +300,6 @@ public interface IEstimator
///
/// Schema propagation for estimators.
/// Returns the output schema shape of the estimator, if the input schema shape is like the one provided.
- /// Throws iff the input schema is not valid for the estimator.
///
SchemaShape GetOutputSchema(SchemaShape inputSchema);
}
diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs
index 6e2bd5230f..3bb8f16ff0 100644
--- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs
+++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs
@@ -367,7 +367,7 @@ public static bool IsNormalized(this Schema schema, int col)
/// of a scalar type, which we assume, if set, should be true.
public static bool IsNormalized(this SchemaShape.Column col)
{
- Contracts.CheckValue(col, nameof(col));
+ Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly");
return col.Metadata.TryFindColumn(Kinds.IsNormalized, out var metaCol)
&& metaCol.Kind == SchemaShape.Column.VectorKind.Scalar && !metaCol.IsKey
&& metaCol.ItemType == BoolType.Instance;
@@ -382,7 +382,7 @@ public static bool IsNormalized(this SchemaShape.Column col)
/// metadata of definite sized vectors of text.
public static bool HasSlotNames(this SchemaShape.Column col)
{
- Contracts.CheckValue(col, nameof(col));
+ Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly");
return col.Kind == SchemaShape.Column.VectorKind.Vector
&& col.Metadata.TryFindColumn(Kinds.SlotNames, out var metaCol)
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector && !metaCol.IsKey
diff --git a/src/Microsoft.ML.Core/Utilities/Contracts.cs b/src/Microsoft.ML.Core/Utilities/Contracts.cs
index d567d6e883..cda7cbc539 100644
--- a/src/Microsoft.ML.Core/Utilities/Contracts.cs
+++ b/src/Microsoft.ML.Core/Utilities/Contracts.cs
@@ -758,6 +758,10 @@ public static void CheckAlive(this IHostEnvironment env)
public static void CheckValueOrNull(T val) where T : class
{
}
+
+ ///
+ /// This documents that the parameter can legally be null.
+ ///
[Conditional("INVARIANT_CHECKS")]
public static void CheckValueOrNull(this IExceptionContext ctx, T val) where T : class
{
diff --git a/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs
index d94219a453..c67c6ae733 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs
@@ -30,22 +30,22 @@ public FakeSchema(IHostEnvironment env, SchemaShape inputShape)
{
_env = env;
_shape = inputShape;
- _colMap = Enumerable.Range(0, _shape.Columns.Length)
- .ToDictionary(idx => _shape.Columns[idx].Name, idx => idx);
+ _colMap = Enumerable.Range(0, _shape.Count)
+ .ToDictionary(idx => _shape[idx].Name, idx => idx);
}
- public int ColumnCount => _shape.Columns.Length;
+ public int ColumnCount => _shape.Count;
public string GetColumnName(int col)
{
_env.Check(0 <= col && col < ColumnCount);
- return _shape.Columns[col].Name;
+ return _shape[col].Name;
}
public ColumnType GetColumnType(int col)
{
_env.Check(0 <= col && col < ColumnCount);
- var inputCol = _shape.Columns[col];
+ var inputCol = _shape[col];
return MakeColumnType(inputCol);
}
@@ -66,7 +66,7 @@ private static ColumnType MakeColumnType(SchemaShape.Column inputCol)
public void GetMetadata(string kind, int col, ref TValue value)
{
_env.Check(0 <= col && col < ColumnCount);
- var inputCol = _shape.Columns[col];
+ var inputCol = _shape[col];
var metaShape = inputCol.Metadata;
if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn))
throw _env.ExceptGetMetadata();
@@ -89,7 +89,7 @@ public void GetMetadata(string kind, int col, ref TValue value)
public ColumnType GetMetadataTypeOrNull(string kind, int col)
{
_env.Check(0 <= col && col < ColumnCount);
- var inputCol = _shape.Columns[col];
+ var inputCol = _shape[col];
var metaShape = inputCol.Metadata;
if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn))
return null;
@@ -99,12 +99,12 @@ public ColumnType GetMetadataTypeOrNull(string kind, int col)
public IEnumerable> GetMetadataTypes(int col)
{
_env.Check(0 <= col && col < ColumnCount);
- var inputCol = _shape.Columns[col];
+ var inputCol = _shape[col];
var metaShape = inputCol.Metadata;
if (metaShape == null)
return Enumerable.Empty>();
- return metaShape.Columns.Select(c => new KeyValuePair(c.Name, MakeColumnType(c)));
+ return metaShape.Select(c => new KeyValuePair(c.Name, MakeColumnType(c)));
}
}
}
diff --git a/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs b/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs
index 6ffd089f24..7b1edd3398 100644
--- a/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs
+++ b/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs
@@ -113,7 +113,7 @@ public void Check(IExceptionContext ectx, SchemaShape shape)
private static Type GetTypeOrNull(SchemaShape.Column col)
{
- Contracts.AssertValue(col);
+ Contracts.Assert(col.IsValid);
Type vecType = null;
diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
index 1e49c32ed3..55d85201b7 100644
--- a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
+++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
@@ -55,13 +55,11 @@ public abstract class TrainerEstimatorBase : ITrainerEstim
private protected TrainerEstimatorBase(IHost host,
SchemaShape.Column feature,
SchemaShape.Column label,
- SchemaShape.Column weight = null)
+ SchemaShape.Column weight = default)
{
Contracts.CheckValue(host, nameof(host));
Host = host;
- Host.CheckValue(feature, nameof(feature));
- Host.CheckValueOrNull(label);
- Host.CheckValueOrNull(weight);
+ Host.CheckParam(feature.IsValid, nameof(feature), "not initialized properly");
FeatureColumn = feature;
LabelColumn = label;
@@ -76,7 +74,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
CheckInputSchema(inputSchema);
- var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);
+ var outColumns = inputSchema.ToDictionary(x => x.Name);
foreach (var col in GetOutputColumnsCore(inputSchema))
outColumns[col.Name] = col;
@@ -102,7 +100,7 @@ private void CheckInputSchema(SchemaShape inputSchema)
if (!FeatureColumn.IsCompatibleWith(featureCol))
throw Host.Except($"Feature column '{FeatureColumn.Name}' is not compatible");
- if (WeightColumn != null)
+ if (WeightColumn.IsValid)
{
if (!inputSchema.TryFindColumn(WeightColumn.Name, out var weightCol))
throw Host.Except($"Weight column '{WeightColumn.Name}' is not found");
@@ -112,7 +110,7 @@ private void CheckInputSchema(SchemaShape inputSchema)
// Special treatment for label column: we allow different types of labels, so the trainers
// may define their own requirements on the label column.
- if (LabelColumn != null)
+ if (LabelColumn.IsValid)
{
if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
throw Host.Except($"Label column '{LabelColumn.Name}' is not found");
@@ -122,8 +120,8 @@ private void CheckInputSchema(SchemaShape inputSchema)
protected virtual void CheckLabelCompatible(SchemaShape.Column labelCol)
{
- Contracts.CheckValue(labelCol, nameof(labelCol));
- Contracts.AssertValue(LabelColumn);
+ Contracts.CheckParam(labelCol.IsValid, nameof(labelCol), "not initialized properly");
+ Host.Assert(LabelColumn.IsValid);
if (!LabelColumn.IsCompatibleWith(labelCol))
throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible");
@@ -133,20 +131,12 @@ protected TTransformer TrainTransformer(IDataView trainSet,
IDataView validationSet = null, IPredictor initPredictor = null)
{
var cachedTrain = Info.WantCaching ? new CacheDataView(Host, trainSet, prefetch: null) : trainSet;
+ var cachedValid = Info.WantCaching && validationSet != null ? new CacheDataView(Host, validationSet, prefetch: null) : validationSet;
- var trainRoles = MakeRoles(cachedTrain);
+ var trainRoleMapped = MakeRoles(cachedTrain);
+ var validRoleMapped = validationSet == null ? null : MakeRoles(cachedValid);
- RoleMappedData validRoles;
-
- if (validationSet == null)
- validRoles = null;
- else
- {
- var cachedValid = Info.WantCaching ? new CacheDataView(Host, validationSet, prefetch: null) : validationSet;
- validRoles = MakeRoles(cachedValid);
- }
-
- var pred = TrainModelCore(new TrainContext(trainRoles, validRoles, null, initPredictor));
+ var pred = TrainModelCore(new TrainContext(trainRoleMapped, validRoleMapped, null, initPredictor));
return MakeTransformer(pred, trainSet.Schema);
}
@@ -156,7 +146,7 @@ protected TTransformer TrainTransformer(IDataView trainSet,
protected abstract TTransformer MakeTransformer(TModel model, Schema trainSchema);
protected virtual RoleMappedData MakeRoles(IDataView data) =>
- new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, weight: WeightColumn?.Name);
+ new RoleMappedData(data, label: LabelColumn.Name, feature: FeatureColumn.Name, weight: WeightColumn.Name);
IPredictor ITrainer.Train(TrainContext context) => ((ITrainer)this).Train(context);
}
@@ -178,16 +168,15 @@ public abstract class TrainerEstimatorBaseWithGroupId : Tr
public TrainerEstimatorBaseWithGroupId(IHost host,
SchemaShape.Column feature,
SchemaShape.Column label,
- SchemaShape.Column weight = null,
- SchemaShape.Column groupId = null)
+ SchemaShape.Column weight = default,
+ SchemaShape.Column groupId = default)
:base(host, feature, label, weight)
{
- Host.CheckValueOrNull(groupId);
GroupIdColumn = groupId;
}
protected override RoleMappedData MakeRoles(IDataView data) =>
- new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, group: GroupIdColumn?.Name, weight: WeightColumn?.Name);
+ new RoleMappedData(data, label: LabelColumn.Name, feature: FeatureColumn.Name, group: GroupIdColumn.Name, weight: WeightColumn.Name);
}
}
diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs
index 41e984f522..9f8dfb68ea 100644
--- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs
+++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs
@@ -366,7 +366,7 @@ public static SchemaShape.Column MakeR4ScalarLabel(string labelColumn)
public static SchemaShape.Column MakeU4ScalarColumn(string columnName)
{
if (columnName == null)
- return null;
+ return default;
return new SchemaShape.Column(columnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
}
@@ -386,7 +386,7 @@ public static SchemaShape.Column MakeR4VecFeature(string featureColumn)
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, bool isExplicit = true)
{
if (weightColumn == null || !isExplicit)
- return null;
+ return default;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}
}
diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs
index dc8ad9b389..81df252d58 100644
--- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs
+++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs
@@ -49,7 +49,7 @@ public ITransformer Fit(IDataView input)
private bool HasCategoricals(SchemaShape.Column col)
{
- _host.AssertValue(col);
+ _host.Assert(col.IsValid);
if (!col.Metadata.TryFindColumn(MetadataUtils.Kinds.CategoricalSlotRanges, out var mcol))
return false;
// The indices must be ints and of a definite size vector type. (Definite becuase
@@ -116,7 +116,7 @@ private SchemaShape.Column CheckInputsAndMakeColumn(
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
result[_name] = CheckInputsAndMakeColumn(inputSchema, _name, _source);
return new SchemaShape(result.Values);
}
diff --git a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs
index 6ab8fe9b7b..84f26e6059 100644
--- a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs
+++ b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs
@@ -49,7 +49,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);
+ var resultDic = inputSchema.ToDictionary(x => x.Name);
foreach (var (Source, Name) in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(Source, out var originalColumn))
diff --git a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs
index afbbbe59b8..3263305457 100644
--- a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs
+++ b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs
@@ -98,14 +98,14 @@ public static ColumnSelectingEstimator DropColumns(IHostEnvironment env, params
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- if (!Transformer.IgnoreMissing && !ColumnSelectingTransformer.IsSchemaValid(inputSchema.Columns.Select(x => x.Name),
+ if (!Transformer.IgnoreMissing && !ColumnSelectingTransformer.IsSchemaValid(inputSchema.Select(x => x.Name),
Transformer.SelectColumns,
out IEnumerable invalidColumns))
{
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", string.Join(",", invalidColumns));
}
- var columns = inputSchema.Columns.Where(c => _selectPredicate(c.Name));
+ var columns = inputSchema.Where(c => _selectPredicate(c.Name));
return new SchemaShape(columns);
}
}
diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs
index 60e799ddb6..4c7e353d9a 100644
--- a/src/Microsoft.ML.Data/Transforms/Hashing.cs
+++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs
@@ -1235,7 +1235,7 @@ public HashingEstimator(IHostEnvironment env, params HashingTransformer.ColumnIn
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in _columns)
{
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs
index fd6af73412..d1aad4d96e 100644
--- a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs
+++ b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs
@@ -512,7 +512,7 @@ public KeyToValueMappingEstimator(IHostEnvironment env, params (string input, st
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(colInfo.input, out var col))
diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs
index 363bc4c83f..93ca9bf874 100644
--- a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs
+++ b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs
@@ -746,7 +746,7 @@ private KeyToVectorMappingEstimator(IHostEnvironment env, KeyToVectorMappingTran
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs
index 1434b680a7..7efbe73601 100644
--- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs
+++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs
@@ -213,7 +213,7 @@ public NormalizingTransformer Fit(IDataView input)
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in _columns)
{
diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs
index 331e5fff71..c38bcd46df 100644
--- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs
+++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs
@@ -552,7 +552,7 @@ public TypeConvertingEstimator(IHostEnvironment env, params TypeConvertingTransf
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs
index 4d01a62541..24ff1c6915 100644
--- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs
+++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs
@@ -60,7 +60,7 @@ public ValueToKeyMappingEstimator(IHostEnvironment env, ValueToKeyMappingTransfo
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in _columns)
{
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
@@ -77,7 +77,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
kv = new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
colInfo.TextKeyValues ? TextType.Instance : col.ItemType, col.IsKey);
}
- Contracts.AssertValue(kv);
+ Contracts.Assert(kv.IsValid);
if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta))
metadata = new SchemaShape(new[] { slotMeta, kv });
diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs
index 9fe6023764..64c3191e5f 100644
--- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs
@@ -97,7 +97,7 @@ internal FastTreeRankingTrainer(IHostEnvironment env, Arguments args)
protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
{
- Contracts.AssertValue(labelCol);
+ Contracts.Assert(labelCol.IsValid);
Action error =
() => throw Host.ExceptSchemaMismatch(nameof(labelCol), RoleMappedSchema.ColumnRole.Label.Value, labelCol.Name, "R4 or a Key", labelCol.GetTypeString());
diff --git a/src/Microsoft.ML.HalLearners/VectorWhitening.cs b/src/Microsoft.ML.HalLearners/VectorWhitening.cs
index 9c9cd820b4..5e282de6d7 100644
--- a/src/Microsoft.ML.HalLearners/VectorWhitening.cs
+++ b/src/Microsoft.ML.HalLearners/VectorWhitening.cs
@@ -806,7 +806,7 @@ public VectorWhiteningTransformer Fit(IDataView input)
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colPair in _infos)
{
if (!inputSchema.TryFindColumn(colPair.Input, out var col))
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs
index 7b89bdb173..a25dabe832 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs
@@ -227,7 +227,7 @@ public ImageGrayscalingEstimator(IHostEnvironment env, params (string input, str
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(colInfo.input, out var col))
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs
index 81b3d339e5..ba99793604 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs
@@ -230,7 +230,7 @@ public ImageLoadingEstimator(IHostEnvironment env, ImageLoaderTransform transfor
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var (input, output) in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(input, out var col))
diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs
index ceecacb0f9..d8e7a37907 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs
@@ -655,7 +655,7 @@ public ImagePixelExtractingEstimator(IHostEnvironment env, params ImagePixelExtr
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs
index e5054b515f..d38af82ba8 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs
@@ -447,7 +447,7 @@ public ImageResizingEstimator(IHostEnvironment env, ImageResizerTransform transf
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
index 389a8be9c2..7924d58f6f 100644
--- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
+++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
@@ -121,7 +121,7 @@ internal KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args)
}
private KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args, Action advancedSettings = null)
- : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), null, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
+ : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), default, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
{
Host.CheckValue(args, nameof(args));
diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
index 62b476b40f..c787a606fa 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
@@ -213,7 +213,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
Contracts.Assert(success);
- var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
+ var metadata = new SchemaShape(labelCol.Metadata.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
.Concat(MetadataUtils.GetTrainerOutputMetadata()));
return new[]
{
diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs
index 61b2c965b1..efb1c466c5 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs
@@ -140,7 +140,7 @@ protected override void CheckDataValid(IChannel ch, RoleMappedData data)
protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
{
- Contracts.AssertValue(labelCol);
+ Contracts.Assert(labelCol.IsValid);
Action error =
() => throw Host.ExceptSchemaMismatch(nameof(labelCol), RoleMappedSchema.ColumnRole.Label.Value, labelCol.Name, "R4 or a Key", labelCol.GetTypeString());
diff --git a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs
index e414261336..e2b060f506 100644
--- a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs
+++ b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs
@@ -480,8 +480,8 @@ public OnnxScoringEstimator(IHostEnvironment env, OnnxTransform transformer)
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
- var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
+ var resultDic = inputSchema.ToDictionary(x => x.Name);
for (var i = 0; i < Transformer.Inputs.Length; i++)
{
diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs
index d37cf58851..c1af69d6c7 100644
--- a/src/Microsoft.ML.PCA/PcaTrainer.cs
+++ b/src/Microsoft.ML.PCA/PcaTrainer.cs
@@ -111,7 +111,7 @@ internal RandomizedPcaTrainer(IHostEnvironment env, Arguments args)
private RandomizedPcaTrainer(IHostEnvironment env, Arguments args, string featureColumn, string weightColumn,
int rank = 20, int oversampling = 20, bool center = true, int? seed = null)
- : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), null, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
+ : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), default, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
{
// if the args are not null, we got here from maml, and the internal ctor.
if (args != null)
@@ -152,7 +152,7 @@ private protected override PcaPredictor TrainModelCore(TrainContext context)
private static SchemaShape.Column MakeWeightColumn(string weightColumn)
{
if (weightColumn == null)
- return null;
+ return default;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}
diff --git a/src/Microsoft.ML.PCA/PcaTransform.cs b/src/Microsoft.ML.PCA/PcaTransform.cs
index 1a6a5c85f2..8c65dd35aa 100644
--- a/src/Microsoft.ML.PCA/PcaTransform.cs
+++ b/src/Microsoft.ML.PCA/PcaTransform.cs
@@ -707,7 +707,7 @@ public PrincipalComponentAnalysisEstimator(IHostEnvironment env, params PcaTrans
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in _columns)
{
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs
index eb2db599a9..e42c6b002e 100644
--- a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs
+++ b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs
@@ -436,7 +436,7 @@ void CheckColumnsCompatible(SchemaShape.Column cachedColumn, string expectedColu
CheckColumnsCompatible(matrixRowIndexColumn, MatrixRowIndexName);
// Input columns just pass through so that output column dictionary contains all input columns.
- var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);
+ var outColumns = inputSchema.ToDictionary(x => x.Name);
// Add columns produced by this estimator.
foreach (var col in GetOutputColumnsCore(inputSchema))
diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
index 81ef9d03cf..be73a2de30 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
@@ -151,7 +151,7 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env,
FeatureColumns[i] = new SchemaShape.Column(featureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
- WeightColumn = weights != null ? new SchemaShape.Column(weights, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : null;
+ WeightColumn = weights != null ? new SchemaShape.Column(weights, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : default;
}
///
@@ -461,7 +461,7 @@ public FieldAwareFactorizationMachinePredictionTransformer Train(IDataView train
roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Label, LabelColumn.Name));
- if (WeightColumn != null)
+ if (WeightColumn.IsValid)
roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Feature, WeightColumn.Name));
var trainingData = new RoleMappedData(trainData, roles);
@@ -500,10 +500,10 @@ void CheckColumnsCompatible(SchemaShape.Column column, string defaultName)
CheckColumnsCompatible(feat, DefaultColumnNames.Features);
}
- if (WeightColumn != null)
+ if (WeightColumn.IsValid)
CheckColumnsCompatible(WeightColumn, DefaultColumnNames.Weight);
- var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);
+ var outColumns = inputSchema.ToDictionary(x => x.Name);
foreach (var col in GetOutputColumnsCore(inputSchema))
outColumns[col.Name] = col;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs
index 0d12aee904..2ec7021cd4 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs
@@ -191,7 +191,7 @@ internal LbfgsTrainerBase(IHostEnvironment env,
args.FeatureColumn = FeatureColumn.Name;
args.LabelColumn = LabelColumn.Name;
- args.WeightColumn = WeightColumn?.Name;
+ args.WeightColumn = WeightColumn.Name;
Host.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null,
nameof(Args.NumThreads), "numThreads must be positive (or empty for default)");
Host.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative");
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
index 6d8a411e7c..c856483f90 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
@@ -321,7 +321,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
Contracts.Assert(success);
- var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
+ var metadata = new SchemaShape(labelCol.Metadata.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
.Concat(MetadataUtils.GetTrainerOutputMetadata()));
return new[]
{
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs
index 15de478efd..e7aeade4fe 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs
@@ -149,7 +149,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- if (LabelColumn != null)
+ if (LabelColumn.IsValid)
{
if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
throw Host.ExceptSchemaMismatch(nameof(labelCol), DefaultColumnNames.PredictedLabel, DefaultColumnNames.PredictedLabel);
@@ -158,7 +158,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible");
}
- var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);
+ var outColumns = inputSchema.ToDictionary(x => x.Name);
foreach (var col in GetOutputColumnsCore(inputSchema))
outColumns[col.Name] = col;
@@ -167,12 +167,12 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
{
- if (LabelColumn != null)
+ if (LabelColumn.IsValid)
{
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
Contracts.Assert(success);
- var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
+ var metadata = new SchemaShape(labelCol.Metadata.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
.Concat(MetadataForScoreColumn()));
return new[]
{
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
index 751a62b7b6..1a82526a9e 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
@@ -76,7 +76,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
Contracts.Assert(success);
- var predLabelMetadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
+ var predLabelMetadata = new SchemaShape(labelCol.Metadata.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
.Concat(MetadataUtils.GetTrainerOutputMetadata()));
return new[]
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs
index 91f68a85f6..c853aa7f79 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs
@@ -163,7 +163,7 @@ protected override void CheckLabels(RoleMappedData data)
protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
{
- Contracts.AssertValue(labelCol);
+ Contracts.Assert(labelCol.IsValid);
Action error =
() => throw Host.ExceptSchemaMismatch(nameof(labelCol), RoleMappedSchema.ColumnRole.Label.Value, labelCol.Name, "BL, R8, R4 or a Key", labelCol.GetTypeString());
diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs
index 42e447004a..7016568bed 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs
@@ -259,13 +259,13 @@ private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColu
}
internal SdcaTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn,
- SchemaShape.Column weight = null, Action advancedSettings = null, float? l2Const = null,
+ SchemaShape.Column weight = default, Action advancedSettings = null, float? l2Const = null,
float? l1Threshold = null, int? maxIterations = null)
: this(env, ArgsInit(featureColumn, labelColumn, advancedSettings), labelColumn, weight, l2Const, l1Threshold, maxIterations)
{
}
- internal SdcaTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label, SchemaShape.Column weight = null,
+ internal SdcaTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label, SchemaShape.Column weight = default,
float? l2Const = null, float? l1Threshold = null, int? maxIterations = null)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, weight)
{
@@ -1520,7 +1520,7 @@ public SdcaBinaryTrainer(IHostEnvironment env, Arguments args)
protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
{
- Contracts.AssertValue(labelCol);
+ Contracts.Assert(labelCol.IsValid);
Action error =
() => throw Host.ExceptSchemaMismatch(nameof(labelCol), RoleMappedSchema.ColumnRole.Label.Value, labelCol.Name, "BL, R8, R4 or a Key", labelCol.GetTypeString());
@@ -1535,7 +1535,7 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
private static SchemaShape.Column MakeWeightColumn(string weightColumn)
{
if (weightColumn == null)
- return null;
+ return default;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
index f60b822827..af6617d862 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
@@ -101,7 +101,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
Contracts.Assert(success);
- var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
+ var metadata = new SchemaShape(labelCol.Metadata.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
.Concat(MetadataUtils.GetTrainerOutputMetadata()));
return new[]
{
@@ -112,7 +112,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
{
- Contracts.AssertValue(labelCol);
+ Contracts.Assert(labelCol.IsValid);
Action error =
() => throw Host.ExceptSchemaMismatch(nameof(labelCol), RoleMappedSchema.ColumnRole.Label.Value, labelCol.Name, "R8, R4 or a Key", labelCol.GetTypeString());
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs
index 2e031433e1..c755406ee2 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs
@@ -86,7 +86,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);
+ var outColumns = inputSchema.ToDictionary(x => x.Name);
var newColumns = new[]
{
@@ -327,7 +327,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);
+ var outColumns = inputSchema.ToDictionary(x => x.Name);
var newColumns = new[]
{
diff --git a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs
index d419a3a8a1..792844ef26 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs
@@ -15,7 +15,7 @@ public abstract class StochasticTrainerBase : TrainerEstim
where TTransformer : ISingleFeaturePredictionTransformer
where TModel : IPredictor
{
- public StochasticTrainerBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = null)
+ public StochasticTrainerBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = default)
: base(host, feature, label, weight)
{
}
diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
index b4ced664c1..16e4b2119f 100644
--- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
@@ -1101,8 +1101,8 @@ private static TensorFlowTransform.Arguments CreateArguments(TensorFlowModelInfo
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
- var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
+ var resultDic = inputSchema.ToDictionary(x => x.Name);
for (var i = 0; i < _args.InputColumns.Length; i++)
{
var input = _args.InputColumns[i];
diff --git a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs
index 3a899f4339..6a42483130 100644
--- a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs
+++ b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs
@@ -251,7 +251,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
var metadata = new List() {
new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
};
- var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);
+ var resultDic = inputSchema.ToDictionary(x => x.Name);
resultDic[Transformer.OutputColumnName] = new SchemaShape.Column(
Transformer.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));
diff --git a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs
index 3fdf968025..107b2873cf 100644
--- a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs
+++ b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs
@@ -226,7 +226,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
var metadata = new List() {
new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
};
- var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);
+ var resultDic = inputSchema.ToDictionary(x => x.Name);
resultDic[Transformer.OutputColumnName] = new SchemaShape.Column(
Transformer.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));
diff --git a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs
index 9171c34662..4565f853b6 100644
--- a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs
+++ b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs
@@ -282,7 +282,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
var metadata = new List() {
new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
};
- var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);
+ var resultDic = inputSchema.ToDictionary(x => x.Name);
resultDic[_args.Name] = new SchemaShape.Column(
_args.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));
diff --git a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs
index f98c8bf34b..00eb719b46 100644
--- a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs
+++ b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs
@@ -260,7 +260,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
var metadata = new List() {
new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
};
- var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);
+ var resultDic = inputSchema.ToDictionary(x => x.Name);
resultDic[_args.Name] = new SchemaShape.Column(
_args.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));
diff --git a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs
index d597422854..eae71fae51 100644
--- a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs
+++ b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs
@@ -106,7 +106,7 @@ public CountFeatureSelectingEstimator(IHostEnvironment env, string inputColumn,
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colPair in _columns)
{
if (!inputSchema.TryFindColumn(colPair.Input, out var col))
diff --git a/src/Microsoft.ML.Transforms/GcnTransform.cs b/src/Microsoft.ML.Transforms/GcnTransform.cs
index e35a41d5e3..5f390a27dd 100644
--- a/src/Microsoft.ML.Transforms/GcnTransform.cs
+++ b/src/Microsoft.ML.Transforms/GcnTransform.cs
@@ -795,7 +795,7 @@ internal static bool IsSchemaColumnValid(SchemaShape.Column col)
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colPair in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(colPair.Input, out var col))
diff --git a/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs b/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs
index 818806083c..314d98b259 100644
--- a/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs
+++ b/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs
@@ -465,7 +465,7 @@ private KeyToBinaryVectorMappingEstimator(IHostEnvironment env, KeyToBinaryVecto
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
diff --git a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs
index 3c52227bec..45c299531d 100644
--- a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs
+++ b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs
@@ -371,7 +371,7 @@ public MissingValueDroppingEstimator(IHostEnvironment env, string input, string
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colPair in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(colPair.input, out var col) || !Runtime.Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(col.ItemType, out Delegate del))
diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs
index 049911f48e..1503b248cc 100644
--- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs
+++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs
@@ -460,7 +460,7 @@ public MissingValueIndicatorEstimator(IHostEnvironment env, string input, string
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colPair in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(colPair.input, out var col) || !Runtime.Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(col.ItemType, out Delegate del))
diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs
index 41e7b5f857..bbc090f333 100644
--- a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs
+++ b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs
@@ -952,7 +952,7 @@ public MissingValueReplacingEstimator(IHostEnvironment env, params MissingValueR
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in _columns)
{
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
diff --git a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs
index 3c496e711c..517628655d 100644
--- a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs
+++ b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs
@@ -167,7 +167,7 @@ public ITransformer Fit(IDataView input)
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colPair in _columns)
{
if (!inputSchema.TryFindColumn(colPair.input, out var col))
diff --git a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs
index 0984f40f97..0d16f6cfb1 100644
--- a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs
+++ b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs
@@ -678,7 +678,7 @@ public RandomFourierFeaturizingEstimator(IHostEnvironment env, params RandomFour
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in _columns)
{
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs
index 8978401c68..70b9e0914b 100644
--- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs
@@ -1145,7 +1145,7 @@ public LatentDirichletAllocationEstimator(IHostEnvironment env, params LatentDir
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in _columns)
{
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs
index 5f548c4632..063488aa45 100644
--- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs
@@ -863,7 +863,7 @@ internal static bool IsSchemaColumnValid(SchemaShape.Column col)
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in _columns)
{
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs
index 52ba66f0c6..2c274feec3 100644
--- a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs
+++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs
@@ -582,7 +582,7 @@ public StopWordsRemovingEstimator(IHostEnvironment env, params StopWordsRemoving
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
@@ -1076,7 +1076,7 @@ public CustomStopWordsRemovingEstimator(IHostEnvironment env, (string input, str
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(colInfo.input, out var col))
diff --git a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs
index bf4edae5de..1fa39930bd 100644
--- a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs
+++ b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs
@@ -472,7 +472,7 @@ private static string GenerateColumnName(ISchema schema, string srcName, string
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var srcName in _inputColumns)
{
if (!inputSchema.TryFindColumn(srcName, out var col))
diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs
index 51bc1555bf..e82198234c 100644
--- a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs
+++ b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs
@@ -500,7 +500,7 @@ public TextNormalizingEstimator(IHostEnvironment env,
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(colInfo.input, out var col))
diff --git a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs
index 590b98af5e..4809b5708c 100644
--- a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs
+++ b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs
@@ -587,7 +587,7 @@ public TokenizingByCharactersEstimator(IHostEnvironment env, bool useMarkerChara
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(colInfo.input, out var col))
diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs
index a2056e8af5..25f4b76604 100644
--- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs
+++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs
@@ -846,7 +846,7 @@ public WordEmbeddingsExtractingEstimator(IHostEnvironment env, string customMode
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in _columns)
{
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
diff --git a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs
index 4419e08888..4631cd6674 100644
--- a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs
+++ b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs
@@ -468,7 +468,7 @@ public WordTokenizingEstimator(IHostEnvironment env, params WordTokenizingTransf
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
index e2bc679830..7d726d7cbf 100644
--- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
+++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
@@ -151,9 +151,9 @@ protected void TestEstimatorCore(IEstimator estimator,
private void CheckSameSchemaShape(SchemaShape promised, SchemaShape delivered)
{
- Assert.True(promised.Columns.Length == delivered.Columns.Length);
- var sortedCols1 = promised.Columns.OrderBy(x => x.Name);
- var sortedCols2 = delivered.Columns.OrderBy(x => x.Name);
+ Assert.True(promised.Count == delivered.Count);
+ var sortedCols1 = promised.OrderBy(x => x.Name);
+ var sortedCols2 = delivered.OrderBy(x => x.Name);
foreach (var (x, y) in sortedCols1.Zip(sortedCols2, (x, y) => (x, y)))
{