diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln
index 9923cb7635..5f16d116a0 100644
--- a/Microsoft.ML.sln
+++ b/Microsoft.ML.sln
@@ -273,6 +273,14 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.AutoML.Samples
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Samples.GPU", "docs\samples\Microsoft.ML.Samples.GPU\Microsoft.ML.Samples.GPU.csproj", "{3C8F910B-7F23-4D25-B521-6D5AC9570ADD}"
EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Featurizers", "src\Microsoft.ML.Featurizers\Microsoft.ML.Featurizers.csproj", "{E2DD0721-5B0F-4606-8182-4C7EFB834518}"
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Featurizers", "Microsoft.ML.Featurizers", "{1BA5C784-52E8-4A87-8525-26B2452F2882}"
+ ProjectSection(SolutionItems) = preProject
+ pkg\Microsoft.ML.Featurizers\Microsoft.ML.Featurizers.nupkgproj = pkg\Microsoft.ML.Featurizers\Microsoft.ML.Featurizers.nupkgproj
+ pkg\Microsoft.ML.Featurizers\Microsoft.ML.Featurizers.symbols.nupkgproj = pkg\Microsoft.ML.Featurizers\Microsoft.ML.Featurizers.symbols.nupkgproj
+ EndProjectSection
+EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeGenerator", "src\Microsoft.ML.CodeGenerator\Microsoft.ML.CodeGenerator.csproj", "{56CB0850-7341-4D71-9AE4-9EFC472D93DD}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeGenerator.Tests", "test\Microsoft.ML.CodeGenerator.Tests\Microsoft.ML.CodeGenerator.Tests.csproj", "{46CC5637-3DDF-4100-93FC-44BB87B2DB81}"
@@ -1690,6 +1698,30 @@ Global
{46CC5637-3DDF-4100-93FC-44BB87B2DB81}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
{46CC5637-3DDF-4100-93FC-44BB87B2DB81}.Release-netfx|x64.ActiveCfg = Release-netfx|Any CPU
{46CC5637-3DDF-4100-93FC-44BB87B2DB81}.Release-netfx|x64.Build.0 = Release-netfx|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug|x64.ActiveCfg = Debug|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug|x64.Build.0 = Debug|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug-netcoreapp3_0|Any CPU.ActiveCfg = Debug-netcoreapp3_0|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug-netcoreapp3_0|Any CPU.Build.0 = Debug-netcoreapp3_0|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug-netcoreapp3_0|x64.ActiveCfg = Debug-netcoreapp3_0|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug-netcoreapp3_0|x64.Build.0 = Debug-netcoreapp3_0|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug-netfx|x64.ActiveCfg = Debug-netfx|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Debug-netfx|x64.Build.0 = Debug-netfx|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release|Any CPU.Build.0 = Release|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release|x64.ActiveCfg = Release|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release|x64.Build.0 = Release|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release-netcoreapp3_0|Any CPU.ActiveCfg = Release-netcoreapp3_0|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release-netcoreapp3_0|Any CPU.Build.0 = Release-netcoreapp3_0|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release-netcoreapp3_0|x64.ActiveCfg = Release-netcoreapp3_0|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release-netcoreapp3_0|x64.Build.0 = Release-netcoreapp3_0|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release-netfx|x64.ActiveCfg = Release-netfx|Any CPU
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518}.Release-netfx|x64.Build.0 = Release-netfx|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -1779,6 +1811,8 @@ Global
{56CB0850-7341-4D71-9AE4-9EFC472D93DD} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{46CC5637-3DDF-4100-93FC-44BB87B2DB81} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{3817A875-278C-4140-BF66-3C4A8CA55F0D} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
+ {E2DD0721-5B0F-4606-8182-4C7EFB834518} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
+ {1BA5C784-52E8-4A87-8525-26B2452F2882} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
diff --git a/build/BranchInfo.props b/build/BranchInfo.props
index e87b2897a4..ce3d93c486 100644
--- a/build/BranchInfo.props
+++ b/build/BranchInfo.props
@@ -30,12 +30,12 @@
1
4
0
- preview3
+ preview2
0
16
0
- preview3
+ preview2
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/CategoryImputer.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/CategoryImputer.cs
new file mode 100644
index 0000000000..f89ddc6cfb
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/CategoryImputer.cs
@@ -0,0 +1,69 @@
+using System;
+using System.Collections.Generic;
+using Microsoft.ML;
+using Microsoft.ML.Data;
+using Microsoft.ML.Featurizers;
+
+namespace Samples.Dynamic
+{
+ public static class CategoryImputer
+ {
+ public static void Example()
+ {
+ // Create a new ML context, for ML.NET operations. It can be used for
+ // exception tracking and logging, as well as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Create a small dataset as an IEnumerable.
+ var samples = new List()
+ {
+ new InputData(){ Feature1 = 1f },
+
+ new InputData(){ Feature1 = float.NaN },
+
+ new InputData(){ Feature1 = 1f },
+
+ new InputData(){ Feature1 = float.NaN },
+
+ new InputData(){ Feature1 = 9f },
+ };
+
+ // Convert training data to IDataView.
+ var dataview = mlContext.Data.LoadFromEnumerable(samples);
+
+ // A pipeline for filling in the missing values in the feature1 column
+ var pipeline = mlContext.Transforms.CatagoryImputerTransformer("Feature1");
+
+ // The transformed data.
+ var transformedData = pipeline.Fit(dataview).Transform(dataview);
+
+ // Now let's take a look at what this did. The NaN values should be filled in with the most frequent value, 1.
+ // We can extract the newly created columns as an IEnumerable of TransformedData.
+ var featuresColumn = mlContext.Data.CreateEnumerable(
+ transformedData, reuseRowObject: false);
+
+ // And we can write out a few rows
+ Console.WriteLine($"Features column obtained post-transformation.");
+ foreach (var featureRow in featuresColumn)
+ Console.WriteLine(featureRow.Feature1);
+
+ // Expected output:
+ // Features column obtained post-transformation.
+ // 1
+ // 1
+ // 1
+ // 1
+ // 9
+ }
+
+ private class InputData
+ {
+ public float Feature1;
+ }
+
+ private sealed class TransformedData
+ {
+ public float Feature1 { get; set; }
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformer.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformer.cs
new file mode 100644
index 0000000000..9bf851b0ba
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformer.cs
@@ -0,0 +1,84 @@
+using System;
+using System.Collections.Generic;
+using Microsoft.ML;
+using Microsoft.ML.Data;
+using Microsoft.ML.Featurizers;
+
+namespace Samples.Dynamic
+{
+ public static class DateTimeTransformer
+ {
+ private class DateTimeInput
+ {
+ public long Date;
+ }
+
+ public static void Example()
+ {
+ // Create a new ML context, for ML.NET operations. It can be used for
+ // exception tracking and logging, as well as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Create a small dataset as an IEnumerable.
+ // Future Date - 2025 June 30
+ var samples = new[] { new DateTimeInput() { Date = 1751241600 } };
+
+ // Convert training data to IDataView.
+ var dataview = mlContext.Data.LoadFromEnumerable(samples);
+
+ // A pipeline for splitting the time features into individual columns
+ var pipeline = mlContext.Transforms.DateTimeTransformer("Date", "DTC");
+
+ // The transformed data.
+ var transformedData = pipeline.Fit(dataview).Transform(dataview);
+
+ // Now let's take a look at what this did. We should have created 21 more columns with all the
+ // DateTime information split into its own columns
+ var featuresColumn = mlContext.Data.CreateEnumerable(
+ transformedData, reuseRowObject: false);
+
+ // And we can write out a few rows
+ Console.WriteLine($"Features column obtained post-transformation.");
+ foreach (var featureRow in featuresColumn)
+ Console.WriteLine(featureRow.Date + ", " + featureRow.DTCYear + ", " + featureRow.DTCMonth + ", " +
+ featureRow.DTCDay + ", " + featureRow.DTCHour + ", " + featureRow.DTCMinute + ", " +
+ featureRow.DTCSecond + ", " + featureRow.DTCAmPm + ", " + featureRow.DTCHour12 + ", " +
+ featureRow.DTCDayOfWeek + ", " + featureRow.DTCDayOfQuarter + ", " + featureRow.DTCDayOfYear +
+ ", " + featureRow.DTCWeekOfMonth + ", " + featureRow.DTCQuarterOfYear + ", " + featureRow.DTCHalfOfYear +
+ ", " + featureRow.DTCWeekIso + ", " + featureRow.DTCYearIso + ", " + featureRow.DTCMonthLabel + ", " +
+ featureRow.DTCAmPmLabel + ", " + featureRow.DTCDayOfWeekLabel + ", " + featureRow.DTCHolidayName + ", " +
+ featureRow.DTCIsPaidTimeOff);
+
+ // Expected output:
+ // Features columns obtained post-transformation.
+ // 1751241600, 2025, 6, 30, 0, 0, 0, 0, 0, 1, 91, 180, 4, 2, 1, 27, 2025, June, am, Monday, , 0
+ }
+
+ // These columns start with DTC because that is the prefix we picked
+ private sealed class TransformedData
+ {
+ public long Date { get; set; }
+ public int DTCYear { get; set; }
+ public byte DTCMonth { get; set; }
+ public byte DTCDay { get; set; }
+ public byte DTCHour { get; set; }
+ public byte DTCMinute { get; set; }
+ public byte DTCSecond { get; set; }
+ public byte DTCAmPm { get; set; }
+ public byte DTCHour12 { get; set; }
+ public byte DTCDayOfWeek { get; set; }
+ public byte DTCDayOfQuarter { get; set; }
+ public ushort DTCDayOfYear { get; set; }
+ public ushort DTCWeekOfMonth { get; set; }
+ public byte DTCQuarterOfYear { get; set; }
+ public byte DTCHalfOfYear { get; set; }
+ public byte DTCWeekIso { get; set; }
+ public int DTCYearIso { get; set; }
+ public string DTCMonthLabel { get; set; }
+ public string DTCAmPmLabel { get; set; }
+ public string DTCDayOfWeekLabel { get; set; }
+ public string DTCHolidayName { get; set; }
+ public byte DTCIsPaidTimeOff { get; set; }
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs
new file mode 100644
index 0000000000..680b338056
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs
@@ -0,0 +1,80 @@
+using System;
+using System.Collections.Generic;
+using Microsoft.ML;
+using Microsoft.ML.Data;
+using Microsoft.ML.Featurizers;
+
+namespace Samples.Dynamic
+{
+ public static class DateTimeTransformerDropColumns
+ {
+ private class DateTimeInput
+ {
+ public long Date;
+ }
+
+ public static void Example()
+ {
+ // Create a new ML context, for ML.NET operations. It can be used for
+ // exception tracking and logging, as well as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Create a small dataset as an IEnumerable.
+ // Future Date - 2025 June 30
+ var samples = new[] { new DateTimeInput() { Date = 1751241600 } };
+
+ // Convert training data to IDataView.
+ var dataview = mlContext.Data.LoadFromEnumerable(samples);
+
+ // A pipeline for splitting the time features into individual columns
+ // All the columns listed here will be dropped.
+ var pipeline = mlContext.Transforms.DateTimeTransformer("Date", "DTC", DateTimeTransformerEstimator.ColumnsProduced.IsPaidTimeOff,
+ DateTimeTransformerEstimator.ColumnsProduced.Day, DateTimeTransformerEstimator.ColumnsProduced.QuarterOfYear,
+ DateTimeTransformerEstimator.ColumnsProduced.AmPm, DateTimeTransformerEstimator.ColumnsProduced.HolidayName);
+
+ // The transformed data.
+ var transformedData = pipeline.Fit(dataview).Transform(dataview);
+
+ // Now let's take a look at what this did. We should have created 16 more columns with all the
+ // DateTime information split into its own columns
+ var featuresColumn = mlContext.Data.CreateEnumerable(
+ transformedData, reuseRowObject: false);
+
+ // And we can write out a few rows
+ Console.WriteLine($"Features column obtained post-transformation.");
+ foreach (var featureRow in featuresColumn)
+ Console.WriteLine(featureRow.Date + ", " + featureRow.DTCYear + ", " + featureRow.DTCMonth + ", " +
+ featureRow.DTCHour + ", " + featureRow.DTCMinute + ", " + featureRow.DTCSecond + ", " +
+ featureRow.DTCHour12 + ", " + featureRow.DTCDayOfWeek + ", " + featureRow.DTCDayOfQuarter + ", " +
+ featureRow.DTCDayOfYear + ", " + featureRow.DTCWeekOfMonth + ", " + featureRow.DTCHalfOfYear +
+ ", " + featureRow.DTCWeekIso + ", " + featureRow.DTCYearIso + ", " + featureRow.DTCMonthLabel + ", " +
+ featureRow.DTCAmPmLabel + ", " + featureRow.DTCDayOfWeekLabel);
+
+ // Expected output:
+ // Features columns obtained post-transformation.
+ // 1751241600, 2025, 6, 30, 0, 0, 0, 0, 0, 1, 91, 180, 4, 2, 1, 27, 2025, June, am, Monday
+ }
+
+ // These columns start with DTC because that is the prefix we picked
+ private sealed class TransformedData
+ {
+ public long Date { get; set; }
+ public int DTCYear { get; set; }
+ public byte DTCMonth { get; set; }
+ public byte DTCHour { get; set; }
+ public byte DTCMinute { get; set; }
+ public byte DTCSecond { get; set; }
+ public byte DTCHour12 { get; set; }
+ public byte DTCDayOfWeek { get; set; }
+ public byte DTCDayOfQuarter { get; set; }
+ public ushort DTCDayOfYear { get; set; }
+ public ushort DTCWeekOfMonth { get; set; }
+ public byte DTCHalfOfYear { get; set; }
+ public byte DTCWeekIso { get; set; }
+ public int DTCYearIso { get; set; }
+ public string DTCMonthLabel { get; set; }
+ public string DTCAmPmLabel { get; set; }
+ public string DTCDayOfWeekLabel { get; set; }
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScaler.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScaler.cs
new file mode 100644
index 0000000000..d46449447f
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScaler.cs
@@ -0,0 +1,69 @@
+using System;
+using System.Collections.Generic;
+using Microsoft.ML;
+using Microsoft.ML.Data;
+using Microsoft.ML.Featurizers;
+
+namespace Samples.Dynamic
+{
+ public static class RobustScaler
+ {
+ public static void Example()
+ {
+ // Create a new ML context, for ML.NET operations. It can be used for
+ // exception tracking and logging, as well as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Create a small dataset as an IEnumerable.
+ var samples = new List()
+ {
+ new InputData(){ Feature1 = 1f },
+
+ new InputData(){ Feature1 = 3f },
+
+ new InputData(){ Feature1 = 5f },
+
+ new InputData(){ Feature1 = 7f },
+
+ new InputData(){ Feature1 = 9f },
+ };
+
+ // Convert training data to IDataView.
+ var dataview = mlContext.Data.LoadFromEnumerable(samples);
+
+ // A pipeline for centering and scaling the feature1 column
+ var pipeline = mlContext.Transforms.RobustScalerTransformer("Feature1");
+
+ // The transformed data.
+ var transformedData = pipeline.Fit(dataview).Transform(dataview);
+
+ // Now let's take a look at what this did. The values should be centered around 0 and scaled.
+ // We can extract the newly created columns as an IEnumerable of TransformedData.
+ var featuresColumn = mlContext.Data.CreateEnumerable(
+ transformedData, reuseRowObject: false);
+
+ // And we can write out a few rows
+ Console.WriteLine($"Features column obtained post-transformation.");
+ foreach (var featureRow in featuresColumn)
+ Console.WriteLine(featureRow.Feature1);
+
+ // Expected output:
+ // Features column obtained post-transformation.
+ // -1
+ // -.5
+ // 0
+ // .5
+ // 1
+ }
+
+ private class InputData
+ {
+ public float Feature1;
+ }
+
+ private sealed class TransformedData
+ {
+ public float Feature1 { get; set; }
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScalerWithCenter.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScalerWithCenter.cs
new file mode 100644
index 0000000000..7a868ba7ee
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScalerWithCenter.cs
@@ -0,0 +1,69 @@
+using System;
+using System.Collections.Generic;
+using Microsoft.ML;
+using Microsoft.ML.Data;
+using Microsoft.ML.Featurizers;
+
+namespace Samples.Dynamic
+{
+ public static class RobustScalerWithCenter
+ {
+ public static void Example()
+ {
+ // Create a new ML context, for ML.NET operations. It can be used for
+ // exception tracking and logging, as well as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Create a small dataset as an IEnumerable.
+ var samples = new List()
+ {
+ new InputData(){ Feature1 = 1f },
+
+ new InputData(){ Feature1 = 3f },
+
+ new InputData(){ Feature1 = 5f },
+
+ new InputData(){ Feature1 = 7f },
+
+ new InputData(){ Feature1 = 9f },
+ };
+
+ // Convert training data to IDataView.
+ var dataview = mlContext.Data.LoadFromEnumerable(samples);
+
+ // A pipeline for Centering the feature1 column
+ var pipeline = mlContext.Transforms.RobustScalerTransformer("Feature1", scale: false);
+
+ // The transformed data.
+ var transformedData = pipeline.Fit(dataview).Transform(dataview);
+
+ // Now let's take a look at what this did. The values should be centered around 0.
+ // We can extract the newly created columns as an IEnumerable of TransformedData.
+ var featuresColumn = mlContext.Data.CreateEnumerable(
+ transformedData, reuseRowObject: false);
+
+ // And we can write out a few rows
+ Console.WriteLine($"Features column obtained post-transformation.");
+ foreach (var featureRow in featuresColumn)
+ Console.WriteLine(featureRow.Feature1);
+
+ // Expected output:
+ // Features column obtained post-transformation.
+ // -4
+ // -2
+ // 0
+ // 2
+ // 4
+ }
+
+ private class InputData
+ {
+ public float Feature1;
+ }
+
+ private sealed class TransformedData
+ {
+ public float Feature1 { get; set; }
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScalerWithScale.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScalerWithScale.cs
new file mode 100644
index 0000000000..4318a67e9f
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/RobustScalerWithScale.cs
@@ -0,0 +1,69 @@
+using System;
+using System.Collections.Generic;
+using Microsoft.ML;
+using Microsoft.ML.Data;
+using Microsoft.ML.Featurizers;
+
+namespace Samples.Dynamic
+{
+ public static class RobustScalerWithScale
+ {
+ public static void Example()
+ {
+ // Create a new ML context, for ML.NET operations. It can be used for
+ // exception tracking and logging, as well as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Create a small dataset as an IEnumerable.
+ var samples = new List()
+ {
+ new InputData(){ Feature1 = 1f },
+
+ new InputData(){ Feature1 = 3f },
+
+ new InputData(){ Feature1 = 5f },
+
+ new InputData(){ Feature1 = 7f },
+
+ new InputData(){ Feature1 = 9f },
+ };
+
+ // Convert training data to IDataView.
+ var dataview = mlContext.Data.LoadFromEnumerable(samples);
+
+ // A pipeline for scaling the feature1 column
+ var pipeline = mlContext.Transforms.RobustScalerTransformer("Feature1", center: false);
+
+ // The transformed data.
+ var transformedData = pipeline.Fit(dataview).Transform(dataview);
+
+ // Now let's take a look at what this did. The values should be scaled by the range * ratio.
+ // We can extract the newly created columns as an IEnumerable of TransformedData.
+ var featuresColumn = mlContext.Data.CreateEnumerable(
+ transformedData, reuseRowObject: false);
+
+ // And we can write out a few rows
+ Console.WriteLine($"Features column obtained post-transformation.");
+ foreach (var featureRow in featuresColumn)
+ Console.WriteLine(featureRow.Feature1);
+
+ // Expected output:
+ // Features column obtained post-transformation.
+ // 0.25
+ // .75
+ // 1.25
+ // 1.75
+ // 2.25
+ }
+
+ private class InputData
+ {
+ public float Feature1;
+ }
+
+ private sealed class TransformedData
+ {
+ public float Feature1 { get; set; }
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerBackFill.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerBackFill.cs
new file mode 100644
index 0000000000..50408cc4e4
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerBackFill.cs
@@ -0,0 +1,71 @@
+using System;
+using System.Collections.Generic;
+using Microsoft.ML;
+using Microsoft.ML.Data;
+using Microsoft.ML.Featurizers;
+
+namespace Samples.Dynamic
+{
+ public static class TimeSeriesImputerBackFill
+ {
+ public static void Example()
+ {
+ // Create a new ML context, for ML.NET operations. It can be used for
+ // exception tracking and logging, as well as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Create a small dataset as an IEnumerable.
+ var samples = new[] { new InputData() { Date = 0, GrainA = "A", DataA = float.NaN },
+ new InputData() { Date = 1, GrainA = "A", DataA = float.NaN },
+ new InputData() { Date = 3, GrainA = "A", DataA = 5.0f },
+ new InputData() { Date = 5, GrainA = "A", DataA = float.NaN },
+ new InputData() { Date = 7, GrainA = "A", DataA = 2.0f }};
+
+ // Convert training data to IDataView.
+ var dataview = mlContext.Data.LoadFromEnumerable(samples);
+
+ // A pipeline for imputing the missing rows and values in the columns using the "BackFill" strategy.
+ var pipeline = mlContext.Transforms.TimeSeriesImputer("Date", new string[] { "GrainA" }, TimeSeriesImputerEstimator.ImputationStrategy.BackFill);
+
+ // The transformed data.
+ var transformedData = pipeline.Fit(dataview).Transform(dataview);
+
+ // Now let's take a look at what this did. The NaN values should be filled in with the next value that was not NaN,
+ // and rows should be created to fill in the missing gaps in the time column.
+ // We can extract the newly created columns as an IEnumerable of TransformedData.
+ var featuresColumn = mlContext.Data.CreateEnumerable(
+ transformedData, reuseRowObject: false);
+
+ // And we can write out a few rows
+ Console.WriteLine($"Features column obtained post-transformation.");
+ foreach (var featureRow in featuresColumn)
+ Console.WriteLine(featureRow.Date + ", " + featureRow.GrainA + ", " + featureRow.DataA + ", " + featureRow.IsRowImputed);
+
+ // Expected output:
+ // Features column obtained post-transformation.
+ // 0, A, 5.0, false
+ // 1, A, 5.0, false
+ // 2, A, 5.0, true
+ // 3, A, 5.0, false
+ // 4, A, 2.0, true
+ // 5, A, 2.0, false
+ // 6, A, 2.0, true
+ // 7, A, 2.0, false
+ }
+
+ private class InputData
+ {
+ public long Date;
+ public string GrainA;
+ public float DataA;
+ }
+
+ private sealed class TransformedData
+ {
+ public long Date { get; set; }
+ public string GrainA { get; set; }
+ public float DataA { get; set; }
+ public bool IsRowImputed { get; set; }
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerForwardFill.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerForwardFill.cs
new file mode 100644
index 0000000000..6a22b98b4d
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerForwardFill.cs
@@ -0,0 +1,71 @@
+using System;
+using System.Collections.Generic;
+using Microsoft.ML;
+using Microsoft.ML.Data;
+using Microsoft.ML.Featurizers;
+
+namespace Samples.Dynamic
+{
+ public static class TimeSeriesImputerForwardFill
+ {
+ public static void Example()
+ {
+ // Create a new ML context, for ML.NET operations. It can be used for
+ // exception tracking and logging, as well as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Create a small dataset as an IEnumerable.
+ var samples = new[] { new InputData() { Date = 0, GrainA = "A", DataA = 2.0f },
+ new InputData() { Date = 1, GrainA = "A", DataA = float.NaN },
+ new InputData() { Date = 3, GrainA = "A", DataA = 5.0f },
+ new InputData() { Date = 5, GrainA = "A", DataA = float.NaN },
+ new InputData() { Date = 7, GrainA = "A", DataA = float.NaN }};
+
+ // Convert training data to IDataView.
+ var dataview = mlContext.Data.LoadFromEnumerable(samples);
+
+ // A pipeline for imputing the missing rows and values in the columns using the default "ForwardFill" strategy.
+ var pipeline = mlContext.Transforms.TimeSeriesImputer("Date", new string[] { "GrainA" });
+
+ // The transformed data.
+ var transformedData = pipeline.Fit(dataview).Transform(dataview);
+
+ // Now let's take a look at what this did. The NaN values should be filled in with last value that was not NaN,
+ // and rows should be created to fill in the missing gaps in the time column.
+ // We can extract the newly created columns as an IEnumerable of TransformedData.
+ var featuresColumn = mlContext.Data.CreateEnumerable(
+ transformedData, reuseRowObject: false);
+
+ // And we can write out a few rows
+ Console.WriteLine($"Features column obtained post-transformation.");
+ foreach (var featureRow in featuresColumn)
+ Console.WriteLine(featureRow.Date + ", " + featureRow.GrainA + ", " + featureRow.DataA + ", " + featureRow.IsRowImputed);
+
+ // Expected output:
+ // Features column obtained post-transformation.
+ // 0, A, 2.0, false
+ // 1, A, 2.0, false
+ // 2, A, 2.0, true
+ // 3, A, 5.0, false
+ // 4, A, 5.0, true
+ // 5, A, 5.0, false
+ // 6, A, 5.0, true
+ // 7, A, 5.0, false
+ }
+
+ private class InputData
+ {
+ public long Date;
+ public string GrainA;
+ public float DataA;
+ }
+
+ private sealed class TransformedData
+ {
+ public long Date { get; set; }
+ public string GrainA { get; set; }
+ public float DataA { get; set; }
+ public bool IsRowImputed { get; set; }
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerMedian.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerMedian.cs
new file mode 100644
index 0000000000..dc310ba285
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/TimeSeriesImputerMedian.cs
@@ -0,0 +1,71 @@
+using System;
+using System.Collections.Generic;
+using Microsoft.ML;
+using Microsoft.ML.Data;
+using Microsoft.ML.Featurizers;
+
+namespace Samples.Dynamic
+{
+ public static class TimeSeriesImputerMedian
+ {
+ public static void Example()
+ {
+ // Create a new ML context, for ML.NET operations. It can be used for
+ // exception tracking and logging, as well as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Create a small dataset as an IEnumerable.
+ var samples = new[] { new InputData() { Date = 0, GrainA = "A", DataA = 2.0f },
+ new InputData() { Date = 1, GrainA = "A", DataA = float.NaN },
+ new InputData() { Date = 3, GrainA = "A", DataA = 5.0f },
+ new InputData() { Date = 5, GrainA = "A", DataA = float.NaN },
+ new InputData() { Date = 7, GrainA = "A", DataA = float.NaN }};
+
+ // Convert training data to IDataView.
+ var dataview = mlContext.Data.LoadFromEnumerable(samples);
+
+ // A pipeline for imputing the missing rows and values in the columns using the "Median" strategy.
+ var pipeline = mlContext.Transforms.TimeSeriesImputer("Date", new string[] { "GrainA" }, TimeSeriesImputerEstimator.ImputationStrategy.Median);
+
+ // The transformed data.
+ var transformedData = pipeline.Fit(dataview).Transform(dataview);
+
+ // Now let's take a look at what this did. The NaN values should be filled in with the column median,
+ // and rows should be created to fill in the missing gaps in the time column.
+ // We can extract the newly created columns as an IEnumerable of TransformedData.
+ var featuresColumn = mlContext.Data.CreateEnumerable(
+ transformedData, reuseRowObject: false);
+
+ // And we can write out a few rows
+ Console.WriteLine($"Features column obtained post-transformation.");
+ foreach (var featureRow in featuresColumn)
+ Console.WriteLine(featureRow.Date + ", " + featureRow.GrainA + ", " + featureRow.DataA + ", " + featureRow.IsRowImputed);
+
+ // Expected output:
+ // Features column obtained post-transformation.
+ // 0, A, 2, false
+ // 1, A, 3.5, false
+ // 2, A, 3.5, true
+ // 3, A, 5, false
+ // 4, A, 3.5, true
+ // 5, A, 3.5, false
+ // 6, A, 3.5, true
+ // 7, A, 3.5, false
+ }
+
+ private class InputData
+ {
+ public long Date;
+ public string GrainA;
+ public float DataA;
+ }
+
+ private sealed class TransformedData
+ {
+ public long Date { get; set; }
+ public string GrainA { get; set; }
+ public float DataA { get; set; }
+ public bool IsRowImputed { get; set; }
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/ToStringTransformerMultipleColumns.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/ToStringTransformerMultipleColumns.cs
new file mode 100644
index 0000000000..d2b8016647
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/ToStringTransformerMultipleColumns.cs
@@ -0,0 +1,80 @@
+using System;
+using System.Collections.Generic;
+using Microsoft.ML;
+using Microsoft.ML.Data;
+using Microsoft.ML.Featurizers;
+
+namespace Samples.Dynamic
+{
+ public static class ToStringMultipleColumns
+ {
+ public static void Example()
+ {
+ // Create a new ML context, for ML.NET operations. It can be used for
+ // exception tracking and logging, as well as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Create a small dataset as an IEnumerable.
+ var samples = new List()
+ {
+ new InputData(){ Feature1 = 0.1f, Feature2 = 1.1, Feature3 = 1 },
+
+ new InputData(){ Feature1 = 0.2f, Feature2 =1.2, Feature3 = 2 },
+
+ new InputData(){ Feature1 = 0.3f, Feature2 = 1.3, Feature3 = 3 },
+
+ new InputData(){ Feature1 = 0.4f, Feature2 = 1.4, Feature3 = 4 },
+
+ new InputData(){ Feature1 = 0.5f, Feature2 = 1.5, Feature3 = 5 },
+
+ new InputData(){ Feature1 = 0.6f, Feature2 = 1.6, Feature3 = 6 },
+ };
+
+ // Convert training data to IDataView.
+ var dataview = mlContext.Data.LoadFromEnumerable(samples);
+
+ // A pipeline for converting the "Feature1", "Feature2" and
+ // "Feature3" columns into their string representations
+ //
+ var pipeline = mlContext.Transforms.ToStringTransformer(new InputOutputColumnPair("Feature1"),
+ new InputOutputColumnPair("Feature2"), new InputOutputColumnPair("Feature3"));
+
+ // The transformed data.
+ var transformedData = pipeline.Fit(dataview).Transform(dataview);
+
+ // Now let's take a look at what this did.
+ // We can extract the newly created columns as an IEnumerable of
+ // TransformedData.
+ var featuresColumn = mlContext.Data.CreateEnumerable(
+ transformedData, reuseRowObject: false);
+
+ // And we can write out a few rows
+ Console.WriteLine($"Features column obtained post-transformation.");
+ foreach (var featureRow in featuresColumn)
+ Console.WriteLine(featureRow.Feature1 + " " + featureRow.Feature2 + " " + featureRow.Feature3);
+
+ // Expected output:
+ // Features column obtained post-transformation.
+ // 0.100000 1.100000 1
+ // 0.200000 1.200000 2
+ // 0.300000 1.300000 3
+ // 0.400000 1.400000 4
+ // 0.500000 1.500000 5
+ // 0.600000 1.600000 6
+ }
+
+ private class InputData
+ {
+ public float Feature1;
+ public double Feature2;
+ public int Feature3;
+ }
+
+ private sealed class TransformedData
+ {
+ public string Feature1 { get; set; }
+ public string Feature2 { get; set; }
+ public string Feature3 { get; set; }
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/ToStringTransformerSingleColumn.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/ToStringTransformerSingleColumn.cs
new file mode 100644
index 0000000000..1c11931d90
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/ToStringTransformerSingleColumn.cs
@@ -0,0 +1,79 @@
+using System;
+using System.Collections.Generic;
+using Microsoft.ML;
+using Microsoft.ML.Data;
+using Microsoft.ML.Featurizers;
+
+namespace Samples.Dynamic
+{
+ public static class ToStringSingleColumn
+ {
+ public static void Example()
+ {
+ // Create a new ML context, for ML.NET operations. It can be used for
+ // exception tracking and logging, as well as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Create a small dataset as an IEnumerable.
+ var samples = new List()
+ {
+ new InputData(){ Feature1 = 0.1f, Feature2 = 1.1, Feature3 = 1 },
+
+ new InputData(){ Feature1 = 0.2f, Feature2 =1.2, Feature3 = 2 },
+
+ new InputData(){ Feature1 = 0.3f, Feature2 = 1.3, Feature3 = 3 },
+
+ new InputData(){ Feature1 = 0.4f, Feature2 = 1.4, Feature3 = 4 },
+
+ new InputData(){ Feature1 = 0.5f, Feature2 = 1.5, Feature3 = 5 },
+
+ new InputData(){ Feature1 = 0.6f, Feature2 = 1.6, Feature3 = 6 },
+ };
+
+ // Convert training data to IDataView.
+ var dataview = mlContext.Data.LoadFromEnumerable(samples);
+
+ // A pipeline for converting the "Feature1" column into its string representations
+ //
+ var pipeline = mlContext.Transforms.ToStringTransformer("Feature1Output", "Feature1");
+
+ // The transformed data.
+ var transformedData = pipeline.Fit(dataview).Transform(dataview);
+
+ // Now let's take a look at what this did.
+ // We can extract the newly created columns as an IEnumerable of
+ // TransformedData.
+ var featuresColumn = mlContext.Data.CreateEnumerable(
+ transformedData, reuseRowObject: false);
+
+ // And we can write out a few rows
+ Console.WriteLine($"Features column obtained post-transformation.");
+ foreach (var featureRow in featuresColumn)
+ Console.WriteLine(featureRow.Feature1Output);
+
+ // Expected output:
+ // Features column obtained post-transformation.
+ // 0.100000
+ // 0.200000
+ // 0.300000
+ // 0.400000
+ // 0.500000
+ // 0.600000
+ }
+
+ private class InputData
+ {
+ public float Feature1;
+ public double Feature2;
+ public int Feature3;
+ }
+
+ private sealed class TransformedData
+ {
+ public float Feature1 { get; set; }
+ public double Feature2 { get; set; }
+ public int Feature3 { get; set; }
+ public string Feature1Output { get; set; }
+ }
+ }
+}
diff --git a/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj b/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj
index 71b9436701..8aec0bb5a0 100644
--- a/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj
+++ b/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj
@@ -11,6 +11,7 @@
+
diff --git a/pkg/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.nupkgproj b/pkg/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.nupkgproj
new file mode 100644
index 0000000000..44f4babb81
--- /dev/null
+++ b/pkg/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.nupkgproj
@@ -0,0 +1,14 @@
+
+
+
+ netstandard2.0
+ ML.NET additional featurizers for AutoML
+
+
+
+
+
+
+
+
+
diff --git a/pkg/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.symbols.nupkgproj b/pkg/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.symbols.nupkgproj
new file mode 100644
index 0000000000..483e51c61a
--- /dev/null
+++ b/pkg/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.symbols.nupkgproj
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs
index 4bf0dfe445..2d7ae0c6b6 100644
--- a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs
+++ b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs
@@ -40,9 +40,8 @@
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Dnn" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Transforms" + PublicKey.Value)]
-
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.AutoML" + PublicKey.Value)]
-
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Featurizers" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Internal.MetaLinearLearner" + InternalPublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "TreeVisualizer" + InternalPublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "TMSNlearnPrediction" + InternalPublicKey.Value)]
diff --git a/src/Microsoft.ML.Core/Utilities/Utils.cs b/src/Microsoft.ML.Core/Utilities/Utils.cs
index fbdb3a791d..4d17be39f7 100644
--- a/src/Microsoft.ML.Core/Utilities/Utils.cs
+++ b/src/Microsoft.ML.Core/Utilities/Utils.cs
@@ -966,6 +966,14 @@ private static MethodInfo MarshalInvokeCheckAndCreate(Type genArg, Delegat
return meth;
}
+ private static MethodInfo MarshalInvokeCheckAndCreate(Delegate func, Type[] genArgs)
+ {
+ var meth = MarshalActionInvokeCheckAndCreate(func, genArgs);
+ if (meth.ReturnType != typeof(TRet))
+ throw Contracts.ExceptParam(nameof(func), "Cannot be generic on return type");
+ return meth;
+ }
+
// REVIEW: n-argument versions? The multi-column re-application problem?
// Think about how to address these.
@@ -1092,6 +1100,28 @@ public static TRet MarshalInvoke
+ /// A 1 argument and n type version of .
+ ///
+ public static TRet MarshalInvoke(
+ Func func,
+ Type[] genArgs, TArg1 arg1)
+ {
+ var meth = MarshalInvokeCheckAndCreate(func, genArgs);
+ return (TRet)meth.Invoke(func.Target, new object[] { arg1});
+ }
+
+ ///
+ /// A 2 argument and n type version of .
+ ///
+ public static TRet MarshalInvoke(
+ Func func,
+ Type[] genArgs, TArg1 arg1, TArg2 arg2)
+ {
+ var meth = MarshalInvokeCheckAndCreate(func, genArgs);
+ return (TRet)meth.Invoke(func.Target, new object[] { arg1, arg2});
+ }
+
private static MethodInfo MarshalActionInvokeCheckAndCreate(Type genArg, Delegate func)
{
Contracts.CheckValue(genArg, nameof(genArg));
@@ -1104,6 +1134,18 @@ private static MethodInfo MarshalActionInvokeCheckAndCreate(Type genArg, Delegat
return meth;
}
+ private static MethodInfo MarshalActionInvokeCheckAndCreate(Delegate func, params Type[] typeArguments)
+ {
+ Contracts.CheckValue(typeArguments, nameof(typeArguments));
+ Contracts.CheckValue(func, nameof(func));
+ var meth = func.GetMethodInfo();
+ Contracts.CheckParam(meth.IsGenericMethod, nameof(func), "Should be generic but is not");
+ Contracts.CheckParam(meth.GetGenericArguments().Length == typeArguments.Length, nameof(func),
+ "Method should have exactly the same number of generic type parameters as list passed in but it does not.");
+ meth = meth.GetGenericMethodDefinition().MakeGenericMethod(typeArguments);
+ return meth;
+ }
+
///
/// This is akin to , except applied to
/// instead of .
diff --git a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
index efb86af7ca..b106683166 100644
--- a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
+++ b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
@@ -367,7 +367,7 @@ public Cursor(IChannelProvider provider, DataViewRowCursor input, RowToRowMapper
: base(provider, input)
{
var pred = parent.GetActiveOutputColumns(active);
- _getters = parent._mapper.CreateGetters(input, pred, out _disposer);
+ _getters = parent._mapperFactory == null ? parent._mapper.CreateGetters(input, pred, out _disposer) : parent._mapperFactory.Invoke(input.Schema).CreateGetters(input, pred, out _disposer);
_active = active;
_bindings = parent._bindings;
}
diff --git a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs
index db1be394a8..8cff3c9239 100644
--- a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs
+++ b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs
@@ -44,9 +44,8 @@
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.ResNet101" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.ResNet18" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.ResNet50" + PublicKey.Value)]
-
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Experimental" + PublicKey.Value)]
-
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Featurizers" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Internal.MetaLinearLearner" + InternalPublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "TMSNlearnPrediction" + InternalPublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.CntkWrapper" + InternalPublicKey.Value)]
diff --git a/src/Microsoft.ML.Featurizers/CategoryImputer.cs b/src/Microsoft.ML.Featurizers/CategoryImputer.cs
new file mode 100644
index 0000000000..bb63c56164
--- /dev/null
+++ b/src/Microsoft.ML.Featurizers/CategoryImputer.cs
@@ -0,0 +1,1599 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Security;
+using System.Text;
+using Microsoft.ML;
+using Microsoft.ML.CommandLine;
+using Microsoft.ML.Data;
+using Microsoft.ML.EntryPoints;
+using Microsoft.ML.Featurizers;
+using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Transforms;
+using static Microsoft.ML.Featurizers.CommonExtensions;
+
+[assembly: LoadableClass(typeof(CategoryImputerTransformer), null, typeof(SignatureLoadModel),
+ CategoryImputerTransformer.UserName, CategoryImputerTransformer.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(IRowMapper), typeof(CategoryImputerTransformer), null, typeof(SignatureLoadRowMapper),
+ CategoryImputerTransformer.UserName, CategoryImputerTransformer.LoaderSignature)]
+
+[assembly: EntryPointModule(typeof(CategoryImputerEntrypoint))]
+
+namespace Microsoft.ML.Featurizers
+{
+ public static class CategoryImputerExtensionClass
+ {
+ ///
+ /// Create a , which fills in the missing values in a column with the most frequent value.
+ ///
+ /// Transform Catalog
+ /// Output column name
+ /// Input column name, if null defaults to
+ ///
+ public static CategoryImputerEstimator CatagoryImputerTransformer(this TransformsCatalog catalog, string outputColumnName, string inputColumnName = null)
+ => CategoryImputerEstimator.Create(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName);
+
+ ///
+ /// Create a , which fills in the missing values in a column with the most frequent value.
+ ///
+ /// Transform Catalog
+ /// List of to fill in missing values
+ ///
+ public static CategoryImputerEstimator CatagoryImputerTransformer(this TransformsCatalog catalog, params InputOutputColumnPair[] columns)
+ => CategoryImputerEstimator.Create(CatalogUtils.GetEnvironment(catalog), columns);
+ }
+
+ ///
+ /// The CategoryImputer replaces missing values with the most common value in that column.
+ ///
+ ///
+ /// is not a trivial estimator and needs training.
+ ///
+ ///
+ /// ]]>
+ ///
+ ///
+ ///
+ ///
+ public sealed class CategoryImputerEstimator : IEstimator
+ {
+ private readonly Options _options;
+
+ private readonly IHost _host;
+
+ #region Options
+ internal sealed class Column : OneToOneColumn
+ {
+ internal static Column Parse(string str)
+ {
+ Contracts.AssertNonEmpty(str);
+
+ var res = new Column();
+ if (res.TryParse(str))
+ return res;
+ return null;
+ }
+
+ internal bool TryUnparse(StringBuilder sb)
+ {
+ Contracts.AssertValue(sb);
+ return TryUnparseCore(sb);
+ }
+ }
+
+ internal sealed class Options: TransformInputBase
+ {
+ [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition (optional form: name:src)",
+ Name = "Column", ShortName = "col", SortOrder = 1)]
+ public Column[] Columns;
+ }
+
+ #endregion
+
+ internal static CategoryImputerEstimator Create(IHostEnvironment env, string outputColumnName, string inputColumnName)
+ {
+ return new CategoryImputerEstimator(env, outputColumnName, inputColumnName);
+ }
+
+ internal static CategoryImputerEstimator Create(IHostEnvironment env, params InputOutputColumnPair[] columns)
+ {
+ var columnOptions = columns.Select(x => new Column { Name = x.OutputColumnName, Source = x.InputColumnName ?? x.OutputColumnName }).ToArray();
+ return new CategoryImputerEstimator(env, new Options { Columns = columnOptions });
+ }
+
+ internal CategoryImputerEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host = env.Register(nameof(CategoryImputerEstimator));
+ _options = new Options
+ {
+ Columns = new Column[1] { new Column() { Name = outputColumnName, Source = inputColumnName ?? outputColumnName } }
+ };
+ }
+
+ internal CategoryImputerEstimator(IHostEnvironment env, Options options)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host = env.Register(nameof(CategoryImputerEstimator));
+
+ foreach (var columnPair in options.Columns)
+ {
+ columnPair.Source = columnPair.Source ?? columnPair.Name;
+ }
+
+ _options = options;
+ }
+
+ public CategoryImputerTransformer Fit(IDataView input)
+ {
+ return new CategoryImputerTransformer(_host, input, _options.Columns);
+ }
+
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ var columns = inputSchema.ToDictionary(x => x.Name);
+
+ foreach (var column in _options.Columns)
+ {
+ var inputColumn = columns[column.Source];
+ columns[column.Name] = new SchemaShape.Column(column.Name, inputColumn.Kind,
+ inputColumn.ItemType, inputColumn.IsKey, inputColumn.Annotations);
+ }
+
+ return new SchemaShape(columns.Values);
+ }
+ }
+
+ public sealed class CategoryImputerTransformer : RowToRowTransformerBase, IDisposable
+ {
+ #region Class data members
+
+ internal const string Summary = "Fills in missing values in a column based on the most frequent value";
+ internal const string UserName = "CategoryImputer";
+ internal const string ShortName = "CategoryImputer";
+ internal const string LoadName = "CategoryImputer";
+ internal const string LoaderSignature = "CategoryImputer";
+
+ private readonly TypedColumn[] _columns;
+
+ #endregion
+
+ internal CategoryImputerTransformer(IHostEnvironment host, IDataView input, CategoryImputerEstimator.Column[] columns) :
+ base(host.Register(nameof(CategoryImputerEstimator)))
+ {
+ var schema = input.Schema;
+
+ _columns = columns.Select(x => TypedColumn.CreateTypedColumn(x.Name, x.Source, schema[x.Source].Type.RawType.ToString())).ToArray();
+ foreach (var column in _columns)
+ {
+ column.CreateTransformerFromEstimator(input);
+ }
+ }
+
+ // Factory method for SignatureLoadModel.
+ internal CategoryImputerTransformer(IHostEnvironment host, ModelLoadContext ctx) :
+ base(host.Register(nameof(CategoryImputerTransformer)))
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ // *** Binary format ***
+ // int number of column pairs
+ // for each column pair:
+ // string output column name
+ // string input column name
+ // column type
+ // int length of c++ byte array
+ // byte array from c++
+
+ var columnCount = ctx.Reader.ReadInt32();
+
+ _columns = new TypedColumn[columnCount];
+ for (int i = 0; i < columnCount; i++)
+ {
+ _columns[i] = TypedColumn.CreateTypedColumn(ctx.Reader.ReadString(), ctx.Reader.ReadString(), ctx.Reader.ReadString());
+
+ // Load the C++ state and create the C++ transformer.
+ var dataLength = ctx.Reader.ReadInt32();
+ var data = ctx.Reader.ReadByteArray(dataLength);
+ _columns[i].CreateTransformerFromSavedData(data);
+ }
+ }
+
+ // Factory method for SignatureLoadRowMapper.
+ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
+ => new CategoryImputerTransformer(env, ctx).MakeRowMapper(inputSchema);
+
+ private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "CATIMP T",
+ verWrittenCur: 0x00010001,
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(CategoryImputerTransformer).Assembly.FullName);
+ }
+
+ private protected override void SaveModel(ModelSaveContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ // *** Binary format ***
+ // int number of column pairs
+ // for each column pair:
+ // string output column name
+ // string input column name
+ // column type
+ // int length of c++ byte array
+ // byte array from c++
+
+ ctx.Writer.Write(_columns.Count());
+ foreach (var column in _columns)
+ {
+ ctx.Writer.Write(column.Name);
+ ctx.Writer.Write(column.Source);
+ ctx.Writer.Write(column.Type);
+
+ // Save C++ state
+ var data = column.CreateTransformerSaveData();
+ ctx.Writer.Write(data.Length);
+ ctx.Writer.Write(data);
+ }
+
+ }
+
+ public void Dispose()
+ {
+ foreach (var column in _columns)
+ {
+ column.Dispose();
+ }
+ }
+
+ #region ColumnInfo
+
+ // REVIEW: Since we can't do overloading on the native side due to the C style exports,
+ // this was the best way I could think handle it to allow for any conversions needed based on the data type.
+
+ #region BaseClass
+
+ internal delegate bool DestroyCppTransformerEstimator(IntPtr estimator, out IntPtr errorHandle);
+ internal delegate bool DestroyTransformerSaveData(IntPtr buffer, IntPtr bufferSize, out IntPtr errorHandle);
+ internal delegate bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+
+ internal abstract class TypedColumn : IDisposable
+ {
+ internal readonly string Name;
+ internal readonly string Source;
+ internal readonly string Type;
+ internal TypedColumn(string name, string source, string type)
+ {
+ Name = name;
+ Source = source;
+ Type = type;
+ }
+
+ internal abstract void CreateTransformerFromEstimator(IDataView input);
+ private protected abstract unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize);
+ private protected abstract bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle);
+ public abstract void Dispose();
+
+ internal byte[] CreateTransformerSaveData()
+ {
+
+ var success = CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle);
+ if(!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ using (var savedDataHandle = new SaveDataSafeHandle(buffer, bufferSize))
+ {
+ byte[] savedData = new byte[bufferSize.ToInt32()];
+ Marshal.Copy(buffer, savedData, 0, savedData.Length);
+ return savedData;
+ }
+ }
+
+ internal unsafe void CreateTransformerFromSavedData(byte[] data)
+ {
+ fixed(byte* rawData = data)
+ {
+ IntPtr dataSize = new IntPtr(data.Count());
+ CreateTransformerFromSavedDataHelper(rawData, dataSize);
+ }
+ }
+
+ internal static TypedColumn CreateTypedColumn(string name, string source, string type)
+ {
+ if (type == typeof(sbyte).ToString())
+ {
+ return new Int8TypedColumn(name, source);
+ }
+ else if (type == typeof(short).ToString())
+ {
+ return new Int16TypedColumn(name, source);
+ }
+ else if (type == typeof(int).ToString())
+ {
+ return new Int32TypedColumn(name, source);
+ }
+ else if (type == typeof(long).ToString())
+ {
+ return new Int64TypedColumn(name, source);
+ }
+ else if (type == typeof(byte).ToString())
+ {
+ return new UInt8TypedColumn(name, source);
+ }
+ else if (type == typeof(ushort).ToString())
+ {
+ return new UInt16TypedColumn(name, source);
+ }
+ else if (type == typeof(uint).ToString())
+ {
+ return new UInt32TypedColumn(name, source);
+ }
+ else if (type == typeof(ulong).ToString())
+ {
+ return new UInt64TypedColumn(name, source);
+ }
+ else if (type == typeof(float).ToString())
+ {
+ return new FloatTypedColumn(name, source);
+ }
+ else if (type == typeof(double).ToString())
+ {
+ return new DoubleTypedColumn(name, source);
+ }
+ else if (type == typeof(string).ToString())
+ {
+ return new StringTypedColumn(name, source);
+ }
+ else if (type == typeof(ReadOnlyMemory).ToString())
+ {
+ return new ReadOnlyCharTypedColumn(name, source);
+ }
+
+ throw new Exception($"Unsupported type {type}");
+ }
+ }
+
+ internal abstract class TypedColumn : TypedColumn
+ {
+ internal TypedColumn(string name, string source, string type) :
+ base(name, source, type)
+ {
+ }
+
+ internal abstract T Transform(T input);
+ private protected abstract bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle);
+ private protected abstract bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ private protected abstract bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle);
+ private protected abstract bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle);
+ private protected abstract bool FitHelper(TransformerEstimatorSafeHandle estimator, T input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected abstract bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected TransformerEstimatorSafeHandle CreateTransformerFromEstimatorBase(IDataView input)
+ {
+ var success = CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ using (var estimatorHandler = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorHelper))
+ {
+ var fitResult = FitResult.Continue;
+ while (fitResult != FitResult.Complete)
+ {
+ using (var data = input.GetColumn(Source).GetEnumerator())
+ {
+ while (fitResult == FitResult.Continue && data.MoveNext())
+ {
+ {
+ success = FitHelper(estimatorHandler, data.Current, out fitResult, out errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+ }
+
+ success = CompleteTrainingHelper(estimatorHandler, out fitResult, out errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+ }
+
+ success = CreateTransformerFromEstimatorHelper(estimatorHandler, out IntPtr transformer, out errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerHelper);
+ }
+ }
+ }
+
+ #endregion
+
+ #region Int8Column
+
+ internal sealed class Int8TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ internal Int8TypedColumn(string name, string source) :
+ base(name, source, typeof(sbyte).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, sbyte* input, out sbyte output, out IntPtr errorHandle);
+ internal unsafe override sbyte Transform(sbyte input)
+ {
+ sbyte* interopInput = input == 0 ? null : &input;
+ var success = TransformDataNative(_transformerHandler, interopInput, out sbyte output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, sbyte* input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, sbyte input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ sbyte* interopInput = input == 0 ? null : &input;
+ return FitNative(estimator, interopInput, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int8_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region Int16Column
+
+ internal sealed class Int16TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+
+ internal Int16TypedColumn(string name, string source) :
+ base(name, source, typeof(short).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, short* input, out short output, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_DestroyTransformedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+ internal unsafe override short Transform(short input)
+ {
+ short* interopInput = input == 0 ? null : &input;
+ var success = TransformDataNative(_transformerHandler, interopInput, out short output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, short* input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, short input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ short* interopInput = input == 0 ? null : &input;
+ return FitNative(estimator, interopInput, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int16_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region Int32Column
+
+ internal sealed class Int32TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ internal Int32TypedColumn(string name, string source) :
+ base(name, source, typeof(int).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, int* input, out int output, out IntPtr errorHandle);
+ internal unsafe override int Transform(int input)
+ {
+ int* interopInput = input == 0 ? null : &input;
+ var success = TransformDataNative(_transformerHandler, interopInput, out int output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, int* input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, int input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ int* interopInput = input == 0 ? null : &input;
+ return FitNative(estimator, interopInput, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int32_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region Int64Column
+
+ internal sealed class Int64TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ internal Int64TypedColumn(string name, string source) :
+ base(name, source, typeof(long).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, long* input, out long output, out IntPtr errorHandle);
+ internal unsafe override long Transform(long input)
+ {
+ long* interopInput = input == 0 ? null : &input;
+ var success = TransformDataNative(_transformerHandler, interopInput, out long output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, long* input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, long input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ long* interopInput = input == 0 ? null : &input;
+ return FitNative(estimator, interopInput, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_int64_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region UInt8Column
+
+ internal sealed class UInt8TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ internal UInt8TypedColumn(string name, string source) :
+ base(name, source, typeof(byte).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, byte* input, out byte output, out IntPtr errorHandle);
+ internal unsafe override byte Transform(byte input)
+ {
+ byte* interopInput = input == 0 ? null : &input;
+ var success = TransformDataNative(_transformerHandler, interopInput, out byte output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, byte* input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, byte input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ byte* interopInput = input == 0 ? null : &input;
+ return FitNative(estimator, interopInput, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint8_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region UInt16Column
+
+ internal sealed class UInt16TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ internal UInt16TypedColumn(string name, string source) :
+ base(name, source, typeof(ushort).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, ushort* input, out ushort output, out IntPtr errorHandle);
+ internal unsafe override ushort Transform(ushort input)
+ {
+ ushort* interopInput = input == 0 ? null : &input;
+ var success = TransformDataNative(_transformerHandler, interopInput, out ushort output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, ushort* input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, ushort input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ ushort* interopInput = input == 0 ? null : &input;
+ return FitNative(estimator, interopInput, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint16_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region UInt32Column
+
+ internal sealed class UInt32TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+
+ internal UInt32TypedColumn(string name, string source) :
+ base(name, source, typeof(uint).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, uint* input, out uint output, out IntPtr errorHandle);
+ internal unsafe override uint Transform(uint input)
+ {
+ uint* interopInput = input == 0 ? null : &input;
+ var success = TransformDataNative(_transformerHandler, interopInput, out uint output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, uint* input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, uint input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ uint* interopInput = input == 0 ? null : &input;
+ return FitNative(estimator, interopInput, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint32_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region UInt64Column
+
+ internal sealed class UInt64TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+
+ internal UInt64TypedColumn(string name, string source) :
+ base(name, source, typeof(ulong).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, ulong* input, out ulong output, out IntPtr errorHandle);
+ internal unsafe override ulong Transform(ulong input)
+ {
+ ulong* interopInput = input == 0 ? null : &input;
+ var success = TransformDataNative(_transformerHandler, interopInput, out ulong output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, ulong* input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, ulong input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ ulong* interopInput = input == 0 ? null : &input;
+ return FitNative(estimator, interopInput, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_uint64_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region FloatColumn
+
+ internal sealed class FloatTypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+
+ internal FloatTypedColumn(string name, string source) :
+ base(name, source, typeof(float).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, in float input, out float output, out IntPtr errorHandle);
+ internal override float Transform(float input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out float output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool FitNative(TransformerEstimatorSafeHandle estimator, in float input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool FitHelper(TransformerEstimatorSafeHandle estimator, float input, out FitResult fitResult, out IntPtr errorHandle) =>
+ FitNative(estimator, input, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_float_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region DoubleColumn
+
+ internal sealed class DoubleTypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+
+ internal DoubleTypedColumn(string name, string source) :
+ base(name, source, typeof(double).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, in double input, out double output, out IntPtr errorHandle);
+ internal override double Transform(double input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out double output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool FitNative(TransformerEstimatorSafeHandle estimator, in double input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool FitHelper(TransformerEstimatorSafeHandle estimator, double input, out FitResult fitResult, out IntPtr errorHandle) =>
+ FitNative(estimator, input, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_double_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region StringColumn
+
+ internal sealed class StringTypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+
+ internal StringTypedColumn(string name, string source) :
+ base(name, source, typeof(string).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, byte* input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_DestroyTransformedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+ internal unsafe override string Transform(string input)
+ {
+ // Convert to byte array with NullPointer at end or nullptr.
+ fixed (byte* interopInput = input == null? null : Encoding.UTF8.GetBytes(input + char.MinValue))
+ {
+ var result = TransformDataNative(_transformerHandler, interopInput, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+
+ if (!result)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+
+ using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative))
+ {
+ byte[] buffer = new byte[outputSize.ToInt32()];
+ Marshal.Copy(output, buffer, 0, buffer.Length);
+ return Encoding.UTF8.GetString(buffer);
+ }
+ }
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, byte* input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, string input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ fixed (byte* interopInput = input == null ? null : Encoding.UTF8.GetBytes(input + char.MinValue))
+ {
+ return FitNative(estimator, interopInput, out fitResult, out errorHandle);
+ }
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region ReadOnlyCharColumn
+
+ internal sealed class ReadOnlyCharTypedColumn : TypedColumn>
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+
+ internal ReadOnlyCharTypedColumn(string name, string source) :
+ base(name, source, typeof(ReadOnlyMemory).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, byte* input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_DestroyTransformedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+ internal unsafe override ReadOnlyMemory Transform(ReadOnlyMemory input)
+ {
+ var inputAsString = input.ToString();
+ fixed (byte* interopInput = string.IsNullOrEmpty(inputAsString) ? null : Encoding.UTF8.GetBytes(inputAsString + char.MinValue))
+ {
+ var result = TransformDataNative(_transformerHandler, interopInput, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+
+ if (!result)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+
+ if (outputSize.ToInt32() == 0)
+ return new ReadOnlyMemory(string.Empty.ToArray());
+
+ using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative))
+ {
+ byte[] buffer = new byte[outputSize.ToInt32()];
+ Marshal.Copy(output, buffer, 0, buffer.Length);
+ return new ReadOnlyMemory(Encoding.UTF8.GetString(buffer).ToArray());
+ }
+ }
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, byte* input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, ReadOnlyMemory input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ var inputAsString = input.ToString();
+ fixed (byte* interopInput = string.IsNullOrEmpty(inputAsString) ? null : Encoding.UTF8.GetBytes(inputAsString + char.MinValue))
+ {
+ return FitNative(estimator, interopInput, out fitResult, out errorHandle);
+ }
+ }
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "CatImputerFeaturizer_string_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #endregion
+
+ private sealed class Mapper : MapperBase
+ {
+
+ #region Class data members
+
+ private readonly CategoryImputerTransformer _parent;
+
+ #endregion
+
+ public Mapper(CategoryImputerTransformer parent, DataViewSchema inputSchema) :
+ base(parent.Host.Register(nameof(Mapper)), inputSchema, parent)
+ {
+ _parent = parent;
+ }
+
+ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
+ {
+ return _parent._columns.Select(x => new DataViewSchema.DetachedColumn(x.Name, ColumnTypeExtensions.PrimitiveTypeFromType(Type.GetType(x.Type)))).ToArray();
+ }
+
+ private Delegate MakeGetter(DataViewRow input, int iinfo)
+ {
+ ValueGetter result = (ref T dst) =>
+ {
+ var inputColumn = input.Schema[_parent._columns[iinfo].Source];
+ var srcGetterScalar = input.GetGetter(inputColumn);
+
+ T value = default;
+ srcGetterScalar(ref value);
+
+ dst = ((TypedColumn)_parent._columns[iinfo]).Transform(value);
+
+ };
+
+ return result;
+ }
+
+ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer)
+ {
+ disposer = null;
+ Type inputType = input.Schema[_parent._columns[iinfo].Source].Type.RawType;
+ return Utils.MarshalInvoke(MakeGetter, inputType, input, iinfo);
+ }
+
+ private protected override Func GetDependenciesCore(Func activeOutput)
+ {
+ var active = new bool[InputSchema.Count];
+ for (int i = 0; i < InputSchema.Count; i++)
+ {
+ if (_parent._columns.Any(x => x.Source == InputSchema[i].Name))
+ {
+ active[i] = true;
+ }
+ }
+
+ return col => active[col];
+ }
+
+ private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
+ }
+ }
+
+ internal static class CategoryImputerEntrypoint
+ {
+ [TlcModule.EntryPoint(Name = "Transforms.CategoryImputer",
+ Desc = CategoryImputerTransformer.Summary,
+ UserName = CategoryImputerTransformer.UserName,
+ ShortName = CategoryImputerTransformer.ShortName)]
+ public static CommonOutputs.TransformOutput ImputeToKey(IHostEnvironment env, CategoryImputerEstimator.Options input)
+ {
+ var h = EntryPointUtils.CheckArgsAndCreateHost(env, CategoryImputerTransformer.ShortName, input);
+ var xf = new CategoryImputerEstimator(h, input).Fit(input.Data).Transform(input.Data);
+ return new CommonOutputs.TransformOutput()
+ {
+ Model = new TransformModelImpl(h, xf, input.Data),
+ OutputData = xf
+ };
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Featurizers/Common.cs b/src/Microsoft.ML.Featurizers/Common.cs
new file mode 100644
index 0000000000..1ab76c5d78
--- /dev/null
+++ b/src/Microsoft.ML.Featurizers/Common.cs
@@ -0,0 +1,190 @@
+using System;
+using System.Collections.Generic;
+using System.Runtime.InteropServices;
+using System.Security;
+using System.Text;
+using Microsoft.Win32.SafeHandles;
+
+namespace Microsoft.ML.Featurizers
+{
+ #region Native Function Declarations
+
+ #endregion
+
+ internal enum FitResult : byte
+ {
+ Complete = 1, Continue, ResetAndContinue
+ }
+
+ // Not all these types are currently supported. This is so the ordering will allign with the native code.
+ internal enum TypeId : uint
+ {
+ String = 1, SByte, Short, Int, Long, Byte, UShort,
+ UInt, ULong, Float16, Float32, Double, Complex64,
+ Complex128, BFloat16, Bool, Timepoint, Duration,
+
+ LastStaticValue,
+ Tensor = 0x1001 | LastStaticValue + 1,
+ SparseTensor = 0x1001 | LastStaticValue + 2,
+ Tabular = 0x1001 | LastStaticValue + 3,
+ Nullable = 0x1001 | LastStaticValue + 4,
+ Vector = 0x1001 | LastStaticValue + 5,
+ MapId = 0x1002 | LastStaticValue + 6
+ };
+
+ [StructLayout(LayoutKind.Sequential, Pack = 1)]
+ internal unsafe struct NativeBinaryArchiveData
+ {
+ public byte* Data;
+ public IntPtr DataSize;
+ }
+
+ #region SafeHandles
+
+ internal class ErrorInfoSafeHandle : SafeHandleZeroOrMinusOneIsInvalid
+ {
+ [DllImport("Featurizers", EntryPoint = "DestroyErrorInfo", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyErrorInfo(IntPtr error);
+
+ public ErrorInfoSafeHandle(IntPtr handle) : base(true)
+ {
+ SetHandle(handle);
+ }
+
+ protected override bool ReleaseHandle()
+ {
+ return DestroyErrorInfo(handle);
+ }
+ }
+
+ internal class ErrorInfoStringSafeHandle : SafeHandleZeroOrMinusOneIsInvalid
+ {
+ [DllImport("Featurizers", EntryPoint = "DestroyErrorInfoString", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyErrorInfoString(IntPtr errorString, IntPtr errorStringSize);
+
+ private IntPtr _length;
+ public ErrorInfoStringSafeHandle(IntPtr handle, IntPtr length) : base(true)
+ {
+ SetHandle(handle);
+ _length = length;
+ }
+
+ protected override bool ReleaseHandle()
+ {
+ return DestroyErrorInfoString(handle, _length);
+ }
+ }
+
+ internal delegate bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+ internal class TransformedDataSafeHandle : SafeHandleZeroOrMinusOneIsInvalid
+ {
+ private DestroyTransformedDataNative _destroySaveDataHandler;
+ private IntPtr _dataSize;
+
+ public TransformedDataSafeHandle(IntPtr handle, IntPtr dataSize, DestroyTransformedDataNative destroyCppTransformerEstimator) : base(true)
+ {
+ SetHandle(handle);
+ _dataSize = dataSize;
+ _destroySaveDataHandler = destroyCppTransformerEstimator;
+ }
+
+ protected override bool ReleaseHandle()
+ {
+ // Not sure what to do with error stuff here. There shoudln't ever be one though.
+ return _destroySaveDataHandler(handle, _dataSize, out IntPtr errorHandle);
+ }
+ }
+
+ internal delegate bool DestroyCppTransformerEstimator(IntPtr estimator, out IntPtr errorHandle);
+ internal class TransformerEstimatorSafeHandle : SafeHandleZeroOrMinusOneIsInvalid
+ {
+ private DestroyCppTransformerEstimator _destroyCppTransformerEstimator;
+ public TransformerEstimatorSafeHandle(IntPtr handle, DestroyCppTransformerEstimator destroyCppTransformerEstimator) : base(true)
+ {
+ SetHandle(handle);
+ _destroyCppTransformerEstimator = destroyCppTransformerEstimator;
+ }
+
+ protected override bool ReleaseHandle()
+ {
+ // Not sure what to do with error stuff here. There shouldn't ever be one though.
+ return _destroyCppTransformerEstimator(handle, out IntPtr errorHandle);
+ }
+ }
+
+ // Destroying saved data is always the same.
+ internal delegate bool DestroyTransformerSaveData(IntPtr buffer, IntPtr bufferSize, out IntPtr errorHandle);
+
+ internal class SaveDataSafeHandle : SafeHandleZeroOrMinusOneIsInvalid
+ {
+ private readonly IntPtr _dataSize;
+
+ // TODO: Update with correct entry point.
+ [DllImport("Featurizers", EntryPoint = "DestroyTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerSaveDataNative(IntPtr buffer, IntPtr bufferSize, out IntPtr error);
+
+ public SaveDataSafeHandle(IntPtr handle, IntPtr dataSize) : base(true)
+ {
+ SetHandle(handle);
+ _dataSize = dataSize;
+ }
+
+ protected override bool ReleaseHandle()
+ {
+ // Not sure what to do with error stuff here. There shoudln't ever be one though.
+ return DestroyTransformerSaveDataNative(handle, _dataSize, out _);
+ }
+ }
+
+ #endregion
+
+ internal static class CommonExtensions
+ {
+ [DllImport("Featurizers", EntryPoint = "GetErrorInfoString", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool GetErrorInfoString(IntPtr error, out IntPtr errorHandleString, out IntPtr errorHandleStringSize);
+
+ internal static string GetErrorDetailsAndFreeNativeMemory(IntPtr errorHandle)
+ {
+ using (var error = new ErrorInfoSafeHandle(errorHandle))
+ {
+ GetErrorInfoString(errorHandle, out IntPtr errorHandleString, out IntPtr errorHandleStringSize);
+ using (var errorString = new ErrorInfoStringSafeHandle(errorHandleString, errorHandleStringSize))
+ {
+ byte[] buffer = new byte[errorHandleStringSize.ToInt32()];
+ Marshal.Copy(errorHandleString, buffer, 0, buffer.Length);
+
+ return Encoding.UTF8.GetString(buffer);
+ }
+ }
+ }
+ internal static TypeId GetNativeTypeIdFromType(this Type type)
+ {
+ if (type == typeof(byte))
+ return TypeId.Byte;
+ else if (type == typeof(short))
+ return TypeId.Short;
+ else if (type == typeof(int))
+ return TypeId.Int;
+ else if (type == typeof(long))
+ return TypeId.Long;
+ else if (type == typeof(byte))
+ return TypeId.Byte;
+ else if (type == typeof(ushort))
+ return TypeId.UShort;
+ else if (type == typeof(uint))
+ return TypeId.UInt;
+ else if (type == typeof(ulong))
+ return TypeId.ULong;
+ else if (type == typeof(float))
+ return TypeId.Float32;
+ else if (type == typeof(double))
+ return TypeId.Double;
+ else if (type == typeof(bool))
+ return TypeId.Bool;
+ else if (type == typeof(ReadOnlyMemory))
+ return TypeId.String;
+
+ throw new InvalidOperationException($"Unsupported type {type}");
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs
new file mode 100644
index 0000000000..7e0745396c
--- /dev/null
+++ b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs
@@ -0,0 +1,781 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Reflection;
+using System.Runtime.InteropServices;
+using System.Security;
+using System.Text;
+using Microsoft.ML;
+using Microsoft.ML.CommandLine;
+using Microsoft.ML.Data;
+using Microsoft.ML.EntryPoints;
+using Microsoft.ML.Featurizers;
+using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Transforms;
+using Microsoft.Win32.SafeHandles;
+using static Microsoft.ML.Featurizers.CommonExtensions;
+
+[assembly: LoadableClass(typeof(DateTimeTransformer), null, typeof(SignatureLoadModel),
+ DateTimeTransformer.UserName, DateTimeTransformer.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(IRowMapper), typeof(DateTimeTransformer), null, typeof(SignatureLoadRowMapper),
+ DateTimeTransformer.UserName, DateTimeTransformer.LoaderSignature)]
+
+[assembly: EntryPointModule(typeof(DateTimeTransformerEntrypoint))]
+
+namespace Microsoft.ML.Featurizers
+{
+
+ public static class DateTimeTransformerExtensionClass
+ {
+ ///
+ /// Create a , which splits up the input column specified by
+ /// into all its individual datetime components. Input column must be of type Int64 representing the number of seconds since the unix epoc.
+ /// This transformer will append the to all the output columns. If is empty,
+ /// then all the columns are returned. Otherwise, the columns listed in the array will be dropped from the return value.
+ ///
+ /// Transform catalog
+ /// Input column name
+ /// Prefix to add to the generated columns
+ /// List of columns to drop, if any
+ ///
+ public static DateTimeTransformerEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, params DateTimeTransformerEstimator.ColumnsProduced[] columnsToDrop)
+ => new DateTimeTransformerEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop);
+
+ ///
+ /// Create a , which splits up the input column specified by
+ /// into all its individual datetime components. Input column must be of type Int64 representing the number of seconds since the unix epoc.
+ /// This transformer will append the to all the output columns. If is empty,
+ /// then all the columns are returned. Otherwise, the columns listed in the array will be dropped from the return value. If you specify a country,
+ /// Holiday details will be looked up for that country as well.
+ ///
+ /// Transform catalog
+ /// Input column name
+ /// Prefix to add to the generated columns
+ /// List of columns to drop, if any
+ /// Country name to get holiday details for
+ ///
+ public static DateTimeTransformerEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, DateTimeTransformerEstimator.ColumnsProduced[] columnsToDrop = null, DateTimeTransformerEstimator.Countries country = DateTimeTransformerEstimator.Countries.None)
+ => new DateTimeTransformerEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop, country);
+
+ #region ColumnsProduced static extentions
+
+ internal static Type GetRawColumnType(this DateTimeTransformerEstimator.ColumnsProduced column)
+ {
+ switch (column)
+ {
+ case DateTimeTransformerEstimator.ColumnsProduced.Year:
+ case DateTimeTransformerEstimator.ColumnsProduced.YearIso:
+ return typeof(int);
+ case DateTimeTransformerEstimator.ColumnsProduced.DayOfYear:
+ case DateTimeTransformerEstimator.ColumnsProduced.WeekOfMonth:
+ return typeof(ushort);
+ case DateTimeTransformerEstimator.ColumnsProduced.MonthLabel:
+ case DateTimeTransformerEstimator.ColumnsProduced.AmPmLabel:
+ case DateTimeTransformerEstimator.ColumnsProduced.DayOfWeekLabel:
+ case DateTimeTransformerEstimator.ColumnsProduced.HolidayName:
+ return typeof(ReadOnlyMemory);
+ default:
+ return typeof(byte);
+ }
+ }
+
+ #endregion
+ }
+
+ ///
+ /// The DateTimeTransformerEstimator splits up a date into all of its sub parts as individual columns. It generates these fields with a user specified prefix:
+ /// int Year, byte Month, byte Day, byte Hour, byte Minute, byte Second, byte AmPm, byte Hour12, byte DayOfWeek, byte DayOfQuarter,
+ /// ushort DayOfYear, ushort WeekOfMonth, byte QuarterOfYear, byte HalfOfYear, byte WeekIso, int YearIso, string MonthLabel, string AmPmLabel,
+ /// string DayOfWeekLabel, string HolidayName, byte IsPaidTimeOff
+ ///
+ /// You can optionally specify a country and it will pull holiday information about the country as well
+ ///
+ ///
+ /// is a trivial estimator and does not need training.
+ ///
+ ///
+ /// ]]>
+ ///
+ ///
+ ///
+ ///
+ public sealed class DateTimeTransformerEstimator : IEstimator
+ {
+ private readonly Options _options;
+
+ private readonly IHost _host;
+
+ #region Options
+ internal sealed class Options: TransformInputBase
+ {
+ [Argument(ArgumentType.Required, HelpText = "Input column", Name = "Source", ShortName = "src", SortOrder = 1)]
+ public string Source;
+
+ // This transformer adds columns
+ [Argument(ArgumentType.Required, HelpText = "Output column prefix", Name = "Prefix", ShortName = "pre", SortOrder = 2)]
+ public string Prefix;
+
+ [Argument(ArgumentType.MultipleUnique, HelpText = "Columns to drop after the DateTime Expansion", Name = "ColumnsToDrop", ShortName = "drop", SortOrder = 3)]
+ public ColumnsProduced[] ColumnsToDrop;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Country to get holidays for. Defaults to none if not passed", Name = "Country", ShortName = "ctry", SortOrder = 4)]
+ public Countries Country = Countries.None;
+ }
+
+ #endregion
+
+ public DateTimeTransformerEstimator(IHostEnvironment env, string inputColumnName, string columnPrefix, ColumnsProduced[] columnsToDrop, Countries country = Countries.None)
+ {
+
+ Contracts.CheckValue(env, nameof(env));
+ _host = Contracts.CheckRef(env, nameof(env)).Register("DateTimeTransformerEstimator");
+ _host.CheckValue(inputColumnName, nameof(inputColumnName), "Input column should not be null.");
+
+ _options = new Options
+ {
+ Source = inputColumnName,
+ Prefix = columnPrefix,
+ ColumnsToDrop = columnsToDrop == null ? Array.Empty() : columnsToDrop,
+ Country = country
+ };
+ }
+
+ internal DateTimeTransformerEstimator(IHostEnvironment env, Options options)
+ {
+
+ Contracts.CheckValue(env, nameof(env));
+ _host = Contracts.CheckRef(env, nameof(env)).Register("DateTimeTransformerEstimator");
+
+ _options = options;
+ _options.ColumnsToDrop = _options.ColumnsToDrop == null ? Array.Empty() : _options.ColumnsToDrop;
+ }
+
+ public DateTimeTransformer Fit(IDataView input)
+ {
+ return new DateTimeTransformer(_host, _options.Source, _options.Prefix, _options.ColumnsToDrop, _options.Country);
+ }
+
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ var columns = inputSchema.ToDictionary(x => x.Name);
+
+ foreach (ColumnsProduced column in Enum.GetValues(typeof(ColumnsProduced)))
+ if (_options.ColumnsToDrop == null || !_options.ColumnsToDrop.Contains(column))
+ {
+ columns[_options.Prefix + column.ToString()] = new SchemaShape.Column(_options.Prefix + column.ToString(), SchemaShape.Column.VectorKind.Scalar,
+ ColumnTypeExtensions.PrimitiveTypeFromType(column.GetRawColumnType()), false, null);
+ }
+
+ return new SchemaShape(columns.Values);
+ }
+
+ #region Enums
+ public enum ColumnsProduced : byte
+ {
+ Year = 1, Month, Day, Hour, Minute, Second, AmPm, Hour12, DayOfWeek, DayOfQuarter, DayOfYear,
+ WeekOfMonth, QuarterOfYear, HalfOfYear, WeekIso, YearIso, MonthLabel, AmPmLabel, DayOfWeekLabel,
+ HolidayName, IsPaidTimeOff
+ };
+
+ public enum Countries : byte
+ {
+ None = 1,
+ Argentina, Australia, Austria, Belarus, Belgium, Brazil, Canada, Colombia, Croatia, Czech, Denmark,
+ England, Finland, France, Germany, Hungary, India, Ireland, IsleofMan, Italy, Japan, Mexico, Netherlands,
+ NewZealand, NorthernIreland, Norway, Poland, Portugal, Scotland, Slovenia, SouthAfrica, Spain, Sweden, Switzerland,
+ Ukraine, UnitedKingdom, UnitedStates, Wales
+ }
+
+ #endregion
+ }
+
+ public sealed class DateTimeTransformer : RowToRowTransformerBase, IDisposable
+ {
+ #region Class data members
+
+ internal const string Summary = "Splits a date time value into each individual component";
+ internal const string UserName = "DateTime Transform";
+ internal const string ShortName = "DateTimeTransform";
+ internal const string LoadName = "DateTimeTransform";
+ internal const string LoaderSignature = "DateTimeTransform";
+ private DateTimeTypedColumn _column;
+
+ private DateTimeTransformerEstimator.ColumnsProduced[] _columnsToDrop;
+ private byte[] _activeColumnMapping;
+
+ #endregion
+
+ public DateTimeTransformer(IHostEnvironment env, string inputColumnName, string columnPrefix, DateTimeTransformerEstimator.ColumnsProduced[] columnsToDrop, DateTimeTransformerEstimator.Countries country ) :
+ base(env.Register(nameof(DateTimeTransformer)))
+ {
+
+ _columnsToDrop = columnsToDrop;
+ var activeColumnLength = Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced)).Length - (_columnsToDrop == null ? 0 : _columnsToDrop.Length);
+ _activeColumnMapping = new byte[activeColumnLength];
+ var index = 0;
+ foreach(DateTimeTransformerEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced)))
+ {
+ if (_columnsToDrop == null || !_columnsToDrop.Contains(column))
+ {
+ _activeColumnMapping[index++] = (byte)column;
+ }
+ }
+
+ _column = new DateTimeTypedColumn(inputColumnName, columnPrefix);
+ _column.CreateTransformerFromEstimator(country);
+ }
+
+ // Factory method for SignatureLoadModel.
+ internal DateTimeTransformer(IHostEnvironment host, ModelLoadContext ctx) :
+ base(host.Register(nameof(DateTimeTransformer)))
+ {
+
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ // *** Binary format ***
+ // name of input column
+ // column prefix
+ // byte length of columns to drop array
+ // byte array of columns to drop
+ // length of C++ state array
+ // C++ byte state array
+
+ _column = new DateTimeTypedColumn(ctx.Reader.ReadString(), ctx.Reader.ReadString());
+
+ var dropColumnsLength = ctx.Reader.ReadInt32();
+ if (dropColumnsLength > 0)
+ {
+ _columnsToDrop = new DateTimeTransformerEstimator.ColumnsProduced[dropColumnsLength];
+ //read in enum bytes
+ for (int i = 0; i < dropColumnsLength; i++)
+ _columnsToDrop[i] = (DateTimeTransformerEstimator.ColumnsProduced)ctx.Reader.ReadByte();
+ }
+
+ _activeColumnMapping = new byte[Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced)).Length - dropColumnsLength];
+ var index = 0;
+ foreach (DateTimeTransformerEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced)))
+ {
+ if (_columnsToDrop == null || !_columnsToDrop.Contains(column))
+ {
+ _activeColumnMapping[index++] = (byte)column;
+ }
+ }
+
+ var dataLength = ctx.Reader.ReadInt32();
+ var data = ctx.Reader.ReadByteArray(dataLength);
+ _column.CreateTransformerFromSavedData(data);
+ }
+
+ // Factory method for SignatureLoadRowMapper.
+ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
+ => new DateTimeTransformer(env, ctx).MakeRowMapper(inputSchema);
+
+ private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "DATETI T",
+ verWrittenCur: 0x00010001,
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(DateTimeTransformer).Assembly.FullName);
+ }
+
+ private protected override void SaveModel(ModelSaveContext ctx)
+ {
+
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ // *** Binary format ***
+ // name of input column
+ // column prefix
+ // byte length of columns to drop array
+ // byte array of columns to drop
+ // length of C++ state array
+ // C++ byte state array
+
+ ctx.Writer.Write(_column.Source);
+ ctx.Writer.Write(_column.Prefix);
+
+ ctx.Writer.Write(_columnsToDrop == null ? 0 : _columnsToDrop.Length);
+ if (_columnsToDrop != null)
+ {
+ foreach (var toDrop in _columnsToDrop)
+ ctx.Writer.Write((byte)toDrop);
+ }
+
+ var data = _column.CreateTransformerSaveData();
+ ctx.Writer.Write(data.Length);
+ ctx.Writer.Write(data);
+ }
+
+ public void Dispose()
+ {
+ _column.Dispose();
+ }
+
+ #region C++ Safe handle classes
+
+ internal class TransformedDataSafeHandle : SafeHandleZeroOrMinusOneIsInvalid
+ {
+ private readonly DestroyTransformedDataNative _destroyTransformedDataHandler;
+
+ public TransformedDataSafeHandle(IntPtr handle, DestroyTransformedDataNative destroyTransformedDataHandler) : base(true)
+ {
+ SetHandle(handle);
+ _destroyTransformedDataHandler = destroyTransformedDataHandler;
+ }
+
+ protected override bool ReleaseHandle()
+ {
+ // Not sure what to do with error stuff here. There shoudln't ever be one though.
+ return _destroyTransformedDataHandler(handle, out IntPtr errorHandle);
+ }
+ }
+
+ #endregion
+
+ #region TimePoint
+
+ [StructLayoutAttribute(LayoutKind.Sequential)]
+ internal struct TimePoint
+ {
+ public int Year;
+ public byte Month;
+ public byte Day;
+ public byte Hour;
+ public byte Minute;
+ public byte Second;
+ public byte AmPm;
+ public byte Hour12;
+ public byte DayOfWeek;
+ public byte DayOfQuarter;
+ public ushort DayOfYear;
+ public ushort WeekOfMonth;
+ public byte QuarterOfYear;
+ public byte HalfOfYear;
+ public byte WeekIso;
+ public int YearIso;
+ public string MonthLabel;
+ public string AmPmLabel;
+ public string DayOfWeekLabel;
+ public string HolidayName;
+ public byte IsPaidTimeOff;
+
+ internal unsafe TimePoint(byte* rawData)
+ {
+ int intPtrSize = sizeof(IntPtr);
+
+ Year = *(int*)rawData;
+ rawData += 4;
+
+ Month = *rawData++;
+ Day = *rawData++;
+ Hour = *rawData++;
+ Minute = *rawData++;
+ Second = *rawData++;
+ AmPm = *rawData++;
+ Hour12 = *rawData++;
+ DayOfWeek = *rawData++;
+ DayOfQuarter = *rawData++;
+ DayOfYear = *(ushort*)rawData;
+ rawData += 2;
+
+ WeekOfMonth = *(ushort*)rawData;
+ rawData += 2;
+
+ QuarterOfYear = *rawData++;
+ HalfOfYear = *rawData++;
+ WeekIso = *rawData++;
+ YearIso = *(int*)rawData;
+ rawData += 4;
+
+ // Convert char * to string
+ MonthLabel = GetStringFromPointer(ref rawData, intPtrSize);
+ AmPmLabel = GetStringFromPointer(ref rawData, intPtrSize);
+ DayOfWeekLabel = GetStringFromPointer(ref rawData, intPtrSize);
+ HolidayName = GetStringFromPointer(ref rawData, intPtrSize);
+ IsPaidTimeOff = *rawData;
+ }
+
+ // Converts a pointer to a native char* to a string and increments pointer by to the next value.
+ // The length of the string is stored at byte* + sizeof(IntPtr).
+ private static unsafe string GetStringFromPointer(ref byte* rawData, int intPtrSize)
+ {
+ byte[] buffer;
+ if (intPtrSize == 4) // 32 bit machine
+ buffer = new byte[*(uint*)(rawData + intPtrSize)];
+ else // 64 bit machine
+ buffer = new byte[*(ulong*)(rawData + intPtrSize)];
+
+ if (buffer.Length == 0)
+ {
+ rawData += intPtrSize * 2;
+ return string.Empty;
+ }
+
+ Marshal.Copy(new IntPtr(*(int**)rawData), buffer, 0, buffer.Length);
+ rawData += intPtrSize * 2;
+
+ return Encoding.UTF8.GetString(buffer);
+ }
+
+ };
+
+ #endregion
+
+ #region BaseClass
+
+ internal delegate bool DestroyCppTransformerEstimator(IntPtr estimator, out IntPtr errorHandle);
+ internal delegate bool DestroyTransformerSaveData(IntPtr buffer, IntPtr bufferSize, out IntPtr errorHandle);
+ internal delegate bool DestroyTransformedDataNative(IntPtr output, out IntPtr errorHandle);
+
+ internal abstract class TypedColumn : IDisposable
+ {
+ internal readonly string Source;
+ internal readonly string Prefix;
+
+ internal TypedColumn(string source, string prefix)
+ {
+ Source = source;
+ Prefix = prefix;
+ }
+
+ internal abstract void CreateTransformerFromEstimator(DateTimeTransformerEstimator.Countries country);
+ private protected abstract unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize);
+ private protected unsafe abstract bool CreateEstimatorHelper(byte* countryName, byte* dataRootDir, out IntPtr estimator, out IntPtr errorHandle);
+ private protected abstract bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ private protected abstract bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle);
+ private protected abstract bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle);
+ private protected abstract bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle);
+ public abstract void Dispose();
+
+ private protected unsafe TransformerEstimatorSafeHandle CreateTransformerFromEstimatorBase(DateTimeTransformerEstimator.Countries country)
+ {
+ bool success;
+ IntPtr errorHandle;
+ IntPtr estimator;
+ if (country == DateTimeTransformerEstimator.Countries.None)
+ {
+ success = CreateEstimatorHelper(null, null, out estimator, out errorHandle);
+ }
+ else
+ {
+ fixed (byte* dataRootDir = Encoding.UTF8.GetBytes(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location) + char.MinValue))
+ fixed (byte* countryPointer = Encoding.UTF8.GetBytes(Enum.GetName(typeof(DateTimeTransformerEstimator.Countries), country) + char.MinValue))
+ {
+ success = CreateEstimatorHelper(countryPointer, dataRootDir, out estimator, out errorHandle);
+ }
+ }
+ if (!success)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+
+ using (var estimatorHandler = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorHelper))
+ {
+
+ success = CreateTransformerFromEstimatorHelper(estimatorHandler, out IntPtr transformer, out errorHandle);
+ if (!success)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+
+ return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerHelper);
+ }
+ }
+
+ internal byte[] CreateTransformerSaveData()
+ {
+
+ var success = CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ using (var savedDataHandle = new SaveDataSafeHandle(buffer, bufferSize))
+ {
+ byte[] savedData = new byte[bufferSize.ToInt32()];
+ Marshal.Copy(buffer, savedData, 0, savedData.Length);
+ return savedData;
+ }
+ }
+
+ internal unsafe void CreateTransformerFromSavedData(byte[] data)
+ {
+ fixed (byte* rawData = data)
+ {
+ IntPtr dataSize = new IntPtr(data.Count());
+ CreateTransformerFromSavedDataHelper(rawData, dataSize);
+ }
+ }
+ }
+
+ internal abstract class TypedColumn : TypedColumn
+ {
+ internal TypedColumn(string source, string prefix) :
+ base(source, prefix)
+ {
+ }
+
+ internal abstract TimePoint Transform(T input);
+
+ }
+
+ #endregion
+
+ #region DateTimeTypedColumn
+
+ internal sealed class DateTimeTypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ internal DateTimeTypedColumn(string source, string prefix) :
+ base(source, prefix)
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateEstimator"), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateEstimatorNative(byte* countryName, byte* dataRootDir, out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_DestroyEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_DestroyTransformer"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+ internal override unsafe void CreateTransformerFromEstimator(DateTimeTransformerEstimator.Countries country)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(country);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateTransformerFromSavedDataWithDataRoot"), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, byte* dataRootDir, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ fixed (byte* dataRootDir = Encoding.UTF8.GetBytes(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location) + char.MinValue))
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, dataRootDir, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+ }
+
+ [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_Transform"), SuppressUnmanagedCodeSecurity]
+ private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, long input, out IntPtr output, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_DestroyTransformedData"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformedDataNative(IntPtr output, out IntPtr errorHandle);
+ internal override TimePoint Transform(long input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ using (var handler = new TransformedDataSafeHandle(output, DestroyTransformedDataNative))
+ {
+ unsafe
+ {
+ return new TimePoint((byte*)output.ToPointer());
+ }
+ }
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected unsafe override bool CreateEstimatorHelper(byte* countryName, byte* dataRootDir, out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(countryName, dataRootDir, out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ private sealed class Mapper : MapperBase
+ {
+
+ #region Class data members
+
+ private readonly DateTimeTransformer _parent;
+ private ConcurrentDictionary _cache;
+ private ConcurrentQueue _oldestKeys;
+
+ #endregion
+
+ public Mapper(DateTimeTransformer parent, DataViewSchema inputSchema) :
+ base(parent.Host.Register(nameof(Mapper)), inputSchema, parent)
+ {
+ _parent = parent;
+ _cache = new ConcurrentDictionary();
+ _oldestKeys = new ConcurrentQueue();
+ }
+
+ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
+ {
+ var columns = new List();
+
+ foreach (DateTimeTransformerEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced)))
+ if (_parent._columnsToDrop == null || !_parent._columnsToDrop.Contains(column))
+ {
+ columns.Add(new DataViewSchema.DetachedColumn(_parent._column.Prefix + column.ToString(),
+ ColumnTypeExtensions.PrimitiveTypeFromType(column.GetRawColumnType())));
+ }
+
+ return columns.ToArray();
+ }
+
+ private Delegate MakeGetter(DataViewRow input, int iinfo)
+ {
+ ValueGetter result = (ref T dst) =>
+ {
+ long dateTime = default;
+ var getter = input.GetGetter(input.Schema[_parent._column.Source]);
+ getter(ref dateTime);
+
+ if (!_cache.TryGetValue(dateTime, out TimePoint timePoint)){
+ _cache[dateTime] = _parent._column.Transform(dateTime);
+ _oldestKeys.Enqueue(dateTime);
+ timePoint = _cache[dateTime];
+
+ // If more than 100 cached items, remove 20
+ if(_cache.Count > 100)
+ {
+ for(int i = 0; i < 20; i++)
+ {
+ long key;
+ while (!_oldestKeys.TryDequeue(out key)) { }
+ while (!_cache.TryRemove(key, out TimePoint removedValue)) { }
+ }
+ }
+ }
+
+ if (iinfo == 0)
+ dst = (T)Convert.ChangeType(timePoint.Year, typeof(T));
+ else if (iinfo == 1)
+ dst = (T)Convert.ChangeType(timePoint.Month, typeof(T));
+ else if (iinfo == 2)
+ dst = (T)Convert.ChangeType(timePoint.Day, typeof(T));
+ else if (iinfo == 3)
+ dst = (T)Convert.ChangeType(timePoint.Hour, typeof(T));
+ else if (iinfo == 4)
+ dst = (T)Convert.ChangeType(timePoint.Minute, typeof(T));
+ else if (iinfo == 5)
+ dst = (T)Convert.ChangeType(timePoint.Second, typeof(T));
+ else if (iinfo == 6)
+ dst = (T)Convert.ChangeType(timePoint.AmPm, typeof(T));
+ else if (iinfo == 7)
+ dst = (T)Convert.ChangeType(timePoint.Hour12, typeof(T));
+ else if (iinfo == 8)
+ dst = (T)Convert.ChangeType(timePoint.DayOfWeek, typeof(T));
+ else if (iinfo == 9)
+ dst = (T)Convert.ChangeType(timePoint.DayOfQuarter, typeof(T));
+ else if (iinfo == 10)
+ dst = (T)Convert.ChangeType(timePoint.DayOfYear, typeof(T));
+ else if (iinfo == 11)
+ dst = (T)Convert.ChangeType(timePoint.WeekOfMonth, typeof(T));
+ else if (iinfo == 12)
+ dst = (T)Convert.ChangeType(timePoint.QuarterOfYear, typeof(T));
+ else if (iinfo == 13)
+ dst = (T)Convert.ChangeType(timePoint.HalfOfYear, typeof(T));
+ else if (iinfo == 14)
+ dst = (T)Convert.ChangeType(timePoint.WeekIso, typeof(T));
+ else if (iinfo == 15)
+ dst = (T)Convert.ChangeType(timePoint.YearIso, typeof(T));
+ else if (iinfo == 16)
+ dst = (T)Convert.ChangeType(timePoint.MonthLabel.AsMemory(), typeof(T));
+ else if (iinfo == 17)
+ dst = (T)Convert.ChangeType(timePoint.AmPmLabel.AsMemory(), typeof(T));
+ else if (iinfo == 18)
+ dst = (T)Convert.ChangeType(timePoint.DayOfWeekLabel.AsMemory(), typeof(T));
+ else if (iinfo == 19)
+ dst = (T)Convert.ChangeType(timePoint.HolidayName.AsMemory(), typeof(T));
+ else
+ dst = (T)Convert.ChangeType(timePoint.IsPaidTimeOff, typeof(T));
+ };
+
+ return result;
+ }
+
+ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer)
+ {
+ disposer = null;
+
+ var outputColumn = (int)_parent._activeColumnMapping[iinfo];
+
+ // Have to subtract 1 from the output column since the enum starts and 1 and not 0.
+ return Utils.MarshalInvoke(MakeGetter, ((DateTimeTransformerEstimator.ColumnsProduced)outputColumn).GetRawColumnType(), input, outputColumn - 1);
+ }
+
+ private protected override Func GetDependenciesCore(Func activeOutput)
+ {
+ var active = new bool[InputSchema.Count];
+ for (int i = 0; i < InputSchema.Count; i++)
+ {
+ if (InputSchema[i].Name.Equals(_parent._column.Source))
+ {
+ active[i] = true;
+ }
+ }
+
+ return col => active[col];
+ }
+
+ private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
+ }
+ }
+
+ internal static class DateTimeTransformerEntrypoint
+ {
+ [TlcModule.EntryPoint(Name = "Transforms.DateTimeSplitter",
+ Desc = DateTimeTransformer.Summary,
+ UserName = DateTimeTransformer.UserName,
+ ShortName = DateTimeTransformer.ShortName)]
+ public static CommonOutputs.TransformOutput DateTimeSplit(IHostEnvironment env, DateTimeTransformerEstimator.Options input)
+ {
+ var h = EntryPointUtils.CheckArgsAndCreateHost(env, DateTimeTransformer.ShortName, input);
+ var xf = new DateTimeTransformerEstimator(h, input).Fit(input.Data).Transform(input.Data);
+ return new CommonOutputs.TransformOutput()
+ {
+ Model = new TransformModelImpl(h, xf, input.Data),
+ OutputData = xf
+ };
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj b/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj
new file mode 100644
index 0000000000..f3a35c5d85
--- /dev/null
+++ b/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj
@@ -0,0 +1,18 @@
+
+
+
+ netstandard2.0
+ Microsoft.ML.Featurizers
+ true
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/Microsoft.ML.Featurizers/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Featurizers/Properties/AssemblyInfo.cs
new file mode 100644
index 0000000000..f7c9d934e5
--- /dev/null
+++ b/src/Microsoft.ML.Featurizers/Properties/AssemblyInfo.cs
@@ -0,0 +1,11 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Runtime.CompilerServices;
+using Microsoft.ML;
+
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.EntryPoints" + PublicKey.Value)]
+
+[assembly: WantsToBeBestFriends]
diff --git a/src/Microsoft.ML.Featurizers/RobustScaler.cs b/src/Microsoft.ML.Featurizers/RobustScaler.cs
new file mode 100644
index 0000000000..739a25c2b2
--- /dev/null
+++ b/src/Microsoft.ML.Featurizers/RobustScaler.cs
@@ -0,0 +1,1616 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Security;
+using System.Text;
+using Microsoft.ML;
+using Microsoft.ML.CommandLine;
+using Microsoft.ML.Data;
+using Microsoft.ML.EntryPoints;
+using Microsoft.ML.Featurizers;
+using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Transforms;
+using static Microsoft.ML.Featurizers.CommonExtensions;
+
+[assembly: LoadableClass(typeof(RobustScalerTransformer), null, typeof(SignatureLoadModel),
+ RobustScalerTransformer.UserName, RobustScalerTransformer.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(IRowMapper), typeof(RobustScalerTransformer), null, typeof(SignatureLoadRowMapper),
+ RobustScalerTransformer.UserName, RobustScalerTransformer.LoaderSignature)]
+
+[assembly: EntryPointModule(typeof(RobustScalerEntrypoint))]
+
+namespace Microsoft.ML.Featurizers
+{
+ public static class RobustScalerExtensionClass
+ {
+ public static RobustScalerEstimator RobustScalerTransformer(this TransformsCatalog catalog, string outputColumnName, string inputColumnName = null, bool center = true, bool scale = true, float quantileMin = 25.0f, float quantileMax = 75.0f)
+ {
+ var options = new RobustScalerEstimator.Options
+ {
+ Columns = new RobustScalerEstimator.Column[1] { new RobustScalerEstimator.Column() { Name = outputColumnName, Source = inputColumnName ?? outputColumnName } },
+ Center = center,
+ Scale = scale,
+ QuantileMin = quantileMin,
+ QuantileMax = quantileMax
+ };
+
+ return new RobustScalerEstimator(CatalogUtils.GetEnvironment(catalog), options);
+ }
+
+ public static RobustScalerEstimator RobustScalerTransformer(this TransformsCatalog catalog, InputOutputColumnPair[] columns, bool center = true, bool scale = true, float quantileMin = 25.0f, float quantileMax = 75.0f)
+ {
+ var options = new RobustScalerEstimator.Options
+ {
+ Columns = columns.Select(x => new RobustScalerEstimator.Column { Name = x.OutputColumnName, Source = x.InputColumnName ?? x.OutputColumnName }).ToArray(),
+ Center = center,
+ Scale = scale,
+ QuantileMin = quantileMin,
+ QuantileMax = quantileMax
+ };
+
+ return new RobustScalerEstimator(CatalogUtils.GetEnvironment(catalog), options);
+ }
+ }
+
+ ///
+ /// RobustScalar Featurizer scales features using statistics that are robust to outliers, by removing the median and scaling the data according to the quantile range
+ /// (defaults to IQR: Interquartile Range). Centering and scaling happen independently on each feature by computing the relevant statistics on the samples in the training set.
+ /// Median and interquartile range are then stored to be used on later data using the transform method.
+ ///
+ ///
+ /// is not a trivial estimator and needs training.
+ ///
+ ///
+ /// ]]>
+ ///
+ ///
+ ///
+ ///
+ public sealed class RobustScalerEstimator : IEstimator
+ {
+ private Options _options;
+
+ private readonly IHost _host;
+
+ // For determining what the output type is.
+ private static readonly Type[] _floatTypes = new Type[] { typeof(byte), typeof(sbyte), typeof(short), typeof(ushort), typeof(float) };
+
+ #region Options
+
+ internal sealed class Column : OneToOneColumn
+ {
+ internal static Column Parse(string str)
+ {
+ Contracts.AssertNonEmpty(str);
+
+ var res = new Column();
+ if (res.TryParse(str))
+ return res;
+ return null;
+ }
+
+ internal bool TryUnparse(StringBuilder sb)
+ {
+ Contracts.AssertValue(sb);
+ return TryUnparseCore(sb);
+ }
+ }
+
+ internal sealed class Options : TransformInputBase
+ {
+ [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition (optional form: name:src)",
+ Name = "Column", ShortName = "col", SortOrder = 1)]
+ public Column[] Columns;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "If True, center the data before scaling.",
+ Name = "Center", ShortName = "ctr", SortOrder = 2)]
+ public bool Center = true;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "If True, scale the data to interquartile range.",
+ Name = "Scale", ShortName = "sc", SortOrder = 3)]
+ public bool Scale = true;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Min for the quantile range used to calculate scale.",
+ Name = "QuantileMin", ShortName = "min", SortOrder = 4)]
+ public float QuantileMin = 25.0f;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Max for the quantile range used to calculate scale.",
+ Name = "QuantileMax", ShortName = "max", SortOrder = 5)]
+ public float QuantileMax = 75.0f;
+ }
+
+ #endregion
+
+ internal RobustScalerEstimator(IHostEnvironment env, Options options)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host = env.Register(nameof(RobustScalerEstimator));
+ Contracts.Check(options.QuantileMin >= 0.0f && options.QuantileMin < options.QuantileMax && options.QuantileMax <= 100.0f, "Invalid QuantileRange provided");
+ Contracts.CheckNonEmpty(options.Columns, nameof(options.Columns));
+
+ _options = options;
+ }
+
+ public RobustScalerTransformer Fit(IDataView input)
+ {
+ return new RobustScalerTransformer(_host, input, _options);
+ }
+
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ var columns = inputSchema.ToDictionary(x => x.Name);
+
+ foreach (var column in _options.Columns)
+ {
+ var inputColumn = columns[column.Source];
+
+ if (!RobustScalerTransformer.TypedColumn.IsColumnTypeSupported(inputColumn.ItemType.RawType))
+ throw new InvalidOperationException($"Type {inputColumn.ItemType.RawType.ToString()} for column {column.Name} not a supported type.");
+
+ if (_floatTypes.Contains(inputColumn.ItemType.RawType))
+ {
+ columns[column.Name] = new SchemaShape.Column(column.Name, inputColumn.Kind,
+ ColumnTypeExtensions.PrimitiveTypeFromType(typeof(float)), inputColumn.IsKey, inputColumn.Annotations);
+ }
+ else
+ {
+ columns[column.Name] = new SchemaShape.Column(column.Name, inputColumn.Kind,
+ ColumnTypeExtensions.PrimitiveTypeFromType(typeof(double)), inputColumn.IsKey, inputColumn.Annotations);
+ }
+
+ }
+ return new SchemaShape(columns.Values);
+ }
+ }
+
+ public sealed class RobustScalerTransformer : RowToRowTransformerBase, IDisposable
+ {
+ #region Class data members
+
+ internal const string Summary = "Removes the median and scales the data according to the quantile range.";
+ internal const string UserName = "RobustScalerTransformer";
+ internal const string ShortName = "RobustScalerTransformer";
+ internal const string LoadName = "RobustScalerTransformer";
+ internal const string LoaderSignature = "RobustScalerTransformer";
+
+ private TypedColumn[] _columns;
+ private RobustScalerEstimator.Options _options;
+
+ #endregion
+
+ internal RobustScalerTransformer(IHostEnvironment host, IDataView input, RobustScalerEstimator.Options options) :
+ base(host.Register(nameof(RobustScalerTransformer)))
+ {
+ var schema = input.Schema;
+ _options = options;
+
+ _columns = options.Columns.Select(x => TypedColumn.CreateTypedColumn(x.Name, x.Source, schema[x.Source].Type.RawType.ToString(), this)).ToArray();
+ foreach (var column in _columns)
+ {
+ column.CreateTransformerFromEstimator(input);
+ }
+ }
+
+ // Factory method for SignatureLoadModel.
+ internal RobustScalerTransformer(IHostEnvironment host, ModelLoadContext ctx) :
+ base(host.Register(nameof(RobustScalerTransformer)))
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ // *** Binary format ***
+ // int number of column pairs
+ // for each column pair:
+ // string output column name
+ // string input column name
+ // column type
+ // int length of c++ byte array
+ // byte array from c++
+
+ var columnCount = ctx.Reader.ReadInt32();
+
+ _columns = new TypedColumn[columnCount];
+ for (int i = 0; i < columnCount; i++)
+ {
+ _columns[i] = TypedColumn.CreateTypedColumn(ctx.Reader.ReadString(), ctx.Reader.ReadString(), ctx.Reader.ReadString(), this);
+
+ // Load the C++ state and create the C++ transformer.
+ var dataLength = ctx.Reader.ReadInt32();
+ var data = ctx.Reader.ReadByteArray(dataLength);
+ _columns[i].CreateTransformerFromSavedData(data);
+ }
+ }
+
+ // Factory method for SignatureLoadRowMapper.
+ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
+ => new RobustScalerTransformer(env, ctx).MakeRowMapper(inputSchema);
+
+ private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "RbScal T",
+ verWrittenCur: 0x00010001,
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(RobustScalerTransformer).Assembly.FullName);
+ }
+
+ private protected override void SaveModel(ModelSaveContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ // *** Binary format ***
+ // int number of column pairs
+ // for each column pair:
+ // string output column name
+ // string input column name
+ // column type
+ // int length of c++ byte array
+ // byte array from c++
+
+ ctx.Writer.Write(_columns.Count());
+ foreach (var column in _columns)
+ {
+ ctx.Writer.Write(column.Name);
+ ctx.Writer.Write(column.Source);
+ ctx.Writer.Write(column.Type);
+
+ // Save C++ state
+ var data = column.CreateTransformerSaveData();
+ ctx.Writer.Write(data.Length);
+ ctx.Writer.Write(data);
+ }
+ }
+
+ public void Dispose()
+ {
+ foreach (var column in _columns)
+ {
+ column.Dispose();
+ }
+ }
+
+ #region ColumnInfo
+
+ #region BaseClass
+
+ internal abstract class TypedColumn : IDisposable
+ {
+ internal readonly string Name;
+ internal readonly string Source;
+ internal readonly string Type;
+
+ private static readonly Type[] _supportedTypes = new[] { typeof(sbyte), typeof(short), typeof(int), typeof(long), typeof(byte), typeof(ushort),
+ typeof(uint), typeof(ulong), typeof(float), typeof(double) };
+
+ internal TypedColumn(string name, string source, string type)
+ {
+ Name = name;
+ Source = source;
+ Type = type;
+ }
+
+ internal abstract void CreateTransformerFromEstimator(IDataView input);
+ private protected abstract unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize);
+ private protected abstract bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle);
+ public abstract void Dispose();
+
+ public abstract Type ReturnType();
+
+ internal byte[] CreateTransformerSaveData()
+ {
+
+ var success = CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ using (var savedDataHandle = new SaveDataSafeHandle(buffer, bufferSize))
+ {
+ byte[] savedData = new byte[bufferSize.ToInt32()];
+ Marshal.Copy(buffer, savedData, 0, savedData.Length);
+ return savedData;
+ }
+ }
+
+ internal unsafe void CreateTransformerFromSavedData(byte[] data)
+ {
+ fixed (byte* rawData = data)
+ {
+ IntPtr dataSize = new IntPtr(data.Count());
+ CreateTransformerFromSavedDataHelper(rawData, dataSize);
+ }
+ }
+
+ internal static bool IsColumnTypeSupported(Type type)
+ {
+ return _supportedTypes.Contains(type);
+ }
+
+ internal static TypedColumn CreateTypedColumn(string name, string source, string type, RobustScalerTransformer parent)
+ {
+ if (type == typeof(sbyte).ToString())
+ {
+ return new Int8TypedColumn(name, source, parent);
+ }
+ else if (type == typeof(short).ToString())
+ {
+ return new Int16TypedColumn(name, source, parent);
+ }
+ else if (type == typeof(int).ToString())
+ {
+ return new Int32TypedColumn(name, source, parent);
+ }
+ else if (type == typeof(long).ToString())
+ {
+ return new Int64TypedColumn(name, source, parent);
+ }
+ else if (type == typeof(byte).ToString())
+ {
+ return new UInt8TypedColumn(name, source, parent);
+ }
+ else if (type == typeof(ushort).ToString())
+ {
+ return new UInt16TypedColumn(name, source, parent);
+ }
+ else if (type == typeof(uint).ToString())
+ {
+ return new UInt32TypedColumn(name, source, parent);
+ }
+ else if (type == typeof(ulong).ToString())
+ {
+ return new UInt64TypedColumn(name, source, parent);
+ }
+ else if (type == typeof(float).ToString())
+ {
+ return new FloatTypedColumn(name, source, parent);
+ }
+ else if (type == typeof(double).ToString())
+ {
+ return new DoubleTypedColumn(name, source, parent);
+ }
+
+ throw new InvalidOperationException($"Column {name} has an unsupported type {type}.");
+ }
+ }
+
+ internal abstract class TypedColumn : TypedColumn
+ {
+ internal TypedColumn(string name, string source, string type) :
+ base(name, source, type)
+ {
+ }
+
+ internal abstract TOutputType Transform(TSourceType input);
+ private protected abstract bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle);
+ private protected abstract bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ private protected abstract bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle);
+ private protected abstract bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle);
+ private protected abstract bool FitHelper(TransformerEstimatorSafeHandle estimator, TSourceType input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected abstract bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected abstract bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle);
+ private protected TransformerEstimatorSafeHandle CreateTransformerFromEstimatorBase(IDataView input)
+ {
+ var success = CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ using (var estimatorHandle = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorHelper))
+ {
+ if (!IsTrainingComplete(estimatorHandle))
+ {
+ var fitResult = FitResult.Continue;
+ while (fitResult != FitResult.Complete)
+ {
+ fitResult = FitResult.Continue;
+ using (var data = input.GetColumn(Source).GetEnumerator())
+ {
+ while (fitResult == FitResult.Continue && data.MoveNext())
+ {
+ {
+ success = FitHelper(estimatorHandle, data.Current, out fitResult, out errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+ }
+
+ success = CompleteTrainingHelper(estimatorHandle, out fitResult, out errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+ }
+ }
+
+ success = CreateTransformerFromEstimatorHelper(estimatorHandle, out IntPtr transformer, out errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerHelper);
+ }
+ }
+ }
+
+ #endregion
+
+ #region Int8Column
+
+ internal sealed class Int8TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ private RobustScalerTransformer _parent;
+ internal Int8TypedColumn(string name, string source, RobustScalerTransformer parent) :
+ base(name, source, typeof(sbyte).ToString())
+ {
+ _parent = parent;
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, sbyte input, out float output, out IntPtr errorHandle);
+ internal unsafe override float Transform(sbyte input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out float output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle)
+ {
+ if (_parent._options.Scale)
+ return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle);
+ else
+ return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle);
+ }
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, sbyte input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, sbyte input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ return FitNative(estimator, input, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int8_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle);
+ private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle)
+ {
+ var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return isTrainingComplete;
+ }
+
+ public override Type ReturnType()
+ {
+ return typeof(float);
+ }
+ }
+
+ #endregion
+
+ #region UInt8Column
+
+ internal sealed class UInt8TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ private RobustScalerTransformer _parent;
+ internal UInt8TypedColumn(string name, string source, RobustScalerTransformer parent) :
+ base(name, source, typeof(byte).ToString())
+ {
+ _parent = parent;
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, byte input, out float output, out IntPtr errorHandle);
+ internal unsafe override float Transform(byte input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out float output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle)
+ {
+ if (_parent._options.Scale)
+ return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle);
+ else
+ return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle);
+ }
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, byte input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, byte input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ return FitNative(estimator, input, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint8_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle);
+ private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle)
+ {
+ var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return isTrainingComplete;
+ }
+
+ public override Type ReturnType()
+ {
+ return typeof(float);
+ }
+ }
+
+ #endregion
+
+ #region Int16Column
+
+ internal sealed class Int16TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ private RobustScalerTransformer _parent;
+ internal Int16TypedColumn(string name, string source, RobustScalerTransformer parent) :
+ base(name, source, typeof(short).ToString())
+ {
+ _parent = parent;
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, short input, out float output, out IntPtr errorHandle);
+ internal unsafe override float Transform(short input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out float output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle)
+ {
+ if (_parent._options.Scale)
+ return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle);
+ else
+ return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle);
+ }
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, short input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, short input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ return FitNative(estimator, input, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int16_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle);
+ private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle)
+ {
+ var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return isTrainingComplete;
+ }
+
+ public override Type ReturnType()
+ {
+ return typeof(float);
+ }
+ }
+
+ #endregion
+
+ #region UInt16Column
+
+ internal sealed class UInt16TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ private RobustScalerTransformer _parent;
+ internal UInt16TypedColumn(string name, string source, RobustScalerTransformer parent) :
+ base(name, source, typeof(ushort).ToString())
+ {
+ _parent = parent;
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, ushort input, out float output, out IntPtr errorHandle);
+ internal unsafe override float Transform(ushort input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out float output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle)
+ {
+ if (_parent._options.Scale)
+ return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle);
+ else
+ return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle);
+ }
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, ushort input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, ushort input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ return FitNative(estimator, input, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint16_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle);
+ private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle)
+ {
+ var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return isTrainingComplete;
+ }
+
+ public override Type ReturnType()
+ {
+ return typeof(float);
+ }
+ }
+
+ #endregion
+
+ #region Int32Column
+
+ internal sealed class Int32TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ private RobustScalerTransformer _parent;
+ internal Int32TypedColumn(string name, string source, RobustScalerTransformer parent) :
+ base(name, source, typeof(int).ToString())
+ {
+ _parent = parent;
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, int input, out double output, out IntPtr errorHandle);
+ internal unsafe override double Transform(int input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out double output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle)
+ {
+ if (_parent._options.Scale)
+ return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle);
+ else
+ return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle);
+ }
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, int input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, int input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ return FitNative(estimator, input, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int32_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle);
+ private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle)
+ {
+ var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return isTrainingComplete;
+ }
+
+ public override Type ReturnType()
+ {
+ return typeof(double);
+ }
+ }
+
+ #endregion
+
+ #region UInt32Column
+
+ internal sealed class UInt32TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ private RobustScalerTransformer _parent;
+ internal UInt32TypedColumn(string name, string source, RobustScalerTransformer parent) :
+ base(name, source, typeof(uint).ToString())
+ {
+ _parent = parent;
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, uint input, out double output, out IntPtr errorHandle);
+ internal unsafe override double Transform(uint input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out double output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle)
+ {
+ if (_parent._options.Scale)
+ return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle);
+ else
+ return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle);
+ }
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, uint input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, uint input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ return FitNative(estimator, input, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint32_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle);
+ private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle)
+ {
+ var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return isTrainingComplete;
+ }
+
+ public override Type ReturnType()
+ {
+ return typeof(double);
+ }
+ }
+
+ #endregion
+
+ #region Int64Column
+
+ internal sealed class Int64TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ private RobustScalerTransformer _parent;
+ internal Int64TypedColumn(string name, string source, RobustScalerTransformer parent) :
+ base(name, source, typeof(long).ToString())
+ {
+ _parent = parent;
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, long input, out double output, out IntPtr errorHandle);
+ internal unsafe override double Transform(long input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out double output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle)
+ {
+ if (_parent._options.Scale)
+ return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle);
+ else
+ return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle);
+ }
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, long input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, long input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ return FitNative(estimator, input, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_int64_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle);
+ private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle)
+ {
+ var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return isTrainingComplete;
+ }
+
+ public override Type ReturnType()
+ {
+ return typeof(double);
+ }
+ }
+
+ #endregion
+
+ #region UInt64Column
+
+ internal sealed class UInt64TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ private RobustScalerTransformer _parent;
+ internal UInt64TypedColumn(string name, string source, RobustScalerTransformer parent) :
+ base(name, source, typeof(ulong).ToString())
+ {
+ _parent = parent;
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, ulong input, out double output, out IntPtr errorHandle);
+ internal unsafe override double Transform(ulong input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out double output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle)
+ {
+ if (_parent._options.Scale)
+ return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle);
+ else
+ return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle);
+ }
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, ulong input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, ulong input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ return FitNative(estimator, input, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_uint64_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle);
+ private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle)
+ {
+ var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return isTrainingComplete;
+ }
+
+ public override Type ReturnType()
+ {
+ return typeof(double);
+ }
+ }
+
+ #endregion
+
+ #region FloatColumn
+
+ internal sealed class FloatTypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ private RobustScalerTransformer _parent;
+ internal FloatTypedColumn(string name, string source, RobustScalerTransformer parent) :
+ base(name, source, typeof(float).ToString())
+ {
+ _parent = parent;
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, float input, out float output, out IntPtr errorHandle);
+ internal unsafe override float Transform(float input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out float output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle)
+ {
+ if (_parent._options.Scale)
+ return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle);
+ else
+ return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle);
+ }
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, float input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, float input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ return FitNative(estimator, input, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_float_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle);
+ private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle)
+ {
+ var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return isTrainingComplete;
+ }
+
+ public override Type ReturnType()
+ {
+ return typeof(float);
+ }
+ }
+
+ #endregion
+
+ #region DoubleColumn
+
+ internal sealed class DoubleTypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ private RobustScalerTransformer _parent;
+ internal DoubleTypedColumn(string name, string source, RobustScalerTransformer parent) :
+ base(name, source, typeof(double).ToString())
+ {
+ _parent = parent;
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(bool withCentering, float qRangeMin, float qRangeMax, out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator(IDataView input)
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase(input);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_CreateTransformerFromSavedData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_Transform", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, double input, out double output, out IntPtr errorHandle);
+ internal unsafe override double Transform(double input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out double output, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return output;
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle)
+ {
+ if (_parent._options.Scale)
+ return CreateEstimatorNative(_parent._options.Center, _parent._options.QuantileMin, _parent._options.QuantileMax, out estimator, out errorHandle);
+ else
+ return CreateEstimatorNative(_parent._options.Center, -1, -1, out estimator, out errorHandle);
+ }
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, double input, out FitResult fitResult, out IntPtr errorHandle);
+ private protected unsafe override bool FitHelper(TransformerEstimatorSafeHandle estimator, double input, out FitResult fitResult, out IntPtr errorHandle)
+ {
+ return FitNative(estimator, input, out fitResult, out errorHandle);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+ private protected override bool CompleteTrainingHelper(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle) =>
+ CompleteTrainingNative(estimator, out fitResult, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "RobustScalarFeaturizer_double_t_IsTrainingComplete", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool IsTrainingCompleteNative(TransformerEstimatorSafeHandle transformer, out bool isTrainingComplete, out IntPtr errorHandle);
+ private protected override bool IsTrainingComplete(TransformerEstimatorSafeHandle estimatorHandle)
+ {
+ var success = IsTrainingCompleteNative(estimatorHandle, out bool isTrainingComplete, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return isTrainingComplete;
+ }
+
+ public override Type ReturnType()
+ {
+ return typeof(double);
+ }
+ }
+
+ #endregion
+
+ #endregion // Column Info
+
+ private sealed class Mapper : MapperBase
+ {
+
+ #region Class data members
+
+ private readonly RobustScalerTransformer _parent;
+
+ #endregion
+
+ public Mapper(RobustScalerTransformer parent, DataViewSchema inputSchema) :
+ base(parent.Host.Register(nameof(Mapper)), inputSchema, parent)
+ {
+ _parent = parent;
+ }
+
+ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
+ {
+ return _parent._columns.Select(x => new DataViewSchema.DetachedColumn(x.Name, ColumnTypeExtensions.PrimitiveTypeFromType(x.ReturnType()))).ToArray();
+ }
+
+ private Delegate MakeGetter(DataViewRow input, int iinfo)
+ {
+ ValueGetter result = (ref TOutputType dst) =>
+ {
+ var inputColumn = input.Schema[_parent._columns[iinfo].Source];
+ var srcGetterScalar = input.GetGetter(inputColumn);
+
+ TSourceType value = default;
+ srcGetterScalar(ref value);
+
+ dst = ((TypedColumn)_parent._columns[iinfo]).Transform(value);
+
+ };
+
+ return result;
+ }
+
+ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer)
+ {
+ disposer = null;
+ Type inputType = input.Schema[_parent._columns[iinfo].Source].Type.RawType;
+ Type outputType = _parent._columns[iinfo].ReturnType();
+
+ return Utils.MarshalInvoke(MakeGetter, new Type[] { inputType, outputType }, input, iinfo);
+ }
+
+ private protected override Func GetDependenciesCore(Func activeOutput)
+ {
+ var active = new bool[InputSchema.Count];
+ for (int i = 0; i < InputSchema.Count; i++)
+ {
+ if (_parent._columns.Any(x => x.Source == InputSchema[i].Name))
+ {
+ active[i] = true;
+ }
+ }
+
+ return col => active[col];
+ }
+
+ private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
+ }
+ }
+
+ internal static class RobustScalerEntrypoint
+ {
+ [TlcModule.EntryPoint(Name = "Transforms.RobustScaler",
+ Desc = RobustScalerTransformer.Summary,
+ UserName = RobustScalerTransformer.UserName,
+ ShortName = RobustScalerTransformer.ShortName)]
+ public static CommonOutputs.TransformOutput RobustScaler(IHostEnvironment env, RobustScalerEstimator.Options input)
+ {
+ var h = EntryPointUtils.CheckArgsAndCreateHost(env, RobustScalerTransformer.ShortName, input);
+ var xf = new RobustScalerEstimator(h, input).Fit(input.Data).Transform(input.Data);
+ return new CommonOutputs.TransformOutput()
+ {
+ Model = new TransformModelImpl(h, xf, input.Data),
+ OutputData = xf
+ };
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Featurizers/TimeSeriesImputer.cs b/src/Microsoft.ML.Featurizers/TimeSeriesImputer.cs
new file mode 100644
index 0000000000..6d899564ca
--- /dev/null
+++ b/src/Microsoft.ML.Featurizers/TimeSeriesImputer.cs
@@ -0,0 +1,637 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Security;
+using System.Text;
+using Microsoft.ML;
+using Microsoft.ML.CommandLine;
+using Microsoft.ML.Data;
+using Microsoft.ML.EntryPoints;
+using Microsoft.ML.Featurizers;
+using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Transforms;
+using static Microsoft.ML.Featurizers.CommonExtensions;
+
+[assembly: LoadableClass(typeof(TimeSeriesImputerTransformer), null, typeof(SignatureLoadModel),
+ TimeSeriesImputerTransformer.UserName, TimeSeriesImputerTransformer.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(IDataTransform), typeof(TimeSeriesImputerTransformer), null, typeof(SignatureLoadDataTransform),
+ TimeSeriesImputerTransformer.UserName, TimeSeriesImputerTransformer.LoaderSignature)]
+
+[assembly: EntryPointModule(typeof(TimeSeriesTransformerEntrypoint))]
+
+namespace Microsoft.ML.Featurizers
+{
+ public static class TimeSeriesImputerExtensionClass
+ {
+ ///
+ /// Create a , Imputes missing rows and column data per grain. Operates on all columns in the IDataView. Currently only numeric columns are supported.
+ ///
+ /// The transform catalog.
+ /// Column representing the time series. Should be of type
+ /// List of columns to use as grains
+ /// Mode of imputation for missing values in column. If not passed defaults to forward fill
+ public static TimeSeriesImputerEstimator TimeSeriesImputer(this TransformsCatalog catalog, string timeSeriesColumn, string[] grainColumns, TimeSeriesImputerEstimator.ImputationStrategy imputeMode = TimeSeriesImputerEstimator.ImputationStrategy.ForwardFill) =>
+ new TimeSeriesImputerEstimator(CatalogUtils.GetEnvironment(catalog), timeSeriesColumn, grainColumns, null, TimeSeriesImputerEstimator.FilterMode.NoFilter, imputeMode, false);
+
+ ///
+ /// Create a , Imputes missing rows and column data per grain. Operates on a filtered list of columns in the IDataView.
+ /// If a column is not imputed but rows are added then it will be filled with the default value for that data type. Currently only numeric columns are supported for imputation.
+ ///
+ /// The transform catalog.
+ /// Column representing the time series. Should be of type
+ /// List of columns to use as grains
+ /// List of columns to filter. If is than columns in the list will be ignored.
+ /// If is than values in the list are the only columns imputed.
+ /// Whether the list should include or exclude those columns.
+ /// Mode of imputation for missing values in column. If not passed defaults to forward fill
+ /// Supress the errors that would occur if a column and impute mode are imcompatible. If true, will skip the column. If false, will stop and throw an error.
+ public static TimeSeriesImputerEstimator TimeSeriesImputer(this TransformsCatalog catalog, string timeSeriesColumn, string[] grainColumns, string[] filterColumns, TimeSeriesImputerEstimator.FilterMode filterMode = TimeSeriesImputerEstimator.FilterMode.Exclude, TimeSeriesImputerEstimator.ImputationStrategy imputeMode = TimeSeriesImputerEstimator.ImputationStrategy.ForwardFill, bool suppressTypeErrors = false) =>
+ new TimeSeriesImputerEstimator(CatalogUtils.GetEnvironment(catalog), timeSeriesColumn, grainColumns, filterColumns, filterMode, imputeMode, suppressTypeErrors);
+ }
+
+ ///
+ /// Imputes missing rows and column data per grain, based on the dates in the date column. Operates on a filtered list of columns in the IDataView.
+ /// If a column is not imputed but rows are added then it will be filled with the default value for that data type. Currently only numeric columns are supported for imputation.
+ /// A new column is added to the schema after this operation is run. The column is called "IsRowImputed" and is a boolean value representing if the row was created as a result
+ /// of this transformer or not.
+ ///
+ ///
+ /// is not a trivial estimator and needs training.
+ ///
+ ///
+ /// ]]>
+ ///
+ ///
+ ///
+ ///
+ public sealed class TimeSeriesImputerEstimator : IEstimator
+ {
+ private Options _options;
+ internal const string IsRowImputedColumnName = "IsRowImputed";
+
+ private readonly IHost _host;
+ private static readonly List _currentSupportedTypes = new List { typeof(sbyte), typeof(byte), typeof(short), typeof(ushort), typeof(int), typeof(uint),
+ typeof(long), typeof(ulong), typeof(float), typeof(double), typeof(string), typeof(ReadOnlyMemory)};
+
+ #region Options
+ internal sealed class Options : TransformInputBase
+ {
+ [Argument(ArgumentType.Required, HelpText = "Column representing the time", Name = "TimeSeriesColumn", ShortName = "time", SortOrder = 1)]
+ public string TimeSeriesColumn;
+
+ [Argument((ArgumentType.MultipleUnique | ArgumentType.Required), HelpText = "List of grain columns", Name = "GrainColumns", ShortName = "grains", SortOrder = 2)]
+ public string[] GrainColumns;
+
+ // This transformer adds columns
+ [Argument(ArgumentType.MultipleUnique, HelpText = "Columns to filter", Name = "FilterColumns", ShortName = "filters", SortOrder = 2)]
+ public string[] FilterColumns;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Filter mode. Either include or exclude", Name = "FilterMode", ShortName = "fmode", SortOrder = 3)]
+ public FilterMode FilterMode = FilterMode.Exclude;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Mode for imputing, defaults to ForwardFill if not provided", Name = "ImputeMode", ShortName = "mode", SortOrder = 3)]
+ public ImputationStrategy ImputeMode = ImputationStrategy.ForwardFill;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Supress the errors that would occur if a column and impute mode are imcompatible. If true, will skip the column. If false, will stop and throw an error.", Name = "SupressTypeErrors", ShortName = "error", SortOrder = 3)]
+ public bool SupressTypeErrors = false;
+ }
+
+ #endregion
+
+ #region Class Enums
+
+ public enum ImputationStrategy : byte
+ {
+ ForwardFill = 1, BackFill, Median, Interpolate
+ };
+
+ public enum FilterMode : byte
+ {
+ NoFilter = 1, Include, Exclude
+ };
+
+ #endregion
+
+ internal TimeSeriesImputerEstimator(IHostEnvironment env, string timeSeriesColumn, string[] grainColumns, string[] filterColumns, FilterMode filterMode, ImputationStrategy imputeMode, bool supressTypeErrors)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host = Contracts.CheckRef(env, nameof(env)).Register("TimeSeriesImputerEstimator");
+ _host.CheckValue(timeSeriesColumn, nameof(timeSeriesColumn), "TimePoint column should not be null.");
+ _host.CheckNonEmpty(grainColumns, nameof(grainColumns), "Need at least one grain column.");
+ if (filterMode == FilterMode.Include)
+ _host.CheckNonEmpty(filterColumns, nameof(filterColumns), "Need at least 1 filter column if a FilterMode is specified");
+
+ _options = new Options
+ {
+ TimeSeriesColumn = timeSeriesColumn,
+ GrainColumns = grainColumns,
+ FilterColumns = filterColumns == null ? new string[] { } : filterColumns,
+ FilterMode = filterMode,
+ ImputeMode = imputeMode,
+ SupressTypeErrors = supressTypeErrors
+ };
+ }
+
+ internal TimeSeriesImputerEstimator(IHostEnvironment env, Options options)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host = Contracts.CheckRef(env, nameof(env)).Register("TimeSeriesImputerEstimator");
+ _host.CheckValue(options.TimeSeriesColumn, nameof(options.TimeSeriesColumn), "TimePoint column should not be null.");
+ _host.CheckValue(options.GrainColumns, nameof(options.GrainColumns), "Grain columns should not be null.");
+ _host.CheckNonEmpty(options.GrainColumns, nameof(options.GrainColumns), "Need at least one grain column.");
+ if (options.FilterMode != FilterMode.NoFilter)
+ _host.CheckNonEmpty(options.FilterColumns, nameof(options.FilterColumns), "Need at least 1 filter column if a FilterMode is specified");
+
+ _options = options;
+ }
+
+ public TimeSeriesImputerTransformer Fit(IDataView input)
+ {
+ // If we are not suppressing type errors make sure columns to impute only contain supported types.
+ if (!_options.SupressTypeErrors)
+ {
+ var columns = input.Schema.Where(x => !_options.GrainColumns.Contains(x.Name));
+ if (_options.FilterMode == FilterMode.Exclude)
+ columns = columns.Where(x => !_options.FilterColumns.Contains(x.Name));
+ else if (_options.FilterMode == FilterMode.Include)
+ columns = columns.Where(x => _options.FilterColumns.Contains(x.Name));
+
+ foreach (var column in columns)
+ {
+ if (!_currentSupportedTypes.Contains(column.Type.RawType))
+ throw new InvalidOperationException($"Type {column.Type.RawType.ToString()} for column {column.Name} not a supported type.");
+ }
+ }
+
+ return new TimeSeriesImputerTransformer(_host, _options.TimeSeriesColumn, _options.GrainColumns, _options.FilterColumns, _options.FilterMode, _options.ImputeMode, _options.SupressTypeErrors, input);
+ }
+
+ // Add one column called WasColumnImputed, otherwise everything stays the same.
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ var columns = inputSchema.ToDictionary(x => x.Name);
+ columns[IsRowImputedColumnName] = new SchemaShape.Column(IsRowImputedColumnName, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false);
+ return new SchemaShape(columns.Values);
+ }
+ }
+
+ public sealed class TimeSeriesImputerTransformer : ITransformer, IDisposable
+ {
+ #region Class data members
+
+ internal const string Summary = "Fills in missing row and values";
+ internal const string UserName = "TimeSeriesImputer";
+ internal const string ShortName = "TimeSeriesImputer";
+ internal const string LoadName = "TimeSeriesImputer";
+ internal const string LoaderSignature = "TimeSeriesImputer";
+
+ private readonly IHost _host;
+ private readonly string _timeSeriesColumn;
+ private readonly string[] _grainColumns;
+ private readonly string[] _dataColumns;
+ private readonly string[] _allColumnNames;
+ private readonly bool _suppressTypeErrors;
+ private readonly TimeSeriesImputerEstimator.ImputationStrategy _imputeMode;
+ internal TransformerEstimatorSafeHandle TransformerHandle;
+
+ #endregion
+
+ // Normal constructor.
+ internal TimeSeriesImputerTransformer(IHostEnvironment host, string timeSeriesColumn, string[] grainColumns, string[] filterColumns, TimeSeriesImputerEstimator.FilterMode filterMode, TimeSeriesImputerEstimator.ImputationStrategy imputeMode, bool suppressTypeErrors, IDataView input)
+ {
+ _host = host.Register(nameof(TimeSeriesImputerTransformer));
+ _timeSeriesColumn = timeSeriesColumn;
+ _grainColumns = grainColumns;
+ _imputeMode = imputeMode;
+ _suppressTypeErrors = suppressTypeErrors;
+
+ IEnumerable tempDataColumns;
+
+ if (filterMode == TimeSeriesImputerEstimator.FilterMode.Exclude)
+ tempDataColumns = input.Schema.Where(x => !filterColumns.Contains(x.Name)).Select(x => x.Name);
+ else if (filterMode == TimeSeriesImputerEstimator.FilterMode.Include)
+ tempDataColumns = input.Schema.Where(x => filterColumns.Contains(x.Name)).Select(x => x.Name);
+ else
+ tempDataColumns = input.Schema.Select(x => x.Name);
+
+ // Time series and Grain columns should never be included in the data columns
+ _dataColumns = tempDataColumns.Where(x => x != timeSeriesColumn && !grainColumns.Contains(x)).ToArray();
+
+ // 1 is for the time series column. Make one array in the correct order of all the columns.
+ // Order is Timeseries column, All grain columns, All data columns.
+ _allColumnNames = new string[1 + _grainColumns.Length + _dataColumns.Length];
+ _allColumnNames[0] = _timeSeriesColumn;
+ Array.Copy(_grainColumns, 0, _allColumnNames, 1, _grainColumns.Length);
+ Array.Copy(_dataColumns, 0, _allColumnNames, 1 + _grainColumns.Length, _dataColumns.Length);
+
+ TransformerHandle = CreateTransformerFromEstimator(input);
+ }
+
+ // Factory method for SignatureLoadModel.
+ internal TimeSeriesImputerTransformer(IHostEnvironment host, ModelLoadContext ctx)
+ {
+ _host = host.Register(nameof(TimeSeriesImputerTransformer));
+
+ // *** Binary format ***
+ // name of time series column
+ // length of grain column array
+ // all column names in grain column array
+ // length of filter column array
+ // all column names in filter column array
+ // byte value of filter mode
+ // byte value of impute mode
+ // length of C++ state array
+ // C++ byte state array
+
+ _timeSeriesColumn = ctx.Reader.ReadString();
+
+ _grainColumns = new string[ctx.Reader.ReadInt32()];
+ for (int i = 0; i < _grainColumns.Length; i++)
+ _grainColumns[i] = ctx.Reader.ReadString();
+
+ _dataColumns = new string[ctx.Reader.ReadInt32()];
+ for (int i = 0; i < _dataColumns.Length; i++)
+ _dataColumns[i] = ctx.Reader.ReadString();
+
+ _imputeMode = (TimeSeriesImputerEstimator.ImputationStrategy)ctx.Reader.ReadByte();
+
+ _allColumnNames = new string[1 + _grainColumns.Length + _dataColumns.Length];
+ _allColumnNames[0] = _timeSeriesColumn;
+ Array.Copy(_grainColumns, 0, _allColumnNames, 1, _grainColumns.Length);
+ Array.Copy(_dataColumns, 0, _allColumnNames, 1 + _grainColumns.Length, _dataColumns.Length);
+
+ var nativeData = ctx.Reader.ReadByteArray();
+ TransformerHandle = CreateTransformerFromSavedData(nativeData);
+ }
+
+ private unsafe TransformerEstimatorSafeHandle CreateTransformerFromSavedData(byte[] data)
+ {
+ fixed (byte* rawData = data)
+ {
+ IntPtr dataSize = new IntPtr(data.Count());
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+ }
+
+ // Factory method for SignatureLoadDataTransform.
+ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ {
+ return (IDataTransform)(new TimeSeriesImputerTransformer(env, ctx).Transform(input));
+ }
+
+ private unsafe TransformerEstimatorSafeHandle CreateTransformerFromEstimator(IDataView input)
+ {
+ IntPtr estimator;
+ IntPtr errorHandle;
+ bool success;
+
+ var allColumns = input.Schema.Where(x => _allColumnNames.Contains(x.Name)).Select(x => TypedColumn.CreateTypedColumn(x, _dataColumns)).ToDictionary(x => x.Column.Name);
+
+ // Create buffer to hold binary data
+ var columnBuffer = new byte[1024];
+
+ // Create TypeId[] for types of grain and data columns;
+ var dataColumnTypes = new TypeId[_dataColumns.Length];
+ var grainColumnTypes = new TypeId[_grainColumns.Length];
+
+ foreach (var column in _grainColumns.Select((value, index) => new { index, value }))
+ grainColumnTypes[column.index] = allColumns[column.value].GetTypeId();
+
+ foreach (var column in _dataColumns.Select((value, index) => new { index, value }))
+ dataColumnTypes[column.index] = allColumns[column.value].GetTypeId();
+
+ fixed (bool* suppressErrors = &_suppressTypeErrors)
+ fixed (TypeId* rawDataColumnTypes = dataColumnTypes)
+ fixed (TypeId* rawGrainColumnTypes = grainColumnTypes)
+ {
+ success = CreateEstimatorNative(rawGrainColumnTypes, new IntPtr(grainColumnTypes.Length), rawDataColumnTypes, new IntPtr(dataColumnTypes.Length), _imputeMode, suppressErrors, out estimator, out errorHandle);
+ }
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ using (var estimatorHandler = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorNative))
+ {
+ var fitResult = FitResult.Continue;
+ while (fitResult != FitResult.Complete)
+ {
+ using (var cursor = input.GetRowCursorForAllColumns())
+ {
+ // Initialize getters for start of loop
+ foreach (var column in allColumns.Values)
+ column.InitializeGetter(cursor);
+
+ while ((fitResult == FitResult.Continue || fitResult == FitResult.ResetAndContinue) && cursor.MoveNext())
+ {
+ BuildColumnByteArray(allColumns, ref columnBuffer, out int bufferLength);
+
+ fixed (byte* bufferPointer = columnBuffer)
+ {
+ var binaryArchiveData = new NativeBinaryArchiveData() { Data = bufferPointer, DataSize = new IntPtr(bufferLength) };
+ success = FitNative(estimatorHandler, binaryArchiveData, out fitResult, out errorHandle);
+ }
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+
+ success = CompleteTrainingNative(estimatorHandler, out fitResult, out errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+ }
+
+ success = CreateTransformerFromEstimatorNative(estimatorHandler, out IntPtr transformer, out errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+ }
+
+ private void BuildColumnByteArray(Dictionary allColumns, ref byte[] columnByteBuffer, out int bufferLength)
+ {
+ bufferLength = 0;
+ foreach (var column in _allColumnNames)
+ {
+ var bytes = allColumns[column].GetSerializedValue();
+ var byteLength = bytes.Length;
+ if (byteLength + bufferLength >= columnByteBuffer.Length)
+ Array.Resize(ref columnByteBuffer, columnByteBuffer.Length * 2);
+
+ Array.Copy(bytes, 0, columnByteBuffer, bufferLength, byteLength);
+ bufferLength += byteLength;
+ }
+ }
+
+ public bool IsRowToRowMapper => false;
+
+ // Schema not changed
+ public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
+ {
+ var columns = inputSchema.ToDictionary(x => x.Name);
+ var schemaBuilder = new DataViewSchema.Builder();
+ schemaBuilder.AddColumns(inputSchema.AsEnumerable());
+ schemaBuilder.AddColumn(TimeSeriesImputerEstimator.IsRowImputedColumnName, BooleanDataViewType.Instance);
+
+ return schemaBuilder.ToSchema();
+ }
+
+ public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) => throw new InvalidOperationException("Not a RowToRowMapper.");
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "TimeIm T",
+ verWrittenCur: 0x00010001,
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(TimeSeriesImputerTransformer).Assembly.FullName);
+ }
+
+ public void Save(ModelSaveContext ctx)
+ {
+ _host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ // *** Binary format ***
+ // name of time series column
+ // length of grain column array
+ // all column names in grain column array
+ // length of data column array
+ // all column names in data column array
+ // byte value of impute mode
+ // length of C++ state array
+ // C++ byte state array
+
+ ctx.Writer.Write(_timeSeriesColumn);
+ ctx.Writer.Write(_grainColumns.Length);
+ foreach (var column in _grainColumns)
+ ctx.Writer.Write(column);
+ ctx.Writer.Write(_dataColumns.Length);
+ foreach (var column in _dataColumns)
+ ctx.Writer.Write(column);
+ ctx.Writer.Write((byte)_imputeMode);
+ var data = CreateTransformerSaveData();
+ ctx.Writer.Write(data.Length);
+ ctx.Writer.Write(data);
+ }
+
+ private byte[] CreateTransformerSaveData()
+ {
+ var success = CreateTransformerSaveDataNative(TransformerHandle, out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ using (var savedDataHandle = new SaveDataSafeHandle(buffer, bufferSize))
+ {
+ byte[] savedData = new byte[bufferSize.ToInt32()];
+ Marshal.Copy(buffer, savedData, 0, savedData.Length);
+ return savedData;
+ }
+ }
+
+ public IDataView Transform(IDataView input) => MakeDataTransform(input);
+
+ internal TimeSeriesImputerDataView MakeDataTransform(IDataView input)
+ {
+ _host.CheckValue(input, nameof(input));
+
+ return new TimeSeriesImputerDataView(_host, input, _timeSeriesColumn, _grainColumns, _dataColumns, _allColumnNames, this);
+ }
+
+ internal TransformerEstimatorSafeHandle CloneTransformer() => CreateTransformerFromSavedData(CreateTransformerSaveData());
+
+ public void Dispose()
+ {
+ if (!TransformerHandle.IsClosed)
+ TransformerHandle.Close();
+ }
+
+ #region C++ function declarations
+ // TODO: Update entry points
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateEstimatorNative(TypeId* grainTypes, IntPtr grainTypesSize, TypeId* dataTypes, IntPtr dataTypesSize, TimeSeriesImputerEstimator.ImputationStrategy strategy, bool* suppressTypeErrors, out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, NativeBinaryArchiveData data, out FitResult fitResult, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+
+ #endregion
+
+ #region Typed Columns
+
+ private abstract class TypedColumn
+ {
+ internal readonly DataViewSchema.Column Column;
+ internal TypedColumn(DataViewSchema.Column column)
+ {
+ Column = column;
+ }
+
+ internal abstract void InitializeGetter(DataViewRowCursor cursor);
+ internal abstract byte[] GetSerializedValue();
+ internal abstract TypeId GetTypeId();
+
+ internal static TypedColumn CreateTypedColumn(DataViewSchema.Column column, string[] optionalColumns)
+ {
+ var type = column.Type.RawType.ToString();
+ if (type == typeof(sbyte).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(short).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(int).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(long).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(byte).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(ushort).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(uint).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(ulong).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(float).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(double).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(ReadOnlyMemory).ToString())
+ return new StringTypedColumn(column, optionalColumns.Contains(column.Name));
+
+ throw new InvalidOperationException($"Unsupported type {type}");
+ }
+ }
+
+ private abstract class TypedColumn : TypedColumn
+ {
+ private ValueGetter _getter;
+ private T _value;
+
+ internal TypedColumn(DataViewSchema.Column column) :
+ base(column)
+ {
+ _value = default;
+ }
+
+ internal override void InitializeGetter(DataViewRowCursor cursor)
+ {
+ _getter = cursor.GetGetter(Column);
+ }
+
+ internal T GetValue()
+ {
+ _getter(ref _value);
+ return _value;
+ }
+
+ internal override TypeId GetTypeId()
+ {
+ return typeof(T).GetNativeTypeIdFromType();
+ }
+ }
+
+ private class NumericTypedColumn : TypedColumn
+ {
+ private readonly bool _isNullable;
+
+ internal NumericTypedColumn(DataViewSchema.Column column, bool isNullable = false) :
+ base(column)
+ {
+ _isNullable = isNullable;
+ }
+
+ internal override byte[] GetSerializedValue()
+ {
+ dynamic value = GetValue();
+ byte[] bytes;
+ if (value.GetType() == typeof(byte))
+ bytes = new byte[1] { value };
+ if (BitConverter.IsLittleEndian)
+ bytes = BitConverter.GetBytes(value);
+ // Will need to enable this when Jin's pr goes in. return ((IEnumerable)BitConverter.GetBytes(value)).Reverse().ToArray();
+ else
+ bytes = BitConverter.GetBytes(value);
+
+ if (_isNullable && value.GetType() != typeof(float) && value.GetType() != typeof(double))
+ return new byte[1] { Convert.ToByte(true) }.Concat(bytes).ToArray();
+ else
+ return bytes;
+ }
+ }
+
+ private class StringTypedColumn : TypedColumn>
+ {
+ private readonly bool _isNullable;
+
+ internal StringTypedColumn(DataViewSchema.Column column, bool isNullable = false) :
+ base(column)
+ {
+ _isNullable = isNullable;
+ }
+
+ internal override byte[] GetSerializedValue()
+ {
+ var value = GetValue().ToString();
+ var stringBytes = Encoding.UTF8.GetBytes(value);
+ if (_isNullable)
+ return new byte[] { Convert.ToByte(true) }.Concat(BitConverter.GetBytes(stringBytes.Length)).Concat(stringBytes).ToArray();
+ return BitConverter.GetBytes(stringBytes.Length).Concat(stringBytes).ToArray();
+ }
+ }
+
+ #endregion
+ }
+
+ internal static class TimeSeriesTransformerEntrypoint
+ {
+ [TlcModule.EntryPoint(Name = "Transforms.TimeSeriesImputer",
+ Desc = TimeSeriesImputerTransformer.Summary,
+ UserName = TimeSeriesImputerTransformer.UserName,
+ ShortName = TimeSeriesImputerTransformer.ShortName)]
+ public static CommonOutputs.TransformOutput TimeSeriesImputer(IHostEnvironment env, TimeSeriesImputerEstimator.Options input)
+ {
+ var h = EntryPointUtils.CheckArgsAndCreateHost(env, TimeSeriesImputerTransformer.ShortName, input);
+ var xf = new TimeSeriesImputerEstimator(h, input).Fit(input.Data).Transform(input.Data);
+ return new CommonOutputs.TransformOutput()
+ {
+ Model = new TransformModelImpl(h, xf, input.Data),
+ OutputData = xf
+ };
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Featurizers/TimeSeriesImputerDataView.cs b/src/Microsoft.ML.Featurizers/TimeSeriesImputerDataView.cs
new file mode 100644
index 0000000000..ff9fa5d8ec
--- /dev/null
+++ b/src/Microsoft.ML.Featurizers/TimeSeriesImputerDataView.cs
@@ -0,0 +1,761 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Security;
+using System.Text;
+using Microsoft.ML.Data;
+using Microsoft.ML.Featurizers;
+using Microsoft.ML.Runtime;
+using Microsoft.Win32.SafeHandles;
+using static Microsoft.ML.Featurizers.CommonExtensions;
+using static Microsoft.ML.Featurizers.TimeSeriesImputerEstimator;
+
+namespace Microsoft.ML.Transforms
+{
+
+ internal sealed class TimeSeriesImputerDataView : IDataTransform
+ {
+ #region Typed Columns
+ private TimeSeriesImputerTransformer _parent;
+ public class SharedColumnState
+ {
+ public bool SourceCanMoveNext { get; set; }
+ public int TransformedDataPosition { get; set; }
+ public NativeBinaryArchiveData[] TransformedData { get; set; }
+ public byte[] ColumnBuffer { get; set; }
+ public TransformedDataSafeHandle TransformedDataHandler { get; set; }
+ }
+
+ private abstract class TypedColumn
+ {
+ private protected SharedColumnState SharedState;
+
+ internal readonly DataViewSchema.Column Column;
+ internal readonly bool IsImputed;
+ internal TypedColumn(DataViewSchema.Column column, bool isImputed, SharedColumnState state)
+ {
+ Column = column;
+ SharedState = state;
+ IsImputed = isImputed;
+ }
+
+ internal abstract Delegate GetGetter();
+ internal abstract void InitializeGetter(DataViewRowCursor cursor, TransformerEstimatorSafeHandle transformerParent, string timeSeriesColumn,
+ string[] grainColumns, string[] dataColumns, string[] allColumnNames, Dictionary allColumns);
+
+ internal abstract TypeId GetTypeId();
+ internal abstract byte[] GetSerializedValue();
+ internal abstract unsafe int GetDataSizeInBytes(byte* data, int currentOffset);
+ internal abstract void QueueNonImputedColumnValue();
+
+ public bool MoveNext(DataViewRowCursor cursor)
+ {
+ SharedState.TransformedDataPosition++;
+
+ if (SharedState.TransformedData == null || SharedState.TransformedDataPosition >= SharedState.TransformedData.Length)
+ SharedState.SourceCanMoveNext = cursor.MoveNext();
+
+ if (!SharedState.SourceCanMoveNext)
+ if (SharedState.TransformedDataPosition >= SharedState.TransformedData.Length)
+ {
+ if (!SharedState.TransformedDataHandler.IsClosed)
+ SharedState.TransformedDataHandler.Dispose();
+ return false;
+ }
+
+ return true;
+ }
+
+ internal static TypedColumn CreateTypedColumn(DataViewSchema.Column column, string[] optionalColumns, string[] allImputedColumns, SharedColumnState state)
+ {
+ var type = column.Type.RawType.ToString();
+ if (type == typeof(sbyte).ToString())
+ return new SByteTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(short).ToString())
+ return new ShortTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(int).ToString())
+ return new IntTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(long).ToString())
+ return new LongTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(byte).ToString())
+ return new ByteTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(ushort).ToString())
+ return new UShortTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(uint).ToString())
+ return new UIntTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(ulong).ToString())
+ return new ULongTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(float).ToString())
+ return new FloatTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(double).ToString())
+ return new DoubleTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(ReadOnlyMemory).ToString())
+ return new StringTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(bool).ToString())
+ return new BoolTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+
+ throw new InvalidOperationException($"Unsupported type {type}");
+ }
+ }
+
+ private abstract class TypedColumn : TypedColumn
+ {
+ private ValueGetter _getter;
+ private ValueGetter _sourceGetter;
+ private long _position;
+ private T _cache;
+
+ // When columns are not being imputed, we need to store the column values in memory until they are used.
+ private protected Queue SourceQueue;
+
+ internal TypedColumn(DataViewSchema.Column column, bool isImputed, SharedColumnState state) :
+ base(column, isImputed, state)
+ {
+ SourceQueue = new Queue();
+ _position = -1;
+ }
+
+ internal override Delegate GetGetter()
+ {
+ return _getter;
+ }
+
+ internal override unsafe void InitializeGetter(DataViewRowCursor cursor, TransformerEstimatorSafeHandle transformer, string timeSeriesColumn,
+ string[] grainColumns, string[] dataColumns, string[] allImputedColumnNames, Dictionary allColumns)
+ {
+ if (Column.Name != IsRowImputedColumnName)
+ _sourceGetter = cursor.GetGetter(Column);
+
+ _getter = (ref T dst) =>
+ {
+ IntPtr errorHandle = IntPtr.Zero;
+ bool success = false;
+ if (SharedState.TransformedData == null || SharedState.TransformedDataPosition >= SharedState.TransformedData.Length)
+ {
+ // Free native memory if we are about to get more
+ if (SharedState.TransformedData != null && SharedState.TransformedDataPosition >= SharedState.TransformedData.Length)
+ SharedState.TransformedDataHandler.Dispose();
+
+ var outputDataSize = IntPtr.Zero;
+ NativeBinaryArchiveData* outputData = default;
+ while(outputDataSize == IntPtr.Zero && SharedState.SourceCanMoveNext)
+ {
+ BuildColumnByteArray(allColumns, allImputedColumnNames, out int bufferLength);
+ QueueDataForNonImputedColumns(allColumns, allImputedColumnNames);
+ fixed (byte* bufferPointer = SharedState.ColumnBuffer)
+ {
+ var binaryArchiveData = new NativeBinaryArchiveData() { Data = bufferPointer, DataSize = new IntPtr(bufferLength) };
+ success = TransformDataNative(transformer, binaryArchiveData, out outputData, out outputDataSize, out errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+
+ if (outputDataSize == IntPtr.Zero)
+ SharedState.SourceCanMoveNext = cursor.MoveNext();
+ }
+
+ if (!SharedState.SourceCanMoveNext)
+ success = FlushDataNative(transformer, out outputData, out outputDataSize, out errorHandle);
+
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ if (outputDataSize.ToInt32() > 0)
+ {
+ SharedState.TransformedDataHandler = new TransformedDataSafeHandle((IntPtr)outputData, outputDataSize);
+ SharedState.TransformedData = new NativeBinaryArchiveData[outputDataSize.ToInt32()];
+ for (int i = 0; i < outputDataSize.ToInt32(); i++)
+ {
+ SharedState.TransformedData[i] = *(outputData + i);
+ }
+ SharedState.TransformedDataPosition = 0;
+ }
+ }
+
+ // Base case where we didn't impute anything
+ if (!allImputedColumnNames.Contains(Column.Name))
+ {
+ var imputedData = SharedState.TransformedData[SharedState.TransformedDataPosition];
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(imputedData.Data, 0))
+ {
+ dst = default;
+ }
+ else
+ {
+ if (_position != cursor.Position)
+ {
+ _position = cursor.Position;
+ _cache = SourceQueue.Dequeue();
+ }
+ dst = _cache;
+ }
+ }
+ else
+ {
+ var imputedData = SharedState.TransformedData[SharedState.TransformedDataPosition];
+ int offset = 0;
+ foreach (var columnName in allImputedColumnNames)
+ {
+ var col = allColumns[columnName];
+ if (col.Column.Name == Column.Name)
+ {
+ dst = GetDataFromNativeBinaryArchiveData(imputedData.Data, offset);
+ return;
+ }
+
+ offset += col.GetDataSizeInBytes(imputedData.Data, offset);
+ }
+
+ // This should never be hit.
+ dst = default;
+ }
+ };
+ }
+
+ private void QueueDataForNonImputedColumns(Dictionary allColumns, string[] allImputedColumnNames)
+ {
+ foreach (var column in allColumns.Where(x => !allImputedColumnNames.Contains(x.Value.Column.Name)).Select(x => x.Value))
+ column.QueueNonImputedColumnValue();
+ }
+
+ internal override void QueueNonImputedColumnValue()
+ {
+ SourceQueue.Enqueue(GetSourceValue());
+ }
+
+ private void BuildColumnByteArray(Dictionary allColumns, string[] columns, out int bufferLength)
+ {
+ bufferLength = 0;
+ foreach(var column in columns.Where(x => x != IsRowImputedColumnName))
+ {
+ var bytes = allColumns[column].GetSerializedValue();
+ var byteLength = bytes.Length;
+ if (byteLength + bufferLength >= SharedState.ColumnBuffer.Length)
+ {
+ var buffer = SharedState.ColumnBuffer;
+ Array.Resize(ref buffer, SharedState.ColumnBuffer.Length * 2);
+ SharedState.ColumnBuffer = buffer;
+ }
+
+ Array.Copy(bytes, 0, SharedState.ColumnBuffer, bufferLength, byteLength);
+ bufferLength += byteLength;
+ }
+ }
+
+ private protected T GetSourceValue()
+ {
+ T value = default;
+ _sourceGetter(ref value);
+ return value;
+ }
+
+ internal override TypeId GetTypeId()
+ {
+ return typeof(T).GetNativeTypeIdFromType();
+ }
+
+ internal unsafe abstract T GetDataFromNativeBinaryArchiveData(byte* data, int offset);
+ }
+
+ private abstract class NumericTypedColumn : TypedColumn
+ {
+ private protected readonly bool IsNullable;
+
+ internal NumericTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isImputed, state)
+ {
+ IsNullable = isNullable;
+ }
+
+ internal override byte[] GetSerializedValue()
+ {
+ dynamic value = GetSourceValue();
+ byte[] bytes;
+ if (value.GetType() == typeof(byte))
+ bytes = new byte[1] { value };
+ if (BitConverter.IsLittleEndian)
+ bytes = BitConverter.GetBytes(value);
+ else
+ bytes = BitConverter.GetBytes(value);
+
+ if (IsNullable && value.GetType() != typeof(float) && value.GetType() != typeof(double))
+ return new byte[1] { Convert.ToByte(true) }.Concat(bytes).ToArray();
+ else
+ return bytes;
+ }
+
+ internal override unsafe int GetDataSizeInBytes(byte* data, int currentOffset)
+ {
+ if (IsNullable && typeof(T) != typeof(float) && typeof(T) != typeof(double))
+ return Marshal.SizeOf(default(T)) + sizeof(bool);
+ else
+ return Marshal.SizeOf(default(T));
+ }
+ }
+
+ private class ByteTypedColumn : NumericTypedColumn
+ {
+ internal ByteTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override byte GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(byte*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(byte*)(data + offset);
+ }
+ }
+
+ private class SByteTypedColumn : NumericTypedColumn
+ {
+ internal SByteTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override sbyte GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(sbyte*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(sbyte*)(data + offset);
+ }
+ }
+
+ private class ShortTypedColumn : NumericTypedColumn
+ {
+ internal ShortTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override short GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(short*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(short*)(data + offset);
+ }
+ }
+
+ private class UShortTypedColumn : NumericTypedColumn
+ {
+ internal UShortTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override ushort GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(ushort*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(ushort*)(data + offset);
+ }
+ }
+
+ private class IntTypedColumn : NumericTypedColumn
+ {
+ internal IntTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override int GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(int*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(int*)(data + offset);
+ }
+ }
+
+ private class UIntTypedColumn : NumericTypedColumn
+ {
+ internal UIntTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override uint GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(uint*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(uint*)(data + offset);
+ }
+ }
+
+ private class LongTypedColumn : NumericTypedColumn
+ {
+ internal LongTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override long GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(long*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(long*)(data + offset);
+ }
+ }
+
+ private class ULongTypedColumn : NumericTypedColumn
+ {
+ internal ULongTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override ulong GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(ulong*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(ulong*)(data + offset);
+ }
+ }
+
+ private class FloatTypedColumn : NumericTypedColumn
+ {
+ internal FloatTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override float GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ var bytes = new byte[sizeof(float)];
+ Marshal.Copy((IntPtr)(data + offset), bytes, 0, sizeof(float));
+ return BitConverter.ToSingle(bytes, 0);
+ }
+ }
+
+ private class DoubleTypedColumn : NumericTypedColumn
+ {
+ internal DoubleTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override double GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ var bytes = new byte[sizeof(double)];
+ Marshal.Copy((IntPtr)(data + offset), bytes, 0, sizeof(double));
+ return BitConverter.ToDouble(bytes, 0);
+ }
+ }
+
+ private class BoolTypedColumn : NumericTypedColumn
+ {
+ internal BoolTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override bool GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(bool*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(bool*)(data + offset);
+ }
+
+ internal static unsafe bool GetBoolFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ return *(bool*)(data + offset);
+ }
+
+ internal override unsafe int GetDataSizeInBytes(byte* data, int currentOffset)
+ {
+ return sizeof(bool);
+ }
+ }
+
+ private class StringTypedColumn : TypedColumn>
+ {
+ private readonly bool _isNullable;
+ internal StringTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isImputed, state)
+ {
+ _isNullable = isNullable;
+ }
+
+ internal override byte[] GetSerializedValue()
+ {
+ var value = GetSourceValue().ToString();
+ var stringBytes = Encoding.UTF8.GetBytes(value);
+ if (_isNullable)
+ return new byte[] { Convert.ToByte(true)}.Concat(BitConverter.GetBytes(stringBytes.Length)).Concat(stringBytes).ToArray();
+ return BitConverter.GetBytes(stringBytes.Length).Concat(stringBytes).ToArray();
+ }
+
+ internal unsafe override ReadOnlyMemory GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (_isNullable)
+ {
+ if (!BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) // If value not present return empty string
+ return new ReadOnlyMemory("".ToCharArray());
+
+ var size = *(uint*)(data + offset + 1); // Add 1 for the byte bool flag
+
+ var bytes = new byte[size];
+ Marshal.Copy((IntPtr)(data + offset + sizeof(uint) + 1), bytes, 0, (int)size);
+ return Encoding.UTF8.GetString(bytes).AsMemory();
+ }
+ else
+ {
+ var size = *(uint*)(data + offset);
+
+ var bytes = new byte[size];
+ Marshal.Copy((IntPtr)(data + offset + sizeof(uint)), bytes, 0, (int)size);
+ return Encoding.UTF8.GetString(bytes).AsMemory();
+ }
+ }
+
+ internal override unsafe int GetDataSizeInBytes(byte* data, int currentOffset)
+ {
+ var size = *(uint*)(data + currentOffset);
+ if (_isNullable)
+ return 1 + (int)size + sizeof(uint); // + 1 for the byte bool flag
+
+ return (int)size + sizeof(uint);
+ }
+ }
+
+ #endregion
+
+ #region Native Exports
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_Transform"), SuppressUnmanagedCodeSecurity]
+ private static extern unsafe bool TransformDataNative(TransformerEstimatorSafeHandle transformer, /*in*/ NativeBinaryArchiveData data, out NativeBinaryArchiveData* outputData, out IntPtr outputDataSize, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_Transform"), SuppressUnmanagedCodeSecurity]
+ private static extern unsafe bool FlushDataNative(TransformerEstimatorSafeHandle transformer, out NativeBinaryArchiveData* outputData, out IntPtr outputDataSize, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_DestroyTransformedData"), SuppressUnmanagedCodeSecurity]
+ private static extern unsafe bool DestroyTransformedDataNative(IntPtr data, IntPtr dataSize, out IntPtr errorHandle);
+
+ #endregion
+
+ #region Native SafeHandles
+
+ internal class TransformedDataSafeHandle : SafeHandleZeroOrMinusOneIsInvalid
+ {
+ private IntPtr _size;
+ public TransformedDataSafeHandle(IntPtr handle, IntPtr size) : base(true)
+ {
+ SetHandle(handle);
+ _size = size;
+ }
+
+ protected override bool ReleaseHandle()
+ {
+ // Not sure what to do with error stuff here. There shoudln't ever be one though.
+ return DestroyTransformedDataNative(handle, _size, out IntPtr errorHandle);
+ }
+ }
+
+ #endregion
+
+ private readonly IHostEnvironment _host;
+ private readonly IDataView _source;
+ private readonly string _timeSeriesColumn;
+ private readonly string[] _dataColumns;
+ private readonly string[] _grainColumns;
+ private readonly string[] _allImputedColumnNames;
+ private readonly DataViewSchema _schema;
+
+ internal TimeSeriesImputerDataView(IHostEnvironment env, IDataView input, string timeSeriesColumn, string[] grainColumns, string[] dataColumns, string[] allColumnNames, TimeSeriesImputerTransformer parent)
+ {
+ _host = env;
+ _source = input;
+
+ _timeSeriesColumn = timeSeriesColumn;
+ _grainColumns = grainColumns;
+ _dataColumns = dataColumns;
+ _allImputedColumnNames = new string[] { IsRowImputedColumnName }.Concat(allColumnNames).ToArray();
+ _parent = parent;
+ // Build new schema.
+ var schemaColumns = _source.Schema.ToDictionary(x => x.Name);
+ var schemaBuilder = new DataViewSchema.Builder();
+ schemaBuilder.AddColumns(_source.Schema.AsEnumerable());
+ schemaBuilder.AddColumn(IsRowImputedColumnName, BooleanDataViewType.Instance);
+
+ _schema = schemaBuilder.ToSchema();
+ }
+
+ public bool CanShuffle => false;
+
+ public DataViewSchema Schema => _schema;
+
+ public IDataView Source => _source;
+
+ public DataViewRowCursor GetRowCursor(IEnumerable columnsNeeded, Random rand = null)
+ {
+ _host.AssertValueOrNull(rand);
+
+ var input = _source.GetRowCursorForAllColumns();
+ return new Cursor(_host, input, _parent.CloneTransformer(), _timeSeriesColumn, _grainColumns, _dataColumns, _allImputedColumnNames, _schema);
+ }
+
+ // Can't use parallel cursors so this defaults to calling non-parallel version
+ public DataViewRowCursor[] GetRowCursorSet(IEnumerable columnsNeeded, int n, Random rand = null) =>
+ new DataViewRowCursor[] { GetRowCursor(columnsNeeded, rand) };
+
+ // Since we may add rows we don't know the row count
+ public long? GetRowCount() { return null; }
+
+ public void Save(ModelSaveContext ctx)
+ {
+ _parent.Save(ctx);
+ }
+
+ private sealed class Cursor : DataViewRowCursor
+ {
+ private readonly IChannelProvider _ch;
+ private DataViewRowCursor _input;
+ private long _position;
+ private bool _isGood;
+ private readonly Dictionary _allColumns;
+ private readonly DataViewSchema _schema;
+ private readonly TransformerEstimatorSafeHandle _transformer;
+
+ public Cursor(IChannelProvider provider, DataViewRowCursor input, TransformerEstimatorSafeHandle transformer, string timeSeriesColumn,
+ string[] grainColumns, string[] dataColumns, string[] allImputedColumnNames, DataViewSchema schema)
+ {
+ _ch = provider;
+ _ch.CheckValue(input, nameof(input));
+
+ _input = input;
+ var length = input.Schema.Count;
+ _position = -1;
+ _schema = schema;
+ _transformer = transformer;
+
+ var sharedState = new SharedColumnState()
+ {
+ SourceCanMoveNext = true,
+ ColumnBuffer = new byte[1024]
+ };
+
+ _allColumns = _schema.Select(x => TypedColumn.CreateTypedColumn(x, dataColumns, allImputedColumnNames, sharedState)).ToDictionary(x => x.Column.Name); ;
+ _allColumns[IsRowImputedColumnName] = new BoolTypedColumn(_schema[IsRowImputedColumnName], false, true, sharedState);
+
+ foreach (var column in _allColumns.Values)
+ {
+ column.InitializeGetter(_input, transformer, timeSeriesColumn, grainColumns, dataColumns, allImputedColumnNames, _allColumns);
+ }
+ }
+
+ public sealed override ValueGetter GetIdGetter()
+ {
+ return
+ (ref DataViewRowId val) =>
+ {
+ _ch.Check(_isGood, RowCursorUtils.FetchValueStateError);
+ val = new DataViewRowId((ulong)Position, 0);
+ };
+ }
+
+ public sealed override DataViewSchema Schema => _schema;
+
+ ///
+ /// Since rows will be generated all columns are active
+ ///
+ public override bool IsColumnActive(DataViewSchema.Column column)
+ {
+ return true;
+ }
+
+ protected override void Dispose(bool disposing)
+ {
+ if (!_transformer.IsClosed)
+ _transformer.Close();
+ }
+
+ ///
+ /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
+ /// This throws if the column is not active in this row, or if the type
+ /// differs from this column's type.
+ ///
+ /// is the column's content type.
+ /// is the output column whose getter should be returned.
+ public override ValueGetter GetGetter(DataViewSchema.Column column)
+ {
+ _ch.Check(IsColumnActive(column));
+
+ var fn = _allColumns[column.Name].GetGetter() as ValueGetter;
+ if (fn == null)
+ throw _ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
+ return fn;
+ }
+
+ public override bool MoveNext()
+ {
+ _position++;
+ _isGood = _allColumns[IsRowImputedColumnName].MoveNext(_input);
+ return _isGood;
+ }
+
+ public sealed override long Position => _position;
+
+ public sealed override long Batch => _input.Batch;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Featurizers/ToStringTransformer.cs b/src/Microsoft.ML.Featurizers/ToStringTransformer.cs
new file mode 100644
index 0000000000..2ea4687d6b
--- /dev/null
+++ b/src/Microsoft.ML.Featurizers/ToStringTransformer.cs
@@ -0,0 +1,1630 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Security;
+using System.Text;
+using Microsoft.ML;
+using Microsoft.ML.CommandLine;
+using Microsoft.ML.Data;
+using Microsoft.ML.EntryPoints;
+using Microsoft.ML.Featurizers;
+using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Transforms;
+using Microsoft.Win32.SafeHandles;
+using static Microsoft.ML.Featurizers.CommonExtensions;
+
+[assembly: LoadableClass(typeof(ToStringTransformer), null, typeof(SignatureLoadModel),
+ ToStringTransformer.UserName, ToStringTransformer.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(IRowMapper), typeof(ToStringTransformer), null, typeof(SignatureLoadRowMapper),
+ ToStringTransformer.UserName, ToStringTransformer.LoaderSignature)]
+
+[assembly: EntryPointModule(typeof(ToStringTransformerEntrypoint))]
+
+namespace Microsoft.ML.Featurizers
+{
+ public static class ToStringTransformerExtensionClass
+ {
+ ///
+ /// Create a , which converts the input column specified by
+ /// into a string representation of its contents .
+ ///
+ /// The transform catalog.
+ /// Name of the column resulting from the transformation of .
+ /// This column's data type will be of type
+ /// Name of column to convert to its string representation. If set to , the value of the
+ /// will be used as source. This column's data type can be scalar of numeric, text, and boolean
+ ///
+ public static ToStringTransformerEstimator ToStringTransformer(this TransformsCatalog catalog, string outputColumnName, string inputColumnName = null)
+ => ToStringTransformerEstimator.Create(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName);
+
+ ///
+ /// Create a , which converts each input column in specified by
+ /// into a string representation of its contents and stores it in the column specified by .
+ /// The input column data type can be scalar of numeric, text, and boolean
+ ///
+ /// The transform catalog.
+ /// Array of . The output column data type will be of type
+ ///
+ public static ToStringTransformerEstimator ToStringTransformer(this TransformsCatalog catalog, params InputOutputColumnPair[] columns)
+ => ToStringTransformerEstimator.Create(CatalogUtils.GetEnvironment(catalog), columns);
+ }
+
+ ///
+ /// Converts one or more input columns into string representations of its contents. Supports input column's of data type numeric, text, and boolean
+ ///
+ ///
+ /// is a trivial estimator that doesn't need training.
+ /// The resulting converts one or more input columns into its appropriate string representation.
+ ///
+ /// The ToStringTransformer can be applied to one or more columns, in which case it turns each input type into its appropriate string represenation.
+ ///
+ /// ]]>
+ ///
+ ///
+ ///
+ ///
+ public sealed class ToStringTransformerEstimator : IEstimator
+ {
+ private Options _options;
+
+ private readonly IHost _host;
+
+ #region Options
+
+ internal sealed class Column : OneToOneColumn
+ {
+ internal static Column Parse(string str)
+ {
+ Contracts.AssertNonEmpty(str);
+
+ var res = new Column();
+ if (res.TryParse(str))
+ return res;
+ return null;
+ }
+
+ internal bool TryUnparse(StringBuilder sb)
+ {
+ Contracts.AssertValue(sb);
+ return TryUnparseCore(sb);
+ }
+ }
+
+ internal sealed class Options : TransformInputBase
+ {
+ [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition (optional form: name:src)",
+ Name = "Column", ShortName = "col", SortOrder = 1)]
+ public Column[] Columns;
+ }
+ #endregion
+
+ internal static ToStringTransformerEstimator Create(IHostEnvironment env, string outputColumnName, string inputColumnName)
+ {
+ return new ToStringTransformerEstimator(env, outputColumnName, inputColumnName);
+ }
+
+ internal static ToStringTransformerEstimator Create(IHostEnvironment env, params InputOutputColumnPair[] columns)
+ {
+ var columnOptions = columns.Select(x => new Column { Name = x.OutputColumnName, Source = x.InputColumnName ?? x.OutputColumnName }).ToArray();
+ return new ToStringTransformerEstimator(env, new Options { Columns = columnOptions });
+ }
+
+ internal ToStringTransformerEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host = env.Register(nameof(ToStringTransformerEstimator));
+
+ _options = new Options
+ {
+ Columns = new Column[1] { new Column() { Name = outputColumnName, Source = inputColumnName ?? outputColumnName } }
+ };
+ }
+
+ internal ToStringTransformerEstimator(IHostEnvironment env, Options options)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host = env.Register(nameof(ToStringTransformerEstimator));
+
+ foreach(var columnPair in options.Columns)
+ {
+ columnPair.Source = columnPair.Source ?? columnPair.Name;
+ }
+ _options = options;
+
+ }
+
+ public ToStringTransformer Fit(IDataView input)
+ {
+ return new ToStringTransformer(_host, _options.Columns, input);
+ }
+
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ var columns = inputSchema.ToDictionary(x => x.Name);
+
+ foreach (var column in _options.Columns)
+ {
+ columns[column.Name] = new SchemaShape.Column(column.Name, SchemaShape.Column.VectorKind.Scalar,
+ ColumnTypeExtensions.PrimitiveTypeFromType(typeof(string)), false, null);
+ }
+
+ return new SchemaShape(columns.Values);
+ }
+ }
+
+ public sealed class ToStringTransformer : RowToRowTransformerBase, IDisposable
+ {
+ #region Class data members
+
+ internal const string Summary = "Turns the given column into a column of its string representation";
+ internal const string UserName = "ToString Transform";
+ internal const string ShortName = "ToStringTransform";
+ internal const string LoadName = "ToStringTransform";
+ internal const string LoaderSignature = "ToStringTransform";
+
+ private TypedColumn[] _columns;
+
+ #endregion
+
+ internal ToStringTransformer(IHostEnvironment host, ToStringTransformerEstimator.Column[] columns, IDataView input) :
+ base(host.Register(nameof(ToStringTransformer)))
+ {
+ var schema = input.Schema;
+
+ _columns = columns.Select(x => TypedColumn.CreateTypedColumn(x.Name, x.Source, schema[x.Source].Type.RawType.ToString())).ToArray();
+ foreach (var column in _columns)
+ {
+ // No training is required so directly create the transformer
+ column.CreateTransformerFromEstimator();
+ }
+ }
+
+ // Factory method for SignatureLoadModel.
+ internal ToStringTransformer(IHostEnvironment host, ModelLoadContext ctx) :
+ base(host.Register(nameof(ToStringTransformer)))
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ // *** Binary format ***
+ // int number of column pairs
+ // for each column pair:
+ // string output column name
+ // string input column name
+ // string representation of type
+
+ var columnCount = ctx.Reader.ReadInt32();
+
+ _columns = new TypedColumn[columnCount];
+ for (int i = 0; i < columnCount; i++)
+ {
+ _columns[i] = TypedColumn.CreateTypedColumn(ctx.Reader.ReadString(), ctx.Reader.ReadString(), ctx.Reader.ReadString());
+
+ var dataLength = ctx.Reader.ReadInt32();
+ var data = ctx.Reader.ReadByteArray(dataLength);
+ _columns[i].CreateTransformerFromSavedData(data);
+
+ }
+ }
+
+ // Factory method for SignatureLoadRowMapper.
+ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
+ => new ToStringTransformer(env, ctx).MakeRowMapper(inputSchema);
+
+ private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "TOSTRI T",
+ verWrittenCur: 0x00010001,
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(ToStringTransformer).Assembly.FullName);
+ }
+
+ private protected override void SaveModel(ModelSaveContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ // *** Binary format ***
+ // int number of column pairs
+ // for each column pair:
+ // string output column name
+ // string input column name
+ // string representation of type
+ // int c++ state array length
+
+ ctx.Writer.Write(_columns.Count());
+ foreach (var column in _columns)
+ {
+ ctx.Writer.Write(column.Name);
+ ctx.Writer.Write(column.Source);
+ ctx.Writer.Write(column.Type);
+
+ var data = column.CreateTransformerSaveData();
+ ctx.Writer.Write(data.Length);
+ ctx.Writer.Write(data);
+ }
+ }
+
+ public void Dispose()
+ {
+ foreach (var column in _columns)
+ {
+ column.Dispose();
+ }
+ }
+
+ #region C++ Safe handle classes
+
+ internal class TransformedDataSafeHandle : SafeHandleZeroOrMinusOneIsInvalid
+ {
+ private DestroyTransformedDataNative _destroySaveDataHandler;
+ private IntPtr _dataSize;
+
+ public TransformedDataSafeHandle(IntPtr handle, IntPtr dataSize, DestroyTransformedDataNative destroyCppTransformerEstimator) : base(true)
+ {
+ SetHandle(handle);
+ _dataSize = dataSize;
+ _destroySaveDataHandler = destroyCppTransformerEstimator;
+ }
+
+ protected override bool ReleaseHandle()
+ {
+ // Not sure what to do with error stuff here. There shoudln't ever be one though.
+ return _destroySaveDataHandler(handle, _dataSize, out IntPtr errorHandle);
+ }
+ }
+
+ #endregion
+
+ #region ColumnInfo
+
+ // REVIEW: Since we can't do overloading on the native side due to the C style exports,
+ // this was the best way I could think handle it to allow for any conversions needed based on the data type.
+
+ #region BaseClass
+
+ internal delegate bool DestroyCppTransformerEstimator(IntPtr estimator, out IntPtr errorHandle);
+ internal delegate bool DestroyTransformerSaveData(IntPtr buffer, IntPtr bufferSize, out IntPtr errorHandle);
+ internal delegate bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+
+ internal abstract class TypedColumn : IDisposable
+ {
+ internal readonly string Name;
+ internal readonly string Source;
+ internal readonly string Type;
+ internal TypedColumn(string name, string source, string type)
+ {
+ Name = name;
+ Source = source;
+ Type = type;
+ }
+
+ internal abstract void CreateTransformerFromEstimator();
+ private protected abstract unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize);
+ private protected abstract bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle);
+ private protected abstract bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ private protected abstract bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle);
+ private protected abstract bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle);
+ private protected abstract bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle);
+ public abstract void Dispose();
+
+ private protected TransformerEstimatorSafeHandle CreateTransformerFromEstimatorBase()
+ {
+ var success = CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle);
+ if (!success)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+
+ using (var estimatorHandler = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorHelper))
+ {
+
+ success = CreateTransformerFromEstimatorHelper(estimatorHandler, out IntPtr transformer, out errorHandle);
+ if (!success)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+
+ return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerHelper);
+ }
+ }
+
+ internal byte[] CreateTransformerSaveData()
+ {
+
+ var success = CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ using (var savedDataHandle = new SaveDataSafeHandle(buffer, bufferSize))
+ {
+ byte[] savedData = new byte[bufferSize.ToInt32()];
+ Marshal.Copy(buffer, savedData, 0, savedData.Length);
+ return savedData;
+ }
+ }
+
+ internal unsafe void CreateTransformerFromSavedData(byte[] data)
+ {
+ fixed (byte* rawData = data)
+ {
+ IntPtr dataSize = new IntPtr(data.Count());
+ CreateTransformerFromSavedDataHelper(rawData, dataSize);
+ }
+ }
+
+ internal static TypedColumn CreateTypedColumn(string name, string source, string type)
+ {
+ if (type == typeof(sbyte).ToString())
+ {
+ return new Int8TypedColumn(name, source);
+ }
+ else if (type == typeof(short).ToString())
+ {
+ return new Int16TypedColumn(name, source);
+ }
+ else if (type == typeof(int).ToString())
+ {
+ return new Int32TypedColumn(name, source);
+ }
+ else if (type == typeof(long).ToString())
+ {
+ return new Int64TypedColumn(name, source);
+ }
+ else if (type == typeof(byte).ToString())
+ {
+ return new UInt8TypedColumn(name, source);
+ }
+ else if (type == typeof(ushort).ToString())
+ {
+ return new UInt16TypedColumn(name, source);
+ }
+ else if (type == typeof(uint).ToString())
+ {
+ return new UInt32TypedColumn(name, source);
+ }
+ else if (type == typeof(ulong).ToString())
+ {
+ return new UInt64TypedColumn(name, source);
+ }
+ else if (type == typeof(float).ToString())
+ {
+ return new FloatTypedColumn(name, source);
+ }
+ else if (type == typeof(double).ToString())
+ {
+ return new DoubleTypedColumn(name, source);
+ }
+ else if (type == typeof(bool).ToString())
+ {
+ return new BoolTypedColumn(name, source);
+ }
+ else if (type == typeof(string).ToString())
+ {
+ return new StringTypedColumn(name, source);
+ }
+ else if (type == typeof(ReadOnlyMemory).ToString())
+ {
+ return new ReadOnlyCharTypedColumn(name, source);
+ }
+
+ throw new Exception($"Unsupported type {type}");
+ }
+ }
+
+ internal abstract class TypedColumn : TypedColumn
+ {
+ internal TypedColumn(string name, string source, string type) :
+ base(name, source, type)
+ {
+ }
+
+ internal abstract string Transform(T input);
+
+ }
+
+ #endregion
+
+ #region Int8Column
+
+ internal sealed class Int8TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ internal Int8TypedColumn(string name, string source) :
+ base(name, source, typeof(sbyte).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int8_t_CreateEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int8_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int8_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int8_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator()
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase();
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int8_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int8_t_Transform"), SuppressUnmanagedCodeSecurity]
+ private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, sbyte input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int8_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+ internal override string Transform(sbyte input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ if (!success)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+ using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative))
+ {
+ byte[] buffer = new byte[outputSize.ToInt32()];
+ Marshal.Copy(output, buffer, 0, buffer.Length);
+ return Encoding.UTF8.GetString(buffer);
+ }
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int8_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region Int16Column
+
+ internal sealed class Int16TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+
+ internal Int16TypedColumn(string name, string source) :
+ base(name, source, typeof(short).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int16_t_CreateEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int16_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int16_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator()
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase();
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int16_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int16_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int16_t_Transform"), SuppressUnmanagedCodeSecurity]
+ private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, short input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int16_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+ internal override string Transform(short input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ if (!success)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+ using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative))
+ {
+ byte[] buffer = new byte[outputSize.ToInt32()];
+ Marshal.Copy(output, buffer, 0, buffer.Length);
+ return Encoding.UTF8.GetString(buffer);
+ }
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int16_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region Int32Column
+
+ internal sealed class Int32TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ internal Int32TypedColumn(string name, string source) :
+ base(name, source, typeof(int).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int32_t_CreateEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int32_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int32_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator()
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase();
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int32_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int32_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int32_t_Transform"), SuppressUnmanagedCodeSecurity]
+ private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, int input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int32_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+ internal override string Transform(int input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ if (!success)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+ using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative))
+ {
+ byte[] buffer = new byte[outputSize.ToInt32()];
+ Marshal.Copy(output, buffer, 0, buffer.Length);
+ return Encoding.UTF8.GetString(buffer);
+ }
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int32_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region Int64Column
+
+ internal sealed class Int64TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ internal Int64TypedColumn(string name, string source) :
+ base(name, source, typeof(long).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int64_t_CreateEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int64_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int64_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator()
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase();
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int64_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int64_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int64_t_Transform"), SuppressUnmanagedCodeSecurity]
+ private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, long input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int64_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+ internal override string Transform(long input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ if (!success)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+ using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative))
+ {
+ byte[] buffer = new byte[outputSize.ToInt32()];
+ Marshal.Copy(output, buffer, 0, buffer.Length);
+ return Encoding.UTF8.GetString(buffer);
+ }
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_int64_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region UInt8Column
+
+ internal sealed class UInt8TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ internal UInt8TypedColumn(string name, string source) :
+ base(name, source, typeof(byte).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint8_t_CreateEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint8_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint8_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator()
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase();
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint8_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint8_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint8_t_Transform"), SuppressUnmanagedCodeSecurity]
+ private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, byte input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint8_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+ internal override string Transform(byte input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ if (!success)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+ using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative))
+ {
+ byte[] buffer = new byte[outputSize.ToInt32()];
+ Marshal.Copy(output, buffer, 0, buffer.Length);
+ return Encoding.UTF8.GetString(buffer);
+ }
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint8_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region UInt16Column
+
+ internal sealed class UInt16TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+ internal UInt16TypedColumn(string name, string source) :
+ base(name, source, typeof(ushort).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint16_t_CreateEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint16_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint16_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator()
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase();
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint16_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint16_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint16_t_Transform"), SuppressUnmanagedCodeSecurity]
+ private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, ushort input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint16_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+ internal override string Transform(ushort input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ if (!success)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+ using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative))
+ {
+ byte[] buffer = new byte[outputSize.ToInt32()];
+ Marshal.Copy(output, buffer, 0, buffer.Length);
+ return Encoding.UTF8.GetString(buffer);
+ }
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint16_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region UInt32Column
+
+ internal sealed class UInt32TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+
+ internal UInt32TypedColumn(string name, string source) :
+ base(name, source, typeof(uint).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint32_t_CreateEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint32_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint32_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator()
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase();
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint32_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint32_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint32_t_Transform"), SuppressUnmanagedCodeSecurity]
+ private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, uint input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint32_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+ internal override string Transform(uint input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ if (!success)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+ using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative))
+ {
+ byte[] buffer = new byte[outputSize.ToInt32()];
+ Marshal.Copy(output, buffer, 0, buffer.Length);
+ return Encoding.UTF8.GetString(buffer);
+ }
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint32_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region UInt64Column
+
+ internal sealed class UInt64TypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+
+ internal UInt64TypedColumn(string name, string source) :
+ base(name, source, typeof(ulong).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint64_t_CreateEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint64_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint64_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator()
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase();
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint64_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint64_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint64_t_Transform"), SuppressUnmanagedCodeSecurity]
+ private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, ulong input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint64_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+ internal override string Transform(ulong input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ if (!success)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+ using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative))
+ {
+ byte[] buffer = new byte[outputSize.ToInt32()];
+ Marshal.Copy(output, buffer, 0, buffer.Length);
+ return Encoding.UTF8.GetString(buffer);
+ }
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_uint64_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region FloatColumn
+
+ internal sealed class FloatTypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+
+ internal FloatTypedColumn(string name, string source) :
+ base(name, source, typeof(float).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_float_t_CreateEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_float_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_float_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator()
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase();
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_float_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_float_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_float_t_Transform"), SuppressUnmanagedCodeSecurity]
+
+ private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, float input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_float_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+ internal override string Transform(float input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ if (!success)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+ using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative))
+ {
+ byte[] buffer = new byte[outputSize.ToInt32()];
+ Marshal.Copy(output, buffer, 0, buffer.Length);
+ return Encoding.UTF8.GetString(buffer);
+ }
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_float_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region DoubleColumn
+
+ internal sealed class DoubleTypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+
+ internal DoubleTypedColumn(string name, string source) :
+ base(name, source, typeof(double).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_double_t_CreateEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_double_t_DestroyEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_double_t_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator()
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase();
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_double_t_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_double_t_DestroyTransformer"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_double_t_Transform"), SuppressUnmanagedCodeSecurity]
+ private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, double input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_double_t_DestroyTransformedData"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+ internal override string Transform(double input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ if (!success)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+ using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative))
+ {
+ byte[] buffer = new byte[outputSize.ToInt32()];
+ Marshal.Copy(output, buffer, 0, buffer.Length);
+ return Encoding.UTF8.GetString(buffer);
+ }
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_double_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region BoolColumn
+
+ internal sealed class BoolTypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+
+ internal BoolTypedColumn(string name, string source) :
+ base(name, source, typeof(bool).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_bool_CreateEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_bool_DestroyEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_bool_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator()
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase();
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_bool_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_bool_DestroyTransformer"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_bool_Transform"), SuppressUnmanagedCodeSecurity]
+ private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, bool input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_bool_DestroyTransformedData"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+ internal override string Transform(bool input)
+ {
+ var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ if (!success)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+ using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative))
+ {
+ byte[] buffer = new byte[outputSize.ToInt32()];
+ Marshal.Copy(output, buffer, 0, buffer.Length);
+ return Encoding.UTF8.GetString(buffer);
+ }
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_bool_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region StringColumn
+
+ internal sealed class StringTypedColumn : TypedColumn
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+
+ internal StringTypedColumn(string name, string source) :
+ base(name, source, typeof(string).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_CreateEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_DestroyEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator()
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase();
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_DestroyTransformer"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_Transform"), SuppressUnmanagedCodeSecurity]
+ private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, IntPtr input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_DestroyTransformedData"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+ internal override string Transform(string input)
+ {
+ // Convert to byte array with NullPointer at end.
+ var rawData = Encoding.UTF8.GetBytes(input + char.MinValue);
+ bool result;
+ GCHandle handle = GCHandle.Alloc(rawData, GCHandleType.Pinned);
+ try
+ {
+ IntPtr rawDataPtr = handle.AddrOfPinnedObject();
+ result = TransformDataNative(_transformerHandler, rawDataPtr, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+
+ if (!result)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+
+ using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative))
+ {
+ byte[] buffer = new byte[outputSize.ToInt32()];
+ Marshal.Copy(output, buffer, 0, buffer.Length);
+ return Encoding.UTF8.GetString(buffer);
+ }
+ }
+ finally
+ {
+ handle.Free();
+ }
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_t_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #region ReadOnlyCharColumn
+
+ internal sealed class ReadOnlyCharTypedColumn : TypedColumn>
+ {
+ private TransformerEstimatorSafeHandle _transformerHandler;
+
+ internal ReadOnlyCharTypedColumn(string name, string source) :
+ base(name, source, typeof(ReadOnlyMemory).ToString())
+ {
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_CreateEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateEstimatorNative(out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_DestroyEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+ internal override void CreateTransformerFromEstimator()
+ {
+ _transformerHandler = CreateTransformerFromEstimatorBase();
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+ private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize)
+ {
+ var result = CreateTransformerFromSavedDataNative(rawData, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_DestroyTransformer"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_Transform"), SuppressUnmanagedCodeSecurity]
+ private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, IntPtr input, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_DestroyTransformedData"), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformedDataNative(IntPtr output, IntPtr outputSize, out IntPtr errorHandle);
+ internal override string Transform(ReadOnlyMemory input)
+ {
+ var rawData = Encoding.UTF8.GetBytes(input.ToString() + char.MinValue);
+ bool result;
+ GCHandle handle = GCHandle.Alloc(rawData, GCHandleType.Pinned);
+ try
+ {
+ IntPtr rawDataPtr = handle.AddrOfPinnedObject();
+ result = TransformDataNative(_transformerHandler, rawDataPtr, out IntPtr output, out IntPtr outputSize, out IntPtr errorHandle);
+
+ if (!result)
+ {
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+
+ if (outputSize.ToInt32() == 0)
+ return string.Empty;
+
+ using (var handler = new TransformedDataSafeHandle(output, outputSize, DestroyTransformedDataNative))
+ {
+ byte[] buffer = new byte[outputSize.ToInt32()];
+ Marshal.Copy(output, buffer, 0, buffer.Length);
+ return Encoding.UTF8.GetString(buffer);
+ }
+ }
+ finally
+ {
+ handle.Free();
+ }
+ }
+
+ public override void Dispose()
+ {
+ if (!_transformerHandler.IsClosed)
+ _transformerHandler.Dispose();
+ }
+
+ private protected override bool CreateEstimatorHelper(out IntPtr estimator, out IntPtr errorHandle) =>
+ CreateEstimatorNative(out estimator, out errorHandle);
+
+ private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) =>
+ CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle);
+
+ private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) =>
+ DestroyEstimatorNative(estimator, out errorHandle);
+
+ private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) =>
+ DestroyTransformerNative(transformer, out errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "StringFeaturizer_string_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) =>
+ CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
+ }
+
+ #endregion
+
+ #endregion
+
+ private sealed class Mapper : MapperBase
+ {
+
+ #region Class data members
+
+ private readonly ToStringTransformer _parent;
+
+ #endregion
+
+ public Mapper(ToStringTransformer parent, DataViewSchema inputSchema) :
+ base(parent.Host.Register(nameof(Mapper)), inputSchema, parent)
+ {
+ _parent = parent;
+ }
+
+ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
+ {
+ return _parent._columns.Select(x => new DataViewSchema.DetachedColumn(x.Name, ColumnTypeExtensions.PrimitiveTypeFromType(typeof(string)))).ToArray();
+ }
+
+ private Delegate MakeGetter(DataViewRow input, int iinfo)
+ {
+ ValueGetter> result = (ref ReadOnlyMemory dst) =>
+ {
+ var inputColumn = input.Schema[_parent._columns[iinfo].Source];
+
+ var srcGetter = input.GetGetter(inputColumn);
+
+ T value = default;
+ srcGetter(ref value);
+ string transformed = ((TypedColumn)_parent._columns[iinfo]).Transform(value);
+ dst = new ReadOnlyMemory(transformed.ToArray());
+ };
+
+ return result;
+ }
+
+ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer)
+ {
+ disposer = null;
+ Type inputType = input.Schema[_parent._columns[iinfo].Source].Type.RawType;
+ return Utils.MarshalInvoke(MakeGetter, inputType, input, iinfo);
+ }
+
+ private protected override Func GetDependenciesCore(Func activeOutput)
+ {
+ var active = new bool[InputSchema.Count];
+ for (int i = 0; i < InputSchema.Count; i++)
+ {
+ if (_parent._columns.Any(x => x.Source == InputSchema[i].Name))
+ {
+ active[i] = true;
+ }
+ }
+
+ return col => active[col];
+ }
+
+ private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
+ }
+ }
+
+ internal static class ToStringTransformerEntrypoint
+ {
+ [TlcModule.EntryPoint(Name = "Transforms.ToString",
+ Desc = ToStringTransformer.Summary,
+ UserName = ToStringTransformer.UserName,
+ ShortName = ToStringTransformer.ShortName)]
+ public static CommonOutputs.TransformOutput ToString(IHostEnvironment env, ToStringTransformerEstimator.Options input)
+ {
+ var h = EntryPointUtils.CheckArgsAndCreateHost(env, ToStringTransformer.ShortName, input);
+ var xf = new ToStringTransformerEstimator(h, input).Fit(input.Data).Transform(input.Data);
+ return new CommonOutputs.TransformOutput()
+ {
+ Model = new TransformModelImpl(h, xf, input.Data),
+ OutputData = xf
+ };
+ }
+ }
+
+}
diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
index 4f2bcc426a..fd4fcbc369 100644
--- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
+++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
@@ -76,6 +76,8 @@ Transforms.BinaryPredictionScoreColumnsRenamer For binary prediction, it renames
Transforms.BinNormalizer The values are assigned into equidensity bins and a value is mapped to its bin_number/number_of_bins. Microsoft.ML.Data.Normalize Bin Microsoft.ML.Transforms.NormalizeTransform+BinArguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.CategoricalHashOneHotVectorizer Converts the categorical value into an indicator array by hashing the value and using the hash as an index in the bag. If the input column is a vector, a single indicator bag is returned for it. Microsoft.ML.Transforms.Categorical CatTransformHash Microsoft.ML.Transforms.OneHotHashEncodingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.CategoricalOneHotVectorizer Converts the categorical value into an indicator array by building a dictionary of categories based on the data and using the id in the dictionary as the index in the array. Microsoft.ML.Transforms.Categorical CatTransformDict Microsoft.ML.Transforms.OneHotEncodingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
+Transforms.CategoryImputer Fills in missing values in a column based on the most frequent value Microsoft.ML.Featurizers.CategoryImputerEntrypoint ImputeToKey Microsoft.ML.Featurizers.CategoryImputerTransformerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
+Transforms.CatImputer Microsoft.ML.Featurizers.CatImputerEntrypoint CatImputer Microsoft.ML.Featurizers.CatImputerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.CharacterTokenizer Character-oriented tokenizer where text is considered a sequence of characters. Microsoft.ML.Transforms.Text.TextAnalytics CharTokenize Microsoft.ML.Transforms.Text.TokenizingByCharactersTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.ColumnConcatenator Concatenates one or more columns of the same item type. Microsoft.ML.EntryPoints.SchemaManipulation ConcatColumns Microsoft.ML.Data.ColumnConcatenatingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.ColumnCopier Duplicates columns from the dataset Microsoft.ML.EntryPoints.SchemaManipulation CopyColumns Microsoft.ML.Transforms.ColumnCopyingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
@@ -85,6 +87,7 @@ Transforms.CombinerByContiguousGroupId Groups values of a scalar column into a v
Transforms.ConditionalNormalizer Normalize the columns only if needed Microsoft.ML.Data.Normalize IfNeeded Microsoft.ML.Transforms.NormalizeTransform+MinMaxArguments Microsoft.ML.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput]
Transforms.DatasetScorer Score a dataset with a predictor model Microsoft.ML.EntryPoints.ScoreModel Score Microsoft.ML.EntryPoints.ScoreModel+Input Microsoft.ML.EntryPoints.ScoreModel+Output
Transforms.DatasetTransformScorer Score a dataset with a transform model Microsoft.ML.EntryPoints.ScoreModel ScoreUsingTransform Microsoft.ML.EntryPoints.ScoreModel+InputTransformScorer Microsoft.ML.EntryPoints.ScoreModel+Output
+Transforms.DateTimeSplitter Splits a date time value into each individual component Microsoft.ML.Featurizers.DateTimeTransformerEntrypoint DateTimeSplit Microsoft.ML.Featurizers.DateTimeTransformerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.Dictionarizer Converts input values (words, numbers, etc.) to index in a dictionary. Microsoft.ML.Transforms.Text.TextAnalytics TermTransform Microsoft.ML.Transforms.ValueToKeyMappingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.FeatureCombiner Combines all the features into one feature column. Microsoft.ML.EntryPoints.FeatureCombiner PrepareFeatures Microsoft.ML.EntryPoints.FeatureCombiner+FeatureCombinerInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.FeatureContributionCalculationTransformer For each data point, calculates the contribution of individual features to the model prediction. Microsoft.ML.Transforms.FeatureContributionEntryPoint FeatureContributionCalculation Microsoft.ML.Transforms.FeatureContributionCalculatingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
@@ -119,6 +122,8 @@ Transforms.PcaCalculator PCA is a dimensionality-reduction transform which compu
Transforms.PermutationFeatureImportance Permutation Feature Importance (PFI) Microsoft.ML.Transforms.PermutationFeatureImportanceEntryPoints PermutationFeatureImportance Microsoft.ML.Transforms.PermutationFeatureImportanceArguments Microsoft.ML.Transforms.PermutationFeatureImportanceOutput
Transforms.PredictedLabelColumnOriginalValueConverter Transforms a predicted label column to its original values, unless it is of type bool. Microsoft.ML.EntryPoints.FeatureCombiner ConvertPredictedLabel Microsoft.ML.EntryPoints.FeatureCombiner+PredictedLabelInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.RandomNumberGenerator Adds a column with a generated number sequence. Microsoft.ML.Transforms.RandomNumberGenerator Generate Microsoft.ML.Transforms.GenerateNumberTransform+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
+Transforms.RobScal Microsoft.ML.Featurizers.RobScalEntrypoint RobScal Microsoft.ML.Featurizers.RobScalEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
+Transforms.RobustScaler Removes the median and scales the data according to the quantile range. Microsoft.ML.Featurizers.RobustScalerEntrypoint RobustScaler Microsoft.ML.Featurizers.RobustScalerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.RowRangeFilter Filters a dataview on a column of type Single, Double or Key (contiguous). Keeps the values that are in the specified min/max range. NaNs are always filtered out. If the input is a Key type, the min/max are considered percentages of the number of values. Microsoft.ML.EntryPoints.SelectRows FilterByRange Microsoft.ML.Transforms.RangeFilter+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.RowSkipAndTakeFilter Allows limiting input to a subset of rows at an optional offset. Can be used to implement data paging. Microsoft.ML.EntryPoints.SelectRows SkipAndTakeFilter Microsoft.ML.Transforms.SkipTakeFilter+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.RowSkipFilter Allows limiting input to a subset of rows by skipping a number of rows. Microsoft.ML.EntryPoints.SelectRows SkipFilter Microsoft.ML.Transforms.SkipTakeFilter+SkipOptions Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
@@ -130,6 +135,8 @@ Transforms.SentimentAnalyzer Uses a pretrained sentiment model to score input st
Transforms.TensorFlowScorer Transforms the data using the TensorFlow model. Microsoft.ML.Transforms.TensorFlowTransformer TensorFlowScorer Microsoft.ML.Transforms.TensorFlowEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.TextFeaturizer A transform that turns a collection of text documents into numerical feature vectors. The feature vectors are normalized counts of (word and/or character) n-grams in a given tokenized text. Microsoft.ML.Transforms.Text.TextAnalytics TextTransform Microsoft.ML.Transforms.Text.TextFeaturizingEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.TextToKeyConverter Converts input values (words, numbers, etc.) to index in a dictionary. Microsoft.ML.Transforms.Categorical TextToKey Microsoft.ML.Transforms.ValueToKeyMappingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
+Transforms.TimeSeriesImputer Fills in missing row and values Microsoft.ML.Featurizers.TimeSeriesTransformerEntrypoint TimeSeriesImputer Microsoft.ML.Featurizers.TimeSeriesImputerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
+Transforms.ToString Turns the given column into a column of its string representation Microsoft.ML.Featurizers.ToStringTransformerEntrypoint ToString Microsoft.ML.Featurizers.ToStringTransformerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.TrainTestDatasetSplitter Split the dataset into train and test sets Microsoft.ML.EntryPoints.TrainTestSplit Split Microsoft.ML.EntryPoints.TrainTestSplit+Input Microsoft.ML.EntryPoints.TrainTestSplit+Output
Transforms.TreeLeafFeaturizer Trains a tree ensemble, or loads it from a file, then maps a numeric feature vector to three outputs: 1. A vector containing the individual tree outputs of the tree ensemble. 2. A vector indicating the leaves that the feature vector falls on in the tree ensemble. 3. A vector indicating the paths that the feature vector falls on in the tree ensemble. If a both a model file and a trainer are specified - will use the model file. If neither are specified, will train a default FastTree model. This can handle key labels by training a regression model towards their optionally permuted indices. Microsoft.ML.Data.TreeFeaturize Featurizer Microsoft.ML.Data.TreeEnsembleFeaturizerTransform+ArgumentsForEntryPoint Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.TwoHeterogeneousModelCombiner Combines a TransformModel and a PredictorModel into a single PredictorModel. Microsoft.ML.EntryPoints.ModelOperations CombineTwoModels Microsoft.ML.EntryPoints.ModelOperations+SimplePredictorModelInput Microsoft.ML.EntryPoints.ModelOperations+PredictorModelOutput
diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
index c8e6d6e55c..69da1640b9 100644
--- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json
+++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
@@ -17301,6 +17301,158 @@
"ITransformOutput"
]
},
+ {
+ "Name": "Transforms.CategoryImputer",
+ "Desc": "Fills in missing values in a column based on the most frequent value",
+ "FriendlyName": "CategoryImputer",
+ "ShortName": "CategoryImputer",
+ "Inputs": [
+ {
+ "Name": "Column",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": {
+ "Kind": "Struct",
+ "Fields": [
+ {
+ "Name": "Name",
+ "Type": "String",
+ "Desc": "Name of the new column",
+ "Aliases": [
+ "name"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "Source",
+ "Type": "String",
+ "Desc": "Name of the source column",
+ "Aliases": [
+ "src"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": null
+ }
+ ]
+ }
+ },
+ "Desc": "New column definition (optional form: name:src)",
+ "Aliases": [
+ "col"
+ ],
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "Data",
+ "Type": "DataView",
+ "Desc": "Input dataset",
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ }
+ ],
+ "Outputs": [
+ {
+ "Name": "OutputData",
+ "Type": "DataView",
+ "Desc": "Transformed dataset"
+ },
+ {
+ "Name": "Model",
+ "Type": "TransformModel",
+ "Desc": "Transform model"
+ }
+ ],
+ "InputKind": [
+ "ITransformInput"
+ ],
+ "OutputKind": [
+ "ITransformOutput"
+ ]
+ },
+ {
+ "Name": "Transforms.CatImputer",
+ "Desc": "",
+ "FriendlyName": "CatImputerTransformer",
+ "ShortName": "CatImputerTransformer",
+ "Inputs": [
+ {
+ "Name": "Column",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": {
+ "Kind": "Struct",
+ "Fields": [
+ {
+ "Name": "Name",
+ "Type": "String",
+ "Desc": "Name of the new column",
+ "Aliases": [
+ "name"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "Source",
+ "Type": "String",
+ "Desc": "Name of the source column",
+ "Aliases": [
+ "src"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": null
+ }
+ ]
+ }
+ },
+ "Desc": "New column definition (optional form: name:src)",
+ "Aliases": [
+ "col"
+ ],
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "Data",
+ "Type": "DataView",
+ "Desc": "Input dataset",
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ }
+ ],
+ "Outputs": [
+ {
+ "Name": "OutputData",
+ "Type": "DataView",
+ "Desc": "Transformed dataset"
+ },
+ {
+ "Name": "Model",
+ "Type": "TransformModel",
+ "Desc": "Transform model"
+ }
+ ],
+ "InputKind": [
+ "ITransformInput"
+ ],
+ "OutputKind": [
+ "ITransformOutput"
+ ]
+ },
{
"Name": "Transforms.CharacterTokenizer",
"Desc": "Character-oriented tokenizer where text is considered a sequence of characters.",
@@ -18077,6 +18229,157 @@
}
]
},
+ {
+ "Name": "Transforms.DateTimeSplitter",
+ "Desc": "Splits a date time value into each individual component",
+ "FriendlyName": "DateTime Transform",
+ "ShortName": "DateTimeTransform",
+ "Inputs": [
+ {
+ "Name": "Source",
+ "Type": "String",
+ "Desc": "Input column",
+ "Aliases": [
+ "src"
+ ],
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "Data",
+ "Type": "DataView",
+ "Desc": "Input dataset",
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "Prefix",
+ "Type": "String",
+ "Desc": "Output column prefix",
+ "Aliases": [
+ "pre"
+ ],
+ "Required": true,
+ "SortOrder": 2.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "ColumnsToDrop",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": {
+ "Kind": "Enum",
+ "Values": [
+ "Year",
+ "Month",
+ "Day",
+ "Hour",
+ "Minute",
+ "Second",
+ "AmPm",
+ "Hour12",
+ "DayOfWeek",
+ "DayOfQuarter",
+ "DayOfYear",
+ "WeekOfMonth",
+ "QuarterOfYear",
+ "HalfOfYear",
+ "WeekIso",
+ "YearIso",
+ "MonthLabel",
+ "AmPmLabel",
+ "DayOfWeekLabel",
+ "HolidayName",
+ "IsPaidTimeOff"
+ ]
+ }
+ },
+ "Desc": "Columns to drop after the DateTime Expansion",
+ "Aliases": [
+ "drop"
+ ],
+ "Required": false,
+ "SortOrder": 3.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "Country",
+ "Type": {
+ "Kind": "Enum",
+ "Values": [
+ "None",
+ "Argentina",
+ "Australia",
+ "Austria",
+ "Belarus",
+ "Belgium",
+ "Brazil",
+ "Canada",
+ "Colombia",
+ "Croatia",
+ "Czech",
+ "Denmark",
+ "England",
+ "Finland",
+ "France",
+ "Germany",
+ "Hungary",
+ "India",
+ "Ireland",
+ "IsleofMan",
+ "Italy",
+ "Japan",
+ "Mexico",
+ "Netherlands",
+ "NewZealand",
+ "NorthernIreland",
+ "Norway",
+ "Poland",
+ "Portugal",
+ "Scotland",
+ "Slovenia",
+ "SouthAfrica",
+ "Spain",
+ "Sweden",
+ "Switzerland",
+ "Ukraine",
+ "UnitedKingdom",
+ "UnitedStates",
+ "Wales"
+ ]
+ },
+ "Desc": "Country to get holidays for. Defaults to none if not passed",
+ "Aliases": [
+ "ctry"
+ ],
+ "Required": false,
+ "SortOrder": 4.0,
+ "IsNullable": false,
+ "Default": "None"
+ }
+ ],
+ "Outputs": [
+ {
+ "Name": "OutputData",
+ "Type": "DataView",
+ "Desc": "Transformed dataset"
+ },
+ {
+ "Name": "Model",
+ "Type": "TransformModel",
+ "Desc": "Transform model"
+ }
+ ],
+ "InputKind": [
+ "ITransformInput"
+ ],
+ "OutputKind": [
+ "ITransformOutput"
+ ]
+ },
{
"Name": "Transforms.Dictionarizer",
"Desc": "Converts input values (words, numbers, etc.) to index in a dictionary.",
@@ -21932,15 +22235,215 @@
]
},
{
- "Name": "Transforms.RowRangeFilter",
- "Desc": "Filters a dataview on a column of type Single, Double or Key (contiguous). Keeps the values that are in the specified min/max range. NaNs are always filtered out. If the input is a Key type, the min/max are considered percentages of the number of values.",
- "FriendlyName": "Range Filter",
- "ShortName": "RangeFilter",
+ "Name": "Transforms.RobScal",
+ "Desc": "",
+ "FriendlyName": "RobScalTransformer",
+ "ShortName": "RobScalTransformer",
"Inputs": [
{
"Name": "Column",
- "Type": "String",
- "Desc": "Column",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": {
+ "Kind": "Struct",
+ "Fields": [
+ {
+ "Name": "Name",
+ "Type": "String",
+ "Desc": "Name of the new column",
+ "Aliases": [
+ "name"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "Source",
+ "Type": "String",
+ "Desc": "Name of the source column",
+ "Aliases": [
+ "src"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": null
+ }
+ ]
+ }
+ },
+ "Desc": "New column definition (optional form: name:src)",
+ "Aliases": [
+ "col"
+ ],
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "Data",
+ "Type": "DataView",
+ "Desc": "Input dataset",
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ }
+ ],
+ "Outputs": [
+ {
+ "Name": "OutputData",
+ "Type": "DataView",
+ "Desc": "Transformed dataset"
+ },
+ {
+ "Name": "Model",
+ "Type": "TransformModel",
+ "Desc": "Transform model"
+ }
+ ],
+ "InputKind": [
+ "ITransformInput"
+ ],
+ "OutputKind": [
+ "ITransformOutput"
+ ]
+ },
+ {
+ "Name": "Transforms.RobustScaler",
+ "Desc": "Removes the median and scales the data according to the quantile range.",
+ "FriendlyName": "RobustScalerTransformer",
+ "ShortName": "RobustScalerTransformer",
+ "Inputs": [
+ {
+ "Name": "Column",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": {
+ "Kind": "Struct",
+ "Fields": [
+ {
+ "Name": "Name",
+ "Type": "String",
+ "Desc": "Name of the new column",
+ "Aliases": [
+ "name"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "Source",
+ "Type": "String",
+ "Desc": "Name of the source column",
+ "Aliases": [
+ "src"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": null
+ }
+ ]
+ }
+ },
+ "Desc": "New column definition (optional form: name:src)",
+ "Aliases": [
+ "col"
+ ],
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "Data",
+ "Type": "DataView",
+ "Desc": "Input dataset",
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "Center",
+ "Type": "Bool",
+ "Desc": "If True, center the data before scaling.",
+ "Aliases": [
+ "ctr"
+ ],
+ "Required": false,
+ "SortOrder": 2.0,
+ "IsNullable": false,
+ "Default": true
+ },
+ {
+ "Name": "Scale",
+ "Type": "Bool",
+ "Desc": "If True, scale the data to interquartile range.",
+ "Aliases": [
+ "sc"
+ ],
+ "Required": false,
+ "SortOrder": 3.0,
+ "IsNullable": false,
+ "Default": true
+ },
+ {
+ "Name": "QuantileMin",
+ "Type": "Float",
+ "Desc": "Min for the quantile range used to calculate scale.",
+ "Aliases": [
+ "min"
+ ],
+ "Required": false,
+ "SortOrder": 4.0,
+ "IsNullable": false,
+ "Default": 25.0
+ },
+ {
+ "Name": "QuantileMax",
+ "Type": "Float",
+ "Desc": "Max for the quantile range used to calculate scale.",
+ "Aliases": [
+ "max"
+ ],
+ "Required": false,
+ "SortOrder": 5.0,
+ "IsNullable": false,
+ "Default": 75.0
+ }
+ ],
+ "Outputs": [
+ {
+ "Name": "OutputData",
+ "Type": "DataView",
+ "Desc": "Transformed dataset"
+ },
+ {
+ "Name": "Model",
+ "Type": "TransformModel",
+ "Desc": "Transform model"
+ }
+ ],
+ "InputKind": [
+ "ITransformInput"
+ ],
+ "OutputKind": [
+ "ITransformOutput"
+ ]
+ },
+ {
+ "Name": "Transforms.RowRangeFilter",
+ "Desc": "Filters a dataview on a column of type Single, Double or Key (contiguous). Keeps the values that are in the specified min/max range. NaNs are always filtered out. If the input is a Key type, the min/max are considered percentages of the number of values.",
+ "FriendlyName": "Range Filter",
+ "ShortName": "RangeFilter",
+ "Inputs": [
+ {
+ "Name": "Column",
+ "Type": "String",
+ "Desc": "Column",
"Aliases": [
"col"
],
@@ -22941,6 +23444,207 @@
"ITransformOutput"
]
},
+ {
+ "Name": "Transforms.TimeSeriesImputer",
+ "Desc": "Fills in missing row and values",
+ "FriendlyName": "TimeSeriesImputer",
+ "ShortName": "TimeSeriesImputer",
+ "Inputs": [
+ {
+ "Name": "TimeSeriesColumn",
+ "Type": "String",
+ "Desc": "Column representing the time",
+ "Aliases": [
+ "time"
+ ],
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "Data",
+ "Type": "DataView",
+ "Desc": "Input dataset",
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "GrainColumns",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": "String"
+ },
+ "Desc": "List of grain columns",
+ "Aliases": [
+ "grains"
+ ],
+ "Required": true,
+ "SortOrder": 2.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "FilterColumns",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": "String"
+ },
+ "Desc": "Columns to filter",
+ "Aliases": [
+ "filters"
+ ],
+ "Required": false,
+ "SortOrder": 2.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "FilterMode",
+ "Type": {
+ "Kind": "Enum",
+ "Values": [
+ "NoFilter",
+ "Include",
+ "Exclude"
+ ]
+ },
+ "Desc": "Filter mode. Either include or exclude",
+ "Aliases": [
+ "fmode"
+ ],
+ "Required": false,
+ "SortOrder": 3.0,
+ "IsNullable": false,
+ "Default": "Exclude"
+ },
+ {
+ "Name": "ImputeMode",
+ "Type": {
+ "Kind": "Enum",
+ "Values": [
+ "ForwardFill",
+ "BackFill",
+ "Median",
+ "Interpolate"
+ ]
+ },
+ "Desc": "Mode for imputing, defaults to ForwardFill if not provided",
+ "Aliases": [
+ "mode"
+ ],
+ "Required": false,
+ "SortOrder": 3.0,
+ "IsNullable": false,
+ "Default": "ForwardFill"
+ },
+ {
+ "Name": "SupressTypeErrors",
+ "Type": "Bool",
+ "Desc": "Supress the errors that would occur if a column and impute mode are imcompatible. If true, will skip the column. If false, will stop and throw an error.",
+ "Aliases": [
+ "error"
+ ],
+ "Required": false,
+ "SortOrder": 3.0,
+ "IsNullable": false,
+ "Default": false
+ }
+ ],
+ "Outputs": [
+ {
+ "Name": "OutputData",
+ "Type": "DataView",
+ "Desc": "Transformed dataset"
+ },
+ {
+ "Name": "Model",
+ "Type": "TransformModel",
+ "Desc": "Transform model"
+ }
+ ],
+ "InputKind": [
+ "ITransformInput"
+ ],
+ "OutputKind": [
+ "ITransformOutput"
+ ]
+ },
+ {
+ "Name": "Transforms.ToString",
+ "Desc": "Turns the given column into a column of its string representation",
+ "FriendlyName": "ToString Transform",
+ "ShortName": "ToStringTransform",
+ "Inputs": [
+ {
+ "Name": "Column",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": {
+ "Kind": "Struct",
+ "Fields": [
+ {
+ "Name": "Name",
+ "Type": "String",
+ "Desc": "Name of the new column",
+ "Aliases": [
+ "name"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "Source",
+ "Type": "String",
+ "Desc": "Name of the source column",
+ "Aliases": [
+ "src"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": null
+ }
+ ]
+ }
+ },
+ "Desc": "New column definition (optional form: name:src)",
+ "Aliases": [
+ "col"
+ ],
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "Data",
+ "Type": "DataView",
+ "Desc": "Input dataset",
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ }
+ ],
+ "Outputs": [
+ {
+ "Name": "OutputData",
+ "Type": "DataView",
+ "Desc": "Transformed dataset"
+ },
+ {
+ "Name": "Model",
+ "Type": "TransformModel",
+ "Desc": "Transform model"
+ }
+ ],
+ "InputKind": [
+ "ITransformInput"
+ ],
+ "OutputKind": [
+ "ITransformOutput"
+ ]
+ },
{
"Name": "Transforms.TrainTestDatasetSplitter",
"Desc": "Split the dataset into train and test sets",
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
index 87a0c61f1c..cba56c6e68 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
@@ -7,6 +7,7 @@
using System.IO;
using System.Linq;
using System.Text.RegularExpressions;
+using Microsoft.ML.Featurizers;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Core.Tests.UnitTests;
using Microsoft.ML.Data;
@@ -328,6 +329,7 @@ public void EntryPointCatalogCheckDuplicateParams()
Env.ComponentCatalog.RegisterAssembly(typeof(SymbolicSgdLogisticRegressionBinaryTrainer).Assembly);
Env.ComponentCatalog.RegisterAssembly(typeof(SaveOnnxCommand).Assembly);
Env.ComponentCatalog.RegisterAssembly(typeof(TimeSeriesProcessingEntryPoints).Assembly);
+ Env.ComponentCatalog.RegisterAssembly(typeof(TimeSeriesImputerEstimator).Assembly);
Env.ComponentCatalog.RegisterAssembly(typeof(ParquetLoader).Assembly);
var catalog = Env.ComponentCatalog;
diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
index 453fb8c6c8..1241918e3e 100644
--- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
+++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
@@ -10,6 +10,7 @@
+
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs
index 36eb2d1b6b..026684e1b4 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs
@@ -5,6 +5,9 @@
using Microsoft.ML.Data;
using Microsoft.ML.RunTests;
using Microsoft.ML.Trainers;
+using Microsoft.ML.Transforms;
+using System;
+using System.Collections.Generic;
using Xunit;
namespace Microsoft.ML.Scenarios
diff --git a/test/Microsoft.ML.Tests/Transformers/CategoryImputerTests.cs b/test/Microsoft.ML.Tests/Transformers/CategoryImputerTests.cs
new file mode 100644
index 0000000000..ad548fcda0
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Transformers/CategoryImputerTests.cs
@@ -0,0 +1,121 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Data;
+using Microsoft.ML.Featurizers;
+using Microsoft.ML.RunTests;
+using System;
+using System.Collections.Generic;
+using Xunit;
+using Xunit.Abstractions;
+using System.Linq;
+
+namespace Microsoft.ML.Tests.Transformers
+{
+ public class CategoryImputerTests : TestDataPipeBase
+ {
+ public CategoryImputerTests(ITestOutputHelper output) : base(output)
+ {
+ }
+
+ private class SchemaAllTypes
+ {
+ public byte uint8_t;
+ public sbyte int8_t;
+ public short int16_t;
+ public ushort uint16_t;
+ public int int32_t;
+ public uint uint32_t;
+ public long int64_t;
+ public ulong uint64_t;
+ public float float_t;
+ public double double_t;
+ public string str;
+
+ internal SchemaAllTypes(byte num_i, float num_f, string s)
+ {
+ uint8_t = num_i;
+ int8_t = (sbyte)num_i;
+ int16_t = num_i;
+ uint16_t = num_i;
+ int32_t = num_i;
+ uint32_t = num_i;
+ int64_t = num_i;
+ uint64_t = num_i;
+ float_t = num_f;
+ double_t = num_f;
+ str = s;
+ }
+ }
+
+ private IDataView GetIDataView()
+ {
+ List dataList = new List();
+ dataList.Add(new SchemaAllTypes(1, 1.5f, "one"));
+ dataList.Add(new SchemaAllTypes(1, 1.5f, "one"));
+ dataList.Add(new SchemaAllTypes(2, 2.5f, "two"));
+ dataList.Add(new SchemaAllTypes(0, Single.NaN, null));
+ dataList.Add(new SchemaAllTypes(1, 1.5f, "one"));
+ dataList.Add(new SchemaAllTypes(1, 1.5f, "one"));
+ dataList.Add(new SchemaAllTypes(2, 2.5f, "two"));
+ dataList.Add(new SchemaAllTypes(0, Single.NaN, null));
+ dataList.Add(new SchemaAllTypes(1, 1.5f, "one"));
+ dataList.Add(new SchemaAllTypes(1, 1.5f, "one"));
+ dataList.Add(new SchemaAllTypes(2, 2.5f, "two"));
+ dataList.Add(new SchemaAllTypes(0, Single.NaN, null));
+
+ MLContext mlContext = new MLContext(1);
+ IDataView data = mlContext.Data.LoadFromEnumerable(dataList);
+ return data;
+ }
+ private void Test(string columnName, bool addNewCol, T mostFrequentValue)
+ {
+
+ string outputColName = addNewCol ? columnName + "_output" : columnName;
+ string inputColName = addNewCol ? columnName : null;
+
+ MLContext mlContext = new MLContext(1);
+ var data = GetIDataView();
+ var pipeline = mlContext.Transforms.CatagoryImputerTransformer(outputColName, inputColName);
+ TestEstimatorCore(pipeline, data);
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+
+ List transformedColData = output.GetColumn(outputColName).ToList();
+ Assert.Equal(mostFrequentValue, transformedColData[3]);
+ Assert.Equal(mostFrequentValue, transformedColData[7]);
+ Assert.Equal(mostFrequentValue, transformedColData[11]);
+ }
+
+ [Fact]
+ public void TestAllTypes()
+ {
+ Test("int8_t", false, 1);
+ Test("int8_t", true, 1);
+ Test("uint8_t", false, 1);
+ Test("uint8_t", true, 1);
+ Test("int16_t", false, 1);
+ Test("int16_t", true, 1);
+ Test("uint16_t", false, 1);
+ Test("uint16_t", true, 1);
+ Test("int32_t", false, 1);
+ Test("int32_t", true, 1);
+ Test("uint32_t", false, 1);
+ Test("uint32_t", true, 1);
+ Test("int64_t", false, 1);
+ Test("int64_t", true, 1);
+ Test("uint64_t", false, 1);
+ Test("uint64_t", true, 1);
+ Test("float_t", false, 1.5f);
+ Test("float_t", true, 1.5f);
+ Test("double_t", false, 1.5);
+ Test("double_t", true, 1.5);
+ Test("str", false, "one");
+ Test("str", true, "one");
+
+ Done();
+ }
+
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs b/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs
new file mode 100644
index 0000000000..62a9f234c3
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs
@@ -0,0 +1,360 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Data;
+using Microsoft.ML.RunTests;
+using Microsoft.ML.Featurizers;
+using System;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace Microsoft.ML.Tests.Transformers
+{
+ public class DateTimeTransformerTests : TestDataPipeBase
+ {
+ public DateTimeTransformerTests(ITestOutputHelper output) : base(output)
+ {
+ }
+
+ private class DateTimeInput
+ {
+ public long date;
+ }
+
+ [Fact]
+ public void CorrectNumberOfColumnsAndSchema()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new DateTimeInput() { date = 0} };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var columnPrefix = "DTC_";
+ var pipeline = mlContext.Transforms.DateTimeTransformer("date", columnPrefix);
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+
+ // Check the schema has 22 columns
+ Assert.Equal(22, schema.Count);
+
+ // Make sure names with prefix and order are correct
+ Assert.Equal($"{columnPrefix}Year", schema[1].Name);
+ Assert.Equal($"{columnPrefix}Month", schema[2].Name);
+ Assert.Equal($"{columnPrefix}Day", schema[3].Name);
+ Assert.Equal($"{columnPrefix}Hour", schema[4].Name);
+ Assert.Equal($"{columnPrefix}Minute", schema[5].Name);
+ Assert.Equal($"{columnPrefix}Second", schema[6].Name);
+ Assert.Equal($"{columnPrefix}AmPm", schema[7].Name);
+ Assert.Equal($"{columnPrefix}Hour12", schema[8].Name);
+ Assert.Equal($"{columnPrefix}DayOfWeek", schema[9].Name);
+ Assert.Equal($"{columnPrefix}DayOfQuarter", schema[10].Name);
+ Assert.Equal($"{columnPrefix}DayOfYear", schema[11].Name);
+ Assert.Equal($"{columnPrefix}WeekOfMonth", schema[12].Name);
+ Assert.Equal($"{columnPrefix}QuarterOfYear", schema[13].Name);
+ Assert.Equal($"{columnPrefix}HalfOfYear", schema[14].Name);
+ Assert.Equal($"{columnPrefix}WeekIso", schema[15].Name);
+ Assert.Equal($"{columnPrefix}YearIso", schema[16].Name);
+ Assert.Equal($"{columnPrefix}MonthLabel", schema[17].Name);
+ Assert.Equal($"{columnPrefix}AmPmLabel", schema[18].Name);
+ Assert.Equal($"{columnPrefix}DayOfWeekLabel", schema[19].Name);
+ Assert.Equal($"{columnPrefix}HolidayName", schema[20].Name);
+ Assert.Equal($"{columnPrefix}IsPaidTimeOff", schema[21].Name);
+
+ // Make sure types are correct
+ Assert.Equal(typeof(int), schema[1].Type.RawType);
+ Assert.Equal(typeof(byte), schema[2].Type.RawType);
+ Assert.Equal(typeof(byte), schema[3].Type.RawType);
+ Assert.Equal(typeof(byte), schema[4].Type.RawType);
+ Assert.Equal(typeof(byte), schema[5].Type.RawType);
+ Assert.Equal(typeof(byte), schema[6].Type.RawType);
+ Assert.Equal(typeof(byte), schema[7].Type.RawType);
+ Assert.Equal(typeof(byte), schema[8].Type.RawType);
+ Assert.Equal(typeof(byte), schema[9].Type.RawType);
+ Assert.Equal(typeof(byte), schema[10].Type.RawType);
+ Assert.Equal(typeof(ushort), schema[11].Type.RawType);
+ Assert.Equal(typeof(ushort), schema[12].Type.RawType);
+ Assert.Equal(typeof(byte), schema[13].Type.RawType);
+ Assert.Equal(typeof(byte), schema[14].Type.RawType);
+ Assert.Equal(typeof(byte), schema[15].Type.RawType);
+ Assert.Equal(typeof(int), schema[16].Type.RawType);
+ Assert.Equal(typeof(ReadOnlyMemory), schema[17].Type.RawType);
+ Assert.Equal(typeof(ReadOnlyMemory), schema[18].Type.RawType);
+ Assert.Equal(typeof(ReadOnlyMemory), schema[19].Type.RawType);
+ Assert.Equal(typeof(ReadOnlyMemory), schema[20].Type.RawType);
+ Assert.Equal(typeof(byte), schema[21].Type.RawType);
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void DropOneColumn()
+ {
+ // TODO: This will fail until we figure out the C++ dll situation
+
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new DateTimeInput() { date = 0 } };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var columnPrefix = "DTC_";
+ var pipeline = mlContext.Transforms.DateTimeTransformer("date", columnPrefix, DateTimeTransformerEstimator.ColumnsProduced.IsPaidTimeOff);
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+
+ // Check the schema has 21 columns
+ Assert.Equal(21, schema.Count);
+
+ // Make sure names with prefix and order are correct
+ Assert.Equal($"{columnPrefix}Year", schema[1].Name);
+ Assert.Equal($"{columnPrefix}Month", schema[2].Name);
+ Assert.Equal($"{columnPrefix}Day", schema[3].Name);
+ Assert.Equal($"{columnPrefix}Hour", schema[4].Name);
+ Assert.Equal($"{columnPrefix}Minute", schema[5].Name);
+ Assert.Equal($"{columnPrefix}Second", schema[6].Name);
+ Assert.Equal($"{columnPrefix}AmPm", schema[7].Name);
+ Assert.Equal($"{columnPrefix}Hour12", schema[8].Name);
+ Assert.Equal($"{columnPrefix}DayOfWeek", schema[9].Name);
+ Assert.Equal($"{columnPrefix}DayOfQuarter", schema[10].Name);
+ Assert.Equal($"{columnPrefix}DayOfYear", schema[11].Name);
+ Assert.Equal($"{columnPrefix}WeekOfMonth", schema[12].Name);
+ Assert.Equal($"{columnPrefix}QuarterOfYear", schema[13].Name);
+ Assert.Equal($"{columnPrefix}HalfOfYear", schema[14].Name);
+ Assert.Equal($"{columnPrefix}WeekIso", schema[15].Name);
+ Assert.Equal($"{columnPrefix}YearIso", schema[16].Name);
+ Assert.Equal($"{columnPrefix}MonthLabel", schema[17].Name);
+ Assert.Equal($"{columnPrefix}AmPmLabel", schema[18].Name);
+ Assert.Equal($"{columnPrefix}DayOfWeekLabel", schema[19].Name);
+ Assert.Equal($"{columnPrefix}HolidayName", schema[20].Name);
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void DropManyColumns()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new DateTimeInput() { date = 0 } };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var columnPrefix = "DTC_";
+ var pipeline = mlContext.Transforms.DateTimeTransformer("date", columnPrefix, DateTimeTransformerEstimator.ColumnsProduced.IsPaidTimeOff,
+ DateTimeTransformerEstimator.ColumnsProduced.Day, DateTimeTransformerEstimator.ColumnsProduced.QuarterOfYear, DateTimeTransformerEstimator.ColumnsProduced.AmPm);
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+
+ // Check the schema has 18 columns
+ Assert.Equal(18, schema.Count);
+
+ // Make sure names with prefix and order are correct
+ Assert.Equal($"{columnPrefix}Year", schema[1].Name);
+ Assert.Equal($"{columnPrefix}Month", schema[2].Name);
+ Assert.Equal($"{columnPrefix}Hour", schema[3].Name);
+ Assert.Equal($"{columnPrefix}Minute", schema[4].Name);
+ Assert.Equal($"{columnPrefix}Second", schema[5].Name);
+ Assert.Equal($"{columnPrefix}Hour12", schema[6].Name);
+ Assert.Equal($"{columnPrefix}DayOfWeek", schema[7].Name);
+ Assert.Equal($"{columnPrefix}DayOfQuarter", schema[8].Name);
+ Assert.Equal($"{columnPrefix}DayOfYear", schema[9].Name);
+ Assert.Equal($"{columnPrefix}WeekOfMonth", schema[10].Name);
+ Assert.Equal($"{columnPrefix}HalfOfYear", schema[11].Name);
+ Assert.Equal($"{columnPrefix}WeekIso", schema[12].Name);
+ Assert.Equal($"{columnPrefix}YearIso", schema[13].Name);
+ Assert.Equal($"{columnPrefix}MonthLabel", schema[14].Name);
+ Assert.Equal($"{columnPrefix}AmPmLabel", schema[15].Name);
+ Assert.Equal($"{columnPrefix}DayOfWeekLabel", schema[16].Name);
+ Assert.Equal($"{columnPrefix}HolidayName", schema[17].Name);
+
+ // Make sure types are correct
+ Assert.Equal(typeof(int), schema[1].Type.RawType);
+ Assert.Equal(typeof(byte), schema[2].Type.RawType);
+ Assert.Equal(typeof(byte), schema[3].Type.RawType);
+ Assert.Equal(typeof(byte), schema[4].Type.RawType);
+ Assert.Equal(typeof(byte), schema[5].Type.RawType);
+ Assert.Equal(typeof(byte), schema[6].Type.RawType);
+ Assert.Equal(typeof(byte), schema[7].Type.RawType);
+ Assert.Equal(typeof(byte), schema[8].Type.RawType);
+ Assert.Equal(typeof(ushort), schema[9].Type.RawType);
+ Assert.Equal(typeof(ushort), schema[10].Type.RawType);
+ Assert.Equal(typeof(byte), schema[11].Type.RawType);
+ Assert.Equal(typeof(byte), schema[12].Type.RawType);
+ Assert.Equal(typeof(int), schema[13].Type.RawType);
+ Assert.Equal(typeof(ReadOnlyMemory), schema[14].Type.RawType);
+ Assert.Equal(typeof(ReadOnlyMemory), schema[15].Type.RawType);
+ Assert.Equal(typeof(ReadOnlyMemory), schema[16].Type.RawType);
+ Assert.Equal(typeof(ReadOnlyMemory), schema[17].Type.RawType);
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void CanUseDateFromColumn()
+ {
+ // Future Date - 2025 June 30
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new DateTimeInput() { date = 1751241600 } };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.DateTimeTransformer("date", "DTC");
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+
+ // Get the data from the first row and make sure it matches expected
+ var row = output.Preview(1).RowView[0].Values;
+
+ // Assert the data from the first row is what we expect
+ Assert.Equal(2025, row[1].Value); // Year
+ Assert.Equal((byte)6, row[2].Value); // Month
+ Assert.Equal((byte)30, row[3].Value); // Day
+ Assert.Equal((byte)0, row[4].Value); // Hour
+ Assert.Equal((byte)0, row[5].Value); // Minute
+ Assert.Equal((byte)0, row[6].Value); // Second
+ Assert.Equal((byte)0, row[7].Value); // AmPm
+ Assert.Equal((byte)0, row[8].Value); // Hour12
+ Assert.Equal((byte)1, row[9].Value); // DayOfWeek
+ Assert.Equal((byte)91, row[10].Value); // DayOfQuarter
+ Assert.Equal((ushort)180, row[11].Value); // DayOfYear
+ Assert.Equal((ushort)4, row[12].Value); // WeekOfMonth
+ Assert.Equal((byte)2, row[13].Value); // QuarterOfYear
+ Assert.Equal((byte)1, row[14].Value); // HalfOfYear
+ Assert.Equal((byte)27, row[15].Value); // WeekIso
+ Assert.Equal(2025, row[16].Value); // YearIso
+ Assert.Equal("June", row[17].Value.ToString()); // MonthLabel
+ Assert.Equal("am", row[18].Value.ToString()); // AmPmLabel
+ Assert.Equal("Monday", row[19].Value.ToString()); // DayOfWeekLabel
+ Assert.Equal("", row[20].Value.ToString()); // HolidayName
+ Assert.Equal((byte)0, row[21].Value); // IsPaidTimeOff
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void HolidayTest()
+ {
+ // Future Date - 2025 June 30
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new DateTimeInput() { date = 157161600 } };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.DateTimeTransformer("date", "DTC", country: DateTimeTransformerEstimator.Countries.Canada);
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+
+ // Get the data from the first row and make sure it matches expected
+ var row = output.Preview(1).RowView[0].Values;
+
+ // Assert the data from the first row for holidays is what we expect
+ Assert.Equal("Christmas Day", row[20].Value.ToString()); // HolidayName
+ Assert.Equal((byte)0, row[21].Value); // IsPaidTimeOff
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void ManyRowsTest()
+ {
+ // Future Date - 2025 June 30
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new DateTimeInput() { date = 1751241600 }, new DateTimeInput() { date = 1751241600 }, new DateTimeInput() { date = 12341 },
+ new DateTimeInput() { date = 134 }, new DateTimeInput() { date = 134 }, new DateTimeInput() { date = 1234 }, new DateTimeInput() { date = 1751241600 },
+ new DateTimeInput() { date = 1751241600 }, new DateTimeInput() { date = 12341 },
+ new DateTimeInput() { date = 134 }, new DateTimeInput() { date = 134 }, new DateTimeInput() { date = 1234 }};
+
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.DateTimeTransformer("date", "DTC");
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+
+ // Get the data from the first row and make sure it matches expected
+ var row = output.Preview().RowView[0].Values;
+
+ // Assert the data from the first row is what we expect
+ Assert.Equal(2025, row[1].Value); // Year
+ Assert.Equal((byte)6, row[2].Value); // Month
+ Assert.Equal((byte)30, row[3].Value); // Day
+ Assert.Equal((byte)0, row[4].Value); // Hour
+ Assert.Equal((byte)0, row[5].Value); // Minute
+ Assert.Equal((byte)0, row[6].Value); // Second
+ Assert.Equal((byte)0, row[7].Value); // AmPm
+ Assert.Equal((byte)0, row[8].Value); // Hour12
+ Assert.Equal((byte)1, row[9].Value); // DayOfWeek
+ Assert.Equal((byte)91, row[10].Value); // DayOfQuarter
+ Assert.Equal((ushort)180, row[11].Value); // DayOfYear
+ Assert.Equal((ushort)4, row[12].Value); // WeekOfMonth
+ Assert.Equal((byte)2, row[13].Value); // QuarterOfYear
+ Assert.Equal((byte)1, row[14].Value); // HalfOfYear
+ Assert.Equal((byte)27, row[15].Value); // WeekIso
+ Assert.Equal(2025, row[16].Value); // YearIso
+ Assert.Equal("June", row[17].Value.ToString()); // MonthLabel
+ Assert.Equal("am", row[18].Value.ToString()); // AmPmLabel
+ Assert.Equal("Monday", row[19].Value.ToString()); // DayOfWeekLabel
+ Assert.Equal("", row[20].Value.ToString()); // HolidayName
+ Assert.Equal((byte)0, row[21].Value); // IsPaidTimeOff
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void EntryPointTest()
+ {
+ // Future Date - 2025 June 30
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new DateTimeInput() { date = 1751241600 } };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var options = new DateTimeTransformerEstimator.Options
+ {
+ ColumnsToDrop = null,
+ Source = "date",
+ Prefix = "pref_",
+ Data = data
+ };
+
+ var entryOutput = DateTimeTransformerEntrypoint.DateTimeSplit(mlContext.Transforms.GetEnvironment(), options);
+ var output = entryOutput.OutputData;
+
+ // Get the data from the first row and make sure it matches expected
+ var row = output.Preview(1).RowView[0].Values;
+
+ // Assert the data from the first row is what we expect
+ Assert.Equal(2025, row[1].Value); // Year
+ Assert.Equal((byte)6, row[2].Value); // Month
+ Assert.Equal((byte)30, row[3].Value); // Day
+ Assert.Equal((byte)0, row[4].Value); // Hour
+ Assert.Equal((byte)0, row[5].Value); // Minute
+ Assert.Equal((byte)0, row[6].Value); // Second
+ Assert.Equal((byte)0, row[7].Value); // AmPm
+ Assert.Equal((byte)0, row[8].Value); // Hour12
+ Assert.Equal((byte)1, row[9].Value); // DayOfWeek
+ Assert.Equal((byte)91, row[10].Value); // DayOfQuarter
+ Assert.Equal((ushort)180, row[11].Value); // DayOfYear
+ Assert.Equal((ushort)4, row[12].Value); // WeekOfMonth
+ Assert.Equal((byte)2, row[13].Value); // QuarterOfYear
+ Assert.Equal((byte)1, row[14].Value); // HalfOfYear
+ Assert.Equal((byte)27, row[15].Value); // WeekIso
+ Assert.Equal(2025, row[16].Value); // YearIso
+ Assert.Equal("June", row[17].Value.ToString()); // MonthLabel
+ Assert.Equal("am", row[18].Value.ToString()); // AmPmLabel
+ Assert.Equal("Monday", row[19].Value.ToString()); // DayOfWeekLabel
+ Assert.Equal("", row[20].Value.ToString()); // HolidayName
+ Assert.Equal((byte)0, row[21].Value); // IsPaidTimeOff
+
+ Done();
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Transformers/RobustScalerTests.cs b/test/Microsoft.ML.Tests/Transformers/RobustScalerTests.cs
new file mode 100644
index 0000000000..d4e415e9f8
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Transformers/RobustScalerTests.cs
@@ -0,0 +1,430 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Data;
+using Microsoft.ML.Featurizers;
+using Microsoft.ML.RunTests;
+using System;
+using System.Collections.Generic;
+using Xunit;
+using Xunit.Abstractions;
+using System.Linq;
+
+namespace Microsoft.ML.Tests.Transformers
+{
+ public class RobustScalerTests : TestDataPipeBase
+ {
+ public RobustScalerTests(ITestOutputHelper output) : base(output)
+ {
+ }
+
+ [Fact]
+ public void TestInvalidType()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new [] { new { ColA = "Invalid Type" }};
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, should error on fit and on GetOutputSchema
+ var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA");
+
+ Assert.Throws(() => pipeline.Fit(data));
+ Assert.Throws(() => pipeline.GetOutputSchema(SchemaShape.Create(data.Schema)));
+
+ Done();
+ }
+
+ [Fact]
+ public void TestNoScale()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new [] { new { ColA = 1f }, new { ColA = 3f }, new { ColA = 5f }, new { ColA = 7f }, new { ColA = 9f } };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA", scale: false);
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+
+ Assert.Single(schema.Where(x => x.IsHidden == false));
+ Assert.Equal(typeof(float), schema["ColA"].Type.RawType);
+
+ var cursor = output.GetRowCursor(schema["ColA"]);
+ var expectedOutput = new[] { -4f, -2f, 0f, 2f, 4f };
+ var index = 0;
+ var getter = cursor.GetGetter(schema["ColA"]);
+ float value = default;
+
+ while (cursor.MoveNext())
+ {
+ getter(ref value);
+ Assert.Equal(expectedOutput[index++], value);
+ }
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void TestNoScaleNoCenter()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new [] { new { ColA = 1f }, new { ColA = 3f }, new { ColA = 5f }, new { ColA = 7f }, new { ColA = 9f } };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA", scale: false, center: false);
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+
+ Assert.Single(schema.Where(x => x.IsHidden == false));
+ Assert.Equal(typeof(float), schema["ColA"].Type.RawType);
+
+ var cursor = output.GetRowCursor(schema["ColA"]);
+ var expectedOutput = new[] { 1f, 3f, 5f, 7f, 9f };
+ var index = 0;
+ var getter = cursor.GetGetter(schema["ColA"]);
+ float value = default;
+
+ while (cursor.MoveNext())
+ {
+ getter(ref value);
+ Assert.Equal(expectedOutput[index++], value);
+ }
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void TestFloat()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new [] { new { ColA = 1f }, new { ColA = 3f }, new { ColA = 5f }, new { ColA = 7f }, new { ColA = 9f } };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA");
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+
+ Assert.Single(schema.Where(x => x.IsHidden == false));
+ Assert.Equal(typeof(float), schema["ColA"].Type.RawType);
+
+ var cursor = output.GetRowCursor(schema["ColA"]);
+ var expectedOutput = new[] { -1f, -0.5f, 0f, .5f, 1f };
+ var index = 0;
+ var getter = cursor.GetGetter(schema["ColA"]);
+ float value = default;
+
+ while (cursor.MoveNext())
+ {
+ getter(ref value);
+ Assert.Equal(expectedOutput[index++], value);
+ }
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void TestInt64()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new [] { new { ColA = 1L }, new { ColA = 3L }, new { ColA = 5L }, new { ColA = 7L }, new { ColA = 9L } };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA");
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+ var prev = output.Preview();
+
+ Assert.Single(schema.Where(x => x.IsHidden == false));
+ Assert.Equal(typeof(double), schema["ColA"].Type.RawType);
+
+ var cursor = output.GetRowCursor(schema["ColA"]);
+ var expectedOutput = new[] { -1d, -0.5d, 0d, .5d, 1d };
+ var index = 0;
+ var getter = cursor.GetGetter(schema["ColA"]);
+ double value = default;
+
+ while (cursor.MoveNext())
+ {
+ getter(ref value);
+ Assert.Equal(expectedOutput[index++], value);
+ }
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void TestInt32()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new [] { new { ColA = 1 }, new { ColA = 3 }, new { ColA = 5 }, new { ColA = 7 }, new { ColA = 9 } };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA");
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+ var prev = output.Preview();
+
+ Assert.Single(schema.Where(x => x.IsHidden == false));
+ Assert.Equal(typeof(double), schema["ColA"].Type.RawType);
+
+ var cursor = output.GetRowCursor(schema["ColA"]);
+ var expectedOutput = new[] { -1d, -0.5d, 0d, .5d, 1d };
+ var index = 0;
+ var getter = cursor.GetGetter(schema["ColA"]);
+ double value = default;
+
+ while (cursor.MoveNext())
+ {
+ getter(ref value);
+ Assert.Equal(expectedOutput[index++], value);
+ }
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void TestInt16()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new [] { new { ColA = (short)1 }, new { ColA = (short)3 }, new { ColA = (short)5 }, new { ColA = (short)7 }, new { ColA = (short)9 } };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA");
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+ var prev = output.Preview();
+
+ Assert.Single(schema.Where(x => x.IsHidden == false));
+ Assert.Equal(typeof(float), schema["ColA"].Type.RawType);
+
+ var cursor = output.GetRowCursor(schema["ColA"]);
+ var expectedOutput = new[] { -1f, -0.5f, 0f, .5f, 1f };
+ var index = 0;
+ var getter = cursor.GetGetter(schema["ColA"]);
+ float value = default;
+
+ while (cursor.MoveNext())
+ {
+ getter(ref value);
+ Assert.Equal(expectedOutput[index++], value);
+ }
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void TestInt8()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new [] { new { ColA = (sbyte)1 }, new { ColA = (sbyte)3 }, new { ColA = (sbyte)5 }, new { ColA = (sbyte)7 }, new { ColA = (sbyte)9 } };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA");
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+ var prev = output.Preview();
+
+ Assert.Single(schema.Where(x => x.IsHidden == false));
+ Assert.Equal(typeof(float), schema["ColA"].Type.RawType);
+
+ var cursor = output.GetRowCursor(schema["ColA"]);
+ var expectedOutput = new[] { -1f, -0.5f, 0f, .5f, 1f };
+ var index = 0;
+ var getter = cursor.GetGetter(schema["ColA"]);
+ float value = default;
+
+ while (cursor.MoveNext())
+ {
+ getter(ref value);
+ Assert.Equal(expectedOutput[index++], value);
+ }
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void TestDouble()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new [] { new { ColA = 1d }, new { ColA = 3d }, new { ColA = 5d }, new { ColA = 7d }, new { ColA = 9d } };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA");
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+
+ Assert.Single(schema.Where(x => x.IsHidden == false));
+ Assert.Equal(typeof(double), schema["ColA"].Type.RawType);
+
+ var cursor = output.GetRowCursor(schema["ColA"]);
+ var expectedOutput = new[] { -1d, -0.5d, 0d, .5d, 1d };
+ var index = 0;
+ var getter = cursor.GetGetter(schema["ColA"]);
+ double value = default;
+
+ while (cursor.MoveNext())
+ {
+ getter(ref value);
+ Assert.Equal(expectedOutput[index++], value);
+ }
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void TestUInt64()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new [] { new { ColA = (ulong)1 }, new { ColA = (ulong)3 }, new { ColA = (ulong)5 }, new { ColA = (ulong)7 }, new { ColA = (ulong)9 } };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA");
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+ var prev = output.Preview();
+
+ Assert.Single(schema.Where(x => x.IsHidden == false));
+ Assert.Equal(typeof(double), schema["ColA"].Type.RawType);
+
+ var cursor = output.GetRowCursor(schema["ColA"]);
+ var expectedOutput = new[] { -1d, -0.5d, 0d, .5d, 1d };
+ var index = 0;
+ var getter = cursor.GetGetter(schema["ColA"]);
+ double value = default;
+
+ while (cursor.MoveNext())
+ {
+ getter(ref value);
+ Assert.Equal(expectedOutput[index++], value);
+ }
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void TestUInt32()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new [] { new { ColA = (uint)1 }, new { ColA = (uint)3 }, new { ColA = (uint)5 }, new { ColA = (uint)7 }, new { ColA = (uint)9 } };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA");
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+ var prev = output.Preview();
+
+ Assert.Single(schema.Where(x => x.IsHidden == false));
+ Assert.Equal(typeof(double), schema["ColA"].Type.RawType);
+
+ var cursor = output.GetRowCursor(schema["ColA"]);
+ var expectedOutput = new[] { -1d, -0.5d, 0d, .5d, 1d };
+ var index = 0;
+ var getter = cursor.GetGetter(schema["ColA"]);
+ double value = default;
+
+ while (cursor.MoveNext())
+ {
+ getter(ref value);
+ Assert.Equal(expectedOutput[index++], value);
+ }
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void TestUInt16()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new [] { new { ColA = (ushort)1 }, new { ColA = (ushort)3 }, new { ColA = (ushort)5 }, new { ColA = (ushort)7 }, new { ColA = (ushort)9 } };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA");
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+ var prev = output.Preview();
+
+ Assert.Single(schema.Where(x => x.IsHidden == false));
+ Assert.Equal(typeof(float), schema["ColA"].Type.RawType);
+
+ var cursor = output.GetRowCursor(schema["ColA"]);
+ var expectedOutput = new[] { -1f, -0.5f, 0f, .5f, 1f };
+ var index = 0;
+ var getter = cursor.GetGetter(schema["ColA"]);
+ float value = default;
+
+ while (cursor.MoveNext())
+ {
+ getter(ref value);
+ Assert.Equal(expectedOutput[index++], value);
+ }
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void TestUInt8()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new [] { new { ColA = (byte)1 }, new { ColA = (byte)3 }, new { ColA = (byte)5 }, new { ColA = (byte)7 }, new { ColA = (byte)9 } };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.RobustScalerTransformer("ColA");
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+ var prev = output.Preview();
+
+ Assert.Single(schema.Where(x => x.IsHidden == false));
+ Assert.Equal(typeof(float), schema["ColA"].Type.RawType);
+
+ var cursor = output.GetRowCursor(schema["ColA"]);
+ var expectedOutput = new[] { -1f, -0.5f, 0f, .5f, 1f };
+ var index = 0;
+ var getter = cursor.GetGetter(schema["ColA"]);
+ float value = default;
+
+ while (cursor.MoveNext())
+ {
+ getter(ref value);
+ Assert.Equal(expectedOutput[index++], value);
+ }
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+ }
+}
\ No newline at end of file
diff --git a/test/Microsoft.ML.Tests/Transformers/TimeSeriesImputerTests.cs b/test/Microsoft.ML.Tests/Transformers/TimeSeriesImputerTests.cs
new file mode 100644
index 0000000000..a304832957
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Transformers/TimeSeriesImputerTests.cs
@@ -0,0 +1,402 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Data;
+using Microsoft.ML.RunTests;
+using Microsoft.ML.Featurizers;
+using System;
+using Xunit;
+using Xunit.Abstractions;
+using System.Drawing.Printing;
+using System.Linq;
+
+namespace Microsoft.ML.Tests.Transformers
+{
+ public class TimeSeriesImputerTests : TestDataPipeBase
+ {
+ public TimeSeriesImputerTests(ITestOutputHelper output) : base(output)
+ {
+ }
+
+ private class TimeSeriesTwoGrainInput
+ {
+ public long date;
+ public string grainA;
+ public string grainB;
+ public float data;
+ }
+
+ private class TimeSeriesOneGrainInput
+ {
+ public long date;
+ public string grainA;
+ public int dataA;
+ public float dataB;
+ public uint dataC;
+ }
+
+ private class TimeSeriesOneGrainFloatInput
+ {
+ public long date;
+ public string grainA;
+ public float dataA;
+ }
+
+ [Fact]
+ public void NotImputeOneColumn()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] {
+ new TimeSeriesOneGrainInput() { date = 25, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 },
+ new TimeSeriesOneGrainInput() { date = 26, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 },
+ new TimeSeriesOneGrainInput() { date = 28, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 }
+ };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.TimeSeriesImputer("date", new string[] { "grainA" }, new string[] { "dataB"});
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+
+ // We always output the same column as the input, plus adding a column saying whether the row was imputed or not.
+ Assert.Equal(6, schema.Count);
+ Assert.Equal("date", schema[0].Name);
+ Assert.Equal("grainA", schema[1].Name);
+ Assert.Equal("dataA", schema[2].Name);
+ Assert.Equal("dataB", schema[3].Name);
+ Assert.Equal("dataC", schema[4].Name);
+ Assert.Equal("IsRowImputed", schema[5].Name);
+
+ // We are imputing 1 row, so total rows should be 4.
+ var preview = output.Preview();
+ Assert.Equal(4, preview.RowView.Length);
+
+ // Row that was imputed should have date of 27
+ Assert.Equal(27L, preview.ColumnView[0].Values[2]);
+
+ // Since we are not imputing data on one column and a row is getting imputed, its value should be default(T)
+ Assert.Equal(default(float), preview.ColumnView[3].Values[2]);
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void ImputeOnlyOneColumn()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] {
+ new TimeSeriesOneGrainInput() { date = 25, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 },
+ new TimeSeriesOneGrainInput() { date = 26, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 },
+ new TimeSeriesOneGrainInput() { date = 28, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 }
+ };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.TimeSeriesImputer("date", new string[] { "grainA" }, new string[] { "dataB"}, TimeSeriesImputerEstimator.FilterMode.Include);
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+
+ // We always output the same column as the input, plus adding a column saying whether the row was imputed or not.
+ Assert.Equal(6, schema.Count);
+ Assert.Equal("date", schema[0].Name);
+ Assert.Equal("grainA", schema[1].Name);
+ Assert.Equal("dataA", schema[2].Name);
+ Assert.Equal("dataB", schema[3].Name);
+ Assert.Equal("dataC", schema[4].Name);
+ Assert.Equal("IsRowImputed", schema[5].Name);
+
+ // We are imputing 1 row, so total rows should be 4.
+ var preview = output.Preview();
+ Assert.Equal(4, preview.RowView.Length);
+
+ // Row that was imputed should have date of 27
+ Assert.Equal(27L, preview.ColumnView[0].Values[2]);
+
+ // Since we are not imputing data on two columns and a row is getting imputed, its value should be default(T)
+ Assert.Equal(default(int), preview.ColumnView[2].Values[2]);
+ Assert.Equal(default(uint), preview.ColumnView[4].Values[2]);
+
+ // Column that was imputed should have value of 2.0f
+ Assert.Equal(2.0f, preview.ColumnView[3].Values[2]);
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void Forwardfill()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new TimeSeriesOneGrainFloatInput() { date = 0, grainA = "A", dataA = 2.0f },
+ new TimeSeriesOneGrainFloatInput() { date = 1, grainA = "A", dataA = float.NaN },
+ new TimeSeriesOneGrainFloatInput() { date = 3, grainA = "A", dataA = 5.0f },
+ new TimeSeriesOneGrainFloatInput() { date = 5, grainA = "A", dataA = float.NaN },
+ new TimeSeriesOneGrainFloatInput() { date = 7, grainA = "A", dataA = float.NaN }};
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.TimeSeriesImputer("date", new string[] { "grainA" });
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var prev = output.Preview();
+
+ // Should have 3 original columns + 1 more for IsRowImputed
+ Assert.Equal(4, output.Schema.Count);
+
+ // Imputing rows with dates 2,4,6, so should have length of 8
+ Assert.Equal(8, prev.RowView.Length);
+
+ // Check that imputed rows have the correct dates
+ Assert.Equal(2L, prev.ColumnView[0].Values[2]);
+ Assert.Equal(4L, prev.ColumnView[0].Values[4]);
+ Assert.Equal(6L, prev.ColumnView[0].Values[6]);
+
+ // Make sure grain was propagated correctly
+ Assert.Equal("A", prev.ColumnView[1].Values[2].ToString());
+ Assert.Equal("A", prev.ColumnView[1].Values[4].ToString());
+ Assert.Equal("A", prev.ColumnView[1].Values[6].ToString());
+
+ // Make sure forward fill is working as expected. All NA's should be replaced, and imputed rows should have correct values too
+ Assert.Equal(2.0f, prev.ColumnView[2].Values[1]);
+ Assert.Equal(2.0f, prev.ColumnView[2].Values[2]);
+ Assert.Equal(5.0f, prev.ColumnView[2].Values[4]);
+ Assert.Equal(5.0f, prev.ColumnView[2].Values[5]);
+ Assert.Equal(5.0f, prev.ColumnView[2].Values[6]);
+ Assert.Equal(5.0f, prev.ColumnView[2].Values[7]);
+
+ // Make sure IsRowImputed is true for row 2, 4,6 , false for the rest
+ Assert.Equal(false, prev.ColumnView[3].Values[0]);
+ Assert.Equal(false, prev.ColumnView[3].Values[1]);
+ Assert.Equal(true, prev.ColumnView[3].Values[2]);
+ Assert.Equal(false, prev.ColumnView[3].Values[3]);
+ Assert.Equal(true, prev.ColumnView[3].Values[4]);
+ Assert.Equal(false, prev.ColumnView[3].Values[5]);
+ Assert.Equal(true, prev.ColumnView[3].Values[6]);
+ Assert.Equal(false, prev.ColumnView[3].Values[7]);
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void EntryPoint()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new { ts = 1L, grain = 1970, c3 = 10, c4 = 19},
+ new { ts = 2L, grain = 1970, c3 = 13, c4 = 12},
+ new { ts = 3L, grain = 1970, c3 = 15, c4 = 16},
+ new { ts = 5L, grain = 1970, c3 = 20, c4 = 19}
+ };
+
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+ TimeSeriesImputerEstimator.Options options = new TimeSeriesImputerEstimator.Options() {
+ TimeSeriesColumn = "ts",
+ GrainColumns = new[] { "grain" },
+ FilterColumns = new[] { "c3", "c4" },
+ FilterMode = TimeSeriesImputerEstimator.FilterMode.Include,
+ ImputeMode = TimeSeriesImputerEstimator.ImputationStrategy.ForwardFill,
+ Data = data
+ };
+
+ var entryOutput = TimeSeriesTransformerEntrypoint.TimeSeriesImputer(mlContext.Transforms.GetEnvironment(), options);
+ // Build the pipeline, fit, and transform it.
+ var output = entryOutput.OutputData;
+
+ // Get the data from the first row and make sure it matches expected
+ var prev = output.Preview();
+
+ // Should have 4 original columns + 1 more for IsRowImputed
+ Assert.Equal(5, output.Schema.Count);
+
+ // Imputing rows with date 4 so should have length of 5
+ Assert.Equal(5, prev.RowView.Length);
+
+ // Check that imputed rows have the correct dates
+ Assert.Equal(4L, prev.ColumnView[0].Values[3]);
+
+ // Make sure grain was propagated correctly
+ Assert.Equal(1970, prev.ColumnView[1].Values[2]);
+
+ // Make sure forward fill is working as expected. All NA's should be replaced, and imputed rows should have correct values too
+ Assert.Equal(15, prev.ColumnView[2].Values[3]);
+ Assert.Equal(16, prev.ColumnView[3].Values[3]);
+
+ // Make sure IsRowImputed is true for row 4, false for the rest
+ Assert.Equal(false, prev.ColumnView[4].Values[0]);
+ Assert.Equal(false, prev.ColumnView[4].Values[1]);
+ Assert.Equal(false, prev.ColumnView[4].Values[2]);
+ Assert.Equal(true, prev.ColumnView[4].Values[3]);
+ Assert.Equal(false, prev.ColumnView[4].Values[4]);
+
+ Done();
+ }
+
+ [Fact]
+ public void Median()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new TimeSeriesOneGrainFloatInput() { date = 0, grainA = "A", dataA = 2.0f },
+ new TimeSeriesOneGrainFloatInput() { date = 1, grainA = "A", dataA = float.NaN },
+ new TimeSeriesOneGrainFloatInput() { date = 3, grainA = "A", dataA = 5.0f },
+ new TimeSeriesOneGrainFloatInput() { date = 5, grainA = "A", dataA = float.NaN },
+ new TimeSeriesOneGrainFloatInput() { date = 7, grainA = "A", dataA = float.NaN }};
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.TimeSeriesImputer("date", new string[] { "grainA" }, imputeMode: TimeSeriesImputerEstimator.ImputationStrategy.Median, filterColumns: null, suppressTypeErrors: true);
+ var model = pipeline.Fit(data);
+
+ var output = model.Transform(data);
+
+ var prev = output.Preview();
+
+ // Should have 3 original columns + 1 more for IsRowImputed
+ Assert.Equal(4, output.Schema.Count);
+
+ // Imputing rows with dates 2,4,6, so should have length of 8
+ Assert.Equal(8, prev.RowView.Length);
+
+ // Check that imputed rows have the correct dates
+ Assert.Equal(2L, prev.ColumnView[0].Values[2]);
+ Assert.Equal(4L, prev.ColumnView[0].Values[4]);
+ Assert.Equal(6L, prev.ColumnView[0].Values[6]);
+
+ // Make sure grain was propagated correctly
+ Assert.Equal("A", prev.ColumnView[1].Values[2].ToString());
+ Assert.Equal("A", prev.ColumnView[1].Values[4].ToString());
+ Assert.Equal("A", prev.ColumnView[1].Values[6].ToString());
+
+ // Make sure Median is working as expected. All NA's should be replaced, and imputed rows should have correct values too
+ Assert.Equal(3.5f, prev.ColumnView[2].Values[1]);
+ Assert.Equal(3.5f, prev.ColumnView[2].Values[2]);
+ Assert.Equal(3.5f, prev.ColumnView[2].Values[4]);
+ Assert.Equal(3.5f, prev.ColumnView[2].Values[5]);
+ Assert.Equal(3.5f, prev.ColumnView[2].Values[6]);
+ Assert.Equal(3.5f, prev.ColumnView[2].Values[7]);
+
+ // Make sure IsRowImputed is true for row 2, 4,6 , false for the rest
+ Assert.Equal(false, prev.ColumnView[3].Values[0]);
+ Assert.Equal(false, prev.ColumnView[3].Values[1]);
+ Assert.Equal(true, prev.ColumnView[3].Values[2]);
+ Assert.Equal(false, prev.ColumnView[3].Values[3]);
+ Assert.Equal(true, prev.ColumnView[3].Values[4]);
+ Assert.Equal(false, prev.ColumnView[3].Values[5]);
+ Assert.Equal(true, prev.ColumnView[3].Values[6]);
+ Assert.Equal(false, prev.ColumnView[3].Values[7]);
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void Backfill()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new TimeSeriesOneGrainFloatInput() { date = 0, grainA = "A", dataA = float.NaN },
+ new TimeSeriesOneGrainFloatInput() { date = 1, grainA = "A", dataA = float.NaN },
+ new TimeSeriesOneGrainFloatInput() { date = 3, grainA = "A", dataA = 5.0f },
+ new TimeSeriesOneGrainFloatInput() { date = 5, grainA = "A", dataA = float.NaN },
+ new TimeSeriesOneGrainFloatInput() { date = 7, grainA = "A", dataA = 2.0f }};
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.TimeSeriesImputer("date", new string[] { "grainA" }, TimeSeriesImputerEstimator.ImputationStrategy.BackFill);
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var prev = output.Preview();
+
+ // Should have 3 original columns + 1 more for IsRowImputed
+ Assert.Equal(4, output.Schema.Count);
+
+ // Imputing rows with dates 2,4,6, so should have length of 8
+ Assert.Equal(8, prev.RowView.Length);
+
+ // Check that imputed rows have the correct dates
+ Assert.Equal(2L, prev.ColumnView[0].Values[2]);
+ Assert.Equal(4L, prev.ColumnView[0].Values[4]);
+ Assert.Equal(6L, prev.ColumnView[0].Values[6]);
+
+ // Make sure grain was propagated correctly
+ Assert.Equal("A", prev.ColumnView[1].Values[2].ToString());
+ Assert.Equal("A", prev.ColumnView[1].Values[4].ToString());
+ Assert.Equal("A", prev.ColumnView[1].Values[6].ToString());
+
+ // Make sure backfill is working as expected. All NA's should be replaced, and imputed rows should have correct values too
+ Assert.Equal(5.0f, prev.ColumnView[2].Values[0]);
+ Assert.Equal(5.0f, prev.ColumnView[2].Values[1]);
+ Assert.Equal(5.0f, prev.ColumnView[2].Values[2]);
+ Assert.Equal(2.0f, prev.ColumnView[2].Values[4]);
+ Assert.Equal(2.0f, prev.ColumnView[2].Values[5]);
+ Assert.Equal(2.0f, prev.ColumnView[2].Values[6]);
+
+ // Make sure IsRowImputed is true for row 2, 4,6 , false for the rest
+ Assert.Equal(false, prev.ColumnView[3].Values[0]);
+ Assert.Equal(false, prev.ColumnView[3].Values[1]);
+ Assert.Equal(true, prev.ColumnView[3].Values[2]);
+ Assert.Equal(false, prev.ColumnView[3].Values[3]);
+ Assert.Equal(true, prev.ColumnView[3].Values[4]);
+ Assert.Equal(false, prev.ColumnView[3].Values[5]);
+ Assert.Equal(true, prev.ColumnView[3].Values[6]);
+ Assert.Equal(false, prev.ColumnView[3].Values[7]);
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [Fact]
+ public void BackfillTwoGrain()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new TimeSeriesTwoGrainInput() { date = 0, grainA = "A", grainB = "A", data = float.NaN},
+ new TimeSeriesTwoGrainInput() { date = 1, grainA = "A", grainB = "A", data = 0.0f},
+ new TimeSeriesTwoGrainInput() { date = 3, grainA = "A", grainB = "B", data = 1.0f},
+ new TimeSeriesTwoGrainInput() { date = 5, grainA = "A", grainB = "B", data = float.NaN},
+ new TimeSeriesTwoGrainInput() { date = 7, grainA = "A", grainB = "B", data = 2.0f }};
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.TimeSeriesImputer("date", new string[] { "grainA", "grainB" }, TimeSeriesImputerEstimator.ImputationStrategy.BackFill);
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var prev = output.Preview();
+
+ // Should have 4 original columns + 1 more for IsRowImputed
+ Assert.Equal(5, output.Schema.Count);
+
+ // Imputing rows with dates 4,6, so should have length of 8
+ Assert.Equal(7, prev.RowView.Length);
+
+ // Check that imputed rows have the correct dates
+ Assert.Equal(4L, prev.ColumnView[0].Values[3]);
+ Assert.Equal(6L, prev.ColumnView[0].Values[5]);
+
+ // Make sure grain was propagated correctly
+ Assert.Equal("A", prev.ColumnView[1].Values[3].ToString());
+ Assert.Equal("A", prev.ColumnView[1].Values[5].ToString());
+ Assert.Equal("B", prev.ColumnView[2].Values[3].ToString());
+ Assert.Equal("B", prev.ColumnView[2].Values[5].ToString());
+
+ // Make sure backfill is working as expected. All NA's should be replaced, and imputed rows should have correct values too
+ Assert.Equal(0.0f, prev.ColumnView[3].Values[0]);
+ Assert.Equal(2.0f, prev.ColumnView[3].Values[3]);
+ Assert.Equal(2.0f, prev.ColumnView[3].Values[4]);
+ Assert.Equal(2.0f, prev.ColumnView[3].Values[5]);
+
+ // Make sure IsRowImputed is true for row 4,6 false for the rest
+ Assert.Equal(false, prev.ColumnView[4].Values[0]);
+ Assert.Equal(false, prev.ColumnView[4].Values[1]);
+ Assert.Equal(false, prev.ColumnView[4].Values[2]);
+ Assert.Equal(true, prev.ColumnView[4].Values[3]);
+ Assert.Equal(false, prev.ColumnView[4].Values[4]);
+ Assert.Equal(true, prev.ColumnView[4].Values[5]);
+ Assert.Equal(false, prev.ColumnView[4].Values[6]);
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Transformers/ToStringTransformerTests.cs b/test/Microsoft.ML.Tests/Transformers/ToStringTransformerTests.cs
new file mode 100644
index 0000000000..7f5c4e0f0f
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Transformers/ToStringTransformerTests.cs
@@ -0,0 +1,353 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Featurizers;
+using Microsoft.ML.Data;
+using Microsoft.ML.RunTests;
+using Microsoft.ML.Transforms;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace Microsoft.ML.Tests.Transformers
+{
+ public class ToStringTransformerTests : TestDataPipeBase
+ {
+ private class BoolInput
+ {
+ public bool data;
+ }
+
+ private class ByteInput
+ {
+ public byte data;
+ }
+
+ private class SByteInput
+ {
+ public sbyte data;
+ }
+
+ private class ShortInput
+ {
+ public short data;
+ }
+
+ private class UShortInput
+ {
+ public ushort data;
+ }
+
+ private class IntInput
+ {
+ public int data;
+ }
+
+ private class UIntInput
+ {
+ public uint data;
+ }
+
+ private class LongInput
+ {
+ public long data;
+ }
+
+ private class ULongInput
+ {
+ public ulong data;
+ }
+
+ private class FloatInput
+ {
+ public float data;
+ }
+
+ private class DoubleInput
+ {
+ public double data;
+ }
+ private class StringInput
+ {
+ public string data;
+ }
+
+ public ToStringTransformerTests(ITestOutputHelper output) : base(output)
+ {
+ }
+
+ [Fact]
+ public void TestBool()
+ {
+ MLContext mlContext = new MLContext(1);
+ var data = new[] { new BoolInput() { data = true }, new BoolInput() { data = false } };
+ IDataView input = mlContext.Data.LoadFromEnumerable(data);
+ var pipeline = mlContext.Transforms.ToStringTransformer("output", "data");
+
+ TestEstimatorCore(pipeline, input);
+
+ var model = pipeline.Fit(input);
+ var output = model.Transform(input);
+ var rows = output.Preview().RowView;
+
+ Assert.Equal("True", rows[0].Values[1].Value.ToString());
+ Assert.Equal("False", rows[1].Values[1].Value.ToString());
+ Done();
+ }
+
+
+ [Fact]
+ public void TestByte()
+ {
+ MLContext mlContext = new MLContext(1);
+
+ var data = new[] { new ByteInput() { data = byte.MinValue }, new ByteInput() { data = byte.MaxValue } };
+ IDataView input = mlContext.Data.LoadFromEnumerable(data);
+ var pipeline = mlContext.Transforms.ToStringTransformer("output", "data");
+
+ TestEstimatorCore(pipeline, input);
+
+ var model = pipeline.Fit(input);
+ var output = model.Transform(input);
+ var rows = output.Preview().RowView;
+
+ Assert.Equal("0", rows[0].Values[1].Value.ToString());
+ Assert.Equal("255", rows[1].Values[1].Value.ToString());
+ Done();
+ }
+
+ [Fact]
+ public void TestSByte()
+ {
+ MLContext mlContext = new MLContext(1);
+
+ var data = new[] { new SByteInput() { data = sbyte.MinValue }, new SByteInput() { data = sbyte.MaxValue } };
+ IDataView input = mlContext.Data.LoadFromEnumerable(data);
+ var pipeline = mlContext.Transforms.ToStringTransformer("output", "data");
+
+ TestEstimatorCore(pipeline, input);
+
+ var model = pipeline.Fit(input);
+ var output = model.Transform(input);
+ var rows = output.Preview().RowView;
+
+ Assert.Equal("-128", rows[0].Values[1].Value.ToString());
+ Assert.Equal("127", rows[1].Values[1].Value.ToString());
+ Done();
+ }
+
+ [Fact]
+ public void TestUShort()
+ {
+ MLContext mlContext = new MLContext(1);
+
+ var data = new[] { new UShortInput() { data = ushort.MinValue }, new UShortInput() { data = ushort.MaxValue } };
+ IDataView input = mlContext.Data.LoadFromEnumerable(data);
+ var pipeline = mlContext.Transforms.ToStringTransformer("output", "data");
+
+ TestEstimatorCore(pipeline, input);
+
+ var model = pipeline.Fit(input);
+ var output = model.Transform(input);
+ var rows = output.Preview().RowView;
+
+ Assert.Equal("0", rows[0].Values[1].Value.ToString());
+ Assert.Equal("65535", rows[1].Values[1].Value.ToString());
+ Done();
+ }
+
+ [Fact]
+ public void TestInt()
+ {
+ MLContext mlContext = new MLContext(1);
+
+ var data = new[] { new IntInput() { data = int.MinValue }, new IntInput() { data = int.MaxValue } };
+ IDataView input = mlContext.Data.LoadFromEnumerable(data);
+ var pipeline = mlContext.Transforms.ToStringTransformer("output", "data");
+
+ TestEstimatorCore(pipeline, input);
+
+ var model = pipeline.Fit(input);
+ var output = model.Transform(input);
+ var rows = output.Preview().RowView;
+
+ Assert.Equal("-2147483648", rows[0].Values[1].Value.ToString());
+ Assert.Equal("2147483647", rows[1].Values[1].Value.ToString());
+ Done();
+ }
+
+ [Fact]
+ public void TestUInt()
+ {
+ MLContext mlContext = new MLContext(1);
+
+ var data = new[] { new UIntInput() { data = uint.MinValue }, new UIntInput() { data = uint.MaxValue } };
+ IDataView input = mlContext.Data.LoadFromEnumerable(data);
+ var pipeline = mlContext.Transforms.ToStringTransformer("output", "data");
+
+ TestEstimatorCore(pipeline, input);
+
+ var model = pipeline.Fit(input);
+ var output = model.Transform(input);
+ var rows = output.Preview().RowView;
+
+ Assert.Equal("0", rows[0].Values[1].Value.ToString());
+ Assert.Equal("4294967295", rows[1].Values[1].Value.ToString());
+ Done();
+ }
+
+ [Fact]
+ public void TestLong()
+ {
+ MLContext mlContext = new MLContext(1);
+
+ var data = new[] { new LongInput() { data = long.MinValue }, new LongInput() { data = long.MaxValue } };
+ IDataView input = mlContext.Data.LoadFromEnumerable(data);
+ var pipeline = mlContext.Transforms.ToStringTransformer("output", "data");
+
+ TestEstimatorCore(pipeline, input);
+
+ var model = pipeline.Fit(input);
+ var output = model.Transform(input);
+ var rows = output.Preview().RowView;
+
+ Assert.Equal("-9223372036854775808", rows[0].Values[1].Value.ToString());
+ Assert.Equal("9223372036854775807", rows[1].Values[1].Value.ToString());
+ Done();
+ }
+
+ [Fact]
+ public void TestULong()
+ {
+ MLContext mlContext = new MLContext(1);
+
+ var data = new[] { new ULongInput() { data = ulong.MinValue }, new ULongInput() { data = ulong.MaxValue } };
+ IDataView input = mlContext.Data.LoadFromEnumerable(data);
+ var pipeline = mlContext.Transforms.ToStringTransformer("output", "data");
+
+ TestEstimatorCore(pipeline, input);
+
+ var model = pipeline.Fit(input);
+ var output = model.Transform(input);
+ var rows = output.Preview().RowView;
+
+ Assert.Equal("0", rows[0].Values[1].Value.ToString());
+ Assert.Equal("18446744073709551615", rows[1].Values[1].Value.ToString());
+ Done();
+ }
+
+ [Fact]
+ public void TestFloat()
+ {
+ MLContext mlContext = new MLContext(1);
+
+ var data = new[] { new FloatInput() { data = float.MinValue}, new FloatInput() { data = float.MaxValue }, new FloatInput() { data = float.NaN } };
+ IDataView input = mlContext.Data.LoadFromEnumerable(data);
+ var pipeline = mlContext.Transforms.ToStringTransformer("output", "data");
+
+ TestEstimatorCore(pipeline, input);
+
+ var model = pipeline.Fit(input);
+ var output = model.Transform(input);
+ var rows = output.Preview().RowView;
+
+ Assert.Equal("-340282346638528859811704183484516925440.000000", rows[0].Values[1].Value.ToString());
+ Assert.Equal("340282346638528859811704183484516925440.000000", rows[1].Values[1].Value.ToString());
+ Done();
+ }
+
+ [Fact]
+ public void TestShort()
+ {
+ MLContext mlContext = new MLContext(1);
+
+ var data = new[] { new ShortInput() { data = short.MinValue }, new ShortInput() { data = short.MaxValue } };
+ IDataView input = mlContext.Data.LoadFromEnumerable(data);
+ var pipeline = mlContext.Transforms.ToStringTransformer("output", "data");
+
+ TestEstimatorCore(pipeline, input);
+
+ var model = pipeline.Fit(input);
+ var output = model.Transform(input);
+ var rows = output.Preview().RowView;
+
+ Assert.Equal("-32768", rows[0].Values[1].Value.ToString());
+ Assert.Equal("32767", rows[1].Values[1].Value.ToString());
+ Done();
+ }
+
+ [Fact]
+ public void TestDouble()
+ {
+ MLContext mlContext = new MLContext(1);
+
+ var data = new[] { new DoubleInput() { data = double.MinValue}, new DoubleInput() { data = double.MaxValue }, new DoubleInput() { data = double.NaN } };
+ IDataView input = mlContext.Data.LoadFromEnumerable(data);
+ var pipeline = mlContext.Transforms.ToStringTransformer("data.out", "data");
+
+ TestEstimatorCore(pipeline, input);
+
+ var model = pipeline.Fit(input);
+ var output = model.Transform(input);
+ var rows = output.Preview().RowView;
+
+ // Since we can't set the precision yet on the Native side and it returns the whole string value, only checking the first 10 places.
+ Assert.Equal(double.MinValue.ToString("F10").Substring(0,10), rows[0].Values[1].Value.ToString().Substring(0, 10));
+ Assert.Equal(double.MaxValue.ToString("F10").Substring(0, 10), rows[1].Values[1].Value.ToString().Substring(0, 10));
+
+ Done();
+ }
+
+ [Fact]
+ public void TestString()
+ {
+ MLContext mlContext = new MLContext(1);
+
+ var data = new[] { new StringInput() { data = ""}, new StringInput() { data = "Long Dummy String Value" } };
+ IDataView input = mlContext.Data.LoadFromEnumerable(data);
+ var pipeline = mlContext.Transforms.ToStringTransformer("output", "data");
+
+ TestEstimatorCore(pipeline, input);
+
+ var model = pipeline.Fit(input);
+ var output = model.Transform(input);
+ var rows = output.Preview().RowView;
+
+ Assert.Equal("", rows[0].Values[1].Value.ToString());
+ Assert.Equal("Long Dummy String Value", rows[1].Values[1].Value.ToString());
+ Done();
+
+
+
+ }
+
+ [Fact]
+ public void TestEntryPoint()
+ {
+ MLContext mlContext = new MLContext(1);
+
+ var data = new[] { new StringInput() { data = ""}, new StringInput() { data = "Long Dummy String Value" } };
+ IDataView input = mlContext.Data.LoadFromEnumerable(data);
+
+ var options = new ToStringTransformerEstimator.Options()
+ {
+ Columns = new ToStringTransformerEstimator.Column[1]
+ {
+ new ToStringTransformerEstimator.Column()
+ {
+ Name = "data"
+ }
+ },
+ Data = input
+ };
+
+ var output = ToStringTransformerEntrypoint.ToString(mlContext.Transforms.GetEnvironment(), options);
+
+ var rows = output.OutputData.Preview().RowView;
+
+ Assert.Equal("", rows[0].Values[1].Value.ToString());
+ Assert.Equal("Long Dummy String Value", rows[1].Values[1].Value.ToString());
+ Done();
+ }
+ }
+}
diff --git a/test/data/dates.txt b/test/data/dates.txt
new file mode 100644
index 0000000000..510371af4e
Binary files /dev/null and b/test/data/dates.txt differ