Skip to content

Commit 9c7dc9c

Browse files
committed
Refactor CpuMathUtils
- Allow it to take Spans instead of arrays. - Remove redundant overloads - When multiple spans are accepted, always use an explicit count parameter instead of one being chosen as having the right length. Working towards dotnet#608
1 parent 0b84350 commit 9c7dc9c

File tree

17 files changed

+320
-740
lines changed

17 files changed

+320
-740
lines changed

src/Microsoft.ML.Core/Utilities/Contracts.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,19 @@ public static void AssertNonWhiteSpace(this IExceptionContext ctx, string s, str
945945
DbgFailEmpty(ctx, msg);
946946
}
947947

948+
[Conditional("DEBUG")]
949+
public static void AssertNonEmpty<T>(ReadOnlySpan<T> args)
950+
{
951+
if (args.IsEmpty)
952+
DbgFail();
953+
}
954+
[Conditional("DEBUG")]
955+
public static void AssertNonEmpty<T>(Span<T> args)
956+
{
957+
if (args.IsEmpty)
958+
DbgFail();
959+
}
960+
948961
[Conditional("DEBUG")]
949962
public static void AssertNonEmpty<T>(ICollection<T> args)
950963
{

src/Microsoft.ML.CpuMath/AvxIntrinsics.cs

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -606,12 +606,12 @@ public static unsafe void Scale(float scale, Span<float> dst)
606606
}
607607
}
608608

609-
public static unsafe void ScaleSrcU(float scale, Span<float> src, Span<float> dst)
609+
public static unsafe void ScaleSrcU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count)
610610
{
611611
fixed (float* psrc = src)
612612
fixed (float* pdst = dst)
613613
{
614-
float* pDstEnd = pdst + dst.Length;
614+
float* pDstEnd = pdst + count;
615615
float* pSrcCurrent = psrc;
616616
float* pDstCurrent = pdst;
617617

@@ -697,14 +697,14 @@ public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
697697
}
698698
}
699699

700-
public static unsafe void AddScaleU(float scale, Span<float> src, Span<float> dst)
700+
public static unsafe void AddScaleU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count)
701701
{
702702
fixed (float* psrc = src)
703703
fixed (float* pdst = dst)
704704
{
705705
float* pSrcCurrent = psrc;
706706
float* pDstCurrent = pdst;
707-
float* pEnd = pdst + dst.Length;
707+
float* pEnd = pdst + count;
708708

709709
Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);
710710

@@ -751,13 +751,13 @@ public static unsafe void AddScaleU(float scale, Span<float> src, Span<float> ds
751751
}
752752
}
753753

754-
public static unsafe void AddScaleCopyU(float scale, Span<float> src, Span<float> dst, Span<float> result)
754+
public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<float> dst, Span<float> result, int count)
755755
{
756756
fixed (float* psrc = src)
757757
fixed (float* pdst = dst)
758758
fixed (float* pres = result)
759759
{
760-
float* pResEnd = pres + result.Length;
760+
float* pResEnd = pres + count;
761761
float* pSrcCurrent = psrc;
762762
float* pDstCurrent = pdst;
763763
float* pResCurrent = pres;
@@ -807,7 +807,7 @@ public static unsafe void AddScaleCopyU(float scale, Span<float> src, Span<float
807807
}
808808
}
809809

810-
public static unsafe void AddScaleSU(float scale, Span<float> src, Span<int> idx, Span<float> dst)
810+
public static unsafe void AddScaleSU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count)
811811
{
812812
fixed (float* psrc = src)
813813
fixed (int* pidx = idx)
@@ -816,7 +816,7 @@ public static unsafe void AddScaleSU(float scale, Span<float> src, Span<int> idx
816816
float* pSrcCurrent = psrc;
817817
int* pIdxCurrent = pidx;
818818
float* pDstCurrent = pdst;
819-
int* pEnd = pidx + idx.Length;
819+
int* pEnd = pidx + count;
820820

821821
Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);
822822

@@ -858,14 +858,14 @@ public static unsafe void AddScaleSU(float scale, Span<float> src, Span<int> idx
858858
}
859859
}
860860

