Skip to content

Commit a4af0ec

Browse files
mengaimsklausmhyaeldMS
authored
Add SrCnn entire API by implementing function (#5135)
* Draft PR for SrCnn batch detection API interface (#1) * POC Batch transform * SrCnn batch interface * Removed comment * Handled some APIreview comments. * Handled other review comments. * Resolved review comments. Added sample. Co-authored-by: Yael Dekel <[email protected]> * Implement SrCnn entire API by function * Fix bugs and add test * Resolve comments * Change names and add documentation * Handling review comments * Resolve the array allocating issue * Move modeler initializing to CreateBatch and other minor fix. * Fix 3 remaining comments * Fixed code analysis issue. * Fixed minor comments Co-authored-by: klausmh <[email protected]> Co-authored-by: Yael Dekel <[email protected]>
1 parent d58e8d1 commit a4af0ec

File tree

5 files changed

+1238
-5
lines changed

5 files changed

+1238
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using Microsoft.ML;
4+
using Microsoft.ML.Data;
5+
using Microsoft.ML.TimeSeries;
6+
7+
namespace Samples.Dynamic
8+
{
9+
public static class DetectEntireAnomalyBySrCnn
10+
{
11+
public static void Example()
12+
{
13+
// Create a new ML context, for ML.NET operations. It can be used for
14+
// exception tracking and logging,
15+
// as well as the source of randomness.
16+
var ml = new MLContext();
17+
18+
// Generate sample series data with an anomaly
19+
var data = new List<TimeSeriesData>();
20+
for (int index = 0; index < 20; index++)
21+
{
22+
data.Add(new TimeSeriesData { Value = 5 });
23+
}
24+
data.Add(new TimeSeriesData { Value = 10 });
25+
for (int index = 0; index < 5; index++)
26+
{
27+
data.Add(new TimeSeriesData { Value = 5 });
28+
}
29+
30+
// Convert data to IDataView.
31+
var dataView = ml.Data.LoadFromEnumerable(data);
32+
33+
// Setup the detection arguments
34+
string outputColumnName = nameof(SrCnnAnomalyDetection.Prediction);
35+
string inputColumnName = nameof(TimeSeriesData.Value);
36+
37+
// Do batch anomaly detection
38+
var outputDataView = ml.AnomalyDetection.DetectEntireAnomalyBySrCnn(dataView, outputColumnName, inputColumnName,
39+
threshold: 0.35, batchSize: 512, sensitivity: 90.0, detectMode: SrCnnDetectMode.AnomalyAndMargin);
40+
41+
// Getting the data of the newly created column as an IEnumerable of
42+
// SrCnnAnomalyDetection.
43+
var predictionColumn = ml.Data.CreateEnumerable<SrCnnAnomalyDetection>(
44+
outputDataView, reuseRowObject: false);
45+
46+
Console.WriteLine("Index\tData\tAnomaly\tAnomalyScore\tMag\tExpectedValue\tBoundaryUnit\tUpperBoundary\tLowerBoundary");
47+
48+
int k = 0;
49+
foreach (var prediction in predictionColumn)
50+
{
51+
PrintPrediction(k, data[k].Value, prediction);
52+
k++;
53+
}
54+
//Index Data Anomaly AnomalyScore Mag ExpectedValue BoundaryUnit UpperBoundary LowerBoundary
55+
//0 5.00 0 0.00 0.21 5.00 5.00 5.01 4.99
56+
//1 5.00 0 0.00 0.11 5.00 5.00 5.01 4.99
57+
//2 5.00 0 0.00 0.03 5.00 5.00 5.01 4.99
58+
//3 5.00 0 0.00 0.01 5.00 5.00 5.01 4.99
59+
//4 5.00 0 0.00 0.03 5.00 5.00 5.01 4.99
60+
//5 5.00 0 0.00 0.06 5.00 5.00 5.01 4.99
61+
//6 5.00 0 0.00 0.02 5.00 5.00 5.01 4.99
62+
//7 5.00 0 0.00 0.01 5.00 5.00 5.01 4.99
63+
//8 5.00 0 0.00 0.01 5.00 5.00 5.01 4.99
64+
//9 5.00 0 0.00 0.01 5.00 5.00 5.01 4.99
65+
//10 5.00 0 0.00 0.00 5.00 5.00 5.01 4.99
66+
//11 5.00 0 0.00 0.01 5.00 5.00 5.01 4.99
67+
//12 5.00 0 0.00 0.01 5.00 5.00 5.01 4.99
68+
//13 5.00 0 0.00 0.02 5.00 5.00 5.01 4.99
69+
//14 5.00 0 0.00 0.07 5.00 5.00 5.01 4.99
70+
//15 5.00 0 0.00 0.08 5.00 5.00 5.01 4.99
71+
//16 5.00 0 0.00 0.02 5.00 5.00 5.01 4.99
72+
//17 5.00 0 0.00 0.05 5.00 5.00 5.01 4.99
73+
//18 5.00 0 0.00 0.12 5.00 5.00 5.01 4.99
74+
//19 5.00 0 0.00 0.17 5.00 5.00 5.01 4.99
75+
//20 10.00 1 0.50 0.80 5.00 5.00 5.01 4.99
76+
//21 5.00 0 0.00 0.16 5.00 5.00 5.01 4.99
77+
//22 5.00 0 0.00 0.11 5.00 5.00 5.01 4.99
78+
//23 5.00 0 0.00 0.05 5.00 5.00 5.01 4.99
79+
//24 5.00 0 0.00 0.11 5.00 5.00 5.01 4.99
80+
//25 5.00 0 0.00 0.19 5.00 5.00 5.01 4.99
81+
}
82+
83+
private static void PrintPrediction(int idx, double value, SrCnnAnomalyDetection prediction) =>
84+
Console.WriteLine("{0}\t{1:0.00}\t{2}\t\t{3:0.00}\t{4:0.00}\t\t{5:0.00}\t\t{6:0.00}\t\t{7:0.00}\t\t{8:0.00}",
85+
idx, value, prediction.Prediction[0], prediction.Prediction[1], prediction.Prediction[2],
86+
prediction.Prediction[3], prediction.Prediction[4], prediction.Prediction[5], prediction.Prediction[6]);
87+
88+
private class TimeSeriesData
89+
{
90+
public double Value { get; set; }
91+
}
92+
93+
private class SrCnnAnomalyDetection
94+
{
95+
[VectorType]
96+
public double[] Prediction { get; set; }
97+
}
98+
}
99+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
8+
using Microsoft.ML.Runtime;
9+
10+
namespace Microsoft.ML.Data.DataView
11+
{
12+
internal abstract class BatchDataViewMapperBase<TInput, TBatch> : IDataView
13+
{
14+
public bool CanShuffle => false;
15+
16+
public DataViewSchema Schema => SchemaBindings.AsSchema;
17+
18+
private readonly IDataView _source;
19+
protected readonly IHost Host;
20+
21+
protected BatchDataViewMapperBase(IHostEnvironment env, string registrationName, IDataView input)
22+
{
23+
Contracts.CheckValue(env, nameof(env));
24+
Host = env.Register(registrationName);
25+
_source = input;
26+
}
27+
28+
public long? GetRowCount() => _source.GetRowCount();
29+
30+
public DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
31+
{
32+
Host.CheckValue(columnsNeeded, nameof(columnsNeeded));
33+
Host.CheckValueOrNull(rand);
34+
35+
var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, SchemaBindings.AsSchema);
36+
37+
// If we aren't selecting any of the output columns, don't construct our cursor.
38+
// Note that because we cannot support random due to the inherently
39+
// stratified nature, neither can we allow the base data to be shuffled,
40+
// even if it supports shuffling.
41+
if (!SchemaBindings.AnyNewColumnsActive(predicate))
42+
{
43+
var activeInput = SchemaBindings.GetActiveInput(predicate);
44+
var inputCursor = _source.GetRowCursor(_source.Schema.Where(c => activeInput[c.Index]), null);
45+
return new BindingsWrappedRowCursor(Host, inputCursor, SchemaBindings);
46+
}
47+
var active = SchemaBindings.GetActive(predicate);
48+
Contracts.Assert(active.Length == SchemaBindings.ColumnCount);
49+
50+
// REVIEW: We can get a different input predicate for the input cursor and for the lookahead cursor. The lookahead
51+
// cursor is only used for getting the values from the input column, so it only needs that column activated. The
52+
// other cursor is used to get source columns, so it needs the rest of them activated.
53+
var predInput = GetSchemaBindingDependencies(predicate);
54+
var inputCols = _source.Schema.Where(c => predInput(c.Index));
55+
return new Cursor(this, _source.GetRowCursor(inputCols), _source.GetRowCursor(inputCols), active);
56+
}
57+
58+
public DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
59+
{
60+
return new[] { GetRowCursor(columnsNeeded, rand) };
61+
}
62+
63+
protected abstract ColumnBindingsBase SchemaBindings { get; }
64+
protected abstract TBatch CreateBatch(DataViewRowCursor input);
65+
protected abstract void ProcessBatch(TBatch currentBatch);
66+
protected abstract void ProcessExample(TBatch currentBatch, TInput currentInput);
67+
protected abstract Func<bool> GetLastInBatchDelegate(DataViewRowCursor lookAheadCursor);
68+
protected abstract Func<bool> GetIsNewBatchDelegate(DataViewRowCursor lookAheadCursor);
69+
protected abstract ValueGetter<TInput> GetLookAheadGetter(DataViewRowCursor lookAheadCursor);
70+
protected abstract Delegate[] CreateGetters(DataViewRowCursor input, TBatch currentBatch, bool[] active);
71+
protected abstract Func<int, bool> GetSchemaBindingDependencies(Func<int, bool> predicate);
72+
73+
private sealed class Cursor : RootCursorBase
74+
{
75+
private readonly BatchDataViewMapperBase<TInput, TBatch> _parent;
76+
private readonly DataViewRowCursor _lookAheadCursor;
77+
private readonly DataViewRowCursor _input;
78+
79+
private readonly bool[] _active;
80+
private readonly Delegate[] _getters;
81+
82+
private readonly TBatch _currentBatch;
83+
private readonly Func<bool> _lastInBatchInLookAheadCursorDel;
84+
private readonly Func<bool> _firstInBatchInInputCursorDel;
85+
private readonly ValueGetter<TInput> _inputGetterInLookAheadCursor;
86+
private TInput _currentInput;
87+
88+
public override long Batch => 0;
89+
90+
public override DataViewSchema Schema => _parent.Schema;
91+
92+
public Cursor(BatchDataViewMapperBase<TInput, TBatch> parent, DataViewRowCursor input, DataViewRowCursor lookAheadCursor, bool[] active)
93+
: base(parent.Host)
94+
{
95+
_parent = parent;
96+
_input = input;
97+
_lookAheadCursor = lookAheadCursor;
98+
_active = active;
99+
100+
_currentBatch = _parent.CreateBatch(_input);
101+
102+
_getters = _parent.CreateGetters(_input, _currentBatch, _active);
103+
104+
_lastInBatchInLookAheadCursorDel = _parent.GetLastInBatchDelegate(_lookAheadCursor);
105+
_firstInBatchInInputCursorDel = _parent.GetIsNewBatchDelegate(_input);
106+
_inputGetterInLookAheadCursor = _parent.GetLookAheadGetter(_lookAheadCursor);
107+
}
108+
109+
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
110+
{
111+
Contracts.CheckParam(IsColumnActive(column), nameof(column), "requested column is not active");
112+
113+
var col = _parent.SchemaBindings.MapColumnIndex(out bool isSrc, column.Index);
114+
if (isSrc)
115+
{
116+
Contracts.AssertValue(_input);
117+
return _input.GetGetter<TValue>(_input.Schema[col]);
118+
}
119+
120+
Ch.AssertValue(_getters);
121+
var getter = _getters[col];
122+
Ch.Assert(getter != null);
123+
var fn = getter as ValueGetter<TValue>;
124+
if (fn == null)
125+
throw Ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
126+
return fn;
127+
}
128+
129+
public override ValueGetter<DataViewRowId> GetIdGetter()
130+
{
131+
return
132+
(ref DataViewRowId val) =>
133+
{
134+
Ch.Check(IsGood, "Cannot call ID getter in current state");
135+
val = new DataViewRowId((ulong)Position, 0);
136+
};
137+
}
138+
139+
public override bool IsColumnActive(DataViewSchema.Column column)
140+
{
141+
Ch.Check(column.Index < _parent.SchemaBindings.AsSchema.Count);
142+
return _active[column.Index];
143+
}
144+
145+
protected override bool MoveNextCore()
146+
{
147+
if (!_input.MoveNext())
148+
return false;
149+
if (!_firstInBatchInInputCursorDel())
150+
return true;
151+
152+
// If we are here, this means that _input.MoveNext() has gotten us to the beginning of the next batch,
153+
// so now we need to look ahead at the entire next batch in the _lookAheadCursor.
154+
// The _lookAheadCursor's position should be on the last row of the previous batch (or -1).
155+
Ch.Assert(_lastInBatchInLookAheadCursorDel());
156+
157+
var good = _lookAheadCursor.MoveNext();
158+
// The two cursors should have the same number of elements, so if _input.MoveNext() returned true,
159+
// then it must return true here too.
160+
Ch.Assert(good);
161+
162+
do
163+
{
164+
_inputGetterInLookAheadCursor(ref _currentInput);
165+
_parent.ProcessExample(_currentBatch, _currentInput);
166+
} while (!_lastInBatchInLookAheadCursorDel() && _lookAheadCursor.MoveNext());
167+
168+
_parent.ProcessBatch(_currentBatch);
169+
return true;
170+
}
171+
}
172+
}
173+
}

src/Microsoft.ML.TimeSeries/ExtensionsCatalog.cs

+29-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
6-
using System.Reflection;
75
using Microsoft.ML.Data;
86
using Microsoft.ML.Runtime;
97
using Microsoft.ML.TimeSeries;
@@ -150,6 +148,35 @@ public static SrCnnAnomalyEstimator DetectAnomalyBySrCnn(this TransformsCatalog
150148
int windowSize = 64, int backAddWindowSize = 5, int lookaheadWindowSize = 5, int averageingWindowSize = 3, int judgementWindowSize = 21, double threshold = 0.3)
151149
=> new SrCnnAnomalyEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, windowSize, backAddWindowSize, lookaheadWindowSize, averageingWindowSize, judgementWindowSize, threshold, inputColumnName);
152150

151+
/// <summary>
152+
/// Create <see cref="SrCnnEntireAnomalyDetector"/>, which detects timeseries anomalies for entire input using SRCNN algorithm.
153+
/// </summary>
154+
/// <param name="catalog">The AnomalyDetectionCatalog.</param>
155+
/// <param name="input">Input DataView.</param>
156+
/// <param name="outputColumnName">Name of the column resulting from data processing of <paramref name="inputColumnName"/>.
157+
/// The column data is a vector of <see cref="System.Double"/>. The length of this vector varies depending on <paramref name="detectMode"/>.</param>
158+
/// <param name="inputColumnName">Name of column to process. The column data must be <see cref="System.Double"/>.</param>
159+
/// <param name="threshold">The threshold to determine anomaly, score larger than the threshold is considered as anomaly. Must be in [0,1]. Default value is 0.3.</param>
160+
/// <param name="batchSize">Divide the input data into batches to fit srcnn model.
161+
/// When set to -1, use the whole input to fit model instead of batch by batch, when set to a positive integer, use this number as batch size.
162+
/// Must be -1 or a positive integer no less than 12. Default value is 1024.</param>
163+
/// <param name="sensitivity">Sensitivity of boundaries, only useful when srCnnDetectMode is AnomalyAndMargin. Must be in [0,100]. Default value is 99.</param>
164+
/// <param name="detectMode">An enum type of <see cref="SrCnnDetectMode"/>.
165+
/// When set to AnomalyOnly, the output vector would be a 3-element Double vector of (IsAnomaly, RawScore, Mag).
166+
/// When set to AnomalyAndExpectedValue, the output vector would be a 4-element Double vector of (IsAnomaly, RawScore, Mag, ExpectedValue).
167+
/// When set to AnomalyAndMargin, the output vector would be a 7-element Double vector of (IsAnomaly, AnomalyScore, Mag, ExpectedValue, BoundaryUnit, UpperBoundary, LowerBoundary).
168+
/// Default value is AnomalyOnly.</param>
169+
/// <example>
170+
/// <format type="text/markdown">
171+
/// <![CDATA[
172+
/// [!code-csharp[DetectEntireAnomalyBySrCnn](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectEntireAnomalyBySrCnn.cs)]
173+
/// ]]>
174+
/// </format>
175+
/// </example>
176+
public static IDataView DetectEntireAnomalyBySrCnn(this AnomalyDetectionCatalog catalog, IDataView input, string outputColumnName, string inputColumnName,
177+
double threshold = 0.3, int batchSize = 1024, double sensitivity = 99, SrCnnDetectMode detectMode = SrCnnDetectMode.AnomalyOnly)
178+
=> new SrCnnEntireAnomalyDetector(CatalogUtils.GetEnvironment(catalog), input, inputColumnName, outputColumnName, threshold, batchSize, sensitivity, detectMode);
179+
153180
/// <summary>
154181
/// Create <see cref="RootCause"/>, which localizes root causes using decision tree algorithm.
155182
/// </summary>

0 commit comments

Comments
 (0)