Skip to content

Made loop bound checking in hardware intrinsics more efficient #994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions src/Microsoft.ML.CpuMath/AvxIntrinsics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ public static unsafe void AddScalarU(float scalar, Span<float> dst)

Vector256<float> scalarVector256 = Avx.SetAllVector256(scalar);

while (pDstCurrent + 8 <= pDstEnd)
while (pDstCurrent <= pDstEnd - 8)
{
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
dstVector = Avx.Add(dstVector, scalarVector256);
Expand Down Expand Up @@ -460,7 +460,7 @@ public static unsafe void ScaleU(float scale, Span<float> dst)

Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);

while (pDstCurrent + 8 <= pEnd)
while (pDstCurrent <= pEnd - 8)
{
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);

Expand Down Expand Up @@ -505,7 +505,7 @@ public static unsafe void ScaleSrcU(float scale, Span<float> src, Span<float> ds

Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);

while (pDstCurrent + 8 <= pDstEnd)
while (pDstCurrent <= pDstEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
srcVector = Avx.Multiply(srcVector, scaleVector256);
Expand Down Expand Up @@ -550,7 +550,7 @@ public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
Vector256<float> a256 = Avx.SetAllVector256(a);
Vector256<float> b256 = Avx.SetAllVector256(b);

while (pDstCurrent + 8 <= pDstEnd)
while (pDstCurrent <= pDstEnd - 8)
{
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
dstVector = Avx.Add(dstVector, b256);
Expand Down Expand Up @@ -596,7 +596,7 @@ public static unsafe void AddScaleU(float scale, Span<float> src, Span<float> ds

Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);

while (pDstCurrent + 8 <= pEnd)
while (pDstCurrent <= pEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
Expand Down Expand Up @@ -652,7 +652,7 @@ public static unsafe void AddScaleCopyU(float scale, Span<float> src, Span<float

Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);

while (pResCurrent + 8 <= pResEnd)
while (pResCurrent <= pResEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
Expand Down Expand Up @@ -708,7 +708,7 @@ public static unsafe void AddScaleSU(float scale, Span<float> src, Span<int> idx

Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);

while (pIdxCurrent + 8 <= pEnd)
while (pIdxCurrent <= pEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
Vector256<float> dstVector = Load8(pDstCurrent, pIdxCurrent);
Expand Down Expand Up @@ -755,7 +755,7 @@ public static unsafe void AddU(Span<float> src, Span<float> dst)
float* pDstCurrent = pdst;
float* pEnd = psrc + src.Length;

while (pSrcCurrent + 8 <= pEnd)
while (pSrcCurrent <= pEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
Expand Down Expand Up @@ -804,7 +804,7 @@ public static unsafe void AddSU(Span<float> src, Span<int> idx, Span<float> dst)
float* pDstCurrent = pdst;
int* pEnd = pidx + idx.Length;

while (pIdxCurrent + 8 <= pEnd)
while (pIdxCurrent <= pEnd - 8)
{
Vector256<float> dstVector = Load8(pDstCurrent, pIdxCurrent);
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
Expand Down Expand Up @@ -849,7 +849,7 @@ public static unsafe void MulElementWiseU(Span<float> src1, Span<float> src2, Sp
float* pDstCurrent = pdst;
float* pEnd = pdst + dst.Length;

while (pDstCurrent + 8 <= pEnd)
while (pDstCurrent <= pEnd - 8)
{
Vector256<float> src1Vector = Avx.LoadVector256(pSrc1Current);
Vector256<float> src2Vector = Avx.LoadVector256(pSrc2Current);
Expand Down Expand Up @@ -896,7 +896,7 @@ public static unsafe float SumU(Span<float> src)

Vector256<float> result256 = Avx.SetZeroVector256<float>();

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
result256 = Avx.Add(result256, Avx.LoadVector256(pSrcCurrent));
pSrcCurrent += 8;
Expand Down Expand Up @@ -934,7 +934,7 @@ public static unsafe float SumSqU(Span<float> src)

Vector256<float> result256 = Avx.SetZeroVector256<float>();

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
result256 = Avx.Add(result256, Avx.Multiply(srcVector, srcVector));
Expand Down Expand Up @@ -979,7 +979,7 @@ public static unsafe float SumSqDiffU(float mean, Span<float> src)
Vector256<float> result256 = Avx.SetZeroVector256<float>();
Vector256<float> meanVector256 = Avx.SetAllVector256(mean);

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
srcVector = Avx.Subtract(srcVector, meanVector256);
Expand Down Expand Up @@ -1027,7 +1027,7 @@ public static unsafe float SumAbsU(Span<float> src)

Vector256<float> result256 = Avx.SetZeroVector256<float>();

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
result256 = Avx.Add(result256, Avx.And(srcVector, _absMask256));
Expand Down Expand Up @@ -1072,7 +1072,7 @@ public static unsafe float SumAbsDiffU(float mean, Span<float> src)
Vector256<float> result256 = Avx.SetZeroVector256<float>();
Vector256<float> meanVector256 = Avx.SetAllVector256(mean);

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
srcVector = Avx.Subtract(srcVector, meanVector256);
Expand Down Expand Up @@ -1120,7 +1120,7 @@ public static unsafe float MaxAbsU(Span<float> src)

Vector256<float> result256 = Avx.SetZeroVector256<float>();

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
result256 = Avx.Max(result256, Avx.And(srcVector, _absMask256));
Expand Down Expand Up @@ -1165,7 +1165,7 @@ public static unsafe float MaxAbsDiffU(float mean, Span<float> src)
Vector256<float> result256 = Avx.SetZeroVector256<float>();
Vector256<float> meanVector256 = Avx.SetAllVector256(mean);

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
srcVector = Avx.Subtract(srcVector, meanVector256);
Expand Down Expand Up @@ -1215,7 +1215,7 @@ public static unsafe float DotU(Span<float> src, Span<float> dst)

Vector256<float> result256 = Avx.SetZeroVector256<float>();

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
Expand Down Expand Up @@ -1272,7 +1272,7 @@ public static unsafe float DotSU(Span<float> src, Span<float> dst, Span<int> idx

Vector256<float> result256 = Avx.SetZeroVector256<float>();

while (pIdxCurrent + 8 <= pIdxEnd)
while (pIdxCurrent <= pIdxEnd - 8)
{
Vector256<float> srcVector = Load8(pSrcCurrent, pIdxCurrent);
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
Expand Down Expand Up @@ -1327,7 +1327,7 @@ public static unsafe float Dist2(Span<float> src, Span<float> dst)

Vector256<float> sqDistanceVector256 = Avx.SetZeroVector256<float>();

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> distanceVector = Avx.Subtract(Avx.LoadVector256(pSrcCurrent),
Avx.LoadVector256(pDstCurrent));
Expand Down Expand Up @@ -1384,7 +1384,7 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, Span<float> src, flo
Vector256<float> xPrimal256 = Avx.SetAllVector256(primalUpdate);
Vector256<float> xThreshold256 = Avx.SetAllVector256(threshold);

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> xSrc = Avx.LoadVector256(pSrcCurrent);

Expand Down Expand Up @@ -1446,7 +1446,7 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, Span<float> src, Sp
Vector256<float> xPrimal256 = Avx.SetAllVector256(primalUpdate);
Vector256<float> xThreshold = Avx.SetAllVector256(threshold);

while (pIdxCurrent + 8 <= pIdxEnd)
while (pIdxCurrent <= pIdxEnd - 8)
{
Vector256<float> xSrc = Avx.LoadVector256(pSrcCurrent);

Expand Down
44 changes: 22 additions & 22 deletions src/Microsoft.ML.CpuMath/SseIntrinsics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ public static unsafe void AddScalarU(float scalar, Span<float> dst)

Vector128<float> scalarVector = Sse.SetAllVector128(scalar);

while (pDstCurrent + 4 <= pDstEnd)
while (pDstCurrent <= pDstEnd - 4)
{
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
dstVector = Sse.Add(dstVector, scalarVector);
Expand Down Expand Up @@ -446,7 +446,7 @@ public static unsafe void ScaleU(float scale, Span<float> dst)

Vector128<float> scaleVector = Sse.SetAllVector128(scale);

while (pDstCurrent + 4 <= pEnd)
while (pDstCurrent <= pEnd - 4)
{
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);

Expand Down Expand Up @@ -479,7 +479,7 @@ public static unsafe void ScaleSrcU(float scale, Span<float> src, Span<float> ds

Vector128<float> scaleVector = Sse.SetAllVector128(scale);

while (pDstCurrent + 4 <= pDstEnd)
while (pDstCurrent <= pDstEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
srcVector = Sse.Multiply(srcVector, scaleVector);
Expand Down Expand Up @@ -512,7 +512,7 @@ public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
Vector128<float> aVector = Sse.SetAllVector128(a);
Vector128<float> bVector = Sse.SetAllVector128(b);

while (pDstCurrent + 4 <= pDstEnd)
while (pDstCurrent <= pDstEnd - 4)
{
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
dstVector = Sse.Add(dstVector, bVector);
Expand Down Expand Up @@ -545,7 +545,7 @@ public static unsafe void AddScaleU(float scale, Span<float> src, Span<float> ds

Vector128<float> scaleVector = Sse.SetAllVector128(scale);

while (pDstCurrent + 4 <= pEnd)
while (pDstCurrent <= pEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
Expand Down Expand Up @@ -586,7 +586,7 @@ public static unsafe void AddScaleCopyU(float scale, Span<float> src, Span<float

Vector128<float> scaleVector = Sse.SetAllVector128(scale);

while (pResCurrent + 4 <= pResEnd)
while (pResCurrent <= pResEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
Expand Down Expand Up @@ -627,7 +627,7 @@ public static unsafe void AddScaleSU(float scale, Span<float> src, Span<int> idx

Vector128<float> scaleVector = Sse.SetAllVector128(scale);

while (pIdxCurrent + 4 <= pEnd)
while (pIdxCurrent <= pEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
Vector128<float> dstVector = Load4(pDstCurrent, pIdxCurrent);
Expand Down Expand Up @@ -659,7 +659,7 @@ public static unsafe void AddU(Span<float> src, Span<float> dst)
float* pDstCurrent = pdst;
float* pEnd = psrc + src.Length;

while (pSrcCurrent + 4 <= pEnd)
while (pSrcCurrent <= pEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
Expand Down Expand Up @@ -696,7 +696,7 @@ public static unsafe void AddSU(Span<float> src, Span<int> idx, Span<float> dst)
float* pDstCurrent = pdst;
int* pEnd = pidx + idx.Length;

while (pIdxCurrent + 4 <= pEnd)
while (pIdxCurrent <= pEnd - 4)
{
Vector128<float> dstVector = Load4(pDstCurrent, pIdxCurrent);
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
Expand Down Expand Up @@ -729,7 +729,7 @@ public static unsafe void MulElementWiseU(Span<float> src1, Span<float> src2, Sp
float* pDstCurrent = pdst;
float* pEnd = pdst + dst.Length;

while (pDstCurrent + 4 <= pEnd)
while (pDstCurrent <= pEnd - 4)
{
Vector128<float> src1Vector = Sse.LoadVector128(pSrc1Current);
Vector128<float> src2Vector = Sse.LoadVector128(pSrc2Current);
Expand Down Expand Up @@ -764,7 +764,7 @@ public static unsafe float SumU(Span<float> src)

Vector128<float> result = Sse.SetZeroVector128();

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
result = Sse.Add(result, Sse.LoadVector128(pSrcCurrent));
pSrcCurrent += 4;
Expand All @@ -791,7 +791,7 @@ public static unsafe float SumSqU(Span<float> src)

Vector128<float> result = Sse.SetZeroVector128();

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
result = Sse.Add(result, Sse.Multiply(srcVector, srcVector));
Expand Down Expand Up @@ -823,7 +823,7 @@ public static unsafe float SumSqDiffU(float mean, Span<float> src)
Vector128<float> result = Sse.SetZeroVector128();
Vector128<float> meanVector = Sse.SetAllVector128(mean);

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
srcVector = Sse.Subtract(srcVector, meanVector);
Expand Down Expand Up @@ -856,7 +856,7 @@ public static unsafe float SumAbsU(Span<float> src)

Vector128<float> result = Sse.SetZeroVector128();

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
result = Sse.Add(result, Sse.And(srcVector, AbsMask128));
Expand Down Expand Up @@ -888,7 +888,7 @@ public static unsafe float SumAbsDiffU(float mean, Span<float> src)
Vector128<float> result = Sse.SetZeroVector128();
Vector128<float> meanVector = Sse.SetAllVector128(mean);

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
srcVector = Sse.Subtract(srcVector, meanVector);
Expand Down Expand Up @@ -921,7 +921,7 @@ public static unsafe float MaxAbsU(Span<float> src)

Vector128<float> result = Sse.SetZeroVector128();

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
result = Sse.Max(result, Sse.And(srcVector, AbsMask128));
Expand Down Expand Up @@ -953,7 +953,7 @@ public static unsafe float MaxAbsDiffU(float mean, Span<float> src)
Vector128<float> result = Sse.SetZeroVector128();
Vector128<float> meanVector = Sse.SetAllVector128(mean);

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
srcVector = Sse.Subtract(srcVector, meanVector);
Expand Down Expand Up @@ -988,7 +988,7 @@ public static unsafe float DotU(Span<float> src, Span<float> dst)

Vector128<float> result = Sse.SetZeroVector128();

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
Expand Down Expand Up @@ -1029,7 +1029,7 @@ public static unsafe float DotSU(Span<float> src, Span<float> dst, Span<int> idx

Vector128<float> result = Sse.SetZeroVector128();

while (pIdxCurrent + 4 <= pIdxEnd)
while (pIdxCurrent <= pIdxEnd - 4)
{
Vector128<float> srcVector = Load4(pSrcCurrent, pIdxCurrent);
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
Expand Down Expand Up @@ -1068,7 +1068,7 @@ public static unsafe float Dist2(Span<float> src, Span<float> dst)

Vector128<float> sqDistanceVector = Sse.SetZeroVector128();

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> distanceVector = Sse.Subtract(Sse.LoadVector128(pSrcCurrent),
Sse.LoadVector128(pDstCurrent));
Expand Down Expand Up @@ -1111,7 +1111,7 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, Span<float> src, flo
Vector128<float> signMask = Sse.SetAllVector128(-0.0f); // 0x8000 0000
Vector128<float> xThreshold = Sse.SetAllVector128(threshold);

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> xSrc = Sse.LoadVector128(pSrcCurrent);

Expand Down Expand Up @@ -1156,7 +1156,7 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, Span<float> src, Sp
Vector128<float> signMask = Sse.SetAllVector128(-0.0f); // 0x8000 0000
Vector128<float> xThreshold = Sse.SetAllVector128(threshold);

while (pIdxCurrent + 4 <= pIdxEnd)
while (pIdxCurrent <= pIdxEnd - 4)
{
Vector128<float> xSrc = Sse.LoadVector128(pSrcCurrent);

Expand Down
Loading