|
| 1 | +// Licensed to the .NET Foundation under one or more agreements. |
| 2 | +// The .NET Foundation licenses this file to you under the MIT license. |
| 3 | +// See the LICENSE file in the project root for more information. |
| 4 | + |
| 5 | +using System; |
| 6 | +using System.Collections.Generic; |
| 7 | +using System.Linq; |
| 8 | +using Microsoft.ML.Runtime; |
| 9 | + |
| 10 | +namespace Microsoft.ML.Data.DataView |
| 11 | +{ |
| 12 | + internal abstract class BatchDataViewMapperBase<TInput, TBatch> : IDataView |
| 13 | + { |
| 14 | + public bool CanShuffle => false; |
| 15 | + |
| 16 | + public DataViewSchema Schema => SchemaBindings.AsSchema; |
| 17 | + |
| 18 | + private readonly IDataView _source; |
| 19 | + protected readonly IHost Host; |
| 20 | + |
| 21 | + protected BatchDataViewMapperBase(IHostEnvironment env, string registrationName, IDataView input) |
| 22 | + { |
| 23 | + Contracts.CheckValue(env, nameof(env)); |
| 24 | + Host = env.Register(registrationName); |
| 25 | + _source = input; |
| 26 | + } |
| 27 | + |
| 28 | + public long? GetRowCount() => _source.GetRowCount(); |
| 29 | + |
| 30 | + public DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null) |
| 31 | + { |
| 32 | + Host.CheckValue(columnsNeeded, nameof(columnsNeeded)); |
| 33 | + Host.CheckValueOrNull(rand); |
| 34 | + |
| 35 | + var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, SchemaBindings.AsSchema); |
| 36 | + |
| 37 | + // If we aren't selecting any of the output columns, don't construct our cursor. |
| 38 | + // Note that because we cannot support random due to the inherently |
| 39 | + // stratified nature, neither can we allow the base data to be shuffled, |
| 40 | + // even if it supports shuffling. |
| 41 | + if (!SchemaBindings.AnyNewColumnsActive(predicate)) |
| 42 | + { |
| 43 | + var activeInput = SchemaBindings.GetActiveInput(predicate); |
| 44 | + var inputCursor = _source.GetRowCursor(_source.Schema.Where(c => activeInput[c.Index]), null); |
| 45 | + return new BindingsWrappedRowCursor(Host, inputCursor, SchemaBindings); |
| 46 | + } |
| 47 | + var active = SchemaBindings.GetActive(predicate); |
| 48 | + Contracts.Assert(active.Length == SchemaBindings.ColumnCount); |
| 49 | + |
| 50 | + // REVIEW: We can get a different input predicate for the input cursor and for the lookahead cursor. The lookahead |
| 51 | + // cursor is only used for getting the values from the input column, so it only needs that column activated. The |
| 52 | + // other cursor is used to get source columns, so it needs the rest of them activated. |
| 53 | + var predInput = GetSchemaBindingDependencies(predicate); |
| 54 | + var inputCols = _source.Schema.Where(c => predInput(c.Index)); |
| 55 | + return new Cursor(this, _source.GetRowCursor(inputCols), _source.GetRowCursor(inputCols), active); |
| 56 | + } |
| 57 | + |
| 58 | + public DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null) |
| 59 | + { |
| 60 | + return new[] { GetRowCursor(columnsNeeded, rand) }; |
| 61 | + } |
| 62 | + |
| 63 | + protected abstract ColumnBindingsBase SchemaBindings { get; } |
| 64 | + protected abstract TBatch CreateBatch(DataViewRowCursor input); |
| 65 | + protected abstract void ProcessBatch(TBatch currentBatch); |
| 66 | + protected abstract void ProcessExample(TBatch currentBatch, TInput currentInput); |
| 67 | + protected abstract Func<bool> GetLastInBatchDelegate(DataViewRowCursor lookAheadCursor); |
| 68 | + protected abstract Func<bool> GetIsNewBatchDelegate(DataViewRowCursor lookAheadCursor); |
| 69 | + protected abstract ValueGetter<TInput> GetLookAheadGetter(DataViewRowCursor lookAheadCursor); |
| 70 | + protected abstract Delegate[] CreateGetters(DataViewRowCursor input, TBatch currentBatch, bool[] active); |
| 71 | + protected abstract Func<int, bool> GetSchemaBindingDependencies(Func<int, bool> predicate); |
| 72 | + |
| 73 | + private sealed class Cursor : RootCursorBase |
| 74 | + { |
| 75 | + private readonly BatchDataViewMapperBase<TInput, TBatch> _parent; |
| 76 | + private readonly DataViewRowCursor _lookAheadCursor; |
| 77 | + private readonly DataViewRowCursor _input; |
| 78 | + |
| 79 | + private readonly bool[] _active; |
| 80 | + private readonly Delegate[] _getters; |
| 81 | + |
| 82 | + private readonly TBatch _currentBatch; |
| 83 | + private readonly Func<bool> _lastInBatchInLookAheadCursorDel; |
| 84 | + private readonly Func<bool> _firstInBatchInInputCursorDel; |
| 85 | + private readonly ValueGetter<TInput> _inputGetterInLookAheadCursor; |
| 86 | + private TInput _currentInput; |
| 87 | + |
| 88 | + public override long Batch => 0; |
| 89 | + |
| 90 | + public override DataViewSchema Schema => _parent.Schema; |
| 91 | + |
| 92 | + public Cursor(BatchDataViewMapperBase<TInput, TBatch> parent, DataViewRowCursor input, DataViewRowCursor lookAheadCursor, bool[] active) |
| 93 | + : base(parent.Host) |
| 94 | + { |
| 95 | + _parent = parent; |
| 96 | + _input = input; |
| 97 | + _lookAheadCursor = lookAheadCursor; |
| 98 | + _active = active; |
| 99 | + |
| 100 | + _currentBatch = _parent.CreateBatch(_input); |
| 101 | + |
| 102 | + _getters = _parent.CreateGetters(_input, _currentBatch, _active); |
| 103 | + |
| 104 | + _lastInBatchInLookAheadCursorDel = _parent.GetLastInBatchDelegate(_lookAheadCursor); |
| 105 | + _firstInBatchInInputCursorDel = _parent.GetIsNewBatchDelegate(_input); |
| 106 | + _inputGetterInLookAheadCursor = _parent.GetLookAheadGetter(_lookAheadCursor); |
| 107 | + } |
| 108 | + |
| 109 | + public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column) |
| 110 | + { |
| 111 | + Contracts.CheckParam(IsColumnActive(column), nameof(column), "requested column is not active"); |
| 112 | + |
| 113 | + var col = _parent.SchemaBindings.MapColumnIndex(out bool isSrc, column.Index); |
| 114 | + if (isSrc) |
| 115 | + { |
| 116 | + Contracts.AssertValue(_input); |
| 117 | + return _input.GetGetter<TValue>(_input.Schema[col]); |
| 118 | + } |
| 119 | + |
| 120 | + Ch.AssertValue(_getters); |
| 121 | + var getter = _getters[col]; |
| 122 | + Ch.Assert(getter != null); |
| 123 | + var fn = getter as ValueGetter<TValue>; |
| 124 | + if (fn == null) |
| 125 | + throw Ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue)); |
| 126 | + return fn; |
| 127 | + } |
| 128 | + |
| 129 | + public override ValueGetter<DataViewRowId> GetIdGetter() |
| 130 | + { |
| 131 | + return |
| 132 | + (ref DataViewRowId val) => |
| 133 | + { |
| 134 | + Ch.Check(IsGood, "Cannot call ID getter in current state"); |
| 135 | + val = new DataViewRowId((ulong)Position, 0); |
| 136 | + }; |
| 137 | + } |
| 138 | + |
| 139 | + public override bool IsColumnActive(DataViewSchema.Column column) |
| 140 | + { |
| 141 | + Ch.Check(column.Index < _parent.SchemaBindings.AsSchema.Count); |
| 142 | + return _active[column.Index]; |
| 143 | + } |
| 144 | + |
| 145 | + protected override bool MoveNextCore() |
| 146 | + { |
| 147 | + if (!_input.MoveNext()) |
| 148 | + return false; |
| 149 | + if (!_firstInBatchInInputCursorDel()) |
| 150 | + return true; |
| 151 | + |
| 152 | + // If we are here, this means that _input.MoveNext() has gotten us to the beginning of the next batch, |
| 153 | + // so now we need to look ahead at the entire next batch in the _lookAheadCursor. |
| 154 | + // The _lookAheadCursor's position should be on the last row of the previous batch (or -1). |
| 155 | + Ch.Assert(_lastInBatchInLookAheadCursorDel()); |
| 156 | + |
| 157 | + var good = _lookAheadCursor.MoveNext(); |
| 158 | + // The two cursors should have the same number of elements, so if _input.MoveNext() returned true, |
| 159 | + // then it must return true here too. |
| 160 | + Ch.Assert(good); |
| 161 | + |
| 162 | + do |
| 163 | + { |
| 164 | + _inputGetterInLookAheadCursor(ref _currentInput); |
| 165 | + _parent.ProcessExample(_currentBatch, _currentInput); |
| 166 | + } while (!_lastInBatchInLookAheadCursorDel() && _lookAheadCursor.MoveNext()); |
| 167 | + |
| 168 | + _parent.ProcessBatch(_currentBatch); |
| 169 | + return true; |
| 170 | + } |
| 171 | + } |
| 172 | + } |
| 173 | +} |
0 commit comments