|
11 | 11 |
|
12 | 12 | using System;
|
13 | 13 | using System.Runtime.CompilerServices;
|
| 14 | +using System.Runtime.InteropServices; |
14 | 15 | using System.Runtime.Intrinsics;
|
15 | 16 | using System.Runtime.Intrinsics.X86;
|
16 | 17 | using nuint = System.UInt64;
|
@@ -448,7 +449,7 @@ public static unsafe void MatMulTranPX(bool add, AlignedArray mat, int[] rgposSr
|
448 | 449 | // dst[i] += scale
|
449 | 450 | public static unsafe void AddScalarU(float scalar, Span<float> dst)
|
450 | 451 | {
|
451 |
| - fixed (float* pdst = dst) |
| 452 | + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) |
452 | 453 | {
|
453 | 454 | float* pDstEnd = pdst + dst.Length;
|
454 | 455 | float* pDstCurrent = pdst;
|
@@ -490,7 +491,7 @@ public static unsafe void Scale(float scale, Span<float> dst)
|
490 | 491 | {
|
491 | 492 | fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
|
492 | 493 | fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
|
493 |
| - fixed (float* pd = dst) |
| 494 | + fixed (float* pd = &MemoryMarshal.GetReference(dst)) |
494 | 495 | {
|
495 | 496 | float* pDstCurrent = pd;
|
496 | 497 | int length = dst.Length;
|
@@ -606,12 +607,12 @@ public static unsafe void Scale(float scale, Span<float> dst)
|
606 | 607 | }
|
607 | 608 | }
|
608 | 609 |
|
609 |
| - public static unsafe void ScaleSrcU(float scale, Span<float> src, Span<float> dst) |
| 610 | + public static unsafe void ScaleSrcU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count) |
610 | 611 | {
|
611 |
| - fixed (float* psrc = src) |
612 |
| - fixed (float* pdst = dst) |
| 612 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
| 613 | + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) |
613 | 614 | {
|
614 |
| - float* pDstEnd = pdst + dst.Length; |
| 615 | + float* pDstEnd = pdst + count; |
615 | 616 | float* pSrcCurrent = psrc;
|
616 | 617 | float* pDstCurrent = pdst;
|
617 | 618 |
|
@@ -654,7 +655,7 @@ public static unsafe void ScaleSrcU(float scale, Span<float> src, Span<float> ds
|
654 | 655 | // dst[i] = a * (dst[i] + b)
|
655 | 656 | public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
|
656 | 657 | {
|
657 |
| - fixed (float* pdst = dst) |
| 658 | + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) |
658 | 659 | {
|
659 | 660 | float* pDstEnd = pdst + dst.Length;
|
660 | 661 | float* pDstCurrent = pdst;
|
@@ -697,14 +698,14 @@ public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
|
697 | 698 | }
|
698 | 699 | }
|
699 | 700 |
|
700 |
| - public static unsafe void AddScaleU(float scale, Span<float> src, Span<float> dst) |
| 701 | + public static unsafe void AddScaleU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count) |
701 | 702 | {
|
702 |
| - fixed (float* psrc = src) |
703 |
| - fixed (float* pdst = dst) |
| 703 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
| 704 | + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) |
704 | 705 | {
|
705 | 706 | float* pSrcCurrent = psrc;
|
706 | 707 | float* pDstCurrent = pdst;
|
707 |
| - float* pEnd = pdst + dst.Length; |
| 708 | + float* pEnd = pdst + count; |
708 | 709 |
|
709 | 710 | Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);
|
710 | 711 |
|
@@ -751,13 +752,13 @@ public static unsafe void AddScaleU(float scale, Span<float> src, Span<float> ds
|
751 | 752 | }
|
752 | 753 | }
|
753 | 754 |
|
754 |
| - public static unsafe void AddScaleCopyU(float scale, Span<float> src, Span<float> dst, Span<float> result) |
| 755 | + public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<float> dst, Span<float> result, int count) |
755 | 756 | {
|
756 |
| - fixed (float* psrc = src) |
757 |
| - fixed (float* pdst = dst) |
758 |
| - fixed (float* pres = result) |
| 757 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
| 758 | + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) |
| 759 | + fixed (float* pres = &MemoryMarshal.GetReference(result)) |
759 | 760 | {
|
760 |
| - float* pResEnd = pres + result.Length; |
| 761 | + float* pResEnd = pres + count; |
761 | 762 | float* pSrcCurrent = psrc;
|
762 | 763 | float* pDstCurrent = pdst;
|
763 | 764 | float* pResCurrent = pres;
|
@@ -807,16 +808,16 @@ public static unsafe void AddScaleCopyU(float scale, Span<float> src, Span<float
|
807 | 808 | }
|
808 | 809 | }
|
809 | 810 |
|
810 |
| - public static unsafe void AddScaleSU(float scale, Span<float> src, Span<int> idx, Span<float> dst) |
| 811 | + public static unsafe void AddScaleSU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count) |
811 | 812 | {
|
812 |
| - fixed (float* psrc = src) |
813 |
| - fixed (int* pidx = idx) |
814 |
| - fixed (float* pdst = dst) |
| 813 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
| 814 | + fixed (int* pidx = &MemoryMarshal.GetReference(idx)) |
| 815 | + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) |
815 | 816 | {
|
816 | 817 | float* pSrcCurrent = psrc;
|
817 | 818 | int* pIdxCurrent = pidx;
|
818 | 819 | float* pDstCurrent = pdst;
|
819 |
| - int* pEnd = pidx + idx.Length; |
| 820 | + int* pEnd = pidx + count; |
820 | 821 |
|
821 | 822 | Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);
|
822 | 823 |
|
@@ -858,14 +859,14 @@ public static unsafe void AddScaleSU(float scale, Span<float> src, Span<int> idx
|
858 | 859 | }
|
859 | 860 | }
|
860 | 861 |
|
861 |
| - public static unsafe void AddU(Span<float> src, Span<float> dst) |
| 862 | + public static unsafe void AddU(ReadOnlySpan<float> src, Span<float> dst, int count) |
862 | 863 | {
|
863 |
| - fixed (float* psrc = src) |
864 |
| - fixed (float* pdst = dst) |
| 864 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
| 865 | + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) |
865 | 866 | {
|
866 | 867 | float* pSrcCurrent = psrc;
|
867 | 868 | float* pDstCurrent = pdst;
|
868 |
| - float* pEnd = psrc + src.Length; |
| 869 | + float* pEnd = psrc + count; |
869 | 870 |
|
870 | 871 | while (pSrcCurrent + 8 <= pEnd)
|
871 | 872 | {
|
@@ -905,16 +906,16 @@ public static unsafe void AddU(Span<float> src, Span<float> dst)
|
905 | 906 | }
|
906 | 907 | }
|
907 | 908 |
|
908 |
| - public static unsafe void AddSU(Span<float> src, Span<int> idx, Span<float> dst) |
| 909 | + public static unsafe void AddSU(ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count) |
909 | 910 | {
|
910 |
| - fixed (float* psrc = src) |
911 |
| - fixed (int* pidx = idx) |
912 |
| - fixed (float* pdst = dst) |
| 911 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
| 912 | + fixed (int* pidx = &MemoryMarshal.GetReference(idx)) |
| 913 | + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) |
913 | 914 | {
|
914 | 915 | float* pSrcCurrent = psrc;
|
915 | 916 | int* pIdxCurrent = pidx;
|
916 | 917 | float* pDstCurrent = pdst;
|
917 |
| - int* pEnd = pidx + idx.Length; |
| 918 | + int* pEnd = pidx + count; |
918 | 919 |
|
919 | 920 | while (pIdxCurrent + 8 <= pEnd)
|
920 | 921 | {
|
@@ -950,16 +951,16 @@ public static unsafe void AddSU(Span<float> src, Span<int> idx, Span<float> dst)
|
950 | 951 | }
|
951 | 952 | }
|
952 | 953 |
|
953 |
| - public static unsafe void MulElementWiseU(Span<float> src1, Span<float> src2, Span<float> dst) |
| 954 | + public static unsafe void MulElementWiseU(ReadOnlySpan<float> src1, ReadOnlySpan<float> src2, Span<float> dst, int count) |
954 | 955 | {
|
955 |
| - fixed (float* psrc1 = src1) |
956 |
| - fixed (float* psrc2 = src2) |
957 |
| - fixed (float* pdst = dst) |
| 956 | + fixed (float* psrc1 = &MemoryMarshal.GetReference(src1)) |
| 957 | + fixed (float* psrc2 = &MemoryMarshal.GetReference(src2)) |
| 958 | + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) |
958 | 959 | {
|
959 | 960 | float* pSrc1Current = psrc1;
|
960 | 961 | float* pSrc2Current = psrc2;
|
961 | 962 | float* pDstCurrent = pdst;
|
962 |
| - float* pEnd = pdst + dst.Length; |
| 963 | + float* pEnd = pdst + count; |
963 | 964 |
|
964 | 965 | while (pDstCurrent + 8 <= pEnd)
|
965 | 966 | {
|
@@ -999,9 +1000,9 @@ public static unsafe void MulElementWiseU(Span<float> src1, Span<float> src2, Sp
|
999 | 1000 | }
|
1000 | 1001 | }
|
1001 | 1002 |
|
1002 |
| - public static unsafe float SumU(Span<float> src) |
| 1003 | + public static unsafe float SumU(ReadOnlySpan<float> src) |
1003 | 1004 | {
|
1004 |
| - fixed (float* psrc = src) |
| 1005 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
1005 | 1006 | {
|
1006 | 1007 | float* pSrcEnd = psrc + src.Length;
|
1007 | 1008 | float* pSrcCurrent = psrc;
|
@@ -1037,9 +1038,9 @@ public static unsafe float SumU(Span<float> src)
|
1037 | 1038 | }
|
1038 | 1039 | }
|
1039 | 1040 |
|
1040 |
| - public static unsafe float SumSqU(Span<float> src) |
| 1041 | + public static unsafe float SumSqU(ReadOnlySpan<float> src) |
1041 | 1042 | {
|
1042 |
| - fixed (float* psrc = src) |
| 1043 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
1043 | 1044 | {
|
1044 | 1045 | float* pSrcEnd = psrc + src.Length;
|
1045 | 1046 | float* pSrcCurrent = psrc;
|
@@ -1081,9 +1082,9 @@ public static unsafe float SumSqU(Span<float> src)
|
1081 | 1082 | }
|
1082 | 1083 | }
|
1083 | 1084 |
|
1084 |
| - public static unsafe float SumSqDiffU(float mean, Span<float> src) |
| 1085 | + public static unsafe float SumSqDiffU(float mean, ReadOnlySpan<float> src) |
1085 | 1086 | {
|
1086 |
| - fixed (float* psrc = src) |
| 1087 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
1087 | 1088 | {
|
1088 | 1089 | float* pSrcEnd = psrc + src.Length;
|
1089 | 1090 | float* pSrcCurrent = psrc;
|
@@ -1130,9 +1131,9 @@ public static unsafe float SumSqDiffU(float mean, Span<float> src)
|
1130 | 1131 | }
|
1131 | 1132 | }
|
1132 | 1133 |
|
1133 |
| - public static unsafe float SumAbsU(Span<float> src) |
| 1134 | + public static unsafe float SumAbsU(ReadOnlySpan<float> src) |
1134 | 1135 | {
|
1135 |
| - fixed (float* psrc = src) |
| 1136 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
1136 | 1137 | {
|
1137 | 1138 | float* pSrcEnd = psrc + src.Length;
|
1138 | 1139 | float* pSrcCurrent = psrc;
|
@@ -1174,9 +1175,9 @@ public static unsafe float SumAbsU(Span<float> src)
|
1174 | 1175 | }
|
1175 | 1176 | }
|
1176 | 1177 |
|
1177 |
| - public static unsafe float SumAbsDiffU(float mean, Span<float> src) |
| 1178 | + public static unsafe float SumAbsDiffU(float mean, ReadOnlySpan<float> src) |
1178 | 1179 | {
|
1179 |
| - fixed (float* psrc = src) |
| 1180 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
1180 | 1181 | {
|
1181 | 1182 | float* pSrcEnd = psrc + src.Length;
|
1182 | 1183 | float* pSrcCurrent = psrc;
|
@@ -1223,9 +1224,9 @@ public static unsafe float SumAbsDiffU(float mean, Span<float> src)
|
1223 | 1224 | }
|
1224 | 1225 | }
|
1225 | 1226 |
|
1226 |
| - public static unsafe float MaxAbsU(Span<float> src) |
| 1227 | + public static unsafe float MaxAbsU(ReadOnlySpan<float> src) |
1227 | 1228 | {
|
1228 |
| - fixed (float* psrc = src) |
| 1229 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
1229 | 1230 | {
|
1230 | 1231 | float* pSrcEnd = psrc + src.Length;
|
1231 | 1232 | float* pSrcCurrent = psrc;
|
@@ -1267,9 +1268,9 @@ public static unsafe float MaxAbsU(Span<float> src)
|
1267 | 1268 | }
|
1268 | 1269 | }
|
1269 | 1270 |
|
1270 |
| - public static unsafe float MaxAbsDiffU(float mean, Span<float> src) |
| 1271 | + public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan<float> src) |
1271 | 1272 | {
|
1272 |
| - fixed (float* psrc = src) |
| 1273 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
1273 | 1274 | {
|
1274 | 1275 | float* pSrcEnd = psrc + src.Length;
|
1275 | 1276 | float* pSrcCurrent = psrc;
|
@@ -1316,14 +1317,14 @@ public static unsafe float MaxAbsDiffU(float mean, Span<float> src)
|
1316 | 1317 | }
|
1317 | 1318 | }
|
1318 | 1319 |
|
1319 |
| - public static unsafe float DotU(Span<float> src, Span<float> dst) |
| 1320 | + public static unsafe float DotU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count) |
1320 | 1321 | {
|
1321 |
| - fixed (float* psrc = src) |
1322 |
| - fixed (float* pdst = dst) |
| 1322 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
| 1323 | + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) |
1323 | 1324 | {
|
1324 | 1325 | float* pSrcCurrent = psrc;
|
1325 | 1326 | float* pDstCurrent = pdst;
|
1326 |
| - float* pSrcEnd = psrc + src.Length; |
| 1327 | + float* pSrcEnd = psrc + count; |
1327 | 1328 |
|
1328 | 1329 | Vector256<float> result256 = Avx.SetZeroVector256<float>();
|
1329 | 1330 |
|
@@ -1371,16 +1372,16 @@ public static unsafe float DotU(Span<float> src, Span<float> dst)
|
1371 | 1372 | }
|
1372 | 1373 | }
|
1373 | 1374 |
|
1374 |
| - public static unsafe float DotSU(Span<float> src, Span<float> dst, Span<int> idx) |
| 1375 | + public static unsafe float DotSU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, ReadOnlySpan<int> idx, int count) |
1375 | 1376 | {
|
1376 |
| - fixed (float* psrc = src) |
1377 |
| - fixed (float* pdst = dst) |
1378 |
| - fixed (int* pidx = idx) |
| 1377 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
| 1378 | + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) |
| 1379 | + fixed (int* pidx = &MemoryMarshal.GetReference(idx)) |
1379 | 1380 | {
|
1380 | 1381 | float* pSrcCurrent = psrc;
|
1381 | 1382 | float* pDstCurrent = pdst;
|
1382 | 1383 | int* pIdxCurrent = pidx;
|
1383 |
| - int* pIdxEnd = pidx + idx.Length; |
| 1384 | + int* pIdxEnd = pidx + count; |
1384 | 1385 |
|
1385 | 1386 | Vector256<float> result256 = Avx.SetZeroVector256<float>();
|
1386 | 1387 |
|
@@ -1428,14 +1429,14 @@ public static unsafe float DotSU(Span<float> src, Span<float> dst, Span<int> idx
|
1428 | 1429 | }
|
1429 | 1430 | }
|
1430 | 1431 |
|
1431 |
| - public static unsafe float Dist2(Span<float> src, Span<float> dst) |
| 1432 | + public static unsafe float Dist2(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count) |
1432 | 1433 | {
|
1433 |
| - fixed (float* psrc = src) |
1434 |
| - fixed (float* pdst = dst) |
| 1434 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
| 1435 | + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) |
1435 | 1436 | {
|
1436 | 1437 | float* pSrcCurrent = psrc;
|
1437 | 1438 | float* pDstCurrent = pdst;
|
1438 |
| - float* pSrcEnd = psrc + src.Length; |
| 1439 | + float* pSrcEnd = psrc + count; |
1439 | 1440 |
|
1440 | 1441 | Vector256<float> sqDistanceVector256 = Avx.SetZeroVector256<float>();
|
1441 | 1442 |
|
@@ -1482,13 +1483,13 @@ public static unsafe float Dist2(Span<float> src, Span<float> dst)
|
1482 | 1483 | }
|
1483 | 1484 | }
|
1484 | 1485 |
|
1485 |
| - public static unsafe void SdcaL1UpdateU(float primalUpdate, Span<float> src, float threshold, Span<float> v, Span<float> w) |
| 1486 | + public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlySpan<float> src, float threshold, Span<float> v, Span<float> w) |
1486 | 1487 | {
|
1487 |
| - fixed (float* psrc = src) |
1488 |
| - fixed (float* pdst1 = v) |
1489 |
| - fixed (float* pdst2 = w) |
| 1488 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
| 1489 | + fixed (float* pdst1 = &MemoryMarshal.GetReference(v)) |
| 1490 | + fixed (float* pdst2 = &MemoryMarshal.GetReference(w)) |
1490 | 1491 | {
|
1491 |
| - float* pSrcEnd = psrc + src.Length; |
| 1492 | + float* pSrcEnd = psrc + count; |
1492 | 1493 | float* pSrcCurrent = psrc;
|
1493 | 1494 | float* pDst1Current = pdst1;
|
1494 | 1495 | float* pDst2Current = pdst2;
|
@@ -1544,14 +1545,14 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, Span<float> src, flo
|
1544 | 1545 | }
|
1545 | 1546 | }
|
1546 | 1547 |
|
1547 |
| - public static unsafe void SdcaL1UpdateSU(float primalUpdate, Span<float> src, Span<int> indices, float threshold, Span<float> v, Span<float> w) |
| 1548 | + public static unsafe void SdcaL1UpdateSU(float primalUpdate, int count, ReadOnlySpan<float> src, ReadOnlySpan<int> indices, float threshold, Span<float> v, Span<float> w) |
1548 | 1549 | {
|
1549 |
| - fixed (float* psrc = src) |
1550 |
| - fixed (int* pidx = indices) |
1551 |
| - fixed (float* pdst1 = v) |
1552 |
| - fixed (float* pdst2 = w) |
| 1550 | + fixed (float* psrc = &MemoryMarshal.GetReference(src)) |
| 1551 | + fixed (int* pidx = &MemoryMarshal.GetReference(indices)) |
| 1552 | + fixed (float* pdst1 = &MemoryMarshal.GetReference(v)) |
| 1553 | + fixed (float* pdst2 = &MemoryMarshal.GetReference(w)) |
1553 | 1554 | {
|
1554 |
| - int* pIdxEnd = pidx + indices.Length; |
| 1555 | + int* pIdxEnd = pidx + count; |
1555 | 1556 | float* pSrcCurrent = psrc;
|
1556 | 1557 | int* pIdxCurrent = pidx;
|
1557 | 1558 |
|
|
0 commit comments