Skip to content

Commit e5d407c

Browse files
committed
updates with correct native nuget package:
1 parent 39dcb9b commit e5d407c

File tree

3 files changed

+63
-18
lines changed

3 files changed

+63
-18
lines changed

src/Microsoft.ML.Featurizers/TimeSeriesImputer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,14 @@ internal TimeSeriesImputerEstimator(IHostEnvironment env, string timeSeriesColum
107107
_host = Contracts.CheckRef(env, nameof(env)).Register("TimeSeriesImputerEstimator");
108108
_host.CheckValue(timeSeriesColumn, nameof(timeSeriesColumn), "TimePoint column should not be null.");
109109
_host.CheckNonEmpty(grainColumns, nameof(grainColumns), "Need at least one grain column.");
110-
if(filterMode != FilterMode.NoFilter)
110+
if(filterMode == FilterMode.Include)
111111
_host.CheckNonEmpty(filterColumns, nameof(filterColumns), "Need at least 1 filter column if a FilterMode is specified");
112112

113113
_options = new Options
114114
{
115115
TimeSeriesColumn = timeSeriesColumn,
116116
GrainColumns = grainColumns,
117-
FilterColumns = filterColumns,
117+
FilterColumns = filterColumns == null? new string[] { } : filterColumns,
118118
FilterMode = filterMode,
119119
ImputeMode = imputeMode,
120120
SupressTypeErrors = supressTypeErrors
Binary file not shown.

test/Microsoft.ML.Tests/Transformers/TimeSeriesImputerTests.cs

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ public TimeSeriesImputerTests(ITestOutputHelper output) : base(output)
1919
{
2020
}
2121

22-
//private class TimeSeriesTwoGrainInput
23-
//{
24-
// public long date;
25-
// public string grainA;
26-
// public string grainB;
27-
// public int dataA;
28-
// public float dataB;
29-
// public uint dataC;
30-
//}
22+
private class TimeSeriesTwoGrainInput
23+
{
24+
public long date;
25+
public string grainA;
26+
public string grainB;
27+
public int dataA;
28+
public float dataB;
29+
public uint dataC;
30+
}
3131

3232
private class TimeSeriesOneGrainInput
3333
{
@@ -138,18 +138,17 @@ public void Median()
138138
var data = mlContext.Data.LoadFromEnumerable(dataList);
139139

140140
// Build the pipeline, fit, and transform it.
141-
var pipeline = mlContext.Transforms.TimeSeriesImputer("date", new string[] { "grainA" }, imputeMode: TimeSeriesImputerEstimator.ImputationStrategy.ForwardFill);
141+
var pipeline = mlContext.Transforms.TimeSeriesImputer("date", new string[] { "grainA" }, imputeMode: TimeSeriesImputerEstimator.ImputationStrategy.Median, filterColumns: null, suppressTypeErrors: true);
142142
var model = pipeline.Fit(data);
143-
mlContext.Model.Save(model, data.Schema, "D:/Repos/machinelearning/temp-nuget-folder/SaveNoTransform.zip");
144-
var loadedModel = mlContext.Model.Load("D:/Repos/machinelearning/temp-nuget-folder/SaveNoTransform.zip", out DataViewSchema loadedSchema);
143+
144+
dataList = new[] { new TimeSeriesOneGrainFloatInput() { date = 0, grainA = "B", dataA = float.NaN },
145+
new TimeSeriesOneGrainFloatInput() { date = 1, grainA = "B", dataA = 2 }};
146+
147+
data = mlContext.Data.LoadFromEnumerable(dataList);
145148

146149
var output = model.Transform(data);
147-
var loadedOutput = loadedModel.Transform(data);
148150

149151
var prev = output.Preview();
150-
var loadedPrev = loadedOutput.Preview();
151-
152-
Assert.Equal(prev.ColumnView.Length, loadedPrev.ColumnView.Length);
153152

154153
TestEstimatorCore(pipeline, data);
155154
Done();
@@ -177,5 +176,51 @@ public void Backfill()
177176
TestEstimatorCore(pipeline, data);
178177
Done();
179178
}
179+
180+
[Fact]
181+
public void BackfillTwoGrain()
182+
{
183+
MLContext mlContext = new MLContext(1);
184+
var dataList = new[] { new TimeSeriesTwoGrainInput() { date = 0, grainA = "A", grainB = "B", dataA = 0, dataB = 0.0f, dataC = 0 },
185+
new TimeSeriesTwoGrainInput() { date = 1, grainA = "A", grainB = "B", dataA = 0, dataB = float.NaN, dataC = 0 },
186+
new TimeSeriesTwoGrainInput() { date = 3, grainA = "A", grainB = "B", dataA = 0, dataB = 1.0f, dataC = 0 },
187+
new TimeSeriesTwoGrainInput() { date = 5, grainA = "A", grainB = "B", dataA = 0, dataB = float.NaN, dataC = 0 },
188+
new TimeSeriesTwoGrainInput() { date = 7, grainA = "A", grainB = "B", dataA = 0, dataB = 2.0f, dataC = 0 }};
189+
var data = mlContext.Data.LoadFromEnumerable(dataList);
190+
191+
// Build the pipeline, fit, and transform it.
192+
var pipeline = mlContext.Transforms.TimeSeriesImputer("date", new string[] { "grainA", "grainB" }, TimeSeriesImputerEstimator.ImputationStrategy.BackFill);
193+
var model = pipeline.Fit(data);
194+
var output = model.Transform(data);
195+
var schema = output.Schema;
196+
var outputSchema = model.GetOutputSchema(data.Schema);
197+
var prev = output.Preview();
198+
199+
TestEstimatorCore(pipeline, data);
200+
Done();
201+
}
202+
203+
[Fact]
204+
public void ForwardFillTwoGrain()
205+
{
206+
MLContext mlContext = new MLContext(1);
207+
var dataList = new[] { new TimeSeriesTwoGrainInput() { date = 0, grainA = "A", grainB = "B", dataA = 0, dataB = 0.0f, dataC = 0 },
208+
new TimeSeriesTwoGrainInput() { date = 1, grainA = "A", grainB = "B", dataA = 0, dataB = float.NaN, dataC = 0 },
209+
new TimeSeriesTwoGrainInput() { date = 3, grainA = "A", grainB = "B", dataA = 0, dataB = 1.0f, dataC = 0 },
210+
new TimeSeriesTwoGrainInput() { date = 5, grainA = "A", grainB = "B", dataA = 0, dataB = float.NaN, dataC = 0 },
211+
new TimeSeriesTwoGrainInput() { date = 7, grainA = "A", grainB = "B", dataA = 0, dataB = 2.0f, dataC = 0 }};
212+
var data = mlContext.Data.LoadFromEnumerable(dataList);
213+
214+
// Build the pipeline, fit, and transform it.
215+
var pipeline = mlContext.Transforms.TimeSeriesImputer("date", new string[] { "grainA", "grainB" }, TimeSeriesImputerEstimator.ImputationStrategy.ForwardFill);
216+
var model = pipeline.Fit(data);
217+
var output = model.Transform(data);
218+
var schema = output.Schema;
219+
var outputSchema = model.GetOutputSchema(data.Schema);
220+
var prev = output.Preview();
221+
222+
TestEstimatorCore(pipeline, data);
223+
Done();
224+
}
180225
}
181226
}

0 commit comments

Comments
 (0)