Skip to content

Improvements to the "Sum" SIMD algorithm #1112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 25, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 91 additions & 20 deletions src/Microsoft.ML.CpuMath/AvxIntrinsics.cs
Original file line number Diff line number Diff line change
@@ -1300,41 +1300,112 @@ public static unsafe void MulElementWiseU(ReadOnlySpan<float> src1, ReadOnlySpan
}
}

public static unsafe float SumU(ReadOnlySpan<float> src)
public static unsafe float Sum(ReadOnlySpan<float> src)
{
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pSrc = &MemoryMarshal.GetReference(src))
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
{
float* pSrcEnd = psrc + src.Length;
float* pSrcCurrent = psrc;
float* pValues = pSrc;
int length = src.Length;

Vector256<float> result256 = Avx.SetZeroVector256<float>();

while (pSrcCurrent + 8 <= pSrcEnd)
if (length < 8)
{
result256 = Avx.Add(result256, Avx.LoadVector256(pSrcCurrent));
pSrcCurrent += 8;
// Handle cases where we have less than 256-bits total and can't ever use SIMD acceleration.

float res = 0;

switch (length)
{
case 7: res += pValues[6]; goto case 6;
case 6: res += pValues[5]; goto case 5;
case 5: res += pValues[4]; goto case 4;
case 4: res += pValues[3]; goto case 3;
case 3: res += pValues[2]; goto case 2;
case 2: res += pValues[1]; goto case 1;
case 1: res += pValues[0]; break;
}

return res;
}

result256 = VectorSum256(in result256);
Vector128<float> resultPadded = Sse.AddScalar(Avx.GetLowerHalf(result256), GetHigh(result256));
Vector256<float> result = Avx.SetZeroVector256<float>();

Vector128<float> result128 = Sse.SetZeroVector128();
nuint address = (nuint)(pValues);
int misalignment = (int)(address % 32);
int remainder = 0;

if (pSrcCurrent + 4 <= pSrcEnd)
if ((misalignment & 3) != 0)
{
result128 = Sse.Add(result128, Sse.LoadVector128(pSrcCurrent));
pSrcCurrent += 4;
// Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations

remainder = length % 8;

for (float* pEnd = pValues + (length - remainder); pValues < pEnd; pValues += 8)
{
result = Avx.Add(result, Avx.LoadVector256(pValues));
}
}
else
{
if (misalignment != 0)
{
// Handle cases where the data is not 256-bit aligned by doing an unaligned read and then
// masking any elements that will be included in the first aligned read

result128 = SseIntrinsics.VectorSum128(in result128);
misalignment >>= 2;
misalignment = 8 - misalignment;

while (pSrcCurrent < pSrcEnd)
Vector256<float> mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8));
Vector256<float> temp = Avx.And(mask, Avx.LoadVector256(pValues));
result = Avx.Add(result, temp);

pValues += misalignment;
length -= misalignment;
}

if (length > 7)
{
// Handle all the 256-bit blocks that we can now that we have offset to an aligned address

remainder = length % 8;

for (float* pEnd = pValues + (length - remainder); pValues < pEnd; pValues += 8)
{
// The JIT will only fold away unaligned loads due to the semantics behind
// the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since
// modern hardware has unaligned loads that are as fast as aligned loads,
// when it doesn't cross a cache-line/page boundary, we will just assert
// that the alignment is correct and allow for the more-efficient codegen.

Contracts.Assert(((nuint)(pValues) % 32) == 0);
result = Avx.Add(result, Avx.LoadVector256(pValues));
}
}
else
{
// Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not
// 256-bit aligned. This means we can't do any aligned loads and will just end up doing two
// unaligned loads where we mask the input each time.
remainder = length;
}
}

if (remainder != 0)
{
result128 = Sse.AddScalar(result128, Sse.LoadScalarVector128(pSrcCurrent));
pSrcCurrent++;
// Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next
// unaligned load will read to the end of the array and then mask out any elements already processed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be "next aligned load"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we are moving back from an aligned address to an unaligned one.


pValues -= (8 - remainder);

Vector256<float> mask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8));
Vector256<float> temp = Avx.And(mask, Avx.LoadVector256(pValues));
result = Avx.Add(result, temp);
}

return Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded));
// Sum all the elements together and return the result
result = VectorSum256(in result);
return Sse.ConvertToSingle(Sse.AddScalar(Avx.GetLowerHalf(result), GetHigh(result)));
}
}

4 changes: 2 additions & 2 deletions src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs
Original file line number Diff line number Diff line change
@@ -388,11 +388,11 @@ public static float Sum(ReadOnlySpan<float> src)

if (Avx.IsSupported)
{
return AvxIntrinsics.SumU(src);
return AvxIntrinsics.Sum(src);
}
else if (Sse.IsSupported)
{
return SseIntrinsics.SumU(src);
return SseIntrinsics.Sum(src);
}
else
{
2 changes: 1 addition & 1 deletion src/Microsoft.ML.CpuMath/Sse.cs
Original file line number Diff line number Diff line change
@@ -246,7 +246,7 @@ public static float Sum(ReadOnlySpan<float> src)
unsafe
{
fixed (float* psrc = &MemoryMarshal.GetReference(src))
return Thunk.SumU(psrc, src.Length);
return Thunk.Sum(psrc, src.Length);
}
}

97 changes: 86 additions & 11 deletions src/Microsoft.ML.CpuMath/SseIntrinsics.cs
Original file line number Diff line number Diff line change
@@ -1140,29 +1140,104 @@ public static unsafe void MulElementWiseU(ReadOnlySpan<float> src1, ReadOnlySpan
}
}

public static unsafe float SumU(ReadOnlySpan<float> src)
public static unsafe float Sum(ReadOnlySpan<float> src)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming you're running on a machine supporting AVX -- unit tests would not hit this -- unless you ran them with the env variable set?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume not since @fiigii change didn't go in yet.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have unit/perf tests that explicitly call these methods/code-paths

{
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pSrc = &MemoryMarshal.GetReference(src))
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
{
float* pSrcEnd = psrc + src.Length;
float* pSrcCurrent = psrc;
float* pValues = pSrc;
int length = src.Length;

if (length < 4)
{
// Handle cases where we have less than 128-bits total and can't ever use SIMD acceleration.

float res = 0;

switch (length)
{
case 3: res += pValues[2]; goto case 2;
case 2: res += pValues[1]; goto case 1;
case 1: res += pValues[0]; break;
}

return res;
}

Vector128<float> result = Sse.SetZeroVector128();

while (pSrcCurrent + 4 <= pSrcEnd)
nuint address = (nuint)(pValues);
int misalignment = (int)(address % 16);
int remainder = 0;

if ((misalignment & 3) != 0)
{
result = Sse.Add(result, Sse.LoadVector128(pSrcCurrent));
pSrcCurrent += 4;
// Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations

remainder = length % 4;

for (float* pEnd = pValues + (length - remainder); pValues < pEnd; pValues += 4)
{
result = Sse.Add(result, Sse.LoadVector128(pValues));
}
}
else
{
if (misalignment != 0)
{
// Handle cases where the data is not 128-bit aligned by doing an unaligned read and then
// masking any elements that will be included in the first aligned read

result = VectorSum128(in result);
misalignment >>= 2;
misalignment = 4 - misalignment;

while (pSrcCurrent < pSrcEnd)
Vector128<float> mask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4));
Vector128<float> temp = Sse.And(mask, Sse.LoadVector128(pValues));
result = Sse.Add(result, temp);

pValues += misalignment;
length -= misalignment;
}

if (length > 3)
{
// Handle all the 128-bit blocks that we can now that we have offset to an aligned address

remainder = length % 4;

for (float* pEnd = pValues + (length - remainder); pValues < pEnd; pValues += 4)
{
// If we aren't using the VEX-encoding, the JIT will only fold away aligned loads
// (due to semantics of the legacy encoding).
// We don't need an assert, since the instruction will throw for unaligned inputs.

result = Sse.Add(result, Sse.LoadAlignedVector128(pValues));
}
}
else
{
// Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not
// 128-bit aligned. This means we can't do any aligned loads and will just end up doing two
// unaligned loads where we mask the input each time.
remainder = length;
}
}

