Skip to content

Commit 5d819b7

Browse files
committed
Prediction engine now uses IRowToRowMapper
1 parent 8935b43 commit 5d819b7

File tree

10 files changed

+488
-241
lines changed

10 files changed

+488
-241
lines changed

src/Microsoft.ML.Api/ComponentCreation.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System;
66
using System.Collections.Generic;
77
using System.IO;
8+
using Microsoft.ML.Core.Data;
89
using Microsoft.ML.Runtime.CommandLine;
910
using Microsoft.ML.Runtime.Data;
1011
using Microsoft.ML.Runtime.Model;
@@ -189,6 +190,26 @@ public static PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(th
189190
return new PredictionEngine<TSrc, TDst>(env, dataPipe, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
190191
}
191192

193+
/// <summary>
194+
/// Create an on-demand prediction engine.
195+
/// </summary>
196+
/// <param name="env">The host environment to use.</param>
197+
/// <param name="transformer">The transformer.</param>
198+
/// <param name="ignoreMissingColumns">Whether to ignore missing columns in the data view.</param>
199+
/// <param name="inputSchemaDefinition">The optional input schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TSrc"/> type.</param>
200+
/// <param name="outputSchemaDefinition">The optional output schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TDst"/> type.</param>
201+
public static PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(this IHostEnvironment env, ITransformer transformer,
202+
bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
203+
where TSrc : class
204+
where TDst : class, new()
205+
{
206+
Contracts.CheckValue(env, nameof(env));
207+
env.CheckValue(transformer, nameof(transformer));
208+
env.CheckValueOrNull(inputSchemaDefinition);
209+
env.CheckValueOrNull(outputSchemaDefinition);
210+
return new PredictionEngine<TSrc, TDst>(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
211+
}
212+
192213
/// <summary>
193214
/// Create a prediction engine.
194215
/// This encapsulates the 'classic' prediction problem, where the input is denoted by the float array of features,

src/Microsoft.ML.Api/DataViewConstructionUtils.cs

Lines changed: 363 additions & 167 deletions
Large diffs are not rendered by default.

src/Microsoft.ML.Api/MapTransform.cs

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ internal sealed class MapTransform<TSrc, TDst> : LambdaTransformBase, ITransform
2424
where TDst : class, new()
2525
{
2626
private const string RegistrationNameTemplate = "MapTransform<{0}, {1}>";
27-
28-
private readonly IDataView _source;
2927
private readonly Action<TSrc, TDst> _mapAction;
3028
private readonly MergedSchema _schema;
3129

@@ -56,20 +54,20 @@ public MapTransform(IHostEnvironment env, IDataView source, Action<TSrc, TDst> m
5654
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
5755
: base(env, RegistrationName, saveAction, loadFunc)
5856
{
59-
Host.AssertValue(source, "source");
60-
Host.AssertValue(mapAction, "mapAction");
57+
Host.AssertValue(source);
58+
Host.AssertValue(mapAction);
6159
Host.AssertValueOrNull(inputSchemaDefinition);
6260
Host.AssertValueOrNull(outputSchemaDefinition);
6361

64-
_source = source;
62+
Source = source;
6563
_mapAction = mapAction;
6664
_inputSchemaDefinition = inputSchemaDefinition;
6765
_typedSource = TypedCursorable<TSrc>.Create(Host, Source, false, inputSchemaDefinition);
6866
var outSchema = outputSchemaDefinition == null
6967
? InternalSchemaDefinition.Create(typeof(TDst), SchemaDefinition.Direction.Write)
7068
: InternalSchemaDefinition.Create(typeof(TDst), outputSchemaDefinition);
7169

72-
_schema = MergedSchema.Create(_source.Schema, outSchema);
70+
_schema = MergedSchema.Create(Source.Schema, outSchema);
7371
}
7472

7573
/// <summary>
@@ -81,26 +79,20 @@ private MapTransform(IHostEnvironment env, MapTransform<TSrc, TDst> transform, I
8179
Host.AssertValue(transform);
8280
Host.AssertValue(newSource);
8381

84-
_source = newSource;
82+
Source = newSource;
8583
_mapAction = transform._mapAction;
8684
_typedSource = TypedCursorable<TSrc>.Create(Host, newSource, false, transform._inputSchemaDefinition);
8785

8886
_schema = MergedSchema.Create(newSource.Schema, transform._schema.AddedSchema);
8987
}
9088

91-
public bool CanShuffle
92-
{
93-
get { return _source.CanShuffle; }
94-
}
89+
public bool CanShuffle => Source.CanShuffle;
9590

96-
public ISchema Schema
97-
{
98-
get { return _schema; }
99-
}
91+
public ISchema Schema => _schema;
10092

10193
public long? GetRowCount(bool lazy = true)
10294
{
103-
return _source.GetRowCount(lazy);
95+
return Source.GetRowCount(lazy);
10496
}
10597

10698
public IRowCursor GetRowCursor(Func<int, bool> predicate, IRandom rand = null)
@@ -136,10 +128,7 @@ public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Fun
136128
return cursors;
137129
}
138130

139-
public IDataView Source
140-
{
141-
get { return _source; }
142-
}
131+
public IDataView Source { get; }
143132

144133
public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource)
145134
{
@@ -161,13 +150,13 @@ public Func<int, bool> GetDependencies(Func<int, bool> predicate)
161150
return _typedSource.GetDependencies(srcPredicate);
162151
}
163152

164-
ISchema IRowToRowMapper.InputSchema => _source.Schema;
153+
ISchema IRowToRowMapper.InputSchema => Source.Schema;
165154

166155
public IRow GetRow(IRow input, Func<int, bool> active, out Action disposer)
167156
{
168157
Host.CheckValue(input, nameof(input));
169158
Host.CheckValue(active, nameof(active));
170-
Host.CheckParam(input.Schema == _source.Schema, nameof(input), "Schema of input row must be the same as the schema the mapper is bound to");
159+
Host.CheckParam(input.Schema == Source.Schema, nameof(input), "Schema of input row must be the same as the schema the mapper is bound to");
171160

172161
var src = new TSrc();
173162
var dst = new TDst();
@@ -234,7 +223,7 @@ public ValueGetter<TValue> GetGetter<TValue>(int col)
234223
() =>
235224
{
236225
if (!IsGood)
237-
throw Contracts.Except("Getter is called when the cursor is {0}, which is not allowed.", Input.State);
226+
throw Ch.Except("Getter is called when the cursor is {0}, which is not allowed.", Input.State);
238227
};
239228
return _row.GetGetterCore<TValue>(col, isGood);
240229
}
@@ -253,7 +242,6 @@ public override void Dispose()
253242

254243
private sealed class Row : IRow
255244
{
256-
private readonly ISchema _schema;
257245
private readonly IRow<TSrc> _input;
258246
private readonly IRow _appendedRow;
259247
private readonly bool[] _active;
@@ -269,15 +257,15 @@ private sealed class Row : IRow
269257

270258
public long Position => _input.Position;
271259

272-
public ISchema Schema => _schema;
260+
public ISchema Schema { get; }
273261

274262
public Row(IRow<TSrc> input, MapTransform<TSrc, TDst> parent, Func<int, bool> active, TSrc src, TDst dst)
275263
{
276264
_input = input;
277265
_parent = parent;
278-
_schema = parent.Schema;
266+
Schema = parent.Schema;
279267

280-
_active = Utils.BuildArray(_schema.ColumnCount, active);
268+
_active = Utils.BuildArray(Schema.ColumnCount, active);
281269
_src = src;
282270
_dst = dst;
283271

@@ -321,7 +309,7 @@ public ValueGetter<UInt128> GetIdGetter()
321309

322310
public bool IsColumnActive(int col)
323311
{
324-
_parent.Host.Check(0 <= col && col < _schema.ColumnCount);
312+
_parent.Host.Check(0 <= col && col < Schema.ColumnCount);
325313
return _active[col];
326314
}
327315
}

src/Microsoft.ML.Api/PredictionEngine.cs

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
using System.IO;
99
using Microsoft.ML.Runtime.Data;
1010
using Microsoft.ML.Runtime.Model;
11+
using Microsoft.ML.Core.Data;
12+
using System;
1113

1214
namespace Microsoft.ML.Runtime.Api
1315
{
@@ -38,21 +40,7 @@ internal BatchPredictionEngine(IHostEnvironment env, Stream modelStream, bool ig
3840

3941
// Initialize pipe.
4042
_srcDataView = DataViewConstructionUtils.CreateFromEnumerable(env, new TSrc[] { }, inputSchemaDefinition);
41-
42-
// Load transforms.
43-
var pipe = env.LoadTransforms(modelStream, _srcDataView);
44-
45-
// Load predictor (if present) and apply default scorer.
46-
// REVIEW: distinguish the case of predictor / no predictor?
47-
var predictor = env.LoadPredictorOrNull(modelStream);
48-
if (predictor != null)
49-
{
50-
var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, modelStream);
51-
pipe = roles != null
52-
? env.CreateDefaultScorer(new RoleMappedData(pipe, roles, opt: true), predictor)
53-
: env.CreateDefaultScorer(new RoleMappedData(pipe, label: null, "Features"), predictor);
54-
}
55-
43+
var pipe = DataViewConstructionUtils.LoadPipeWithPredictor(env, modelStream, _srcDataView);
5644
_pipeEngine = new PipeEngine<TDst>(env, pipe, ignoreMissingColumns, outputSchemaDefinition);
5745
}
5846

@@ -150,24 +138,64 @@ public sealed class PredictionEngine<TSrc, TDst>
150138
where TSrc : class
151139
where TDst : class, new()
152140
{
153-
private readonly BatchPredictionEngine<TSrc, TDst> _engine;
141+
private readonly DataViewConstructionUtils.InputRow<TSrc> _inputRow;
142+
private readonly IRow<TDst> _outputRow;
143+
private readonly Action _disposer;
144+
private TDst _result;
154145

155146
internal PredictionEngine(IHostEnvironment env, Stream modelStream, bool ignoreMissingColumns,
156147
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
148+
: this(env, StreamChecker(env, modelStream), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
157149
{
158-
Contracts.CheckValue(env, nameof(env));
159-
var singleThreadedEnv = env.Register("SingleThreaded", conc: 1);
160-
_engine = new BatchPredictionEngine<TSrc, TDst>(singleThreadedEnv, modelStream, ignoreMissingColumns,
161-
inputSchemaDefinition, outputSchemaDefinition);
150+
}
151+
152+
private static Func<ISchema, IRowToRowMapper> StreamChecker(IHostEnvironment env, Stream modelStream)
153+
{
154+
env.CheckValue(modelStream, nameof(modelStream));
155+
return schema =>
156+
{
157+
var pipe = DataViewConstructionUtils.LoadPipeWithPredictor(env, modelStream, new EmptyDataView(env, schema));
158+
var transformer = new TransformWrapper(env, pipe);
159+
env.CheckParam(transformer.IsRowToRowMapper, nameof(transformer), "Must be a row to row mapper");
160+
return transformer.GetRowToRowMapper(schema);
161+
};
162162
}
163163

164164
internal PredictionEngine(IHostEnvironment env, IDataView dataPipe, bool ignoreMissingColumns,
165165
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
166+
: this(env, new TransformWrapper(env, env.CheckRef(dataPipe, nameof(dataPipe))), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
167+
{
168+
}
169+
170+
internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
171+
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
172+
: this(env, TransformerChecker(env, transformer), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
173+
{
174+
}
175+
176+
private static Func<ISchema, IRowToRowMapper> TransformerChecker(IExceptionContext ectx, ITransformer transformer)
177+
{
178+
ectx.CheckValue(transformer, nameof(transformer));
179+
ectx.CheckParam(transformer.IsRowToRowMapper, nameof(transformer), "Must be a row to row mapper");
180+
return transformer.GetRowToRowMapper;
181+
}
182+
183+
private PredictionEngine(IHostEnvironment env, Func<ISchema, IRowToRowMapper> makeMapper, bool ignoreMissingColumns,
184+
SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition)
166185
{
167186
Contracts.CheckValue(env, nameof(env));
168-
var singleThreadedEnv = env.Register("SingleThreaded", conc: 1);
169-
_engine = new BatchPredictionEngine<TSrc, TDst>(singleThreadedEnv, dataPipe, ignoreMissingColumns,
170-
inputSchemaDefinition, outputSchemaDefinition);
187+
env.AssertValue(makeMapper);
188+
189+
_inputRow = DataViewConstructionUtils.CreateInputRow<TSrc>(env, inputSchemaDefinition);
190+
var mapper = makeMapper(_inputRow.Schema);
191+
var cursorable = TypedCursorable<TDst>.Create(env, new EmptyDataView(env, mapper.Schema), ignoreMissingColumns, outputSchemaDefinition);
192+
var outputRow = mapper.GetRow(_inputRow, col => true, out _disposer);
193+
_outputRow = cursorable.GetRow(outputRow);
194+
}
195+
196+
~PredictionEngine()
197+
{
198+
_disposer?.Invoke();
171199
}
172200

173201
/// <summary>
@@ -178,21 +206,11 @@ internal PredictionEngine(IHostEnvironment env, IDataView dataPipe, bool ignoreM
178206
public TDst Predict(TSrc example)
179207
{
180208
Contracts.CheckValue(example, nameof(example));
181-
int count = 0;
182-
TDst result = null;
183-
foreach (var item in _engine.Predict(new[] { example }, true))
184-
{
185-
if (count == 0)
186-
result = item;
187-
188-
count++;
189-
if (count > 1)
190-
break;
191-
}
192-
193-
if (count > 1)
194-
throw Contracts.Except("Prediction pipeline must return at most one prediction per example. If it isn't, use BatchPredictionEngine.");
195-
return result;
209+
_inputRow.AcceptValues(example);
210+
if (_result == null)
211+
_result = new TDst();
212+
_outputRow.FillValues(_result);
213+
return _result;
196214
}
197215
}
198216

src/Microsoft.ML.Api/PredictionFunction.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public PredictionFunction(IHostEnvironment env, ITransformer transformer)
2323
env.CheckValue(transformer, nameof(transformer));
2424

2525
IDataView dv = env.CreateDataView(new TSrc[0]);
26-
_engine = env.CreatePredictionEngine<TSrc, TDst>(transformer.Transform(dv));
26+
_engine = env.CreatePredictionEngine<TSrc, TDst>(transformer);
2727
}
2828

2929
public TDst Predict(TSrc example) => _engine.Predict(example);

src/Microsoft.ML.Api/TypedCursor.cs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,21 @@ public interface IRow<TRow> : IRow
2626
void FillValues(TRow row);
2727
}
2828

29+
/// <summary>
30+
/// This interface is an <see cref="IRow"/> with 'stringly typed' binding.
31+
/// It can accept values of type <typeparamref name="TRow"/> and present the value as a row.
32+
/// </summary>
33+
/// <typeparam name="TRow"></typeparam>
34+
public interface IInputRow<TRow> : IRow
35+
where TRow : class
36+
{
37+
/// <summary>
38+
/// Accepts the fields of the user-supplied <paramref name="row"/> object and publishes the instance as a row.
39+
/// </summary>
40+
/// <param name="row">The row object. Cannot be null.</param>
41+
void AcceptValues(TRow row);
42+
}
43+
2944
/// <summary>
3045
/// This interface provides cursoring through a <see cref="IDataView"/> via a 'strongly typed' binding.
3146
/// It can populate the user-supplied object's fields with the values of the current row.
@@ -239,11 +254,11 @@ private abstract class TypedRowBase : IRow<TRow>
239254
private readonly IRow _input;
240255
private readonly Action<TRow>[] _setters;
241256

242-
public long Batch { get { return _input.Batch; } }
257+
public long Batch => _input.Batch;
243258

244-
public long Position { get { return _input.Position; } }
259+
public long Position => _input.Position;
245260

246-
public ISchema Schema { get { return _input.Schema; } }
261+
public ISchema Schema => _input.Schema;
247262

248263
public TypedRowBase(TypedCursorable<TRow> parent, IRow input, string channelMessage)
249264
{

src/Microsoft.ML.Data/DataView/ChainedRowToRowMapper.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ public ChainedRowToRowMapper(ISchema inputSchema, IRowToRowMapper[] mappers)
3030
Contracts.CheckValueOrNull(mappers);
3131
_innerMappers = Utils.Size(mappers) > 0 ? mappers : _empty;
3232
InputSchema = inputSchema;
33+
Schema = Utils.Size(mappers) > 0 ? mappers[mappers.Length - 1].Schema : inputSchema;
3334
}
3435

3536
public Func<int, bool> GetDependencies(Func<int, bool> predicate)
@@ -67,7 +68,7 @@ public IRow GetRow(IRow input, Func<int, bool> active, out Action disposer)
6768
// computed based on the dependencies of the next one in the chain.
6869
var deps = new Func<int, bool>[_innerMappers.Length];
6970
deps[deps.Length - 1] = active;
70-
for (int i = deps.Length - 1; i <= 1; --i)
71+
for (int i = deps.Length - 1; i >= 1; --i)
7172
deps[i - 1] = _innerMappers[i].GetDependencies(deps[i]);
7273

7374
IRow result = input;

src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ protected PredictionTransformerBase(IHost host, ModelLoadContext ctx)
9393
/// <returns>The transformed <see cref="IDataView"/></returns>
9494
public abstract IDataView Transform(IDataView input);
9595

96+
public abstract IRowToRowMapper GetRowToRowMapper(ISchema inputSchema);
97+
9698
protected void SaveModel(ModelSaveContext ctx)
9799
{
98100
// *** Binary format ***
@@ -187,8 +189,6 @@ protected virtual void SaveCore(ModelSaveContext ctx)
187189
SaveModel(ctx);
188190
ctx.SaveStringOrNull(FeatureColumn);
189191
}
190-
191-
public abstract IRowToRowMapper GetRowToRowMapper(ISchema inputSchema);
192192
}
193193

194194
/// <summary>

src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ public Func<int, bool> GetDependencies(Func<int, bool> predicate)
485485
{
486486
for (int i = 0; i < Schema.ColumnCount; i++)
487487
{
488-
if (predicate(i) && _inputSchema.Feature != null)
488+
if (predicate(i) && InputRoleMappedSchema.Feature != null)
489489
return col => col == InputRoleMappedSchema.Feature.Index;
490490
}
491491
return col => false;

0 commit comments

Comments
 (0)