Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,165 changes: 0 additions & 1,165 deletions src/Microsoft.ML.CpuMath/Avx.cs

This file was deleted.

192 changes: 136 additions & 56 deletions src/Microsoft.ML.CpuMath/AvxIntrinsics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,6 @@ internal static class AvxIntrinsics

private static readonly Vector256<float> _absMask256 = Avx.StaticCast<int, float>(Avx.SetAllVector256(0x7FFFFFFF));

private const int Vector256Alignment = 32;

[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
private static bool HasCompatibleAlignment(AlignedArray alignedArray)
{
Contracts.AssertValue(alignedArray);
Contracts.Assert(alignedArray.Size > 0);
return (alignedArray.CbAlign % Vector256Alignment) == 0;
}

[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
private static unsafe float* GetAlignedBase(AlignedArray alignedArray, float* unalignedBase)
{
Contracts.AssertValue(alignedArray);
float* alignedBase = unalignedBase + alignedArray.GetBase((long)unalignedBase);
Contracts.Assert(((long)alignedBase % Vector256Alignment) == 0);
return alignedBase;
}

[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
private static Vector128<float> GetHigh(in Vector256<float> x)
=> Avx.ExtractVector128(x, 1);
Expand Down Expand Up @@ -170,19 +151,19 @@ private static Vector256<float> MultiplyAdd(Vector256<float> src1, Vector256<flo
}

// Multiply matrix times vector into vector.
public static unsafe void MatMulX(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
public static unsafe void MatMul(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
{
Contracts.Assert(crow % 4 == 0);
Contracts.Assert(ccol % 4 == 0);

MatMulX(mat.Items, src.Items, dst.Items, crow, ccol);
MatMul(mat.Items, src.Items, dst.Items, crow, ccol);
}

public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int crow, int ccol)
public static unsafe void MatMul(ReadOnlySpan<float> mat, ReadOnlySpan<float> src, Span<float> dst, int crow, int ccol)
{
fixed (float* psrc = &src[0])
fixed (float* pdst = &dst[0])
fixed (float* pmat = &mat[0])
Contracts.Assert(crow % 4 == 0);
Contracts.Assert(ccol % 4 == 0);

fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
fixed (float* pmat = &MemoryMarshal.GetReference(mat))
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
{
Expand Down Expand Up @@ -312,32 +293,134 @@ public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int cro
}

// Partial sparse source vector.
public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArray src,
int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol)
public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan<int> rgposSrc, AlignedArray src,
int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol)
{
Contracts.Assert(HasCompatibleAlignment(mat));
Contracts.Assert(HasCompatibleAlignment(src));
Contracts.Assert(HasCompatibleAlignment(dst));
MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol);
}

public static unsafe void MatMulP(ReadOnlySpan<float> mat, ReadOnlySpan<int> rgposSrc, ReadOnlySpan<float> src,
int posMin, int iposMin, int iposEnd, Span<float> dst, int crow, int ccol)
{
Contracts.Assert(crow % 8 == 0);
Contracts.Assert(ccol % 8 == 0);

// REVIEW: For extremely sparse inputs, interchanging the loops would
// likely be more efficient.
fixed (float* pSrcStart = &src.Items[0])
fixed (float* pDstStart = &dst.Items[0])
fixed (float* pMatStart = &mat.Items[0])
fixed (int* pposSrc = &rgposSrc[0])
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
fixed (float* pmat = &MemoryMarshal.GetReference(mat))
fixed (int* pposSrc = &MemoryMarshal.GetReference(rgposSrc))
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
{
float* psrc = GetAlignedBase(src, pSrcStart);
float* pdst = GetAlignedBase(dst, pDstStart);
float* pmat = GetAlignedBase(mat, pMatStart);

int* pposMin = pposSrc + iposMin;
int* pposEnd = pposSrc + iposEnd;
float* pDstEnd = pdst + crow;
float* pm0 = pmat - posMin;
float* pSrcCurrent = psrc - posMin;
float* pDstCurrent = pdst;

while (pDstCurrent < pDstEnd)
nuint address = (nuint)(pDstCurrent);
int misalignment = (int)(address % 32);
int length = crow;
int remainder = 0;

if ((misalignment & 3) != 0)
{
while (pDstCurrent < pDstEnd)
{
Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow());
pDstCurrent += 8;
pm0 += 8 * ccol;
}
}
else
{
if (misalignment != 0)
{
misalignment >>= 2;
misalignment = 8 - misalignment;

Vector256<float> mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8));

float* pm1 = pm0 + ccol;
float* pm2 = pm1 + ccol;
float* pm3 = pm2 + ccol;
Vector256<float> result = Avx.SetZeroVector256<float>();

int* ppos = pposMin;

while (ppos < pposEnd)
{
int col1 = *ppos;
int col2 = col1 + 4 * ccol;
Vector256<float> x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we have a helper method for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no its different , the one we have the indexs are continous
return Avx.SetVector256(src[idx[7]], src[idx[6]], src[idx[5]], src[idx[4]], src[idx[3]], src[idx[2]], src[idx[1]], src[idx[0]]);

pm3[col1], pm2[col1], pm1[col1], pm0[col1]);

x1 = Avx.And(mask, x1);
Vector256<float> x2 = Avx.SetAllVector256(pSrcCurrent[col1]);
result = MultiplyAdd(x2, x1, result);
ppos++;
}

Avx.Store(pDstCurrent, result);
pDstCurrent += misalignment;
pm0 += misalignment * ccol;
length -= misalignment;
}

if (length > 7)
{
remainder = length % 8;
while (pDstCurrent < pDstEnd)
{
Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow());
pDstCurrent += 8;
pm0 += 8 * ccol;
}
}
else
{
remainder = length;
}

