Skip to content

Commit eb62d83

Browse files
committed
Make ML.NET build on the new Microsoft.Data.DataView assembly.
Fix #1860
1 parent bb6e6e5 commit eb62d83

21 files changed

+163
-24
lines changed

pkg/Microsoft.ML/Microsoft.ML.nupkgproj

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
</PropertyGroup>
77

88
<ItemGroup>
9+
<ProjectReference Include="../Microsoft.Data.DataView/Microsoft.Data.DataView.nupkgproj" />
910
<ProjectReference Include="../Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.nupkgproj" />
1011

1112
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonPackageVersion)" />

src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs

+74
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
6+
57
namespace Microsoft.ML.Data
68
{
79
/// <summary>
@@ -76,5 +78,77 @@ public static bool SameSizeAndItemType(this ColumnType columnType, ColumnType ot
7678
return false;
7779
return vectorType.Size == otherVectorType.Size;
7880
}
81+
82+
public static PrimitiveType PrimitiveTypeFromType(Type type)
83+
{
84+
if (type == typeof(ReadOnlyMemory<char>) || type == typeof(string))
85+
return TextType.Instance;
86+
if (type == typeof(bool))
87+
return BoolType.Instance;
88+
if (type == typeof(TimeSpan))
89+
return TimeSpanType.Instance;
90+
if (type == typeof(DateTime))
91+
return DateTimeType.Instance;
92+
if (type == typeof(DateTimeOffset))
93+
return DateTimeOffsetType.Instance;
94+
return NumberTypeFromType(type);
95+
}
96+
97+
public static PrimitiveType PrimitiveTypeFromKind(DataKind kind)
98+
{
99+
if (kind == DataKind.TX)
100+
return TextType.Instance;
101+
if (kind == DataKind.BL)
102+
return BoolType.Instance;
103+
if (kind == DataKind.TS)
104+
return TimeSpanType.Instance;
105+
if (kind == DataKind.DT)
106+
return DateTimeType.Instance;
107+
if (kind == DataKind.DZ)
108+
return DateTimeOffsetType.Instance;
109+
return NumberTypeFromKind(kind);
110+
}
111+
112+
public static NumberType NumberTypeFromType(Type type)
113+
{
114+
DataKind kind;
115+
if (type.TryGetDataKind(out kind))
116+
return NumberTypeFromKind(kind);
117+
118+
Contracts.Assert(false);
119+
throw new InvalidOperationException($"Bad type in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromType)}: {type}");
120+
}
121+
122+
public static NumberType NumberTypeFromKind(DataKind kind)
123+
{
124+
switch (kind)
125+
{
126+
case DataKind.I1:
127+
return NumberType.I1;
128+
case DataKind.U1:
129+
return NumberType.U1;
130+
case DataKind.I2:
131+
return NumberType.I2;
132+
case DataKind.U2:
133+
return NumberType.U2;
134+
case DataKind.I4:
135+
return NumberType.I4;
136+
case DataKind.U4:
137+
return NumberType.U4;
138+
case DataKind.I8:
139+
return NumberType.I8;
140+
case DataKind.U8:
141+
return NumberType.U8;
142+
case DataKind.R4:
143+
return NumberType.R4;
144+
case DataKind.R8:
145+
return NumberType.R8;
146+
case DataKind.UG:
147+
return NumberType.UG;
148+
}
149+
150+
Contracts.Assert(false);
151+
throw new InvalidOperationException($"Bad data kind in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromKind)}: {kind}");
152+
}
79153
}
80154
}

src/Microsoft.ML.Core/Data/IEstimator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ internal static void GetColumnTypeShape(ColumnType type,
167167

168168
isKey = itemType is KeyType;
169169
if (isKey)
170-
itemType = PrimitiveType.FromType(itemType.RawType);
170+
itemType = ColumnTypeExtensions.PrimitiveTypeFromType(itemType.RawType);
171171
}
172172

