-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Same implementation for Sparse Multiplication for aligned and unaligned arrays #1274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6d6897f
6fff9a6
e4478ad
924ac0e
96ca8a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,25 +46,6 @@ internal static class AvxIntrinsics | |
|
||
private static readonly Vector256<float> _absMask256 = Avx.StaticCast<int, float>(Avx.SetAllVector256(0x7FFFFFFF)); | ||
|
||
private const int Vector256Alignment = 32; | ||
|
||
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)] | ||
private static bool HasCompatibleAlignment(AlignedArray alignedArray) | ||
{ | ||
Contracts.AssertValue(alignedArray); | ||
Contracts.Assert(alignedArray.Size > 0); | ||
return (alignedArray.CbAlign % Vector256Alignment) == 0; | ||
} | ||
|
||
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)] | ||
private static unsafe float* GetAlignedBase(AlignedArray alignedArray, float* unalignedBase) | ||
{ | ||
Contracts.AssertValue(alignedArray); | ||
float* alignedBase = unalignedBase + alignedArray.GetBase((long)unalignedBase); | ||
Contracts.Assert(((long)alignedBase % Vector256Alignment) == 0); | ||
return alignedBase; | ||
} | ||
|
||
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)] | ||
private static Vector128<float> GetHigh(in Vector256<float> x) | ||
=> Avx.ExtractVector128(x, 1); | ||
|
@@ -170,19 +151,19 @@ private static Vector256<float> MultiplyAdd(Vector256<float> src1, Vector256<flo | |
} | ||
|
||
// Multiply matrix times vector into vector. | ||
public static unsafe void MatMulX(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) | ||
public static unsafe void MatMul(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) | ||
{ | ||
Contracts.Assert(crow % 4 == 0); | ||
Contracts.Assert(ccol % 4 == 0); | ||
|
||
MatMulX(mat.Items, src.Items, dst.Items, crow, ccol); | ||
MatMul(mat.Items, src.Items, dst.Items, crow, ccol); | ||
} | ||
|
||
public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int crow, int ccol) | ||
public static unsafe void MatMul(ReadOnlySpan<float> mat, ReadOnlySpan<float> src, Span<float> dst, int crow, int ccol) | ||
{ | ||
fixed (float* psrc = &src[0]) | ||
fixed (float* pdst = &dst[0]) | ||
fixed (float* pmat = &mat[0]) | ||
Contracts.Assert(crow % 4 == 0); | ||
Contracts.Assert(ccol % 4 == 0); | ||
|
||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
fixed (float* pdst = &MemoryMarshal.GetReference(dst)) | ||
fixed (float* pmat = &MemoryMarshal.GetReference(mat)) | ||
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) | ||
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) | ||
{ | ||
|
@@ -312,32 +293,134 @@ public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int cro | |
} | ||
|
||
// Partial sparse source vector. | ||
public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArray src, | ||
int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) | ||
public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan<int> rgposSrc, AlignedArray src, | ||
int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) | ||
{ | ||
Contracts.Assert(HasCompatibleAlignment(mat)); | ||
Contracts.Assert(HasCompatibleAlignment(src)); | ||
Contracts.Assert(HasCompatibleAlignment(dst)); | ||
MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol); | ||
} | ||
|
||
public static unsafe void MatMulP(ReadOnlySpan<float> mat, ReadOnlySpan<int> rgposSrc, ReadOnlySpan<float> src, | ||
int posMin, int iposMin, int iposEnd, Span<float> dst, int crow, int ccol) | ||
{ | ||
Contracts.Assert(crow % 8 == 0); | ||
Contracts.Assert(ccol % 8 == 0); | ||
|
||
// REVIEW: For extremely sparse inputs, interchanging the loops would | ||
// likely be more efficient. | ||
fixed (float* pSrcStart = &src.Items[0]) | ||
fixed (float* pDstStart = &dst.Items[0]) | ||
fixed (float* pMatStart = &mat.Items[0]) | ||
fixed (int* pposSrc = &rgposSrc[0]) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
fixed (float* pdst = &MemoryMarshal.GetReference(dst)) | ||
fixed (float* pmat = &MemoryMarshal.GetReference(mat)) | ||
fixed (int* pposSrc = &MemoryMarshal.GetReference(rgposSrc)) | ||
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) | ||
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) | ||
{ | ||
float* psrc = GetAlignedBase(src, pSrcStart); | ||
float* pdst = GetAlignedBase(dst, pDstStart); | ||
float* pmat = GetAlignedBase(mat, pMatStart); | ||
|
||
int* pposMin = pposSrc + iposMin; | ||
int* pposEnd = pposSrc + iposEnd; | ||
float* pDstEnd = pdst + crow; | ||
float* pm0 = pmat - posMin; | ||
float* pSrcCurrent = psrc - posMin; | ||
float* pDstCurrent = pdst; | ||
|
||
while (pDstCurrent < pDstEnd) | ||
nuint address = (nuint)(pDstCurrent); | ||
int misalignment = (int)(address % 32); | ||
int length = crow; | ||
int remainder = 0; | ||
|
||
if ((misalignment & 3) != 0) | ||
{ | ||
while (pDstCurrent < pDstEnd) | ||
{ | ||
Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow()); | ||
pDstCurrent += 8; | ||
pm0 += 8 * ccol; | ||
} | ||
} | ||
else | ||
{ | ||
if (misalignment != 0) | ||
{ | ||
misalignment >>= 2; | ||
misalignment = 8 - misalignment; | ||
|
||
Vector256<float> mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); | ||
|
||
float* pm1 = pm0 + ccol; | ||
float* pm2 = pm1 + ccol; | ||
float* pm3 = pm2 + ccol; | ||
Vector256<float> result = Avx.SetZeroVector256<float>(); | ||
|
||
int* ppos = pposMin; | ||
|
||
while (ppos < pposEnd) | ||
{ | ||
int col1 = *ppos; | ||
int col2 = col1 + 4 * ccol; | ||
Vector256<float> x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we have a helper method for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no its different , the one we have the indexs are continous |
||
pm3[col1], pm2[col1], pm1[col1], pm0[col1]); | ||
|
||
x1 = Avx.And(mask, x1); | ||
Vector256<float> x2 = Avx.SetAllVector256(pSrcCurrent[col1]); | ||
result = MultiplyAdd(x2, x1, result); | ||
ppos++; | ||
} | ||
|
||
Avx.Store(pDstCurrent, result); | ||
pDstCurrent += misalignment; | ||
pm0 += misalignment * ccol; | ||
length -= misalignment; | ||
} | ||
|
||
if (length > 7) | ||
{ | ||
remainder = length % 8; | ||
while (pDstCurrent < pDstEnd) | ||
{ | ||
Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow()); | ||
pDstCurrent += 8; | ||
pm0 += 8 * ccol; | ||
} | ||
} | ||
else | ||
{ | ||
remainder = length; | ||
} | ||
|
||
if (remainder != 0) | ||
{ | ||
pDstCurrent -= (8 - remainder); | ||
pm0 -= (8 - remainder) * ccol; | ||
Vector256<float> trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); | ||
Vector256<float> leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); | ||
|
||
float* pm1 = pm0 + ccol; | ||
float* pm2 = pm1 + ccol; | ||
float* pm3 = pm2 + ccol; | ||
Vector256<float> result = Avx.SetZeroVector256<float>(); | ||
|
||
int* ppos = pposMin; | ||
|
||
while (ppos < pposEnd) | ||
{ | ||
int col1 = *ppos; | ||
int col2 = col1 + 4 * ccol; | ||
Vector256<float> x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2], | ||
pm3[col1], pm2[col1], pm1[col1], pm0[col1]); | ||
x1 = Avx.And(x1, trailingMask); | ||
|
||
Vector256<float> x2 = Avx.SetAllVector256(pSrcCurrent[col1]); | ||
result = MultiplyAdd(x2, x1, result); | ||
ppos++; | ||
} | ||
|
||
result = Avx.Add(result, Avx.And(leadingMask, Avx.LoadVector256(pDstCurrent))); | ||
|
||
Avx.Store(pDstCurrent, result); | ||
pDstCurrent += 8; | ||
pm0 += 8 * ccol; | ||
} | ||
} | ||
|
||
Vector256<float> SparseMultiplicationAcrossRow() | ||
{ | ||
float* pm1 = pm0 + ccol; | ||
float* pm2 = pm1 + ccol; | ||
|
@@ -351,33 +434,30 @@ public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArra | |
int col1 = *ppos; | ||
int col2 = col1 + 4 * ccol; | ||
Vector256<float> x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2], | ||
pm3[col1], pm2[col1], pm1[col1], pm0[col1]); | ||
pm3[col1], pm2[col1], pm1[col1], pm0[col1]); | ||
Vector256<float> x2 = Avx.SetAllVector256(pSrcCurrent[col1]); | ||
result = MultiplyAdd(x2, x1, result); | ||
|
||
ppos++; | ||
} | ||
|
||
Avx.StoreAligned(pDstCurrent, result); | ||
pDstCurrent += 8; | ||
pm0 += 8 * ccol; | ||
return result; | ||
} | ||
} | ||
} | ||
|
||
public static unsafe void MatMulTranX(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) | ||
public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) | ||
{ | ||
Contracts.Assert(crow % 4 == 0); | ||
Contracts.Assert(ccol % 4 == 0); | ||
|
||
MatMulTranX(mat.Items, src.Items, dst.Items, crow, ccol); | ||
MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol); | ||
} | ||
|
||
public static unsafe void MatMulTranX(float[] mat, float[] src, float[] dst, int crow, int ccol) | ||
public static unsafe void MatMulTran(ReadOnlySpan<float> mat, ReadOnlySpan<float> src, Span<float> dst, int crow, int ccol) | ||
{ | ||
fixed (float* psrc = &src[0]) | ||
fixed (float* pdst = &dst[0]) | ||
fixed (float* pmat = &mat[0]) | ||
Contracts.Assert(crow % 4 == 0); | ||
Contracts.Assert(ccol % 4 == 0); | ||
|
||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
fixed (float* pdst = &MemoryMarshal.GetReference(dst)) | ||
fixed (float* pmat = &MemoryMarshal.GetReference(mat)) | ||
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) | ||
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) | ||
{ | ||
|
Uh oh!
There was an error while loading. Please reload this page.