Skip to content

Commit bee7f17

Browse files
authored
Refactor CpuMathUtils (#1229)
* Refactor CpuMathUtils - Allow it to take Spans instead of arrays. - Remove redundant overloads - When multiple spans are accepted, always use an explicit count parameter instead of one being chosen as having the right length. Working towards #608 * Use MemoryMarshal.GetReference to avoid perf hit when pinning Span. * PR feedback
1 parent 7e9e468 commit bee7f17

File tree

17 files changed

+408
-824
lines changed

17 files changed

+408
-824
lines changed

src/Microsoft.ML.Core/Utilities/Contracts.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,19 @@ public static void AssertNonWhiteSpace(this IExceptionContext ctx, string s, str
945945
DbgFailEmpty(ctx, msg);
946946
}
947947

948+
[Conditional("DEBUG")]
949+
public static void AssertNonEmpty<T>(ReadOnlySpan<T> args)
950+
{
951+
if (args.IsEmpty)
952+
DbgFail();
953+
}
954+
[Conditional("DEBUG")]
955+
public static void AssertNonEmpty<T>(Span<T> args)
956+
{
957+
if (args.IsEmpty)
958+
DbgFail();
959+
}
960+
948961
[Conditional("DEBUG")]
949962
public static void AssertNonEmpty<T>(ICollection<T> args)
950963
{

src/Microsoft.ML.CpuMath/AvxIntrinsics.cs

Lines changed: 74 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
using System;
1313
using System.Runtime.CompilerServices;
14+
using System.Runtime.InteropServices;
1415
using System.Runtime.Intrinsics;
1516
using System.Runtime.Intrinsics.X86;
1617
using nuint = System.UInt64;
@@ -448,7 +449,7 @@ public static unsafe void MatMulTranPX(bool add, AlignedArray mat, int[] rgposSr
448449
// dst[i] += scale
449450
public static unsafe void AddScalarU(float scalar, Span<float> dst)
450451
{
451-
fixed (float* pdst = dst)
452+
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
452453
{
453454
float* pDstEnd = pdst + dst.Length;
454455
float* pDstCurrent = pdst;
@@ -490,7 +491,7 @@ public static unsafe void Scale(float scale, Span<float> dst)
490491
{
491492
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
492493
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
493-
fixed (float* pd = dst)
494+
fixed (float* pd = &MemoryMarshal.GetReference(dst))
494495
{
495496
float* pDstCurrent = pd;
496497
int length = dst.Length;
@@ -606,12 +607,12 @@ public static unsafe void Scale(float scale, Span<float> dst)
606607
}
607608
}
608609

609-
public static unsafe void ScaleSrcU(float scale, Span<float> src, Span<float> dst)
610+
public static unsafe void ScaleSrcU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count)
610611
{
611-
fixed (float* psrc = src)
612-
fixed (float* pdst = dst)
612+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
613+
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
613614
{
614-
float* pDstEnd = pdst + dst.Length;
615+
float* pDstEnd = pdst + count;
615616
float* pSrcCurrent = psrc;
616617
float* pDstCurrent = pdst;
617618

@@ -654,7 +655,7 @@ public static unsafe void ScaleSrcU(float scale, Span<float> src, Span<float> ds
654655
// dst[i] = a * (dst[i] + b)
655656
public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
656657
{
657-
fixed (float* pdst = dst)
658+
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
658659
{
659660
float* pDstEnd = pdst + dst.Length;
660661
float* pDstCurrent = pdst;
@@ -697,14 +698,14 @@ public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
697698
}
698699
}
699700

700-
public static unsafe void AddScaleU(float scale, Span<float> src, Span<float> dst)
701+
public static unsafe void AddScaleU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count)
701702
{
702-
fixed (float* psrc = src)
703-
fixed (float* pdst = dst)
703+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
704+
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
704705
{
705706
float* pSrcCurrent = psrc;
706707
float* pDstCurrent = pdst;
707-
float* pEnd = pdst + dst.Length;
708+
float* pEnd = pdst + count;
708709

709710
Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);
710711

@@ -751,13 +752,13 @@ public static unsafe void AddScaleU(float scale, Span<float> src, Span<float> ds
751752
}
752753
}
753754

754-
public static unsafe void AddScaleCopyU(float scale, Span<float> src, Span<float> dst, Span<float> result)
755+
public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<float> dst, Span<float> result, int count)
755756
{
756-
fixed (float* psrc = src)
757-
fixed (float* pdst = dst)
758-
fixed (float* pres = result)
757+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
758+
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
759+
fixed (float* pres = &MemoryMarshal.GetReference(result))
759760
{
760-
float* pResEnd = pres + result.Length;
761+
float* pResEnd = pres + count;
761762
float* pSrcCurrent = psrc;
762763
float* pDstCurrent = pdst;
763764
float* pResCurrent = pres;
@@ -807,16 +808,16 @@ public static unsafe void AddScaleCopyU(float scale, Span<float> src, Span<float
807808
}
808809
}
809810

810-
public static unsafe void AddScaleSU(float scale, Span<float> src, Span<int> idx, Span<float> dst)
811+
public static unsafe void AddScaleSU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count)
811812
{
812-
fixed (float* psrc = src)
813-
fixed (int* pidx = idx)
814-
fixed (float* pdst = dst)
813+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
814+
fixed (int* pidx = &MemoryMarshal.GetReference(idx))
815+
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
815816
{
816817
float* pSrcCurrent = psrc;
817818
int* pIdxCurrent = pidx;
818819
float* pDstCurrent = pdst;
819-
int* pEnd = pidx + idx.Length;
820+
int* pEnd = pidx + count;
820821

821822
Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);
822823

@@ -858,14 +859,14 @@ public static unsafe void AddScaleSU(float scale, Span<float> src, Span<int> idx
858859
}
859860
}
860861

861-
public static unsafe void AddU(Span<float> src, Span<float> dst)
862+
public static unsafe void AddU(ReadOnlySpan<float> src, Span<float> dst, int count)
862863
{
863-
fixed (float* psrc = src)
864-
fixed (float* pdst = dst)
864+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
865+
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
865866
{
866867
float* pSrcCurrent = psrc;
867868
float* pDstCurrent = pdst;
868-
float* pEnd = psrc + src.Length;
869+
float* pEnd = psrc + count;
869870

870871
while (pSrcCurrent + 8 <= pEnd)
871872
{
@@ -905,16 +906,16 @@ public static unsafe void AddU(Span<float> src, Span<float> dst)
905906
}
906907
}
907908

908-
public static unsafe void AddSU(Span<float> src, Span<int> idx, Span<float> dst)
909+
public static unsafe void AddSU(ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count)
909910
{
910-
fixed (float* psrc = src)
911-
fixed (int* pidx = idx)
912-
fixed (float* pdst = dst)
911+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
912+
fixed (int* pidx = &MemoryMarshal.GetReference(idx))
913+
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
913914
{
914915
float* pSrcCurrent = psrc;
915916
int* pIdxCurrent = pidx;
916917
float* pDstCurrent = pdst;
917-
int* pEnd = pidx + idx.Length;
918+
int* pEnd = pidx + count;
918919

919920
while (pIdxCurrent + 8 <= pEnd)
920921
{
@@ -950,16 +951,16 @@ public static unsafe void AddSU(Span<float> src, Span<int> idx, Span<float> dst)
950951
}
951952
}
952953

953-
public static unsafe void MulElementWiseU(Span<float> src1, Span<float> src2, Span<float> dst)
954+
public static unsafe void MulElementWiseU(ReadOnlySpan<float> src1, ReadOnlySpan<float> src2, Span<float> dst, int count)
954955
{
955-
fixed (float* psrc1 = src1)
956-
fixed (float* psrc2 = src2)
957-
fixed (float* pdst = dst)
956+
fixed (float* psrc1 = &MemoryMarshal.GetReference(src1))
957+
fixed (float* psrc2 = &MemoryMarshal.GetReference(src2))
958+
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
958959
{
959960
float* pSrc1Current = psrc1;
960961
float* pSrc2Current = psrc2;
961962
float* pDstCurrent = pdst;
962-
float* pEnd = pdst + dst.Length;
963+
float* pEnd = pdst + count;
963964

964965
while (pDstCurrent + 8 <= pEnd)
965966
{
@@ -999,9 +1000,9 @@ public static unsafe void MulElementWiseU(Span<float> src1, Span<float> src2, Sp
9991000
}
10001001
}
10011002

1002-
public static unsafe float SumU(Span<float> src)
1003+
public static unsafe float SumU(ReadOnlySpan<float> src)
10031004
{
1004-
fixed (float* psrc = src)
1005+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
10051006
{
10061007
float* pSrcEnd = psrc + src.Length;
10071008
float* pSrcCurrent = psrc;
@@ -1037,9 +1038,9 @@ public static unsafe float SumU(Span<float> src)
10371038
}
10381039
}
10391040

1040-
public static unsafe float SumSqU(Span<float> src)
1041+
public static unsafe float SumSqU(ReadOnlySpan<float> src)
10411042
{
1042-
fixed (float* psrc = src)
1043+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
10431044
{
10441045
float* pSrcEnd = psrc + src.Length;
10451046
float* pSrcCurrent = psrc;
@@ -1081,9 +1082,9 @@ public static unsafe float SumSqU(Span<float> src)
10811082
}
10821083
}
10831084

1084-
public static unsafe float SumSqDiffU(float mean, Span<float> src)
1085+
public static unsafe float SumSqDiffU(float mean, ReadOnlySpan<float> src)
10851086
{
1086-
fixed (float* psrc = src)
1087+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
10871088
{
10881089
float* pSrcEnd = psrc + src.Length;
10891090
float* pSrcCurrent = psrc;
@@ -1130,9 +1131,9 @@ public static unsafe float SumSqDiffU(float mean, Span<float> src)
11301131
}
11311132
}
11321133

1133-
public static unsafe float SumAbsU(Span<float> src)
1134+
public static unsafe float SumAbsU(ReadOnlySpan<float> src)
11341135
{
1135-
fixed (float* psrc = src)
1136+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
11361137
{
11371138
float* pSrcEnd = psrc + src.Length;
11381139
float* pSrcCurrent = psrc;
@@ -1174,9 +1175,9 @@ public static unsafe float SumAbsU(Span<float> src)
11741175
}
11751176
}
11761177

1177-
public static unsafe float SumAbsDiffU(float mean, Span<float> src)
1178+
public static unsafe float SumAbsDiffU(float mean, ReadOnlySpan<float> src)
11781179
{
1179-
fixed (float* psrc = src)
1180+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
11801181
{
11811182
float* pSrcEnd = psrc + src.Length;
11821183
float* pSrcCurrent = psrc;
@@ -1223,9 +1224,9 @@ public static unsafe float SumAbsDiffU(float mean, Span<float> src)
12231224
}
12241225
}
12251226

1226-
public static unsafe float MaxAbsU(Span<float> src)
1227+
public static unsafe float MaxAbsU(ReadOnlySpan<float> src)
12271228
{
1228-
fixed (float* psrc = src)
1229+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
12291230
{
12301231
float* pSrcEnd = psrc + src.Length;
12311232
float* pSrcCurrent = psrc;
@@ -1267,9 +1268,9 @@ public static unsafe float MaxAbsU(Span<float> src)
12671268
}
12681269
}
12691270

1270-
public static unsafe float MaxAbsDiffU(float mean, Span<float> src)
1271+
public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan<float> src)
12711272
{
1272-
fixed (float* psrc = src)
1273+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
12731274
{
12741275
float* pSrcEnd = psrc + src.Length;
12751276
float* pSrcCurrent = psrc;
@@ -1316,14 +1317,14 @@ public static unsafe float MaxAbsDiffU(float mean, Span<float> src)
13161317
}
13171318
}
13181319

1319-
public static unsafe float DotU(Span<float> src, Span<float> dst)
1320+
public static unsafe float DotU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count)
13201321
{
1321-
fixed (float* psrc = src)
1322-
fixed (float* pdst = dst)
1322+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
1323+
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
13231324
{
13241325
float* pSrcCurrent = psrc;
13251326
float* pDstCurrent = pdst;
1326-
float* pSrcEnd = psrc + src.Length;
1327+
float* pSrcEnd = psrc + count;
13271328

13281329
Vector256<float> result256 = Avx.SetZeroVector256<float>();
13291330

@@ -1371,16 +1372,16 @@ public static unsafe float DotU(Span<float> src, Span<float> dst)
13711372
}
13721373
}
13731374

1374-
public static unsafe float DotSU(Span<float> src, Span<float> dst, Span<int> idx)
1375+
public static unsafe float DotSU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, ReadOnlySpan<int> idx, int count)
13751376
{
1376-
fixed (float* psrc = src)
1377-
fixed (float* pdst = dst)
1378-
fixed (int* pidx = idx)
1377+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
1378+
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
1379+
fixed (int* pidx = &MemoryMarshal.GetReference(idx))
13791380
{
13801381
float* pSrcCurrent = psrc;
13811382
float* pDstCurrent = pdst;
13821383
int* pIdxCurrent = pidx;
1383-
int* pIdxEnd = pidx + idx.Length;
1384+
int* pIdxEnd = pidx + count;
13841385

13851386
Vector256<float> result256 = Avx.SetZeroVector256<float>();
13861387

@@ -1428,14 +1429,14 @@ public static unsafe float DotSU(Span<float> src, Span<float> dst, Span<int> idx
14281429
}
14291430
}
14301431

1431-
public static unsafe float Dist2(Span<float> src, Span<float> dst)
1432+
public static unsafe float Dist2(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count)
14321433
{
1433-
fixed (float* psrc = src)
1434-
fixed (float* pdst = dst)
1434+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
1435+
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
14351436
{
14361437
float* pSrcCurrent = psrc;
14371438
float* pDstCurrent = pdst;
1438-
float* pSrcEnd = psrc + src.Length;
1439+
float* pSrcEnd = psrc + count;
14391440

14401441
Vector256<float> sqDistanceVector256 = Avx.SetZeroVector256<float>();
14411442

@@ -1482,13 +1483,13 @@ public static unsafe float Dist2(Span<float> src, Span<float> dst)
14821483
}
14831484
}
14841485

1485-
public static unsafe void SdcaL1UpdateU(float primalUpdate, Span<float> src, float threshold, Span<float> v, Span<float> w)
1486+
public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlySpan<float> src, float threshold, Span<float> v, Span<float> w)
14861487
{
1487-
fixed (float* psrc = src)
1488-
fixed (float* pdst1 = v)
1489-
fixed (float* pdst2 = w)
1488+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
1489+
fixed (float* pdst1 = &MemoryMarshal.GetReference(v))
1490+
fixed (float* pdst2 = &MemoryMarshal.GetReference(w))
14901491
{
1491-
float* pSrcEnd = psrc + src.Length;
1492+
float* pSrcEnd = psrc + count;
14921493
float* pSrcCurrent = psrc;
14931494
float* pDst1Current = pdst1;
14941495
float* pDst2Current = pdst2;
@@ -1544,14 +1545,14 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, Span<float> src, flo
15441545
}
15451546
}
15461547

1547-
public static unsafe void SdcaL1UpdateSU(float primalUpdate, Span<float> src, Span<int> indices, float threshold, Span<float> v, Span<float> w)
1548+
public static unsafe void SdcaL1UpdateSU(float primalUpdate, int count, ReadOnlySpan<float> src, ReadOnlySpan<int> indices, float threshold, Span<float> v, Span<float> w)
15481549
{
1549-
fixed (float* psrc = src)
1550-
fixed (int* pidx = indices)
1551-
fixed (float* pdst1 = v)
1552-
fixed (float* pdst2 = w)
1550+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
1551+
fixed (int* pidx = &MemoryMarshal.GetReference(indices))
1552+
fixed (float* pdst1 = &MemoryMarshal.GetReference(v))
1553+
fixed (float* pdst2 = &MemoryMarshal.GetReference(w))
15531554
{
1554-
int* pIdxEnd = pidx + indices.Length;
1555+
int* pIdxEnd = pidx + count;
15551556
float* pSrcCurrent = psrc;
15561557
int* pIdxCurrent = pidx;
15571558

0 commit comments

Comments
 (0)