Skip to content

Convert ValueMapper to use 'in' parameters #1475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Core/Data/IValueMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ namespace Microsoft.ML.Runtime.Data
/// <summary>
/// Delegate type to map/convert a value.
/// </summary>
public delegate void ValueMapper<TSrc, TDst>(ref TSrc src, ref TDst dst);
public delegate void ValueMapper<TSrc, TDst>(in TSrc src, ref TDst dst);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, is TSrc expected to always be a readonly struct? in causes hidden copies when calling an instance member for non-readonly structs and can cause unexpected perf-regressions in some cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a practical matter in most applications they will be some sort of VBuffer, with some isolated instances here and there as you've observed. Yet these other types from what I see tend to also be readonly structs. (E.g., you observed elsewhere TimeSpan, and here is Int32) I just ran a pretty simple benchmark like so (using BenchmarkDotNet, my new favorite project)...

private static int Foo(in byte a, in byte b) => a + b;
private static int Bar(byte a, byte b) => a + b;

[Benchmark]
public void Foo() { for (int i = 0; i < 1000; ++i) Foo(2, 5); }
[Benchmark]
public void Bar() { for (int i = 0; i < 1000; ++i) Bar(2, 5); }

At least on my computer this gives:

Method Mean Error StdDev
Foo 239.6 ns 2.8371 ns 2.6538 ns
Bar 238.2 ns 0.8780 ns 0.7332 ns

So it seems like it is at least not harmful for things like int, while still giving some benefit for things like VBuffer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my limited knowledge of the area, I would guess that >95% of the time ValueMapper will be used for primitives, reference types, and readonly struct.

The most common case is to use these types:

    using BL = Boolean;
    using DT = DateTime;
    using DZ = DateTimeOffset;
    using R4 = Single;
    using R8 = Double;
    using I1 = SByte;
    using I2 = Int16;
    using I4 = Int32;
    using I8 = Int64;
    using SB = StringBuilder;
    using TX = ReadOnlyMemory<char>;
    using TS = TimeSpan;
    using U1 = Byte;
    using U2 = UInt16;
    using U4 = UInt32;
    using U8 = UInt64;
    using UG = UInt128;

Along with VBuffer, which just got made readonly in #1454.

Skimming this list, I found that UInt128 was not marked readonly. So I'm opening #1496 to make that (and other obvious readonly structs) marked as readonly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any concern with the primitive types not being readonly structs for full framework and/or netstandard right now?


/// <summary>
/// Delegate type to map/convert among three values, for example, one input with two
/// outputs, or two inputs with one output.
/// </summary>
public delegate void ValueMapper<TVal1, TVal2, TVal3>(ref TVal1 val1, ref TVal2 val2, ref TVal3 val3);
public delegate void ValueMapper<TVal1, TVal2, TVal3>(in TVal1 val1, ref TVal2 val2, ref TVal3 val3);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I inspected all the usages of this delegate, and most times it is used by the below interface IValueMapperDist, which takes 1 input and returns 2 outputs.
The only case where this was used for 2 inputs and 1 output was in one place in matrix factorization code (the two inputs were column and row I believe).

I didn't feel it was necessary to split this delegate. If someone does feel it is necessary, I think a separate issue should be logged and addressed separately.


