Skip to content

Commit 4d6434f

Browse files
committed
adding in TimeSeriesImputer
1 parent 4cd8763 commit 4d6434f

File tree

6 files changed

+30
-694
lines changed

6 files changed

+30
-694
lines changed

src/Microsoft.ML.AutoMLFeaturizers/CategoryImputer.cs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -378,15 +378,6 @@ protected override bool ReleaseHandle()
378378

379379
#endregion
380380

381-
#region FitResult
382-
383-
internal enum FitResult : byte
384-
{
385-
Complete = 0, Continue, ResetAndContinue
386-
}
387-
388-
#endregion
389-
390381
#region ColumnInfo
391382

392383
// REVIEW: Since we can't do overloading on the native side due to the C style exports,

src/Microsoft.ML.AutoMLFeaturizers/Common.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace Microsoft.ML.AutoMLFeaturizers
1313

1414
internal enum FitResult : byte
1515
{
16-
Complete = 0, Continue, ResetAndContinue
16+
Complete = 1, Continue, ResetAndContinue
1717
}
1818

1919
internal enum TypeId : byte

src/Microsoft.ML.AutoMLFeaturizers/DateTimeTransformer.cs

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ namespace Microsoft.ML.AutoMLFeaturizers
3232
public static class DateTimeTransformerExtensionClass
3333
{
3434
public static DateTimeTransformerEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, params DateTimeTransformerEstimator.ColumnsProduced[] columnsToDrop)
35-
=> DateTimeTransformerEstimator.Create(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop);
35+
=> new DateTimeTransformerEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop);
36+
37+
public static DateTimeTransformerEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, DateTimeTransformerEstimator.ColumnsProduced[] columnsToDrop = null, DateTimeTransformerEstimator.Countries country = DateTimeTransformerEstimator.Countries.None)
38+
=> new DateTimeTransformerEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop, country);
3639

3740
#region ColumnsProduced static extentions
3841

@@ -77,21 +80,19 @@ internal sealed class Options: TransformInputBase
7780

7881
[Argument(ArgumentType.MultipleUnique, HelpText = "Columns to drop after the DateTime Expansion", Name = "ColumnsToDrop", ShortName = "drop", SortOrder = 3)]
7982
public ColumnsProduced[] ColumnsToDrop;
83+
84+
[Argument(ArgumentType.AtMostOnce, HelpText = "Country to get holidays for. Defaults to none if not passed", Name = "Country", ShortName = "ctry", SortOrder = 4)]
85+
public Countries Country = Countries.None;
8086
}
8187

8288
#endregion
8389

84-
internal static DateTimeTransformerEstimator Create(IHostEnvironment env, string inputColumnName, string columnPrefix, ColumnsProduced[] columnsToDrop)
85-
{
86-
return new DateTimeTransformerEstimator(env, inputColumnName, columnPrefix, columnsToDrop);
87-
}
88-
8990
// Using this to confirm DLL exists. If does it will just return false since no parameters are being passed.
9091
// Once we have a binary dependency on the dll we can remove this code.
9192
[DllImport("Featurizers", EntryPoint = "GetErrorInfoString"), SuppressUnmanagedCodeSecurity]
9293
private static extern bool CheckIfDllExists(IntPtr error, out IntPtr errorHandleString, out IntPtr errorHandleStringSize);
9394

