-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Categorical estimator #899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d68e59b
ac91db8
25d46e7
2570eb9
8174f3b
114fec9
9d51927
4711a8d
120d3e8
cbbf5d0
47a00a9
7bfedf5
d2d2a88
c645cbd
486d65a
49c5514
71c1d45
a397200
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,17 +2,14 @@ | |
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Float = System.Single; | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Text; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.CommandLine; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.EntryPoints; | ||
using Microsoft.ML.Runtime.Internal.Utilities; | ||
using Microsoft.ML.Runtime.Internal.Internallearn; | ||
|
||
[assembly: LoadableClass(CategoricalHashTransform.Summary, typeof(IDataTransform), typeof(CategoricalHashTransform), typeof(CategoricalHashTransform.Arguments), typeof(SignatureDataTransform), | ||
CategoricalHashTransform.UserName, "CategoricalHashTransform", "CatHashTransform", "CategoricalHash", "CatHash")] | ||
|
@@ -62,14 +59,11 @@ protected override bool TryParse(string str) | |
|
||
// We accept N:B:S where N is the new column name, B is the number of bits, | ||
// and S is source column names. | ||
string extra; | ||
if (!base.TryParse(str, out extra)) | ||
if (!TryParse(str, out string extra)) | ||
return false; | ||
if (extra == null) | ||
return true; | ||
|
||
int bits; | ||
if (!int.TryParse(extra, out bits)) | ||
if (!int.TryParse(extra, out int bits)) | ||
return false; | ||
HashBits = bits; | ||
return true; | ||
|
@@ -201,14 +195,81 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV | |
}; | ||
} | ||
|
||
return CategoricalTransform.CreateTransformCore( | ||
return CreateTransformCore( | ||
args.OutputKind, args.Column, | ||
args.Column.Select(col => col.OutputKind).ToList(), | ||
new HashTransform(h, hashArgs, input), | ||
h, | ||
env, | ||
args); | ||
} | ||
} | ||
|
||
private static IDataTransform CreateTransformCore(CategoricalTransform.OutputKind argsOutputKind, OneToOneColumn[] columns, | ||
List<CategoricalTransform.OutputKind?> columnOutputKinds, IDataTransform input, IHost h, Arguments catHashArgs = null) | ||
{ | ||
Contracts.CheckValue(columns, nameof(columns)); | ||
Contracts.CheckValue(columnOutputKinds, nameof(columnOutputKinds)); | ||
Contracts.CheckParam(columns.Length == columnOutputKinds.Count, nameof(columns)); | ||
|
||
using (var ch = h.Start("Create Transform Core")) | ||
{ | ||
// Create the KeyToVectorTransform, if needed. | ||
var cols = new List<KeyToVectorTransform.Column>(); | ||
bool binaryEncoding = argsOutputKind == CategoricalTransform.OutputKind.Bin; | ||
for (int i = 0; i < columns.Length; i++) | ||
{ | ||
var column = columns[i]; | ||
if (!column.TrySanitize()) | ||
throw h.ExceptUserArg(nameof(Column.Name)); | ||
|
||
bool? bag; | ||
CategoricalTransform.OutputKind kind = columnOutputKinds[i] ?? argsOutputKind; | ||
switch (kind) | ||
{ | ||
default: | ||
throw ch.ExceptUserArg(nameof(Column.OutputKind)); | ||
case CategoricalTransform.OutputKind.Key: | ||
continue; | ||
case CategoricalTransform.OutputKind.Bin: | ||
binaryEncoding = true; | ||
bag = false; | ||
break; | ||
case CategoricalTransform.OutputKind.Ind: | ||
bag = false; | ||
break; | ||
case CategoricalTransform.OutputKind.Bag: | ||
bag = true; | ||
break; | ||
} | ||
var col = new KeyToVectorTransform.Column(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
object initializer? #ByDesign There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's gonna be dead after hash transform become estimator. In reply to: 217236161 [](ancestors = 217236161) |
||
col.Name = column.Name; | ||
col.Source = column.Name; | ||
col.Bag = bag; | ||
cols.Add(col); | ||
} | ||
|
||
if (cols.Count == 0) | ||
return input; | ||
|
||
IDataTransform transform; | ||
if (binaryEncoding) | ||
{ | ||
if ((catHashArgs?.InvertHash ?? 0) != 0) | ||
ch.Warning("Invert hashing is being used with binary encoding."); | ||
|
||
var keyToBinaryVecCols = cols.Select(x => new KeyToBinaryVectorTransform.ColumnInfo(x.Source, x.Name)).ToArray(); | ||
transform = KeyToBinaryVectorTransform.Create(h, input, keyToBinaryVecCols); | ||
} | ||
else | ||
{ | ||
var keyToVecCols = cols.Select(x => new KeyToVectorTransform.ColumnInfo(x.Source, x.Name, x.Bag ?? argsOutputKind == CategoricalTransform.OutputKind.Bag)).ToArray(); | ||
|
||
transform = KeyToVectorTransform.Create(h, input, keyToVecCols); | ||
} | ||
|
||
ch.Done(); | ||
return transform; | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a right way to do? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good to me
In reply to: 217229486 [](ancestors = 217229486)