/// <summary>
/// Interface for mapping a single input value (of an indicated ColumnType) to
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ private static void ShowMetadataValue<T>(IndentingTextWriter itw, ISchema schema
var value = default(T);
var sb = default(StringBuilder);
schema.GetMetadata(kind, col, ref value);
conv(ref value, ref sb);
conv(in value, ref sb);

itw.Write(": '{0}'", sb);
}
Expand Down Expand Up @@ -292,7 +292,7 @@ private static void ShowMetadataValueVec<T>(IndentingTextWriter itw, ISchema sch
else
itw.Write(", ");
var val = item.Value;
conv(ref val, ref sb);
conv(in val, ref sb);
itw.Write("[{0}] '{1}'", item.Key, sb);
count++;
}
Expand Down
238 changes: 119 additions & 119 deletions src/Microsoft.ML.Data/Data/Conversion.cs

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/Data/DataViewUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1342,7 +1342,7 @@ public static ValueGetter<ReadOnlyMemory<char>> GetSingleValueGetter<T>(IRow cur
if (!Conversions.Instance.TryGetStringConversion<T>(colType, out conversion))
{
var error = $"Cannot display {colType}";
conversion = (ref T src, ref StringBuilder builder) =>
conversion = (in T src, ref StringBuilder builder) =>
{
if (builder == null)
builder = new StringBuilder();
Expand All @@ -1357,7 +1357,7 @@ public static ValueGetter<ReadOnlyMemory<char>> GetSingleValueGetter<T>(IRow cur
(ref ReadOnlyMemory<char> value) =>
{
floatGetter(ref v);
conversion(ref v, ref dst);
conversion(in v, ref dst);
string text = dst.ToString();
value = text.AsMemory();
};
Expand All @@ -1384,7 +1384,7 @@ public static ValueGetter<ReadOnlyMemory<char>> GetVectorFlatteningGetter<T>(IRo
x =>
{
var v = x.Value;
conversion(ref v, ref dst);
conversion(in v, ref dst);
return dst.ToString();
}));
value = string.Format("<{0}{1}>", stringRep, suffix).AsMemory();
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/Data/RowCursorUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ private static ValueGetter<TDst> GetGetterAsCore<TSrc, TDst>(ColumnType typeSrc,
(ref TDst dst) =>
{
getter(ref src);
conv(ref src, ref dst);
conv(in src, ref dst);
};
}

Expand Down Expand Up @@ -134,7 +134,7 @@ private static ValueGetter<StringBuilder> GetGetterAsStringBuilderCore<TSrc>(Col
(ref StringBuilder dst) =>
{
getter(ref src);
conv(ref src, ref dst);
conv(in src, ref dst);
};
}

Expand Down Expand Up @@ -278,7 +278,7 @@ private static ValueGetter<VBuffer<TDst>> GetVecGetterAsCore<TSrc, TDst>(VectorT
// REVIEW: This would be faster if there were loops for each std conversion.
// Consider adding those to the Conversions class.
for (int i = 0; i < count; i++)
conv(ref src.Values[i], ref values[i]);
conv(in src.Values[i], ref values[i]);

if (!src.IsDense)
{
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1228,7 +1228,7 @@ private TableOfContentsEntry CreateRowIndexEntry(string rowIndexName)
int count = _header.RowCount <= int.MaxValue ? (int)_header.RowCount : 0;
KeyType type = new KeyType(DataKind.U8, 0, count);
// We are mapping the row index as expressed as a long, into a key value, so we must increment by one.
ValueMapper<long, ulong> mapper = (ref long src, ref ulong dst) => dst = (ulong)(src + 1);
ValueMapper<long, ulong> mapper = (in long src, ref ulong dst) => dst = (ulong)(src + 1);
var entry = new TableOfContentsEntry(this, rowIndexName, type, mapper);
return entry;
}
Expand Down Expand Up @@ -1710,7 +1710,7 @@ private void Get(ref T value)
{
Ectx.Check(_curr != null, _badCursorState);
long src = _curr.RowIndexLim - _remaining - 1;
_mapper(ref src, ref value);
_mapper(in src, ref value);
}

public override Delegate GetGetter()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ private ValueGetter<TValue> GetterDelegateCore<TValue>(int col, ColumnType type)

return (ref TValue value) =>
{
conv(ref _colValues[col], ref value);
conv(in _colValues[col], ref value);
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,7 @@ public int GatherFields(ReadOnlyMemory<char> lineSpan, ReadOnlySpan<char> span,
int csrc = default;
try
{
Conversions.Instance.Convert(ref spanT, ref csrc);
Conversions.Instance.Convert(in spanT, ref csrc);
}
catch
{
Expand Down
32 changes: 16 additions & 16 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,28 +116,28 @@ protected ValueWriterBase(PrimitiveType type, int source, char sep)
Conv = Conversions.Instance.GetStringConversion<T>(type);

var d = default(T);
Conv(ref d, ref Sb);
Conv(in d, ref Sb);
Default = Sb.ToString();
}

protected void MapText(ref ReadOnlyMemory<char> src, ref StringBuilder sb)
protected void MapText(in ReadOnlyMemory<char> src, ref StringBuilder sb)
{
TextSaverUtils.MapText(src.Span, ref sb, Sep);
}

protected void MapTimeSpan(ref TimeSpan src, ref StringBuilder sb)
protected void MapTimeSpan(in TimeSpan src, ref StringBuilder sb)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note: some of these don't seem worthwhile to pass by in, since they are often smaller than a register in size.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we see though we benefit from having a uniform signature no matter what the input type is, even if not all types will benefit, since we can then use the same high level code for all types.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These have to be changed to in because these methods are used as ValueMapper delegates.

else if (type.IsTimeSpan)
{
ValueMapper<TimeSpan, StringBuilder> c = MapTimeSpan;
Conv = (ValueMapper<T, StringBuilder>)(Delegate)c;
}
else if (type.IsDateTime)
{
ValueMapper<DateTime, StringBuilder> c = MapDateTime;
Conv = (ValueMapper<T, StringBuilder>)(Delegate)c;
}
else if (type.IsDateTimeZone)
{
ValueMapper<DateTimeOffset, StringBuilder> c = MapDateTimeZone;
Conv = (ValueMapper<T, StringBuilder>)(Delegate)c;
}

{
TextSaverUtils.MapTimeSpan(ref src, ref sb);
TextSaverUtils.MapTimeSpan(in src, ref sb);
}

protected void MapDateTime(ref DateTime src, ref StringBuilder sb)
protected void MapDateTime(in DateTime src, ref StringBuilder sb)
{
TextSaverUtils.MapDateTime(ref src, ref sb);
TextSaverUtils.MapDateTime(in src, ref sb);
}

protected void MapDateTimeZone(ref DateTimeOffset src, ref StringBuilder sb)
protected void MapDateTimeZone(in DateTimeOffset src, ref StringBuilder sb)
{
TextSaverUtils.MapDateTimeZone(ref src, ref sb);
TextSaverUtils.MapDateTimeZone(in src, ref sb);
}
}

Expand Down Expand Up @@ -170,15 +170,15 @@ public override void WriteData(Action<StringBuilder, int> appendItem, out int le
{
for (int i = 0; i < _src.Length; i++)
{
Conv(ref _src.Values[i], ref Sb);
Conv(in _src.Values[i], ref Sb);
appendItem(Sb, i);
}
}
else
{
for (int i = 0; i < _src.Count; i++)
{
Conv(ref _src.Values[i], ref Sb);
Conv(in _src.Values[i], ref Sb);
appendItem(Sb, _src.Indices[i]);
}
}
Expand All @@ -195,7 +195,7 @@ public override void WriteHeader(Action<StringBuilder, int> appendItem, out int
var name = _slotNames.Values[i];
if (name.IsEmpty)
continue;
MapText(ref name, ref Sb);
MapText(in name, ref Sb);
int index = _slotNames.IsDense ? i : _slotNames.Indices[i];
appendItem(Sb, index);
}
Expand All @@ -218,15 +218,15 @@ public ValueWriter(IRowCursor cursor, PrimitiveType type, int source, char sep)
public override void WriteData(Action<StringBuilder, int> appendItem, out int length)
{
_getSrc(ref _src);
Conv(ref _src, ref Sb);
Conv(in _src, ref Sb);
appendItem(Sb, 0);
length = 1;
}

public override void WriteHeader(Action<StringBuilder, int> appendItem, out int length)
{
var span = _columnName.AsMemory();
MapText(ref span, ref Sb);
MapText(in span, ref Sb);
appendItem(Sb, 0);
length = 1;
}
Expand Down Expand Up @@ -846,7 +846,7 @@ internal static void MapText(ReadOnlySpan<char> span, ref StringBuilder sb, char
}
}

internal static void MapTimeSpan(ref TimeSpan src, ref StringBuilder sb)
internal static void MapTimeSpan(in TimeSpan src, ref StringBuilder sb)
{
if (sb == null)
sb = new StringBuilder();
Expand All @@ -856,7 +856,7 @@ internal static void MapTimeSpan(ref TimeSpan src, ref StringBuilder sb)
sb.AppendFormat("\"{0:c}\"", src);
}

internal static void MapDateTime(ref DateTime src, ref StringBuilder sb)
internal static void MapDateTime(in DateTime src, ref StringBuilder sb)
{
if (sb == null)
sb = new StringBuilder();
Expand All @@ -866,7 +866,7 @@ internal static void MapDateTime(ref DateTime src, ref StringBuilder sb)
sb.AppendFormat("\"{0:o}\"", src);
}

internal static void MapDateTimeZone(ref DateTimeOffset src, ref StringBuilder sb)
internal static void MapDateTimeZone(in DateTimeOffset src, ref StringBuilder sb)
{
if (sb == null)
sb = new StringBuilder();
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou
(ref T2 v2) =>
{
getSrc(ref v1);
_map1(ref v1, ref v2);
_map1(in v1, ref v2);
};
return getter;
}
Expand All @@ -178,8 +178,8 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou
(ref T3 v3) =>
{
getSrc(ref v1);
_map1(ref v1, ref v2);
_map2(ref v2, ref v3);
_map1(in v1, ref v2);
_map2(in v2, ref v3);
};
return getter;
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/DataView/LambdaFilter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ public RowCursor(Impl<T1, T2> parent, IRowCursor input, bool[] active)
_pred =
(in T1 src) =>
{
conv(ref _src, ref val);
conv(in _src, ref val);
return pred(in val);
};
}
Expand Down
16 changes: 8 additions & 8 deletions src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ private static IDataView AddTextColumn<TSrc>(IHostEnvironment env, IDataView inp
{
Contracts.Check(typeSrc.RawType == typeof(TSrc));
return LambdaColumnMapper.Create(env, registrationName, input, inputColName, outputColName, typeSrc, TextType.Instance,
(ref TSrc src, ref ReadOnlyMemory<char> dst) => dst = value.AsMemory());
(in TSrc src, ref ReadOnlyMemory<char> dst) => dst = value.AsMemory());
}

/// <summary>
Expand Down Expand Up @@ -406,7 +406,7 @@ private static IDataView AddKeyColumn<TSrc>(IHostEnvironment env, IDataView inpu
{
Contracts.Check(typeSrc.RawType == typeof(TSrc));
return LambdaColumnMapper.Create(env, registrationName, input, inputColName, outputColName, typeSrc,
new KeyType(DataKind.U4, 0, keyCount), (ref TSrc src, ref uint dst) =>
new KeyType(DataKind.U4, 0, keyCount), (in TSrc src, ref uint dst) =>
{
if (value < 0 || value > keyCount)
dst = 0;
Expand Down Expand Up @@ -507,7 +507,7 @@ public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int
if (def.Equals(default(T)))
{
mapper =
(ref VBuffer<T> src, ref VBuffer<T> dst) =>
(in VBuffer<T> src, ref VBuffer<T> dst) =>
{
Contracts.Assert(src.Length == Utils.Size(map));

Expand All @@ -533,7 +533,7 @@ public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int
naIndices.Add(j);
}
mapper =
(ref VBuffer<T> src, ref VBuffer<T> dst) =>
(in VBuffer<T> src, ref VBuffer<T> dst) =>
{
Contracts.Assert(src.Length == Utils.Size(map));
var values = dst.Values;
Expand Down Expand Up @@ -622,7 +622,7 @@ public static void ReconcileKeyValues(IHostEnvironment env, IDataView[] views, s
{
var keyMapperCur = keyValueMappers[i];
ValueMapper<uint, uint> mapper =
(ref uint src, ref uint dst) =>
(in uint src, ref uint dst) =>
{
if (src == 0 || src > keyMapperCur.Length)
dst = 0;
Expand Down Expand Up @@ -653,7 +653,7 @@ public static void ReconcileKeyValuesWithNoNames(IHostEnvironment env, IDataView
if (!views[i].Schema.TryGetColumnIndex(columnName, out var index))
throw env.Except($"Data view {i} doesn't contain a column '{columnName}'");
ValueMapper<uint, uint> mapper =
(ref uint src, ref uint dst) =>
(in uint src, ref uint dst) =>
{
if (src > keyCount)
dst = 0;
Expand Down Expand Up @@ -689,7 +689,7 @@ public static void ReconcileVectorKeyValues(IHostEnvironment env, IDataView[] vi
{
var keyMapperCur = keyValueMappers[i];
ValueMapper<VBuffer<uint>, VBuffer<uint>> mapper =
(ref VBuffer<uint> src, ref VBuffer<uint> dst) =>
(in VBuffer<uint> src, ref VBuffer<uint> dst) =>
{
var values = dst.Values;
if (Utils.Size(values) < src.Count)
Expand Down Expand Up @@ -984,7 +984,7 @@ private static IDataView AddVarLengthColumn<TSrc>(IHostEnvironment env, IDataVie
{
return LambdaColumnMapper.Create(env, "ChangeToVarLength", idv, variableSizeVectorColumnName,
variableSizeVectorColumnName + "_VarLength", typeSrc, new VectorType(typeSrc.ItemType.AsPrimitive),
(ref VBuffer<TSrc> src, ref VBuffer<TSrc> dst) => src.CopyTo(ref dst));
(in VBuffer<TSrc> src, ref VBuffer<TSrc> dst) => src.CopyTo(ref dst));
}

private static List<string> GetMetricNames(IChannel ch, Schema schema, IRow row, Func<int, bool> ignoreCol,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMa
{
perInst = LambdaColumnMapper.Create(Host, "ConvertToDouble", perInst, schema.Label.Name,
schema.Label.Name, perInst.Schema.GetColumnType(labelCol), NumberType.R8,
(ref uint src, ref double dst) => dst = src == 0 ? double.NaN : src - 1 + (double)labelType.AsKey.Min);
(in uint src, ref double dst) => dst = src == 0 ? double.NaN : src - 1 + (double)labelType.AsKey.Min);
}

var perInstSchema = perInst.Schema;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ private IDataView ExtractRelevantIndex(IDataView data)
var name = data.Schema.GetColumnName(i);
var index = _index ?? type.VectorSize / 2;
output = LambdaColumnMapper.Create(Host, "Quantile Regression", output, name, name, type, NumberType.R8,
(ref VBuffer<Double> src, ref Double dst) => dst = src.GetItemOrDefault(index));
(in VBuffer<Double> src, ref Double dst) => dst = src.GetItemOrDefault(index));
output = new ChooseColumnsByIndexTransform(Host,
new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = new[] { i } }, output);
}
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,9 @@ public ValueMapper<TIn, TOut, TDist> GetMapper<TIn, TOut, TDist>()
Host.Check(typeof(TDist) == typeof(Float));
var map = GetMapper<TIn, Float>();
ValueMapper<TIn, Float, Float> del =
(ref TIn src, ref Float score, ref Float prob) =>
(in TIn src, ref Float score, ref Float prob) =>
{
map(ref src, ref score);
map(in src, ref score);
prob = Calibrator.PredictProbability(score);
};
return (ValueMapper<TIn, TOut, TDist>)(Delegate)del;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ private ValueGetter<TDst> GetValueGetter<TSrc, TDst>(IRow input, int colSrc)
(ref TDst dst) =>
{
featureGetter(ref features);
map(ref features, ref dst);
map(in features, ref dst);
};
}

Expand Down Expand Up @@ -546,7 +546,7 @@ private static void EnsureCachedResultValueMapper(ValueMapper<VBuffer<Float>, Fl
if (featureGetter != null)
featureGetter(ref features);

mapper(ref features, ref score, ref prob);
mapper(in features, ref score, ref prob);
cachedPosition = input.Position;
}
}
Expand Down Expand Up @@ -667,7 +667,7 @@ protected override Delegate GetPredictionGetter(IRow input, int colSrc)
{
featureGetter(ref features);
Contracts.Check(features.Length == featureCount || featureCount == 0);
map(ref features, ref value);
map(in features, ref value);
};
return del;
}
Expand Down
Loading