861-
public static unsafe void AddU(Span<float> src, Span<float> dst)
861+
public static unsafe void AddU(ReadOnlySpan<float> src, Span<float> dst, int count)
862862
{
863863
fixed (float* psrc = src)
864864
fixed (float* pdst = dst)
865865
{
866866
float* pSrcCurrent = psrc;
867867
float* pDstCurrent = pdst;
868-
float* pEnd = psrc + src.Length;
868+
float* pEnd = psrc + count;
869869

870870
while (pSrcCurrent + 8 <= pEnd)
871871
{
@@ -905,7 +905,7 @@ public static unsafe void AddU(Span<float> src, Span<float> dst)
905905
}
906906
}
907907

908-
public static unsafe void AddSU(Span<float> src, Span<int> idx, Span<float> dst)
908+
public static unsafe void AddSU(ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count)
909909
{
910910
fixed (float* psrc = src)
911911
fixed (int* pidx = idx)
@@ -914,7 +914,7 @@ public static unsafe void AddSU(Span<float> src, Span<int> idx, Span<float> dst)
914914
float* pSrcCurrent = psrc;
915915
int* pIdxCurrent = pidx;
916916
float* pDstCurrent = pdst;
917-
int* pEnd = pidx + idx.Length;
917+
int* pEnd = pidx + count;
918918

919919
while (pIdxCurrent + 8 <= pEnd)
920920
{
@@ -950,7 +950,7 @@ public static unsafe void AddSU(Span<float> src, Span<int> idx, Span<float> dst)
950950
}
951951
}
952952

953-
public static unsafe void MulElementWiseU(Span<float> src1, Span<float> src2, Span<float> dst)
953+
public static unsafe void MulElementWiseU(ReadOnlySpan<float> src1, ReadOnlySpan<float> src2, Span<float> dst, int count)
954954
{
955955
fixed (float* psrc1 = src1)
956956
fixed (float* psrc2 = src2)
@@ -959,7 +959,7 @@ public static unsafe void MulElementWiseU(Span<float> src1, Span<float> src2, Sp
959959
float* pSrc1Current = psrc1;
960960
float* pSrc2Current = psrc2;
961961
float* pDstCurrent = pdst;
962-
float* pEnd = pdst + dst.Length;
962+
float* pEnd = pdst + count;
963963

964964
while (pDstCurrent + 8 <= pEnd)
965965
{
@@ -999,7 +999,7 @@ public static unsafe void MulElementWiseU(Span<float> src1, Span<float> src2, Sp
999999
}
10001000
}
10011001

1002-
public static unsafe float SumU(Span<float> src)
1002+
public static unsafe float SumU(ReadOnlySpan<float> src)
10031003
{
10041004
fixed (float* psrc = src)
10051005
{
@@ -1037,7 +1037,7 @@ public static unsafe float SumU(Span<float> src)
10371037
}
10381038
}
10391039

1040-
public static unsafe float SumSqU(Span<float> src)
1040+
public static unsafe float SumSqU(ReadOnlySpan<float> src)
10411041
{
10421042
fixed (float* psrc = src)
10431043
{
@@ -1081,7 +1081,7 @@ public static unsafe float SumSqU(Span<float> src)
10811081
}
10821082
}
10831083

1084-
public static unsafe float SumSqDiffU(float mean, Span<float> src)
1084+
public static unsafe float SumSqDiffU(float mean, ReadOnlySpan<float> src)
10851085
{
10861086
fixed (float* psrc = src)
10871087
{
@@ -1130,7 +1130,7 @@ public static unsafe float SumSqDiffU(float mean, Span<float> src)
11301130
}
11311131
}
11321132

1133-
public static unsafe float SumAbsU(Span<float> src)
1133+
public static unsafe float SumAbsU(ReadOnlySpan<float> src)
11341134
{
11351135
fixed (float* psrc = src)
11361136
{
@@ -1174,7 +1174,7 @@ public static unsafe float SumAbsU(Span<float> src)
11741174
}
11751175
}
11761176

1177-
public static unsafe float SumAbsDiffU(float mean, Span<float> src)
1177+
public static unsafe float SumAbsDiffU(float mean, ReadOnlySpan<float> src)
11781178
{
11791179
fixed (float* psrc = src)
11801180
{
@@ -1223,7 +1223,7 @@ public static unsafe float SumAbsDiffU(float mean, Span<float> src)
12231223
}
12241224
}
12251225

1226-
public static unsafe float MaxAbsU(Span<float> src)
1226+
public static unsafe float MaxAbsU(ReadOnlySpan<float> src)
12271227
{
12281228
fixed (float* psrc = src)
12291229
{
@@ -1267,7 +1267,7 @@ public static unsafe float MaxAbsU(Span<float> src)
12671267
}
12681268
}
12691269

1270-
public static unsafe float MaxAbsDiffU(float mean, Span<float> src)
1270+
public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan<float> src)
12711271
{
12721272
fixed (float* psrc = src)
12731273
{
@@ -1316,14 +1316,14 @@ public static unsafe float MaxAbsDiffU(float mean, Span<float> src)
13161316
}
13171317
}
13181318

1319-
public static unsafe float DotU(Span<float> src, Span<float> dst)
1319+
public static unsafe float DotU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count)
13201320
{
13211321
fixed (float* psrc = src)
13221322
fixed (float* pdst = dst)
13231323
{
13241324
float* pSrcCurrent = psrc;
13251325
float* pDstCurrent = pdst;
1326-
float* pSrcEnd = psrc + src.Length;
1326+
float* pSrcEnd = psrc + count;
13271327

13281328
Vector256<float> result256 = Avx.SetZeroVector256<float>();
13291329

@@ -1371,7 +1371,7 @@ public static unsafe float DotU(Span<float> src, Span<float> dst)
13711371
}
13721372
}
13731373

1374-
public static unsafe float DotSU(Span<float> src, Span<float> dst, Span<int> idx)
1374+
public static unsafe float DotSU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, ReadOnlySpan<int> idx, int count)
13751375
{
13761376
fixed (float* psrc = src)
13771377
fixed (float* pdst = dst)
@@ -1380,7 +1380,7 @@ public static unsafe float DotSU(Span<float> src, Span<float> dst, Span<int> idx
13801380
float* pSrcCurrent = psrc;
13811381
float* pDstCurrent = pdst;
13821382
int* pIdxCurrent = pidx;
1383-
int* pIdxEnd = pidx + idx.Length;
1383+
int* pIdxEnd = pidx + count;
13841384

13851385
Vector256<float> result256 = Avx.SetZeroVector256<float>();
13861386

@@ -1428,14 +1428,14 @@ public static unsafe float DotSU(Span<float> src, Span<float> dst, Span<int> idx
14281428
}
14291429
}
14301430

1431-
public static unsafe float Dist2(Span<float> src, Span<float> dst)
1431+
public static unsafe float Dist2(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count)
14321432
{
14331433
fixed (float* psrc = src)
14341434
fixed (float* pdst = dst)
14351435
{
14361436
float* pSrcCurrent = psrc;
14371437
float* pDstCurrent = pdst;
1438-
float* pSrcEnd = psrc + src.Length;
1438+
float* pSrcEnd = psrc + count;
14391439

14401440
Vector256<float> sqDistanceVector256 = Avx.SetZeroVector256<float>();
14411441

@@ -1482,13 +1482,13 @@ public static unsafe float Dist2(Span<float> src, Span<float> dst)
14821482
}
14831483
}
14841484

1485-
public static unsafe void SdcaL1UpdateU(float primalUpdate, Span<float> src, float threshold, Span<float> v, Span<float> w)
1485+
public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlySpan<float> src, float threshold, Span<float> v, Span<float> w)
14861486
{
14871487
fixed (float* psrc = src)
14881488
fixed (float* pdst1 = v)
14891489
fixed (float* pdst2 = w)
14901490
{
1491-
float* pSrcEnd = psrc + src.Length;
1491+
float* pSrcEnd = psrc + count;
14921492
float* pSrcCurrent = psrc;
14931493
float* pDst1Current = pdst1;
14941494
float* pDst2Current = pdst2;
@@ -1544,14 +1544,14 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, Span<float> src, flo
15441544
}
15451545
}
15461546

1547-
public static unsafe void SdcaL1UpdateSU(float primalUpdate, Span<float> src, Span<int> indices, float threshold, Span<float> v, Span<float> w)
1547+
public static unsafe void SdcaL1UpdateSU(float primalUpdate, int count, ReadOnlySpan<float> src, ReadOnlySpan<int> indices, float threshold, Span<float> v, Span<float> w)
15481548
{
15491549
fixed (float* psrc = src)
15501550
fixed (int* pidx = indices)
15511551
fixed (float* pdst1 = v)
15521552
fixed (float* pdst2 = w)
15531553
{
1554-
int* pIdxEnd = pidx + indices.Length;
1554+
int* pIdxEnd = pidx + count;
15551555
float* pSrcCurrent = psrc;
15561556
int* pIdxCurrent = pidx;
15571557

0 commit comments

Comments
 (0)