173173
/// <summary>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
7+
namespace Microsoft.ML.Data
8+
{
9+
public static class MetadataBuilderExtensions
10+
{
11+
/// <summary>
12+
/// Add slot names metadata.
13+
/// </summary>
14+
/// <param name="builder">The MetadataBuilder to which to add the slot names.</param>
15+
/// <param name="size">The size of the slot names vector.</param>
16+
/// <param name="getter">The getter delegate for the slot names.</param>
17+
public static void AddSlotNames(this MetadataBuilder builder, int size, ValueGetter<VBuffer<ReadOnlyMemory<char>>> getter)
18+
=> builder.Add(MetadataUtils.Kinds.SlotNames, new VectorType(TextType.Instance, size), getter);
19+
20+
/// <summary>
21+
/// Add key values metadata.
22+
/// </summary>
23+
/// <typeparam name="TValue">The value type of key values.</typeparam>
24+
/// <param name="builder">The MetadataBuilder to which to add the key values.</param>
25+
/// <param name="size">The size of key values vector.</param>
26+
/// <param name="valueType">The value type of key values. Its raw type must match <typeparamref name="TValue"/>.</param>
27+
/// <param name="getter">The getter delegate for the key values.</param>
28+
public static void AddKeyValues<TValue>(this MetadataBuilder builder, int size, PrimitiveType valueType, ValueGetter<VBuffer<TValue>> getter)
29+
=> builder.Add(MetadataUtils.Kinds.KeyValues, new VectorType(valueType, size), getter);
30+
}
31+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using Microsoft.ML.Internal.Utilities;
8+
9+
namespace Microsoft.ML.Data
10+
{
11+
[BestFriend]
12+
internal static class SchemaExtensions
13+
{
14+
public static Schema MakeSchema(IEnumerable<Schema.DetachedColumn> columns)
15+
{
16+
var builder = new SchemaBuilder();
17+
builder.AddColumns(columns);
18+
return builder.GetSchema();
19+
}
20+
21+
/// <summary>
22+
/// Legacy method to get the column index.
23+
/// DO NOT USE: use <see cref="Schema.GetColumnOrNull"/> instead.
24+
/// </summary>
25+
public static bool TryGetColumnIndex(this Schema schema, string name, out int col)
26+
{
27+
col = schema.GetColumnOrNull(name)?.Index ?? -1;
28+
return col >= 0;
29+
}
30+
}
31+
}

src/Microsoft.ML.Core/Microsoft.ML.Core.csproj

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
</PropertyGroup>
1010

1111
<ItemGroup>
12+
<ProjectReference Include="..\Microsoft.Data.DataView\Microsoft.Data.DataView.csproj" />
13+
1214
<PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" />
1315
<PackageReference Include="System.ComponentModel.Composition" Version="$(SystemComponentModelCompositionVersion)" />
1416
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />

src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ public void Run()
8181
var srcToDstMap = new Dictionary<DataKind, HashSet<DataKind>>();
8282

8383
var kinds = Enum.GetValues(typeof(DataKind)).Cast<DataKind>().Distinct().OrderBy(k => k).ToArray();
84-
var types = kinds.Select(kind => PrimitiveType.FromKind(kind)).ToArray();
84+
var types = kinds.Select(kind => ColumnTypeExtensions.PrimitiveTypeFromKind(kind)).ToArray();
8585

8686
HashSet<DataKind> nonIdentity = null;
8787
// For each kind and its associated type.

src/Microsoft.ML.Data/Data/SchemaDefinition.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc
397397
itemType = new KeyType(dataType, keyAttr.Min, keyAttr.Count, keyAttr.Contiguous);
398398
}
399399
else
400-
itemType = PrimitiveType.FromType(dataType);
400+
itemType = ColumnTypeExtensions.PrimitiveTypeFromType(dataType);
401401

402402
// Get the column type.
403403
ColumnType columnType;

src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ private sealed class UnsafeTypeCodec<T> : SimpleCodec<T> where T : struct
158158
// Throws an exception if T is neither a TimeSpan nor a NumberType.
159159
private static ColumnType UnsafeColumnType(Type type)
160160
{
161-
return type == typeof(TimeSpan) ? (ColumnType)TimeSpanType.Instance : NumberType.FromType(type);
161+
return type == typeof(TimeSpan) ? (ColumnType)TimeSpanType.Instance : ColumnTypeExtensions.NumberTypeFromType(type);
162162
}
163163

164164
public UnsafeTypeCodec(CodecFactory factory)
@@ -1269,7 +1269,7 @@ private bool GetKeyCodec(ColumnType type, out IValueCodec codec)
12691269
throw Contracts.ExceptParam(nameof(type), "type must be a key type");
12701270
// Create the internal codec the key codec will use to do the actual reading/writing.
12711271
IValueCodec innerCodec;
1272-
if (!TryGetCodec(NumberType.FromType(type.RawType), out innerCodec))
1272+
if (!TryGetCodec(ColumnTypeExtensions.NumberTypeFromType(type.RawType), out innerCodec))
12731273
{
12741274
codec = default(IValueCodec);
12751275
return false;

src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ private Schema CreateSchema(IExceptionContext ectx, Column[] cols, IDataLoader s
316316
Contracts.AssertValue(subLoader);
317317

318318
var builder = new SchemaBuilder();
319-
builder.AddColumns(cols.Select(c => new Schema.DetachedColumn(c.Name, PrimitiveType.FromKind(c.Type.Value), null)));
319+
builder.AddColumns(cols.Select(c => new Schema.DetachedColumn(c.Name, ColumnTypeExtensions.PrimitiveTypeFromKind(c.Type.Value), null)));
320320
var colSchema = builder.GetSchema();
321321

322322
var subSchema = subLoader.Schema;

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile,
613613
{
614614
kind = col.Type ?? DataKind.Num;
615615
ch.CheckUserArg(Enum.IsDefined(typeof(DataKind), kind), nameof(Column.Type), "Bad item type");
616-
itemType = PrimitiveType.FromKind(kind);
616+
itemType = ColumnTypeExtensions.PrimitiveTypeFromKind(kind);
617617
}
618618

619619
// This was checked above.
@@ -792,7 +792,7 @@ public Bindings(ModelLoadContext ctx, TextLoader parent)
792792
}
793793
}
794794
else
795-
itemType = PrimitiveType.FromKind(kind);
795+
itemType = ColumnTypeExtensions.PrimitiveTypeFromKind(kind);
796796

