diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 9923cb7635..5f16d116a0 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -273,6 +273,14 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.AutoML.Samples EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Samples.GPU", "docs\samples\Microsoft.ML.Samples.GPU\Microsoft.ML.Samples.GPU.csproj", "{3C8F910B-7F23-4D25-B521-6D5AC9570ADD}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Featurizers", "src\Microsoft.ML.Featurizers\Microsoft.ML.Featurizers.csproj", "{E2DD0721-5B0F-4606-8182-4C7EFB834518}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Featurizers", "Microsoft.ML.Featurizers", "{1BA5C784-52E8-4A87-8525-26B2452F2882}" + ProjectSection(SolutionItems) = preProject + pkg\Microsoft.ML.Featurizers\Microsoft.ML.Featurizers.nupkgproj = pkg\Microsoft.ML.Featurizers\Microsoft.ML.Featurizers.nupkgproj + pkg\Microsoft.ML.Featurizers\Microsoft.ML.Featurizers.symbols.nupkgproj = pkg\Microsoft.ML.Featurizers\Microsoft.ML.Featurizers.symbols.nupkgproj + EndProjectSection +EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeGenerator", "src\Microsoft.ML.CodeGenerator\Microsoft.ML.CodeGenerator.csproj", "{56CB0850-7341-4D71-9AE4-9EFC472D93DD}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeGenerator.Tests", "test\Microsoft.ML.CodeGenerator.Tests\Microsoft.ML.CodeGenerator.Tests.csproj", "{46CC5637-3DDF-4100-93FC-44BB87B2DB81}" @@ -1690,6 +1698,30 @@ Global {46CC5637-3DDF-4100-93FC-44BB87B2DB81}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU {46CC5637-3DDF-4100-93FC-44BB87B2DB81}.Release-netfx|x64.ActiveCfg = Release-netfx|Any CPU {46CC5637-3DDF-4100-93FC-44BB87B2DB81}.Release-netfx|x64.Build.0 = Release-netfx|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug|x64.ActiveCfg = Debug|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug|x64.Build.0 = Debug|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug-netcoreapp3_0|Any CPU.ActiveCfg = Debug-netcoreapp3_0|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug-netcoreapp3_0|Any CPU.Build.0 = Debug-netcoreapp3_0|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug-netcoreapp3_0|x64.ActiveCfg = Debug-netcoreapp3_0|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug-netcoreapp3_0|x64.Build.0 = Debug-netcoreapp3_0|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug-netfx|x64.ActiveCfg = Debug-netfx|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug-netfx|x64.Build.0 = Debug-netfx|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release|Any CPU.Build.0 = Release|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release|x64.ActiveCfg = Release|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release|x64.Build.0 = Release|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release-netcoreapp3_0|Any CPU.ActiveCfg = Release-netcoreapp3_0|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release-netcoreapp3_0|Any CPU.Build.0 = Release-netcoreapp3_0|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release-netcoreapp3_0|x64.ActiveCfg = Release-netcoreapp3_0|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release-netcoreapp3_0|x64.Build.0 = Release-netcoreapp3_0|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release-netfx|x64.ActiveCfg = Release-netfx|Any CPU + {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release-netfx|x64.Build.0 = Release-netfx|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -1779,6 +1811,8 @@ Global {56CB0850-7341-4D71-9AE4-9EFC472D93DD} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {46CC5637-3DDF-4100-93FC-44BB87B2DB81} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {3817A875-278C-4140-BF66-3C4A8CA55F0D} = {D3D38B03-B557-484D-8348-8BADEE4DF592} + {E2DD0721-5B0F-4606-8182-4C7EFB834518} = {09EADF06-BE25-4228-AB53-95AE3E15B530} + {1BA5C784-52E8-4A87-8525-26B2452F2882} = {D3D38B03-B557-484D-8348-8BADEE4DF592} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D} diff --git a/build/BranchInfo.props b/build/BranchInfo.props index e87b2897a4..ce3d93c486 100644 --- a/build/BranchInfo.props +++ b/build/BranchInfo.props @@ -30,12 +30,12 @@ 1 4 0 - preview3 + preview2 0 16 0 - preview3 + preview2 diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/CategoryImputer.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/CategoryImputer.cs new file mode 100644 index 0000000000..f89ddc6cfb --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/CategoryImputer.cs @@ -0,0 +1,69 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; + +namespace Samples.Dynamic +{ + public static class CategoryImputer + { + public static void Example() + { + // Create a new ML context, for ML.NET operations. It can be used for + // exception tracking and logging, as well as the source of randomness. + var mlContext = new MLContext(); + + // Create a small dataset as an IEnumerable. + var samples = new List() + { + new InputData(){ Feature1 = 1f }, + + new InputData(){ Feature1 = float.NaN }, + + new InputData(){ Feature1 = 1f }, + + new InputData(){ Feature1 = float.NaN }, + + new InputData(){ Feature1 = 9f }, + }; + + // Convert training data to IDataView. + var dataview = mlContext.Data.LoadFromEnumerable(samples); + + // A pipeline for filling in the missing values in the feature1 column + var pipeline = mlContext.Transforms.CatagoryImputerTransformer("Feature1"); + + // The transformed data. + var transformedData = pipeline.Fit(dataview).Transform(dataview); + + // Now let's take a look at what this did. The NaN values should be filled in with the most frequent value, 1. + // We can extract the newly created columns as an IEnumerable of TransformedData. + var featuresColumn = mlContext.Data.CreateEnumerable( + transformedData, reuseRowObject: false); + + // And we can write out a few rows + Console.WriteLine($"Features column obtained post-transformation."); + foreach (var featureRow in featuresColumn) + Console.WriteLine(featureRow.Feature1); + + // Expected output: + // Features column obtained post-transformation. + // 1 + // 1 + // 1 + // 1 + // 9 + } + + private class InputData + { + public float Feature1; + } + + private sealed class TransformedData + { + public float Feature1 { get; set; } + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformer.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformer.cs new file mode 100644 index 0000000000..9bf851b0ba --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformer.cs @@ -0,0 +1,84 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; + +namespace Samples.Dynamic +{ + public static class DateTimeTransformer + { + private class DateTimeInput + { + public long Date; + } + + public static void Example() + { + // Create a new ML context, for ML.NET operations. It can be used for + // exception tracking and logging, as well as the source of randomness. + var mlContext = new MLContext(); + + // Create a small dataset as an IEnumerable. + // Future Date - 2025 June 30 + var samples = new[] { new DateTimeInput() { Date = 1751241600 } }; + + // Convert training data to IDataView. + var dataview = mlContext.Data.LoadFromEnumerable(samples); + + // A pipeline for splitting the time features into individual columns + var pipeline = mlContext.Transforms.DateTimeTransformer("Date", "DTC"); + + // The transformed data. + var transformedData = pipeline.Fit(dataview).Transform(dataview); + + // Now let's take a look at what this did. We should have created 21 more columns with all the + // DateTime information split into its own columns + var featuresColumn = mlContext.Data.CreateEnumerable( + transformedData, reuseRowObject: false); + + // And we can write out a few rows + Console.WriteLine($"Features column obtained post-transformation."); + foreach (var featureRow in featuresColumn) + Console.WriteLine(featureRow.Date + ", " + featureRow.DTCYear + ", " + featureRow.DTCMonth + ", " + + featureRow.DTCDay + ", " + featureRow.DTCHour + ", " + featureRow.DTCMinute + ", " + + featureRow.DTCSecond + ", " + featureRow.DTCAmPm + ", " + featureRow.DTCHour12 + ", " + + featureRow.DTCDayOfWeek + ", " + featureRow.DTCDayOfQuarter + ", " + featureRow.DTCDayOfYear + + ", " + featureRow.DTCWeekOfMonth + ", " + featureRow.DTCQuarterOfYear + ", " + featureRow.DTCHalfOfYear + + ", " + featureRow.DTCWeekIso + ", " + featureRow.DTCYearIso + ", " + featureRow.DTCMonthLabel + ", " + + featureRow.DTCAmPmLabel + ", " + featureRow.DTCDayOfWeekLabel + ", " + featureRow.DTCHolidayName + ", " + + featureRow.DTCIsPaidTimeOff); + + // Expected output: + // Features columns obtained post-transformation. + // 1751241600, 2025, 6, 30, 0, 0, 0, 0, 0, 1, 91, 180, 4, 2, 1, 27, 2025, June, am, Monday, , 0 + } + + // These columns start with DTC because that is the prefix we picked + private sealed class TransformedData + { + public long Date { get; set; } + public int DTCYear { get; set; } + public byte DTCMonth { get; set; } + public byte DTCDay { get; set; } + public byte DTCHour { get; set; } + public byte DTCMinute { get; set; } + public byte DTCSecond { get; set; } + public byte DTCAmPm { get; set; } + public byte DTCHour12 { get; set; } + public byte DTCDayOfWeek { get; set; } + public byte DTCDayOfQuarter { get; set; } + public ushort DTCDayOfYear { get; set; } + public ushort DTCWeekOfMonth { get; set; } + public byte DTCQuarterOfYear { get; set; } + public byte DTCHalfOfYear { get; set; } + public byte DTCWeekIso { get; set; } + public int DTCYearIso { get; set; } + public string DTCMonthLabel { get; set; } + public string DTCAmPmLabel { get; set; } + public string DTCDayOfWeekLabel { get; set; } + public string DTCHolidayName { get; set; } + public byte DTCIsPaidTimeOff { get; set; } + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs new file mode 100644 index 0000000000..680b338056 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs @@ -0,0 +1,80 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; + +namespace Samples.Dynamic +{ + public static class DateTimeTransformerDropColumns + { + private class DateTimeInput + { + public long Date; + } + + public static void Example() + { + // Create a new ML context, for ML.NET operations. It can be used for + // exception tracking and logging, as well as the source of randomness. + var mlContext = new MLContext(); + + // Create a small dataset as an IEnumerable. + // Future Date - 2025 June 30 + var samples = new[] { new DateTimeInput() { Date = 1751241600 } }; + + // Convert training data to IDataView. + var dataview = mlContext.Data.LoadFromEnumerable(samples); + + // A pipeline for splitting the time features into individual columns + // All the columns listed here will be dropped. + var pipeline = mlContext.Transforms.DateTimeTransformer("Date", "DTC", DateTimeTransformerEstimator.ColumnsProduced.IsPaidTimeOff, + DateTimeTransformerEstimator.ColumnsProduced.Day, DateTimeTransformerEstimator.ColumnsProduced.QuarterOfYear, + DateTimeTransformerEstimator.ColumnsProduced.AmPm, DateTimeTransformerEstimator.ColumnsProduced.HolidayName); + + // The transformed data. + var transformedData = pipeline.Fit(dataview).Transform(dataview); + + // Now let's take a look at what this did. We should have created 16 more columns with all the + // DateTime information split into its own columns + var featuresColumn = mlContext.Data.CreateEnumerable( + transformedData, reuseRowObject: false); + + // And we can write out a few rows + Console.WriteLine($"Features column obtained post-transformation."); + foreach (var featureRow in featuresColumn) + Console.WriteLine(featureRow.Date + ", " + featureRow.DTCYear + ", " + featureRow.DTCMonth + ", " + + featureRow.DTCHour + ", " + featureRow.DTCMinute + ", " + featureRow.DTCSecond + ", " + + featureRow.DTCHour12 + ", " + featureRow.DTCDayOfWeek + ", " + featureRow.DTCDayOfQuarter + ", " + + featureRow.DTCDayOfYear + ", " + featureRow.DTCWeekOfMonth + ", " + featureRow.DTCHalfOfYear + + ", " + featureRow.DTCWeekIso + ", " + featureRow.DTCYearIso + ", " + featureRow.DTCMonthLabel + ", " + + featureRow.DTCAmPmLabel + ", " + featureRow.DTCDayOfWeekLabel); + + // Expected output: + // Features columns obtained post-transformation. + // 1751241600, 2025, 6, 30, 0, 0, 0, 0, 0, 1, 91, 180, 4, 2, 1, 27, 2025, June, am, Monday + } + + // These columns start with DTC because that is the prefix we picked + private sealed class TransformedData + { + public long Date { get; set; } + public int DTCYear { get; set; } + public byte DTCMonth { get; set; } + public byte DTCHour { get; set; } + public byte DTCMinute { get; set; } + public byte DTCSecond { get; set; } + public byte DTCHour12 { get; set; } + public byte DTCDayOfWeek { get; set; } + public byte DTCDayOfQuarter { get; set; } + public ushort DTCDayOfYear { get; set; } + public ushort DTCWeekOfMonth { get; set; } + public byte DTCHalfOfYear { get; set; } + public byte DTCWeekIso { get; set; } + public int DTCYearIso { get; set; } + public string DTCMonthLabel { get; set; } + public string DTCAmPmLabel { get; set; } + public string DTCDayOfWeekLabel { get; set; } + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScaler.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScaler.cs new file mode 100644 index 0000000000..d46449447f --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScaler.cs @@ -0,0 +1,69 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; + +namespace Samples.Dynamic +{ + public static class RobustScaler + { + public static void Example() + { + // Create a new ML context, for ML.NET operations. It can be used for + // exception tracking and logging, as well as the source of randomness. + var mlContext = new MLContext(); + + // Create a small dataset as an IEnumerable. + var samples = new List() + { + new InputData(){ Feature1 = 1f }, + + new InputData(){ Feature1 = 3f }, + + new InputData(){ Feature1 = 5f }, + + new InputData(){ Feature1 = 7f }, + + new InputData(){ Feature1 = 9f }, + }; + + // Convert training data to IDataView. + var dataview = mlContext.Data.LoadFromEnumerable(samples); + + // A pipeline for centering and scaling the feature1 column + var pipeline = mlContext.Transforms.RobustScalerTransformer("Feature1"); + + // The transformed data. + var transformedData = pipeline.Fit(dataview).Transform(dataview); + + // Now let's take a look at what this did. The values should be centered around 0 and scaled. + // We can extract the newly created columns as an IEnumerable of TransformedData. + var featuresColumn = mlContext.Data.CreateEnumerable( + transformedData, reuseRowObject: false); + + // And we can write out a few rows + Console.WriteLine($"Features column obtained post-transformation."); + foreach (var featureRow in featuresColumn) + Console.WriteLine(featureRow.Feature1); + + // Expected output: + // Features column obtained post-transformation. + // -1 + // -.5 + // 0 + // .5 + // 1 + } + + private class InputData + { + public float Feature1; + } + + private sealed class TransformedData + { + public float Feature1 { get; set; } + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScalerWithCenter.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScalerWithCenter.cs new file mode 100644 index 0000000000..7a868ba7ee --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScalerWithCenter.cs @@ -0,0 +1,69 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; + +namespace Samples.Dynamic +{ + public static class RobustScalerWithCenter + { + public static void Example() + { + // Create a new ML context, for ML.NET operations. It can be used for + // exception tracking and logging, as well as the source of randomness. + var mlContext = new MLContext(); + + // Create a small dataset as an IEnumerable. + var samples = new List() + { + new InputData(){ Feature1 = 1f }, + + new InputData(){ Feature1 = 3f }, + + new InputData(){ Feature1 = 5f }, + + new InputData(){ Feature1 = 7f }, + + new InputData(){ Feature1 = 9f }, + }; + + // Convert training data to IDataView. + var dataview = mlContext.Data.LoadFromEnumerable(samples); + + // A pipeline for Centering the feature1 column + var pipeline = mlContext.Transforms.RobustScalerTransformer("Feature1", scale: false); + + // The transformed data. + var transformedData = pipeline.Fit(dataview).Transform(dataview); + + // Now let's take a look at what this did. The values should be centered around 0. + // We can extract the newly created columns as an IEnumerable of TransformedData. + var featuresColumn = mlContext.Data.CreateEnumerable( + transformedData, reuseRowObject: false); + + // And we can write out a few rows + Console.WriteLine($"Features column obtained post-transformation."); + foreach (var featureRow in featuresColumn) + Console.WriteLine(featureRow.Feature1); + + // Expected output: + // Features column obtained post-transformation. + // -4 + // -2 + // 0 + // 2 + // 4 + } + + private class InputData + { + public float Feature1; + } + + private sealed class TransformedData + { + public float Feature1 { get; set; } + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScalerWithScale.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScalerWithScale.cs new file mode 100644 index 0000000000..4318a67e9f --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScalerWithScale.cs @@ -0,0 +1,69 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; + +namespace Samples.Dynamic +{ + public static class RobustScalerWithScale + { + public static void Example() + { + // Create a new ML context, for ML.NET operations. It can be used for + // exception tracking and logging, as well as the source of randomness. + var mlContext = new MLContext(); + + // Create a small dataset as an IEnumerable. + var samples = new List() + { + new InputData(){ Feature1 = 1f }, + + new InputData(){ Feature1 = 3f }, + + new InputData(){ Feature1 = 5f }, + + new InputData(){ Feature1 = 7f }, + + new InputData(){ Feature1 = 9f }, + }; + + // Convert training data to IDataView. + var dataview = mlContext.Data.LoadFromEnumerable(samples); + + // A pipeline for scaling the feature1 column + var pipeline = mlContext.Transforms.RobustScalerTransformer("Feature1", center: false); + + // The transformed data. + var transformedData = pipeline.Fit(dataview).Transform(dataview); + + // Now let's take a look at what this did. The values should be scaled by the range * ratio. + // We can extract the newly created columns as an IEnumerable of TransformedData. + var featuresColumn = mlContext.Data.CreateEnumerable( + transformedData, reuseRowObject: false); + + // And we can write out a few rows + Console.WriteLine($"Features column obtained post-transformation."); + foreach (var featureRow in featuresColumn) + Console.WriteLine(featureRow.Feature1); + + // Expected output: + // Features column obtained post-transformation. + // 0.25 + // .75 + // 1.25 + // 1.75 + // 2.25 + } + + private class InputData + { + public float Feature1; + } + + private sealed class TransformedData + { + public float Feature1 { get; set; } + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerBackFill.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerBackFill.cs new file mode 100644 index 0000000000..50408cc4e4 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerBackFill.cs @@ -0,0 +1,71 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; + +namespace Samples.Dynamic +{ + public static class TimeSeriesImputerBackFill + { + public static void Example() + { + // Create a new ML context, for ML.NET operations. It can be used for + // exception tracking and logging, as well as the source of randomness. + var mlContext = new MLContext(); + + // Create a small dataset as an IEnumerable. + var samples = new[] { new InputData() { Date = 0, GrainA = "A", DataA = float.NaN }, + new InputData() { Date = 1, GrainA = "A", DataA = float.NaN }, + new InputData() { Date = 3, GrainA = "A", DataA = 5.0f }, + new InputData() { Date = 5, GrainA = "A", DataA = float.NaN }, + new InputData() { Date = 7, GrainA = "A", DataA = 2.0f }}; + + // Convert training data to IDataView. + var dataview = mlContext.Data.LoadFromEnumerable(samples); + + // A pipeline for imputing the missing rows and values in the columns using the "BackFill" strategy. + var pipeline = mlContext.Transforms.TimeSeriesImputer("Date", new string[] { "GrainA" }, TimeSeriesImputerEstimator.ImputationStrategy.BackFill); + + // The transformed data. + var transformedData = pipeline.Fit(dataview).Transform(dataview); + + // Now let's take a look at what this did. The NaN values should be filled in with the next value that was not NaN, + // and rows should be created to fill in the missing gaps in the time column. + // We can extract the newly created columns as an IEnumerable of TransformedData. + var featuresColumn = mlContext.Data.CreateEnumerable( + transformedData, reuseRowObject: false); + + // And we can write out a few rows + Console.WriteLine($"Features column obtained post-transformation."); + foreach (var featureRow in featuresColumn) + Console.WriteLine(featureRow.Date + ", " + featureRow.GrainA + ", " + featureRow.DataA + ", " + featureRow.IsRowImputed); + + // Expected output: + // Features column obtained post-transformation. + // 0, A, 5.0, false + // 1, A, 5.0, false + // 2, A, 5.0, true + // 3, A, 5.0, false + // 4, A, 2.0, true + // 5, A, 2.0, false + // 6, A, 2.0, true + // 7, A, 2.0, false + } + + private class InputData + { + public long Date; + public string GrainA; + public float DataA; + } + + private sealed class TransformedData + { + public long Date { get; set; } + public string GrainA { get; set; } + public float DataA { get; set; } + public bool IsRowImputed { get; set; } + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerForwardFill.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerForwardFill.cs new file mode 100644 index 0000000000..6a22b98b4d --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerForwardFill.cs @@ -0,0 +1,71 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; + +namespace Samples.Dynamic +{ + public static class TimeSeriesImputerForwardFill + { + public static void Example() + { + // Create a new ML context, for ML.NET operations. It can be used for + // exception tracking and logging, as well as the source of randomness. + var mlContext = new MLContext(); + + // Create a small dataset as an IEnumerable. + var samples = new[] { new InputData() { Date = 0, GrainA = "A", DataA = 2.0f }, + new InputData() { Date = 1, GrainA = "A", DataA = float.NaN }, + new InputData() { Date = 3, GrainA = "A", DataA = 5.0f }, + new InputData() { Date = 5, GrainA = "A", DataA = float.NaN }, + new InputData() { Date = 7, GrainA = "A", DataA = float.NaN }}; + + // Convert training data to IDataView. + var dataview = mlContext.Data.LoadFromEnumerable(samples); + + // A pipeline for imputing the missing rows and values in the columns using the default "ForwardFill" strategy. + var pipeline = mlContext.Transforms.TimeSeriesImputer("Date", new string[] { "GrainA" }); + + // The transformed data. + var transformedData = pipeline.Fit(dataview).Transform(dataview); + + // Now let's take a look at what this did. The NaN values should be filled in with last value that was not NaN, + // and rows should be created to fill in the missing gaps in the time column. + // We can extract the newly created columns as an IEnumerable of TransformedData. + var featuresColumn = mlContext.Data.CreateEnumerable( + transformedData, reuseRowObject: false); + + // And we can write out a few rows + Console.WriteLine($"Features column obtained post-transformation."); + foreach (var featureRow in featuresColumn) + Console.WriteLine(featureRow.Date + ", " + featureRow.GrainA + ", " + featureRow.DataA + ", " + featureRow.IsRowImputed); + + // Expected output: + // Features column obtained post-transformation. + // 0, A, 2.0, false + // 1, A, 2.0, false + // 2, A, 2.0, true + // 3, A, 5.0, false + // 4, A, 5.0, true + // 5, A, 5.0, false + // 6, A, 5.0, true + // 7, A, 5.0, false + } + + private class InputData + { + public long Date; + public string GrainA; + public float DataA; + } + + private sealed class TransformedData + { + public long Date { get; set; } + public string GrainA { get; set; } + public float DataA { get; set; } + public bool IsRowImputed { get; set; } + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerMedian.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerMedian.cs new file mode 100644 index 0000000000..dc310ba285 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerMedian.cs @@ -0,0 +1,71 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; + +namespace Samples.Dynamic +{ + public static class TimeSeriesImputerMedian + { + public static void Example() + { + // Create a new ML context, for ML.NET operations. It can be used for + // exception tracking and logging, as well as the source of randomness. + var mlContext = new MLContext(); + + // Create a small dataset as an IEnumerable. + var samples = new[] { new InputData() { Date = 0, GrainA = "A", DataA = 2.0f }, + new InputData() { Date = 1, GrainA = "A", DataA = float.NaN }, + new InputData() { Date = 3, GrainA = "A", DataA = 5.0f }, + new InputData() { Date = 5, GrainA = "A", DataA = float.NaN }, + new InputData() { Date = 7, GrainA = "A", DataA = float.NaN }}; + + // Convert training data to IDataView. + var dataview = mlContext.Data.LoadFromEnumerable(samples); + + // A pipeline for imputing the missing rows and values in the columns using the "Median" strategy. + var pipeline = mlContext.Transforms.TimeSeriesImputer("Date", new string[] { "GrainA" }, TimeSeriesImputerEstimator.ImputationStrategy.Median); + + // The transformed data. + var transformedData = pipeline.Fit(dataview).Transform(dataview); + + // Now let's take a look at what this did. The NaN values should be filled in with the column median, + // and rows should be created to fill in the missing gaps in the time column. + // We can extract the newly created columns as an IEnumerable of TransformedData. + var featuresColumn = mlContext.Data.CreateEnumerable( + transformedData, reuseRowObject: false); + + // And we can write out a few rows + Console.WriteLine($"Features column obtained post-transformation."); + foreach (var featureRow in featuresColumn) + Console.WriteLine(featureRow.Date + ", " + featureRow.GrainA + ", " + featureRow.DataA + ", " + featureRow.IsRowImputed); + + // Expected output: + // Features column obtained post-transformation. + // 0, A, 2, false + // 1, A, 3.5, false + // 2, A, 3.5, true + // 3, A, 5, false + // 4, A, 3.5, true + // 5, A, 3.5, false + // 6, A, 3.5, true + // 7, A, 3.5, false + } + + private class InputData + { + public long Date; + public string GrainA; + public float DataA; + } + + private sealed class TransformedData + { + public long Date { get; set; } + public string GrainA { get; set; } + public float DataA { get; set; } + public bool IsRowImputed { get; set; } + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/ToStringTransformerMultipleColumns.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/ToStringTransformerMultipleColumns.cs new file mode 100644 index 0000000000..d2b8016647 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/ToStringTransformerMultipleColumns.cs @@ -0,0 +1,80 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; + +namespace Samples.Dynamic +{ + public static class ToStringMultipleColumns + { + public static void Example() + { + // Create a new ML context, for ML.NET operations. It can be used for + // exception tracking and logging, as well as the source of randomness. + var mlContext = new MLContext(); + + // Create a small dataset as an IEnumerable. + var samples = new List() + { + new InputData(){ Feature1 = 0.1f, Feature2 = 1.1, Feature3 = 1 }, + + new InputData(){ Feature1 = 0.2f, Feature2 =1.2, Feature3 = 2 }, + + new InputData(){ Feature1 = 0.3f, Feature2 = 1.3, Feature3 = 3 }, + + new InputData(){ Feature1 = 0.4f, Feature2 = 1.4, Feature3 = 4 }, + + new InputData(){ Feature1 = 0.5f, Feature2 = 1.5, Feature3 = 5 }, + + new InputData(){ Feature1 = 0.6f, Feature2 = 1.6, Feature3 = 6 }, + }; + + // Convert training data to IDataView. + var dataview = mlContext.Data.LoadFromEnumerable(samples); + + // A pipeline for converting the "Feature1", "Feature2" and + // "Feature3" columns into their string representations + // + var pipeline = mlContext.Transforms.ToStringTransformer(new InputOutputColumnPair("Feature1"), + new InputOutputColumnPair("Feature2"), new InputOutputColumnPair("Feature3")); + + // The transformed data. + var transformedData = pipeline.Fit(dataview).Transform(dataview); + + // Now let's take a look at what this did. + // We can extract the newly created columns as an IEnumerable of + // TransformedData. + var featuresColumn = mlContext.Data.CreateEnumerable( + transformedData, reuseRowObject: false); + + // And we can write out a few rows + Console.WriteLine($"Features column obtained post-transformation."); + foreach (var featureRow in featuresColumn) + Console.WriteLine(featureRow.Feature1 + " " + featureRow.Feature2 + " " + featureRow.Feature3); + + // Expected output: + // Features column obtained post-transformation. + // 0.100000 1.100000 1 + // 0.200000 1.200000 2 + // 0.300000 1.300000 3 + // 0.400000 1.400000 4 + // 0.500000 1.500000 5 + // 0.600000 1.600000 6 + } + + private class InputData + { + public float Feature1; + public double Feature2; + public int Feature3; + } + + private sealed class TransformedData + { + public string Feature1 { get; set; } + public string Feature2 { get; set; } + public string Feature3 { get; set; } + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/ToStringTransformerSingleColumn.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/ToStringTransformerSingleColumn.cs new file mode 100644 index 0000000000..1c11931d90 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/ToStringTransformerSingleColumn.cs @@ -0,0 +1,79 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; + +namespace Samples.Dynamic +{ + public static class ToStringSingleColumn + { + public static void Example() + { + // Create a new ML context, for ML.NET operations. It can be used for + // exception tracking and logging, as well as the source of randomness. + var mlContext = new MLContext(); + + // Create a small dataset as an IEnumerable. + var samples = new List() + { + new InputData(){ Feature1 = 0.1f, Feature2 = 1.1, Feature3 = 1 }, + + new InputData(){ Feature1 = 0.2f, Feature2 =1.2, Feature3 = 2 }, + + new InputData(){ Feature1 = 0.3f, Feature2 = 1.3, Feature3 = 3 }, + + new InputData(){ Feature1 = 0.4f, Feature2 = 1.4, Feature3 = 4 }, + + new InputData(){ Feature1 = 0.5f, Feature2 = 1.5, Feature3 = 5 }, + + new InputData(){ Feature1 = 0.6f, Feature2 = 1.6, Feature3 = 6 }, + }; + + // Convert training data to IDataView. + var dataview = mlContext.Data.LoadFromEnumerable(samples); + + // A pipeline for converting the "Feature1" column into its string representations + // + var pipeline = mlContext.Transforms.ToStringTransformer("Feature1Output", "Feature1"); + + // The transformed data. + var transformedData = pipeline.Fit(dataview).Transform(dataview); + + // Now let's take a look at what this did. + // We can extract the newly created columns as an IEnumerable of + // TransformedData. + var featuresColumn = mlContext.Data.CreateEnumerable( + transformedData, reuseRowObject: false); + + // And we can write out a few rows + Console.WriteLine($"Features column obtained post-transformation."); + foreach (var featureRow in featuresColumn) + Console.WriteLine(featureRow.Feature1Output); + + // Expected output: + // Features column obtained post-transformation. + // 0.100000 + // 0.200000 + // 0.300000 + // 0.400000 + // 0.500000 + // 0.600000 + } + + private class InputData + { + public float Feature1; + public double Feature2; + public int Feature3; + } + + private sealed class TransformedData + { + public float Feature1 { get; set; } + public double Feature2 { get; set; } + public int Feature3 { get; set; } + public string Feature1Output { get; set; } + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj b/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj index 71b9436701..8aec0bb5a0 100644 --- a/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj +++ b/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj @@ -11,6 +11,7 @@ + diff --git a/pkg/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.nupkgproj b/pkg/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.nupkgproj new file mode 100644 index 0000000000..44f4babb81 --- /dev/null +++ b/pkg/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.nupkgproj @@ -0,0 +1,14 @@ + + + + netstandard2.0 + ML.NET additional featurizers for AutoML + + + + + + + + + diff --git a/pkg/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.symbols.nupkgproj b/pkg/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.symbols.nupkgproj new file mode 100644 index 0000000000..483e51c61a --- /dev/null +++ b/pkg/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.symbols.nupkgproj @@ -0,0 +1,5 @@ + + + + + diff --git a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs index 4bf0dfe445..2d7ae0c6b6 100644 --- a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs @@ -40,9 +40,8 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Dnn" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Transforms" + PublicKey.Value)] - [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.AutoML" + PublicKey.Value)] - +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Featurizers" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Internal.MetaLinearLearner" + InternalPublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "TreeVisualizer" + InternalPublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "TMSNlearnPrediction" + InternalPublicKey.Value)] diff --git a/src/Microsoft.ML.Core/Utilities/Utils.cs b/src/Microsoft.ML.Core/Utilities/Utils.cs index fbdb3a791d..4d17be39f7 100644 --- a/src/Microsoft.ML.Core/Utilities/Utils.cs +++ b/src/Microsoft.ML.Core/Utilities/Utils.cs @@ -966,6 +966,14 @@ private static MethodInfo MarshalInvokeCheckAndCreate(Type genArg, Delegat return meth; } + private static MethodInfo MarshalInvokeCheckAndCreate(Delegate func, Type[] genArgs) + { + var meth = MarshalActionInvokeCheckAndCreate(func, genArgs); + if (meth.ReturnType != typeof(TRet)) + throw Contracts.ExceptParam(nameof(func), "Cannot be generic on return type"); + return meth; + } + // REVIEW: n-argument versions? The multi-column re-application problem? // Think about how to address these. @@ -1092,6 +1100,28 @@ public static TRet MarshalInvoke + /// A 1 argument and n type version of . + /// + public static TRet MarshalInvoke( + Func func, + Type[] genArgs, TArg1 arg1) + { + var meth = MarshalInvokeCheckAndCreate(func, genArgs); + return (TRet)meth.Invoke(func.Target, new object[] { arg1}); + } + + /// + /// A 2 argument and n type version of . + /// + public static TRet MarshalInvoke( + Func func, + Type[] genArgs, TArg1 arg1, TArg2 arg2) + { + var meth = MarshalInvokeCheckAndCreate(func, genArgs); + return (TRet)meth.Invoke(func.Target, new object[] { arg1, arg2}); + } + private static MethodInfo MarshalActionInvokeCheckAndCreate(Type genArg, Delegate func) { Contracts.CheckValue(genArg, nameof(genArg)); @@ -1104,6 +1134,18 @@ private static MethodInfo MarshalActionInvokeCheckAndCreate(Type genArg, Delegat return meth; } + private static MethodInfo MarshalActionInvokeCheckAndCreate(Delegate func, params Type[] typeArguments) + { + Contracts.CheckValue(typeArguments, nameof(typeArguments)); + Contracts.CheckValue(func, nameof(func)); + var meth = func.GetMethodInfo(); + Contracts.CheckParam(meth.IsGenericMethod, nameof(func), "Should be generic but is not"); + Contracts.CheckParam(meth.GetGenericArguments().Length == typeArguments.Length, nameof(func), + "Method should have exactly the same number of generic type parameters as list passed in but it does not."); + meth = meth.GetGenericMethodDefinition().MakeGenericMethod(typeArguments); + return meth; + } + /// /// This is akin to , except applied to /// instead of . diff --git a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs index efb86af7ca..b106683166 100644 --- a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs +++ b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs @@ -367,7 +367,7 @@ public Cursor(IChannelProvider provider, DataViewRowCursor input, RowToRowMapper : base(provider, input) { var pred = parent.GetActiveOutputColumns(active); - _getters = parent._mapper.CreateGetters(input, pred, out _disposer); + _getters = parent._mapperFactory == null ? parent._mapper.CreateGetters(input, pred, out _disposer) : parent._mapperFactory.Invoke(input.Schema).CreateGetters(input, pred, out _disposer); _active = active; _bindings = parent._bindings; } diff --git a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs index db1be394a8..8cff3c9239 100644 --- a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs @@ -44,9 +44,8 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.ResNet101" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.ResNet18" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.ResNet50" + PublicKey.Value)] - [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Experimental" + PublicKey.Value)] - +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Featurizers" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Internal.MetaLinearLearner" + InternalPublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "TMSNlearnPrediction" + InternalPublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.CntkWrapper" + InternalPublicKey.Value)] diff --git a/src/Microsoft.ML.Featurizers/CategoryImputer.cs b/src/Microsoft.ML.Featurizers/CategoryImputer.cs new file mode 100644 index 0000000000..bb63c56164 --- /dev/null +++ b/src/Microsoft.ML.Featurizers/CategoryImputer.cs @@ -0,0 +1,1599 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Linq; +using System.Runtime.InteropServices; +using System.Security; +using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Featurizers; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Runtime; +using Microsoft.ML.Transforms; +using static Microsoft.ML.Featurizers.CommonExtensions; + +[assembly: LoadableClass(typeof(CategoryImputerTransformer), null, typeof(SignatureLoadModel), + CategoryImputerTransformer.UserName, CategoryImputerTransformer.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(CategoryImputerTransformer), null, typeof(SignatureLoadRowMapper), + CategoryImputerTransformer.UserName, CategoryImputerTransformer.LoaderSignature)] + +[assembly: EntryPointModule(typeof(CategoryImputerEntrypoint))] + +namespace Microsoft.ML.Featurizers +{ + public static class CategoryImputerExtensionClass + { + /// + /// Create a , which fills in the missing values in a column with the most frequent value. + /// + /// Transform Catalog + /// Output column name + /// Input column name, if null defaults to + /// + public static CategoryImputerEstimator CatagoryImputerTransformer(this TransformsCatalog catalog, string outputColumnName, string inputColumnName = null) + => CategoryImputerEstimator.Create(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName); + + /// + /// Create a , which fills in the missing values in a column with the most frequent value. + /// + /// Transform Catalog + /// List of to fill in missing values + /// + public static CategoryImputerEstimator CatagoryImputerTransformer(this TransformsCatalog catalog, params InputOutputColumnPair[] columns) + => CategoryImputerEstimator.Create(CatalogUtils.GetEnvironment(catalog), columns); + } + + /// + /// The CategoryImputer replaces missing values with the most common value in that column. + /// + /// + /// is not a trivial estimator and needs training. + /// + /// + /// ]]> + /// + /// + /// + /// + public sealed class CategoryImputerEstimator : IEstimator + { + private readonly Options _options; + + private readonly IHost _host; + + #region Options + internal sealed class Column : OneToOneColumn + { + internal static Column Parse(string str) + { + Contracts.AssertNonEmpty(str); + + var res = new Column(); + if (res.TryParse(str)) + return res; + return null; + } + + internal bool TryUnparse(StringBuilder sb) + { + Contracts.AssertValue(sb); + return TryUnparseCore(sb); + } + } + + internal sealed class Options: TransformInputBase + { + [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition (optional form: name:src)", + Name = "Column", ShortName = "col", SortOrder = 1)] + public Column[] Columns; + } + + #endregion + + internal static CategoryImputerEstimator Create(IHostEnvironment env, string outputColumnName, string inputColumnName) + { + return new CategoryImputerEstimator(env, outputColumnName, inputColumnName); + } + + internal static CategoryImputerEstimator Create(IHostEnvironment env, params InputOutputColumnPair[] columns) + { + var columnOptions = columns.Select(x => new Column { Name = x.OutputColumnName, Source = x.InputColumnName ?? x.OutputColumnName }).ToArray(); + return new CategoryImputerEstimator(env, new Options { Columns = columnOptions }); + } + + internal CategoryImputerEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(CategoryImputerEstimator)); + _options = new Options + { + Columns = new Column[1] { new Column() { Name = outputColumnName, Source = inputColumnName ?? outputColumnName } } + }; + } + + internal CategoryImputerEstimator(IHostEnvironment env, Options options) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(CategoryImputerEstimator)); + + foreach (var columnPair in options.Columns) + { + columnPair.Source = columnPair.Source ?? columnPair.Name; + } + + _options = options; + } + + public CategoryImputerTransformer Fit(IDataView input) + { + return new CategoryImputerTransformer(_host, input, _options.Columns); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + var columns = inputSchema.ToDictionary(x => x.Name); + + foreach (var column in _options.Columns) + { + var inputColumn = columns[column.Source]; + columns[column.Name] = new SchemaShape.Column(column.Name, inputColumn.Kind, + inputColumn.ItemType, inputColumn.IsKey, inputColumn.Annotations); + } + + return new SchemaShape(columns.Values); + } + } + + public sealed class CategoryImputerTransformer : RowToRowTransformerBase, IDisposable + { + #region Class data members + + internal const string Summary = "Fills in missing values in a column based on the most frequent value"; + internal const string UserName = "CategoryImputer"; + internal const string ShortName = "CategoryImputer"; + internal const string LoadName = "CategoryImputer"; + internal const string LoaderSignature = "CategoryImputer"; + + private readonly TypedColumn[] _columns; + + #endregion + + internal CategoryImputerTransformer(IHostEnvironment host, IDataView input, CategoryImputerEstimator.Column[] columns) : + base(host.Register(nameof(CategoryImputerEstimator))) + { + var schema = input.Schema; + + _columns = columns.Select(x => TypedColumn.CreateTypedColumn(x.Name, x.Source, schema[x.Source].Type.RawType.ToString())).ToArray(); + foreach (var column in _columns) + { + column.CreateTransformerFromEstimator(input); + } + } + + // Factory method for SignatureLoadModel. + internal CategoryImputerTransformer(IHostEnvironment host, ModelLoadContext ctx) : + base(host.Register(nameof(CategoryImputerTransformer))) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + // *** Binary format *** + // int number of column pairs + // for each column pair: + // string output column name + // string input column name + // column type + // int length of c++ byte array + // byte array from c++ + + var columnCount = ctx.Reader.ReadInt32(); + + _columns = new TypedColumn[columnCount]; + for (int i = 0; i < columnCount; i++) + { + _columns[i] = TypedColumn.CreateTypedColumn(ctx.Reader.ReadString(), ctx.Reader.ReadString(), ctx.Reader.ReadString()); + + // Load the C++ state and create the C++ transformer. + var dataLength = ctx.Reader.ReadInt32(); + var data = ctx.Reader.ReadByteArray(dataLength); + _columns[i].CreateTransformerFromSavedData(data); + } + } + + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema) + => new CategoryImputerTransformer(env, ctx).MakeRowMapper(inputSchema); + + private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema); + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "CATIMP T", + verWrittenCur: 0x00010001, + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(CategoryImputerTransformer).Assembly.FullName); + } + + private protected override void SaveModel(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // int number of column pairs + // for each column pair: + // string output column name + // string input column name + // column type + // int length of c++ byte array + // byte array from c++ + + ctx.Writer.Write(_columns.Count()); + foreach (var column in _columns) + { + ctx.Writer.Write(column.Name); + ctx.Writer.Write(column.Source); + ctx.Writer.Write(column.Type); + + // Save C++ state + var data = column.CreateTransformerSaveData(); + ctx.Writer.Write(data.Length); + ctx.Writer.Write(data); + } + + } + + public void Dispose() + { + foreach (var column in _columns) + { + column.Dispose(); + } + } + + #region ColumnInfo + + // REVIEW: Since we can't do overloading on the native side due to the C style exports, + // this was the best way I could think handle it to allow for any conversions needed based on the data type. + + #region BaseClass + + internal delegate bool DestroyCppTransformerEstimator(IntPtr estimator, out IntPtr errorHandle); + internal delegate bool DestroyTransformerSaveData(IntPtr buffer, IntPtr bufferSize, out IntPtr errorHandle); + internal delegate bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + + internal abstract class TypedColumn : IDisposable + { + internal readonly string Name; + internal readonly string Source; + internal readonly string Type; + internal TypedColumn(string name, string source, string type) + { + Name = name; + Source = source; + Type = type; + } + + internal abstract void CreateTransformerFromEstimator(IDataView input); + private protected abstract unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize); + private protected abstract bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle); + public abstract void Dispose(); + + internal byte[] CreateTransformerSaveData() + { + + var success = CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle); + if(!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + using (var savedDataHandle = new SaveDataSafeHandle(buffer, bufferSize)) + { + byte[] savedData = new byte[bufferSize.ToInt32()]; + Marshal.Copy(buffer, savedData, 0, savedData.Length); + return savedData; + } + } + + internal unsafe void CreateTransformerFromSavedData(byte[] data) + { + fixed(byte* rawData = data) + { + IntPtr dataSize = new IntPtr(data.Count()); + CreateTransformerFromSavedDataHelper(rawData, dataSize); + } + } + + internal static TypedColumn CreateTypedColumn(string name, string source, string type) + { + if (type == typeof(sbyte).ToString()) + { + return new Int8TypedColumn(name, source); + } + else if (type == typeof(short).ToString()) + { + return new Int16TypedColumn(name, source); + } + else if (type == typeof(int).ToString()) + { + return new Int32TypedColumn(name, source); + } + else if (type == typeof(long).ToString()) + { + return new Int64TypedColumn(name, source); + } + else if (type == typeof(byte).ToString()) + { + return new UInt8TypedColumn(name, source); + } + else if (type == typeof(ushort).ToString()) + { + return new UInt16TypedColumn(name, source); + } + else if (type == typeof(uint).ToString()) + { + return new UInt32TypedColumn(name, source); + } + else if (type == typeof(ulong).ToString()) + { + return new UInt64TypedColumn(name, source); + } + else if (type == typeof(float).ToString()) + { + return new FloatTypedColumn(name, source); + } + else if (type == typeof(double).ToString()) + { + return new DoubleTypedColumn(name, source); + } + else if (type == typeof(string).ToString()) + { + return new StringTypedColumn(name, source); + } + else if (type == typeof(ReadOnlyMemory).ToString()) + { + return new ReadOnlyCharTypedColumn(name, source); + } + + throw new Exception($"Unsupported type {type}"); + } + } + + internal abstract class TypedColumn : TypedColumn + { + internal TypedColumn(string name, string source, string type) : + base(name, source, type) + { + } + + internal abstract T Transform(T input); + private protected abstract bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle); + private protected abstract bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + private protected abstract bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle); + private protected abstract bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle); + private protected abstract bool FitHelper(TransformerEstimatorSafeHandle estimator, T input, out FitResult fitResult, out IntPtr errorHandle); + private protected abstract bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected TransformerEstimatorSafeHandle CreateTransformerFromEstimatorBase(IDataView input) + { + var success = CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + using (var estimatorHandler = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorHelper)) + { + var fitResult = FitResult.Continue; + while (fitResult != FitResult.Complete) + { + using (var data = input.GetColumn(Source).GetEnumerator()) + { + while (fitResult == FitResult.Continue && data.MoveNext()) + { + { + success = FitHelper(estimatorHandler, data.Current, out fitResult, out errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + } + + success = CompleteTrainingHelper(estimatorHandler, out fitResult, out errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + } + + success = CreateTransformerFromEstimatorHelper(estimatorHandler, out IntPtr transformer, out errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerHelper); + } + } + } + + #endregion + + #region Int8Column + + internal sealed class Int8TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + internal Int8TypedColumn(string name, string source) : + base(name, source, typeof(sbyte).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, sbyte* input, out sbyte output, out IntPtr errorHandle); + internal unsafe override sbyte Transform(sbyte input) + { + sbyte* interopInput = input == 0 ? null : &input; + var success = TransformDataNative(_transformerHandler, interopInput, out sbyte output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, sbyte* input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, sbyte input, out FitResult fitResult, out IntPtr errorHandle) + { + sbyte* interopInput = input == 0 ? null : &input; + return FitNative(estimator, interopInput, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region Int16Column + + internal sealed class Int16TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + + internal Int16TypedColumn(string name, string source) : + base(name, source, typeof(short).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, short* input, out short output, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_DestroyTransformedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + internal unsafe override short Transform(short input) + { + short* interopInput = input == 0 ? null : &input; + var success = TransformDataNative(_transformerHandler, interopInput, out short output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, short* input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, short input, out FitResult fitResult, out IntPtr errorHandle) + { + short* interopInput = input == 0 ? null : &input; + return FitNative(estimator, interopInput, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region Int32Column + + internal sealed class Int32TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + internal Int32TypedColumn(string name, string source) : + base(name, source, typeof(int).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, int* input, out int output, out IntPtr errorHandle); + internal unsafe override int Transform(int input) + { + int* interopInput = input == 0 ? null : &input; + var success = TransformDataNative(_transformerHandler, interopInput, out int output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, int* input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, int input, out FitResult fitResult, out IntPtr errorHandle) + { + int* interopInput = input == 0 ? null : &input; + return FitNative(estimator, interopInput, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region Int64Column + + internal sealed class Int64TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + internal Int64TypedColumn(string name, string source) : + base(name, source, typeof(long).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, long* input, out long output, out IntPtr errorHandle); + internal unsafe override long Transform(long input) + { + long* interopInput = input == 0 ? null : &input; + var success = TransformDataNative(_transformerHandler, interopInput, out long output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, long* input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, long input, out FitResult fitResult, out IntPtr errorHandle) + { + long* interopInput = input == 0 ? null : &input; + return FitNative(estimator, interopInput, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region UInt8Column + + internal sealed class UInt8TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + internal UInt8TypedColumn(string name, string source) : + base(name, source, typeof(byte).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, byte* input, out byte output, out IntPtr errorHandle); + internal unsafe override byte Transform(byte input) + { + byte* interopInput = input == 0 ? null : &input; + var success = TransformDataNative(_transformerHandler, interopInput, out byte output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, byte* input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, byte input, out FitResult fitResult, out IntPtr errorHandle) + { + byte* interopInput = input == 0 ? null : &input; + return FitNative(estimator, interopInput, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region UInt16Column + + internal sealed class UInt16TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + internal UInt16TypedColumn(string name, string source) : + base(name, source, typeof(ushort).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, ushort* input, out ushort output, out IntPtr errorHandle); + internal unsafe override ushort Transform(ushort input) + { + ushort* interopInput = input == 0 ? null : &input; + var success = TransformDataNative(_transformerHandler, interopInput, out ushort output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, ushort* input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, ushort input, out FitResult fitResult, out IntPtr errorHandle) + { + ushort* interopInput = input == 0 ? null : &input; + return FitNative(estimator, interopInput, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region UInt32Column + + internal sealed class UInt32TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + + internal UInt32TypedColumn(string name, string source) : + base(name, source, typeof(uint).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, uint* input, out uint output, out IntPtr errorHandle); + internal unsafe override uint Transform(uint input) + { + uint* interopInput = input == 0 ? null : &input; + var success = TransformDataNative(_transformerHandler, interopInput, out uint output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, uint* input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, uint input, out FitResult fitResult, out IntPtr errorHandle) + { + uint* interopInput = input == 0 ? null : &input; + return FitNative(estimator, interopInput, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region UInt64Column + + internal sealed class UInt64TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + + internal UInt64TypedColumn(string name, string source) : + base(name, source, typeof(ulong).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, ulong* input, out ulong output, out IntPtr errorHandle); + internal unsafe override ulong Transform(ulong input) + { + ulong* interopInput = input == 0 ? null : &input; + var success = TransformDataNative(_transformerHandler, interopInput, out ulong output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, ulong* input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, ulong input, out FitResult fitResult, out IntPtr errorHandle) + { + ulong* interopInput = input == 0 ? null : &input; + return FitNative(estimator, interopInput, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region FloatColumn + + internal sealed class FloatTypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + + internal FloatTypedColumn(string name, string source) : + base(name, source, typeof(float).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, in float input, out float output, out IntPtr errorHandle); + internal override float Transform(float input) + { + var success = TransformDataNative(_transformerHandler, input, out float output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool FitNative(TransformerEstimatorSafeHandle estimator, in float input, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool FitHelper(TransformerEstimatorSafeHandle estimator, float input, out FitResult fitResult, out IntPtr errorHandle) => + FitNative(estimator, input, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region DoubleColumn + + internal sealed class DoubleTypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + + internal DoubleTypedColumn(string name, string source) : + base(name, source, typeof(double).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, in double input, out double output, out IntPtr errorHandle); + internal override double Transform(double input) + { + var success = TransformDataNative(_transformerHandler, input, out double output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool FitNative(TransformerEstimatorSafeHandle estimator, in double input, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool FitHelper(TransformerEstimatorSafeHandle estimator, double input, out FitResult fitResult, out IntPtr errorHandle) => + FitNative(estimator, input, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region StringColumn + + internal sealed class StringTypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + + internal StringTypedColumn(string name, string source) : + base(name, source, typeof(string).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, byte* input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_DestroyTransformedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + internal unsafe override string Transform(string input) + { + // Convert to byte array with NullPointer at end or nullptr. + fixed (byte* interopInput = input == null? null : Encoding.UTF8.GetBytes(input + char.MinValue)) + { + var result = TransformDataNative(_transformerHandler, interopInput, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + + if (!result) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + + using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative)) + { + byte[] buffer = new byte[outputSize.ToInt32()]; + Marshal.Copy(output, buffer, 0, buffer.Length); + return Encoding.UTF8.GetString(buffer); + } + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, byte* input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, string input, out FitResult fitResult, out IntPtr errorHandle) + { + fixed (byte* interopInput = input == null ? null : Encoding.UTF8.GetBytes(input + char.MinValue)) + { + return FitNative(estimator, interopInput, out fitResult, out errorHandle); + } + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region ReadOnlyCharColumn + + internal sealed class ReadOnlyCharTypedColumn : TypedColumn> + { + private TransformerEstimatorSafeHandle _transformerHandler; + + internal ReadOnlyCharTypedColumn(string name, string source) : + base(name, source, typeof(ReadOnlyMemory).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, byte* input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_DestroyTransformedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + internal unsafe override ReadOnlyMemory Transform(ReadOnlyMemory input) + { + var inputAsString = input.ToString(); + fixed (byte* interopInput = string.IsNullOrEmpty(inputAsString) ? null : Encoding.UTF8.GetBytes(inputAsString + char.MinValue)) + { + var result = TransformDataNative(_transformerHandler, interopInput, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + + if (!result) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + + if (outputSize.ToInt32() == 0) + return new ReadOnlyMemory(string.Empty.ToArray()); + + using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative)) + { + byte[] buffer = new byte[outputSize.ToInt32()]; + Marshal.Copy(output, buffer, 0, buffer.Length); + return new ReadOnlyMemory(Encoding.UTF8.GetString(buffer).ToArray()); + } + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, byte* input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, ReadOnlyMemory input, out FitResult fitResult, out IntPtr errorHandle) + { + var inputAsString = input.ToString(); + fixed (byte* interopInput = string.IsNullOrEmpty(inputAsString) ? null : Encoding.UTF8.GetBytes(inputAsString + char.MinValue)) + { + return FitNative(estimator, interopInput, out fitResult, out errorHandle); + } + } + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #endregion + + private sealed class Mapper : MapperBase + { + + #region Class data members + + private readonly CategoryImputerTransformer _parent; + + #endregion + + public Mapper(CategoryImputerTransformer parent, DataViewSchema inputSchema) : + base(parent.Host.Register(nameof(Mapper)), inputSchema, parent) + { + _parent = parent; + } + + protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() + { + return _parent._columns.Select(x => new DataViewSchema.DetachedColumn(x.Name, ColumnTypeExtensions.PrimitiveTypeFromType(Type.GetType(x.Type)))).ToArray(); + } + + private Delegate MakeGetter(DataViewRow input, int iinfo) + { + ValueGetter result = (ref T dst) => + { + var inputColumn = input.Schema[_parent._columns[iinfo].Source]; + var srcGetterScalar = input.GetGetter(inputColumn); + + T value = default; + srcGetterScalar(ref value); + + dst = ((TypedColumn)_parent._columns[iinfo]).Transform(value); + + }; + + return result; + } + + protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer) + { + disposer = null; + Type inputType = input.Schema[_parent._columns[iinfo].Source].Type.RawType; + return Utils.MarshalInvoke(MakeGetter, inputType, input, iinfo); + } + + private protected override Func GetDependenciesCore(Func activeOutput) + { + var active = new bool[InputSchema.Count]; + for (int i = 0; i < InputSchema.Count; i++) + { + if (_parent._columns.Any(x => x.Source == InputSchema[i].Name)) + { + active[i] = true; + } + } + + return col => active[col]; + } + + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); + } + } + + internal static class CategoryImputerEntrypoint + { + [TlcModule.EntryPoint(Name = "Transforms.CategoryImputer", + Desc = CategoryImputerTransformer.Summary, + UserName = CategoryImputerTransformer.UserName, + ShortName = CategoryImputerTransformer.ShortName)] + public static CommonOutputs.TransformOutput ImputeToKey(IHostEnvironment env, CategoryImputerEstimator.Options input) + { + var h = EntryPointUtils.CheckArgsAndCreateHost(env, CategoryImputerTransformer.ShortName, input); + var xf = new CategoryImputerEstimator(h, input).Fit(input.Data).Transform(input.Data); + return new CommonOutputs.TransformOutput() + { + Model = new TransformModelImpl(h, xf, input.Data), + OutputData = xf + }; + } + } +} diff --git a/src/Microsoft.ML.Featurizers/Common.cs b/src/Microsoft.ML.Featurizers/Common.cs new file mode 100644 index 0000000000..1ab76c5d78 --- /dev/null +++ b/src/Microsoft.ML.Featurizers/Common.cs @@ -0,0 +1,190 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Security; +using System.Text; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.ML.Featurizers +{ + #region Native Function Declarations + + #endregion + + internal enum FitResult : byte + { + Complete = 1, Continue, ResetAndContinue + } + + // Not all these types are currently supported. This is so the ordering will allign with the native code. + internal enum TypeId : uint + { + String = 1, SByte, Short, Int, Long, Byte, UShort, + UInt, ULong, Float16, Float32, Double, Complex64, + Complex128, BFloat16, Bool, Timepoint, Duration, + + LastStaticValue, + Tensor = 0x1001 | LastStaticValue + 1, + SparseTensor = 0x1001 | LastStaticValue + 2, + Tabular = 0x1001 | LastStaticValue + 3, + Nullable = 0x1001 | LastStaticValue + 4, + Vector = 0x1001 | LastStaticValue + 5, + MapId = 0x1002 | LastStaticValue + 6 + }; + + [StructLayout(LayoutKind.Sequential, Pack = 1)] + internal unsafe struct NativeBinaryArchiveData + { + public byte* Data; + public IntPtr DataSize; + } + + #region SafeHandles + + internal class ErrorInfoSafeHandle : SafeHandleZeroOrMinusOneIsInvalid + { + [DllImport("Featurizers", EntryPoint = "DestroyErrorInfo", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyErrorInfo(IntPtr error); + + public ErrorInfoSafeHandle(IntPtr handle) : base(true) + { + SetHandle(handle); + } + + protected override bool ReleaseHandle() + { + return DestroyErrorInfo(handle); + } + } + + internal class ErrorInfoStringSafeHandle : SafeHandleZeroOrMinusOneIsInvalid + { + [DllImport("Featurizers", EntryPoint = "DestroyErrorInfoString", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyErrorInfoString(IntPtr errorString, IntPtr errorStringSize); + + private IntPtr _length; + public ErrorInfoStringSafeHandle(IntPtr handle, IntPtr length) : base(true) + { + SetHandle(handle); + _length = length; + } + + protected override bool ReleaseHandle() + { + return DestroyErrorInfoString(handle, _length); + } + } + + internal delegate bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + internal class TransformedDataSafeHandle : SafeHandleZeroOrMinusOneIsInvalid + { + private DestroyTransformedDataNative _destroySaveDataHandler; + private IntPtr _dataSize; + + public TransformedDataSafeHandle(IntPtr handle, IntPtr dataSize, DestroyTransformedDataNative destroyCppTransformerEstimator) : base(true) + { + SetHandle(handle); + _dataSize = dataSize; + _destroySaveDataHandler = destroyCppTransformerEstimator; + } + + protected override bool ReleaseHandle() + { + // Not sure what to do with error stuff here. There shoudln't ever be one though. + return _destroySaveDataHandler(handle, _dataSize, out IntPtr errorHandle); + } + } + + internal delegate bool DestroyCppTransformerEstimator(IntPtr estimator, out IntPtr errorHandle); + internal class TransformerEstimatorSafeHandle : SafeHandleZeroOrMinusOneIsInvalid + { + private DestroyCppTransformerEstimator _destroyCppTransformerEstimator; + public TransformerEstimatorSafeHandle(IntPtr handle, DestroyCppTransformerEstimator destroyCppTransformerEstimator) : base(true) + { + SetHandle(handle); + _destroyCppTransformerEstimator = destroyCppTransformerEstimator; + } + + protected override bool ReleaseHandle() + { + // Not sure what to do with error stuff here. There shouldn't ever be one though. + return _destroyCppTransformerEstimator(handle, out IntPtr errorHandle); + } + } + + // Destroying saved data is always the same. + internal delegate bool DestroyTransformerSaveData(IntPtr buffer, IntPtr bufferSize, out IntPtr errorHandle); + + internal class SaveDataSafeHandle : SafeHandleZeroOrMinusOneIsInvalid + { + private readonly IntPtr _dataSize; + + // TODO: Update with correct entry point. + [DllImport("Featurizers", EntryPoint = "DestroyTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerSaveDataNative(IntPtr buffer, IntPtr bufferSize, out IntPtr error); + + public SaveDataSafeHandle(IntPtr handle, IntPtr dataSize) : base(true) + { + SetHandle(handle); + _dataSize = dataSize; + } + + protected override bool ReleaseHandle() + { + // Not sure what to do with error stuff here. There shoudln't ever be one though. + return DestroyTransformerSaveDataNative(handle, _dataSize, out _); + } + } + + #endregion + + internal static class CommonExtensions + { + [DllImport("Featurizers", EntryPoint = "GetErrorInfoString", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool GetErrorInfoString(IntPtr error, out IntPtr errorHandleString, out IntPtr errorHandleStringSize); + + internal static string GetErrorDetailsAndFreeNativeMemory(IntPtr errorHandle) + { + using (var error = new ErrorInfoSafeHandle(errorHandle)) + { + GetErrorInfoString(errorHandle, out IntPtr errorHandleString, out IntPtr errorHandleStringSize); + using (var errorString = new ErrorInfoStringSafeHandle(errorHandleString, errorHandleStringSize)) + { + byte[] buffer = new byte[errorHandleStringSize.ToInt32()]; + Marshal.Copy(errorHandleString, buffer, 0, buffer.Length); + + return Encoding.UTF8.GetString(buffer); + } + } + } + internal static TypeId GetNativeTypeIdFromType(this Type type) + { + if (type == typeof(byte)) + return TypeId.Byte; + else if (type == typeof(short)) + return TypeId.Short; + else if (type == typeof(int)) + return TypeId.Int; + else if (type == typeof(long)) + return TypeId.Long; + else if (type == typeof(byte)) + return TypeId.Byte; + else if (type == typeof(ushort)) + return TypeId.UShort; + else if (type == typeof(uint)) + return TypeId.UInt; + else if (type == typeof(ulong)) + return TypeId.ULong; + else if (type == typeof(float)) + return TypeId.Float32; + else if (type == typeof(double)) + return TypeId.Double; + else if (type == typeof(bool)) + return TypeId.Bool; + else if (type == typeof(ReadOnlyMemory)) + return TypeId.String; + + throw new InvalidOperationException($"Unsupported type {type}"); + } + } +} diff --git a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs new file mode 100644 index 0000000000..7e0745396c --- /dev/null +++ b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs @@ -0,0 +1,781 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Runtime.InteropServices; +using System.Security; +using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Featurizers; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Runtime; +using Microsoft.ML.Transforms; +using Microsoft.Win32.SafeHandles; +using static Microsoft.ML.Featurizers.CommonExtensions; + +[assembly: LoadableClass(typeof(DateTimeTransformer), null, typeof(SignatureLoadModel), + DateTimeTransformer.UserName, DateTimeTransformer.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(DateTimeTransformer), null, typeof(SignatureLoadRowMapper), + DateTimeTransformer.UserName, DateTimeTransformer.LoaderSignature)] + +[assembly: EntryPointModule(typeof(DateTimeTransformerEntrypoint))] + +namespace Microsoft.ML.Featurizers +{ + + public static class DateTimeTransformerExtensionClass + { + /// + /// Create a , which splits up the input column specified by + /// into all its individual datetime components. Input column must be of type Int64 representing the number of seconds since the unix epoc. + /// This transformer will append the to all the output columns. If is empty, + /// then all the columns are returned. Otherwise, the columns listed in the array will be dropped from the return value. + /// + /// Transform catalog + /// Input column name + /// Prefix to add to the generated columns + /// List of columns to drop, if any + /// + public static DateTimeTransformerEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, params DateTimeTransformerEstimator.ColumnsProduced[] columnsToDrop) + => new DateTimeTransformerEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop); + + /// + /// Create a , which splits up the input column specified by + /// into all its individual datetime components. Input column must be of type Int64 representing the number of seconds since the unix epoc. + /// This transformer will append the to all the output columns. If is empty, + /// then all the columns are returned. Otherwise, the columns listed in the array will be dropped from the return value. If you specify a country, + /// Holiday details will be looked up for that country as well. + /// + /// Transform catalog + /// Input column name + /// Prefix to add to the generated columns + /// List of columns to drop, if any + /// Country name to get holiday details for + /// + public static DateTimeTransformerEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, DateTimeTransformerEstimator.ColumnsProduced[] columnsToDrop = null, DateTimeTransformerEstimator.Countries country = DateTimeTransformerEstimator.Countries.None) + => new DateTimeTransformerEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop, country); + + #region ColumnsProduced static extentions + + internal static Type GetRawColumnType(this DateTimeTransformerEstimator.ColumnsProduced column) + { + switch (column) + { + case DateTimeTransformerEstimator.ColumnsProduced.Year: + case DateTimeTransformerEstimator.ColumnsProduced.YearIso: + return typeof(int); + case DateTimeTransformerEstimator.ColumnsProduced.DayOfYear: + case DateTimeTransformerEstimator.ColumnsProduced.WeekOfMonth: + return typeof(ushort); + case DateTimeTransformerEstimator.ColumnsProduced.MonthLabel: + case DateTimeTransformerEstimator.ColumnsProduced.AmPmLabel: + case DateTimeTransformerEstimator.ColumnsProduced.DayOfWeekLabel: + case DateTimeTransformerEstimator.ColumnsProduced.HolidayName: + return typeof(ReadOnlyMemory); + default: + return typeof(byte); + } + } + + #endregion + } + + /// + /// The DateTimeTransformerEstimator splits up a date into all of its sub parts as individual columns. It generates these fields with a user specified prefix: + /// int Year, byte Month, byte Day, byte Hour, byte Minute, byte Second, byte AmPm, byte Hour12, byte DayOfWeek, byte DayOfQuarter, + /// ushort DayOfYear, ushort WeekOfMonth, byte QuarterOfYear, byte HalfOfYear, byte WeekIso, int YearIso, string MonthLabel, string AmPmLabel, + /// string DayOfWeekLabel, string HolidayName, byte IsPaidTimeOff + /// + /// You can optionally specify a country and it will pull holiday information about the country as well + /// + /// + /// is a trivial estimator and does not need training. + /// + /// + /// ]]> + /// + /// + /// + /// + public sealed class DateTimeTransformerEstimator : IEstimator + { + private readonly Options _options; + + private readonly IHost _host; + + #region Options + internal sealed class Options: TransformInputBase + { + [Argument(ArgumentType.Required, HelpText = "Input column", Name = "Source", ShortName = "src", SortOrder = 1)] + public string Source; + + // This transformer adds columns + [Argument(ArgumentType.Required, HelpText = "Output column prefix", Name = "Prefix", ShortName = "pre", SortOrder = 2)] + public string Prefix; + + [Argument(ArgumentType.MultipleUnique, HelpText = "Columns to drop after the DateTime Expansion", Name = "ColumnsToDrop", ShortName = "drop", SortOrder = 3)] + public ColumnsProduced[] ColumnsToDrop; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Country to get holidays for. Defaults to none if not passed", Name = "Country", ShortName = "ctry", SortOrder = 4)] + public Countries Country = Countries.None; + } + + #endregion + + public DateTimeTransformerEstimator(IHostEnvironment env, string inputColumnName, string columnPrefix, ColumnsProduced[] columnsToDrop, Countries country = Countries.None) + { + + Contracts.CheckValue(env, nameof(env)); + _host = Contracts.CheckRef(env, nameof(env)).Register("DateTimeTransformerEstimator"); + _host.CheckValue(inputColumnName, nameof(inputColumnName), "Input column should not be null."); + + _options = new Options + { + Source = inputColumnName, + Prefix = columnPrefix, + ColumnsToDrop = columnsToDrop == null ? Array.Empty() : columnsToDrop, + Country = country + }; + } + + internal DateTimeTransformerEstimator(IHostEnvironment env, Options options) + { + + Contracts.CheckValue(env, nameof(env)); + _host = Contracts.CheckRef(env, nameof(env)).Register("DateTimeTransformerEstimator"); + + _options = options; + _options.ColumnsToDrop = _options.ColumnsToDrop == null ? Array.Empty() : _options.ColumnsToDrop; + } + + public DateTimeTransformer Fit(IDataView input) + { + return new DateTimeTransformer(_host, _options.Source, _options.Prefix, _options.ColumnsToDrop, _options.Country); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + var columns = inputSchema.ToDictionary(x => x.Name); + + foreach (ColumnsProduced column in Enum.GetValues(typeof(ColumnsProduced))) + if (_options.ColumnsToDrop == null || !_options.ColumnsToDrop.Contains(column)) + { + columns[_options.Prefix + column.ToString()] = new SchemaShape.Column(_options.Prefix + column.ToString(), SchemaShape.Column.VectorKind.Scalar, + ColumnTypeExtensions.PrimitiveTypeFromType(column.GetRawColumnType()), false, null); + } + + return new SchemaShape(columns.Values); + } + + #region Enums + public enum ColumnsProduced : byte + { + Year = 1, Month, Day, Hour, Minute, Second, AmPm, Hour12, DayOfWeek, DayOfQuarter, DayOfYear, + WeekOfMonth, QuarterOfYear, HalfOfYear, WeekIso, YearIso, MonthLabel, AmPmLabel, DayOfWeekLabel, + HolidayName, IsPaidTimeOff + }; + + public enum Countries : byte + { + None = 1, + Argentina, Australia, Austria, Belarus, Belgium, Brazil, Canada, Colombia, Croatia, Czech, Denmark, + England, Finland, France, Germany, Hungary, India, Ireland, IsleofMan, Italy, Japan, Mexico, Netherlands, + NewZealand, NorthernIreland, Norway, Poland, Portugal, Scotland, Slovenia, SouthAfrica, Spain, Sweden, Switzerland, + Ukraine, UnitedKingdom, UnitedStates, Wales + } + + #endregion + } + + public sealed class DateTimeTransformer : RowToRowTransformerBase, IDisposable + { + #region Class data members + + internal const string Summary = "Splits a date time value into each individual component"; + internal const string UserName = "DateTime Transform"; + internal const string ShortName = "DateTimeTransform"; + internal const string LoadName = "DateTimeTransform"; + internal const string LoaderSignature = "DateTimeTransform"; + private DateTimeTypedColumn _column; + + private DateTimeTransformerEstimator.ColumnsProduced[] _columnsToDrop; + private byte[] _activeColumnMapping; + + #endregion + + public DateTimeTransformer(IHostEnvironment env, string inputColumnName, string columnPrefix, DateTimeTransformerEstimator.ColumnsProduced[] columnsToDrop, DateTimeTransformerEstimator.Countries country ) : + base(env.Register(nameof(DateTimeTransformer))) + { + + _columnsToDrop = columnsToDrop; + var activeColumnLength = Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced)).Length - (_columnsToDrop == null ? 0 : _columnsToDrop.Length); + _activeColumnMapping = new byte[activeColumnLength]; + var index = 0; + foreach(DateTimeTransformerEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced))) + { + if (_columnsToDrop == null || !_columnsToDrop.Contains(column)) + { + _activeColumnMapping[index++] = (byte)column; + } + } + + _column = new DateTimeTypedColumn(inputColumnName, columnPrefix); + _column.CreateTransformerFromEstimator(country); + } + + // Factory method for SignatureLoadModel. + internal DateTimeTransformer(IHostEnvironment host, ModelLoadContext ctx) : + base(host.Register(nameof(DateTimeTransformer))) + { + + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + // *** Binary format *** + // name of input column + // column prefix + // byte length of columns to drop array + // byte array of columns to drop + // length of C++ state array + // C++ byte state array + + _column = new DateTimeTypedColumn(ctx.Reader.ReadString(), ctx.Reader.ReadString()); + + var dropColumnsLength = ctx.Reader.ReadInt32(); + if (dropColumnsLength > 0) + { + _columnsToDrop = new DateTimeTransformerEstimator.ColumnsProduced[dropColumnsLength]; + //read in enum bytes + for (int i = 0; i < dropColumnsLength; i++) + _columnsToDrop[i] = (DateTimeTransformerEstimator.ColumnsProduced)ctx.Reader.ReadByte(); + } + + _activeColumnMapping = new byte[Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced)).Length - dropColumnsLength]; + var index = 0; + foreach (DateTimeTransformerEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced))) + { + if (_columnsToDrop == null || !_columnsToDrop.Contains(column)) + { + _activeColumnMapping[index++] = (byte)column; + } + } + + var dataLength = ctx.Reader.ReadInt32(); + var data = ctx.Reader.ReadByteArray(dataLength); + _column.CreateTransformerFromSavedData(data); + } + + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema) + => new DateTimeTransformer(env, ctx).MakeRowMapper(inputSchema); + + private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema); + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "DATETI T", + verWrittenCur: 0x00010001, + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(DateTimeTransformer).Assembly.FullName); + } + + private protected override void SaveModel(ModelSaveContext ctx) + { + + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // name of input column + // column prefix + // byte length of columns to drop array + // byte array of columns to drop + // length of C++ state array + // C++ byte state array + + ctx.Writer.Write(_column.Source); + ctx.Writer.Write(_column.Prefix); + + ctx.Writer.Write(_columnsToDrop == null ? 0 : _columnsToDrop.Length); + if (_columnsToDrop != null) + { + foreach (var toDrop in _columnsToDrop) + ctx.Writer.Write((byte)toDrop); + } + + var data = _column.CreateTransformerSaveData(); + ctx.Writer.Write(data.Length); + ctx.Writer.Write(data); + } + + public void Dispose() + { + _column.Dispose(); + } + + #region C++ Safe handle classes + + internal class TransformedDataSafeHandle : SafeHandleZeroOrMinusOneIsInvalid + { + private readonly DestroyTransformedDataNative _destroyTransformedDataHandler; + + public TransformedDataSafeHandle(IntPtr handle, DestroyTransformedDataNative destroyTransformedDataHandler) : base(true) + { + SetHandle(handle); + _destroyTransformedDataHandler = destroyTransformedDataHandler; + } + + protected override bool ReleaseHandle() + { + // Not sure what to do with error stuff here. There shoudln't ever be one though. + return _destroyTransformedDataHandler(handle, out IntPtr errorHandle); + } + } + + #endregion + + #region TimePoint + + [StructLayoutAttribute(LayoutKind.Sequential)] + internal struct TimePoint + { + public int Year; + public byte Month; + public byte Day; + public byte Hour; + public byte Minute; + public byte Second; + public byte AmPm; + public byte Hour12; + public byte DayOfWeek; + public byte DayOfQuarter; + public ushort DayOfYear; + public ushort WeekOfMonth; + public byte QuarterOfYear; + public byte HalfOfYear; + public byte WeekIso; + public int YearIso; + public string MonthLabel; + public string AmPmLabel; + public string DayOfWeekLabel; + public string HolidayName; + public byte IsPaidTimeOff; + + internal unsafe TimePoint(byte* rawData) + { + int intPtrSize = sizeof(IntPtr); + + Year = *(int*)rawData; + rawData += 4; + + Month = *rawData++; + Day = *rawData++; + Hour = *rawData++; + Minute = *rawData++; + Second = *rawData++; + AmPm = *rawData++; + Hour12 = *rawData++; + DayOfWeek = *rawData++; + DayOfQuarter = *rawData++; + DayOfYear = *(ushort*)rawData; + rawData += 2; + + WeekOfMonth = *(ushort*)rawData; + rawData += 2; + + QuarterOfYear = *rawData++; + HalfOfYear = *rawData++; + WeekIso = *rawData++; + YearIso = *(int*)rawData; + rawData += 4; + + // Convert char * to string + MonthLabel = GetStringFromPointer(ref rawData, intPtrSize); + AmPmLabel = GetStringFromPointer(ref rawData, intPtrSize); + DayOfWeekLabel = GetStringFromPointer(ref rawData, intPtrSize); + HolidayName = GetStringFromPointer(ref rawData, intPtrSize); + IsPaidTimeOff = *rawData; + } + + // Converts a pointer to a native char* to a string and increments pointer by to the next value. + // The length of the string is stored at byte* + sizeof(IntPtr). + private static unsafe string GetStringFromPointer(ref byte* rawData, int intPtrSize) + { + byte[] buffer; + if (intPtrSize == 4) // 32 bit machine + buffer = new byte[*(uint*)(rawData + intPtrSize)]; + else // 64 bit machine + buffer = new byte[*(ulong*)(rawData + intPtrSize)]; + + if (buffer.Length == 0) + { + rawData += intPtrSize * 2; + return string.Empty; + } + + Marshal.Copy(new IntPtr(*(int**)rawData), buffer, 0, buffer.Length); + rawData += intPtrSize * 2; + + return Encoding.UTF8.GetString(buffer); + } + + }; + + #endregion + + #region BaseClass + + internal delegate bool DestroyCppTransformerEstimator(IntPtr estimator, out IntPtr errorHandle); + internal delegate bool DestroyTransformerSaveData(IntPtr buffer, IntPtr bufferSize, out IntPtr errorHandle); + internal delegate bool DestroyTransformedDataNative(IntPtr output, out IntPtr errorHandle); + + internal abstract class TypedColumn : IDisposable + { + internal readonly string Source; + internal readonly string Prefix; + + internal TypedColumn(string source, string prefix) + { + Source = source; + Prefix = prefix; + } + + internal abstract void CreateTransformerFromEstimator(DateTimeTransformerEstimator.Countries country); + private protected abstract unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize); + private protected unsafe abstract bool CreateEstimatorHelper(byte* countryName, byte* dataRootDir, out IntPtr estimator, out IntPtr errorHandle); + private protected abstract bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + private protected abstract bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle); + private protected abstract bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle); + private protected abstract bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle); + public abstract void Dispose(); + + private protected unsafe TransformerEstimatorSafeHandle CreateTransformerFromEstimatorBase(DateTimeTransformerEstimator.Countries country) + { + bool success; + IntPtr errorHandle; + IntPtr estimator; + if (country == DateTimeTransformerEstimator.Countries.None) + { + success = CreateEstimatorHelper(null, null, out estimator, out errorHandle); + } + else + { + fixed (byte* dataRootDir = Encoding.UTF8.GetBytes(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location) + char.MinValue)) + fixed (byte* countryPointer = Encoding.UTF8.GetBytes(Enum.GetName(typeof(DateTimeTransformerEstimator.Countries), country) + char.MinValue)) + { + success = CreateEstimatorHelper(countryPointer, dataRootDir, out estimator, out errorHandle); + } + } + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + + using (var estimatorHandler = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorHelper)) + { + + success = CreateTransformerFromEstimatorHelper(estimatorHandler, out IntPtr transformer, out errorHandle); + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + + return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerHelper); + } + } + + internal byte[] CreateTransformerSaveData() + { + + var success = CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + using (var savedDataHandle = new SaveDataSafeHandle(buffer, bufferSize)) + { + byte[] savedData = new byte[bufferSize.ToInt32()]; + Marshal.Copy(buffer, savedData, 0, savedData.Length); + return savedData; + } + } + + internal unsafe void CreateTransformerFromSavedData(byte[] data) + { + fixed (byte* rawData = data) + { + IntPtr dataSize = new IntPtr(data.Count()); + CreateTransformerFromSavedDataHelper(rawData, dataSize); + } + } + } + + internal abstract class TypedColumn : TypedColumn + { + internal TypedColumn(string source, string prefix) : + base(source, prefix) + { + } + + internal abstract TimePoint Transform(T input); + + } + + #endregion + + #region DateTimeTypedColumn + + internal sealed class DateTimeTypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + internal DateTimeTypedColumn(string source, string prefix) : + base(source, prefix) + { + } + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateEstimator"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateEstimatorNative(byte* countryName, byte* dataRootDir, out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_DestroyEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_DestroyTransformer"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + internal override unsafe void CreateTransformerFromEstimator(DateTimeTransformerEstimator.Countries country) + { + _transformerHandler = CreateTransformerFromEstimatorBase(country); + } + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateTransformerFromSavedDataWithDataRoot"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, byte* dataRootDir, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + fixed (byte* dataRootDir = Encoding.UTF8.GetBytes(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location) + char.MinValue)) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, dataRootDir, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + } + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_Transform"), SuppressUnmanagedCodeSecurity] + private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, long input, out IntPtr output, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, out IntPtr errorHandle); + internal override TimePoint Transform(long input) + { + var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + using (var handler = new TransformedDataSafeHandle(output, DestroyTransformedDataNative)) + { + unsafe + { + return new TimePoint((byte*)output.ToPointer()); + } + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected unsafe override bool CreateEstimatorHelper(byte* countryName, byte* dataRootDir, out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(countryName, dataRootDir, out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + private sealed class Mapper : MapperBase + { + + #region Class data members + + private readonly DateTimeTransformer _parent; + private ConcurrentDictionary _cache; + private ConcurrentQueue _oldestKeys; + + #endregion + + public Mapper(DateTimeTransformer parent, DataViewSchema inputSchema) : + base(parent.Host.Register(nameof(Mapper)), inputSchema, parent) + { + _parent = parent; + _cache = new ConcurrentDictionary(); + _oldestKeys = new ConcurrentQueue(); + } + + protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() + { + var columns = new List(); + + foreach (DateTimeTransformerEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced))) + if (_parent._columnsToDrop == null || !_parent._columnsToDrop.Contains(column)) + { + columns.Add(new DataViewSchema.DetachedColumn(_parent._column.Prefix + column.ToString(), + ColumnTypeExtensions.PrimitiveTypeFromType(column.GetRawColumnType()))); + } + + return columns.ToArray(); + } + + private Delegate MakeGetter(DataViewRow input, int iinfo) + { + ValueGetter result = (ref T dst) => + { + long dateTime = default; + var getter = input.GetGetter(input.Schema[_parent._column.Source]); + getter(ref dateTime); + + if (!_cache.TryGetValue(dateTime, out TimePoint timePoint)){ + _cache[dateTime] = _parent._column.Transform(dateTime); + _oldestKeys.Enqueue(dateTime); + timePoint = _cache[dateTime]; + + // If more than 100 cached items, remove 20 + if(_cache.Count > 100) + { + for(int i = 0; i < 20; i++) + { + long key; + while (!_oldestKeys.TryDequeue(out key)) { } + while (!_cache.TryRemove(key, out TimePoint removedValue)) { } + } + } + } + + if (iinfo == 0) + dst = (T)Convert.ChangeType(timePoint.Year, typeof(T)); + else if (iinfo == 1) + dst = (T)Convert.ChangeType(timePoint.Month, typeof(T)); + else if (iinfo == 2) + dst = (T)Convert.ChangeType(timePoint.Day, typeof(T)); + else if (iinfo == 3) + dst = (T)Convert.ChangeType(timePoint.Hour, typeof(T)); + else if (iinfo == 4) + dst = (T)Convert.ChangeType(timePoint.Minute, typeof(T)); + else if (iinfo == 5) + dst = (T)Convert.ChangeType(timePoint.Second, typeof(T)); + else if (iinfo == 6) + dst = (T)Convert.ChangeType(timePoint.AmPm, typeof(T)); + else if (iinfo == 7) + dst = (T)Convert.ChangeType(timePoint.Hour12, typeof(T)); + else if (iinfo == 8) + dst = (T)Convert.ChangeType(timePoint.DayOfWeek, typeof(T)); + else if (iinfo == 9) + dst = (T)Convert.ChangeType(timePoint.DayOfQuarter, typeof(T)); + else if (iinfo == 10) + dst = (T)Convert.ChangeType(timePoint.DayOfYear, typeof(T)); + else if (iinfo == 11) + dst = (T)Convert.ChangeType(timePoint.WeekOfMonth, typeof(T)); + else if (iinfo == 12) + dst = (T)Convert.ChangeType(timePoint.QuarterOfYear, typeof(T)); + else if (iinfo == 13) + dst = (T)Convert.ChangeType(timePoint.HalfOfYear, typeof(T)); + else if (iinfo == 14) + dst = (T)Convert.ChangeType(timePoint.WeekIso, typeof(T)); + else if (iinfo == 15) + dst = (T)Convert.ChangeType(timePoint.YearIso, typeof(T)); + else if (iinfo == 16) + dst = (T)Convert.ChangeType(timePoint.MonthLabel.AsMemory(), typeof(T)); + else if (iinfo == 17) + dst = (T)Convert.ChangeType(timePoint.AmPmLabel.AsMemory(), typeof(T)); + else if (iinfo == 18) + dst = (T)Convert.ChangeType(timePoint.DayOfWeekLabel.AsMemory(), typeof(T)); + else if (iinfo == 19) + dst = (T)Convert.ChangeType(timePoint.HolidayName.AsMemory(), typeof(T)); + else + dst = (T)Convert.ChangeType(timePoint.IsPaidTimeOff, typeof(T)); + }; + + return result; + } + + protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer) + { + disposer = null; + + var outputColumn = (int)_parent._activeColumnMapping[iinfo]; + + // Have to subtract 1 from the output column since the enum starts and 1 and not 0. + return Utils.MarshalInvoke(MakeGetter, ((DateTimeTransformerEstimator.ColumnsProduced)outputColumn).GetRawColumnType(), input, outputColumn - 1); + } + + private protected override Func GetDependenciesCore(Func activeOutput) + { + var active = new bool[InputSchema.Count]; + for (int i = 0; i < InputSchema.Count; i++) + { + if (InputSchema[i].Name.Equals(_parent._column.Source)) + { + active[i] = true; + } + } + + return col => active[col]; + } + + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); + } + } + + internal static class DateTimeTransformerEntrypoint + { + [TlcModule.EntryPoint(Name = "Transforms.DateTimeSplitter", + Desc = DateTimeTransformer.Summary, + UserName = DateTimeTransformer.UserName, + ShortName = DateTimeTransformer.ShortName)] + public static CommonOutputs.TransformOutput DateTimeSplit(IHostEnvironment env, DateTimeTransformerEstimator.Options input) + { + var h = EntryPointUtils.CheckArgsAndCreateHost(env, DateTimeTransformer.ShortName, input); + var xf = new DateTimeTransformerEstimator(h, input).Fit(input.Data).Transform(input.Data); + return new CommonOutputs.TransformOutput() + { + Model = new TransformModelImpl(h, xf, input.Data), + OutputData = xf + }; + } + } +} diff --git a/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj b/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj new file mode 100644 index 0000000000..f3a35c5d85 --- /dev/null +++ b/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj @@ -0,0 +1,18 @@ + + + + netstandard2.0 + Microsoft.ML.Featurizers + true + + + + + + + + + + + + diff --git a/src/Microsoft.ML.Featurizers/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Featurizers/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..f7c9d934e5 --- /dev/null +++ b/src/Microsoft.ML.Featurizers/Properties/AssemblyInfo.cs @@ -0,0 +1,11 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.CompilerServices; +using Microsoft.ML; + +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.EntryPoints" + PublicKey.Value)] + +[assembly: WantsToBeBestFriends] diff --git a/src/Microsoft.ML.Featurizers/RobustScaler.cs b/src/Microsoft.ML.Featurizers/RobustScaler.cs new file mode 100644 index 0000000000..739a25c2b2 --- /dev/null +++ b/src/Microsoft.ML.Featurizers/RobustScaler.cs @@ -0,0 +1,1616 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Linq; +using System.Runtime.InteropServices; +using System.Security; +using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Featurizers; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Runtime; +using Microsoft.ML.Transforms; +using static Microsoft.ML.Featurizers.CommonExtensions; + +[assembly: LoadableClass(typeof(RobustScalerTransformer), null, typeof(SignatureLoadModel), + RobustScalerTransformer.UserName, RobustScalerTransformer.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(RobustScalerTransformer), null, typeof(SignatureLoadRowMapper), + RobustScalerTransformer.UserName, RobustScalerTransformer.LoaderSignature)] + +[assembly: EntryPointModule(typeof(RobustScalerEntrypoint))] + +namespace Microsoft.ML.Featurizers +{ + public static class RobustScalerExtensionClass + { + public static RobustScalerEstimator RobustScalerTransformer(this TransformsCatalog catalog, string outputColumnName, string inputColumnName = null, bool center = true, bool scale = true, float quantileMin = 25.0f, float quantileMax = 75.0f) + { + var options = new RobustScalerEstimator.Options + { + Columns = new RobustScalerEstimator.Column[1] { new RobustScalerEstimator.Column() { Name = outputColumnName, Source = inputColumnName ?? outputColumnName } }, + Center = center, + Scale = scale, + QuantileMin = quantileMin, + QuantileMax = quantileMax + }; + + return new RobustScalerEstimator(CatalogUtils.GetEnvironment(catalog), options); + } + + public static RobustScalerEstimator RobustScalerTransformer(this TransformsCatalog catalog, InputOutputColumnPair[] columns, bool center = true, bool scale = true, float quantileMin = 25.0f, float quantileMax = 75.0f) + { + var options = new RobustScalerEstimator.Options + { + Columns = columns.Select(x => new RobustScalerEstimator.Column { Name = x.OutputColumnName, Source = x.InputColumnName ?? x.OutputColumnName }).ToArray(), + Center = center, + Scale = scale, + QuantileMin = quantileMin, + QuantileMax = quantileMax + }; + + return new RobustScalerEstimator(CatalogUtils.GetEnvironment(catalog), options); + } + } + + /// + /// RobustScalar Featurizer scales features using statistics that are robust to outliers, by removing the median and scaling the data according to the quantile range + /// (defaults to IQR: Interquartile Range). Centering and scaling happen independently on each feature by computing the relevant statistics on the samples in the training set. + /// Median and interquartile range are then stored to be used on later data using the transform method. + /// + /// + /// is not a trivial estimator and needs training. + /// + /// + /// ]]> + /// + /// + /// + /// + public sealed class RobustScalerEstimator : IEstimator + { + private Options _options; + + private readonly IHost _host; + + // For determining what the output type is. + private static readonly Type[] _floatTypes = new Type[] { typeof(byte), typeof(sbyte), typeof(short), typeof(ushort), typeof(float) }; + + #region Options + + internal sealed class Column : OneToOneColumn + { + internal static Column Parse(string str) + { + Contracts.AssertNonEmpty(str); + + var res = new Column(); + if (res.TryParse(str)) + return res; + return null; + } + + internal bool TryUnparse(StringBuilder sb) + { + Contracts.AssertValue(sb); + return TryUnparseCore(sb); + } + } + + internal sealed class Options : TransformInputBase + { + [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition (optional form: name:src)", + Name = "Column", ShortName = "col", SortOrder = 1)] + public Column[] Columns; + + [Argument(ArgumentType.AtMostOnce, HelpText = "If True, center the data before scaling.", + Name = "Center", ShortName = "ctr", SortOrder = 2)] + public bool Center = true; + + [Argument(ArgumentType.AtMostOnce, HelpText = "If True, scale the data to interquartile range.", + Name = "Scale", ShortName = "sc", SortOrder = 3)] + public bool Scale = true; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Min for the quantile range used to calculate scale.", + Name = "QuantileMin", ShortName = "min", SortOrder = 4)] + public float QuantileMin = 25.0f; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Max for the quantile range used to calculate scale.", + Name = "QuantileMax", ShortName = "max", SortOrder = 5)] + public float QuantileMax = 75.0f; + } + + #endregion + + internal RobustScalerEstimator(IHostEnvironment env, Options options) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(RobustScalerEstimator)); + Contracts.Check(options.QuantileMin >= 0.0f && options.QuantileMin < options.QuantileMax && options.QuantileMax <= 100.0f, "Invalid QuantileRange provided"); + Contracts.CheckNonEmpty(options.Columns, nameof(options.Columns)); + + _options = options; + } + + public RobustScalerTransformer Fit(IDataView input) + { + return new RobustScalerTransformer(_host, input, _options); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + var columns = inputSchema.ToDictionary(x => x.Name); + + foreach (var column in _options.Columns) + { + var inputColumn = columns[column.Source]; + + if (!RobustScalerTransformer.TypedColumn.IsColumnTypeSupported(inputColumn.ItemType.RawType)) + throw new InvalidOperationException($"Type {inputColumn.ItemType.RawType.ToString()} for column {column.Name} not a supported type."); + + if (_floatTypes.Contains(inputColumn.ItemType.RawType)) + { + columns[column.Name] = new SchemaShape.Column(column.Name, inputColumn.Kind, + ColumnTypeExtensions.PrimitiveTypeFromType(typeof(float)), inputColumn.IsKey, inputColumn.Annotations); + } + else + { + columns[column.Name] = new SchemaShape.Column(column.Name, inputColumn.Kind, + ColumnTypeExtensions.PrimitiveTypeFromType(typeof(double)), inputColumn.IsKey, inputColumn.Annotations); + } + + } + return new SchemaShape(columns.Values); + } + } + + public sealed class RobustScalerTransformer : RowToRowTransformerBase, IDisposable + { + #region Class data members + + internal const string Summary = "Removes the median and scales the data according to the quantile range."; + internal const string UserName = "RobustScalerTransformer"; + internal const string ShortName = "RobustScalerTransformer"; + internal const string LoadName = "RobustScalerTransformer"; + internal const string LoaderSignature = "RobustScalerTransformer"; + + private TypedColumn[] _columns; + private RobustScalerEstimator.Options _options; + + #endregion + + internal RobustScalerTransformer(IHostEnvironment host, IDataView input, RobustScalerEstimator.Options options) : + base(host.Register(nameof(RobustScalerTransformer))) + { + var schema = input.Schema; + _options = options; + + _columns = options.Columns.Select(x => TypedColumn.CreateTypedColumn(x.Name, x.Source, schema[x.Source].Type.RawType.ToString(), this)).ToArray(); + foreach (var column in _columns) + { + column.CreateTransformerFromEstimator(input); + } + } + + // Factory method for SignatureLoadModel. + internal RobustScalerTransformer(IHostEnvironment host, ModelLoadContext ctx) : + base(host.Register(nameof(RobustScalerTransformer))) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + // *** Binary format *** + // int number of column pairs + // for each column pair: + // string output column name + // string input column name + // column type + // int length of c++ byte array + // byte array from c++ + + var columnCount = ctx.Reader.ReadInt32(); + + _columns = new TypedColumn[columnCount]; + for (int i = 0; i < columnCount; i++) + { + _columns[i] = TypedColumn.CreateTypedColumn(ctx.Reader.ReadString(), ctx.Reader.ReadString(), ctx.Reader.ReadString(), this); + + // Load the C++ state and create the C++ transformer. + var dataLength = ctx.Reader.ReadInt32(); + var data = ctx.Reader.ReadByteArray(dataLength); + _columns[i].CreateTransformerFromSavedData(data); + } + } + + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema) + => new RobustScalerTransformer(env, ctx).MakeRowMapper(inputSchema); + + private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema); + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "RbScal T", + verWrittenCur: 0x00010001, + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(RobustScalerTransformer).Assembly.FullName); + } + + private protected override void SaveModel(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // int number of column pairs + // for each column pair: + // string output column name + // string input column name + // column type + // int length of c++ byte array + // byte array from c++ + + ctx.Writer.Write(_columns.Count()); + foreach (var column in _columns) + { + ctx.Writer.Write(column.Name); + ctx.Writer.Write(column.Source); + ctx.Writer.Write(column.Type); + + // Save C++ state + var data = column.CreateTransformerSaveData(); + ctx.Writer.Write(data.Length); + ctx.Writer.Write(data); + } + } + + public void Dispose() + { + foreach (var column in _columns) + { + column.Dispose(); + } + } + + #region ColumnInfo + + #region BaseClass + + internal abstract class TypedColumn : IDisposable + { + internal readonly string Name; + internal readonly string Source; + internal readonly string Type; + + private static readonly Type[] _supportedTypes = new[] { typeof(sbyte), typeof(short), typeof(int), typeof(long), typeof(byte), typeof(ushort), + typeof(uint), typeof(ulong), typeof(float), typeof(double) }; + + internal TypedColumn(string name, string source, string type) + { + Name = name; + Source = source; + Type = type; + } + + internal abstract void CreateTransformerFromEstimator(IDataView input); + private protected abstract unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize); + private protected abstract bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle); + public abstract void Dispose(); + + public abstract Type ReturnType(); + + internal byte[] CreateTransformerSaveData() + { + + var success = CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + using (var savedDataHandle = new SaveDataSafeHandle(buffer, bufferSize)) + { + byte[] savedData = new byte[bufferSize.ToInt32()]; + Marshal.Copy(buffer, savedData, 0, savedData.Length); + return savedData; + } + } + + internal unsafe void CreateTransformerFromSavedData(byte[] data) + { + fixed (byte* rawData = data) + { + IntPtr dataSize = new IntPtr(data.Count()); + CreateTransformerFromSavedDataHelper(rawData, dataSize); + } + } + + internal static bool IsColumnTypeSupported(Type type) + { + return _supportedTypes.Contains(type); + } + + internal static TypedColumn CreateTypedColumn(string name, string source, string type, RobustScalerTransformer parent) + { + if (type == typeof(sbyte).ToString()) + { + return new Int8TypedColumn(name, source, parent); + } + else if (type == typeof(short).ToString()) + { + return new Int16TypedColumn(name, source, parent); + } + else if (type == typeof(int).ToString()) + { + return new Int32TypedColumn(name, source, parent); + } + else if (type == typeof(long).ToString()) + { + return new Int64TypedColumn(name, source, parent); + } + else if (type == typeof(byte).ToString()) + { + return new UInt8TypedColumn(name, source, parent); + } + else if (type == typeof(ushort).ToString()) + { + return new UInt16TypedColumn(name, source, parent); + } + else if (type == typeof(uint).ToString()) + { + return new UInt32TypedColumn(name, source, parent); + } + else if (type == typeof(ulong).ToString()) + { + return new UInt64TypedColumn(name, source, parent); + } + else if (type == typeof(float).ToString()) + { + return new FloatTypedColumn(name, source, parent); + } + else if (type == typeof(double).ToString()) + { + return new DoubleTypedColumn(name, source, parent); + } + + throw new InvalidOperationException($"Column {name} has an unsupported type {type}."); + } + } + + internal abstract class TypedColumn : TypedColumn + { + internal TypedColumn(string name, string source, string type) : + base(name, source, type) + { + } + + internal abstract TOutputType Transform(TSourceType input); + private protected abstract bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle); + private protected abstract bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + private protected abstract bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle); + private protected abstract bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle); + private protected abstract bool FitHelper(TransformerEstimatorSafeHandle estimator, TSourceType input, out FitResult fitResult, out IntPtr errorHandle); + private protected abstract bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected abstract bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle); + private protected TransformerEstimatorSafeHandle CreateTransformerFromEstimatorBase(IDataView input) + { + var success = CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + using (var estimatorHandle = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorHelper)) + { + if (!IsTrainingComplete(estimatorHandle)) + { + var fitResult = FitResult.Continue; + while (fitResult != FitResult.Complete) + { + fitResult = FitResult.Continue; + using (var data = input.GetColumn(Source).GetEnumerator()) + { + while (fitResult == FitResult.Continue && data.MoveNext()) + { + { + success = FitHelper(estimatorHandle, data.Current, out fitResult, out errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + } + + success = CompleteTrainingHelper(estimatorHandle, out fitResult, out errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + } + } + + success = CreateTransformerFromEstimatorHelper(estimatorHandle, out IntPtr transformer, out errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerHelper); + } + } + } + + #endregion + + #region Int8Column + + internal sealed class Int8TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + private RobustScalerTransformer _parent; + internal Int8TypedColumn(string name, string source, RobustScalerTransformer parent) : + base(name, source, typeof(sbyte).ToString()) + { + _parent = parent; + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, sbyte input, out float output, out IntPtr errorHandle); + internal unsafe override float Transform(sbyte input) + { + var success = TransformDataNative(_transformerHandler, input, out float output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) + { + if (_parent._options.Scale) + return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle); + else + return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle); + } + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, sbyte input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, sbyte input, out FitResult fitResult, out IntPtr errorHandle) + { + return FitNative(estimator, input, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle); + private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle) + { + var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return isTrainingComplete; + } + + public override Type ReturnType() + { + return typeof(float); + } + } + + #endregion + + #region UInt8Column + + internal sealed class UInt8TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + private RobustScalerTransformer _parent; + internal UInt8TypedColumn(string name, string source, RobustScalerTransformer parent) : + base(name, source, typeof(byte).ToString()) + { + _parent = parent; + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, byte input, out float output, out IntPtr errorHandle); + internal unsafe override float Transform(byte input) + { + var success = TransformDataNative(_transformerHandler, input, out float output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) + { + if (_parent._options.Scale) + return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle); + else + return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle); + } + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, byte input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, byte input, out FitResult fitResult, out IntPtr errorHandle) + { + return FitNative(estimator, input, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle); + private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle) + { + var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return isTrainingComplete; + } + + public override Type ReturnType() + { + return typeof(float); + } + } + + #endregion + + #region Int16Column + + internal sealed class Int16TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + private RobustScalerTransformer _parent; + internal Int16TypedColumn(string name, string source, RobustScalerTransformer parent) : + base(name, source, typeof(short).ToString()) + { + _parent = parent; + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, short input, out float output, out IntPtr errorHandle); + internal unsafe override float Transform(short input) + { + var success = TransformDataNative(_transformerHandler, input, out float output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) + { + if (_parent._options.Scale) + return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle); + else + return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle); + } + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, short input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, short input, out FitResult fitResult, out IntPtr errorHandle) + { + return FitNative(estimator, input, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle); + private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle) + { + var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return isTrainingComplete; + } + + public override Type ReturnType() + { + return typeof(float); + } + } + + #endregion + + #region UInt16Column + + internal sealed class UInt16TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + private RobustScalerTransformer _parent; + internal UInt16TypedColumn(string name, string source, RobustScalerTransformer parent) : + base(name, source, typeof(ushort).ToString()) + { + _parent = parent; + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, ushort input, out float output, out IntPtr errorHandle); + internal unsafe override float Transform(ushort input) + { + var success = TransformDataNative(_transformerHandler, input, out float output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) + { + if (_parent._options.Scale) + return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle); + else + return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle); + } + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, ushort input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, ushort input, out FitResult fitResult, out IntPtr errorHandle) + { + return FitNative(estimator, input, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle); + private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle) + { + var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return isTrainingComplete; + } + + public override Type ReturnType() + { + return typeof(float); + } + } + + #endregion + + #region Int32Column + + internal sealed class Int32TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + private RobustScalerTransformer _parent; + internal Int32TypedColumn(string name, string source, RobustScalerTransformer parent) : + base(name, source, typeof(int).ToString()) + { + _parent = parent; + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, int input, out double output, out IntPtr errorHandle); + internal unsafe override double Transform(int input) + { + var success = TransformDataNative(_transformerHandler, input, out double output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) + { + if (_parent._options.Scale) + return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle); + else + return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle); + } + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, int input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, int input, out FitResult fitResult, out IntPtr errorHandle) + { + return FitNative(estimator, input, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle); + private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle) + { + var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return isTrainingComplete; + } + + public override Type ReturnType() + { + return typeof(double); + } + } + + #endregion + + #region UInt32Column + + internal sealed class UInt32TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + private RobustScalerTransformer _parent; + internal UInt32TypedColumn(string name, string source, RobustScalerTransformer parent) : + base(name, source, typeof(uint).ToString()) + { + _parent = parent; + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, uint input, out double output, out IntPtr errorHandle); + internal unsafe override double Transform(uint input) + { + var success = TransformDataNative(_transformerHandler, input, out double output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) + { + if (_parent._options.Scale) + return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle); + else + return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle); + } + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, uint input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, uint input, out FitResult fitResult, out IntPtr errorHandle) + { + return FitNative(estimator, input, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle); + private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle) + { + var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return isTrainingComplete; + } + + public override Type ReturnType() + { + return typeof(double); + } + } + + #endregion + + #region Int64Column + + internal sealed class Int64TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + private RobustScalerTransformer _parent; + internal Int64TypedColumn(string name, string source, RobustScalerTransformer parent) : + base(name, source, typeof(long).ToString()) + { + _parent = parent; + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, long input, out double output, out IntPtr errorHandle); + internal unsafe override double Transform(long input) + { + var success = TransformDataNative(_transformerHandler, input, out double output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) + { + if (_parent._options.Scale) + return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle); + else + return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle); + } + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, long input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, long input, out FitResult fitResult, out IntPtr errorHandle) + { + return FitNative(estimator, input, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle); + private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle) + { + var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return isTrainingComplete; + } + + public override Type ReturnType() + { + return typeof(double); + } + } + + #endregion + + #region UInt64Column + + internal sealed class UInt64TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + private RobustScalerTransformer _parent; + internal UInt64TypedColumn(string name, string source, RobustScalerTransformer parent) : + base(name, source, typeof(ulong).ToString()) + { + _parent = parent; + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, ulong input, out double output, out IntPtr errorHandle); + internal unsafe override double Transform(ulong input) + { + var success = TransformDataNative(_transformerHandler, input, out double output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) + { + if (_parent._options.Scale) + return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle); + else + return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle); + } + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, ulong input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, ulong input, out FitResult fitResult, out IntPtr errorHandle) + { + return FitNative(estimator, input, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle); + private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle) + { + var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return isTrainingComplete; + } + + public override Type ReturnType() + { + return typeof(double); + } + } + + #endregion + + #region FloatColumn + + internal sealed class FloatTypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + private RobustScalerTransformer _parent; + internal FloatTypedColumn(string name, string source, RobustScalerTransformer parent) : + base(name, source, typeof(float).ToString()) + { + _parent = parent; + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, float input, out float output, out IntPtr errorHandle); + internal unsafe override float Transform(float input) + { + var success = TransformDataNative(_transformerHandler, input, out float output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) + { + if (_parent._options.Scale) + return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle); + else + return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle); + } + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, float input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, float input, out FitResult fitResult, out IntPtr errorHandle) + { + return FitNative(estimator, input, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle); + private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle) + { + var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return isTrainingComplete; + } + + public override Type ReturnType() + { + return typeof(float); + } + } + + #endregion + + #region DoubleColumn + + internal sealed class DoubleTypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + private RobustScalerTransformer _parent; + internal DoubleTypedColumn(string name, string source, RobustScalerTransformer parent) : + base(name, source, typeof(double).ToString()) + { + _parent = parent; + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator(IDataView input) + { + _transformerHandler = CreateTransformerFromEstimatorBase(input); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, double input, out double output, out IntPtr errorHandle); + internal unsafe override double Transform(double input) + { + var success = TransformDataNative(_transformerHandler, input, out double output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return output; + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) + { + if (_parent._options.Scale) + return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle); + else + return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle); + } + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, double input, out FitResult fitResult, out IntPtr errorHandle); + private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, double input, out FitResult fitResult, out IntPtr errorHandle) + { + return FitNative(estimator, input, out fitResult, out errorHandle); + } + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) => + CompleteTrainingNative(estimator, out fitResult, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle); + private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle) + { + var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return isTrainingComplete; + } + + public override Type ReturnType() + { + return typeof(double); + } + } + + #endregion + + #endregion // Column Info + + private sealed class Mapper : MapperBase + { + + #region Class data members + + private readonly RobustScalerTransformer _parent; + + #endregion + + public Mapper(RobustScalerTransformer parent, DataViewSchema inputSchema) : + base(parent.Host.Register(nameof(Mapper)), inputSchema, parent) + { + _parent = parent; + } + + protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() + { + return _parent._columns.Select(x => new DataViewSchema.DetachedColumn(x.Name, ColumnTypeExtensions.PrimitiveTypeFromType(x.ReturnType()))).ToArray(); + } + + private Delegate MakeGetter(DataViewRow input, int iinfo) + { + ValueGetter result = (ref TOutputType dst) => + { + var inputColumn = input.Schema[_parent._columns[iinfo].Source]; + var srcGetterScalar = input.GetGetter(inputColumn); + + TSourceType value = default; + srcGetterScalar(ref value); + + dst = ((TypedColumn)_parent._columns[iinfo]).Transform(value); + + }; + + return result; + } + + protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer) + { + disposer = null; + Type inputType = input.Schema[_parent._columns[iinfo].Source].Type.RawType; + Type outputType = _parent._columns[iinfo].ReturnType(); + + return Utils.MarshalInvoke(MakeGetter, new Type[] { inputType, outputType }, input, iinfo); + } + + private protected override Func GetDependenciesCore(Func activeOutput) + { + var active = new bool[InputSchema.Count]; + for (int i = 0; i < InputSchema.Count; i++) + { + if (_parent._columns.Any(x => x.Source == InputSchema[i].Name)) + { + active[i] = true; + } + } + + return col => active[col]; + } + + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); + } + } + + internal static class RobustScalerEntrypoint + { + [TlcModule.EntryPoint(Name = "Transforms.RobustScaler", + Desc = RobustScalerTransformer.Summary, + UserName = RobustScalerTransformer.UserName, + ShortName = RobustScalerTransformer.ShortName)] + public static CommonOutputs.TransformOutput RobustScaler(IHostEnvironment env, RobustScalerEstimator.Options input) + { + var h = EntryPointUtils.CheckArgsAndCreateHost(env, RobustScalerTransformer.ShortName, input); + var xf = new RobustScalerEstimator(h, input).Fit(input.Data).Transform(input.Data); + return new CommonOutputs.TransformOutput() + { + Model = new TransformModelImpl(h, xf, input.Data), + OutputData = xf + }; + } + } +} diff --git a/src/Microsoft.ML.Featurizers/TimeSeriesImputer.cs b/src/Microsoft.ML.Featurizers/TimeSeriesImputer.cs new file mode 100644 index 0000000000..6d899564ca --- /dev/null +++ b/src/Microsoft.ML.Featurizers/TimeSeriesImputer.cs @@ -0,0 +1,637 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Security; +using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Featurizers; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Runtime; +using Microsoft.ML.Transforms; +using static Microsoft.ML.Featurizers.CommonExtensions; + +[assembly: LoadableClass(typeof(TimeSeriesImputerTransformer), null, typeof(SignatureLoadModel), + TimeSeriesImputerTransformer.UserName, TimeSeriesImputerTransformer.LoaderSignature)] + +[assembly: LoadableClass(typeof(IDataTransform), typeof(TimeSeriesImputerTransformer), null, typeof(SignatureLoadDataTransform), + TimeSeriesImputerTransformer.UserName, TimeSeriesImputerTransformer.LoaderSignature)] + +[assembly: EntryPointModule(typeof(TimeSeriesTransformerEntrypoint))] + +namespace Microsoft.ML.Featurizers +{ + public static class TimeSeriesImputerExtensionClass + { + /// + /// Create a , Imputes missing rows and column data per grain. Operates on all columns in the IDataView. Currently only numeric columns are supported. + /// + /// The transform catalog. + /// Column representing the time series. Should be of type + /// List of columns to use as grains + /// Mode of imputation for missing values in column. If not passed defaults to forward fill + public static TimeSeriesImputerEstimator TimeSeriesImputer(this TransformsCatalog catalog, string timeSeriesColumn, string[] grainColumns, TimeSeriesImputerEstimator.ImputationStrategy imputeMode = TimeSeriesImputerEstimator.ImputationStrategy.ForwardFill) => + new TimeSeriesImputerEstimator(CatalogUtils.GetEnvironment(catalog), timeSeriesColumn, grainColumns, null, TimeSeriesImputerEstimator.FilterMode.NoFilter, imputeMode, false); + + /// + /// Create a , Imputes missing rows and column data per grain. Operates on a filtered list of columns in the IDataView. + /// If a column is not imputed but rows are added then it will be filled with the default value for that data type. Currently only numeric columns are supported for imputation. + /// + /// The transform catalog. + /// Column representing the time series. Should be of type + /// List of columns to use as grains + /// List of columns to filter. If is than columns in the list will be ignored. + /// If is than values in the list are the only columns imputed. + /// Whether the list should include or exclude those columns. + /// Mode of imputation for missing values in column. If not passed defaults to forward fill + /// Supress the errors that would occur if a column and impute mode are imcompatible. If true, will skip the column. If false, will stop and throw an error. + public static TimeSeriesImputerEstimator TimeSeriesImputer(this TransformsCatalog catalog, string timeSeriesColumn, string[] grainColumns, string[] filterColumns, TimeSeriesImputerEstimator.FilterMode filterMode = TimeSeriesImputerEstimator.FilterMode.Exclude, TimeSeriesImputerEstimator.ImputationStrategy imputeMode = TimeSeriesImputerEstimator.ImputationStrategy.ForwardFill, bool suppressTypeErrors = false) => + new TimeSeriesImputerEstimator(CatalogUtils.GetEnvironment(catalog), timeSeriesColumn, grainColumns, filterColumns, filterMode, imputeMode, suppressTypeErrors); + } + + /// + /// Imputes missing rows and column data per grain, based on the dates in the date column. Operates on a filtered list of columns in the IDataView. + /// If a column is not imputed but rows are added then it will be filled with the default value for that data type. Currently only numeric columns are supported for imputation. + /// A new column is added to the schema after this operation is run. The column is called "IsRowImputed" and is a boolean value representing if the row was created as a result + /// of this transformer or not. + /// + /// + /// is not a trivial estimator and needs training. + /// + /// + /// ]]> + /// + /// + /// + /// + public sealed class TimeSeriesImputerEstimator : IEstimator + { + private Options _options; + internal const string IsRowImputedColumnName = "IsRowImputed"; + + private readonly IHost _host; + private static readonly List _currentSupportedTypes = new List { typeof(sbyte), typeof(byte), typeof(short), typeof(ushort), typeof(int), typeof(uint), + typeof(long), typeof(ulong), typeof(float), typeof(double), typeof(string), typeof(ReadOnlyMemory)}; + + #region Options + internal sealed class Options : TransformInputBase + { + [Argument(ArgumentType.Required, HelpText = "Column representing the time", Name = "TimeSeriesColumn", ShortName = "time", SortOrder = 1)] + public string TimeSeriesColumn; + + [Argument((ArgumentType.MultipleUnique | ArgumentType.Required), HelpText = "List of grain columns", Name = "GrainColumns", ShortName = "grains", SortOrder = 2)] + public string[] GrainColumns; + + // This transformer adds columns + [Argument(ArgumentType.MultipleUnique, HelpText = "Columns to filter", Name = "FilterColumns", ShortName = "filters", SortOrder = 2)] + public string[] FilterColumns; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Filter mode. Either include or exclude", Name = "FilterMode", ShortName = "fmode", SortOrder = 3)] + public FilterMode FilterMode = FilterMode.Exclude; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Mode for imputing, defaults to ForwardFill if not provided", Name = "ImputeMode", ShortName = "mode", SortOrder = 3)] + public ImputationStrategy ImputeMode = ImputationStrategy.ForwardFill; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Supress the errors that would occur if a column and impute mode are imcompatible. If true, will skip the column. If false, will stop and throw an error.", Name = "SupressTypeErrors", ShortName = "error", SortOrder = 3)] + public bool SupressTypeErrors = false; + } + + #endregion + + #region Class Enums + + public enum ImputationStrategy : byte + { + ForwardFill = 1, BackFill, Median, Interpolate + }; + + public enum FilterMode : byte + { + NoFilter = 1, Include, Exclude + }; + + #endregion + + internal TimeSeriesImputerEstimator(IHostEnvironment env, string timeSeriesColumn, string[] grainColumns, string[] filterColumns, FilterMode filterMode, ImputationStrategy imputeMode, bool supressTypeErrors) + { + Contracts.CheckValue(env, nameof(env)); + _host = Contracts.CheckRef(env, nameof(env)).Register("TimeSeriesImputerEstimator"); + _host.CheckValue(timeSeriesColumn, nameof(timeSeriesColumn), "TimePoint column should not be null."); + _host.CheckNonEmpty(grainColumns, nameof(grainColumns), "Need at least one grain column."); + if (filterMode == FilterMode.Include) + _host.CheckNonEmpty(filterColumns, nameof(filterColumns), "Need at least 1 filter column if a FilterMode is specified"); + + _options = new Options + { + TimeSeriesColumn = timeSeriesColumn, + GrainColumns = grainColumns, + FilterColumns = filterColumns == null ? new string[] { } : filterColumns, + FilterMode = filterMode, + ImputeMode = imputeMode, + SupressTypeErrors = supressTypeErrors + }; + } + + internal TimeSeriesImputerEstimator(IHostEnvironment env, Options options) + { + Contracts.CheckValue(env, nameof(env)); + _host = Contracts.CheckRef(env, nameof(env)).Register("TimeSeriesImputerEstimator"); + _host.CheckValue(options.TimeSeriesColumn, nameof(options.TimeSeriesColumn), "TimePoint column should not be null."); + _host.CheckValue(options.GrainColumns, nameof(options.GrainColumns), "Grain columns should not be null."); + _host.CheckNonEmpty(options.GrainColumns, nameof(options.GrainColumns), "Need at least one grain column."); + if (options.FilterMode != FilterMode.NoFilter) + _host.CheckNonEmpty(options.FilterColumns, nameof(options.FilterColumns), "Need at least 1 filter column if a FilterMode is specified"); + + _options = options; + } + + public TimeSeriesImputerTransformer Fit(IDataView input) + { + // If we are not suppressing type errors make sure columns to impute only contain supported types. + if (!_options.SupressTypeErrors) + { + var columns = input.Schema.Where(x => !_options.GrainColumns.Contains(x.Name)); + if (_options.FilterMode == FilterMode.Exclude) + columns = columns.Where(x => !_options.FilterColumns.Contains(x.Name)); + else if (_options.FilterMode == FilterMode.Include) + columns = columns.Where(x => _options.FilterColumns.Contains(x.Name)); + + foreach (var column in columns) + { + if (!_currentSupportedTypes.Contains(column.Type.RawType)) + throw new InvalidOperationException($"Type {column.Type.RawType.ToString()} for column {column.Name} not a supported type."); + } + } + + return new TimeSeriesImputerTransformer(_host, _options.TimeSeriesColumn, _options.GrainColumns, _options.FilterColumns, _options.FilterMode, _options.ImputeMode, _options.SupressTypeErrors, input); + } + + // Add one column called WasColumnImputed, otherwise everything stays the same. + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + var columns = inputSchema.ToDictionary(x => x.Name); + columns[IsRowImputedColumnName] = new SchemaShape.Column(IsRowImputedColumnName, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false); + return new SchemaShape(columns.Values); + } + } + + public sealed class TimeSeriesImputerTransformer : ITransformer, IDisposable + { + #region Class data members + + internal const string Summary = "Fills in missing row and values"; + internal const string UserName = "TimeSeriesImputer"; + internal const string ShortName = "TimeSeriesImputer"; + internal const string LoadName = "TimeSeriesImputer"; + internal const string LoaderSignature = "TimeSeriesImputer"; + + private readonly IHost _host; + private readonly string _timeSeriesColumn; + private readonly string[] _grainColumns; + private readonly string[] _dataColumns; + private readonly string[] _allColumnNames; + private readonly bool _suppressTypeErrors; + private readonly TimeSeriesImputerEstimator.ImputationStrategy _imputeMode; + internal TransformerEstimatorSafeHandle TransformerHandle; + + #endregion + + // Normal constructor. + internal TimeSeriesImputerTransformer(IHostEnvironment host, string timeSeriesColumn, string[] grainColumns, string[] filterColumns, TimeSeriesImputerEstimator.FilterMode filterMode, TimeSeriesImputerEstimator.ImputationStrategy imputeMode, bool suppressTypeErrors, IDataView input) + { + _host = host.Register(nameof(TimeSeriesImputerTransformer)); + _timeSeriesColumn = timeSeriesColumn; + _grainColumns = grainColumns; + _imputeMode = imputeMode; + _suppressTypeErrors = suppressTypeErrors; + + IEnumerable tempDataColumns; + + if (filterMode == TimeSeriesImputerEstimator.FilterMode.Exclude) + tempDataColumns = input.Schema.Where(x => !filterColumns.Contains(x.Name)).Select(x => x.Name); + else if (filterMode == TimeSeriesImputerEstimator.FilterMode.Include) + tempDataColumns = input.Schema.Where(x => filterColumns.Contains(x.Name)).Select(x => x.Name); + else + tempDataColumns = input.Schema.Select(x => x.Name); + + // Time series and Grain columns should never be included in the data columns + _dataColumns = tempDataColumns.Where(x => x != timeSeriesColumn && !grainColumns.Contains(x)).ToArray(); + + // 1 is for the time series column. Make one array in the correct order of all the columns. + // Order is Timeseries column, All grain columns, All data columns. + _allColumnNames = new string[1 + _grainColumns.Length + _dataColumns.Length]; + _allColumnNames[0] = _timeSeriesColumn; + Array.Copy(_grainColumns, 0, _allColumnNames, 1, _grainColumns.Length); + Array.Copy(_dataColumns, 0, _allColumnNames, 1 + _grainColumns.Length, _dataColumns.Length); + + TransformerHandle = CreateTransformerFromEstimator(input); + } + + // Factory method for SignatureLoadModel. + internal TimeSeriesImputerTransformer(IHostEnvironment host, ModelLoadContext ctx) + { + _host = host.Register(nameof(TimeSeriesImputerTransformer)); + + // *** Binary format *** + // name of time series column + // length of grain column array + // all column names in grain column array + // length of filter column array + // all column names in filter column array + // byte value of filter mode + // byte value of impute mode + // length of C++ state array + // C++ byte state array + + _timeSeriesColumn = ctx.Reader.ReadString(); + + _grainColumns = new string[ctx.Reader.ReadInt32()]; + for (int i = 0; i < _grainColumns.Length; i++) + _grainColumns[i] = ctx.Reader.ReadString(); + + _dataColumns = new string[ctx.Reader.ReadInt32()]; + for (int i = 0; i < _dataColumns.Length; i++) + _dataColumns[i] = ctx.Reader.ReadString(); + + _imputeMode = (TimeSeriesImputerEstimator.ImputationStrategy)ctx.Reader.ReadByte(); + + _allColumnNames = new string[1 + _grainColumns.Length + _dataColumns.Length]; + _allColumnNames[0] = _timeSeriesColumn; + Array.Copy(_grainColumns, 0, _allColumnNames, 1, _grainColumns.Length); + Array.Copy(_dataColumns, 0, _allColumnNames, 1 + _grainColumns.Length, _dataColumns.Length); + + var nativeData = ctx.Reader.ReadByteArray(); + TransformerHandle = CreateTransformerFromSavedData(nativeData); + } + + private unsafe TransformerEstimatorSafeHandle CreateTransformerFromSavedData(byte[] data) + { + fixed (byte* rawData = data) + { + IntPtr dataSize = new IntPtr(data.Count()); + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + } + + // Factory method for SignatureLoadDataTransform. + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + { + return (IDataTransform)(new TimeSeriesImputerTransformer(env, ctx).Transform(input)); + } + + private unsafe TransformerEstimatorSafeHandle CreateTransformerFromEstimator(IDataView input) + { + IntPtr estimator; + IntPtr errorHandle; + bool success; + + var allColumns = input.Schema.Where(x => _allColumnNames.Contains(x.Name)).Select(x => TypedColumn.CreateTypedColumn(x, _dataColumns)).ToDictionary(x => x.Column.Name); + + // Create buffer to hold binary data + var columnBuffer = new byte[1024]; + + // Create TypeId[] for types of grain and data columns; + var dataColumnTypes = new TypeId[_dataColumns.Length]; + var grainColumnTypes = new TypeId[_grainColumns.Length]; + + foreach (var column in _grainColumns.Select((value, index) => new { index, value })) + grainColumnTypes[column.index] = allColumns[column.value].GetTypeId(); + + foreach (var column in _dataColumns.Select((value, index) => new { index, value })) + dataColumnTypes[column.index] = allColumns[column.value].GetTypeId(); + + fixed (bool* suppressErrors = &_suppressTypeErrors) + fixed (TypeId* rawDataColumnTypes = dataColumnTypes) + fixed (TypeId* rawGrainColumnTypes = grainColumnTypes) + { + success = CreateEstimatorNative(rawGrainColumnTypes, new IntPtr(grainColumnTypes.Length), rawDataColumnTypes, new IntPtr(dataColumnTypes.Length), _imputeMode, suppressErrors, out estimator, out errorHandle); + } + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + using (var estimatorHandler = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorNative)) + { + var fitResult = FitResult.Continue; + while (fitResult != FitResult.Complete) + { + using (var cursor = input.GetRowCursorForAllColumns()) + { + // Initialize getters for start of loop + foreach (var column in allColumns.Values) + column.InitializeGetter(cursor); + + while ((fitResult == FitResult.Continue || fitResult == FitResult.ResetAndContinue) && cursor.MoveNext()) + { + BuildColumnByteArray(allColumns, ref columnBuffer, out int bufferLength); + + fixed (byte* bufferPointer = columnBuffer) + { + var binaryArchiveData = new NativeBinaryArchiveData() { Data = bufferPointer, DataSize = new IntPtr(bufferLength) }; + success = FitNative(estimatorHandler, binaryArchiveData, out fitResult, out errorHandle); + } + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + + success = CompleteTrainingNative(estimatorHandler, out fitResult, out errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + } + + success = CreateTransformerFromEstimatorNative(estimatorHandler, out IntPtr transformer, out errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + } + + private void BuildColumnByteArray(Dictionary allColumns, ref byte[] columnByteBuffer, out int bufferLength) + { + bufferLength = 0; + foreach (var column in _allColumnNames) + { + var bytes = allColumns[column].GetSerializedValue(); + var byteLength = bytes.Length; + if (byteLength + bufferLength >= columnByteBuffer.Length) + Array.Resize(ref columnByteBuffer, columnByteBuffer.Length * 2); + + Array.Copy(bytes, 0, columnByteBuffer, bufferLength, byteLength); + bufferLength += byteLength; + } + } + + public bool IsRowToRowMapper => false; + + // Schema not changed + public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) + { + var columns = inputSchema.ToDictionary(x => x.Name); + var schemaBuilder = new DataViewSchema.Builder(); + schemaBuilder.AddColumns(inputSchema.AsEnumerable()); + schemaBuilder.AddColumn(TimeSeriesImputerEstimator.IsRowImputedColumnName, BooleanDataViewType.Instance); + + return schemaBuilder.ToSchema(); + } + + public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) => throw new InvalidOperationException("Not a RowToRowMapper."); + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "TimeIm T", + verWrittenCur: 0x00010001, + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(TimeSeriesImputerTransformer).Assembly.FullName); + } + + public void Save(ModelSaveContext ctx) + { + _host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // name of time series column + // length of grain column array + // all column names in grain column array + // length of data column array + // all column names in data column array + // byte value of impute mode + // length of C++ state array + // C++ byte state array + + ctx.Writer.Write(_timeSeriesColumn); + ctx.Writer.Write(_grainColumns.Length); + foreach (var column in _grainColumns) + ctx.Writer.Write(column); + ctx.Writer.Write(_dataColumns.Length); + foreach (var column in _dataColumns) + ctx.Writer.Write(column); + ctx.Writer.Write((byte)_imputeMode); + var data = CreateTransformerSaveData(); + ctx.Writer.Write(data.Length); + ctx.Writer.Write(data); + } + + private byte[] CreateTransformerSaveData() + { + var success = CreateTransformerSaveDataNative(TransformerHandle, out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + using (var savedDataHandle = new SaveDataSafeHandle(buffer, bufferSize)) + { + byte[] savedData = new byte[bufferSize.ToInt32()]; + Marshal.Copy(buffer, savedData, 0, savedData.Length); + return savedData; + } + } + + public IDataView Transform(IDataView input) => MakeDataTransform(input); + + internal TimeSeriesImputerDataView MakeDataTransform(IDataView input) + { + _host.CheckValue(input, nameof(input)); + + return new TimeSeriesImputerDataView(_host, input, _timeSeriesColumn, _grainColumns, _dataColumns, _allColumnNames, this); + } + + internal TransformerEstimatorSafeHandle CloneTransformer() => CreateTransformerFromSavedData(CreateTransformerSaveData()); + + public void Dispose() + { + if (!TransformerHandle.IsClosed) + TransformerHandle.Close(); + } + + #region C++ function declarations + // TODO: Update entry points + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateEstimatorNative(TypeId* grainTypes, IntPtr grainTypesSize, TypeId* dataTypes, IntPtr dataTypesSize, TimeSeriesImputerEstimator.ImputationStrategy strategy, bool* suppressTypeErrors, out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, NativeBinaryArchiveData data, out FitResult fitResult, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + + #endregion + + #region Typed Columns + + private abstract class TypedColumn + { + internal readonly DataViewSchema.Column Column; + internal TypedColumn(DataViewSchema.Column column) + { + Column = column; + } + + internal abstract void InitializeGetter(DataViewRowCursor cursor); + internal abstract byte[] GetSerializedValue(); + internal abstract TypeId GetTypeId(); + + internal static TypedColumn CreateTypedColumn(DataViewSchema.Column column, string[] optionalColumns) + { + var type = column.Type.RawType.ToString(); + if (type == typeof(sbyte).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(short).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(int).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(long).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(byte).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(ushort).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(uint).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(ulong).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(float).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(double).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(ReadOnlyMemory).ToString()) + return new StringTypedColumn(column, optionalColumns.Contains(column.Name)); + + throw new InvalidOperationException($"Unsupported type {type}"); + } + } + + private abstract class TypedColumn : TypedColumn + { + private ValueGetter _getter; + private T _value; + + internal TypedColumn(DataViewSchema.Column column) : + base(column) + { + _value = default; + } + + internal override void InitializeGetter(DataViewRowCursor cursor) + { + _getter = cursor.GetGetter(Column); + } + + internal T GetValue() + { + _getter(ref _value); + return _value; + } + + internal override TypeId GetTypeId() + { + return typeof(T).GetNativeTypeIdFromType(); + } + } + + private class NumericTypedColumn : TypedColumn + { + private readonly bool _isNullable; + + internal NumericTypedColumn(DataViewSchema.Column column, bool isNullable = false) : + base(column) + { + _isNullable = isNullable; + } + + internal override byte[] GetSerializedValue() + { + dynamic value = GetValue(); + byte[] bytes; + if (value.GetType() == typeof(byte)) + bytes = new byte[1] { value }; + if (BitConverter.IsLittleEndian) + bytes = BitConverter.GetBytes(value); + // Will need to enable this when Jin's pr goes in. return ((IEnumerable)BitConverter.GetBytes(value)).Reverse().ToArray(); + else + bytes = BitConverter.GetBytes(value); + + if (_isNullable && value.GetType() != typeof(float) && value.GetType() != typeof(double)) + return new byte[1] { Convert.ToByte(true) }.Concat(bytes).ToArray(); + else + return bytes; + } + } + + private class StringTypedColumn : TypedColumn> + { + private readonly bool _isNullable; + + internal StringTypedColumn(DataViewSchema.Column column, bool isNullable = false) : + base(column) + { + _isNullable = isNullable; + } + + internal override byte[] GetSerializedValue() + { + var value = GetValue().ToString(); + var stringBytes = Encoding.UTF8.GetBytes(value); + if (_isNullable) + return new byte[] { Convert.ToByte(true) }.Concat(BitConverter.GetBytes(stringBytes.Length)).Concat(stringBytes).ToArray(); + return BitConverter.GetBytes(stringBytes.Length).Concat(stringBytes).ToArray(); + } + } + + #endregion + } + + internal static class TimeSeriesTransformerEntrypoint + { + [TlcModule.EntryPoint(Name = "Transforms.TimeSeriesImputer", + Desc = TimeSeriesImputerTransformer.Summary, + UserName = TimeSeriesImputerTransformer.UserName, + ShortName = TimeSeriesImputerTransformer.ShortName)] + public static CommonOutputs.TransformOutput TimeSeriesImputer(IHostEnvironment env, TimeSeriesImputerEstimator.Options input) + { + var h = EntryPointUtils.CheckArgsAndCreateHost(env, TimeSeriesImputerTransformer.ShortName, input); + var xf = new TimeSeriesImputerEstimator(h, input).Fit(input.Data).Transform(input.Data); + return new CommonOutputs.TransformOutput() + { + Model = new TransformModelImpl(h, xf, input.Data), + OutputData = xf + }; + } + } +} diff --git a/src/Microsoft.ML.Featurizers/TimeSeriesImputerDataView.cs b/src/Microsoft.ML.Featurizers/TimeSeriesImputerDataView.cs new file mode 100644 index 0000000000..ff9fa5d8ec --- /dev/null +++ b/src/Microsoft.ML.Featurizers/TimeSeriesImputerDataView.cs @@ -0,0 +1,761 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Security; +using System.Text; +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; +using Microsoft.ML.Runtime; +using Microsoft.Win32.SafeHandles; +using static Microsoft.ML.Featurizers.CommonExtensions; +using static Microsoft.ML.Featurizers.TimeSeriesImputerEstimator; + +namespace Microsoft.ML.Transforms +{ + + internal sealed class TimeSeriesImputerDataView : IDataTransform + { + #region Typed Columns + private TimeSeriesImputerTransformer _parent; + public class SharedColumnState + { + public bool SourceCanMoveNext { get; set; } + public int TransformedDataPosition { get; set; } + public NativeBinaryArchiveData[] TransformedData { get; set; } + public byte[] ColumnBuffer { get; set; } + public TransformedDataSafeHandle TransformedDataHandler { get; set; } + } + + private abstract class TypedColumn + { + private protected SharedColumnState SharedState; + + internal readonly DataViewSchema.Column Column; + internal readonly bool IsImputed; + internal TypedColumn(DataViewSchema.Column column, bool isImputed, SharedColumnState state) + { + Column = column; + SharedState = state; + IsImputed = isImputed; + } + + internal abstract Delegate GetGetter(); + internal abstract void InitializeGetter(DataViewRowCursor cursor, TransformerEstimatorSafeHandle transformerParent, string timeSeriesColumn, + string[] grainColumns, string[] dataColumns, string[] allColumnNames, Dictionary allColumns); + + internal abstract TypeId GetTypeId(); + internal abstract byte[] GetSerializedValue(); + internal abstract unsafe int GetDataSizeInBytes(byte* data, int currentOffset); + internal abstract void QueueNonImputedColumnValue(); + + public bool MoveNext(DataViewRowCursor cursor) + { + SharedState.TransformedDataPosition++; + + if (SharedState.TransformedData == null || SharedState.TransformedDataPosition >= SharedState.TransformedData.Length) + SharedState.SourceCanMoveNext = cursor.MoveNext(); + + if (!SharedState.SourceCanMoveNext) + if (SharedState.TransformedDataPosition >= SharedState.TransformedData.Length) + { + if (!SharedState.TransformedDataHandler.IsClosed) + SharedState.TransformedDataHandler.Dispose(); + return false; + } + + return true; + } + + internal static TypedColumn CreateTypedColumn(DataViewSchema.Column column, string[] optionalColumns, string[] allImputedColumns, SharedColumnState state) + { + var type = column.Type.RawType.ToString(); + if (type == typeof(sbyte).ToString()) + return new SByteTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(short).ToString()) + return new ShortTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(int).ToString()) + return new IntTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(long).ToString()) + return new LongTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(byte).ToString()) + return new ByteTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(ushort).ToString()) + return new UShortTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(uint).ToString()) + return new UIntTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(ulong).ToString()) + return new ULongTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(float).ToString()) + return new FloatTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(double).ToString()) + return new DoubleTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(ReadOnlyMemory).ToString()) + return new StringTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(bool).ToString()) + return new BoolTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + + throw new InvalidOperationException($"Unsupported type {type}"); + } + } + + private abstract class TypedColumn : TypedColumn + { + private ValueGetter _getter; + private ValueGetter _sourceGetter; + private long _position; + private T _cache; + + // When columns are not being imputed, we need to store the column values in memory until they are used. + private protected Queue SourceQueue; + + internal TypedColumn(DataViewSchema.Column column, bool isImputed, SharedColumnState state) : + base(column, isImputed, state) + { + SourceQueue = new Queue(); + _position = -1; + } + + internal override Delegate GetGetter() + { + return _getter; + } + + internal override unsafe void InitializeGetter(DataViewRowCursor cursor, TransformerEstimatorSafeHandle transformer, string timeSeriesColumn, + string[] grainColumns, string[] dataColumns, string[] allImputedColumnNames, Dictionary allColumns) + { + if (Column.Name != IsRowImputedColumnName) + _sourceGetter = cursor.GetGetter(Column); + + _getter = (ref T dst) => + { + IntPtr errorHandle = IntPtr.Zero; + bool success = false; + if (SharedState.TransformedData == null || SharedState.TransformedDataPosition >= SharedState.TransformedData.Length) + { + // Free native memory if we are about to get more + if (SharedState.TransformedData != null && SharedState.TransformedDataPosition >= SharedState.TransformedData.Length) + SharedState.TransformedDataHandler.Dispose(); + + var outputDataSize = IntPtr.Zero; + NativeBinaryArchiveData* outputData = default; + while(outputDataSize == IntPtr.Zero && SharedState.SourceCanMoveNext) + { + BuildColumnByteArray(allColumns, allImputedColumnNames, out int bufferLength); + QueueDataForNonImputedColumns(allColumns, allImputedColumnNames); + fixed (byte* bufferPointer = SharedState.ColumnBuffer) + { + var binaryArchiveData = new NativeBinaryArchiveData() { Data = bufferPointer, DataSize = new IntPtr(bufferLength) }; + success = TransformDataNative(transformer, binaryArchiveData, out outputData, out outputDataSize, out errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + + if (outputDataSize == IntPtr.Zero) + SharedState.SourceCanMoveNext = cursor.MoveNext(); + } + + if (!SharedState.SourceCanMoveNext) + success = FlushDataNative(transformer, out outputData, out outputDataSize, out errorHandle); + + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + if (outputDataSize.ToInt32() > 0) + { + SharedState.TransformedDataHandler = new TransformedDataSafeHandle((IntPtr)outputData, outputDataSize); + SharedState.TransformedData = new NativeBinaryArchiveData[outputDataSize.ToInt32()]; + for (int i = 0; i < outputDataSize.ToInt32(); i++) + { + SharedState.TransformedData[i] = *(outputData + i); + } + SharedState.TransformedDataPosition = 0; + } + } + + // Base case where we didn't impute anything + if (!allImputedColumnNames.Contains(Column.Name)) + { + var imputedData = SharedState.TransformedData[SharedState.TransformedDataPosition]; + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(imputedData.Data, 0)) + { + dst = default; + } + else + { + if (_position != cursor.Position) + { + _position = cursor.Position; + _cache = SourceQueue.Dequeue(); + } + dst = _cache; + } + } + else + { + var imputedData = SharedState.TransformedData[SharedState.TransformedDataPosition]; + int offset = 0; + foreach (var columnName in allImputedColumnNames) + { + var col = allColumns[columnName]; + if (col.Column.Name == Column.Name) + { + dst = GetDataFromNativeBinaryArchiveData(imputedData.Data, offset); + return; + } + + offset += col.GetDataSizeInBytes(imputedData.Data, offset); + } + + // This should never be hit. + dst = default; + } + }; + } + + private void QueueDataForNonImputedColumns(Dictionary allColumns, string[] allImputedColumnNames) + { + foreach (var column in allColumns.Where(x => !allImputedColumnNames.Contains(x.Value.Column.Name)).Select(x => x.Value)) + column.QueueNonImputedColumnValue(); + } + + internal override void QueueNonImputedColumnValue() + { + SourceQueue.Enqueue(GetSourceValue()); + } + + private void BuildColumnByteArray(Dictionary allColumns, string[] columns, out int bufferLength) + { + bufferLength = 0; + foreach(var column in columns.Where(x => x != IsRowImputedColumnName)) + { + var bytes = allColumns[column].GetSerializedValue(); + var byteLength = bytes.Length; + if (byteLength + bufferLength >= SharedState.ColumnBuffer.Length) + { + var buffer = SharedState.ColumnBuffer; + Array.Resize(ref buffer, SharedState.ColumnBuffer.Length * 2); + SharedState.ColumnBuffer = buffer; + } + + Array.Copy(bytes, 0, SharedState.ColumnBuffer, bufferLength, byteLength); + bufferLength += byteLength; + } + } + + private protected T GetSourceValue() + { + T value = default; + _sourceGetter(ref value); + return value; + } + + internal override TypeId GetTypeId() + { + return typeof(T).GetNativeTypeIdFromType(); + } + + internal unsafe abstract T GetDataFromNativeBinaryArchiveData(byte* data, int offset); + } + + private abstract class NumericTypedColumn : TypedColumn + { + private protected readonly bool IsNullable; + + internal NumericTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isImputed, state) + { + IsNullable = isNullable; + } + + internal override byte[] GetSerializedValue() + { + dynamic value = GetSourceValue(); + byte[] bytes; + if (value.GetType() == typeof(byte)) + bytes = new byte[1] { value }; + if (BitConverter.IsLittleEndian) + bytes = BitConverter.GetBytes(value); + else + bytes = BitConverter.GetBytes(value); + + if (IsNullable && value.GetType() != typeof(float) && value.GetType() != typeof(double)) + return new byte[1] { Convert.ToByte(true) }.Concat(bytes).ToArray(); + else + return bytes; + } + + internal override unsafe int GetDataSizeInBytes(byte* data, int currentOffset) + { + if (IsNullable && typeof(T) != typeof(float) && typeof(T) != typeof(double)) + return Marshal.SizeOf(default(T)) + sizeof(bool); + else + return Marshal.SizeOf(default(T)); + } + } + + private class ByteTypedColumn : NumericTypedColumn + { + internal ByteTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override byte GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(byte*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(byte*)(data + offset); + } + } + + private class SByteTypedColumn : NumericTypedColumn + { + internal SByteTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override sbyte GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(sbyte*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(sbyte*)(data + offset); + } + } + + private class ShortTypedColumn : NumericTypedColumn + { + internal ShortTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override short GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(short*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(short*)(data + offset); + } + } + + private class UShortTypedColumn : NumericTypedColumn + { + internal UShortTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override ushort GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(ushort*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(ushort*)(data + offset); + } + } + + private class IntTypedColumn : NumericTypedColumn + { + internal IntTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override int GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(int*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(int*)(data + offset); + } + } + + private class UIntTypedColumn : NumericTypedColumn + { + internal UIntTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override uint GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(uint*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(uint*)(data + offset); + } + } + + private class LongTypedColumn : NumericTypedColumn + { + internal LongTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override long GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(long*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(long*)(data + offset); + } + } + + private class ULongTypedColumn : NumericTypedColumn + { + internal ULongTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override ulong GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(ulong*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(ulong*)(data + offset); + } + } + + private class FloatTypedColumn : NumericTypedColumn + { + internal FloatTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override float GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + var bytes = new byte[sizeof(float)]; + Marshal.Copy((IntPtr)(data + offset), bytes, 0, sizeof(float)); + return BitConverter.ToSingle(bytes, 0); + } + } + + private class DoubleTypedColumn : NumericTypedColumn + { + internal DoubleTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override double GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + var bytes = new byte[sizeof(double)]; + Marshal.Copy((IntPtr)(data + offset), bytes, 0, sizeof(double)); + return BitConverter.ToDouble(bytes, 0); + } + } + + private class BoolTypedColumn : NumericTypedColumn + { + internal BoolTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override bool GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(bool*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(bool*)(data + offset); + } + + internal static unsafe bool GetBoolFromNativeBinaryArchiveData(byte* data, int offset) + { + return *(bool*)(data + offset); + } + + internal override unsafe int GetDataSizeInBytes(byte* data, int currentOffset) + { + return sizeof(bool); + } + } + + private class StringTypedColumn : TypedColumn> + { + private readonly bool _isNullable; + internal StringTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isImputed, state) + { + _isNullable = isNullable; + } + + internal override byte[] GetSerializedValue() + { + var value = GetSourceValue().ToString(); + var stringBytes = Encoding.UTF8.GetBytes(value); + if (_isNullable) + return new byte[] { Convert.ToByte(true)}.Concat(BitConverter.GetBytes(stringBytes.Length)).Concat(stringBytes).ToArray(); + return BitConverter.GetBytes(stringBytes.Length).Concat(stringBytes).ToArray(); + } + + internal unsafe override ReadOnlyMemory GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (_isNullable) + { + if (!BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) // If value not present return empty string + return new ReadOnlyMemory("".ToCharArray()); + + var size = *(uint*)(data + offset + 1); // Add 1 for the byte bool flag + + var bytes = new byte[size]; + Marshal.Copy((IntPtr)(data + offset + sizeof(uint) + 1), bytes, 0, (int)size); + return Encoding.UTF8.GetString(bytes).AsMemory(); + } + else + { + var size = *(uint*)(data + offset); + + var bytes = new byte[size]; + Marshal.Copy((IntPtr)(data + offset + sizeof(uint)), bytes, 0, (int)size); + return Encoding.UTF8.GetString(bytes).AsMemory(); + } + } + + internal override unsafe int GetDataSizeInBytes(byte* data, int currentOffset) + { + var size = *(uint*)(data + currentOffset); + if (_isNullable) + return 1 + (int)size + sizeof(uint); // + 1 for the byte bool flag + + return (int)size + sizeof(uint); + } + } + + #endregion + + #region Native Exports + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_Transform"), SuppressUnmanagedCodeSecurity] + private static extern unsafe bool TransformDataNative(TransformerEstimatorSafeHandle transformer, /*in*/ NativeBinaryArchiveData data, out NativeBinaryArchiveData* outputData, out IntPtr outputDataSize, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_Transform"), SuppressUnmanagedCodeSecurity] + private static extern unsafe bool FlushDataNative(TransformerEstimatorSafeHandle transformer, out NativeBinaryArchiveData* outputData, out IntPtr outputDataSize, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern unsafe bool DestroyTransformedDataNative(IntPtr data, IntPtr dataSize, out IntPtr errorHandle); + + #endregion + + #region Native SafeHandles + + internal class TransformedDataSafeHandle : SafeHandleZeroOrMinusOneIsInvalid + { + private IntPtr _size; + public TransformedDataSafeHandle(IntPtr handle, IntPtr size) : base(true) + { + SetHandle(handle); + _size = size; + } + + protected override bool ReleaseHandle() + { + // Not sure what to do with error stuff here. There shoudln't ever be one though. + return DestroyTransformedDataNative(handle, _size, out IntPtr errorHandle); + } + } + + #endregion + + private readonly IHostEnvironment _host; + private readonly IDataView _source; + private readonly string _timeSeriesColumn; + private readonly string[] _dataColumns; + private readonly string[] _grainColumns; + private readonly string[] _allImputedColumnNames; + private readonly DataViewSchema _schema; + + internal TimeSeriesImputerDataView(IHostEnvironment env, IDataView input, string timeSeriesColumn, string[] grainColumns, string[] dataColumns, string[] allColumnNames, TimeSeriesImputerTransformer parent) + { + _host = env; + _source = input; + + _timeSeriesColumn = timeSeriesColumn; + _grainColumns = grainColumns; + _dataColumns = dataColumns; + _allImputedColumnNames = new string[] { IsRowImputedColumnName }.Concat(allColumnNames).ToArray(); + _parent = parent; + // Build new schema. + var schemaColumns = _source.Schema.ToDictionary(x => x.Name); + var schemaBuilder = new DataViewSchema.Builder(); + schemaBuilder.AddColumns(_source.Schema.AsEnumerable()); + schemaBuilder.AddColumn(IsRowImputedColumnName, BooleanDataViewType.Instance); + + _schema = schemaBuilder.ToSchema(); + } + + public bool CanShuffle => false; + + public DataViewSchema Schema => _schema; + + public IDataView Source => _source; + + public DataViewRowCursor GetRowCursor(IEnumerable columnsNeeded, Random rand = null) + { + _host.AssertValueOrNull(rand); + + var input = _source.GetRowCursorForAllColumns(); + return new Cursor(_host, input, _parent.CloneTransformer(), _timeSeriesColumn, _grainColumns, _dataColumns, _allImputedColumnNames, _schema); + } + + // Can't use parallel cursors so this defaults to calling non-parallel version + public DataViewRowCursor[] GetRowCursorSet(IEnumerable columnsNeeded, int n, Random rand = null) => + new DataViewRowCursor[] { GetRowCursor(columnsNeeded, rand) }; + + // Since we may add rows we don't know the row count + public long? GetRowCount() { return null; } + + public void Save(ModelSaveContext ctx) + { + _parent.Save(ctx); + } + + private sealed class Cursor : DataViewRowCursor + { + private readonly IChannelProvider _ch; + private DataViewRowCursor _input; + private long _position; + private bool _isGood; + private readonly Dictionary _allColumns; + private readonly DataViewSchema _schema; + private readonly TransformerEstimatorSafeHandle _transformer; + + public Cursor(IChannelProvider provider, DataViewRowCursor input, TransformerEstimatorSafeHandle transformer, string timeSeriesColumn, + string[] grainColumns, string[] dataColumns, string[] allImputedColumnNames, DataViewSchema schema) + { + _ch = provider; + _ch.CheckValue(input, nameof(input)); + + _input = input; + var length = input.Schema.Count; + _position = -1; + _schema = schema; + _transformer = transformer; + + var sharedState = new SharedColumnState() + { + SourceCanMoveNext = true, + ColumnBuffer = new byte[1024] + }; + + _allColumns = _schema.Select(x => TypedColumn.CreateTypedColumn(x, dataColumns, allImputedColumnNames, sharedState)).ToDictionary(x => x.Column.Name); ; + _allColumns[IsRowImputedColumnName] = new BoolTypedColumn(_schema[IsRowImputedColumnName], false, true, sharedState); + + foreach (var column in _allColumns.Values) + { + column.InitializeGetter(_input, transformer, timeSeriesColumn, grainColumns, dataColumns, allImputedColumnNames, _allColumns); + } + } + + public sealed override ValueGetter GetIdGetter() + { + return + (ref DataViewRowId val) => + { + _ch.Check(_isGood, RowCursorUtils.FetchValueStateError); + val = new DataViewRowId((ulong)Position, 0); + }; + } + + public sealed override DataViewSchema Schema => _schema; + + /// + /// Since rows will be generated all columns are active + /// + public override bool IsColumnActive(DataViewSchema.Column column) + { + return true; + } + + protected override void Dispose(bool disposing) + { + if (!_transformer.IsClosed) + _transformer.Close(); + } + + /// + /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row. + /// This throws if the column is not active in this row, or if the type + /// differs from this column's type. + /// + /// is the column's content type. + /// is the output column whose getter should be returned. + public override ValueGetter GetGetter(DataViewSchema.Column column) + { + _ch.Check(IsColumnActive(column)); + + var fn = _allColumns[column.Name].GetGetter() as ValueGetter; + if (fn == null) + throw _ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue)); + return fn; + } + + public override bool MoveNext() + { + _position++; + _isGood = _allColumns[IsRowImputedColumnName].MoveNext(_input); + return _isGood; + } + + public sealed override long Position => _position; + + public sealed override long Batch => _input.Batch; + } + } +} diff --git a/src/Microsoft.ML.Featurizers/ToStringTransformer.cs b/src/Microsoft.ML.Featurizers/ToStringTransformer.cs new file mode 100644 index 0000000000..2ea4687d6b --- /dev/null +++ b/src/Microsoft.ML.Featurizers/ToStringTransformer.cs @@ -0,0 +1,1630 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Linq; +using System.Runtime.InteropServices; +using System.Security; +using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Featurizers; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Runtime; +using Microsoft.ML.Transforms; +using Microsoft.Win32.SafeHandles; +using static Microsoft.ML.Featurizers.CommonExtensions; + +[assembly: LoadableClass(typeof(ToStringTransformer), null, typeof(SignatureLoadModel), + ToStringTransformer.UserName, ToStringTransformer.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(ToStringTransformer), null, typeof(SignatureLoadRowMapper), + ToStringTransformer.UserName, ToStringTransformer.LoaderSignature)] + +[assembly: EntryPointModule(typeof(ToStringTransformerEntrypoint))] + +namespace Microsoft.ML.Featurizers +{ + public static class ToStringTransformerExtensionClass + { + /// + /// Create a , which converts the input column specified by + /// into a string representation of its contents . + /// + /// The transform catalog. + /// Name of the column resulting from the transformation of . + /// This column's data type will be of type + /// Name of column to convert to its string representation. If set to , the value of the + /// will be used as source. This column's data type can be scalar of numeric, text, and boolean + /// + public static ToStringTransformerEstimator ToStringTransformer(this TransformsCatalog catalog, string outputColumnName, string inputColumnName = null) + => ToStringTransformerEstimator.Create(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName); + + /// + /// Create a , which converts each input column in specified by + /// into a string representation of its contents and stores it in the column specified by . + /// The input column data type can be scalar of numeric, text, and boolean + /// + /// The transform catalog. + /// Array of . The output column data type will be of type + /// + public static ToStringTransformerEstimator ToStringTransformer(this TransformsCatalog catalog, params InputOutputColumnPair[] columns) + => ToStringTransformerEstimator.Create(CatalogUtils.GetEnvironment(catalog), columns); + } + + /// + /// Converts one or more input columns into string representations of its contents. Supports input column's of data type numeric, text, and boolean + /// + /// + /// is a trivial estimator that doesn't need training. + /// The resulting converts one or more input columns into its appropriate string representation. + /// + /// The ToStringTransformer can be applied to one or more columns, in which case it turns each input type into its appropriate string represenation. + /// + /// ]]> + /// + /// + /// + /// + public sealed class ToStringTransformerEstimator : IEstimator + { + private Options _options; + + private readonly IHost _host; + + #region Options + + internal sealed class Column : OneToOneColumn + { + internal static Column Parse(string str) + { + Contracts.AssertNonEmpty(str); + + var res = new Column(); + if (res.TryParse(str)) + return res; + return null; + } + + internal bool TryUnparse(StringBuilder sb) + { + Contracts.AssertValue(sb); + return TryUnparseCore(sb); + } + } + + internal sealed class Options : TransformInputBase + { + [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition (optional form: name:src)", + Name = "Column", ShortName = "col", SortOrder = 1)] + public Column[] Columns; + } + #endregion + + internal static ToStringTransformerEstimator Create(IHostEnvironment env, string outputColumnName, string inputColumnName) + { + return new ToStringTransformerEstimator(env, outputColumnName, inputColumnName); + } + + internal static ToStringTransformerEstimator Create(IHostEnvironment env, params InputOutputColumnPair[] columns) + { + var columnOptions = columns.Select(x => new Column { Name = x.OutputColumnName, Source = x.InputColumnName ?? x.OutputColumnName }).ToArray(); + return new ToStringTransformerEstimator(env, new Options { Columns = columnOptions }); + } + + internal ToStringTransformerEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(ToStringTransformerEstimator)); + + _options = new Options + { + Columns = new Column[1] { new Column() { Name = outputColumnName, Source = inputColumnName ?? outputColumnName } } + }; + } + + internal ToStringTransformerEstimator(IHostEnvironment env, Options options) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(ToStringTransformerEstimator)); + + foreach(var columnPair in options.Columns) + { + columnPair.Source = columnPair.Source ?? columnPair.Name; + } + _options = options; + + } + + public ToStringTransformer Fit(IDataView input) + { + return new ToStringTransformer(_host, _options.Columns, input); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + var columns = inputSchema.ToDictionary(x => x.Name); + + foreach (var column in _options.Columns) + { + columns[column.Name] = new SchemaShape.Column(column.Name, SchemaShape.Column.VectorKind.Scalar, + ColumnTypeExtensions.PrimitiveTypeFromType(typeof(string)), false, null); + } + + return new SchemaShape(columns.Values); + } + } + + public sealed class ToStringTransformer : RowToRowTransformerBase, IDisposable + { + #region Class data members + + internal const string Summary = "Turns the given column into a column of its string representation"; + internal const string UserName = "ToString Transform"; + internal const string ShortName = "ToStringTransform"; + internal const string LoadName = "ToStringTransform"; + internal const string LoaderSignature = "ToStringTransform"; + + private TypedColumn[] _columns; + + #endregion + + internal ToStringTransformer(IHostEnvironment host, ToStringTransformerEstimator.Column[] columns, IDataView input) : + base(host.Register(nameof(ToStringTransformer))) + { + var schema = input.Schema; + + _columns = columns.Select(x => TypedColumn.CreateTypedColumn(x.Name, x.Source, schema[x.Source].Type.RawType.ToString())).ToArray(); + foreach (var column in _columns) + { + // No training is required so directly create the transformer + column.CreateTransformerFromEstimator(); + } + } + + // Factory method for SignatureLoadModel. + internal ToStringTransformer(IHostEnvironment host, ModelLoadContext ctx) : + base(host.Register(nameof(ToStringTransformer))) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + // *** Binary format *** + // int number of column pairs + // for each column pair: + // string output column name + // string input column name + // string representation of type + + var columnCount = ctx.Reader.ReadInt32(); + + _columns = new TypedColumn[columnCount]; + for (int i = 0; i < columnCount; i++) + { + _columns[i] = TypedColumn.CreateTypedColumn(ctx.Reader.ReadString(), ctx.Reader.ReadString(), ctx.Reader.ReadString()); + + var dataLength = ctx.Reader.ReadInt32(); + var data = ctx.Reader.ReadByteArray(dataLength); + _columns[i].CreateTransformerFromSavedData(data); + + } + } + + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema) + => new ToStringTransformer(env, ctx).MakeRowMapper(inputSchema); + + private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema); + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "TOSTRI T", + verWrittenCur: 0x00010001, + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ToStringTransformer).Assembly.FullName); + } + + private protected override void SaveModel(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // int number of column pairs + // for each column pair: + // string output column name + // string input column name + // string representation of type + // int c++ state array length + + ctx.Writer.Write(_columns.Count()); + foreach (var column in _columns) + { + ctx.Writer.Write(column.Name); + ctx.Writer.Write(column.Source); + ctx.Writer.Write(column.Type); + + var data = column.CreateTransformerSaveData(); + ctx.Writer.Write(data.Length); + ctx.Writer.Write(data); + } + } + + public void Dispose() + { + foreach (var column in _columns) + { + column.Dispose(); + } + } + + #region C++ Safe handle classes + + internal class TransformedDataSafeHandle : SafeHandleZeroOrMinusOneIsInvalid + { + private DestroyTransformedDataNative _destroySaveDataHandler; + private IntPtr _dataSize; + + public TransformedDataSafeHandle(IntPtr handle, IntPtr dataSize, DestroyTransformedDataNative destroyCppTransformerEstimator) : base(true) + { + SetHandle(handle); + _dataSize = dataSize; + _destroySaveDataHandler = destroyCppTransformerEstimator; + } + + protected override bool ReleaseHandle() + { + // Not sure what to do with error stuff here. There shoudln't ever be one though. + return _destroySaveDataHandler(handle, _dataSize, out IntPtr errorHandle); + } + } + + #endregion + + #region ColumnInfo + + // REVIEW: Since we can't do overloading on the native side due to the C style exports, + // this was the best way I could think handle it to allow for any conversions needed based on the data type. + + #region BaseClass + + internal delegate bool DestroyCppTransformerEstimator(IntPtr estimator, out IntPtr errorHandle); + internal delegate bool DestroyTransformerSaveData(IntPtr buffer, IntPtr bufferSize, out IntPtr errorHandle); + internal delegate bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + + internal abstract class TypedColumn : IDisposable + { + internal readonly string Name; + internal readonly string Source; + internal readonly string Type; + internal TypedColumn(string name, string source, string type) + { + Name = name; + Source = source; + Type = type; + } + + internal abstract void CreateTransformerFromEstimator(); + private protected abstract unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize); + private protected abstract bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle); + private protected abstract bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + private protected abstract bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle); + private protected abstract bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle); + private protected abstract bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle); + public abstract void Dispose(); + + private protected TransformerEstimatorSafeHandle CreateTransformerFromEstimatorBase() + { + var success = CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle); + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + + using (var estimatorHandler = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorHelper)) + { + + success = CreateTransformerFromEstimatorHelper(estimatorHandler, out IntPtr transformer, out errorHandle); + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + + return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerHelper); + } + } + + internal byte[] CreateTransformerSaveData() + { + + var success = CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + using (var savedDataHandle = new SaveDataSafeHandle(buffer, bufferSize)) + { + byte[] savedData = new byte[bufferSize.ToInt32()]; + Marshal.Copy(buffer, savedData, 0, savedData.Length); + return savedData; + } + } + + internal unsafe void CreateTransformerFromSavedData(byte[] data) + { + fixed (byte* rawData = data) + { + IntPtr dataSize = new IntPtr(data.Count()); + CreateTransformerFromSavedDataHelper(rawData, dataSize); + } + } + + internal static TypedColumn CreateTypedColumn(string name, string source, string type) + { + if (type == typeof(sbyte).ToString()) + { + return new Int8TypedColumn(name, source); + } + else if (type == typeof(short).ToString()) + { + return new Int16TypedColumn(name, source); + } + else if (type == typeof(int).ToString()) + { + return new Int32TypedColumn(name, source); + } + else if (type == typeof(long).ToString()) + { + return new Int64TypedColumn(name, source); + } + else if (type == typeof(byte).ToString()) + { + return new UInt8TypedColumn(name, source); + } + else if (type == typeof(ushort).ToString()) + { + return new UInt16TypedColumn(name, source); + } + else if (type == typeof(uint).ToString()) + { + return new UInt32TypedColumn(name, source); + } + else if (type == typeof(ulong).ToString()) + { + return new UInt64TypedColumn(name, source); + } + else if (type == typeof(float).ToString()) + { + return new FloatTypedColumn(name, source); + } + else if (type == typeof(double).ToString()) + { + return new DoubleTypedColumn(name, source); + } + else if (type == typeof(bool).ToString()) + { + return new BoolTypedColumn(name, source); + } + else if (type == typeof(string).ToString()) + { + return new StringTypedColumn(name, source); + } + else if (type == typeof(ReadOnlyMemory).ToString()) + { + return new ReadOnlyCharTypedColumn(name, source); + } + + throw new Exception($"Unsupported type {type}"); + } + } + + internal abstract class TypedColumn : TypedColumn + { + internal TypedColumn(string name, string source, string type) : + base(name, source, type) + { + } + + internal abstract string Transform(T input); + + } + + #endregion + + #region Int8Column + + internal sealed class Int8TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + internal Int8TypedColumn(string name, string source) : + base(name, source, typeof(sbyte).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int8_t_CreateEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int8_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int8_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int8_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator() + { + _transformerHandler = CreateTransformerFromEstimatorBase(); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int8_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int8_t_Transform"), SuppressUnmanagedCodeSecurity] + private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, sbyte input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int8_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + internal override string Transform(sbyte input) + { + var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative)) + { + byte[] buffer = new byte[outputSize.ToInt32()]; + Marshal.Copy(output, buffer, 0, buffer.Length); + return Encoding.UTF8.GetString(buffer); + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int8_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region Int16Column + + internal sealed class Int16TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + + internal Int16TypedColumn(string name, string source) : + base(name, source, typeof(short).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int16_t_CreateEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int16_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int16_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator() + { + _transformerHandler = CreateTransformerFromEstimatorBase(); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int16_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int16_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int16_t_Transform"), SuppressUnmanagedCodeSecurity] + private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, short input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int16_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + internal override string Transform(short input) + { + var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative)) + { + byte[] buffer = new byte[outputSize.ToInt32()]; + Marshal.Copy(output, buffer, 0, buffer.Length); + return Encoding.UTF8.GetString(buffer); + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int16_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region Int32Column + + internal sealed class Int32TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + internal Int32TypedColumn(string name, string source) : + base(name, source, typeof(int).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int32_t_CreateEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int32_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int32_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator() + { + _transformerHandler = CreateTransformerFromEstimatorBase(); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int32_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int32_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int32_t_Transform"), SuppressUnmanagedCodeSecurity] + private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, int input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int32_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + internal override string Transform(int input) + { + var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative)) + { + byte[] buffer = new byte[outputSize.ToInt32()]; + Marshal.Copy(output, buffer, 0, buffer.Length); + return Encoding.UTF8.GetString(buffer); + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int32_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region Int64Column + + internal sealed class Int64TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + internal Int64TypedColumn(string name, string source) : + base(name, source, typeof(long).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int64_t_CreateEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int64_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int64_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator() + { + _transformerHandler = CreateTransformerFromEstimatorBase(); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int64_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int64_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int64_t_Transform"), SuppressUnmanagedCodeSecurity] + private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, long input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int64_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + internal override string Transform(long input) + { + var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative)) + { + byte[] buffer = new byte[outputSize.ToInt32()]; + Marshal.Copy(output, buffer, 0, buffer.Length); + return Encoding.UTF8.GetString(buffer); + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int64_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region UInt8Column + + internal sealed class UInt8TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + internal UInt8TypedColumn(string name, string source) : + base(name, source, typeof(byte).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint8_t_CreateEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint8_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint8_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator() + { + _transformerHandler = CreateTransformerFromEstimatorBase(); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint8_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint8_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint8_t_Transform"), SuppressUnmanagedCodeSecurity] + private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, byte input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint8_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + internal override string Transform(byte input) + { + var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative)) + { + byte[] buffer = new byte[outputSize.ToInt32()]; + Marshal.Copy(output, buffer, 0, buffer.Length); + return Encoding.UTF8.GetString(buffer); + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint8_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region UInt16Column + + internal sealed class UInt16TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + internal UInt16TypedColumn(string name, string source) : + base(name, source, typeof(ushort).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint16_t_CreateEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint16_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint16_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator() + { + _transformerHandler = CreateTransformerFromEstimatorBase(); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint16_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint16_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint16_t_Transform"), SuppressUnmanagedCodeSecurity] + private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, ushort input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint16_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + internal override string Transform(ushort input) + { + var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative)) + { + byte[] buffer = new byte[outputSize.ToInt32()]; + Marshal.Copy(output, buffer, 0, buffer.Length); + return Encoding.UTF8.GetString(buffer); + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint16_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region UInt32Column + + internal sealed class UInt32TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + + internal UInt32TypedColumn(string name, string source) : + base(name, source, typeof(uint).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint32_t_CreateEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint32_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint32_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator() + { + _transformerHandler = CreateTransformerFromEstimatorBase(); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint32_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint32_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint32_t_Transform"), SuppressUnmanagedCodeSecurity] + private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, uint input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint32_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + internal override string Transform(uint input) + { + var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative)) + { + byte[] buffer = new byte[outputSize.ToInt32()]; + Marshal.Copy(output, buffer, 0, buffer.Length); + return Encoding.UTF8.GetString(buffer); + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint32_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region UInt64Column + + internal sealed class UInt64TypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + + internal UInt64TypedColumn(string name, string source) : + base(name, source, typeof(ulong).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint64_t_CreateEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint64_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint64_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator() + { + _transformerHandler = CreateTransformerFromEstimatorBase(); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint64_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint64_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint64_t_Transform"), SuppressUnmanagedCodeSecurity] + private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, ulong input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint64_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + internal override string Transform(ulong input) + { + var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative)) + { + byte[] buffer = new byte[outputSize.ToInt32()]; + Marshal.Copy(output, buffer, 0, buffer.Length); + return Encoding.UTF8.GetString(buffer); + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint64_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region FloatColumn + + internal sealed class FloatTypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + + internal FloatTypedColumn(string name, string source) : + base(name, source, typeof(float).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_float_t_CreateEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_float_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_float_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator() + { + _transformerHandler = CreateTransformerFromEstimatorBase(); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_float_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_float_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_float_t_Transform"), SuppressUnmanagedCodeSecurity] + + private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, float input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_float_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + internal override string Transform(float input) + { + var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative)) + { + byte[] buffer = new byte[outputSize.ToInt32()]; + Marshal.Copy(output, buffer, 0, buffer.Length); + return Encoding.UTF8.GetString(buffer); + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_float_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region DoubleColumn + + internal sealed class DoubleTypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + + internal DoubleTypedColumn(string name, string source) : + base(name, source, typeof(double).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_double_t_CreateEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_double_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_double_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator() + { + _transformerHandler = CreateTransformerFromEstimatorBase(); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_double_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_double_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_double_t_Transform"), SuppressUnmanagedCodeSecurity] + private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, double input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_double_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + internal override string Transform(double input) + { + var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative)) + { + byte[] buffer = new byte[outputSize.ToInt32()]; + Marshal.Copy(output, buffer, 0, buffer.Length); + return Encoding.UTF8.GetString(buffer); + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_double_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region BoolColumn + + internal sealed class BoolTypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + + internal BoolTypedColumn(string name, string source) : + base(name, source, typeof(bool).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_bool_CreateEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_bool_DestroyEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_bool_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator() + { + _transformerHandler = CreateTransformerFromEstimatorBase(); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_bool_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_bool_DestroyTransformer"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_bool_Transform"), SuppressUnmanagedCodeSecurity] + private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, bool input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_bool_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + internal override string Transform(bool input) + { + var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative)) + { + byte[] buffer = new byte[outputSize.ToInt32()]; + Marshal.Copy(output, buffer, 0, buffer.Length); + return Encoding.UTF8.GetString(buffer); + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_bool_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region StringColumn + + internal sealed class StringTypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + + internal StringTypedColumn(string name, string source) : + base(name, source, typeof(string).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_CreateEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_DestroyEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator() + { + _transformerHandler = CreateTransformerFromEstimatorBase(); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_DestroyTransformer"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_Transform"), SuppressUnmanagedCodeSecurity] + private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, IntPtr input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + internal override string Transform(string input) + { + // Convert to byte array with NullPointer at end. + var rawData = Encoding.UTF8.GetBytes(input + char.MinValue); + bool result; + GCHandle handle = GCHandle.Alloc(rawData, GCHandleType.Pinned); + try + { + IntPtr rawDataPtr = handle.AddrOfPinnedObject(); + result = TransformDataNative(_transformerHandler, rawDataPtr, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + + if (!result) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + + using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative)) + { + byte[] buffer = new byte[outputSize.ToInt32()]; + Marshal.Copy(output, buffer, 0, buffer.Length); + return Encoding.UTF8.GetString(buffer); + } + } + finally + { + handle.Free(); + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #region ReadOnlyCharColumn + + internal sealed class ReadOnlyCharTypedColumn : TypedColumn> + { + private TransformerEstimatorSafeHandle _transformerHandler; + + internal ReadOnlyCharTypedColumn(string name, string source) : + base(name, source, typeof(ReadOnlyMemory).ToString()) + { + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_CreateEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_DestroyEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + internal override void CreateTransformerFromEstimator() + { + _transformerHandler = CreateTransformerFromEstimatorBase(); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_DestroyTransformer"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_Transform"), SuppressUnmanagedCodeSecurity] + private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, IntPtr input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle); + internal override string Transform(ReadOnlyMemory input) + { + var rawData = Encoding.UTF8.GetBytes(input.ToString() + char.MinValue); + bool result; + GCHandle handle = GCHandle.Alloc(rawData, GCHandleType.Pinned); + try + { + IntPtr rawDataPtr = handle.AddrOfPinnedObject(); + result = TransformDataNative(_transformerHandler, rawDataPtr, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle); + + if (!result) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + + if (outputSize.ToInt32() == 0) + return string.Empty; + + using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative)) + { + byte[] buffer = new byte[outputSize.ToInt32()]; + Marshal.Copy(output, buffer, 0, buffer.Length); + return Encoding.UTF8.GetString(buffer); + } + } + finally + { + handle.Free(); + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + #endregion + + private sealed class Mapper : MapperBase + { + + #region Class data members + + private readonly ToStringTransformer _parent; + + #endregion + + public Mapper(ToStringTransformer parent, DataViewSchema inputSchema) : + base(parent.Host.Register(nameof(Mapper)), inputSchema, parent) + { + _parent = parent; + } + + protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() + { + return _parent._columns.Select(x => new DataViewSchema.DetachedColumn(x.Name, ColumnTypeExtensions.PrimitiveTypeFromType(typeof(string)))).ToArray(); + } + + private Delegate MakeGetter(DataViewRow input, int iinfo) + { + ValueGetter> result = (ref ReadOnlyMemory dst) => + { + var inputColumn = input.Schema[_parent._columns[iinfo].Source]; + + var srcGetter = input.GetGetter(inputColumn); + + T value = default; + srcGetter(ref value); + string transformed = ((TypedColumn)_parent._columns[iinfo]).Transform(value); + dst = new ReadOnlyMemory(transformed.ToArray()); + }; + + return result; + } + + protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer) + { + disposer = null; + Type inputType = input.Schema[_parent._columns[iinfo].Source].Type.RawType; + return Utils.MarshalInvoke(MakeGetter, inputType, input, iinfo); + } + + private protected override Func GetDependenciesCore(Func activeOutput) + { + var active = new bool[InputSchema.Count]; + for (int i = 0; i < InputSchema.Count; i++) + { + if (_parent._columns.Any(x => x.Source == InputSchema[i].Name)) + { + active[i] = true; + } + } + + return col => active[col]; + } + + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); + } + } + + internal static class ToStringTransformerEntrypoint + { + [TlcModule.EntryPoint(Name = "Transforms.ToString", + Desc = ToStringTransformer.Summary, + UserName = ToStringTransformer.UserName, + ShortName = ToStringTransformer.ShortName)] + public static CommonOutputs.TransformOutput ToString(IHostEnvironment env, ToStringTransformerEstimator.Options input) + { + var h = EntryPointUtils.CheckArgsAndCreateHost(env, ToStringTransformer.ShortName, input); + var xf = new ToStringTransformerEstimator(h, input).Fit(input.Data).Transform(input.Data); + return new CommonOutputs.TransformOutput() + { + Model = new TransformModelImpl(h, xf, input.Data), + OutputData = xf + }; + } + } + +} diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 4f2bcc426a..fd4fcbc369 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -76,6 +76,8 @@ Transforms.BinaryPredictionScoreColumnsRenamer For binary prediction, it renames Transforms.BinNormalizer The values are assigned into equidensity bins and a value is mapped to its bin_number/number_of_bins. Microsoft.ML.Data.Normalize Bin Microsoft.ML.Transforms.NormalizeTransform+BinArguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.CategoricalHashOneHotVectorizer Converts the categorical value into an indicator array by hashing the value and using the hash as an index in the bag. If the input column is a vector, a single indicator bag is returned for it. Microsoft.ML.Transforms.Categorical CatTransformHash Microsoft.ML.Transforms.OneHotHashEncodingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.CategoricalOneHotVectorizer Converts the categorical value into an indicator array by building a dictionary of categories based on the data and using the id in the dictionary as the index in the array. Microsoft.ML.Transforms.Categorical CatTransformDict Microsoft.ML.Transforms.OneHotEncodingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.CategoryImputer Fills in missing values in a column based on the most frequent value Microsoft.ML.Featurizers.CategoryImputerEntrypoint ImputeToKey Microsoft.ML.Featurizers.CategoryImputerTransformerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.CatImputer Microsoft.ML.Featurizers.CatImputerEntrypoint CatImputer Microsoft.ML.Featurizers.CatImputerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.CharacterTokenizer Character-oriented tokenizer where text is considered a sequence of characters. Microsoft.ML.Transforms.Text.TextAnalytics CharTokenize Microsoft.ML.Transforms.Text.TokenizingByCharactersTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.ColumnConcatenator Concatenates one or more columns of the same item type. Microsoft.ML.EntryPoints.SchemaManipulation ConcatColumns Microsoft.ML.Data.ColumnConcatenatingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.ColumnCopier Duplicates columns from the dataset Microsoft.ML.EntryPoints.SchemaManipulation CopyColumns Microsoft.ML.Transforms.ColumnCopyingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput @@ -85,6 +87,7 @@ Transforms.CombinerByContiguousGroupId Groups values of a scalar column into a v Transforms.ConditionalNormalizer Normalize the columns only if needed Microsoft.ML.Data.Normalize IfNeeded Microsoft.ML.Transforms.NormalizeTransform+MinMaxArguments Microsoft.ML.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput] Transforms.DatasetScorer Score a dataset with a predictor model Microsoft.ML.EntryPoints.ScoreModel Score Microsoft.ML.EntryPoints.ScoreModel+Input Microsoft.ML.EntryPoints.ScoreModel+Output Transforms.DatasetTransformScorer Score a dataset with a transform model Microsoft.ML.EntryPoints.ScoreModel ScoreUsingTransform Microsoft.ML.EntryPoints.ScoreModel+InputTransformScorer Microsoft.ML.EntryPoints.ScoreModel+Output +Transforms.DateTimeSplitter Splits a date time value into each individual component Microsoft.ML.Featurizers.DateTimeTransformerEntrypoint DateTimeSplit Microsoft.ML.Featurizers.DateTimeTransformerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.Dictionarizer Converts input values (words, numbers, etc.) to index in a dictionary. Microsoft.ML.Transforms.Text.TextAnalytics TermTransform Microsoft.ML.Transforms.ValueToKeyMappingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.FeatureCombiner Combines all the features into one feature column. Microsoft.ML.EntryPoints.FeatureCombiner PrepareFeatures Microsoft.ML.EntryPoints.FeatureCombiner+FeatureCombinerInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.FeatureContributionCalculationTransformer For each data point, calculates the contribution of individual features to the model prediction. Microsoft.ML.Transforms.FeatureContributionEntryPoint FeatureContributionCalculation Microsoft.ML.Transforms.FeatureContributionCalculatingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput @@ -119,6 +122,8 @@ Transforms.PcaCalculator PCA is a dimensionality-reduction transform which compu Transforms.PermutationFeatureImportance Permutation Feature Importance (PFI) Microsoft.ML.Transforms.PermutationFeatureImportanceEntryPoints PermutationFeatureImportance Microsoft.ML.Transforms.PermutationFeatureImportanceArguments Microsoft.ML.Transforms.PermutationFeatureImportanceOutput Transforms.PredictedLabelColumnOriginalValueConverter Transforms a predicted label column to its original values, unless it is of type bool. Microsoft.ML.EntryPoints.FeatureCombiner ConvertPredictedLabel Microsoft.ML.EntryPoints.FeatureCombiner+PredictedLabelInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.RandomNumberGenerator Adds a column with a generated number sequence. Microsoft.ML.Transforms.RandomNumberGenerator Generate Microsoft.ML.Transforms.GenerateNumberTransform+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.RobScal Microsoft.ML.Featurizers.RobScalEntrypoint RobScal Microsoft.ML.Featurizers.RobScalEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.RobustScaler Removes the median and scales the data according to the quantile range. Microsoft.ML.Featurizers.RobustScalerEntrypoint RobustScaler Microsoft.ML.Featurizers.RobustScalerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.RowRangeFilter Filters a dataview on a column of type Single, Double or Key (contiguous). Keeps the values that are in the specified min/max range. NaNs are always filtered out. If the input is a Key type, the min/max are considered percentages of the number of values. Microsoft.ML.EntryPoints.SelectRows FilterByRange Microsoft.ML.Transforms.RangeFilter+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.RowSkipAndTakeFilter Allows limiting input to a subset of rows at an optional offset. Can be used to implement data paging. Microsoft.ML.EntryPoints.SelectRows SkipAndTakeFilter Microsoft.ML.Transforms.SkipTakeFilter+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.RowSkipFilter Allows limiting input to a subset of rows by skipping a number of rows. Microsoft.ML.EntryPoints.SelectRows SkipFilter Microsoft.ML.Transforms.SkipTakeFilter+SkipOptions Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput @@ -130,6 +135,8 @@ Transforms.SentimentAnalyzer Uses a pretrained sentiment model to score input st Transforms.TensorFlowScorer Transforms the data using the TensorFlow model. Microsoft.ML.Transforms.TensorFlowTransformer TensorFlowScorer Microsoft.ML.Transforms.TensorFlowEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.TextFeaturizer A transform that turns a collection of text documents into numerical feature vectors. The feature vectors are normalized counts of (word and/or character) n-grams in a given tokenized text. Microsoft.ML.Transforms.Text.TextAnalytics TextTransform Microsoft.ML.Transforms.Text.TextFeaturizingEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.TextToKeyConverter Converts input values (words, numbers, etc.) to index in a dictionary. Microsoft.ML.Transforms.Categorical TextToKey Microsoft.ML.Transforms.ValueToKeyMappingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.TimeSeriesImputer Fills in missing row and values Microsoft.ML.Featurizers.TimeSeriesTransformerEntrypoint TimeSeriesImputer Microsoft.ML.Featurizers.TimeSeriesImputerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.ToString Turns the given column into a column of its string representation Microsoft.ML.Featurizers.ToStringTransformerEntrypoint ToString Microsoft.ML.Featurizers.ToStringTransformerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.TrainTestDatasetSplitter Split the dataset into train and test sets Microsoft.ML.EntryPoints.TrainTestSplit Split Microsoft.ML.EntryPoints.TrainTestSplit+Input Microsoft.ML.EntryPoints.TrainTestSplit+Output Transforms.TreeLeafFeaturizer Trains a tree ensemble, or loads it from a file, then maps a numeric feature vector to three outputs: 1. A vector containing the individual tree outputs of the tree ensemble. 2. A vector indicating the leaves that the feature vector falls on in the tree ensemble. 3. A vector indicating the paths that the feature vector falls on in the tree ensemble. If a both a model file and a trainer are specified - will use the model file. If neither are specified, will train a default FastTree model. This can handle key labels by training a regression model towards their optionally permuted indices. Microsoft.ML.Data.TreeFeaturize Featurizer Microsoft.ML.Data.TreeEnsembleFeaturizerTransform+ArgumentsForEntryPoint Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.TwoHeterogeneousModelCombiner Combines a TransformModel and a PredictorModel into a single PredictorModel. Microsoft.ML.EntryPoints.ModelOperations CombineTwoModels Microsoft.ML.EntryPoints.ModelOperations+SimplePredictorModelInput Microsoft.ML.EntryPoints.ModelOperations+PredictorModelOutput diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index c8e6d6e55c..69da1640b9 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -17301,6 +17301,158 @@ "ITransformOutput" ] }, + { + "Name": "Transforms.CategoryImputer", + "Desc": "Fills in missing values in a column based on the most frequent value", + "FriendlyName": "CategoryImputer", + "ShortName": "CategoryImputer", + "Inputs": [ + { + "Name": "Column", + "Type": { + "Kind": "Array", + "ItemType": { + "Kind": "Struct", + "Fields": [ + { + "Name": "Name", + "Type": "String", + "Desc": "Name of the new column", + "Aliases": [ + "name" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Source", + "Type": "String", + "Desc": "Name of the source column", + "Aliases": [ + "src" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": null + } + ] + } + }, + "Desc": "New column definition (optional form: name:src)", + "Aliases": [ + "col" + ], + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + } + ], + "Outputs": [ + { + "Name": "OutputData", + "Type": "DataView", + "Desc": "Transformed dataset" + }, + { + "Name": "Model", + "Type": "TransformModel", + "Desc": "Transform model" + } + ], + "InputKind": [ + "ITransformInput" + ], + "OutputKind": [ + "ITransformOutput" + ] + }, + { + "Name": "Transforms.CatImputer", + "Desc": "", + "FriendlyName": "CatImputerTransformer", + "ShortName": "CatImputerTransformer", + "Inputs": [ + { + "Name": "Column", + "Type": { + "Kind": "Array", + "ItemType": { + "Kind": "Struct", + "Fields": [ + { + "Name": "Name", + "Type": "String", + "Desc": "Name of the new column", + "Aliases": [ + "name" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Source", + "Type": "String", + "Desc": "Name of the source column", + "Aliases": [ + "src" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": null + } + ] + } + }, + "Desc": "New column definition (optional form: name:src)", + "Aliases": [ + "col" + ], + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + } + ], + "Outputs": [ + { + "Name": "OutputData", + "Type": "DataView", + "Desc": "Transformed dataset" + }, + { + "Name": "Model", + "Type": "TransformModel", + "Desc": "Transform model" + } + ], + "InputKind": [ + "ITransformInput" + ], + "OutputKind": [ + "ITransformOutput" + ] + }, { "Name": "Transforms.CharacterTokenizer", "Desc": "Character-oriented tokenizer where text is considered a sequence of characters.", @@ -18077,6 +18229,157 @@ } ] }, + { + "Name": "Transforms.DateTimeSplitter", + "Desc": "Splits a date time value into each individual component", + "FriendlyName": "DateTime Transform", + "ShortName": "DateTimeTransform", + "Inputs": [ + { + "Name": "Source", + "Type": "String", + "Desc": "Input column", + "Aliases": [ + "src" + ], + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Prefix", + "Type": "String", + "Desc": "Output column prefix", + "Aliases": [ + "pre" + ], + "Required": true, + "SortOrder": 2.0, + "IsNullable": false + }, + { + "Name": "ColumnsToDrop", + "Type": { + "Kind": "Array", + "ItemType": { + "Kind": "Enum", + "Values": [ + "Year", + "Month", + "Day", + "Hour", + "Minute", + "Second", + "AmPm", + "Hour12", + "DayOfWeek", + "DayOfQuarter", + "DayOfYear", + "WeekOfMonth", + "QuarterOfYear", + "HalfOfYear", + "WeekIso", + "YearIso", + "MonthLabel", + "AmPmLabel", + "DayOfWeekLabel", + "HolidayName", + "IsPaidTimeOff" + ] + } + }, + "Desc": "Columns to drop after the DateTime Expansion", + "Aliases": [ + "drop" + ], + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Country", + "Type": { + "Kind": "Enum", + "Values": [ + "None", + "Argentina", + "Australia", + "Austria", + "Belarus", + "Belgium", + "Brazil", + "Canada", + "Colombia", + "Croatia", + "Czech", + "Denmark", + "England", + "Finland", + "France", + "Germany", + "Hungary", + "India", + "Ireland", + "IsleofMan", + "Italy", + "Japan", + "Mexico", + "Netherlands", + "NewZealand", + "NorthernIreland", + "Norway", + "Poland", + "Portugal", + "Scotland", + "Slovenia", + "SouthAfrica", + "Spain", + "Sweden", + "Switzerland", + "Ukraine", + "UnitedKingdom", + "UnitedStates", + "Wales" + ] + }, + "Desc": "Country to get holidays for. Defaults to none if not passed", + "Aliases": [ + "ctry" + ], + "Required": false, + "SortOrder": 4.0, + "IsNullable": false, + "Default": "None" + } + ], + "Outputs": [ + { + "Name": "OutputData", + "Type": "DataView", + "Desc": "Transformed dataset" + }, + { + "Name": "Model", + "Type": "TransformModel", + "Desc": "Transform model" + } + ], + "InputKind": [ + "ITransformInput" + ], + "OutputKind": [ + "ITransformOutput" + ] + }, { "Name": "Transforms.Dictionarizer", "Desc": "Converts input values (words, numbers, etc.) to index in a dictionary.", @@ -21932,15 +22235,215 @@ ] }, { - "Name": "Transforms.RowRangeFilter", - "Desc": "Filters a dataview on a column of type Single, Double or Key (contiguous). Keeps the values that are in the specified min/max range. NaNs are always filtered out. If the input is a Key type, the min/max are considered percentages of the number of values.", - "FriendlyName": "Range Filter", - "ShortName": "RangeFilter", + "Name": "Transforms.RobScal", + "Desc": "", + "FriendlyName": "RobScalTransformer", + "ShortName": "RobScalTransformer", "Inputs": [ { "Name": "Column", - "Type": "String", - "Desc": "Column", + "Type": { + "Kind": "Array", + "ItemType": { + "Kind": "Struct", + "Fields": [ + { + "Name": "Name", + "Type": "String", + "Desc": "Name of the new column", + "Aliases": [ + "name" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Source", + "Type": "String", + "Desc": "Name of the source column", + "Aliases": [ + "src" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": null + } + ] + } + }, + "Desc": "New column definition (optional form: name:src)", + "Aliases": [ + "col" + ], + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + } + ], + "Outputs": [ + { + "Name": "OutputData", + "Type": "DataView", + "Desc": "Transformed dataset" + }, + { + "Name": "Model", + "Type": "TransformModel", + "Desc": "Transform model" + } + ], + "InputKind": [ + "ITransformInput" + ], + "OutputKind": [ + "ITransformOutput" + ] + }, + { + "Name": "Transforms.RobustScaler", + "Desc": "Removes the median and scales the data according to the quantile range.", + "FriendlyName": "RobustScalerTransformer", + "ShortName": "RobustScalerTransformer", + "Inputs": [ + { + "Name": "Column", + "Type": { + "Kind": "Array", + "ItemType": { + "Kind": "Struct", + "Fields": [ + { + "Name": "Name", + "Type": "String", + "Desc": "Name of the new column", + "Aliases": [ + "name" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Source", + "Type": "String", + "Desc": "Name of the source column", + "Aliases": [ + "src" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": null + } + ] + } + }, + "Desc": "New column definition (optional form: name:src)", + "Aliases": [ + "col" + ], + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Center", + "Type": "Bool", + "Desc": "If True, center the data before scaling.", + "Aliases": [ + "ctr" + ], + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": true + }, + { + "Name": "Scale", + "Type": "Bool", + "Desc": "If True, scale the data to interquartile range.", + "Aliases": [ + "sc" + ], + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": true + }, + { + "Name": "QuantileMin", + "Type": "Float", + "Desc": "Min for the quantile range used to calculate scale.", + "Aliases": [ + "min" + ], + "Required": false, + "SortOrder": 4.0, + "IsNullable": false, + "Default": 25.0 + }, + { + "Name": "QuantileMax", + "Type": "Float", + "Desc": "Max for the quantile range used to calculate scale.", + "Aliases": [ + "max" + ], + "Required": false, + "SortOrder": 5.0, + "IsNullable": false, + "Default": 75.0 + } + ], + "Outputs": [ + { + "Name": "OutputData", + "Type": "DataView", + "Desc": "Transformed dataset" + }, + { + "Name": "Model", + "Type": "TransformModel", + "Desc": "Transform model" + } + ], + "InputKind": [ + "ITransformInput" + ], + "OutputKind": [ + "ITransformOutput" + ] + }, + { + "Name": "Transforms.RowRangeFilter", + "Desc": "Filters a dataview on a column of type Single, Double or Key (contiguous). Keeps the values that are in the specified min/max range. NaNs are always filtered out. If the input is a Key type, the min/max are considered percentages of the number of values.", + "FriendlyName": "Range Filter", + "ShortName": "RangeFilter", + "Inputs": [ + { + "Name": "Column", + "Type": "String", + "Desc": "Column", "Aliases": [ "col" ], @@ -22941,6 +23444,207 @@ "ITransformOutput" ] }, + { + "Name": "Transforms.TimeSeriesImputer", + "Desc": "Fills in missing row and values", + "FriendlyName": "TimeSeriesImputer", + "ShortName": "TimeSeriesImputer", + "Inputs": [ + { + "Name": "TimeSeriesColumn", + "Type": "String", + "Desc": "Column representing the time", + "Aliases": [ + "time" + ], + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "GrainColumns", + "Type": { + "Kind": "Array", + "ItemType": "String" + }, + "Desc": "List of grain columns", + "Aliases": [ + "grains" + ], + "Required": true, + "SortOrder": 2.0, + "IsNullable": false + }, + { + "Name": "FilterColumns", + "Type": { + "Kind": "Array", + "ItemType": "String" + }, + "Desc": "Columns to filter", + "Aliases": [ + "filters" + ], + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "FilterMode", + "Type": { + "Kind": "Enum", + "Values": [ + "NoFilter", + "Include", + "Exclude" + ] + }, + "Desc": "Filter mode. Either include or exclude", + "Aliases": [ + "fmode" + ], + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": "Exclude" + }, + { + "Name": "ImputeMode", + "Type": { + "Kind": "Enum", + "Values": [ + "ForwardFill", + "BackFill", + "Median", + "Interpolate" + ] + }, + "Desc": "Mode for imputing, defaults to ForwardFill if not provided", + "Aliases": [ + "mode" + ], + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": "ForwardFill" + }, + { + "Name": "SupressTypeErrors", + "Type": "Bool", + "Desc": "Supress the errors that would occur if a column and impute mode are imcompatible. If true, will skip the column. If false, will stop and throw an error.", + "Aliases": [ + "error" + ], + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": false + } + ], + "Outputs": [ + { + "Name": "OutputData", + "Type": "DataView", + "Desc": "Transformed dataset" + }, + { + "Name": "Model", + "Type": "TransformModel", + "Desc": "Transform model" + } + ], + "InputKind": [ + "ITransformInput" + ], + "OutputKind": [ + "ITransformOutput" + ] + }, + { + "Name": "Transforms.ToString", + "Desc": "Turns the given column into a column of its string representation", + "FriendlyName": "ToString Transform", + "ShortName": "ToStringTransform", + "Inputs": [ + { + "Name": "Column", + "Type": { + "Kind": "Array", + "ItemType": { + "Kind": "Struct", + "Fields": [ + { + "Name": "Name", + "Type": "String", + "Desc": "Name of the new column", + "Aliases": [ + "name" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Source", + "Type": "String", + "Desc": "Name of the source column", + "Aliases": [ + "src" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": null + } + ] + } + }, + "Desc": "New column definition (optional form: name:src)", + "Aliases": [ + "col" + ], + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + } + ], + "Outputs": [ + { + "Name": "OutputData", + "Type": "DataView", + "Desc": "Transformed dataset" + }, + { + "Name": "Model", + "Type": "TransformModel", + "Desc": "Transform model" + } + ], + "InputKind": [ + "ITransformInput" + ], + "OutputKind": [ + "ITransformOutput" + ] + }, { "Name": "Transforms.TrainTestDatasetSplitter", "Desc": "Split the dataset into train and test sets", diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 87a0c61f1c..cba56c6e68 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -7,6 +7,7 @@ using System.IO; using System.Linq; using System.Text.RegularExpressions; +using Microsoft.ML.Featurizers; using Microsoft.ML.Calibrators; using Microsoft.ML.Core.Tests.UnitTests; using Microsoft.ML.Data; @@ -328,6 +329,7 @@ public void EntryPointCatalogCheckDuplicateParams() Env.ComponentCatalog.RegisterAssembly(typeof(SymbolicSgdLogisticRegressionBinaryTrainer).Assembly); Env.ComponentCatalog.RegisterAssembly(typeof(SaveOnnxCommand).Assembly); Env.ComponentCatalog.RegisterAssembly(typeof(TimeSeriesProcessingEntryPoints).Assembly); + Env.ComponentCatalog.RegisterAssembly(typeof(TimeSeriesImputerEstimator).Assembly); Env.ComponentCatalog.RegisterAssembly(typeof(ParquetLoader).Assembly); var catalog = Env.ComponentCatalog; diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index 453fb8c6c8..1241918e3e 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -10,6 +10,7 @@ + diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs index 36eb2d1b6b..026684e1b4 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs @@ -5,6 +5,9 @@ using Microsoft.ML.Data; using Microsoft.ML.RunTests; using Microsoft.ML.Trainers; +using Microsoft.ML.Transforms; +using System; +using System.Collections.Generic; using Xunit; namespace Microsoft.ML.Scenarios diff --git a/test/Microsoft.ML.Tests/Transformers/CategoryImputerTests.cs b/test/Microsoft.ML.Tests/Transformers/CategoryImputerTests.cs new file mode 100644 index 0000000000..ad548fcda0 --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/CategoryImputerTests.cs @@ -0,0 +1,121 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; +using Microsoft.ML.RunTests; +using System; +using System.Collections.Generic; +using Xunit; +using Xunit.Abstractions; +using System.Linq; + +namespace Microsoft.ML.Tests.Transformers +{ + public class CategoryImputerTests : TestDataPipeBase + { + public CategoryImputerTests(ITestOutputHelper output) : base(output) + { + } + + private class SchemaAllTypes + { + public byte uint8_t; + public sbyte int8_t; + public short int16_t; + public ushort uint16_t; + public int int32_t; + public uint uint32_t; + public long int64_t; + public ulong uint64_t; + public float float_t; + public double double_t; + public string str; + + internal SchemaAllTypes(byte num_i, float num_f, string s) + { + uint8_t = num_i; + int8_t = (sbyte)num_i; + int16_t = num_i; + uint16_t = num_i; + int32_t = num_i; + uint32_t = num_i; + int64_t = num_i; + uint64_t = num_i; + float_t = num_f; + double_t = num_f; + str = s; + } + } + + private IDataView GetIDataView() + { + List dataList = new List(); + dataList.Add(new SchemaAllTypes(1, 1.5f, "one")); + dataList.Add(new SchemaAllTypes(1, 1.5f, "one")); + dataList.Add(new SchemaAllTypes(2, 2.5f, "two")); + dataList.Add(new SchemaAllTypes(0, Single.NaN, null)); + dataList.Add(new SchemaAllTypes(1, 1.5f, "one")); + dataList.Add(new SchemaAllTypes(1, 1.5f, "one")); + dataList.Add(new SchemaAllTypes(2, 2.5f, "two")); + dataList.Add(new SchemaAllTypes(0, Single.NaN, null)); + dataList.Add(new SchemaAllTypes(1, 1.5f, "one")); + dataList.Add(new SchemaAllTypes(1, 1.5f, "one")); + dataList.Add(new SchemaAllTypes(2, 2.5f, "two")); + dataList.Add(new SchemaAllTypes(0, Single.NaN, null)); + + MLContext mlContext = new MLContext(1); + IDataView data = mlContext.Data.LoadFromEnumerable(dataList); + return data; + } + private void Test(string columnName, bool addNewCol, T mostFrequentValue) + { + + string outputColName = addNewCol ? columnName + "_output" : columnName; + string inputColName = addNewCol ? columnName : null; + + MLContext mlContext = new MLContext(1); + var data = GetIDataView(); + var pipeline = mlContext.Transforms.CatagoryImputerTransformer(outputColName, inputColName); + TestEstimatorCore(pipeline, data); + var model = pipeline.Fit(data); + var output = model.Transform(data); + + List transformedColData = output.GetColumn(outputColName).ToList(); + Assert.Equal(mostFrequentValue, transformedColData[3]); + Assert.Equal(mostFrequentValue, transformedColData[7]); + Assert.Equal(mostFrequentValue, transformedColData[11]); + } + + [Fact] + public void TestAllTypes() + { + Test("int8_t", false, 1); + Test("int8_t", true, 1); + Test("uint8_t", false, 1); + Test("uint8_t", true, 1); + Test("int16_t", false, 1); + Test("int16_t", true, 1); + Test("uint16_t", false, 1); + Test("uint16_t", true, 1); + Test("int32_t", false, 1); + Test("int32_t", true, 1); + Test("uint32_t", false, 1); + Test("uint32_t", true, 1); + Test("int64_t", false, 1); + Test("int64_t", true, 1); + Test("uint64_t", false, 1); + Test("uint64_t", true, 1); + Test("float_t", false, 1.5f); + Test("float_t", true, 1.5f); + Test("double_t", false, 1.5); + Test("double_t", true, 1.5); + Test("str", false, "one"); + Test("str", true, "one"); + + Done(); + } + + } +} diff --git a/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs b/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs new file mode 100644 index 0000000000..62a9f234c3 --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs @@ -0,0 +1,360 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.ML.RunTests; +using Microsoft.ML.Featurizers; +using System; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.Transformers +{ + public class DateTimeTransformerTests : TestDataPipeBase + { + public DateTimeTransformerTests(ITestOutputHelper output) : base(output) + { + } + + private class DateTimeInput + { + public long date; + } + + [Fact] + public void CorrectNumberOfColumnsAndSchema() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 0} }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var columnPrefix = "DTC_"; + var pipeline = mlContext.Transforms.DateTimeTransformer("date", columnPrefix); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + + // Check the schema has 22 columns + Assert.Equal(22, schema.Count); + + // Make sure names with prefix and order are correct + Assert.Equal($"{columnPrefix}Year", schema[1].Name); + Assert.Equal($"{columnPrefix}Month", schema[2].Name); + Assert.Equal($"{columnPrefix}Day", schema[3].Name); + Assert.Equal($"{columnPrefix}Hour", schema[4].Name); + Assert.Equal($"{columnPrefix}Minute", schema[5].Name); + Assert.Equal($"{columnPrefix}Second", schema[6].Name); + Assert.Equal($"{columnPrefix}AmPm", schema[7].Name); + Assert.Equal($"{columnPrefix}Hour12", schema[8].Name); + Assert.Equal($"{columnPrefix}DayOfWeek", schema[9].Name); + Assert.Equal($"{columnPrefix}DayOfQuarter", schema[10].Name); + Assert.Equal($"{columnPrefix}DayOfYear", schema[11].Name); + Assert.Equal($"{columnPrefix}WeekOfMonth", schema[12].Name); + Assert.Equal($"{columnPrefix}QuarterOfYear", schema[13].Name); + Assert.Equal($"{columnPrefix}HalfOfYear", schema[14].Name); + Assert.Equal($"{columnPrefix}WeekIso", schema[15].Name); + Assert.Equal($"{columnPrefix}YearIso", schema[16].Name); + Assert.Equal($"{columnPrefix}MonthLabel", schema[17].Name); + Assert.Equal($"{columnPrefix}AmPmLabel", schema[18].Name); + Assert.Equal($"{columnPrefix}DayOfWeekLabel", schema[19].Name); + Assert.Equal($"{columnPrefix}HolidayName", schema[20].Name); + Assert.Equal($"{columnPrefix}IsPaidTimeOff", schema[21].Name); + + // Make sure types are correct + Assert.Equal(typeof(int), schema[1].Type.RawType); + Assert.Equal(typeof(byte), schema[2].Type.RawType); + Assert.Equal(typeof(byte), schema[3].Type.RawType); + Assert.Equal(typeof(byte), schema[4].Type.RawType); + Assert.Equal(typeof(byte), schema[5].Type.RawType); + Assert.Equal(typeof(byte), schema[6].Type.RawType); + Assert.Equal(typeof(byte), schema[7].Type.RawType); + Assert.Equal(typeof(byte), schema[8].Type.RawType); + Assert.Equal(typeof(byte), schema[9].Type.RawType); + Assert.Equal(typeof(byte), schema[10].Type.RawType); + Assert.Equal(typeof(ushort), schema[11].Type.RawType); + Assert.Equal(typeof(ushort), schema[12].Type.RawType); + Assert.Equal(typeof(byte), schema[13].Type.RawType); + Assert.Equal(typeof(byte), schema[14].Type.RawType); + Assert.Equal(typeof(byte), schema[15].Type.RawType); + Assert.Equal(typeof(int), schema[16].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[17].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[18].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[19].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[20].Type.RawType); + Assert.Equal(typeof(byte), schema[21].Type.RawType); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void DropOneColumn() + { + // TODO: This will fail until we figure out the C++ dll situation + + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 0 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var columnPrefix = "DTC_"; + var pipeline = mlContext.Transforms.DateTimeTransformer("date", columnPrefix, DateTimeTransformerEstimator.ColumnsProduced.IsPaidTimeOff); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + + // Check the schema has 21 columns + Assert.Equal(21, schema.Count); + + // Make sure names with prefix and order are correct + Assert.Equal($"{columnPrefix}Year", schema[1].Name); + Assert.Equal($"{columnPrefix}Month", schema[2].Name); + Assert.Equal($"{columnPrefix}Day", schema[3].Name); + Assert.Equal($"{columnPrefix}Hour", schema[4].Name); + Assert.Equal($"{columnPrefix}Minute", schema[5].Name); + Assert.Equal($"{columnPrefix}Second", schema[6].Name); + Assert.Equal($"{columnPrefix}AmPm", schema[7].Name); + Assert.Equal($"{columnPrefix}Hour12", schema[8].Name); + Assert.Equal($"{columnPrefix}DayOfWeek", schema[9].Name); + Assert.Equal($"{columnPrefix}DayOfQuarter", schema[10].Name); + Assert.Equal($"{columnPrefix}DayOfYear", schema[11].Name); + Assert.Equal($"{columnPrefix}WeekOfMonth", schema[12].Name); + Assert.Equal($"{columnPrefix}QuarterOfYear", schema[13].Name); + Assert.Equal($"{columnPrefix}HalfOfYear", schema[14].Name); + Assert.Equal($"{columnPrefix}WeekIso", schema[15].Name); + Assert.Equal($"{columnPrefix}YearIso", schema[16].Name); + Assert.Equal($"{columnPrefix}MonthLabel", schema[17].Name); + Assert.Equal($"{columnPrefix}AmPmLabel", schema[18].Name); + Assert.Equal($"{columnPrefix}DayOfWeekLabel", schema[19].Name); + Assert.Equal($"{columnPrefix}HolidayName", schema[20].Name); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void DropManyColumns() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 0 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var columnPrefix = "DTC_"; + var pipeline = mlContext.Transforms.DateTimeTransformer("date", columnPrefix, DateTimeTransformerEstimator.ColumnsProduced.IsPaidTimeOff, + DateTimeTransformerEstimator.ColumnsProduced.Day, DateTimeTransformerEstimator.ColumnsProduced.QuarterOfYear, DateTimeTransformerEstimator.ColumnsProduced.AmPm); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + + // Check the schema has 18 columns + Assert.Equal(18, schema.Count); + + // Make sure names with prefix and order are correct + Assert.Equal($"{columnPrefix}Year", schema[1].Name); + Assert.Equal($"{columnPrefix}Month", schema[2].Name); + Assert.Equal($"{columnPrefix}Hour", schema[3].Name); + Assert.Equal($"{columnPrefix}Minute", schema[4].Name); + Assert.Equal($"{columnPrefix}Second", schema[5].Name); + Assert.Equal($"{columnPrefix}Hour12", schema[6].Name); + Assert.Equal($"{columnPrefix}DayOfWeek", schema[7].Name); + Assert.Equal($"{columnPrefix}DayOfQuarter", schema[8].Name); + Assert.Equal($"{columnPrefix}DayOfYear", schema[9].Name); + Assert.Equal($"{columnPrefix}WeekOfMonth", schema[10].Name); + Assert.Equal($"{columnPrefix}HalfOfYear", schema[11].Name); + Assert.Equal($"{columnPrefix}WeekIso", schema[12].Name); + Assert.Equal($"{columnPrefix}YearIso", schema[13].Name); + Assert.Equal($"{columnPrefix}MonthLabel", schema[14].Name); + Assert.Equal($"{columnPrefix}AmPmLabel", schema[15].Name); + Assert.Equal($"{columnPrefix}DayOfWeekLabel", schema[16].Name); + Assert.Equal($"{columnPrefix}HolidayName", schema[17].Name); + + // Make sure types are correct + Assert.Equal(typeof(int), schema[1].Type.RawType); + Assert.Equal(typeof(byte), schema[2].Type.RawType); + Assert.Equal(typeof(byte), schema[3].Type.RawType); + Assert.Equal(typeof(byte), schema[4].Type.RawType); + Assert.Equal(typeof(byte), schema[5].Type.RawType); + Assert.Equal(typeof(byte), schema[6].Type.RawType); + Assert.Equal(typeof(byte), schema[7].Type.RawType); + Assert.Equal(typeof(byte), schema[8].Type.RawType); + Assert.Equal(typeof(ushort), schema[9].Type.RawType); + Assert.Equal(typeof(ushort), schema[10].Type.RawType); + Assert.Equal(typeof(byte), schema[11].Type.RawType); + Assert.Equal(typeof(byte), schema[12].Type.RawType); + Assert.Equal(typeof(int), schema[13].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[14].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[15].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[16].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[17].Type.RawType); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void CanUseDateFromColumn() + { + // Future Date - 2025 June 30 + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 1751241600 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.DateTimeTransformer("date", "DTC"); + var model = pipeline.Fit(data); + var output = model.Transform(data); + + // Get the data from the first row and make sure it matches expected + var row = output.Preview(1).RowView[0].Values; + + // Assert the data from the first row is what we expect + Assert.Equal(2025, row[1].Value); // Year + Assert.Equal((byte)6, row[2].Value); // Month + Assert.Equal((byte)30, row[3].Value); // Day + Assert.Equal((byte)0, row[4].Value); // Hour + Assert.Equal((byte)0, row[5].Value); // Minute + Assert.Equal((byte)0, row[6].Value); // Second + Assert.Equal((byte)0, row[7].Value); // AmPm + Assert.Equal((byte)0, row[8].Value); // Hour12 + Assert.Equal((byte)1, row[9].Value); // DayOfWeek + Assert.Equal((byte)91, row[10].Value); // DayOfQuarter + Assert.Equal((ushort)180, row[11].Value); // DayOfYear + Assert.Equal((ushort)4, row[12].Value); // WeekOfMonth + Assert.Equal((byte)2, row[13].Value); // QuarterOfYear + Assert.Equal((byte)1, row[14].Value); // HalfOfYear + Assert.Equal((byte)27, row[15].Value); // WeekIso + Assert.Equal(2025, row[16].Value); // YearIso + Assert.Equal("June", row[17].Value.ToString()); // MonthLabel + Assert.Equal("am", row[18].Value.ToString()); // AmPmLabel + Assert.Equal("Monday", row[19].Value.ToString()); // DayOfWeekLabel + Assert.Equal("", row[20].Value.ToString()); // HolidayName + Assert.Equal((byte)0, row[21].Value); // IsPaidTimeOff + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void HolidayTest() + { + // Future Date - 2025 June 30 + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 157161600 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.DateTimeTransformer("date", "DTC", country: DateTimeTransformerEstimator.Countries.Canada); + var model = pipeline.Fit(data); + var output = model.Transform(data); + + // Get the data from the first row and make sure it matches expected + var row = output.Preview(1).RowView[0].Values; + + // Assert the data from the first row for holidays is what we expect + Assert.Equal("Christmas Day", row[20].Value.ToString()); // HolidayName + Assert.Equal((byte)0, row[21].Value); // IsPaidTimeOff + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void ManyRowsTest() + { + // Future Date - 2025 June 30 + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 1751241600 }, new DateTimeInput() { date = 1751241600 }, new DateTimeInput() { date = 12341 }, + new DateTimeInput() { date = 134 }, new DateTimeInput() { date = 134 }, new DateTimeInput() { date = 1234 }, new DateTimeInput() { date = 1751241600 }, + new DateTimeInput() { date = 1751241600 }, new DateTimeInput() { date = 12341 }, + new DateTimeInput() { date = 134 }, new DateTimeInput() { date = 134 }, new DateTimeInput() { date = 1234 }}; + + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.DateTimeTransformer("date", "DTC"); + var model = pipeline.Fit(data); + var output = model.Transform(data); + + // Get the data from the first row and make sure it matches expected + var row = output.Preview().RowView[0].Values; + + // Assert the data from the first row is what we expect + Assert.Equal(2025, row[1].Value); // Year + Assert.Equal((byte)6, row[2].Value); // Month + Assert.Equal((byte)30, row[3].Value); // Day + Assert.Equal((byte)0, row[4].Value); // Hour + Assert.Equal((byte)0, row[5].Value); // Minute + Assert.Equal((byte)0, row[6].Value); // Second + Assert.Equal((byte)0, row[7].Value); // AmPm + Assert.Equal((byte)0, row[8].Value); // Hour12 + Assert.Equal((byte)1, row[9].Value); // DayOfWeek + Assert.Equal((byte)91, row[10].Value); // DayOfQuarter + Assert.Equal((ushort)180, row[11].Value); // DayOfYear + Assert.Equal((ushort)4, row[12].Value); // WeekOfMonth + Assert.Equal((byte)2, row[13].Value); // QuarterOfYear + Assert.Equal((byte)1, row[14].Value); // HalfOfYear + Assert.Equal((byte)27, row[15].Value); // WeekIso + Assert.Equal(2025, row[16].Value); // YearIso + Assert.Equal("June", row[17].Value.ToString()); // MonthLabel + Assert.Equal("am", row[18].Value.ToString()); // AmPmLabel + Assert.Equal("Monday", row[19].Value.ToString()); // DayOfWeekLabel + Assert.Equal("", row[20].Value.ToString()); // HolidayName + Assert.Equal((byte)0, row[21].Value); // IsPaidTimeOff + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void EntryPointTest() + { + // Future Date - 2025 June 30 + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 1751241600 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var options = new DateTimeTransformerEstimator.Options + { + ColumnsToDrop = null, + Source = "date", + Prefix = "pref_", + Data = data + }; + + var entryOutput = DateTimeTransformerEntrypoint.DateTimeSplit(mlContext.Transforms.GetEnvironment(), options); + var output = entryOutput.OutputData; + + // Get the data from the first row and make sure it matches expected + var row = output.Preview(1).RowView[0].Values; + + // Assert the data from the first row is what we expect + Assert.Equal(2025, row[1].Value); // Year + Assert.Equal((byte)6, row[2].Value); // Month + Assert.Equal((byte)30, row[3].Value); // Day + Assert.Equal((byte)0, row[4].Value); // Hour + Assert.Equal((byte)0, row[5].Value); // Minute + Assert.Equal((byte)0, row[6].Value); // Second + Assert.Equal((byte)0, row[7].Value); // AmPm + Assert.Equal((byte)0, row[8].Value); // Hour12 + Assert.Equal((byte)1, row[9].Value); // DayOfWeek + Assert.Equal((byte)91, row[10].Value); // DayOfQuarter + Assert.Equal((ushort)180, row[11].Value); // DayOfYear + Assert.Equal((ushort)4, row[12].Value); // WeekOfMonth + Assert.Equal((byte)2, row[13].Value); // QuarterOfYear + Assert.Equal((byte)1, row[14].Value); // HalfOfYear + Assert.Equal((byte)27, row[15].Value); // WeekIso + Assert.Equal(2025, row[16].Value); // YearIso + Assert.Equal("June", row[17].Value.ToString()); // MonthLabel + Assert.Equal("am", row[18].Value.ToString()); // AmPmLabel + Assert.Equal("Monday", row[19].Value.ToString()); // DayOfWeekLabel + Assert.Equal("", row[20].Value.ToString()); // HolidayName + Assert.Equal((byte)0, row[21].Value); // IsPaidTimeOff + + Done(); + } + } +} diff --git a/test/Microsoft.ML.Tests/Transformers/RobustScalerTests.cs b/test/Microsoft.ML.Tests/Transformers/RobustScalerTests.cs new file mode 100644 index 0000000000..d4e415e9f8 --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/RobustScalerTests.cs @@ -0,0 +1,430 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; +using Microsoft.ML.RunTests; +using System; +using System.Collections.Generic; +using Xunit; +using Xunit.Abstractions; +using System.Linq; + +namespace Microsoft.ML.Tests.Transformers +{ + public class RobustScalerTests : TestDataPipeBase + { + public RobustScalerTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + public void TestInvalidType() + { + MLContext mlContext = new MLContext(1); + var dataList = new [] { new { ColA = "Invalid Type" }}; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, should error on fit and on GetOutputSchema + var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA"); + + Assert.Throws(() => pipeline.Fit(data)); + Assert.Throws(() => pipeline.GetOutputSchema(SchemaShape.Create(data.Schema))); + + Done(); + } + + [Fact] + public void TestNoScale() + { + MLContext mlContext = new MLContext(1); + var dataList = new [] { new { ColA = 1f }, new { ColA = 3f }, new { ColA = 5f }, new { ColA = 7f }, new { ColA = 9f } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA", scale: false); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + + Assert.Single(schema.Where(x => x.IsHidden == false)); + Assert.Equal(typeof(float), schema["ColA"].Type.RawType); + + var cursor = output.GetRowCursor(schema["ColA"]); + var expectedOutput = new[] { -4f, -2f, 0f, 2f, 4f }; + var index = 0; + var getter = cursor.GetGetter(schema["ColA"]); + float value = default; + + while (cursor.MoveNext()) + { + getter(ref value); + Assert.Equal(expectedOutput[index++], value); + } + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void TestNoScaleNoCenter() + { + MLContext mlContext = new MLContext(1); + var dataList = new [] { new { ColA = 1f }, new { ColA = 3f }, new { ColA = 5f }, new { ColA = 7f }, new { ColA = 9f } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA", scale: false, center: false); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + + Assert.Single(schema.Where(x => x.IsHidden == false)); + Assert.Equal(typeof(float), schema["ColA"].Type.RawType); + + var cursor = output.GetRowCursor(schema["ColA"]); + var expectedOutput = new[] { 1f, 3f, 5f, 7f, 9f }; + var index = 0; + var getter = cursor.GetGetter(schema["ColA"]); + float value = default; + + while (cursor.MoveNext()) + { + getter(ref value); + Assert.Equal(expectedOutput[index++], value); + } + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void TestFloat() + { + MLContext mlContext = new MLContext(1); + var dataList = new [] { new { ColA = 1f }, new { ColA = 3f }, new { ColA = 5f }, new { ColA = 7f }, new { ColA = 9f } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA"); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + + Assert.Single(schema.Where(x => x.IsHidden == false)); + Assert.Equal(typeof(float), schema["ColA"].Type.RawType); + + var cursor = output.GetRowCursor(schema["ColA"]); + var expectedOutput = new[] { -1f, -0.5f, 0f, .5f, 1f }; + var index = 0; + var getter = cursor.GetGetter(schema["ColA"]); + float value = default; + + while (cursor.MoveNext()) + { + getter(ref value); + Assert.Equal(expectedOutput[index++], value); + } + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void TestInt64() + { + MLContext mlContext = new MLContext(1); + var dataList = new [] { new { ColA = 1L }, new { ColA = 3L }, new { ColA = 5L }, new { ColA = 7L }, new { ColA = 9L } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA"); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + var prev = output.Preview(); + + Assert.Single(schema.Where(x => x.IsHidden == false)); + Assert.Equal(typeof(double), schema["ColA"].Type.RawType); + + var cursor = output.GetRowCursor(schema["ColA"]); + var expectedOutput = new[] { -1d, -0.5d, 0d, .5d, 1d }; + var index = 0; + var getter = cursor.GetGetter(schema["ColA"]); + double value = default; + + while (cursor.MoveNext()) + { + getter(ref value); + Assert.Equal(expectedOutput[index++], value); + } + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void TestInt32() + { + MLContext mlContext = new MLContext(1); + var dataList = new [] { new { ColA = 1 }, new { ColA = 3 }, new { ColA = 5 }, new { ColA = 7 }, new { ColA = 9 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA"); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + var prev = output.Preview(); + + Assert.Single(schema.Where(x => x.IsHidden == false)); + Assert.Equal(typeof(double), schema["ColA"].Type.RawType); + + var cursor = output.GetRowCursor(schema["ColA"]); + var expectedOutput = new[] { -1d, -0.5d, 0d, .5d, 1d }; + var index = 0; + var getter = cursor.GetGetter(schema["ColA"]); + double value = default; + + while (cursor.MoveNext()) + { + getter(ref value); + Assert.Equal(expectedOutput[index++], value); + } + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void TestInt16() + { + MLContext mlContext = new MLContext(1); + var dataList = new [] { new { ColA = (short)1 }, new { ColA = (short)3 }, new { ColA = (short)5 }, new { ColA = (short)7 }, new { ColA = (short)9 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA"); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + var prev = output.Preview(); + + Assert.Single(schema.Where(x => x.IsHidden == false)); + Assert.Equal(typeof(float), schema["ColA"].Type.RawType); + + var cursor = output.GetRowCursor(schema["ColA"]); + var expectedOutput = new[] { -1f, -0.5f, 0f, .5f, 1f }; + var index = 0; + var getter = cursor.GetGetter(schema["ColA"]); + float value = default; + + while (cursor.MoveNext()) + { + getter(ref value); + Assert.Equal(expectedOutput[index++], value); + } + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void TestInt8() + { + MLContext mlContext = new MLContext(1); + var dataList = new [] { new { ColA = (sbyte)1 }, new { ColA = (sbyte)3 }, new { ColA = (sbyte)5 }, new { ColA = (sbyte)7 }, new { ColA = (sbyte)9 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA"); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + var prev = output.Preview(); + + Assert.Single(schema.Where(x => x.IsHidden == false)); + Assert.Equal(typeof(float), schema["ColA"].Type.RawType); + + var cursor = output.GetRowCursor(schema["ColA"]); + var expectedOutput = new[] { -1f, -0.5f, 0f, .5f, 1f }; + var index = 0; + var getter = cursor.GetGetter(schema["ColA"]); + float value = default; + + while (cursor.MoveNext()) + { + getter(ref value); + Assert.Equal(expectedOutput[index++], value); + } + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void TestDouble() + { + MLContext mlContext = new MLContext(1); + var dataList = new [] { new { ColA = 1d }, new { ColA = 3d }, new { ColA = 5d }, new { ColA = 7d }, new { ColA = 9d } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA"); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + + Assert.Single(schema.Where(x => x.IsHidden == false)); + Assert.Equal(typeof(double), schema["ColA"].Type.RawType); + + var cursor = output.GetRowCursor(schema["ColA"]); + var expectedOutput = new[] { -1d, -0.5d, 0d, .5d, 1d }; + var index = 0; + var getter = cursor.GetGetter(schema["ColA"]); + double value = default; + + while (cursor.MoveNext()) + { + getter(ref value); + Assert.Equal(expectedOutput[index++], value); + } + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void TestUInt64() + { + MLContext mlContext = new MLContext(1); + var dataList = new [] { new { ColA = (ulong)1 }, new { ColA = (ulong)3 }, new { ColA = (ulong)5 }, new { ColA = (ulong)7 }, new { ColA = (ulong)9 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA"); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + var prev = output.Preview(); + + Assert.Single(schema.Where(x => x.IsHidden == false)); + Assert.Equal(typeof(double), schema["ColA"].Type.RawType); + + var cursor = output.GetRowCursor(schema["ColA"]); + var expectedOutput = new[] { -1d, -0.5d, 0d, .5d, 1d }; + var index = 0; + var getter = cursor.GetGetter(schema["ColA"]); + double value = default; + + while (cursor.MoveNext()) + { + getter(ref value); + Assert.Equal(expectedOutput[index++], value); + } + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void TestUInt32() + { + MLContext mlContext = new MLContext(1); + var dataList = new [] { new { ColA = (uint)1 }, new { ColA = (uint)3 }, new { ColA = (uint)5 }, new { ColA = (uint)7 }, new { ColA = (uint)9 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA"); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + var prev = output.Preview(); + + Assert.Single(schema.Where(x => x.IsHidden == false)); + Assert.Equal(typeof(double), schema["ColA"].Type.RawType); + + var cursor = output.GetRowCursor(schema["ColA"]); + var expectedOutput = new[] { -1d, -0.5d, 0d, .5d, 1d }; + var index = 0; + var getter = cursor.GetGetter(schema["ColA"]); + double value = default; + + while (cursor.MoveNext()) + { + getter(ref value); + Assert.Equal(expectedOutput[index++], value); + } + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void TestUInt16() + { + MLContext mlContext = new MLContext(1); + var dataList = new [] { new { ColA = (ushort)1 }, new { ColA = (ushort)3 }, new { ColA = (ushort)5 }, new { ColA = (ushort)7 }, new { ColA = (ushort)9 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA"); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + var prev = output.Preview(); + + Assert.Single(schema.Where(x => x.IsHidden == false)); + Assert.Equal(typeof(float), schema["ColA"].Type.RawType); + + var cursor = output.GetRowCursor(schema["ColA"]); + var expectedOutput = new[] { -1f, -0.5f, 0f, .5f, 1f }; + var index = 0; + var getter = cursor.GetGetter(schema["ColA"]); + float value = default; + + while (cursor.MoveNext()) + { + getter(ref value); + Assert.Equal(expectedOutput[index++], value); + } + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void TestUInt8() + { + MLContext mlContext = new MLContext(1); + var dataList = new [] { new { ColA = (byte)1 }, new { ColA = (byte)3 }, new { ColA = (byte)5 }, new { ColA = (byte)7 }, new { ColA = (byte)9 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA"); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + var prev = output.Preview(); + + Assert.Single(schema.Where(x => x.IsHidden == false)); + Assert.Equal(typeof(float), schema["ColA"].Type.RawType); + + var cursor = output.GetRowCursor(schema["ColA"]); + var expectedOutput = new[] { -1f, -0.5f, 0f, .5f, 1f }; + var index = 0; + var getter = cursor.GetGetter(schema["ColA"]); + float value = default; + + while (cursor.MoveNext()) + { + getter(ref value); + Assert.Equal(expectedOutput[index++], value); + } + + TestEstimatorCore(pipeline, data); + Done(); + } + } +} \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/Transformers/TimeSeriesImputerTests.cs b/test/Microsoft.ML.Tests/Transformers/TimeSeriesImputerTests.cs new file mode 100644 index 0000000000..a304832957 --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/TimeSeriesImputerTests.cs @@ -0,0 +1,402 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.ML.RunTests; +using Microsoft.ML.Featurizers; +using System; +using Xunit; +using Xunit.Abstractions; +using System.Drawing.Printing; +using System.Linq; + +namespace Microsoft.ML.Tests.Transformers +{ + public class TimeSeriesImputerTests : TestDataPipeBase + { + public TimeSeriesImputerTests(ITestOutputHelper output) : base(output) + { + } + + private class TimeSeriesTwoGrainInput + { + public long date; + public string grainA; + public string grainB; + public float data; + } + + private class TimeSeriesOneGrainInput + { + public long date; + public string grainA; + public int dataA; + public float dataB; + public uint dataC; + } + + private class TimeSeriesOneGrainFloatInput + { + public long date; + public string grainA; + public float dataA; + } + + [Fact] + public void NotImputeOneColumn() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { + new TimeSeriesOneGrainInput() { date = 25, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 }, + new TimeSeriesOneGrainInput() { date = 26, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 }, + new TimeSeriesOneGrainInput() { date = 28, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 } + }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.TimeSeriesImputer("date", new string[] { "grainA" }, new string[] { "dataB"}); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + + // We always output the same column as the input, plus adding a column saying whether the row was imputed or not. + Assert.Equal(6, schema.Count); + Assert.Equal("date", schema[0].Name); + Assert.Equal("grainA", schema[1].Name); + Assert.Equal("dataA", schema[2].Name); + Assert.Equal("dataB", schema[3].Name); + Assert.Equal("dataC", schema[4].Name); + Assert.Equal("IsRowImputed", schema[5].Name); + + // We are imputing 1 row, so total rows should be 4. + var preview = output.Preview(); + Assert.Equal(4, preview.RowView.Length); + + // Row that was imputed should have date of 27 + Assert.Equal(27L, preview.ColumnView[0].Values[2]); + + // Since we are not imputing data on one column and a row is getting imputed, its value should be default(T) + Assert.Equal(default(float), preview.ColumnView[3].Values[2]); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void ImputeOnlyOneColumn() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { + new TimeSeriesOneGrainInput() { date = 25, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 }, + new TimeSeriesOneGrainInput() { date = 26, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 }, + new TimeSeriesOneGrainInput() { date = 28, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 } + }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.TimeSeriesImputer("date", new string[] { "grainA" }, new string[] { "dataB"}, TimeSeriesImputerEstimator.FilterMode.Include); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + + // We always output the same column as the input, plus adding a column saying whether the row was imputed or not. + Assert.Equal(6, schema.Count); + Assert.Equal("date", schema[0].Name); + Assert.Equal("grainA", schema[1].Name); + Assert.Equal("dataA", schema[2].Name); + Assert.Equal("dataB", schema[3].Name); + Assert.Equal("dataC", schema[4].Name); + Assert.Equal("IsRowImputed", schema[5].Name); + + // We are imputing 1 row, so total rows should be 4. + var preview = output.Preview(); + Assert.Equal(4, preview.RowView.Length); + + // Row that was imputed should have date of 27 + Assert.Equal(27L, preview.ColumnView[0].Values[2]); + + // Since we are not imputing data on two columns and a row is getting imputed, its value should be default(T) + Assert.Equal(default(int), preview.ColumnView[2].Values[2]); + Assert.Equal(default(uint), preview.ColumnView[4].Values[2]); + + // Column that was imputed should have value of 2.0f + Assert.Equal(2.0f, preview.ColumnView[3].Values[2]); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void Forwardfill() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { new TimeSeriesOneGrainFloatInput() { date = 0, grainA = "A", dataA = 2.0f }, + new TimeSeriesOneGrainFloatInput() { date = 1, grainA = "A", dataA = float.NaN }, + new TimeSeriesOneGrainFloatInput() { date = 3, grainA = "A", dataA = 5.0f }, + new TimeSeriesOneGrainFloatInput() { date = 5, grainA = "A", dataA = float.NaN }, + new TimeSeriesOneGrainFloatInput() { date = 7, grainA = "A", dataA = float.NaN }}; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.TimeSeriesImputer("date", new string[] { "grainA" }); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var prev = output.Preview(); + + // Should have 3 original columns + 1 more for IsRowImputed + Assert.Equal(4, output.Schema.Count); + + // Imputing rows with dates 2,4,6, so should have length of 8 + Assert.Equal(8, prev.RowView.Length); + + // Check that imputed rows have the correct dates + Assert.Equal(2L, prev.ColumnView[0].Values[2]); + Assert.Equal(4L, prev.ColumnView[0].Values[4]); + Assert.Equal(6L, prev.ColumnView[0].Values[6]); + + // Make sure grain was propagated correctly + Assert.Equal("A", prev.ColumnView[1].Values[2].ToString()); + Assert.Equal("A", prev.ColumnView[1].Values[4].ToString()); + Assert.Equal("A", prev.ColumnView[1].Values[6].ToString()); + + // Make sure forward fill is working as expected. All NA's should be replaced, and imputed rows should have correct values too + Assert.Equal(2.0f, prev.ColumnView[2].Values[1]); + Assert.Equal(2.0f, prev.ColumnView[2].Values[2]); + Assert.Equal(5.0f, prev.ColumnView[2].Values[4]); + Assert.Equal(5.0f, prev.ColumnView[2].Values[5]); + Assert.Equal(5.0f, prev.ColumnView[2].Values[6]); + Assert.Equal(5.0f, prev.ColumnView[2].Values[7]); + + // Make sure IsRowImputed is true for row 2, 4,6 , false for the rest + Assert.Equal(false, prev.ColumnView[3].Values[0]); + Assert.Equal(false, prev.ColumnView[3].Values[1]); + Assert.Equal(true, prev.ColumnView[3].Values[2]); + Assert.Equal(false, prev.ColumnView[3].Values[3]); + Assert.Equal(true, prev.ColumnView[3].Values[4]); + Assert.Equal(false, prev.ColumnView[3].Values[5]); + Assert.Equal(true, prev.ColumnView[3].Values[6]); + Assert.Equal(false, prev.ColumnView[3].Values[7]); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void EntryPoint() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { new { ts = 1L, grain = 1970, c3 = 10, c4 = 19}, + new { ts = 2L, grain = 1970, c3 = 13, c4 = 12}, + new { ts = 3L, grain = 1970, c3 = 15, c4 = 16}, + new { ts = 5L, grain = 1970, c3 = 20, c4 = 19} + }; + + var data = mlContext.Data.LoadFromEnumerable(dataList); + TimeSeriesImputerEstimator.Options options = new TimeSeriesImputerEstimator.Options() { + TimeSeriesColumn = "ts", + GrainColumns = new[] { "grain" }, + FilterColumns = new[] { "c3", "c4" }, + FilterMode = TimeSeriesImputerEstimator.FilterMode.Include, + ImputeMode = TimeSeriesImputerEstimator.ImputationStrategy.ForwardFill, + Data = data + }; + + var entryOutput = TimeSeriesTransformerEntrypoint.TimeSeriesImputer(mlContext.Transforms.GetEnvironment(), options); + // Build the pipeline, fit, and transform it. + var output = entryOutput.OutputData; + + // Get the data from the first row and make sure it matches expected + var prev = output.Preview(); + + // Should have 4 original columns + 1 more for IsRowImputed + Assert.Equal(5, output.Schema.Count); + + // Imputing rows with date 4 so should have length of 5 + Assert.Equal(5, prev.RowView.Length); + + // Check that imputed rows have the correct dates + Assert.Equal(4L, prev.ColumnView[0].Values[3]); + + // Make sure grain was propagated correctly + Assert.Equal(1970, prev.ColumnView[1].Values[2]); + + // Make sure forward fill is working as expected. All NA's should be replaced, and imputed rows should have correct values too + Assert.Equal(15, prev.ColumnView[2].Values[3]); + Assert.Equal(16, prev.ColumnView[3].Values[3]); + + // Make sure IsRowImputed is true for row 4, false for the rest + Assert.Equal(false, prev.ColumnView[4].Values[0]); + Assert.Equal(false, prev.ColumnView[4].Values[1]); + Assert.Equal(false, prev.ColumnView[4].Values[2]); + Assert.Equal(true, prev.ColumnView[4].Values[3]); + Assert.Equal(false, prev.ColumnView[4].Values[4]); + + Done(); + } + + [Fact] + public void Median() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { new TimeSeriesOneGrainFloatInput() { date = 0, grainA = "A", dataA = 2.0f }, + new TimeSeriesOneGrainFloatInput() { date = 1, grainA = "A", dataA = float.NaN }, + new TimeSeriesOneGrainFloatInput() { date = 3, grainA = "A", dataA = 5.0f }, + new TimeSeriesOneGrainFloatInput() { date = 5, grainA = "A", dataA = float.NaN }, + new TimeSeriesOneGrainFloatInput() { date = 7, grainA = "A", dataA = float.NaN }}; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.TimeSeriesImputer("date", new string[] { "grainA" }, imputeMode: TimeSeriesImputerEstimator.ImputationStrategy.Median, filterColumns: null, suppressTypeErrors: true); + var model = pipeline.Fit(data); + + var output = model.Transform(data); + + var prev = output.Preview(); + + // Should have 3 original columns + 1 more for IsRowImputed + Assert.Equal(4, output.Schema.Count); + + // Imputing rows with dates 2,4,6, so should have length of 8 + Assert.Equal(8, prev.RowView.Length); + + // Check that imputed rows have the correct dates + Assert.Equal(2L, prev.ColumnView[0].Values[2]); + Assert.Equal(4L, prev.ColumnView[0].Values[4]); + Assert.Equal(6L, prev.ColumnView[0].Values[6]); + + // Make sure grain was propagated correctly + Assert.Equal("A", prev.ColumnView[1].Values[2].ToString()); + Assert.Equal("A", prev.ColumnView[1].Values[4].ToString()); + Assert.Equal("A", prev.ColumnView[1].Values[6].ToString()); + + // Make sure Median is working as expected. All NA's should be replaced, and imputed rows should have correct values too + Assert.Equal(3.5f, prev.ColumnView[2].Values[1]); + Assert.Equal(3.5f, prev.ColumnView[2].Values[2]); + Assert.Equal(3.5f, prev.ColumnView[2].Values[4]); + Assert.Equal(3.5f, prev.ColumnView[2].Values[5]); + Assert.Equal(3.5f, prev.ColumnView[2].Values[6]); + Assert.Equal(3.5f, prev.ColumnView[2].Values[7]); + + // Make sure IsRowImputed is true for row 2, 4,6 , false for the rest + Assert.Equal(false, prev.ColumnView[3].Values[0]); + Assert.Equal(false, prev.ColumnView[3].Values[1]); + Assert.Equal(true, prev.ColumnView[3].Values[2]); + Assert.Equal(false, prev.ColumnView[3].Values[3]); + Assert.Equal(true, prev.ColumnView[3].Values[4]); + Assert.Equal(false, prev.ColumnView[3].Values[5]); + Assert.Equal(true, prev.ColumnView[3].Values[6]); + Assert.Equal(false, prev.ColumnView[3].Values[7]); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void Backfill() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { new TimeSeriesOneGrainFloatInput() { date = 0, grainA = "A", dataA = float.NaN }, + new TimeSeriesOneGrainFloatInput() { date = 1, grainA = "A", dataA = float.NaN }, + new TimeSeriesOneGrainFloatInput() { date = 3, grainA = "A", dataA = 5.0f }, + new TimeSeriesOneGrainFloatInput() { date = 5, grainA = "A", dataA = float.NaN }, + new TimeSeriesOneGrainFloatInput() { date = 7, grainA = "A", dataA = 2.0f }}; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.TimeSeriesImputer("date", new string[] { "grainA" }, TimeSeriesImputerEstimator.ImputationStrategy.BackFill); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var prev = output.Preview(); + + // Should have 3 original columns + 1 more for IsRowImputed + Assert.Equal(4, output.Schema.Count); + + // Imputing rows with dates 2,4,6, so should have length of 8 + Assert.Equal(8, prev.RowView.Length); + + // Check that imputed rows have the correct dates + Assert.Equal(2L, prev.ColumnView[0].Values[2]); + Assert.Equal(4L, prev.ColumnView[0].Values[4]); + Assert.Equal(6L, prev.ColumnView[0].Values[6]); + + // Make sure grain was propagated correctly + Assert.Equal("A", prev.ColumnView[1].Values[2].ToString()); + Assert.Equal("A", prev.ColumnView[1].Values[4].ToString()); + Assert.Equal("A", prev.ColumnView[1].Values[6].ToString()); + + // Make sure backfill is working as expected. All NA's should be replaced, and imputed rows should have correct values too + Assert.Equal(5.0f, prev.ColumnView[2].Values[0]); + Assert.Equal(5.0f, prev.ColumnView[2].Values[1]); + Assert.Equal(5.0f, prev.ColumnView[2].Values[2]); + Assert.Equal(2.0f, prev.ColumnView[2].Values[4]); + Assert.Equal(2.0f, prev.ColumnView[2].Values[5]); + Assert.Equal(2.0f, prev.ColumnView[2].Values[6]); + + // Make sure IsRowImputed is true for row 2, 4,6 , false for the rest + Assert.Equal(false, prev.ColumnView[3].Values[0]); + Assert.Equal(false, prev.ColumnView[3].Values[1]); + Assert.Equal(true, prev.ColumnView[3].Values[2]); + Assert.Equal(false, prev.ColumnView[3].Values[3]); + Assert.Equal(true, prev.ColumnView[3].Values[4]); + Assert.Equal(false, prev.ColumnView[3].Values[5]); + Assert.Equal(true, prev.ColumnView[3].Values[6]); + Assert.Equal(false, prev.ColumnView[3].Values[7]); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void BackfillTwoGrain() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { new TimeSeriesTwoGrainInput() { date = 0, grainA = "A", grainB = "A", data = float.NaN}, + new TimeSeriesTwoGrainInput() { date = 1, grainA = "A", grainB = "A", data = 0.0f}, + new TimeSeriesTwoGrainInput() { date = 3, grainA = "A", grainB = "B", data = 1.0f}, + new TimeSeriesTwoGrainInput() { date = 5, grainA = "A", grainB = "B", data = float.NaN}, + new TimeSeriesTwoGrainInput() { date = 7, grainA = "A", grainB = "B", data = 2.0f }}; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.TimeSeriesImputer("date", new string[] { "grainA", "grainB" }, TimeSeriesImputerEstimator.ImputationStrategy.BackFill); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var prev = output.Preview(); + + // Should have 4 original columns + 1 more for IsRowImputed + Assert.Equal(5, output.Schema.Count); + + // Imputing rows with dates 4,6, so should have length of 8 + Assert.Equal(7, prev.RowView.Length); + + // Check that imputed rows have the correct dates + Assert.Equal(4L, prev.ColumnView[0].Values[3]); + Assert.Equal(6L, prev.ColumnView[0].Values[5]); + + // Make sure grain was propagated correctly + Assert.Equal("A", prev.ColumnView[1].Values[3].ToString()); + Assert.Equal("A", prev.ColumnView[1].Values[5].ToString()); + Assert.Equal("B", prev.ColumnView[2].Values[3].ToString()); + Assert.Equal("B", prev.ColumnView[2].Values[5].ToString()); + + // Make sure backfill is working as expected. All NA's should be replaced, and imputed rows should have correct values too + Assert.Equal(0.0f, prev.ColumnView[3].Values[0]); + Assert.Equal(2.0f, prev.ColumnView[3].Values[3]); + Assert.Equal(2.0f, prev.ColumnView[3].Values[4]); + Assert.Equal(2.0f, prev.ColumnView[3].Values[5]); + + // Make sure IsRowImputed is true for row 4,6 false for the rest + Assert.Equal(false, prev.ColumnView[4].Values[0]); + Assert.Equal(false, prev.ColumnView[4].Values[1]); + Assert.Equal(false, prev.ColumnView[4].Values[2]); + Assert.Equal(true, prev.ColumnView[4].Values[3]); + Assert.Equal(false, prev.ColumnView[4].Values[4]); + Assert.Equal(true, prev.ColumnView[4].Values[5]); + Assert.Equal(false, prev.ColumnView[4].Values[6]); + + TestEstimatorCore(pipeline, data); + Done(); + } + } +} diff --git a/test/Microsoft.ML.Tests/Transformers/ToStringTransformerTests.cs b/test/Microsoft.ML.Tests/Transformers/ToStringTransformerTests.cs new file mode 100644 index 0000000000..7f5c4e0f0f --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/ToStringTransformerTests.cs @@ -0,0 +1,353 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Featurizers; +using Microsoft.ML.Data; +using Microsoft.ML.RunTests; +using Microsoft.ML.Transforms; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.Transformers +{ + public class ToStringTransformerTests : TestDataPipeBase + { + private class BoolInput + { + public bool data; + } + + private class ByteInput + { + public byte data; + } + + private class SByteInput + { + public sbyte data; + } + + private class ShortInput + { + public short data; + } + + private class UShortInput + { + public ushort data; + } + + private class IntInput + { + public int data; + } + + private class UIntInput + { + public uint data; + } + + private class LongInput + { + public long data; + } + + private class ULongInput + { + public ulong data; + } + + private class FloatInput + { + public float data; + } + + private class DoubleInput + { + public double data; + } + private class StringInput + { + public string data; + } + + public ToStringTransformerTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + public void TestBool() + { + MLContext mlContext = new MLContext(1); + var data = new[] { new BoolInput() { data = true }, new BoolInput() { data = false } }; + IDataView input = mlContext.Data.LoadFromEnumerable(data); + var pipeline = mlContext.Transforms.ToStringTransformer("output", "data"); + + TestEstimatorCore(pipeline, input); + + var model = pipeline.Fit(input); + var output = model.Transform(input); + var rows = output.Preview().RowView; + + Assert.Equal("True", rows[0].Values[1].Value.ToString()); + Assert.Equal("False", rows[1].Values[1].Value.ToString()); + Done(); + } + + + [Fact] + public void TestByte() + { + MLContext mlContext = new MLContext(1); + + var data = new[] { new ByteInput() { data = byte.MinValue }, new ByteInput() { data = byte.MaxValue } }; + IDataView input = mlContext.Data.LoadFromEnumerable(data); + var pipeline = mlContext.Transforms.ToStringTransformer("output", "data"); + + TestEstimatorCore(pipeline, input); + + var model = pipeline.Fit(input); + var output = model.Transform(input); + var rows = output.Preview().RowView; + + Assert.Equal("0", rows[0].Values[1].Value.ToString()); + Assert.Equal("255", rows[1].Values[1].Value.ToString()); + Done(); + } + + [Fact] + public void TestSByte() + { + MLContext mlContext = new MLContext(1); + + var data = new[] { new SByteInput() { data = sbyte.MinValue }, new SByteInput() { data = sbyte.MaxValue } }; + IDataView input = mlContext.Data.LoadFromEnumerable(data); + var pipeline = mlContext.Transforms.ToStringTransformer("output", "data"); + + TestEstimatorCore(pipeline, input); + + var model = pipeline.Fit(input); + var output = model.Transform(input); + var rows = output.Preview().RowView; + + Assert.Equal("-128", rows[0].Values[1].Value.ToString()); + Assert.Equal("127", rows[1].Values[1].Value.ToString()); + Done(); + } + + [Fact] + public void TestUShort() + { + MLContext mlContext = new MLContext(1); + + var data = new[] { new UShortInput() { data = ushort.MinValue }, new UShortInput() { data = ushort.MaxValue } }; + IDataView input = mlContext.Data.LoadFromEnumerable(data); + var pipeline = mlContext.Transforms.ToStringTransformer("output", "data"); + + TestEstimatorCore(pipeline, input); + + var model = pipeline.Fit(input); + var output = model.Transform(input); + var rows = output.Preview().RowView; + + Assert.Equal("0", rows[0].Values[1].Value.ToString()); + Assert.Equal("65535", rows[1].Values[1].Value.ToString()); + Done(); + } + + [Fact] + public void TestInt() + { + MLContext mlContext = new MLContext(1); + + var data = new[] { new IntInput() { data = int.MinValue }, new IntInput() { data = int.MaxValue } }; + IDataView input = mlContext.Data.LoadFromEnumerable(data); + var pipeline = mlContext.Transforms.ToStringTransformer("output", "data"); + + TestEstimatorCore(pipeline, input); + + var model = pipeline.Fit(input); + var output = model.Transform(input); + var rows = output.Preview().RowView; + + Assert.Equal("-2147483648", rows[0].Values[1].Value.ToString()); + Assert.Equal("2147483647", rows[1].Values[1].Value.ToString()); + Done(); + } + + [Fact] + public void TestUInt() + { + MLContext mlContext = new MLContext(1); + + var data = new[] { new UIntInput() { data = uint.MinValue }, new UIntInput() { data = uint.MaxValue } }; + IDataView input = mlContext.Data.LoadFromEnumerable(data); + var pipeline = mlContext.Transforms.ToStringTransformer("output", "data"); + + TestEstimatorCore(pipeline, input); + + var model = pipeline.Fit(input); + var output = model.Transform(input); + var rows = output.Preview().RowView; + + Assert.Equal("0", rows[0].Values[1].Value.ToString()); + Assert.Equal("4294967295", rows[1].Values[1].Value.ToString()); + Done(); + } + + [Fact] + public void TestLong() + { + MLContext mlContext = new MLContext(1); + + var data = new[] { new LongInput() { data = long.MinValue }, new LongInput() { data = long.MaxValue } }; + IDataView input = mlContext.Data.LoadFromEnumerable(data); + var pipeline = mlContext.Transforms.ToStringTransformer("output", "data"); + + TestEstimatorCore(pipeline, input); + + var model = pipeline.Fit(input); + var output = model.Transform(input); + var rows = output.Preview().RowView; + + Assert.Equal("-9223372036854775808", rows[0].Values[1].Value.ToString()); + Assert.Equal("9223372036854775807", rows[1].Values[1].Value.ToString()); + Done(); + } + + [Fact] + public void TestULong() + { + MLContext mlContext = new MLContext(1); + + var data = new[] { new ULongInput() { data = ulong.MinValue }, new ULongInput() { data = ulong.MaxValue } }; + IDataView input = mlContext.Data.LoadFromEnumerable(data); + var pipeline = mlContext.Transforms.ToStringTransformer("output", "data"); + + TestEstimatorCore(pipeline, input); + + var model = pipeline.Fit(input); + var output = model.Transform(input); + var rows = output.Preview().RowView; + + Assert.Equal("0", rows[0].Values[1].Value.ToString()); + Assert.Equal("18446744073709551615", rows[1].Values[1].Value.ToString()); + Done(); + } + + [Fact] + public void TestFloat() + { + MLContext mlContext = new MLContext(1); + + var data = new[] { new FloatInput() { data = float.MinValue}, new FloatInput() { data = float.MaxValue }, new FloatInput() { data = float.NaN } }; + IDataView input = mlContext.Data.LoadFromEnumerable(data); + var pipeline = mlContext.Transforms.ToStringTransformer("output", "data"); + + TestEstimatorCore(pipeline, input); + + var model = pipeline.Fit(input); + var output = model.Transform(input); + var rows = output.Preview().RowView; + + Assert.Equal("-340282346638528859811704183484516925440.000000", rows[0].Values[1].Value.ToString()); + Assert.Equal("340282346638528859811704183484516925440.000000", rows[1].Values[1].Value.ToString()); + Done(); + } + + [Fact] + public void TestShort() + { + MLContext mlContext = new MLContext(1); + + var data = new[] { new ShortInput() { data = short.MinValue }, new ShortInput() { data = short.MaxValue } }; + IDataView input = mlContext.Data.LoadFromEnumerable(data); + var pipeline = mlContext.Transforms.ToStringTransformer("output", "data"); + + TestEstimatorCore(pipeline, input); + + var model = pipeline.Fit(input); + var output = model.Transform(input); + var rows = output.Preview().RowView; + + Assert.Equal("-32768", rows[0].Values[1].Value.ToString()); + Assert.Equal("32767", rows[1].Values[1].Value.ToString()); + Done(); + } + + [Fact] + public void TestDouble() + { + MLContext mlContext = new MLContext(1); + + var data = new[] { new DoubleInput() { data = double.MinValue}, new DoubleInput() { data = double.MaxValue }, new DoubleInput() { data = double.NaN } }; + IDataView input = mlContext.Data.LoadFromEnumerable(data); + var pipeline = mlContext.Transforms.ToStringTransformer("data.out", "data"); + + TestEstimatorCore(pipeline, input); + + var model = pipeline.Fit(input); + var output = model.Transform(input); + var rows = output.Preview().RowView; + + // Since we can't set the precision yet on the Native side and it returns the whole string value, only checking the first 10 places. + Assert.Equal(double.MinValue.ToString("F10").Substring(0,10), rows[0].Values[1].Value.ToString().Substring(0, 10)); + Assert.Equal(double.MaxValue.ToString("F10").Substring(0, 10), rows[1].Values[1].Value.ToString().Substring(0, 10)); + + Done(); + } + + [Fact] + public void TestString() + { + MLContext mlContext = new MLContext(1); + + var data = new[] { new StringInput() { data = ""}, new StringInput() { data = "Long Dummy String Value" } }; + IDataView input = mlContext.Data.LoadFromEnumerable(data); + var pipeline = mlContext.Transforms.ToStringTransformer("output", "data"); + + TestEstimatorCore(pipeline, input); + + var model = pipeline.Fit(input); + var output = model.Transform(input); + var rows = output.Preview().RowView; + + Assert.Equal("", rows[0].Values[1].Value.ToString()); + Assert.Equal("Long Dummy String Value", rows[1].Values[1].Value.ToString()); + Done(); + + + + } + + [Fact] + public void TestEntryPoint() + { + MLContext mlContext = new MLContext(1); + + var data = new[] { new StringInput() { data = ""}, new StringInput() { data = "Long Dummy String Value" } }; + IDataView input = mlContext.Data.LoadFromEnumerable(data); + + var options = new ToStringTransformerEstimator.Options() + { + Columns = new ToStringTransformerEstimator.Column[1] + { + new ToStringTransformerEstimator.Column() + { + Name = "data" + } + }, + Data = input + }; + + var output = ToStringTransformerEntrypoint.ToString(mlContext.Transforms.GetEnvironment(), options); + + var rows = output.OutputData.Preview().RowView; + + Assert.Equal("", rows[0].Values[1].Value.ToString()); + Assert.Equal("Long Dummy String Value", rows[1].Values[1].Value.ToString()); + Done(); + } + } +} diff --git a/test/data/dates.txt b/test/data/dates.txt new file mode 100644 index 0000000000..510371af4e Binary files /dev/null and b/test/data/dates.txt differ