if (remainder != 0)
{
result = Sse.AddScalar(result, Sse.LoadScalarVector128(pSrcCurrent));
pSrcCurrent++;
// Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next
// unaligned load will read to the end of the array and then mask out any elements already processed

pValues -= (4 - remainder);

Vector128<float> mask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4));
Vector128<float> temp = Sse.And(temp, Sse.LoadVector128(pValues));
result = Sse.Add(result, temp);
}

// Sum all the elements together and return the result
result = VectorSum128(in result);
return Sse.ConvertToSingle(result);
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.ML.CpuMath/Thunk.cs
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ public static extern void MatMulP(/*const*/ float* pmat, /*const*/ int* pposSrc,
public static extern void AddSU(/*const*/ float* ps, /*const*/ int* pi, float* pd, int c);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern float SumU(/*const*/ float* ps, int c);
public static extern float Sum(/*const*/ float* pValues, int length);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern float SumSqU(/*const*/ float* ps, int c);
101 changes: 91 additions & 10 deletions src/Native/CpuMathNative/Sse.cpp
Original file line number Diff line number Diff line change
@@ -903,21 +903,102 @@ EXPORT_API(void) MulElementWiseU(_In_ const float * ps1, _In_ const float * ps2,
}
}

EXPORT_API(float) SumU(const float * ps, int c)
EXPORT_API(float) Sum(const float* pValues, int length)
{
const float * psLim = ps + c;
if (length < 4)
{
// Handle cases where we have less than 128-bits total and can't ever use SIMD acceleration.

__m128 res = _mm_setzero_ps();
for (; ps + 4 <= psLim; ps += 4)
res = _mm_add_ps(res, _mm_loadu_ps(ps));
float result = 0;

res = _mm_hadd_ps(res, res);
res = _mm_hadd_ps(res, res);
switch (length)
{
case 3: result += pValues[2];
case 2: result += pValues[1];
case 1: result += pValues[0];
}

for (; ps < psLim; ps++)
res = _mm_add_ss(res, _mm_load_ss(ps));
return result;
}

return _mm_cvtss_f32(res);
__m128 result = _mm_setzero_ps();

uintptr_t address = (uintptr_t)(pValues);
uintptr_t misalignment = address % 16;

int remainder = 0;

if ((misalignment & 3) != 0)
{
// Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations

remainder = length % 4;

for (const float* pEnd = pValues + (length - remainder); pValues < pEnd; pValues += 4)
{
__m128 temp = _mm_loadu_ps(pValues);
result = _mm_add_ps(result, temp);
}
}
else
{
if (misalignment != 0)
{
// Handle cases where the data is not 128-bit aligned by doing an unaligned read and then
// masking any elements that will be included in the first aligned read

misalignment >>= 2;
misalignment = 4 - misalignment;

__m128 temp = _mm_loadu_ps(pValues);
__m128 mask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4));
temp = _mm_and_ps(temp, mask);
result = _mm_add_ps(result, temp);

pValues += misalignment;
length -= misalignment;
}

if (length > 3)
{
// Handle all the 128-bit blocks that we can now that we have offset to an aligned address

remainder = length % 4;

for (const float* pEnd = pValues + (length - remainder); pValues < pEnd; pValues += 4)
{
__m128 temp = _mm_load_ps(pValues);
result = _mm_add_ps(result, temp);
}
}
else
{
// Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not
// 128-bit aligned. This means we can't do any aligned loads and will just end up doing two
// unaligned loads where we mask the input each time.
remainder = length;
}
}

if (remainder != 0)
{
// Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next
// unaligned load will read to the end of the array and then mask out any elements already processed

pValues -= (4 - remainder);

__m128 temp = _mm_loadu_ps(pValues);
__m128 mask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4));
temp = _mm_and_ps(temp, mask);
result = _mm_add_ps(result, temp);
}

// Sum all the elements together and return the result

result = _mm_add_ps(result, _mm_movehl_ps(result, result));
result = _mm_add_ps(result, _mm_shuffle_ps(result, result, 0xB1));

return _mm_cvtss_f32(result);
}

EXPORT_API(float) SumSqU(const float * ps, int c)
Loading