Skip to content

Commit eabf20e

Browse files
committed
fixed IDataview so each has its own transformer
updates with correct native nuget package: including signed native nuget package. Null fix added tests, removed dll check
1 parent 198c95c commit eabf20e

File tree

7 files changed

+266
-198
lines changed

7 files changed

+266
-198
lines changed

src/Microsoft.ML.Featurizers/CategoryImputer.cs

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,6 @@ internal static CategoryImputerTransformerEstimator Create(IHostEnvironment env,
8888

8989
internal CategoryImputerTransformerEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName)
9090
{
91-
try
92-
{
93-
CheckIfDllExists(IntPtr.Zero, out IntPtr errorHandleString, out IntPtr errorHandleStringSize);
94-
}
95-
catch
96-
{
97-
throw new Exception("Featurizers library not found");
98-
}
99-
10091
Contracts.CheckValue(env, nameof(env));
10192
_host = env.Register(nameof(CategoryImputerTransformerEstimator));
10293
_options = new Options
@@ -107,15 +98,6 @@ internal CategoryImputerTransformerEstimator(IHostEnvironment env, string output
10798

10899
internal CategoryImputerTransformerEstimator(IHostEnvironment env, Options options)
109100
{
110-
try
111-
{
112-
CheckIfDllExists(IntPtr.Zero, out IntPtr errorHandleString, out IntPtr errorHandleStringSize);
113-
}
114-
catch
115-
{
116-
throw new Exception("Featurizers library not found");
117-
}
118-
119101
Contracts.CheckValue(env, nameof(env));
120102
_host = env.Register(nameof(CategoryImputerTransformerEstimator));
121103

@@ -169,15 +151,6 @@ public sealed class CategoryImputerTransformer : RowToRowTransformerBase, IDispo
169151
internal CategoryImputerTransformer(IHostEnvironment host, IDataView input, CategoryImputerTransformerEstimator.Column[] columns) :
170152
base(host.Register(nameof(CategoryImputerTransformerEstimator)))
171153
{
172-
try
173-
{
174-
CheckIfDllExists(IntPtr.Zero, out IntPtr errorHandleString, out IntPtr errorHandleStringSize);
175-
}
176-
catch
177-
{
178-
throw new Exception("Featurizers library not found");
179-
}
180-
181154
var schema = input.Schema;
182155

183156
_columns = columns.Select(x => TypedColumn.CreateTypedColumn(x.Name, x.Source, schema[x.Source].Type.RawType.ToString())).ToArray();
@@ -191,15 +164,6 @@ internal CategoryImputerTransformer(IHostEnvironment host, IDataView input, Cate
191164
internal CategoryImputerTransformer(IHostEnvironment host, ModelLoadContext ctx) :
192165
base(host.Register(nameof(CategoryImputerTransformer)))
193166
{
194-
try
195-
{
196-
CheckIfDllExists(IntPtr.Zero, out IntPtr errorHandleString, out IntPtr errorHandleStringSize);
197-
}
198-
catch
199-
{
200-
throw new Exception("Featurizers library not found");
201-
}
202-
203167
Host.CheckValue(ctx, nameof(ctx));
204168
ctx.CheckAtModel(GetVersionInfo());
205169
// *** Binary format ***

src/Microsoft.ML.Featurizers/TimeSeriesImputer.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,14 @@ internal TimeSeriesImputerEstimator(IHostEnvironment env, string timeSeriesColum
107107
_host = Contracts.CheckRef(env, nameof(env)).Register("TimeSeriesImputerEstimator");
108108
_host.CheckValue(timeSeriesColumn, nameof(timeSeriesColumn), "TimePoint column should not be null.");
109109
_host.CheckNonEmpty(grainColumns, nameof(grainColumns), "Need at least one grain column.");
110-
if(filterMode != FilterMode.NoFilter)
110+
if (filterMode == FilterMode.Include)
111111
_host.CheckNonEmpty(filterColumns, nameof(filterColumns), "Need at least 1 filter column if a FilterMode is specified");
112112

113113
_options = new Options
114114
{
115115
TimeSeriesColumn = timeSeriesColumn,
116116
GrainColumns = grainColumns,
117-
FilterColumns = filterColumns,
117+
FilterColumns = filterColumns == null ? new string[] { } : filterColumns,
118118
FilterMode = filterMode,
119119
ImputeMode = imputeMode,
120120
SupressTypeErrors = supressTypeErrors
@@ -262,7 +262,7 @@ private unsafe TransformerEstimatorSafeHandle CreateTransformerFromSavedData(byt
262262
if (!result)
263263
throw new Exception(CommonExtensions.GetErrorDetailsAndFreeNativeMemory(errorHandle));
264264

265-
return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
265+
return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
266266
}
267267
}
268268

@@ -433,6 +433,8 @@ internal TimeSeriesImputerDataView MakeDataTransform(IDataView input)
433433
return new TimeSeriesImputerDataView(_host, input, _timeSeriesColumn, _grainColumns, _dataColumns, _allColumnNames, this);
434434
}
435435

436+
internal TransformerEstimatorSafeHandle CloneTransformer() => CreateTransformerFromSavedData(CreateTransformerSaveData());
437+
436438
public void Dispose()
437439
{
438440
if (!TransformerHandle.IsClosed)

src/Microsoft.ML.Featurizers/TimeSeriesImputerDataView.cs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ internal TypedColumn(DataViewSchema.Column column, bool isImputed, SharedColumnS
4949
}
5050

5151
internal abstract Delegate GetGetter();
52-
internal abstract void InitializeGetter(DataViewRowCursor cursor, TimeSeriesImputerTransformer transformerParent, string timeSeriesColumn,
52+
internal abstract void InitializeGetter(DataViewRowCursor cursor, TimeSeriesImputerTransformer.TransformerEstimatorSafeHandle transformerParent, string timeSeriesColumn,
5353
string[] grainColumns, string[] dataColumns, string[] allColumnNames, Dictionary<string, TypedColumn> allColumns);
5454

5555
internal abstract TypeId GetTypeId();
@@ -129,7 +129,7 @@ internal override Delegate GetGetter()
129129
return _getter;
130130
}
131131

132-
internal override unsafe void InitializeGetter(DataViewRowCursor cursor, TimeSeriesImputerTransformer transformerParent, string timeSeriesColumn,
132+
internal override unsafe void InitializeGetter(DataViewRowCursor cursor, TimeSeriesImputerTransformer.TransformerEstimatorSafeHandle transformer, string timeSeriesColumn,
133133
string[] grainColumns, string[] dataColumns, string[] allImputedColumnNames, Dictionary<string, TypedColumn> allColumns)
134134
{
135135
if (Column.Name != IsRowImputedColumnName)
@@ -154,7 +154,7 @@ internal override unsafe void InitializeGetter(DataViewRowCursor cursor, TimeSer
154154
fixed (byte* bufferPointer = SharedState.ColumnBuffer)
155155
{
156156
var binaryArchiveData = new NativeBinaryArchiveData() { Data = bufferPointer, DataSize = new IntPtr(bufferLength) };
157-
success = TransformDataNative(transformerParent.TransformerHandle, binaryArchiveData, out outputData, out outputDataSize, out errorHandle);
157+
success = TransformDataNative(transformer, binaryArchiveData, out outputData, out outputDataSize, out errorHandle);
158158
if (!success)
159159
throw new Exception(CommonExtensions.GetErrorDetailsAndFreeNativeMemory(errorHandle));
160160
}
@@ -633,7 +633,7 @@ protected override bool ReleaseHandle()
633633
private readonly string[] _allImputedColumnNames;
634634
private readonly DataViewSchema _schema;
635635

636-
public TimeSeriesImputerDataView(IHostEnvironment env, IDataView input, string timeSeriesColumn, string[] grainColumns, string[] dataColumns, string[] allColumnNames, TimeSeriesImputerTransformer parent)
636+
internal TimeSeriesImputerDataView(IHostEnvironment env, IDataView input, string timeSeriesColumn, string[] grainColumns, string[] dataColumns, string[] allColumnNames, TimeSeriesImputerTransformer parent)
637637
{
638638
_host = env;
639639
_source = input;
@@ -664,7 +664,7 @@ public DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columns
664664

665665
// Column WasColumnImputed is new, don't get its value from source
666666
var input = _source.GetRowCursorForAllColumns();//.GetRowCursor(columnsNeeded.Where(x => x.Name != IsRowImputedColumnName));
667-
return new Cursor(_host, input, _parent, _timeSeriesColumn, _grainColumns, _dataColumns, _allImputedColumnNames, _schema);
667+
return new Cursor(_host, input, _parent.CloneTransformer(), _timeSeriesColumn, _grainColumns, _dataColumns, _allImputedColumnNames, _schema);
668668
}
669669

670670
// Can't use parallel cursors so this defaults to calling non-parallel version
@@ -687,9 +687,9 @@ private sealed class Cursor : DataViewRowCursor
687687
private bool _isGood;
688688
private readonly Dictionary<string, TypedColumn> _allColumns;
689689
private readonly DataViewSchema _schema;
690-
private readonly TimeSeriesImputerTransformer _transformerParent;
690+
private readonly TimeSeriesImputerTransformer.TransformerEstimatorSafeHandle _transformer;
691691

692-
public Cursor(IChannelProvider provider, DataViewRowCursor input, TimeSeriesImputerTransformer transformerParent, string timeSeriesColumn,
692+
public Cursor(IChannelProvider provider, DataViewRowCursor input, TimeSeriesImputerTransformer.TransformerEstimatorSafeHandle transformer, string timeSeriesColumn,
693693
string[] grainColumns, string[] dataColumns, string[] allImputedColumnNames, DataViewSchema schema)
694694
{
695695
_ch = provider;
@@ -699,7 +699,7 @@ public Cursor(IChannelProvider provider, DataViewRowCursor input, TimeSeriesImpu
699699
var length = input.Schema.Count;
700700
_position = -1;
701701
_schema = schema;
702-
_transformerParent = transformerParent;
702+
_transformer = transformer;
703703

704704
var sharedState = new SharedColumnState()
705705
{
@@ -712,7 +712,7 @@ public Cursor(IChannelProvider provider, DataViewRowCursor input, TimeSeriesImpu
712712

713713
foreach (var column in _allColumns.Values)
714714
{
715-
column.InitializeGetter(_input, transformerParent, timeSeriesColumn, grainColumns, dataColumns, allImputedColumnNames, _allColumns);
715+
column.InitializeGetter(_input, transformer, timeSeriesColumn, grainColumns, dataColumns, allImputedColumnNames, _allColumns);
716716
}
717717
}
718718

@@ -736,6 +736,12 @@ public override bool IsColumnActive(DataViewSchema.Column column)
736736
return true;
737737
}
738738

739+
protected override void Dispose(bool disposing)
740+
{
741+
if (!_transformer.IsClosed)
742+
_transformer.Close();
743+
}
744+
739745
/// <summary>
740746
/// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
741747
/// This throws if the column is not active in this row, or if the type

src/Microsoft.ML.Featurizers/ToStringTransformer.cs

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,6 @@ internal static ToStringTransformerEstimator Create(IHostEnvironment env, params
130130

131131
internal ToStringTransformerEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName)
132132
{
133-
try
134-
{
135-
CheckIfDllExists(IntPtr.Zero, out IntPtr errorHandleString, out IntPtr errorHandleStringSize);
136-
}
137-
catch
138-
{
139-
throw new Exception("Featurizers library not found");
140-
}
141-
142133
Contracts.CheckValue(env, nameof(env));
143134
_host = env.Register(nameof(ToStringTransformerEstimator));
144135

@@ -150,15 +141,6 @@ internal ToStringTransformerEstimator(IHostEnvironment env, string outputColumnN
150141

151142
internal ToStringTransformerEstimator(IHostEnvironment env, Options options)
152143
{
153-
try
154-
{
155-
CheckIfDllExists(IntPtr.Zero, out IntPtr errorHandleString, out IntPtr errorHandleStringSize);
156-
}
157-
catch
158-
{
159-
throw new Exception("Featurizers library not found");
160-
}
161-
162144
Contracts.CheckValue(env, nameof(env));
163145
_host = env.Register(nameof(ToStringTransformerEstimator));
164146

@@ -211,15 +193,6 @@ public sealed class ToStringTransformer : RowToRowTransformerBase, IDisposable
211193
internal ToStringTransformer(IHostEnvironment host, ToStringTransformerEstimator.Column[] columns, IDataView input) :
212194
base(host.Register(nameof(ToStringTransformer)))
213195
{
214-
try
215-
{
216-
CheckIfDllExists(IntPtr.Zero, out IntPtr errorHandleString, out IntPtr errorHandleStringSize);
217-
}
218-
catch
219-
{
220-
throw new Exception("Featurizers library not found");
221-
}
222-
223196
var schema = input.Schema;
224197

225198
_columns = columns.Select(x => TypedColumn.CreateTypedColumn(x.Name, x.Source, schema[x.Source].Type.RawType.ToString())).ToArray();
@@ -234,15 +207,6 @@ internal ToStringTransformer(IHostEnvironment host, ToStringTransformerEstimator
234207
internal ToStringTransformer(IHostEnvironment host, ModelLoadContext ctx) :
235208
base(host.Register(nameof(ToStringTransformer)))
236209
{
237-
try
238-
{
239-
CheckIfDllExists(IntPtr.Zero, out IntPtr errorHandleString, out IntPtr errorHandleStringSize);
240-
}
241-
catch
242-
{
243-
throw new Exception("Featurizers library not found");
244-
}
245-
246210
Host.CheckValue(ctx, nameof(ctx));
247211
ctx.CheckAtModel(GetVersionInfo());
248212
// *** Binary format ***
Binary file not shown.

test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -257,17 +257,6 @@ public void HolidayTest()
257257
Assert.Equal("Christmas Day", row[20].Value.ToString()); // HolidayName
258258
Assert.Equal((byte)0, row[21].Value); // IsPaidTimeOff
259259

260-
mlContext.Model.Save(model, data.Schema, "D:/Repos/machinelearning/temp-nuget-folder/model.zip");
261-
var loadedModel = mlContext.Model.Load("D:/Repos/machinelearning/temp-nuget-folder/model.zip", out DataViewSchema loadedSchema);
262-
263-
var loadedOutput = loadedModel.Transform(data);
264-
var loadedRow = loadedOutput.Preview(1).RowView[0].Values;
265-
266-
// Assert the data from the first row from loaded data for holidays is what we expect
267-
Assert.Equal("Christmas Day", loadedRow[20].Value.ToString()); // HolidayName
268-
Assert.Equal((byte)0, loadedRow[21].Value); // IsPaidTimeOff
269-
270-
271260
TestEstimatorCore(pipeline, data);
272261
Done();
273262
}

0 commit comments

Comments
 (0)