if (remainder != 0)
{
pDstCurrent -= (8 - remainder);
pm0 -= (8 - remainder) * ccol;
Vector256<float> trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8));
Vector256<float> leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8));

float* pm1 = pm0 + ccol;
float* pm2 = pm1 + ccol;
float* pm3 = pm2 + ccol;
Vector256<float> result = Avx.SetZeroVector256<float>();

int* ppos = pposMin;

while (ppos < pposEnd)
{
int col1 = *ppos;
int col2 = col1 + 4 * ccol;
Vector256<float> x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2],
pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
x1 = Avx.And(x1, trailingMask);

Vector256<float> x2 = Avx.SetAllVector256(pSrcCurrent[col1]);
result = MultiplyAdd(x2, x1, result);
ppos++;
}

result = Avx.Add(result, Avx.And(leadingMask, Avx.LoadVector256(pDstCurrent)));

Avx.Store(pDstCurrent, result);
pDstCurrent += 8;
pm0 += 8 * ccol;
}
}

Vector256<float> SparseMultiplicationAcrossRow()
{
float* pm1 = pm0 + ccol;
float* pm2 = pm1 + ccol;
Expand All @@ -351,33 +434,30 @@ public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArra
int col1 = *ppos;
int col2 = col1 + 4 * ccol;
Vector256<float> x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2],
pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
Vector256<float> x2 = Avx.SetAllVector256(pSrcCurrent[col1]);
result = MultiplyAdd(x2, x1, result);

ppos++;
}

Avx.StoreAligned(pDstCurrent, result);
pDstCurrent += 8;
pm0 += 8 * ccol;
return result;
}
}
}

public static unsafe void MatMulTranX(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
{
Contracts.Assert(crow % 4 == 0);
Contracts.Assert(ccol % 4 == 0);

MatMulTranX(mat.Items, src.Items, dst.Items, crow, ccol);
MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol);
}

public static unsafe void MatMulTranX(float[] mat, float[] src, float[] dst, int crow, int ccol)
public static unsafe void MatMulTran(ReadOnlySpan<float> mat, ReadOnlySpan<float> src, Span<float> dst, int crow, int ccol)
{
fixed (float* psrc = &src[0])
fixed (float* pdst = &dst[0])
fixed (float* pmat = &mat[0])
Contracts.Assert(crow % 4 == 0);
Contracts.Assert(ccol % 4 == 0);

fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
fixed (float* pmat = &MemoryMarshal.GetReference(mat))
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
{
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, Al
if (!tran)
{
Contracts.Assert(crun <= dst.Size);
AvxIntrinsics.MatMulX(mat, src, dst, crun, src.Size);
AvxIntrinsics.MatMul(mat, src, dst, crun, src.Size);
}
else
{
Contracts.Assert(crun <= src.Size);
AvxIntrinsics.MatMulTranX(mat, src, dst, dst.Size, crun);
AvxIntrinsics.MatMulTran(mat, src, dst, dst.Size, crun);
}
}
else if (Sse.IsSupported)
Expand Down Expand Up @@ -109,12 +109,12 @@ public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray sr
if (Avx.IsSupported)
{
Contracts.Assert(crun <= dst.Size);
AvxIntrinsics.MatMulPX(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
AvxIntrinsics.MatMulP(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
}
else if (Sse.IsSupported)
{
Contracts.Assert(crun <= dst.Size);
SseIntrinsics.MatMulPA(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
SseIntrinsics.MatMulP(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
}
else
{
Expand Down
Loading