797797
int cseg = ctx.Reader.ReadInt32();
798798
Contracts.CheckDecode(cseg > 0);

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ private ValueCreatorCache()
5656
_creatorsVec = new Func<RowSet, ColumnPipe>[DataKindExtensions.KindCount];
5757
for (var kind = DataKindExtensions.KindMin; kind < DataKindExtensions.KindLim; kind++)
5858
{
59-
var type = PrimitiveType.FromKind(kind);
59+
var type = ColumnTypeExtensions.PrimitiveTypeFromKind(kind);
6060
_creatorsOne[kind.ToIndex()] = GetCreatorOneCore(type);
6161
_creatorsVec[kind.ToIndex()] = GetCreatorVecCore(type);
6262
}

src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ public sealed class InputRow<TRow> : InputRowBase<TRow>
8383
public override long Position => _position;
8484

8585
public InputRow(IHostEnvironment env, InternalSchemaDefinition schemaDef)
86-
: base(env, SchemaBuilder.MakeSchema(GetSchemaColumns(schemaDef)), schemaDef, MakePeeks(schemaDef), c => true)
86+
: base(env, SchemaExtensions.MakeSchema(GetSchemaColumns(schemaDef)), schemaDef, MakePeeks(schemaDef), c => true)
8787
{
8888
_position = -1;
8989
}
@@ -383,7 +383,7 @@ protected DataViewBase(IHostEnvironment env, string name, InternalSchemaDefiniti
383383
Host.AssertValue(schemaDefn);
384384

385385
_schemaDefn = schemaDefn;
386-
_schema = SchemaBuilder.MakeSchema(GetSchemaColumns(schemaDefn));
386+
_schema = SchemaExtensions.MakeSchema(GetSchemaColumns(schemaDefn));
387387
int n = schemaDefn.Columns.Length;
388388
_peeks = new Delegate[n];
389389
for (var i = 0; i < n; i++)
@@ -777,7 +777,7 @@ public MetadataInfo(string kind, T value, ColumnType metadataType = null)
777777
if (metadataType == null)
778778
{
779779
// Infer a type as best we can.
780-
var primitiveItemType = PrimitiveType.FromType(itemType);
780+
var primitiveItemType = ColumnTypeExtensions.PrimitiveTypeFromType(itemType);
781781
metadataType = isVector ? new VectorType(primitiveItemType) : (ColumnType)primitiveItemType;
782782
}
783783
else

src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us
254254
if (col.ColumnType == null)
255255
{
256256
// Infer a type as best we can.
257-
PrimitiveType itemType = PrimitiveType.FromType(dataItemType);
257+
PrimitiveType itemType = ColumnTypeExtensions.PrimitiveTypeFromType(dataItemType);
258258
colType = isVector ? new VectorType(itemType) : (ColumnType)itemType;
259259
}
260260
else

src/Microsoft.ML.Data/Transforms/ColumnBindingsBase.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ public ColumnBindings(Schema input, Schema.DetachedColumn[] addedColumns)
717717

718718
// Create the output schema.
719719
var schemaColumns = indices.Select(idx => idx >= 0 ? new Schema.DetachedColumn(input[idx]) : addedColumns[~idx]);
720-
Schema = SchemaBuilder.MakeSchema(schemaColumns);
720+
Schema = SchemaExtensions.MakeSchema(schemaColumns);
721721

722722
// Memorize column maps.
723723
_colMap = indices.ToArray();

src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ private static Schema GenerateOutputSchema(IEnumerable<int> map,
574574
Schema inputSchema)
575575
{
576576
var outputColumns = map.Select(x => new Schema.DetachedColumn(inputSchema[x]));
577-
return SchemaBuilder.MakeSchema(outputColumns);
577+
return SchemaExtensions.MakeSchema(outputColumns);
578578
}
579579
}
580580

src/Microsoft.ML.Data/Transforms/Normalizer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ internal static ColumnType LoadType(ModelLoadContext ctx)
334334
DataKind itemKind = (DataKind)ctx.Reader.ReadByte();
335335
Contracts.CheckDecode(itemKind == DataKind.R4 || itemKind == DataKind.R8);
336336