94-
public DateTimeTransformerEstimator(IHostEnvironment env, string inputColumnName, string columnPrefix, ColumnsProduced[] columnsToDrop)
95+
public DateTimeTransformerEstimator(IHostEnvironment env, string inputColumnName, string columnPrefix, ColumnsProduced[] columnsToDrop, Countries country = Countries.None)
9596
{
9697
try
9798
{
@@ -110,7 +111,8 @@ public DateTimeTransformerEstimator(IHostEnvironment env, string inputColumnName
110111
{
111112
Source = inputColumnName,
112113
Prefix = columnPrefix,
113-
ColumnsToDrop = columnsToDrop
114+
ColumnsToDrop = columnsToDrop == null ? Array.Empty<ColumnsProduced>() : columnsToDrop,
115+
Country = country
114116
};
115117
}
116118

@@ -129,6 +131,7 @@ internal DateTimeTransformerEstimator(IHostEnvironment env, Options options)
129131
_host = Contracts.CheckRef(env, nameof(env)).Register("DateTimeTransformerEstimator");
130132

131133
_options = options;
134+
_options.ColumnsToDrop = _options.ColumnsToDrop == null ? Array.Empty<ColumnsProduced>() : _options.ColumnsToDrop;
132135
}
133136

134137
public DateTimeTransformer Fit(IDataView input)
@@ -151,14 +154,23 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
151154
return new SchemaShape(columns.Values);
152155
}
153156

154-
#region Column Enums
157+
#region Enums
155158
public enum ColumnsProduced : byte
156159
{
157-
Year = 0, Month, Day, Hour, Minute, Second, AmPm, Hour12, DayOfWeek, DayOfQuarter, DayOfYear,
160+
Year = 1, Month, Day, Hour, Minute, Second, AmPm, Hour12, DayOfWeek, DayOfQuarter, DayOfYear,
158161
WeekOfMonth, QuarterOfYear, HalfOfYear, WeekIso, YearIso, MonthLabel, AmPmLabel, DayOfWeekLabel,
159162
HolidayName, IsPaidTimeOff
160163
};
161164

165+
public enum Countries : byte
166+
{
167+
None = 1,
168+
Argentina, Australia, Austria, Belarus, Belgium, Brazil, Canada, Colombia, Croatia, Czech, Denmark,
169+
England, Finland, France, Germany, Hungary, India, Ireland, IsleofMan, Italy, Japan, Mexico, Netherlands,
170+
NewZealand, NorthernIreland, Norway, Poland, Portugal, Scotland, Slovenia, SouthAfrica, Spain, Sweden, Switzerland,
171+
Ukraine, UnitedKingdom, UnitedStates, Wales
172+
}
173+
162174
#endregion
163175
}
164176

@@ -473,14 +485,17 @@ internal unsafe TimePoint(byte* rawData)
473485
private static unsafe string GetStringFromPointer(ref byte* rawData, int intPtrSize)
474486
{
475487
byte[] buffer;
476-
byte* temp = rawData + intPtrSize;
477-
long tempSize = *(long*)(temp);
478-
int itempSize = *(int*)(temp);
479488
if (intPtrSize == 4) // 32 bit machine
480489
buffer = new byte[*(uint*)(rawData + intPtrSize)];
481490
else // 64 bit machine
482491
buffer = new byte[*(ulong*)(rawData + intPtrSize)];
483492

493+
if (buffer.Length == 0)
494+
{
495+
rawData += intPtrSize * 2;
496+
return string.Empty;
497+
}
498+
484499
Marshal.Copy(new IntPtr(*(int**)rawData), buffer, 0, buffer.Length);
485500
rawData += intPtrSize * 2;
486501

@@ -802,7 +817,7 @@ internal static class DateTimeTransformerEntrypoint
802817
public static CommonOutputs.TransformOutput DateTimeSplit(IHostEnvironment env, DateTimeTransformerEstimator.Options input)
803818
{
804819
var h = EntryPointUtils.CheckArgsAndCreateHost(env, DateTimeTransformer.ShortName, input);
805-
var xf = DateTimeTransformerEstimator.Create(h, input.Source, input.Prefix, input.ColumnsToDrop).Fit(input.Data).Transform(input.Data);
820+
var xf = new DateTimeTransformerEstimator(h, input).Fit(input.Data).Transform(input.Data);
806821
return new CommonOutputs.TransformOutput()
807822
{
808823
Model = new TransformModelImpl(h, xf, input.Data),

0 commit comments

Comments
 (0)