diff --git a/src/Microsoft.ML.Data/Transforms/ValueMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueMappingTransformer.cs index 9b2162ece3..45219c4c1b 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueMappingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueMappingTransformer.cs @@ -100,13 +100,14 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var isKey = Transformer.ValueColumnType is KeyType; var columnType = (isKey) ? PrimitiveType.FromKind(DataKind.U4) : Transformer.ValueColumnType; + var metadataShape = SchemaShape.Create(Transformer.ValueColumnMetadata.Schema); foreach (var (Input, Output) in _columns) { if (!inputSchema.TryFindColumn(Input, out var originalColumn)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Input); - // Get the type from TOutputType - var col = new SchemaShape.Column(Output, vectorKind, columnType, isKey, originalColumn.Metadata); + // Create the Value column + var col = new SchemaShape.Column(Output, vectorKind, columnType, isKey, metadataShape); resultDic[Output] = col; } return new SchemaShape(resultDic.Values); @@ -191,18 +192,14 @@ internal static IDataView CreateDataView(IHostEnvironment env, // set of values. This is used for generating the metadata of // the column. HashSet valueSet = new HashSet(); - HashSet keySet = new HashSet(); - for (int i = 0; i < values.Count(); ++i) + foreach (var v in values) { - var v = values.ElementAt(i); if (valueSet.Contains(v)) continue; valueSet.Add(v); - - var k = keys.ElementAt(i); - keySet.Add(k); } - var metaKeys = keySet.ToArray(); + + var metaKeys = valueSet.ToArray(); // Key Values are treated in one of two ways: // If the values are of type uint or ulong, these values are used directly as the keys types and no new keys are created. @@ -387,7 +384,7 @@ protected ValueMappingTransformer(IHostEnvironment env, IDataView lookupMap, Host.CheckNonEmpty(valueColumn, nameof(valueColumn), "A value column must be specified when passing in an IDataView for the value mapping"); _valueMap = CreateValueMapFromDataView(lookupMap, keyColumn, valueColumn); int valueColumnIdx = 0; - Host.Assert(lookupMap.Schema.TryGetColumnIndex(valueColumn, out valueColumnIdx)); + Host.Check(lookupMap.Schema.TryGetColumnIndex(valueColumn, out valueColumnIdx)); _valueMetadata = lookupMap.Schema[valueColumnIdx].Metadata; // Create the byte array of the original IDataView, this is used for saving out the data. diff --git a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs index 22cbd47265..d2f644270e 100644 --- a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs @@ -39,7 +39,7 @@ class TestWrong public class TestTermLookup { public string Label; - public int GroupId; + public int GroupId; [VectorType(2107)] public float[] Features; @@ -52,17 +52,17 @@ public void ValueMapOneValueTest() var data = new[] { new TestClass() { A = "bar", B = "test", C = "foo" } }; var dataView = ComponentCreation.CreateDataView(Env, data); - IEnumerable> keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; - IEnumerable values = new List() { 1, 2, 3, 4 }; + var keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; + var values = new List() { 1, 2, 3, 4 }; var estimator = new ValueMappingEstimator, int>(Env, keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); var t = estimator.Fit(dataView); var result = t.Transform(dataView); var cursor = result.GetRowCursor((col) => true); - var getterD = cursor.GetGetter(3); - var getterE = cursor.GetGetter(4); - var getterF = cursor.GetGetter(5); + var getterD = cursor.GetGetter(result.Schema["D"].Index); + var getterE = cursor.GetGetter(result.Schema["E"].Index); + var getterF = cursor.GetGetter(result.Schema["F"].Index); cursor.MoveNext(); int dValue = 0; @@ -93,9 +93,9 @@ public void ValueMapVectorValueTest() var result = t.Transform(dataView); var cursor = result.GetRowCursor((col) => true); - var getterD = cursor.GetGetter>(3); - var getterE = cursor.GetGetter>(4); - var getterF = cursor.GetGetter>(5); + var getterD = cursor.GetGetter>(result.Schema["D"].Index); + var getterE = cursor.GetGetter>(result.Schema["E"].Index); + var getterF = cursor.GetGetter>(result.Schema["F"].Index); cursor.MoveNext(); var valuesArray = values.ToArray(); @@ -116,17 +116,17 @@ public void ValueMappingMissingKey() var data = new[] { new TestClass() { A = "barTest", B = "test", C = "foo" } }; var dataView = ComponentCreation.CreateDataView(Env, data); - IEnumerable> keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; - IEnumerable values = new List() { 1, 2, 3, 4 }; + var keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; + var values = new List() { 1, 2, 3, 4 }; var estimator = new ValueMappingEstimator, int>(Env, keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); var t = estimator.Fit(dataView); var result = t.Transform(dataView); var cursor = result.GetRowCursor((col) => true); - var getterD = cursor.GetGetter(3); - var getterE = cursor.GetGetter(4); - var getterF = cursor.GetGetter(5); + var getterD = cursor.GetGetter(result.Schema["D"].Index); + var getterE = cursor.GetGetter(result.Schema["E"].Index); + var getterF = cursor.GetGetter(result.Schema["F"].Index); cursor.MoveNext(); int dValue = 1; @@ -146,8 +146,8 @@ void TestDuplicateKeys() var data = new[] { new TestClass() { A = "barTest", B = "test", C = "foo" } }; var dataView = ComponentCreation.CreateDataView(Env, data); - IEnumerable> keys = new List>() { "foo".AsMemory(), "foo".AsMemory() }; - IEnumerable values = new List() { 1, 2 }; + var keys = new List>() { "foo".AsMemory(), "foo".AsMemory() }; + var values = new List() { 1, 2 }; Assert.Throws(() => new ValueMappingEstimator, int>(Env, keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") })); } @@ -158,11 +158,11 @@ public void ValueMappingOutputSchema() var data = new[] { new TestClass() { A = "barTest", B = "test", C = "foo" } }; var dataView = ComponentCreation.CreateDataView(Env, data); - IEnumerable> keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; - IEnumerable values = new List() { 1, 2, 3, 4 }; + var keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; + var values = new List() { 1, 2, 3, 4 }; var estimator = new ValueMappingEstimator, int>(Env, keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); - var outputSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema)); + var outputSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema)); Assert.Equal(6, outputSchema.Count()); Assert.True(outputSchema.TryFindColumn("D", out SchemaShape.Column dColumn)); Assert.True(outputSchema.TryFindColumn("E", out SchemaShape.Column eColumn)); @@ -173,7 +173,7 @@ public void ValueMappingOutputSchema() Assert.Equal(typeof(int), eColumn.ItemType.RawType); Assert.False(eColumn.IsKey); - + Assert.Equal(typeof(int), fColumn.ItemType.RawType); Assert.False(fColumn.IsKey); } @@ -184,11 +184,11 @@ public void ValueMappingWithValuesAsKeyTypesOutputSchema() var data = new[] { new TestClass() { A = "bar", B = "test", C = "foo" } }; var dataView = ComponentCreation.CreateDataView(Env, data); - IEnumerable> keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; - IEnumerable> values = new List>() { "t".AsMemory(), "s".AsMemory(), "u".AsMemory(), "v".AsMemory() }; + var keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; + var values = new List>() { "t".AsMemory(), "s".AsMemory(), "u".AsMemory(), "v".AsMemory() }; var estimator = new ValueMappingEstimator, ReadOnlyMemory>(Env, keys, values, true, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); - var outputSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema)); + var outputSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema)); Assert.Equal(6, outputSchema.Count()); Assert.True(outputSchema.TryFindColumn("D", out SchemaShape.Column dColumn)); Assert.True(outputSchema.TryFindColumn("E", out SchemaShape.Column eColumn)); @@ -199,7 +199,7 @@ public void ValueMappingWithValuesAsKeyTypesOutputSchema() Assert.Equal(typeof(uint), eColumn.ItemType.RawType); Assert.True(eColumn.IsKey); - + Assert.Equal(typeof(uint), fColumn.ItemType.RawType); Assert.True(fColumn.IsKey); @@ -212,10 +212,10 @@ public void ValueMappingValuesAsUintKeyTypes() var data = new[] { new TestClass() { A = "bar", B = "test2", C = "wahoo" } }; var dataView = ComponentCreation.CreateDataView(Env, data); - IEnumerable> keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; + var keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; // These are the expected key type values - IEnumerable values = new List() { 51, 25, 42, 61 }; + var values = new List() { 51, 25, 42, 61 }; var estimator = new ValueMappingEstimator, uint>(Env, keys, values, true, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); @@ -223,11 +223,11 @@ public void ValueMappingValuesAsUintKeyTypes() var result = t.Transform(dataView); var cursor = result.GetRowCursor((col) => true); - var getterD = cursor.GetGetter(3); - var getterE = cursor.GetGetter(4); - var getterF = cursor.GetGetter(5); + var getterD = cursor.GetGetter(result.Schema["D"].Index); + var getterE = cursor.GetGetter(result.Schema["E"].Index); + var getterF = cursor.GetGetter(result.Schema["F"].Index); cursor.MoveNext(); - + // The expected values will contain the actual uints and are not generated. uint dValue = 1; getterD(ref dValue); @@ -251,10 +251,10 @@ public void ValueMappingValuesAsUlongKeyTypes() var data = new[] { new TestClass() { A = "bar", B = "test2", C = "wahoo" } }; var dataView = ComponentCreation.CreateDataView(Env, data); - IEnumerable> keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; + var keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; // These are the expected key type values - IEnumerable values = new List() { 51, Int32.MaxValue, 42, 61 }; + var values = new List() { 51, Int32.MaxValue, 42, 61 }; var estimator = new ValueMappingEstimator, ulong>(Env, keys, values, true, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); @@ -262,11 +262,11 @@ public void ValueMappingValuesAsUlongKeyTypes() var result = t.Transform(dataView); var cursor = result.GetRowCursor((col) => true); - var getterD = cursor.GetGetter(3); - var getterE = cursor.GetGetter(4); - var getterF = cursor.GetGetter(5); + var getterD = cursor.GetGetter(result.Schema["D"].Index); + var getterE = cursor.GetGetter(result.Schema["E"].Index); + var getterF = cursor.GetGetter(result.Schema["F"].Index); cursor.MoveNext(); - + // The expected values will contain the actual uints and are not generated. ulong dValue = 1; getterD(ref dValue); @@ -289,19 +289,19 @@ public void ValueMappingValuesAsStringKeyTypes() var data = new[] { new TestClass() { A = "bar", B = "test", C = "notfound" } }; var dataView = ComponentCreation.CreateDataView(Env, data); - IEnumerable> keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; + var keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; // Generating the list of strings for the key type values, note that foo1 is duplicated as intended to test that the same index value is returned - IEnumerable> values = new List>() { "foo1".AsMemory(), "foo2".AsMemory(), "foo1".AsMemory(), "foo3".AsMemory() }; + var values = new List>() { "foo1".AsMemory(), "foo2".AsMemory(), "foo1".AsMemory(), "foo3".AsMemory() }; var estimator = new ValueMappingEstimator, ReadOnlyMemory>(Env, keys, values, true, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); var t = estimator.Fit(dataView); var result = t.Transform(dataView); var cursor = result.GetRowCursor((col) => true); - var getterD = cursor.GetGetter(3); - var getterE = cursor.GetGetter(4); - var getterF = cursor.GetGetter(5); + var getterD = cursor.GetGetter(result.Schema["D"].Index); + var getterE = cursor.GetGetter(result.Schema["E"].Index); + var getterF = cursor.GetGetter(result.Schema["F"].Index); cursor.MoveNext(); // The expected values will contain the generated key type values starting from 1. @@ -320,6 +320,32 @@ public void ValueMappingValuesAsStringKeyTypes() Assert.Equal(0, fValue); } + [Fact] + public void ValueMappingValuesAsKeyTypesReverseLookup() + { + var data = new[] { new TestClass() { A = "bar", B = "test", C = "notfound" } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + + var keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; + + // Generating the list of strings for the key type values, note that foo1 is duplicated as intended to test that the same index value is returned + var values = new List>() { "foo1".AsMemory(), "foo2".AsMemory(), "foo1".AsMemory(), "foo3".AsMemory() }; + + var estimator = new ValueMappingEstimator, ReadOnlyMemory>(Env, keys, values, true, new[] { ("A", "D") }) + .Append(new KeyToValueMappingEstimator(Env, ("D", "DOutput"))); + var t = estimator.Fit(dataView); + + var result = t.Transform(dataView); + var cursor = result.GetRowCursor((col) => true); + var getterD = cursor.GetGetter>(result.Schema["DOutput"].Index); + cursor.MoveNext(); + + // The expected values will contain the generated key type values starting from 1. + ReadOnlyMemory dValue = default; + getterD(ref dValue); + Assert.Equal("foo2".AsMemory(), dValue); + } + [Fact] public void ValueMappingWorkout() { @@ -328,8 +354,8 @@ public void ValueMappingWorkout() var badData = new[] { new TestWrong() { A = "bar", B = 1.2f } }; var badDataView = ComponentCreation.CreateDataView(Env, badData); - IEnumerable> keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; - IEnumerable values = new List() { 1, 2, 3, 4 }; + var keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; + var values = new List() { 1, 2, 3, 4 }; // Workout on value mapping var est = ML.Transforms.Conversion.ValueMap(keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); @@ -362,7 +388,7 @@ void TestCommandLineNoLoaderWithColumnNames() + dataFile + @" col=A:B keyCol=foo valueCol=bar} in=f:\1.txt" }), (int)0); } - + [Fact] void TestCommandLineNoLoaderWithoutTreatValuesAsKeys() { @@ -377,10 +403,10 @@ void TestSavingAndLoading() { var data = new[] { new TestClass() { A = "bar", B = "foo", C = "test", } }; var dataView = ComponentCreation.CreateDataView(Env, data); - var est = new ValueMappingEstimator, int>(Env, - new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory() }, - new List() { 2, 43, 56 }, - new [] {("A","D"), ("B", "E")}); + var est = new ValueMappingEstimator, int>(Env, + new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory() }, + new List() { 2, 43, 56 }, + new[] { ("A", "D"), ("B", "E") }); var transformer = est.Fit(dataView); using (var ms = new MemoryStream()) { @@ -400,7 +426,7 @@ void TestValueMapBackCompatTermLookup() { // Model generated with: xf=drop{col=A} // Expected output: Features Label B C - var data = new[] { new TestTermLookup() { Label = "good", GroupId=1 } }; + var data = new[] { new TestTermLookup() { Label = "good", GroupId = 1 } }; var dataView = ComponentCreation.CreateDataView(Env, data); string termLookupModelPath = GetDataPath("backcompat/termlookup.zip"); using (FileStream fs = File.OpenRead(termLookupModelPath)) @@ -417,7 +443,7 @@ void TestValueMapBackCompatTermLookupKeyTypeValue() { // Model generated with: xf=drop{col=A} // Expected output: Features Label B C - var data = new[] { new TestTermLookup() { Label = "Good", GroupId=1 } }; + var data = new[] { new TestTermLookup() { Label = "Good", GroupId = 1 } }; var dataView = ComponentCreation.CreateDataView(Env, data); string termLookupModelPath = GetDataPath("backcompat/termlookup_with_key.zip"); using (FileStream fs = File.OpenRead(termLookupModelPath)) @@ -426,7 +452,7 @@ void TestValueMapBackCompatTermLookupKeyTypeValue() Assert.True(result.Schema.TryGetColumnIndex("Features", out int featureIdx)); Assert.True(result.Schema.TryGetColumnIndex("Label", out int labelIdx)); Assert.True(result.Schema.TryGetColumnIndex("GroupId", out int groupIdx)); - + Assert.True(result.Schema[labelIdx].Type is KeyType); Assert.Equal(5, result.Schema[labelIdx].Type.ItemType.GetKeyCount());