337-
var itemType = PrimitiveType.FromKind(itemKind);
337+
var itemType = ColumnTypeExtensions.PrimitiveTypeFromKind(itemKind);
338338
return isVector ? (ColumnType)(new VectorType(itemType, vectorSize)) : itemType;
339339
}
340340

src/Microsoft.ML.Data/Transforms/TypeConverting.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -369,10 +369,10 @@ internal static bool GetNewType(IExceptionContext ectx, ColumnType srcType, Data
369369
return false;
370370
}
371371
else if (!(srcType.GetItemType() is KeyType key))
372-
itemType = PrimitiveType.FromKind(kind);
372+
itemType = ColumnTypeExtensions.PrimitiveTypeFromKind(kind);
373373
else if (!KeyType.IsValidDataKind(kind))
374374
{
375-
itemType = PrimitiveType.FromKind(kind);
375+
itemType = ColumnTypeExtensions.PrimitiveTypeFromKind(kind);
376376
return false;
377377
}
378378
else

src/Microsoft.ML.Data/Transforms/ValueMappingTransformer.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
9898
var resultDic = inputSchema.ToDictionary(x => x.Name);
9999
var vectorKind = Transformer.ValueColumnType is VectorType ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.Scalar;
100100
var isKey = Transformer.ValueColumnType is KeyType;
101-
var columnType = (isKey) ? PrimitiveType.FromKind(DataKind.U4) :
101+
var columnType = (isKey) ? ColumnTypeExtensions.PrimitiveTypeFromKind(DataKind.U4) :
102102
Transformer.ValueColumnType;
103103
var metadataShape = SchemaShape.Create(Transformer.ValueColumnMetadata.Schema);
104104
foreach (var (Input, Output) in _columns)
@@ -138,7 +138,7 @@ internal static PrimitiveType GetPrimitiveType(Type rawType, out bool isVectorTy
138138
if (!type.TryGetDataKind(out DataKind kind))
139139
throw new InvalidOperationException($"Unsupported type {type} used in mapping.");
140140

141-
return PrimitiveType.FromKind(kind);
141+
return ColumnTypeExtensions.PrimitiveTypeFromKind(kind);
142142
}
143143

144144
/// <summary>
@@ -753,7 +753,7 @@ protected static PrimitiveType GetPrimitiveType(Type rawType, out bool isVectorT
753753
if (!type.TryGetDataKind(out DataKind kind))
754754
throw Contracts.Except($"Unsupported type {type} used in mapping.");
755755

756-
return PrimitiveType.FromKind(kind);
756+
return ColumnTypeExtensions.PrimitiveTypeFromKind(kind);
757757
}
758758

759759
public override void Save(ModelSaveContext ctx)
@@ -1027,7 +1027,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore()
10271027
throw _parent.Host.ExceptNotSupp("Column '{0}' cannot be mapped to values when the column and the map values are both vector type.", _columns[i].Source);
10281028
var colType = _valueMap.ValueType;
10291029
if (_inputSchema[_columns[i].Source].Type is VectorType)
1030-
colType = new VectorType(PrimitiveType.FromType(_valueMap.ValueType.GetItemType().RawType));
1030+
colType = new VectorType(ColumnTypeExtensions.PrimitiveTypeFromType(_valueMap.ValueType.GetItemType().RawType));
10311031
result[i] = new Schema.DetachedColumn(_columns[i].Name, colType, _valueMetadata);
10321032
}
10331033
return result;

src/Microsoft.ML.OnnxTransform/OnnxUtils.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ public static PrimitiveType OnnxToMlNetType(System.Type type)
245245
{
246246
if (!_typeToKindMap.ContainsKey(type))
247247
throw Contracts.ExceptNotSupp("Onnx type not supported", type);
248-
return PrimitiveType.FromKind(_typeToKindMap[type]);
248+
return ColumnTypeExtensions.PrimitiveTypeFromKind(_typeToKindMap[type]);
249249
}
250250
}
251251
}

src/Microsoft.ML.Transforms/CustomMappingTransformer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ public CustomMappingEstimator(IHostEnvironment env, Action<TSrc, TDst> mapAction
200200
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
201201
{
202202
var addedCols = DataViewConstructionUtils.GetSchemaColumns(Transformer.AddedSchema);
203-
var addedSchemaShape = SchemaShape.Create(SchemaBuilder.MakeSchema(addedCols));
203+
var addedSchemaShape = SchemaShape.Create(SchemaExtensions.MakeSchema(addedCols));
204204

205205
var result = inputSchema.ToDictionary(x => x.Name);
206206
var inputDef = InternalSchemaDefinition.Create(typeof(TSrc), Transformer.InputSchemaDefinition);

0 commit comments

Comments
 (0)