diff --git a/src/Microsoft.ML.CpuMath/Avx.cs b/src/Microsoft.ML.CpuMath/Avx.cs deleted file mode 100644 index f7769b3295..0000000000 --- a/src/Microsoft.ML.CpuMath/Avx.cs +++ /dev/null @@ -1,1165 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; - -namespace Microsoft.ML.Runtime.Internal.CpuMath -{ - /// - /// Keep Avx.cs in sync with Sse.cs. When making changes to one, use BeyondCompare or a similar tool - /// to view diffs and propagate appropriate changes to the other. - /// - public static class AvxUtils - { - public const int CbAlign = 32; - - private static bool Compat(AlignedArray a) - { - Contracts.AssertValue(a); - Contracts.Assert(a.Size > 0); - return a.CbAlign == CbAlign; - } - - private static unsafe float* Ptr(AlignedArray a, float* p) - { - Contracts.AssertValue(a); - float* q = p + a.GetBase((long)p); - Contracts.Assert(((long)q & (CbAlign - 1)) == 0); - return q; - } - - public static bool CheckAvx() - { - return Thunk.ChkAvx(); - } - - public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) - { - Contracts.Assert(Compat(mat)); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(mat.Size == dst.Size * src.Size); - - unsafe - { - fixed (float* pmat = &mat.Items[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - { - if (!tran) - { - Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMulX(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size); - } - else - { - Contracts.Assert(0 <= crun && crun <= src.Size); - Thunk.MatMulTranX(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun); - } - } - } - } - - public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, - int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) - { - Contracts.Assert(Compat(mat)); - Contracts.Assert(Compat(srcValues)); - Contracts.Assert(Compat(dst)); - Contracts.AssertValue(rgposSrc); - Contracts.Assert(0 <= iposMin && iposMin <= iposLim && iposLim <= rgposSrc.Length); - Contracts.Assert(mat.Size == dst.Size * srcValues.Size); - - if (iposMin >= iposLim) - { - dst.ZeroItems(); - return; - } - Contracts.AssertNonEmpty(rgposSrc); - unsafe - { - fixed (float* pdst = &dst.Items[0]) - fixed (float* pmat = &mat.Items[0]) - fixed (float* psrc = &srcValues.Items[0]) - fixed (int* ppossrc = &rgposSrc[0]) - { - Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMulPX(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); - } - } - } - - public static void MatTimesSrc(bool add, int[] starts, int[] indices, float[] coefs, - AlignedArray src, AlignedArray dst, int crow) - { - Contracts.AssertNonEmpty(starts); - Contracts.Assert(starts.Length == crow + 1); - Contracts.Assert(starts[0] == 0); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(starts[crow] == indices.Length); - Contracts.AssertNonEmpty(coefs); - Contracts.Assert(indices.Length == coefs.Length); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(0 < crow && crow <= dst.Size); - Contracts.Assert(crow * src.Size >= coefs.Length); - - unsafe - { - fixed (int* pstarts = &starts[0]) - fixed (int* pindices = &indices[0]) - fixed (float* pcoefs = &coefs[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.MatMulRX(add, pstarts, pindices, pcoefs, Ptr(src, psrc), Ptr(dst, pdst), crow); - } - } - - public static void MatTimesSrc(bool add, int[] mprowiv, int[] mprowcol, - int[] mprowrun, int[] runs, float[] coefs, - AlignedArray src, AlignedArray dst, int crow) - { - Contracts.AssertNonEmpty(mprowiv); - Contracts.Assert(mprowiv.Length == crow); - Contracts.AssertNonEmpty(mprowcol); - Contracts.Assert(mprowcol.Length == crow); - Contracts.Assert(mprowrun == null || mprowrun.Length == crow); - Contracts.AssertNonEmpty(runs); - Contracts.AssertNonEmpty(coefs); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(0 < crow && crow <= dst.Size); - - unsafe - { - fixed (int* pmprowiv = &mprowiv[0]) - fixed (int* pmprowcol = &mprowcol[0]) - fixed (int* pruns = &runs[0]) - fixed (float* pcoefs = &coefs[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - { - if (mprowrun == null) - { - Thunk.MatMulCX(add, pmprowiv, pmprowcol, pruns, pcoefs, - Ptr(src, psrc), Ptr(dst, pdst), crow); - } - else - { - fixed (int* pmprowrun = &mprowrun[0]) - { - Thunk.MatMulDX(add, pmprowiv, pmprowcol, pmprowrun, pruns, pcoefs, - Ptr(src, psrc), Ptr(dst, pdst), crow); - } - } - } - } - } - - public static void MeanOfSrc(bool add, int[] mprowcol, int[] mprowindices, - int[] indices, AlignedArray src, AlignedArray dst, int crow) - { - Contracts.AssertNonEmpty(mprowcol); - Contracts.Assert(mprowcol.Length == crow); - Contracts.Assert(mprowindices == null || mprowindices.Length == crow); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(0 < crow && crow <= dst.Size); - - unsafe - { - fixed (int* pmprowcol = &mprowcol[0]) - fixed (int* pmprowindices = mprowindices) - fixed (int* pindices = &indices[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - { - // REVIEW: Implement using AVX - Thunk.MeanU(add, pmprowcol, pmprowindices, pindices, Ptr(src, psrc), Ptr(dst, pdst), crow); - } - } - } - - public static void MaxOfSrc(bool add, int[] mprowcol, int[] mprowindices, - int[] indices, AlignedArray src, AlignedArray dst, int crow) - { - Contracts.AssertNonEmpty(mprowcol); - Contracts.Assert(mprowcol.Length == crow); - Contracts.Assert(mprowindices == null || mprowindices.Length == crow); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(0 < crow && crow <= dst.Size); - - unsafe - { - fixed (int* pmprowcol = &mprowcol[0]) - fixed (int* pmprowindices = mprowindices) - fixed (int* pindices = &indices[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - { - // REVIEW: Implement using AVX - Thunk.MaxU(add, pmprowcol, pmprowindices, pindices, Ptr(src, psrc), Ptr(dst, pdst), crow); - } - } - } - - public static void RespNormOfSrc(bool add, float alpha, float beta, bool avgOverFullKernel, float offset, - int[] mprowcol, int[] mprowindices, int[] indices, - AlignedArray src, AlignedArray dst, int crow) - { - Contracts.AssertNonEmpty(mprowcol); - Contracts.Assert(mprowcol.Length == crow); - Contracts.Assert(mprowindices == null || mprowindices.Length == crow); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(0 < crow && crow <= dst.Size); - - unsafe - { - fixed (int* pmprowcol = &mprowcol[0]) - fixed (int* pmprowindices = mprowindices) - fixed (int* pindices = &indices[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - { - // REVIEW: Implement using AVX - Thunk.RespNormU(add, alpha, beta, avgOverFullKernel, offset, pmprowcol, pmprowindices, pindices, - Ptr(src, psrc), Ptr(dst, pdst), crow); - } - } - } - - public static void MatTranTimesSrc(bool add, int[] starts, int[] indices, float[] coefs, - AlignedArray src, AlignedArray dst, int ccol) - { - Contracts.AssertNonEmpty(starts); - Contracts.Assert(starts.Length == ccol + 1); - Contracts.Assert(starts[0] == 0); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(starts[ccol] == indices.Length); - Contracts.AssertNonEmpty(coefs); - Contracts.Assert(indices.Length == coefs.Length); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(0 < ccol && ccol <= src.Size); - Contracts.Assert(dst.Size * ccol >= coefs.Length); - - unsafe - { - fixed (int* pstarts = &starts[0]) - fixed (int* pindices = &indices[0]) - fixed (float* pcoefs = &coefs[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.MatMulTranRX(add, pstarts, pindices, pcoefs, Ptr(src, psrc), Ptr(dst, pdst), dst.Size, ccol); - } - } - - public static void MatTranTimesSrc(bool add, int[] mpcoliv, int[] mpcolrow, int[] mpcolrun, - int[] runs, float[] coefs, AlignedArray src, AlignedArray dst, int ccol) - { - Contracts.AssertNonEmpty(mpcoliv); - Contracts.Assert(mpcoliv.Length == ccol); - Contracts.AssertNonEmpty(mpcolrow); - Contracts.Assert(mpcolrow.Length == ccol); - Contracts.AssertNonEmpty(runs); - Contracts.AssertNonEmpty(coefs); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(mpcolrun == null || mpcolrun.Length == ccol); - Contracts.Assert(0 < ccol && ccol <= src.Size); - - unsafe - { - fixed (int* pmpcoliv = &mpcoliv[0]) - fixed (int* pmpcolrow = &mpcolrow[0]) - fixed (int* pruns = &runs[0]) - fixed (float* pcoefs = &coefs[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - { - if (mpcolrun == null) - { - Thunk.MatMulTranCX(add, pmpcoliv, pmpcolrow, pruns, pcoefs, - Ptr(src, psrc), Ptr(dst, pdst), dst.Size, ccol); - } - else - { - fixed (int* pmpcolrun = &mpcolrun[0]) - { - Thunk.MatMulTranDX(add, pmpcoliv, pmpcolrow, pmpcolrun, pruns, pcoefs, - Ptr(src, psrc), Ptr(dst, pdst), dst.Size, ccol); - } - } - } - } - } - - public static void MeanBackOfSrc(bool add, int[] mpcolrow, int[] mpcolindices, - int[] indices, AlignedArray src, AlignedArray dst, int ccol) - { - Contracts.AssertNonEmpty(mpcolrow); - Contracts.Assert(mpcolrow.Length == ccol); - Contracts.Assert(mpcolindices == null || mpcolindices.Length == ccol); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(0 < ccol && ccol <= src.Size); - - unsafe - { - fixed (int* pmpcolrow = &mpcolrow[0]) - fixed (int* pmpcolindices = mpcolindices) - fixed (int* pindices = &indices[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - { - // REVIEW: Implement using AVX - Thunk.MeanBackU(add, pmpcolrow, pmpcolindices, pindices, Ptr(src, psrc), Ptr(dst, pdst), dst.Size, ccol); - } - } - } - - public static void MaxBackOfSrc(bool add, int[] mpcolrow, int[] mpcolindices, - int[] indices, AlignedArray src, AlignedArray dst, AlignedArray val, int ccol) - { - Contracts.AssertNonEmpty(mpcolrow); - Contracts.Assert(mpcolrow.Length == ccol); - Contracts.Assert(mpcolindices == null || mpcolindices.Length == ccol); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(Compat(val)); - Contracts.Assert(0 < ccol && ccol <= src.Size); - Contracts.Assert(dst.Size == val.Size); - - unsafe - { - fixed (int* pmpcolrow = &mpcolrow[0]) - fixed (int* pmpcolindices = mpcolindices) - fixed (int* pindices = &indices[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - fixed (float* pval = &val.Items[0]) - { - // REVIEW: Implement using AVX - Thunk.MaxBackU(add, pmpcolrow, pmpcolindices, pindices, Ptr(src, psrc), Ptr(dst, pdst), Ptr(val, pval), dst.Size, ccol); - } - } - } - - public static void RespNormBackOfSrc(bool add, float alpha, float beta, bool avgOverFullKernel, float offset, - int[] mpcolrow, int[] mpcolindices, int[] indices, - AlignedArray errors, AlignedArray errorsPrev, AlignedArray valuesPrev, int ccol) - { - Contracts.AssertNonEmpty(mpcolrow); - Contracts.Assert(mpcolrow.Length == ccol); - Contracts.Assert(mpcolindices == null || mpcolindices.Length == ccol); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(Compat(errors)); - Contracts.Assert(Compat(errorsPrev)); - Contracts.Assert(Compat(valuesPrev)); - Contracts.Assert(0 < ccol && ccol <= errors.Size); - Contracts.Assert(errorsPrev.Size == valuesPrev.Size); - - unsafe - { - fixed (int* pmpcolrow = &mpcolrow[0]) - fixed (int* pmpcolindices = mpcolindices) - fixed (int* pindices = &indices[0]) - fixed (float* perr = &errors.Items[0]) - fixed (float* perrPrev = &errorsPrev.Items[0]) - fixed (float* pvalPrev = &valuesPrev.Items[0]) - { - // REVIEW: Implement using AVX - Thunk.RespNormBackU(add, alpha, beta, avgOverFullKernel, offset, pmpcolrow, pmpcolindices, pindices, - Ptr(errors, perr), Ptr(errorsPrev, perrPrev), Ptr(valuesPrev, pvalPrev), errorsPrev.Size, ccol); - } - } - } - - public static void AddXYTran(float a, AlignedArray x, AlignedArray y, AlignedArray mat, int crow, float decay) - { - Contracts.Assert(Compat(x)); - Contracts.Assert(Compat(y)); - Contracts.Assert(Compat(mat)); - Contracts.Assert(0 < crow && crow <= x.Size); - Contracts.Assert(x.Size * y.Size == mat.Size); - Contracts.Assert(decay >= 0); - - unsafe - { - fixed (float* px = &x.Items[0]) - fixed (float* py = &y.Items[0]) - fixed (float* pmat = &mat.Items[0]) - Thunk.AddXYTranX(a, Ptr(x, px), Ptr(y, py), Ptr(mat, pmat), crow, y.Size, decay); - } - } - - public static void AddXYTran(float a, AlignedArray x, int[] rgposY, AlignedArray valuesY, - int posMinY, int iposMinY, int iposLimY, AlignedArray mat, int crow) - { - Contracts.Assert(Compat(x)); - Contracts.Assert(Compat(valuesY)); - Contracts.Assert(Compat(mat)); - Contracts.AssertNonEmpty(rgposY); - Contracts.Assert(0 <= iposMinY && iposMinY <= iposLimY && iposLimY <= rgposY.Length); - Contracts.Assert(0 < crow && crow <= x.Size); - Contracts.Assert(x.Size * valuesY.Size == mat.Size); - - if (iposMinY >= iposLimY) - return; - - unsafe - { - fixed (float* px = &x.Items[0]) - fixed (float* py = &valuesY.Items[0]) - fixed (int* pposy = &rgposY[0]) - fixed (float* pmat = &mat.Items[0]) - { - Thunk.AddXYTranPX(a, Ptr(x, px), pposy, Ptr(valuesY, py), posMinY, iposMinY, iposLimY, Ptr(mat, pmat), - crow, valuesY.Size); - } - } - } - - public static void AddXYTran(float a, AlignedArray x, AlignedArray y, - int[] starts, int[] indices, float[] coefs, int crow, float decay) - { - Contracts.Assert(Compat(x)); - Contracts.Assert(Compat(y)); - Contracts.Assert(0 < crow && crow <= x.Size); - Contracts.AssertNonEmpty(starts); - Contracts.Assert(starts.Length == crow + 1); - Contracts.Assert(starts[0] == 0); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(starts[crow] == indices.Length); - Contracts.AssertNonEmpty(coefs); - Contracts.Assert(indices.Length == coefs.Length); - Contracts.Assert(crow * y.Size >= coefs.Length); - Contracts.Assert(decay >= 0); - - unsafe - { - fixed (float* px = &x.Items[0]) - fixed (float* py = &y.Items[0]) - fixed (int* pstarts = &starts[0]) - fixed (int* pindices = &indices[0]) - fixed (float* pcoefs = &coefs[0]) - Thunk.AddXYTranRX(a, Ptr(x, px), Ptr(y, py), pstarts, pindices, pcoefs, crow, decay); - } - } - - public static void AddXYTran(float a, AlignedArray x, AlignedArray y, int[] mprowiv, - int[] mprowcol, int[] mprowrun, int[] runs, float[] coefs, int crow) - { - Contracts.Assert(Compat(x)); - Contracts.Assert(Compat(y)); - Contracts.Assert(0 < crow && crow <= x.Size); - Contracts.AssertNonEmpty(mprowiv); - Contracts.Assert(mprowiv.Length == crow); - Contracts.AssertNonEmpty(mprowcol); - Contracts.Assert(mprowcol.Length == crow); - Contracts.Assert(mprowrun == null || mprowrun.Length == crow); - Contracts.AssertNonEmpty(runs); - Contracts.AssertNonEmpty(coefs); - - unsafe - { - fixed (float* px = &x.Items[0]) - fixed (float* py = &y.Items[0]) - fixed (int* pmprowiv = &mprowiv[0]) - fixed (int* pmprowcol = &mprowcol[0]) - fixed (int* pruns = &runs[0]) - fixed (float* pcoefs = &coefs[0]) - { - if (mprowrun == null) - Thunk.AddXYTranCX(a, Ptr(x, px), Ptr(y, py), pmprowiv, pmprowcol, pruns, pcoefs, crow); - else - { - fixed (int* pmprowrun = mprowrun) - Thunk.AddXYTranDX(a, Ptr(x, px), Ptr(y, py), pmprowiv, pmprowcol, pmprowrun, pruns, pcoefs, crow); - } - } - } - } - - public static void AddXYTran(float a, AlignedArray x, AlignedArray y, AlignedArray mat, float momentum, AlignedArray delta, int crow) - { - Contracts.Assert(Compat(x)); - Contracts.Assert(Compat(y)); - Contracts.Assert(Compat(mat)); - Contracts.Assert(Compat(delta)); - Contracts.Assert(0 < crow && crow <= x.Size); - Contracts.Assert(x.Size * y.Size == mat.Size); - Contracts.Assert(mat.Size == delta.Size); - - unsafe - { - fixed (float* px = &x.Items[0]) - fixed (float* py = &y.Items[0]) - fixed (float* pmat = &mat.Items[0]) - fixed (float* pdel = &delta.Items[0]) - Thunk.AddXYTranMomX(a, Ptr(x, px), Ptr(y, py), Ptr(mat, pmat), momentum, Ptr(delta, pdel), crow, y.Size); - } - } - - public static void AddXYTran(AlignedArray x, AlignedArray y, AlignedArray mat, AlignedArray accGrads, AlignedArray accUpdates, - float decay, float cond, int crow) - { - Contracts.Assert(Compat(x)); - Contracts.Assert(Compat(y)); - Contracts.Assert(Compat(mat)); - Contracts.Assert(Compat(accGrads)); - Contracts.Assert(Compat(accUpdates)); - Contracts.Assert(0 < crow && crow <= x.Size); - Contracts.Assert(x.Size * y.Size == mat.Size); - Contracts.Assert(mat.Size == accGrads.Size); - Contracts.Assert(mat.Size == accUpdates.Size); - - unsafe - { - fixed (float* px = &x.Items[0]) - fixed (float* py = &y.Items[0]) - fixed (float* pmat = &mat.Items[0]) - fixed (float* pag = &accGrads.Items[0]) - fixed (float* pau = &accUpdates.Items[0]) - Thunk.AddXYTranGradX(Ptr(x, px), Ptr(y, py), Ptr(mat, pmat), Ptr(accGrads, pag), Ptr(accUpdates, pau), decay, cond, crow, y.Size); - } - } - - public static void AddXYTran(AlignedArray x, AlignedArray y, int[] starts, int[] indices, - float[] coefs, float[] accGrads, float[] accUpdates, float decay, float cond, int crow) - { - Contracts.Assert(Compat(x)); - Contracts.Assert(Compat(y)); - Contracts.Assert(0 < crow && crow <= x.Size); - Contracts.AssertNonEmpty(starts); - Contracts.Assert(starts.Length == crow + 1); - Contracts.Assert(starts[0] == 0); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(starts[crow] == indices.Length); - Contracts.AssertNonEmpty(coefs); - Contracts.Assert(indices.Length == coefs.Length); - Contracts.Assert(crow * y.Size >= coefs.Length); - Contracts.AssertNonEmpty(accGrads); - Contracts.Assert(coefs.Length == accGrads.Length); - Contracts.AssertNonEmpty(accUpdates); - Contracts.Assert(coefs.Length == accUpdates.Length); - - unsafe - { - fixed (float* px = &x.Items[0]) - fixed (float* py = &y.Items[0]) - fixed (int* pstarts = &starts[0]) - fixed (int* pindices = &indices[0]) - fixed (float* pcoefs = &coefs[0]) - fixed (float* pag = &accGrads[0]) - fixed (float* pau = &accUpdates[0]) - Thunk.AddXYTranGradRX(Ptr(x, px), Ptr(y, py), pstarts, pindices, pcoefs, pag, pau, decay, cond, crow); - } - } - - public static void AddXYTran(AlignedArray x, int[] rgposY, AlignedArray valuesY, - int posMinY, int iposMinY, int iposLimY, AlignedArray mat, - AlignedArray accGrads, AlignedArray accUpdates, float decay, float cond, int crow) - { - Contracts.Assert(Compat(x)); - Contracts.AssertNonEmpty(rgposY); - Contracts.Assert(Compat(valuesY)); - Contracts.Assert(Compat(mat)); - Contracts.Assert(0 <= iposMinY && iposMinY <= iposLimY && iposLimY <= rgposY.Length); - Contracts.Assert(0 < crow && crow <= x.Size); - Contracts.Assert(x.Size * valuesY.Size == mat.Size); - Contracts.Assert(mat.Size == accGrads.Size); - Contracts.Assert(mat.Size == accUpdates.Size); - - if (iposMinY >= iposLimY) - return; - - unsafe - { - fixed (float* px = &x.Items[0]) - fixed (float* py = &valuesY.Items[0]) - fixed (int* pposy = &rgposY[0]) - fixed (float* pmat = &mat.Items[0]) - fixed (float* pag = &accGrads.Items[0]) - fixed (float* pau = &accUpdates.Items[0]) - { - Thunk.AddXYTranGradPX(Ptr(x, px), pposy, Ptr(valuesY, py), posMinY, iposMinY, iposLimY, Ptr(mat, pmat), - Ptr(accGrads, pag), Ptr(accUpdates, pau), decay, cond, crow, valuesY.Size); - } - } - } - - public static void Scale(float a, AlignedArray dst) - { - Contracts.Assert(Compat(dst)); - - unsafe - { - fixed (float* pdst = &dst.Items[0]) - Thunk.ScaleX(a, Ptr(dst, pdst), dst.Size); - } - } - - public static void Scale(float a, float[] dst, int count) - { - Contracts.AssertNonEmpty(dst); - Contracts.Assert(0 < count && count <= dst.Length); - - unsafe - { - fixed (float* pd = &dst[0]) - Thunk.Scale(a, pd, count); - } - } - - public static void ScaleConvWeights(float a, int kernelSize, float[] dst) - { - Contracts.AssertValue(dst); - - // REVIEW: implement in SSE/AVX. - for (int istart = 0; istart < dst.Length; istart += kernelSize + 1) - { - for (int i = 0; i < kernelSize; i++) - dst[istart + i] *= a; - } - } - - public static void ScaleMaxNorm(bool tran, float maxNorm, AlignedArray mat, int crun, int runLenPhy) - { - // Called only with Avx alignment. - Contracts.Assert(Compat(mat)); - - unsafe - { - fixed (float* pmat = &mat.Items[0]) - { - if (!tran) - Thunk.ScaleMaxNormX(maxNorm, Ptr(mat, pmat), crun, runLenPhy); - else - Thunk.ScaleMaxNormTranU(maxNorm, Ptr(mat, pmat), crun, runLenPhy); - } - } - } - - public static void ScaleMaxNorm(float maxNorm, int[] starts, int[] indices, float[] mat) - { - Contracts.AssertNonEmpty(starts); - - int crow = starts.Length - 1; - Contracts.Assert(starts[0] == 0); - Contracts.AssertValue(indices); - Contracts.Assert(starts[crow] == indices.Length); - Contracts.AssertNonEmpty(mat); - - unsafe - { - fixed (int* pstarts = &starts[0]) - fixed (float* pmat = &mat[0]) - Thunk.ScaleMaxNormRU(maxNorm, pstarts, pmat, crow); - } - } - - public static void ScaleMaxNorm(float maxNorm, int kernCount, int kernSize, float[] mat) - { - Contracts.AssertNonEmpty(mat); - - unsafe - { - fixed (float* pmat = &mat[0]) - Thunk.ScaleMaxNormCU(maxNorm, kernCount, kernSize, pmat); - } - } - - public static void AddScale(float a, AlignedArray src, AlignedArray dst) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.AddScaleX(a, Ptr(src, psrc), Ptr(dst, pdst), dst.Size); - } - } - - public static void AddScale(float a, AlignedArray src, AlignedArray dst, float momentum, AlignedArray delta) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(Compat(delta)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(src.Size == delta.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - fixed (float* pdel = &delta.Items[0]) - Thunk.AddScaleMomX(a, Ptr(src, psrc), Ptr(dst, pdst), momentum, Ptr(delta, pdel), dst.Size); - } - } - - public static void AddScale(float a, float[] src, float[] dst, int count) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(0 < count && count <= src.Length); - Contracts.AssertNonEmpty(dst); - Contracts.Assert(count <= dst.Length); - - unsafe - { - fixed (float* psrc = &src[0]) - fixed (float* pdst = &dst[0]) - Thunk.AddScaleU(a, psrc, pdst, count); - } - } - - public static void AddScale(AlignedArray src, AlignedArray dst, - AlignedArray accGrads, AlignedArray accUpdates, float decay, float cond) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(Compat(accGrads)); - Contracts.Assert(Compat(accUpdates)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(src.Size == accGrads.Size); - Contracts.Assert(src.Size == accUpdates.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - fixed (float* pag = &accGrads.Items[0]) - fixed (float* pau = &accUpdates.Items[0]) - Thunk.AddScaleGradX(Ptr(src, psrc), Ptr(dst, pdst), Ptr(accGrads, pag), Ptr(accUpdates, pau), decay, cond, dst.Size); - } - } - - public static void Add(AlignedArray src, AlignedArray dst) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.AddX(Ptr(src, psrc), Ptr(dst, pdst), dst.Size); - } - } - - public static void Add(float[] src, float[] dst, int count) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(0 < count && count <= src.Length); - Contracts.AssertNonEmpty(dst); - Contracts.Assert(count <= dst.Length); - - unsafe - { - fixed (float* ps = &src[0]) - fixed (float* pd = &dst[0]) - Thunk.AddU(ps, pd, count); - } - } - - public static float Sum(AlignedArray src) - { - Contracts.Assert(Compat(src)); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - return Thunk.SumX(Ptr(src, psrc), src.Size); - } - } - - public static float Sum(float[] src, int count) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(0 < count && count <= src.Length); - - unsafe - { - fixed (float* psrc = &src[0]) - return Thunk.SumU(psrc, count); - } - } - - public static float SumSq(float[] src, int count) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(0 < count && count <= src.Length); - - unsafe - { - fixed (float* psrc = &src[0]) - return Thunk.SumSqU(psrc, count); - } - } - - public static float SumAbs(float[] src, int count) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(0 < count && count <= src.Length); - - unsafe - { - fixed (float* psrc = &src[0]) - return Thunk.SumAbsU(psrc, count); - } - } - - public static float MaxAbs(float[] src, int count) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(0 < count && count <= src.Length); - - unsafe - { - fixed (float* psrc = &src[0]) - return Thunk.MaxAbsU(psrc, count); - } - } - - public static float DotProductSparse(float[] a, float[] b, int[] indices, int count) - { - Contracts.AssertNonEmpty(a); - Contracts.AssertNonEmpty(b); - Contracts.Assert(0 < count); - Contracts.Assert(count < a.Length); - Contracts.Assert(count <= b.Length); - Contracts.Assert(count <= indices.Length); - - unsafe - { - fixed (float* pa = &a[0]) - fixed (float* pb = &b[0]) - fixed (int* pi = &indices[0]) - return Thunk.DotSU(pa, pb, pi, count); - } - } - - public static float DotProductSparse(float[] a, int offset, float[] b, int[] indices, int count) - { - Contracts.AssertNonEmpty(a); - Contracts.Assert(0 < count); - Contracts.Assert(0 <= offset && offset < a.Length); - Contracts.Assert(a.Length - offset > count); - Contracts.AssertNonEmpty(b); - Contracts.Assert(count <= b.Length); - Contracts.Assert(count <= indices.Length); - - unsafe - { - fixed (float* pa = &a[offset]) - fixed (float* pb = &b[0]) - fixed (int* pi = &indices[0]) - return Thunk.DotSU(pa, pb, pi, count); - } - } - - public static void ApplySigmoid(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 < c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplySigmoidX(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplySoftMax(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 < c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplySoftMaxX(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplyRectifiedLinear(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 < c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplyRectifiedLinearX(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplySquare(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 < c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplySquareX(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplySqrt(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 < c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplySqrtX(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplySoftRectifiedLinear(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 < c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplySoftRectifiedLinearX(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplyAbs(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 < c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplyAbsX(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplyTanh(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 < c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplyTanhX(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplyBoundedRectifiedLinear(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 <= c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplyBoundedRectifiedLinearX(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplySigmoidDerivative(AlignedArray value, AlignedArray grad) - { - Contracts.Assert(Compat(value)); - Contracts.Assert(Compat(grad)); - Contracts.Assert(value.Size == grad.Size); - - unsafe - { - fixed (float* pvalue = &value.Items[0]) - fixed (float* pgrad = &grad.Items[0]) - Thunk.ApplySigmoidDerivativeX(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size); - } - } - - public static void ApplyRectifiedLinearDerivative(AlignedArray value, AlignedArray grad) - { - Contracts.Assert(Compat(value)); - Contracts.Assert(Compat(grad)); - Contracts.Assert(value.Size == grad.Size); - - unsafe - { - fixed (float* pvalue = &value.Items[0]) - fixed (float* pgrad = &grad.Items[0]) - Thunk.ApplyRectifiedLinearDerivativeX(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size); - } - } - - public static void ApplySquareDerivative(AlignedArray input, AlignedArray output, AlignedArray grad, bool drop) - { - Contracts.Assert(Compat(input)); - Contracts.Assert(Compat(output)); - Contracts.Assert(Compat(grad)); - Contracts.Assert(output.Size == input.Size); - Contracts.Assert(output.Size == grad.Size); - - unsafe - { - fixed (float* px = &input.Items[0]) - fixed (float* py = &output.Items[0]) - fixed (float* pg = &grad.Items[0]) - Thunk.ApplySquareDerivativeX(Ptr(input, px), Ptr(output, py), Ptr(grad, pg), grad.Size, drop); - } - } - - public static void ApplySqrtDerivative(AlignedArray value, AlignedArray grad) - { - Contracts.Assert(Compat(value)); - Contracts.Assert(Compat(grad)); - Contracts.Assert(value.Size == grad.Size); - - unsafe - { - fixed (float* pvalue = &value.Items[0]) - fixed (float* pgrad = &grad.Items[0]) - Thunk.ApplySqrtDerivativeX(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size); - } - } - - public static void ApplySoftRectifiedLinearDerivative(AlignedArray input, AlignedArray output, AlignedArray grad) - { - Contracts.Assert(Compat(input)); - Contracts.Assert(Compat(output)); - Contracts.Assert(Compat(grad)); - Contracts.Assert(output.Size == input.Size); - Contracts.Assert(output.Size == grad.Size); - - unsafe - { - fixed (float* px = &input.Items[0]) - fixed (float* py = &output.Items[0]) - fixed (float* pg = &grad.Items[0]) - Thunk.ApplySoftRectifiedLinearDerivativeX(Ptr(input, px), Ptr(output, py), Ptr(grad, pg), grad.Size); - } - } - - public static void ApplyAbsDerivative(AlignedArray input, AlignedArray output, AlignedArray grad, bool drop) - { - Contracts.Assert(Compat(input)); - Contracts.Assert(Compat(output)); - Contracts.Assert(Compat(grad)); - Contracts.Assert(output.Size == input.Size); - Contracts.Assert(output.Size == grad.Size); - - unsafe - { - fixed (float* px = &input.Items[0]) - fixed (float* py = &output.Items[0]) - fixed (float* pg = &grad.Items[0]) - Thunk.ApplyAbsDerivativeX(Ptr(input, px), Ptr(output, py), Ptr(grad, pg), grad.Size, drop); - } - } - - public static void ApplyTanhDerivative(AlignedArray value, AlignedArray grad) - { - Contracts.Assert(Compat(value)); - Contracts.Assert(Compat(grad)); - Contracts.Assert(value.Size == grad.Size); - - unsafe - { - fixed (float* pvalue = &value.Items[0]) - fixed (float* pgrad = &grad.Items[0]) - Thunk.ApplyTanhDerivativeX(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size); - } - } - - public static void ApplyBoundedRectifiedLinearDerivative(AlignedArray value, AlignedArray grad) - { - Contracts.Assert(Compat(value)); - Contracts.Assert(Compat(grad)); - Contracts.Assert(value.Size == grad.Size); - - unsafe - { - fixed (float* pvalue = &value.Items[0]) - fixed (float* pgrad = &grad.Items[0]) - Thunk.ApplyBoundedRectifiedLinearDerivativeX(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size); - } - } - - public static void ZeroMatrixItems(AlignedArray dst, int ccol, int cfltRow, int[] indices) - { - Contracts.Assert(0 < ccol && ccol <= cfltRow); - - unsafe - { - fixed (float* pdst = &dst.Items[0]) - fixed (int* pi = &indices[0]) - { - if (ccol == cfltRow) - Thunk.ZeroItemsU(Ptr(dst, pdst), dst.Size, pi, indices.Length); - else - Thunk.ZeroMatrixItemsCore(Ptr(dst, pdst), dst.Size, ccol, cfltRow, pi, indices.Length); - } - } - } - - public static void ScaleAdadelta(float[] mat, float[] accGrads, float[] accUpdates, float decay, float cond, float[] grads) - { - Contracts.AssertNonEmpty(mat); - Contracts.AssertNonEmpty(accGrads); - Contracts.AssertNonEmpty(accUpdates); - Contracts.Assert(mat.Length == accGrads.Length); - Contracts.Assert(mat.Length == accUpdates.Length); - Contracts.Assert(mat.Length <= grads.Length); - - unsafe - { - fixed (float* pm = &mat[0]) - fixed (float* pag = &accGrads[0]) - fixed (float* pau = &accUpdates[0]) - fixed (float* pg = &grads[0]) - Thunk.ScaleAdadeltaX(pm, pag, pau, decay, cond, pg, mat.Length); - } - } - } -} \ No newline at end of file diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index d99a3b0570..c812b47687 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -46,25 +46,6 @@ internal static class AvxIntrinsics private static readonly Vector256 _absMask256 = Avx.StaticCast(Avx.SetAllVector256(0x7FFFFFFF)); - private const int Vector256Alignment = 32; - - [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] - private static bool HasCompatibleAlignment(AlignedArray alignedArray) - { - Contracts.AssertValue(alignedArray); - Contracts.Assert(alignedArray.Size > 0); - return (alignedArray.CbAlign % Vector256Alignment) == 0; - } - - [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] - private static unsafe float* GetAlignedBase(AlignedArray alignedArray, float* unalignedBase) - { - Contracts.AssertValue(alignedArray); - float* alignedBase = unalignedBase + alignedArray.GetBase((long)unalignedBase); - Contracts.Assert(((long)alignedBase % Vector256Alignment) == 0); - return alignedBase; - } - [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] private static Vector128 GetHigh(in Vector256 x) => Avx.ExtractVector128(x, 1); @@ -170,19 +151,19 @@ private static Vector256 MultiplyAdd(Vector256 src1, Vector256 mat, ReadOnlySpan src, Span dst, int crow, int ccol) { - fixed (float* psrc = &src[0]) - fixed (float* pdst = &dst[0]) - fixed (float* pmat = &mat[0]) + Contracts.Assert(crow % 4 == 0); + Contracts.Assert(ccol % 4 == 0); + + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + fixed (float* pmat = &MemoryMarshal.GetReference(mat)) fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) { @@ -312,24 +293,27 @@ public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int cro } // Partial sparse source vector. - public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArray src, - int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) + public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan rgposSrc, AlignedArray src, + int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) { - Contracts.Assert(HasCompatibleAlignment(mat)); - Contracts.Assert(HasCompatibleAlignment(src)); - Contracts.Assert(HasCompatibleAlignment(dst)); + MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol); + } + + public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgposSrc, ReadOnlySpan src, + int posMin, int iposMin, int iposEnd, Span dst, int crow, int ccol) + { + Contracts.Assert(crow % 8 == 0); + Contracts.Assert(ccol % 8 == 0); // REVIEW: For extremely sparse inputs, interchanging the loops would // likely be more efficient. - fixed (float* pSrcStart = &src.Items[0]) - fixed (float* pDstStart = &dst.Items[0]) - fixed (float* pMatStart = &mat.Items[0]) - fixed (int* pposSrc = &rgposSrc[0]) + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + fixed (float* pmat = &MemoryMarshal.GetReference(mat)) + fixed (int* pposSrc = &MemoryMarshal.GetReference(rgposSrc)) + fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) + fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) { - float* psrc = GetAlignedBase(src, pSrcStart); - float* pdst = GetAlignedBase(dst, pDstStart); - float* pmat = GetAlignedBase(mat, pMatStart); - int* pposMin = pposSrc + iposMin; int* pposEnd = pposSrc + iposEnd; float* pDstEnd = pdst + crow; @@ -337,7 +321,106 @@ public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArra float* pSrcCurrent = psrc - posMin; float* pDstCurrent = pdst; - while (pDstCurrent < pDstEnd) + nuint address = (nuint)(pDstCurrent); + int misalignment = (int)(address % 32); + int length = crow; + int remainder = 0; + + if ((misalignment & 3) != 0) + { + while (pDstCurrent < pDstEnd) + { + Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow()); + pDstCurrent += 8; + pm0 += 8 * ccol; + } + } + else + { + if (misalignment != 0) + { + misalignment >>= 2; + misalignment = 8 - misalignment; + + Vector256 mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); + + float* pm1 = pm0 + ccol; + float* pm2 = pm1 + ccol; + float* pm3 = pm2 + ccol; + Vector256 result = Avx.SetZeroVector256(); + + int* ppos = pposMin; + + while (ppos < pposEnd) + { + int col1 = *ppos; + int col2 = col1 + 4 * ccol; + Vector256 x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2], + pm3[col1], pm2[col1], pm1[col1], pm0[col1]); + + x1 = Avx.And(mask, x1); + Vector256 x2 = Avx.SetAllVector256(pSrcCurrent[col1]); + result = MultiplyAdd(x2, x1, result); + ppos++; + } + + Avx.Store(pDstCurrent, result); + pDstCurrent += misalignment; + pm0 += misalignment * ccol; + length -= misalignment; + } + + if (length > 7) + { + remainder = length % 8; + while (pDstCurrent < pDstEnd) + { + Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow()); + pDstCurrent += 8; + pm0 += 8 * ccol; + } + } + else + { + remainder = length; + } + + if (remainder != 0) + { + pDstCurrent -= (8 - remainder); + pm0 -= (8 - remainder) * ccol; + Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); + Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); + + float* pm1 = pm0 + ccol; + float* pm2 = pm1 + ccol; + float* pm3 = pm2 + ccol; + Vector256 result = Avx.SetZeroVector256(); + + int* ppos = pposMin; + + while (ppos < pposEnd) + { + int col1 = *ppos; + int col2 = col1 + 4 * ccol; + Vector256 x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2], + pm3[col1], pm2[col1], pm1[col1], pm0[col1]); + x1 = Avx.And(x1, trailingMask); + + Vector256 x2 = Avx.SetAllVector256(pSrcCurrent[col1]); + result = MultiplyAdd(x2, x1, result); + ppos++; + } + + result = Avx.Add(result, Avx.And(leadingMask, Avx.LoadVector256(pDstCurrent))); + + Avx.Store(pDstCurrent, result); + pDstCurrent += 8; + pm0 += 8 * ccol; + } + } + + Vector256 SparseMultiplicationAcrossRow() { float* pm1 = pm0 + ccol; float* pm2 = pm1 + ccol; @@ -351,33 +434,30 @@ public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArra int col1 = *ppos; int col2 = col1 + 4 * ccol; Vector256 x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2], - pm3[col1], pm2[col1], pm1[col1], pm0[col1]); + pm3[col1], pm2[col1], pm1[col1], pm0[col1]); Vector256 x2 = Avx.SetAllVector256(pSrcCurrent[col1]); result = MultiplyAdd(x2, x1, result); - ppos++; } - Avx.StoreAligned(pDstCurrent, result); - pDstCurrent += 8; - pm0 += 8 * ccol; + return result; } } } - public static unsafe void MatMulTranX(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) + public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { - Contracts.Assert(crow % 4 == 0); - Contracts.Assert(ccol % 4 == 0); - - MatMulTranX(mat.Items, src.Items, dst.Items, crow, ccol); + MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol); } - public static unsafe void MatMulTranX(float[] mat, float[] src, float[] dst, int crow, int ccol) + public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crow, int ccol) { - fixed (float* psrc = &src[0]) - fixed (float* pdst = &dst[0]) - fixed (float* pmat = &mat[0]) + Contracts.Assert(crow % 4 == 0); + Contracts.Assert(ccol % 4 == 0); + + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + fixed (float* pmat = &MemoryMarshal.GetReference(mat)) fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) { diff --git a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs index aa8ff85bc6..a41ea6f703 100644 --- a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs +++ b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs @@ -34,12 +34,12 @@ public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, Al if (!tran) { Contracts.Assert(crun <= dst.Size); - AvxIntrinsics.MatMulX(mat, src, dst, crun, src.Size); + AvxIntrinsics.MatMul(mat, src, dst, crun, src.Size); } else { Contracts.Assert(crun <= src.Size); - AvxIntrinsics.MatMulTranX(mat, src, dst, dst.Size, crun); + AvxIntrinsics.MatMulTran(mat, src, dst, dst.Size, crun); } } else if (Sse.IsSupported) @@ -109,12 +109,12 @@ public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray sr if (Avx.IsSupported) { Contracts.Assert(crun <= dst.Size); - AvxIntrinsics.MatMulPX(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size); + AvxIntrinsics.MatMulP(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size); } else if (Sse.IsSupported) { Contracts.Assert(crun <= dst.Size); - SseIntrinsics.MatMulPA(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size); + SseIntrinsics.MatMulP(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size); } else { diff --git a/src/Microsoft.ML.CpuMath/Sse.cs b/src/Microsoft.ML.CpuMath/Sse.cs index 4281893852..a25c286d55 100644 --- a/src/Microsoft.ML.CpuMath/Sse.cs +++ b/src/Microsoft.ML.CpuMath/Sse.cs @@ -81,500 +81,7 @@ public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray sr fixed (int* ppossrc = &rgposSrc[0]) { Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMulPA(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); - } - } - } - - public static void MatTimesSrc(bool add, int[] starts, int[] indices, float[] coefs, - AlignedArray src, AlignedArray dst, int crow) - { - Contracts.AssertNonEmpty(starts); - Contracts.Assert(starts.Length == crow + 1); - Contracts.Assert(starts[0] == 0); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(starts[crow] == indices.Length); - Contracts.AssertNonEmpty(coefs); - Contracts.Assert(indices.Length == coefs.Length); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(0 < crow && crow <= dst.Size); - Contracts.Assert(crow * src.Size >= coefs.Length); - - unsafe - { - fixed (int* pstarts = &starts[0]) - fixed (int* pindices = &indices[0]) - fixed (float* pcoefs = &coefs[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.MatMulRU(add, pstarts, pindices, pcoefs, Ptr(src, psrc), Ptr(dst, pdst), crow); - } - } - - public static void MatTimesSrc(bool add, int[] mprowiv, int[] mprowcol, - int[] mprowrun, int[] runs, float[] coefs, - AlignedArray src, AlignedArray dst, int crow) - { - Contracts.AssertNonEmpty(mprowiv); - Contracts.Assert(mprowiv.Length == crow); - Contracts.AssertNonEmpty(mprowcol); - Contracts.Assert(mprowcol.Length == crow); - Contracts.Assert(mprowrun == null || mprowrun.Length == crow); - Contracts.AssertNonEmpty(runs); - Contracts.AssertNonEmpty(coefs); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(0 < crow && crow <= dst.Size); - - unsafe - { - fixed (int* pmprowiv = &mprowiv[0]) - fixed (int* pmprowcol = &mprowcol[0]) - fixed (int* pruns = &runs[0]) - fixed (float* pcoefs = &coefs[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - { - if (mprowrun == null) - { - Thunk.MatMulCU(add, pmprowiv, pmprowcol, pruns, pcoefs, - Ptr(src, psrc), Ptr(dst, pdst), crow); - } - else - { - fixed (int* pmprowrun = &mprowrun[0]) - { - Thunk.MatMulDU(add, pmprowiv, pmprowcol, pmprowrun, pruns, pcoefs, - Ptr(src, psrc), Ptr(dst, pdst), crow); - } - } - } - } - } - - public static void MeanOfSrc(bool add, int[] mprowcol, int[] mprowindices, - int[] indices, AlignedArray src, AlignedArray dst, int crow) - { - Contracts.AssertNonEmpty(mprowcol); - Contracts.Assert(mprowcol.Length == crow); - Contracts.Assert(mprowindices == null || mprowindices.Length == crow); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(0 < crow && crow <= dst.Size); - - unsafe - { - fixed (int* pmprowcol = &mprowcol[0]) - fixed (int* pmprowindices = mprowindices) - fixed (int* pindices = &indices[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.MeanU(add, pmprowcol, pmprowindices, pindices, Ptr(src, psrc), Ptr(dst, pdst), crow); - } - } - - public static void MaxOfSrc(bool add, int[] mprowcol, int[] mprowindices, - int[] indices, AlignedArray src, AlignedArray dst, int crow) - { - Contracts.AssertNonEmpty(mprowcol); - Contracts.Assert(mprowcol.Length == crow); - Contracts.Assert(mprowindices == null || mprowindices.Length == crow); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(0 < crow && crow <= dst.Size); - - unsafe - { - fixed (int* pmprowcol = &mprowcol[0]) - fixed (int* pmprowindices = mprowindices) - fixed (int* pindices = &indices[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.MaxU(add, pmprowcol, pmprowindices, pindices, Ptr(src, psrc), Ptr(dst, pdst), crow); - } - } - - public static void RespNormOfSrc(bool add, float alpha, float beta, bool avgOverFullKernel, float offset, - int[] mprowcol, int[] mprowindices, int[] indices, - AlignedArray src, AlignedArray dst, int crow) - { - Contracts.AssertNonEmpty(mprowcol); - Contracts.Assert(mprowcol.Length == crow); - Contracts.Assert(mprowindices == null || mprowindices.Length == crow); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(0 < crow && crow <= dst.Size); - - unsafe - { - fixed (int* pmprowcol = &mprowcol[0]) - fixed (int* pmprowindices = mprowindices) - fixed (int* pindices = &indices[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - { - Thunk.RespNormU(add, alpha, beta, avgOverFullKernel, offset, pmprowcol, pmprowindices, pindices, - Ptr(src, psrc), Ptr(dst, pdst), crow); - } - } - } - - public static void MatTranTimesSrc(bool add, int[] starts, int[] indices, float[] coefs, - AlignedArray src, AlignedArray dst, int ccol) - { - Contracts.AssertNonEmpty(starts); - Contracts.Assert(starts.Length == ccol + 1); - Contracts.Assert(starts[0] == 0); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(starts[ccol] == indices.Length); - Contracts.AssertNonEmpty(coefs); - Contracts.Assert(indices.Length == coefs.Length); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(0 < ccol && ccol <= src.Size); - Contracts.Assert(dst.Size * ccol >= coefs.Length); - - unsafe - { - fixed (int* pstarts = &starts[0]) - fixed (int* pindices = &indices[0]) - fixed (float* pcoefs = &coefs[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.MatMulTranRU(add, pstarts, pindices, pcoefs, Ptr(src, psrc), Ptr(dst, pdst), dst.Size, ccol); - } - } - - public static void MatTranTimesSrc(bool add, int[] mpcoliv, int[] mpcolrow, int[] mpcolrun, - int[] runs, float[] coefs, AlignedArray src, AlignedArray dst, int ccol) - { - Contracts.AssertNonEmpty(mpcoliv); - Contracts.Assert(mpcoliv.Length == ccol); - Contracts.AssertNonEmpty(mpcolrow); - Contracts.Assert(mpcolrow.Length == ccol); - Contracts.AssertNonEmpty(runs); - Contracts.AssertNonEmpty(coefs); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(mpcolrun == null || mpcolrun.Length == ccol); - Contracts.Assert(0 < ccol && ccol <= src.Size); - - unsafe - { - fixed (int* pmpcoliv = &mpcoliv[0]) - fixed (int* pmpcolrow = &mpcolrow[0]) - fixed (int* pruns = &runs[0]) - fixed (float* pcoefs = &coefs[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - { - if (mpcolrun == null) - { - Thunk.MatMulTranCU(add, pmpcoliv, pmpcolrow, pruns, pcoefs, - Ptr(src, psrc), Ptr(dst, pdst), dst.Size, ccol); - } - else - { - fixed (int* pmpcolrun = &mpcolrun[0]) - { - Thunk.MatMulTranDU(add, pmpcoliv, pmpcolrow, pmpcolrun, pruns, pcoefs, - Ptr(src, psrc), Ptr(dst, pdst), dst.Size, ccol); - } - } - } - } - } - - public static void MeanBackOfSrc(bool add, int[] mpcolrow, int[] mpcolindices, - int[] indices, AlignedArray src, AlignedArray dst, int ccol) - { - Contracts.AssertNonEmpty(mpcolrow); - Contracts.Assert(mpcolrow.Length == ccol); - Contracts.Assert(mpcolindices == null || mpcolindices.Length == ccol); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(0 < ccol && ccol <= src.Size); - - unsafe - { - fixed (int* pmpcolrow = &mpcolrow[0]) - fixed (int* pmpcolindices = mpcolindices) - fixed (int* pindices = &indices[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.MeanBackU(add, pmpcolrow, pmpcolindices, pindices, Ptr(src, psrc), Ptr(dst, pdst), dst.Size, ccol); - } - } - - public static void MaxBackOfSrc(bool add, int[] mpcolrow, int[] mpcolindices, - int[] indices, AlignedArray src, AlignedArray dst, AlignedArray val, int ccol) - { - Contracts.AssertNonEmpty(mpcolrow); - Contracts.Assert(mpcolrow.Length == ccol); - Contracts.Assert(mpcolindices == null || mpcolindices.Length == ccol); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(Compat(val)); - Contracts.Assert(0 < ccol && ccol <= src.Size); - Contracts.Assert(dst.Size == val.Size); - - unsafe - { - fixed (int* pmpcolrow = &mpcolrow[0]) - fixed (int* pmpcolindices = mpcolindices) - fixed (int* pindices = &indices[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - fixed (float* pval = &val.Items[0]) - Thunk.MaxBackU(add, pmpcolrow, pmpcolindices, pindices, Ptr(src, psrc), Ptr(dst, pdst), Ptr(val, pval), dst.Size, ccol); - } - } - - public static void RespNormBackOfSrc(bool add, float alpha, float beta, bool avgOverFullKernel, float offset, - int[] mpcolrow, int[] mpcolindices, int[] indices, - AlignedArray errors, AlignedArray errorsPrev, AlignedArray valuesPrev, int ccol) - { - Contracts.AssertNonEmpty(mpcolrow); - Contracts.Assert(mpcolrow.Length == ccol); - Contracts.Assert(mpcolindices == null || mpcolindices.Length == ccol); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(Compat(errors)); - Contracts.Assert(Compat(errorsPrev)); - Contracts.Assert(Compat(valuesPrev)); - Contracts.Assert(0 < ccol && ccol <= errors.Size); - Contracts.Assert(errorsPrev.Size == valuesPrev.Size); - - unsafe - { - fixed (int* pmpcolrow = &mpcolrow[0]) - fixed (int* pmpcolindices = mpcolindices) - fixed (int* pindices = &indices[0]) - fixed (float* perr = &errors.Items[0]) - fixed (float* perrPrev = &errorsPrev.Items[0]) - fixed (float* pvalPrev = &valuesPrev.Items[0]) - { - Thunk.RespNormBackU(add, alpha, beta, avgOverFullKernel, offset, pmpcolrow, pmpcolindices, pindices, - Ptr(errors, perr), Ptr(errorsPrev, perrPrev), Ptr(valuesPrev, pvalPrev), errorsPrev.Size, ccol); - } - } - } - - public static void AddXYTran(float a, AlignedArray x, AlignedArray y, AlignedArray mat, int crow, float decay) - { - Contracts.Assert(Compat(x)); - Contracts.Assert(Compat(y)); - Contracts.Assert(Compat(mat)); - Contracts.Assert(0 < crow && crow <= x.Size); - Contracts.Assert(x.Size * y.Size == mat.Size); - Contracts.Assert(decay >= 0); - - unsafe - { - fixed (float* px = &x.Items[0]) - fixed (float* py = &y.Items[0]) - fixed (float* pmat = &mat.Items[0]) - Thunk.AddXYTranA(a, Ptr(x, px), Ptr(y, py), Ptr(mat, pmat), crow, y.Size, decay); - } - } - - public static void AddXYTran(float a, AlignedArray x, int[] rgposY, AlignedArray valuesY, - int posMinY, int iposMinY, int iposLimY, AlignedArray mat, int crow) - { - Contracts.Assert(Compat(x)); - Contracts.Assert(Compat(valuesY)); - Contracts.Assert(Compat(mat)); - Contracts.AssertNonEmpty(rgposY); - Contracts.Assert(0 <= iposMinY && iposMinY <= iposLimY && iposLimY <= rgposY.Length); - Contracts.Assert(0 < crow && crow <= x.Size); - Contracts.Assert(x.Size * valuesY.Size == mat.Size); - - if (iposMinY >= iposLimY) - return; - - unsafe - { - fixed (float* px = &x.Items[0]) - fixed (float* py = &valuesY.Items[0]) - fixed (int* pposy = &rgposY[0]) - fixed (float* pmat = &mat.Items[0]) - { - Thunk.AddXYTranPA(a, Ptr(x, px), pposy, Ptr(valuesY, py), posMinY, iposMinY, iposLimY, Ptr(mat, pmat), - crow, valuesY.Size); - } - } - } - - public static void AddXYTran(float a, AlignedArray x, AlignedArray y, - int[] starts, int[] indices, float[] coefs, int crow, float decay) - { - Contracts.Assert(Compat(x)); - Contracts.Assert(Compat(y)); - Contracts.Assert(0 < crow && crow <= x.Size); - Contracts.AssertNonEmpty(starts); - Contracts.Assert(starts.Length == crow + 1); - Contracts.Assert(starts[0] == 0); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(starts[crow] == indices.Length); - Contracts.AssertNonEmpty(coefs); - Contracts.Assert(indices.Length == coefs.Length); - Contracts.Assert(crow * y.Size >= coefs.Length); - Contracts.Assert(decay >= 0); - - unsafe - { - fixed (float* px = &x.Items[0]) - fixed (float* py = &y.Items[0]) - fixed (int* pstarts = &starts[0]) - fixed (int* pindices = &indices[0]) - fixed (float* pcoefs = &coefs[0]) - Thunk.AddXYTranRU(a, Ptr(x, px), Ptr(y, py), pstarts, pindices, pcoefs, crow, decay); - } - } - - public static void AddXYTran(float a, AlignedArray x, AlignedArray y, int[] mprowiv, - int[] mprowcol, int[] mprowrun, int[] runs, float[] coefs, int crow) - { - Contracts.Assert(Compat(x)); - Contracts.Assert(Compat(y)); - Contracts.Assert(0 < crow && crow <= x.Size); - Contracts.AssertNonEmpty(mprowiv); - Contracts.Assert(mprowiv.Length == crow); - Contracts.AssertNonEmpty(mprowcol); - Contracts.Assert(mprowcol.Length == crow); - Contracts.Assert(mprowrun == null || mprowrun.Length == crow); - Contracts.AssertNonEmpty(runs); - Contracts.AssertNonEmpty(coefs); - - unsafe - { - fixed (float* px = &x.Items[0]) - fixed (float* py = &y.Items[0]) - fixed (int* pmprowiv = &mprowiv[0]) - fixed (int* pmprowcol = &mprowcol[0]) - fixed (int* pruns = &runs[0]) - fixed (float* pcoefs = &coefs[0]) - { - if (mprowrun == null) - Thunk.AddXYTranCU(a, Ptr(x, px), Ptr(y, py), pmprowiv, pmprowcol, pruns, pcoefs, crow); - else - { - fixed (int* pmprowrun = &mprowrun[0]) - Thunk.AddXYTranDU(a, Ptr(x, px), Ptr(y, py), pmprowiv, pmprowcol, pmprowrun, pruns, pcoefs, crow); - } - } - } - } - - public static void AddXYTran(float a, AlignedArray x, AlignedArray y, AlignedArray mat, float momentum, AlignedArray delta, int crow) - { - Contracts.Assert(Compat(x)); - Contracts.Assert(Compat(y)); - Contracts.Assert(Compat(mat)); - Contracts.Assert(Compat(delta)); - Contracts.Assert(0 < crow && crow <= x.Size); - Contracts.Assert(x.Size * y.Size == mat.Size); - Contracts.Assert(mat.Size == delta.Size); - - unsafe - { - fixed (float* px = &x.Items[0]) - fixed (float* py = &y.Items[0]) - fixed (float* pmat = &mat.Items[0]) - fixed (float* pdel = &delta.Items[0]) - Thunk.AddXYTranMomA(a, Ptr(x, px), Ptr(y, py), Ptr(mat, pmat), momentum, Ptr(delta, pdel), crow, y.Size); - } - } - - public static void AddXYTran(AlignedArray x, AlignedArray y, AlignedArray mat, AlignedArray accGrads, AlignedArray accUpdates, - float decay, float cond, int crow) - { - Contracts.Assert(Compat(x)); - Contracts.Assert(Compat(y)); - Contracts.Assert(Compat(mat)); - Contracts.Assert(Compat(accGrads)); - Contracts.Assert(Compat(accUpdates)); - Contracts.Assert(0 < crow && crow <= x.Size); - Contracts.Assert(x.Size * y.Size == mat.Size); - Contracts.Assert(mat.Size == accGrads.Size); - Contracts.Assert(mat.Size == accUpdates.Size); - - unsafe - { - fixed (float* px = &x.Items[0]) - fixed (float* py = &y.Items[0]) - fixed (float* pmat = &mat.Items[0]) - fixed (float* pag = &accGrads.Items[0]) - fixed (float* pau = &accUpdates.Items[0]) - Thunk.AddXYTranGradA(Ptr(x, px), Ptr(y, py), Ptr(mat, pmat), Ptr(accGrads, pag), Ptr(accUpdates, pau), decay, cond, crow, y.Size); - } - } - - public static void AddXYTran(AlignedArray x, AlignedArray y, int[] starts, int[] indices, - float[] coefs, float[] accGrads, float[] accUpdates, float decay, float cond, int crow) - { - Contracts.Assert(Compat(x)); - Contracts.Assert(Compat(y)); - Contracts.Assert(0 < crow && crow <= x.Size); - Contracts.AssertNonEmpty(starts); - Contracts.Assert(starts.Length == crow + 1); - Contracts.Assert(starts[0] == 0); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(starts[crow] == indices.Length); - Contracts.AssertNonEmpty(coefs); - Contracts.Assert(indices.Length == coefs.Length); - Contracts.Assert(crow * y.Size >= coefs.Length); - Contracts.AssertNonEmpty(accGrads); - Contracts.Assert(coefs.Length == accGrads.Length); - Contracts.AssertNonEmpty(accUpdates); - Contracts.Assert(coefs.Length == accUpdates.Length); - - unsafe - { - fixed (float* px = &x.Items[0]) - fixed (float* py = &y.Items[0]) - fixed (int* pstarts = &starts[0]) - fixed (int* pindices = &indices[0]) - fixed (float* pcoefs = &coefs[0]) - fixed (float* pag = &accGrads[0]) - fixed (float* pau = &accUpdates[0]) - Thunk.AddXYTranGradRU(Ptr(x, px), Ptr(y, py), pstarts, pindices, pcoefs, pag, pau, decay, cond, crow); - } - } - - public static void AddXYTran(AlignedArray x, int[] rgposY, AlignedArray valuesY, - int posMinY, int iposMinY, int iposLimY, AlignedArray mat, - AlignedArray accGrads, AlignedArray accUpdates, float decay, float cond, int crow) - { - Contracts.Assert(Compat(x)); - Contracts.AssertNonEmpty(rgposY); - Contracts.Assert(Compat(valuesY)); - Contracts.Assert(Compat(mat)); - Contracts.Assert(0 <= iposMinY && iposMinY <= iposLimY && iposLimY <= rgposY.Length); - Contracts.Assert(0 < crow && crow <= x.Size); - Contracts.Assert(x.Size * valuesY.Size == mat.Size); - Contracts.Assert(mat.Size == accGrads.Size); - Contracts.Assert(mat.Size == accUpdates.Size); - - if (iposMinY >= iposLimY) - return; - - unsafe - { - fixed (float* px = &x.Items[0]) - fixed (float* py = &valuesY.Items[0]) - fixed (int* pposy = &rgposY[0]) - fixed (float* pmat = &mat.Items[0]) - fixed (float* pag = &accGrads.Items[0]) - fixed (float* pau = &accUpdates.Items[0]) - { - Thunk.AddXYTranGradPA(Ptr(x, px), pposy, Ptr(valuesY, py), posMinY, iposMinY, iposLimY, Ptr(mat, pmat), - Ptr(accGrads, pag), Ptr(accUpdates, pau), decay, cond, crow, valuesY.Size); + Thunk.MatMulP(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); } } } @@ -591,17 +98,6 @@ public static void Add(float a, Span dst) } } - public static void Scale(float a, AlignedArray dst) - { - Contracts.Assert(Compat(dst)); - - unsafe - { - fixed (float* pdst = &dst.Items[0]) - Thunk.Scale(a, Ptr(dst, pdst), dst.Size); - } - } - public static void Scale(float a, Span dst) { Contracts.AssertNonEmpty(dst); @@ -613,19 +109,6 @@ public static void Scale(float a, Span dst) } } - public static void Scale(float a, float[] dst, int offset, int count) - { - Contracts.AssertNonEmpty(dst); - Contracts.Assert(0 < count); - Contracts.Assert(0 <= offset && offset < dst.Length - count); - - unsafe - { - fixed (float* pd = &dst[offset]) - Thunk.Scale(a, pd, count); - } - } - // dst = a * src public static void Scale(float a, ReadOnlySpan src, Span dst, int count) { @@ -656,98 +139,6 @@ public static void ScaleAdd(float a, float b, Span dst) } } - public static void ScaleConvWeights(float a, int kernelSize, float[] dst) - { - Contracts.AssertValue(dst); - - // REVIEW: implement in SSE/AVX. - for (int istart = 0; istart < dst.Length; istart += kernelSize + 1) - { - for (int i = 0; i < kernelSize; i++) - dst[istart + i] *= a; - } - } - - public static void ScaleMaxNorm(bool tran, float maxNorm, AlignedArray mat, int crun, int runLenPhy) - { - // Called also by MklMath which uses Avx alignment, which is a multiple of Sse alignment. - // Hence, Compat(mat) cannot be asserted here since it checks for exact Sse alignment (mat.CbAlign == CbAlign). - Contracts.AssertValue(mat); - Contracts.Assert(mat.Size > 0); - Contracts.Assert((mat.CbAlign % CbAlign) == 0); - - unsafe - { - fixed (float* pmat = &mat.Items[0]) - { - if (!tran) - Thunk.ScaleMaxNormA(maxNorm, Ptr(mat, pmat), crun, runLenPhy); - else - Thunk.ScaleMaxNormTranU(maxNorm, Ptr(mat, pmat), crun, runLenPhy); - } - } - } - - public static void ScaleMaxNorm(float maxNorm, int[] starts, int[] indices, float[] mat) - { - Contracts.AssertNonEmpty(starts); - - int crow = starts.Length - 1; - Contracts.Assert(starts[0] == 0); - Contracts.AssertValue(indices); - Contracts.Assert(starts[crow] == indices.Length); - Contracts.AssertNonEmpty(mat); - - unsafe - { - fixed (int* pstarts = &starts[0]) - fixed (float* pmat = &mat[0]) - Thunk.ScaleMaxNormRU(maxNorm, pstarts, pmat, crow); - } - } - - public static void ScaleMaxNorm(float maxNorm, int kernCount, int kernSize, float[] mat) - { - Contracts.AssertNonEmpty(mat); - - unsafe - { - fixed (float* pmat = &mat[0]) - Thunk.ScaleMaxNormCU(maxNorm, kernCount, kernSize, pmat); - } - } - - public static void AddScale(float a, AlignedArray src, AlignedArray dst) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.AddScaleA(a, Ptr(src, psrc), Ptr(dst, pdst), dst.Size); - } - } - - public static void AddScale(float a, AlignedArray src, AlignedArray dst, float momentum, AlignedArray delta) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(Compat(delta)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(src.Size == delta.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - fixed (float* pdel = &delta.Items[0]) - Thunk.AddScaleMomA(a, Ptr(src, psrc), Ptr(dst, pdst), momentum, Ptr(delta, pdel), dst.Size); - } - } - public static void AddScale(float a, ReadOnlySpan src, Span dst, int count) { Contracts.AssertNonEmpty(src); @@ -799,41 +190,6 @@ public static void AddScaleCopy(float a, ReadOnlySpan src, ReadOnlySpan src, Span dst, int count) { Contracts.AssertNonEmpty(src); @@ -883,36 +239,6 @@ public static void MulElementWise(ReadOnlySpan src1, ReadOnlySpan } } - public static void MulElementWise(float[] src1, float[] src2, int[] indices, float[] dst, int count) - { - Contracts.AssertNonEmpty(src1); - Contracts.Assert(0 < count && count <= src1.Length); - Contracts.AssertNonEmpty(src2); - Contracts.Assert(0 < count && count <= src2.Length); - Contracts.AssertNonEmpty(dst); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(count <= indices.Length); - unsafe - { - fixed (float* ps1 = &src1[0]) - fixed (float* ps2 = &src2[0]) - fixed (int* pi = &indices[0]) - fixed (float* pd = &dst[0]) - Thunk.MulElementWiseSU(ps1, ps2, pi, pd, count); - } - } - - public static float Sum(AlignedArray src) - { - Contracts.Assert(Compat(src)); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - return Thunk.SumA(Ptr(src, psrc), src.Size); - } - } - public static float Sum(ReadOnlySpan src) { Contracts.AssertNonEmpty(src); @@ -1039,262 +365,6 @@ public static float L2DistSquared(ReadOnlySpan a, ReadOnlySpan b, } } - public static void ApplySigmoid(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 < c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplySigmoidA(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplySoftMax(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 < c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplySoftMaxA(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplyRectifiedLinear(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 < c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplyRectifiedLinearA(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplySquare(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 < c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplySquareA(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplySqrt(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 < c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplySqrtA(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplySoftRectifiedLinear(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 < c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplySoftRectifiedLinearA(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplyAbs(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 < c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplyAbsA(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplyTanh(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 < c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplyTanhA(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplyBoundedRectifiedLinear(AlignedArray src, AlignedArray dst, int c) - { - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(src.Size == dst.Size); - Contracts.Assert(0 <= c && c <= dst.Size); - - unsafe - { - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - Thunk.ApplyBoundedRectifiedLinearA(Ptr(src, psrc), Ptr(dst, pdst), c); - } - } - - public static void ApplySigmoidDerivative(AlignedArray value, AlignedArray grad) - { - Contracts.Assert(Compat(value)); - Contracts.Assert(Compat(grad)); - Contracts.Assert(value.Size == grad.Size); - - unsafe - { - fixed (float* pvalue = &value.Items[0]) - fixed (float* pgrad = &grad.Items[0]) - Thunk.ApplySigmoidDerivativeA(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size); - } - } - - public static void ApplyRectifiedLinearDerivative(AlignedArray value, AlignedArray grad) - { - Contracts.Assert(Compat(value)); - Contracts.Assert(Compat(grad)); - Contracts.Assert(value.Size == grad.Size); - - unsafe - { - fixed (float* pvalue = &value.Items[0]) - fixed (float* pgrad = &grad.Items[0]) - Thunk.ApplyRectifiedLinearDerivativeA(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size); - } - } - - public static void ApplySquareDerivative(AlignedArray input, AlignedArray output, AlignedArray grad, bool drop) - { - Contracts.Assert(Compat(input)); - Contracts.Assert(Compat(output)); - Contracts.Assert(Compat(grad)); - Contracts.Assert(output.Size == input.Size); - Contracts.Assert(output.Size == grad.Size); - - unsafe - { - fixed (float* px = &input.Items[0]) - fixed (float* py = &output.Items[0]) - fixed (float* pg = &grad.Items[0]) - Thunk.ApplySquareDerivativeA(Ptr(input, px), Ptr(output, py), Ptr(grad, pg), grad.Size, drop); - } - } - - public static void ApplySqrtDerivative(AlignedArray value, AlignedArray grad) - { - Contracts.Assert(Compat(value)); - Contracts.Assert(Compat(grad)); - Contracts.Assert(value.Size == grad.Size); - - unsafe - { - fixed (float* pvalue = &value.Items[0]) - fixed (float* pgrad = &grad.Items[0]) - Thunk.ApplySqrtDerivativeA(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size); - } - } - - public static void ApplySoftRectifiedLinearDerivative(AlignedArray input, AlignedArray output, AlignedArray grad) - { - Contracts.Assert(Compat(input)); - Contracts.Assert(Compat(output)); - Contracts.Assert(Compat(grad)); - Contracts.Assert(output.Size == input.Size); - Contracts.Assert(output.Size == grad.Size); - - unsafe - { - fixed (float* px = &input.Items[0]) - fixed (float* py = &output.Items[0]) - fixed (float* pg = &grad.Items[0]) - Thunk.ApplySoftRectifiedLinearDerivativeA(Ptr(input, px), Ptr(output, py), Ptr(grad, pg), grad.Size); - } - } - - public static void ApplyAbsDerivative(AlignedArray input, AlignedArray output, AlignedArray grad, bool drop) - { - Contracts.Assert(Compat(input)); - Contracts.Assert(Compat(output)); - Contracts.Assert(Compat(grad)); - Contracts.Assert(output.Size == input.Size); - Contracts.Assert(output.Size == grad.Size); - - unsafe - { - fixed (float* px = &input.Items[0]) - fixed (float* py = &output.Items[0]) - fixed (float* pg = &grad.Items[0]) - Thunk.ApplyAbsDerivativeA(Ptr(input, px), Ptr(output, py), Ptr(grad, pg), grad.Size, drop); - } - } - - public static void ApplyTanhDerivative(AlignedArray value, AlignedArray grad) - { - Contracts.Assert(Compat(value)); - Contracts.Assert(Compat(grad)); - Contracts.Assert(value.Size == grad.Size); - - unsafe - { - fixed (float* pvalue = &value.Items[0]) - fixed (float* pgrad = &grad.Items[0]) - Thunk.ApplyTanhDerivativeA(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size); - } - } - - public static void ApplyBoundedRectifiedLinearDerivative(AlignedArray value, AlignedArray grad) - { - Contracts.Assert(Compat(value)); - Contracts.Assert(Compat(grad)); - Contracts.Assert(value.Size == grad.Size); - - unsafe - { - fixed (float* pvalue = &value.Items[0]) - fixed (float* pgrad = &grad.Items[0]) - Thunk.ApplyBoundedRectifiedLinearDerivativeA(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size); - } - } - public static void ZeroMatrixItems(AlignedArray dst, int ccol, int cfltRow, int[] indices) { Contracts.Assert(0 < ccol && ccol <= cfltRow); @@ -1352,24 +422,5 @@ public static void SdcaL1UpdateSparse(float primalUpdate, int count, ReadOnlySpa Thunk.SdcaL1UpdateSU(primalUpdate, psrc, pi, threshold, pd1, pd2, count); } } - - public static void ScaleAdadelta(float[] mat, float[] accGrads, float[] accUpdates, float decay, float cond, float[] grads) - { - Contracts.AssertNonEmpty(mat); - Contracts.AssertNonEmpty(accGrads); - Contracts.AssertNonEmpty(accUpdates); - Contracts.Assert(mat.Length == accGrads.Length); - Contracts.Assert(mat.Length == accUpdates.Length); - Contracts.Assert(mat.Length <= grads.Length); - - unsafe - { - fixed (float* pm = &mat[0]) - fixed (float* pag = &accGrads[0]) - fixed (float* pau = &accUpdates[0]) - fixed (float* pg = &grads[0]) - Thunk.ScaleAdadeltaU(pm, pag, pau, decay, cond, pg, mat.Length); - } - } } } \ No newline at end of file diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index 1a99c5bd92..dbcd69aeef 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -44,26 +44,6 @@ internal static class SseIntrinsics Sse.StaticCast(Sse2.SetAllVector128(0x7FFFFFFF)) : Sse.SetAllVector128(BitConverter.Int32BitsToSingle(0x7FFFFFFF)); - // The count of bytes in Vector128, corresponding to _cbAlign in AlignedArray - private const int Vector128Alignment = 16; - - [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] - private static bool HasCompatibleAlignment(AlignedArray alignedArray) - { - Contracts.AssertValue(alignedArray); - Contracts.Assert(alignedArray.Size > 0); - return (alignedArray.CbAlign % Vector128Alignment) == 0; - } - - [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] - private static unsafe float* GetAlignedBase(AlignedArray alignedArray, float* unalignedBase) - { - Contracts.AssertValue(alignedArray); - float* alignedBase = unalignedBase + alignedArray.GetBase((long)unalignedBase); - Contracts.Assert(((long)alignedBase & (Vector128Alignment - 1)) == 0); - return alignedBase; - } - [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] internal static unsafe Vector128 Load1(float* src, int* idx) => Sse.SetScalarVector128(src[idx[0]]); @@ -137,17 +117,17 @@ internal static Vector128 GetNewDst128(in Vector128 xDst1, in Vect // Multiply matrix times vector into vector. public static unsafe void MatMul(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { - Contracts.Assert(crow % 4 == 0); - Contracts.Assert(ccol % 4 == 0); - MatMul(mat.Items, src.Items, dst.Items, crow, ccol); } - public static unsafe void MatMul(float[] mat, float[] src, float[] dst, int crow, int ccol) + public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crow, int ccol) { - fixed (float* psrc = &src[0]) - fixed (float* pdst = &dst[0]) - fixed (float* pmat = &mat[0]) + Contracts.Assert(crow % 4 == 0); + Contracts.Assert(ccol % 4 == 0); + + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + fixed (float* pmat = &MemoryMarshal.GetReference(mat)) fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) { @@ -285,24 +265,27 @@ public static unsafe void MatMul(float[] mat, float[] src, float[] dst, int crow } // Partial sparse source vector. - public static unsafe void MatMulPA(AlignedArray mat, int[] rgposSrc, AlignedArray src, - int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) + public static unsafe void MatMulP(AlignedArray mat, int[] rgposSrc, AlignedArray src, + int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) { - Contracts.Assert(HasCompatibleAlignment(mat)); - Contracts.Assert(HasCompatibleAlignment(src)); - Contracts.Assert(HasCompatibleAlignment(dst)); + MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol); + } + + public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgposSrc, ReadOnlySpan src, + int posMin, int iposMin, int iposEnd, Span dst, int crow, int ccol) + { + Contracts.Assert(crow % 4 == 0); + Contracts.Assert(ccol % 4 == 0); // REVIEW: For extremely sparse inputs, interchanging the loops would // likely be more efficient. - fixed (float* pSrcStart = &src.Items[0]) - fixed (float* pDstStart = &dst.Items[0]) - fixed (float* pMatStart = &mat.Items[0]) - fixed (int* pposSrc = &rgposSrc[0]) + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + fixed (float* pmat = &MemoryMarshal.GetReference(mat)) + fixed (int* pposSrc = &MemoryMarshal.GetReference(rgposSrc)) + fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) + fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) { - float* psrc = GetAlignedBase(src, pSrcStart); - float* pdst = GetAlignedBase(dst, pDstStart); - float* pmat = GetAlignedBase(mat, pMatStart); - int* pposMin = pposSrc + iposMin; int* pposEnd = pposSrc + iposEnd; float* pDstEnd = pdst + crow; @@ -310,7 +293,105 @@ public static unsafe void MatMulPA(AlignedArray mat, int[] rgposSrc, AlignedArra float* pSrcCurrent = psrc - posMin; float* pDstCurrent = pdst; - while (pDstCurrent < pDstEnd) + nuint address = (nuint)(pDstCurrent); + int misalignment = (int)(address % 16); + + int length = crow; + int remainder = 0; + + if ((misalignment & 3) != 0) + { + while (pDstCurrent < pDstEnd) + { + Sse.Store(pDstCurrent, SparseMultiplicationAcrossRow()); + pDstCurrent += 4; + pm0 += 4 * ccol; + } + } + else + { + if (misalignment != 0) + { + misalignment >>= 2; + misalignment = 4 - misalignment; + + Vector128 mask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); + + float* pm1 = pm0 + ccol; + float* pm2 = pm1 + ccol; + float* pm3 = pm2 + ccol; + Vector128 result = Sse.SetZeroVector128(); + + int* ppos = pposMin; + + while (ppos < pposEnd) + { + int col = *ppos; + Vector128 x1 = Sse.SetVector128(pm3[col], pm2[col], pm1[col], pm0[col]); + + x1 = Sse.And(mask, x1); + Vector128 x2 = Sse.SetAllVector128(pSrcCurrent[col]); + x2 = Sse.Multiply(x2, x1); + result = Sse.Add(result, x2); + ppos++; + } + + Sse.Store(pDstCurrent, result); + pDstCurrent += misalignment; + pm0 += misalignment * ccol; + length -= misalignment; + } + + if (length > 3) + { + remainder = length % 4; + while (pDstCurrent < pDstEnd) + { + Sse.Store(pDstCurrent, SparseMultiplicationAcrossRow()); + pDstCurrent += 4; + pm0 += 4 * ccol; + } + } + else + { + remainder = length; + } + + if (remainder != 0) + { + pDstCurrent -= (4 - remainder); + pm0 -= (4 - remainder) * ccol; + Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); + Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); + + float* pm1 = pm0 + ccol; + float* pm2 = pm1 + ccol; + float* pm3 = pm2 + ccol; + Vector128 result = Sse.SetZeroVector128(); + + int* ppos = pposMin; + + while (ppos < pposEnd) + { + int col = *ppos; + Vector128 x1 = Sse.SetVector128(pm3[col], pm2[col], pm1[col], pm0[col]); + x1 = Sse.And(x1, trailingMask); + + Vector128 x2 = Sse.SetAllVector128(pSrcCurrent[col]); + x2 = Sse.Multiply(x2, x1); + result = Sse.Add(result, x2); + ppos++; + } + + result = Sse.Add(result, Sse.And(leadingMask, Sse.LoadVector128(pDstCurrent))); + + Sse.Store(pDstCurrent, result); + pDstCurrent += 4; + pm0 += 4 * ccol; + } + } + + Vector128 SparseMultiplicationAcrossRow() { float* pm1 = pm0 + ccol; float* pm2 = pm1 + ccol; @@ -326,29 +407,27 @@ public static unsafe void MatMulPA(AlignedArray mat, int[] rgposSrc, AlignedArra Vector128 x2 = Sse.SetAllVector128(pSrcCurrent[col]); x2 = Sse.Multiply(x2, x1); result = Sse.Add(result, x2); - ppos++; } - Sse.StoreAligned(pDstCurrent, result); - pDstCurrent += 4; - pm0 += 4 * ccol; + return result; } } } public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { - Contracts.Assert(crow % 4 == 0); - Contracts.Assert(ccol % 4 == 0); MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol); } - public static unsafe void MatMulTran(float[] mat, float[] src, float[] dst, int crow, int ccol) + public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crow, int ccol) { - fixed (float* psrc = &src[0]) - fixed (float* pdst = &dst[0]) - fixed (float* pmat = &mat[0]) + Contracts.Assert(crow % 4 == 0); + Contracts.Assert(ccol % 4 == 0); + + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + fixed (float* pmat = &MemoryMarshal.GetReference(mat)) fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) { diff --git a/src/Microsoft.ML.CpuMath/Thunk.cs b/src/Microsoft.ML.CpuMath/Thunk.cs index 22938bb0d7..d505bf40a3 100644 --- a/src/Microsoft.ML.CpuMath/Thunk.cs +++ b/src/Microsoft.ML.CpuMath/Thunk.cs @@ -12,397 +12,79 @@ internal static unsafe class Thunk { internal const string NativePath = "CpuMathNative"; - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern bool ChkAvx(); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void MatMul(/*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulX(/*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulPA(/*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, + public static extern void MatMulP(/*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, int posMin, int iposMin, int iposLim, float* pdst, int crow, int ccol); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulPX(/*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, - int posMin, int iposMin, int iposLim, float* pdst, int crow, int ccol); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulRU(bool add, /*const*/ int* pstarts, /*const*/ int* pindices, /*const*/ float* pcoefs, - /*const*/ float* psrc, float* pdst, int crow); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulRX(bool add, /*const*/ int* pstarts, /*const*/ int* pindices, /*const*/ float* pcoefs, - /*const*/ float* psrc, float* pdst, int crow); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulCU(bool add, /*const*/ int* pmprowiv, /*const*/ int* pmprowcol, - /*const*/ int* pruns, /*const*/ float* pcoefs, /*const*/ float* psrc, float* pdst, int crow); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulDU(bool add, /*const*/ int* pmprowiv, /*const*/ int* pmprowcol, /*const*/ int* pmprowrun, - /*const*/ int* pruns, /*const*/ float* pcoefs, /*const*/ float* psrc, float* pdst, int crow); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulCX(bool add, /*const*/ int* pmprowiv, /*const*/ int* pmprowcol, - /*const*/ int* pruns, /*const*/ float* pcoefs, /*const*/ float* psrc, float* pdst, int crow); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulDX(bool add, /*const*/ int* pmprowiv, /*const*/ int* pmprowcol, /*const*/ int* pmprowrun, - /*const*/ int* pruns, /*const*/ float* pcoefs, /*const*/ float* psrc, float* pdst, int crow); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MeanU(bool add, /*const*/ int* pmprowcol, /*const*/ int* pmprowindices, /*const*/ int* pindices, - /*const*/ float* psrc, float* pdst, int crow); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MaxU(bool add, /*const*/ int* pmprowcol, /*const*/ int* pmprowindices, /*const*/ int* pindices, - /*const*/ float* psrc, float* pdst, int crow); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void RespNormU(bool add, float alpha, float beta, bool avgOverFullKernel, float offset, - /*const*/ int* pmprowcol, /*const*/ int* pmprowindices, /*const*/ int* pindices, - /*const*/ float* psrc, float* pdst, int crow); // These treat pmat as if it is stored in column-major order. Thus, crow and ccol are the numbers of rows // and columns from that perspective. Alternatively, crow is the number of rows in the transpose of pmat // (thought of as row-major order). [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void MatMulTran(/*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTranX(/*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTranRU(bool add, /*const*/ int* pstarts, /*const*/ int* pindices, /*const*/ float* pcoefs, - /*const*/ float* psrc, float* pdst, int crow, int ccol); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTranRX(bool add, /*const*/ int* pstarts, /*const*/ int* pindices, /*const*/ float* pcoefs, - /*const*/ float* psrc, float* pdst, int crow, int ccol); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTranCU(bool add, /*const*/ int* pmpcoliv, /*const*/ int* pmpcolrow, - /*const*/ int* pruns, /*const*/ float* pcoefs, /*const*/ float* psrc, float* pdst, int crow, int ccol); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTranDU(bool add, /*const*/ int* pmpcoliv, /*const*/ int* pmpcolrow, /*const*/ int* pmpcolrun, - /*const*/ int* pruns, /*const*/ float* pcoefs, /*const*/ float* psrc, float* pdst, int crow, int ccol); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTranCX(bool add, /*const*/ int* pmpcoliv, /*const*/ int* pmpcolrow, - /*const*/ int* pruns, /*const*/ float* pcoefs, /*const*/ float* psrc, float* pdst, int crow, int ccol); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTranDX(bool add, /*const*/ int* pmpcoliv, /*const*/ int* pmpcolrow, /*const*/ int* pmpcolrun, - /*const*/ int* pruns, /*const*/ float* pcoefs, /*const*/ float* psrc, float* pdst, int crow, int ccol); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MeanBackU(bool add, /*const*/ int* pmpcolrow, /*const*/ int* pmpcolindices, /*const*/ int* pindices, - /*const*/ float* psrc, float* pdst, int crow, int ccol); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MaxBackU(bool add, /*const*/ int* pmpcolrow, /*const*/ int* pmpcolindices, /*const*/ int* pindices, - /*const*/ float* psrc, float* pdst, /*const*/ float* pval, int crow, int ccol); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void RespNormBackU(bool add, float alpha, float beta, bool avgOverFullKernel, float offset, - /*const*/ int* pmpcolrow, /*const*/ int* pmpcolindices, /*const*/ int* pindices, - /*const*/ float* perrors, float* perrorsPrev, /*const*/ float* pvaluesPrev, int crow, int ccol); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranA(float a, /*const*/ float* px, /*const*/ float* py, float* pmat, int crow, int ccol, float decay); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranX(float a, /*const*/ float* px, /*const*/ float* py, float* pmat, int crow, int ccol, float decay); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranPA(float a, /*const*/ float* px, /*const*/ int* pposY, /*const*/ float* pvaluesY, - int posMinY, int iposMinY, int iposLimY, float* pmat, int crow, int ccol); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranPX(float a, /*const*/ float* px, /*const*/ int* pposY, /*const*/ float* pvaluesY, - int posMinY, int iposMinY, int iposLimY, float* pmat, int crow, int ccol); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranRU(float a, /*const*/ float* px, /*const*/ float* py, - /*const*/ int* pstarts, /*const*/ int* pindices, float* pcoefs, int crow, float decay); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranRX(float a, /*const*/ float* px, /*const*/ float* py, - /*const*/ int* pstarts, /*const*/ int* pindices, float* pcoefs, int crow, float decay); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranCU(float a, /*const*/ float* px, /*const*/ float* py, /*const*/ int* pmprowiv, /*const*/ int* pmprowcol, - /*const*/ int* pruns, float* pcoefs, int crow); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranDU(float a, /*const*/ float* px, /*const*/ float* py, /*const*/ int* pmprowiv, /*const*/ int* pmprowcol, - /*const*/ int* pmprowrun, /*const*/ int* pruns, float* pcoefs, int crow); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranCX(float a, /*const*/ float* px, /*const*/ float* py, /*const*/ int* pmprowiv, /*const*/ int* pmprowcol, - /*const*/ int* pruns, float* pcoefs, int crow); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranDX(float a, /*const*/ float* px, /*const*/ float* py, /*const*/ int* pmprowiv, /*const*/ int* pmprowcol, - /*const*/ int* pmprowrun, /*const*/ int* pruns, float* pcoefs, int crow); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranMomA(float a, /*const*/ float* px, /*const*/ float* py, float* pmat, float momentum, float* pdel, int crow, int ccol); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranMomX(float a, /*const*/ float* px, /*const*/ float* py, float* pmat, float momentum, float* pdel, int crow, int ccol); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranGradA(/*const*/ float* px, /*const*/ float* py, float* pmat, float* paccGrads, float* paccUpdates, - float decay, float cond, int crow, int ccol); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranGradX(/*const*/ float* px, /*const*/ float* py, float* pmat, float* paccGrads, float* paccUpdates, - float decay, float cond, int crow, int ccol); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranGradRU(/*const*/ float* px, /*const*/ float* py, /*const*/ int* pstarts, /*const*/ int* pindices, - float* pcoefs, float* paccGrads, float* paccUpdates, float decay, float cond, int crow); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranGradRX(/*const*/ float* px, /*const*/ float* py, /*const*/ int* pstarts, /*const*/ int* pindices, - float* pcoefs, float* paccGrads, float* paccUpdates, float decay, float cond, int crow); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranGradPA(/*const*/ float* px, /*const*/ int* pposY, /*const*/ float* pvaluesY, - int posMinY, int iposMinY, int iposLimY, float* pmat, float* paccGrads, float* paccUpdates, - float decay, float cond, int crow, int ccol); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddXYTranGradPX(/*const*/ float* px, /*const*/ int* pposY, /*const*/ float* pvaluesY, - int posMinY, int iposMinY, int iposLimY, float* pmat, float* paccGrads, float* paccUpdates, - float decay, float cond, int crow, int ccol); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void Scale(float a, float* pd, int c); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ScaleX(float a, float* pd, int c); + [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void ScaleSrcU(float a, /*const*/ float* ps, float* pd, int c); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ScaleAddU(float a, float b, float* pd, int c); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ScaleMaxNormA(float maxNorm, float* pmat, int crow, int ccol); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ScaleMaxNormX(float maxNorm, float* pmat, int crow, int ccol); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ScaleMaxNormTranU(float maxNorm, float* pmat, int crow, int ccol); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ScaleMaxNormRU(float maxNorm, /*const*/ int* pstarts, float* pmat, int crow); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ScaleMaxNormCU(float maxNorm, int kernCount, int kernSize, float* pmat); + public static extern void ScaleAddU(float a, float b, float* pd, int c); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddScaleA(float a, /*const*/ float* ps, float* pd, int c); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void AddScaleU(float a, /*const*/ float* ps, float* pd, int c); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddScaleX(float a, /*const*/ float* ps, float* pd, int c); + [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void AddScaleSU(float a, /*const*/ float* ps, /*const*/ int* pi, float* pd, int c); + [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void AddScaleCopyU(float a, /*const*/ float* ps, /*const*/ float* pd, float* pr, int c); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddScaleMomA(float a, /*const*/ float* ps, float* pd, float momentum, float* pdel, int c); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddScaleMomX(float a, /*const*/ float* ps, float* pd, float momentum, float* pdel, int c); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddScaleGradA(/*const*/ float* ps, float* pd, float* paccGrads, float* paccUpdates, - float decay, float cond, int c); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddScaleGradX(/*const*/ float* ps, float* pd, float* paccGrads, float* paccUpdates, - float decay, float cond, int c); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddScaleMultiA(int count, /*const*/ float* ps, float* pd, float* paccGrads, - float* paccUpdates, float decay, float cond, int size); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void AddScalarU(float a, float* pd, int c); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void AddU(/*const*/ float* ps, float* pd, int c); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddA(/*const*/ float* ps, float* pd, int c); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void AddX(/*const*/ float* ps, float* pd, int c); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void AddSU(/*const*/ float* ps, /*const*/ int* pi, float* pd, int c); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern float SumA(/*const*/ float* ps, int c); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern float SumU(/*const*/ float* ps, int c); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern float SumX(/*const*/ float* ps, int c); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern float SumSqU(/*const*/ float* ps, int c); + [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern float SumSqDiffU(float mean, /*const*/ float* ps, int c); + [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern float SumAbsU(/*const*/ float* ps, int c); + [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern float SumAbsDiffU(float mean, /*const*/ float* ps, int c); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern float MulElementWiseU(/*const*/ float* ps1, /*const*/float* ps2, float* pd, int c); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern float MulElementWiseSU(/*const*/ float* ps1, /*const*/float* ps2, /*const*/ int* pi, float* pd, int c); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern float MaxAbsU(/*const*/ float* ps, int c); + [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern float MaxAbsDiffU(float mean, /*const*/ float* ps, int c); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern float DotU(/*const*/ float* pa, /*const*/ float* pb, int c); + [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern float DotSU(/*const*/ float* pa, /*const*/ float* pb, /*const*/ int* pi, int c); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern float Dist2(/*const*/ float* px, /*const*/ float* py, int c); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ApplySigmoidA(/*const*/ float* ps, float* pd, int c); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplySigmoidX(/*const*/ float* ps, float* pd, int c) - { - ApplySigmoidA(ps, pd, c); - } - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ApplySoftMaxU(float* ps, float* pd, int c); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplySoftMaxA(float* ps, float* pd, int c) - { - ApplySoftMaxU(ps, pd, c); - } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplySoftMaxX(float* ps, float* pd, int c) - { - ApplySoftMaxU(ps, pd, c); - } - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ApplyRectifiedLinearA(float* ps, float* pd, int c); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplyRectifiedLinearX(float* ps, float* pd, int c) - { - ApplyRectifiedLinearA(ps, pd, c); - } - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ApplySquareA(float* ps, float* pd, int c); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplySquareX(float* ps, float* pd, int c) - { - ApplySquareA(ps, pd, c); - } - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ApplySqrtA(float* ps, float* pd, int c); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplySqrtX(float* ps, float* pd, int c) - { - ApplySqrtA(ps, pd, c); - } - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ApplySoftRectifiedLinearU(float* ps, float* pd, int c); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplySoftRectifiedLinearA(float* ps, float* pd, int c) - { - ApplySoftRectifiedLinearU(ps, pd, c); - } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplySoftRectifiedLinearX(float* ps, float* pd, int c) - { - ApplySoftRectifiedLinearU(ps, pd, c); - } - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ApplyAbsA(float* ps, float* pd, int c); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplyAbsX(float* ps, float* pd, int c) - { - ApplyAbsA(ps, pd, c); - } - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ApplyTanhA(float* ps, float* pd, int c); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplyTanhX(float* ps, float* pd, int c) - { - ApplyTanhA(ps, pd, c); - } - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ApplyBoundedRectifiedLinearA(float* ps, float* pd, int c); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplyBoundedRectifiedLinearX(float* ps, float* pd, int c) - { - ApplyBoundedRectifiedLinearA(ps, pd, c); - } - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ApplySigmoidDerivativeA(/*const*/ float* pv, float* pg, int c); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplySigmoidDerivativeX(/*const*/ float* pv, float* pg, int c) - { - ApplySigmoidDerivativeA(pv, pg, c); - } - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ApplyRectifiedLinearDerivativeA(/*const*/ float* pv, float* pg, int c); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplyRectifiedLinearDerivativeX(/*const*/ float* pv, float* pg, int c) - { - ApplyRectifiedLinearDerivativeA(pv, pg, c); - } - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ApplySquareDerivativeA(/*const*/ float* px, /*const*/ float* py, float* pg, int c, bool drop); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplySquareDerivativeX(/*const*/ float* px, /*const*/ float* py, float* pg, int c, bool drop) - { - ApplySquareDerivativeA(px, py, pg, c, drop); - } - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ApplySqrtDerivativeA(/*const*/ float* pv, float* pg, int c); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplySqrtDerivativeX(/*const*/ float* pv, float* pg, int c) - { - ApplySqrtDerivativeA(pv, pg, c); - } - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ApplySoftRectifiedLinearDerivativeU(/*const*/ float* px, /*const*/ float* py, float* pg, int c); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplySoftRectifiedLinearDerivativeA(/*const*/ float* px, /*const*/ float* py, float* pg, int c) - { - ApplySoftRectifiedLinearDerivativeU(px, py, pg, c); - } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplySoftRectifiedLinearDerivativeX(/*const*/ float* px, /*const*/ float* py, float* pg, int c) - { - ApplySoftRectifiedLinearDerivativeU(px, py, pg, c); - } - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ApplyAbsDerivativeA(/*const*/ float* px, /*const*/ float* py, float* pg, int c, bool drop); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplyAbsDerivativeX(/*const*/ float* px, /*const*/ float* py, float* pg, int c, bool drop) - { - ApplyAbsDerivativeA(px, py, pg, c, drop); - } - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ApplyTanhDerivativeA(/*const*/ float* pv, float* pg, int c); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplyTanhDerivativeX(/*const*/ float* pv, float* pg, int c) - { - ApplyTanhDerivativeA(pv, pg, c); - } - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ApplyBoundedRectifiedLinearDerivativeA(/*const*/ float* pv, float* pg, int c); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ApplyBoundedRectifiedLinearDerivativeX(/*const*/ float* pv, float* pg, int c) - { - ApplyBoundedRectifiedLinearDerivativeA(pv, pg, c); - } - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void ZeroItemsU(float* pd, int c, /*const*/ int* pindices, int cindices); @@ -411,15 +93,9 @@ public static void ApplyBoundedRectifiedLinearDerivativeX(/*const*/ float* pv, f [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void SdcaL1UpdateU(float primalUpdate, /*const*/ float* ps, float threshold, float* pd1, float* pd2, int c); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void SdcaL1UpdateSU(float primalUpdate, /*const*/ float* ps, /*const*/ int* pi, float threshold, float* pd1, float* pd2, int c); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ScaleAdadeltaU(float* mat, float* accGrads, float* accUpdates, float decay, float cond, float* grads, int size); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ScaleAdadeltaA(float* mat, float* accGrads, float* accUpdates, float decay, float cond, float* grads, int size); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ScaleAdadeltaX(float* mat, float* accGrads, float* accUpdates, float decay, float cond, float* grads, int size); + public static extern void SdcaL1UpdateSU(float primalUpdate, /*const*/ float* ps, /*const*/ int* pi, float threshold, float* pd1, float* pd2, int c); #if !CORECLR // In CoreCLR we use Buffer.MemoryCopy directly instead of diff --git a/src/Native/CpuMathNative/Avx.cpp b/src/Native/CpuMathNative/Avx.cpp deleted file mode 100644 index b52e28cd80..0000000000 --- a/src/Native/CpuMathNative/Avx.cpp +++ /dev/null @@ -1,1285 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -// The exported function names need to be unique (can't be disambiguated based on signature), hence -// we introduce suffix letters to indicate the general patterns used. -// * A suffix means aligned and padded for SSE operations. -// * X suffix means aligned and padded for AVX operations. -// * U suffix means unaligned and unpadded. -// * S suffix means sparse (unaligned) vector. -// * P suffix means sparse (unaligned) partial vector - the vector is only part of a larger sparse vector. -// * R suffix means sparse matrix. -// * C suffix means convolution matrix. -// * D suffix means convolution matrix, with implicit source padding. -// * Tran means the matrix is transposed. -// -// Other notes: -// * AVX methods should end with _vleave() to avoid performance hit. See: -// https://stackoverflow.com/questions/7839925/using-avx-cpu-instructions-poor-performance-without-archavx. -// * Keep Avx.cpp in sync with Sse.cpp. Note that Avx.cpp is compiled with /arch:AVX, but Sse.cpp is not. - -// REVIEW: There is code below that mixes SSE and AVX instructions. Does compiling with /arch:AVX -// make that OK? Does the code need to be rewritten? - -#include "../Stdafx.h" -#include - -#ifndef _WIN32 -#define _mm256_set_m128(va, vb) _mm256_insertf128_ps(_mm256_castps128_ps256(vb), va, 1) -#endif - -#define _vleave _mm256_zeroupper - -#define _get_lo(x) _mm256_extractf128_ps(x, 0) -#define _get_hi(x) _mm256_extractf128_ps(x, 1) - -#define _load1(ps, pi) \ - _mm_set_ss(ps[pi[0]]) - -#define _load4(ps, pi) \ - _mm_setr_ps(ps[pi[0]], ps[pi[1]], ps[pi[2]], ps[pi[3]]) - -#define _load8(ps, pi) \ - _mm256_setr_ps(ps[pi[0]], ps[pi[1]], ps[pi[2]], ps[pi[3]], ps[pi[4]], ps[pi[5]], ps[pi[6]], ps[pi[7]]) - -#define _rotate(x) _mm_shuffle_ps(x, x, 0x39) - -#define _store1(x, pd, pi) \ - _mm_store_ss(pd + pi[0], x) - -//Warning: this operation changes the value of x => do not reuse x -#define _store4(x, pd, pi) \ - _mm_store_ss(pd + pi[0], x); \ - x = _rotate(x); _mm_store_ss(pd + pi[1], x); \ - x = _rotate(x); _mm_store_ss(pd + pi[2], x); \ - x = _rotate(x); _mm_store_ss(pd + pi[3], x) - -#define _store8(x, pd, pi) \ - __m128 tmp = _get_lo(x); _mm_store_ss(pd + pi[0], tmp); \ - tmp = _rotate(tmp); _mm_store_ss(pd + pi[1], tmp); \ - tmp = _rotate(tmp); _mm_store_ss(pd + pi[2], tmp); \ - tmp = _rotate(tmp); _mm_store_ss(pd + pi[3], tmp); \ - tmp = _get_hi(x); _mm_store_ss(pd + pi[4], tmp); \ - tmp = _rotate(tmp); _mm_store_ss(pd + pi[5], tmp); \ - tmp = _rotate(tmp); _mm_store_ss(pd + pi[6], tmp); \ - tmp = _rotate(tmp); _mm_store_ss(pd + pi[7], tmp) - -// Multiply matrix times vector into vector. -EXPORT_API(void) MatMulX(bool add, _In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) -{ - const float * psLim = psrc + ccol; - const float * pdLim = pdst + crow; - const float * pm = pmat; - for (float * pd = pdst; pd < pdLim; pd += 4, pm += 3 * ccol) - { - const float * ps = psrc; - __m256 res0 = _mm256_setzero_ps(); - __m256 res1 = res0; - __m256 res2 = res0; - __m256 res3 = res0; - for (; ps < psLim; ps += 8, pm += 8) - { - const float * pmTmp; - __m256 x01 = _mm256_load_ps(pmTmp = pm); - __m256 x11 = _mm256_load_ps(pmTmp += ccol); - __m256 x21 = _mm256_load_ps(pmTmp += ccol); - __m256 x31 = _mm256_load_ps(pmTmp += ccol); - __m256 x02 = _mm256_load_ps(ps); - x01 = _mm256_mul_ps(x01, x02); - x11 = _mm256_mul_ps(x11, x02); - x21 = _mm256_mul_ps(x21, x02); - x31 = _mm256_mul_ps(x31, x02); - res0 = _mm256_add_ps(res0, x01); - res1 = _mm256_add_ps(res1, x11); - res2 = _mm256_add_ps(res2, x21); - res3 = _mm256_add_ps(res3, x31); - } - - // Add up the entries of each, with the 4x2 results in res0 - res0 = _mm256_hadd_ps(res0, res1); - res2 = _mm256_hadd_ps(res2, res3); - res0 = _mm256_hadd_ps(res0, res2); - - __m128 sum = _mm_add_ps(_get_lo(res0), _get_hi(res0)); - if (add) - sum = _mm_add_ps(sum, _mm_load_ps(pd)); - _mm_store_ps(pd, sum); - } - - _vleave(); -} - -// Partial sparse source vector. -EXPORT_API(void) MatMulPX(bool add, _In_ const float * pmat, _In_ const int * pposSrc, _In_ const float * psrc, - int posMin, int iposMin, int iposLim, _Inout_ float * pdst, int crow, int ccol) -{ - const int * pposMin = pposSrc + iposMin; - const int * pposLim = pposSrc + iposLim; - const float * pdLim = pdst + crow; - const float * pm0 = pmat - posMin; - const float * ps = psrc - posMin; - for (float * pd = pdst; pd < pdLim; pd += 8, pm0 += 8 * ccol) - { - const float * pm1 = pm0 + ccol; - const float * pm2 = pm1 + ccol; - const float * pm3 = pm2 + ccol; - __m256 res = _mm256_setzero_ps(); - for (const int * ppos = pposMin; ppos < pposLim; ppos++) - { - int col1 = *ppos; - int col2 = col1 + 4 * ccol; - __m256 x1 = _mm256_setr_ps( - pm0[col1], pm1[col1], pm2[col1], pm3[col1], - pm0[col2], pm1[col2], pm2[col2], pm3[col2]); - __m256 x2 = _mm256_set1_ps(ps[col1]); - x2 = _mm256_mul_ps(x2, x1); - res = _mm256_add_ps(res, x2); - } - - if (add) - res = _mm256_add_ps(res, _mm256_load_ps(pd)); - _mm256_store_ps(pd, res); - } - - _vleave(); -} - -// Sparse matrix. -EXPORT_API(void) MatMulRX(bool add, _In_ const int * pstarts, _In_ const int * pindices, _In_ const float * pcoefs, - _In_ const float * ps, _Inout_ float * pdst, int crow) -{ - const int * pii = pstarts + 1; - const int * pi = pindices; - const float * pm = pcoefs; - const float * pdLim = pdst + crow; - for (float * pd = pdst; pd < pdLim; pd++) - { - const int * piLim = pindices + *pii++; - - __m256 res2 = _mm256_setzero_ps(); - for (; pi + 8 <= piLim; pi += 8, pm += 8) - { - __m256 x = _mm256_mul_ps(_load8(ps, pi), _mm256_loadu_ps(pm)); - res2 = _mm256_add_ps(res2, x); - } - __m128 res = _mm_add_ps(_get_lo(res2), _get_hi(res2)); - if (pi + 4 <= piLim) - { - __m128 x = _mm_mul_ps(_load4(ps, pi), _mm_loadu_ps(pm)); - res = _mm_add_ps(res, x); - pi += 4; pm += 4; - } - for (; pi < piLim; pi++, pm++) - { - __m128 x = _mm_mul_ss(_load1(ps, pi), _mm_set_ss(*pm)); - res = _mm_add_ss(res, x); - } - res = _mm_hadd_ps(res, res); - res = _mm_hadd_ps(res, res); - - if (add) - res = _mm_add_ss(res, _mm_set_ss(*pd)); - _mm_store_ss(pd, res); - } - - _vleave(); -} - -// Unpadded convolution. -EXPORT_API(void) MatMulCX(bool add, _In_ const int * pmprowiv, _In_ const int * pmprowcol, - _In_ const int * pruns, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pdst, int crow) -{ - int size = pruns[1]; - const int * psupport = pruns + 2; - const int * piv = pmprowiv; - const int * pcol = pmprowcol; - const int * piLim = psupport + size; - const float * pdLim = pdst + crow; - - for (float * pd = pdst; pd < pdLim; pd++) - { - const float * pm = pcoefs + *piv++; - const float * ps = psrc + *pcol++; - const int * pi = psupport; - - __m256 res2 = _mm256_setzero_ps(); - for (; pi + 8 <= piLim; pi += 8, pm += 8) - { - __m256 x = _mm256_mul_ps(_load8(ps, pi), _mm256_loadu_ps(pm)); - res2 = _mm256_add_ps(res2, x); - } - __m128 res = _mm_add_ps(_get_lo(res2), _get_hi(res2)); - if (pi + 4 <= piLim) - { - __m128 x = _mm_mul_ps(_load4(ps, pi), _mm_loadu_ps(pm)); - res = _mm_add_ps(res, x); - pi += 4; pm += 4; - } - for (; pi < piLim; pi++, pm++) - { - __m128 x = _mm_mul_ss(_load1(ps, pi), _mm_set_ss(*pm)); - res = _mm_add_ss(res, x); - } - res = _mm_hadd_ps(res, res); - res = _mm_hadd_ps(res, res); - - // Add the bias. - res = _mm_add_ss(res, _mm_set_ss(*pm)); - - if (add) - res = _mm_add_ss(res, _mm_set_ss(*pd)); - _mm_store_ss(pd, res); - } - - _vleave(); -} - -// Padded convolution. -EXPORT_API(void) MatMulDX(bool add, _In_ const int * pmprowiv, _In_ const int * pmprowcol, _In_ const int * pmprowrun, - _In_ const int * pruns, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pdst, int crow) -{ - const int * piv = pmprowiv; - const int * pcol = pmprowcol; - const float * pdLim = pdst + crow; - int kernelSize = pruns[1]; - - const int * pirun = pmprowrun; - for (float * pd = pdst; pd < pdLim; pd++) - { - const float * pm = pcoefs + *piv++; - const float * pmBias = pm + kernelSize; - const float * ps = psrc + *pcol++; - int irun = *pirun++; - - const int * pi = pruns + 2 + irun; - const int * piLim = pi + pi[-1]; - __m256 res2 = _mm256_setzero_ps(); - __m128 res; - if (irun == 0) - { - // No masking needed. - for (; pi + 8 <= piLim; pi += 8, pm += 8) - { - __m256 x = _mm256_mul_ps(_load8(ps, pi), _mm256_loadu_ps(pm)); - res2 = _mm256_add_ps(res2, x); - } - res = _mm_add_ps(_get_lo(res2), _get_hi(res2)); - if (pi + 4 <= piLim) - { - __m128 x = _mm_mul_ps(_load4(ps, pi), _mm_loadu_ps(pm)); - res = _mm_add_ps(res, x); - pi += 4; pm += 4; - } - for (; pi < piLim; pi++, pm++) - { - __m128 x = _mm_mul_ss(_load1(ps, pi), _mm_set_ss(*pm)); - res = _mm_add_ss(res, x); - } - } - else - { - // Need masking. - pm += pi[-2]; - const float * pmask = reinterpret_cast(piLim); - for (; pi + 8 <= piLim; pi += 8, pm += 8, pmask += 8) - { - __m256 x = _mm256_mul_ps(_load8(ps, pi), _mm256_and_ps(_mm256_loadu_ps(pmask), _mm256_loadu_ps(pm))); - res2 = _mm256_add_ps(res2, x); - } - res = _mm_add_ps(_get_lo(res2), _get_hi(res2)); - if (pi + 4 <= piLim) - { - __m128 x = _mm_mul_ps(_load4(ps, pi), _mm_and_ps(_mm_loadu_ps(pmask), _mm_loadu_ps(pm))); - res = _mm_add_ps(res, x); - pi += 4; pm += 4; pmask += 4; - } - for (; pi < piLim; pi++, pm++, pmask++) - { - __m128 x = _mm_mul_ss(_load1(ps, pi), _mm_and_ps(_mm_set_ss(*pmask), _mm_set_ss(*pm))); - res = _mm_add_ss(res, x); - } - } - res = _mm_hadd_ps(res, res); - res = _mm_hadd_ps(res, res); - - res = _mm_add_ss(res, _mm_set_ss(*pmBias)); - if (add) - res = _mm_add_ss(res, _mm_set_ss(*pd)); - _mm_store_ss(pd, res); - } - - _vleave(); -} - -EXPORT_API(void) MatMulTranX(bool add, _In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) -{ - const float * psLim = psrc + ccol; - const float * pdLim = pdst + crow; - const float * pm = pmat; - const float * ps = psrc; - - // We do 4-way unrolling - if (!add) - { - __m128 h01 = _mm_load_ps(ps); - // Replicate each slot of x01 into its own register. - __m128 h11 = _mm_shuffle_ps(h01, h01, 0x55); - __m128 h21 = _mm_shuffle_ps(h01, h01, 0xAA); - __m128 h31 = _mm_shuffle_ps(h01, h01, 0xFF); - h01 = _mm_shuffle_ps(h01, h01, 0x00); - - __m256 x01 = _mm256_set_m128(h01, h01); - __m256 x11 = _mm256_set_m128(h11, h11); - __m256 x21 = _mm256_set_m128(h21, h21); - __m256 x31 = _mm256_set_m128(h31, h31); - ps += 4; - - for (float * pd = pdst; pd < pdLim; pd += 8, pm += 8) - { - const float * pmTmp; - __m256 x02 = _mm256_load_ps(pmTmp = pm); - __m256 x12 = _mm256_load_ps(pmTmp += crow); - __m256 x22 = _mm256_load_ps(pmTmp += crow); - __m256 x32 = _mm256_load_ps(pmTmp += crow); - x02 = _mm256_mul_ps(x01, x02); - x12 = _mm256_mul_ps(x11, x12); - x22 = _mm256_mul_ps(x21, x22); - x32 = _mm256_mul_ps(x31, x32); - x02 = _mm256_add_ps(x02, x12); - x22 = _mm256_add_ps(x22, x32); - x02 = _mm256_add_ps(x02, x22); - _mm256_store_ps(pd, x02); - } - - pm += 3 * crow; - } - - for (; ps < psLim; ps += 4) - { - __m128 h01 = _mm_load_ps(ps); - // Replicate each slot of x01 into its own register. - __m128 h11 = _mm_shuffle_ps(h01, h01, 0x55); - __m128 h21 = _mm_shuffle_ps(h01, h01, 0xAA); - __m128 h31 = _mm_shuffle_ps(h01, h01, 0xFF); - h01 = _mm_shuffle_ps(h01, h01, 0x00); - - __m256 x01 = _mm256_set_m128(h01, h01); - __m256 x11 = _mm256_set_m128(h11, h11); - __m256 x21 = _mm256_set_m128(h21, h21); - __m256 x31 = _mm256_set_m128(h31, h31); - - for (float * pd = pdst; pd < pdLim; pd += 8, pm += 8) - { - const float * pmTmp; - __m256 x02 = _mm256_load_ps(pmTmp = pm); - __m256 x12 = _mm256_load_ps(pmTmp += crow); - __m256 x22 = _mm256_load_ps(pmTmp += crow); - __m256 x32 = _mm256_load_ps(pmTmp += crow); - __m256 x3 = _mm256_load_ps(pd); - x02 = _mm256_mul_ps(x01, x02); - x12 = _mm256_mul_ps(x11, x12); - x22 = _mm256_mul_ps(x21, x22); - x32 = _mm256_mul_ps(x31, x32); - x02 = _mm256_add_ps(x02, x12); - x22 = _mm256_add_ps(x22, x32); - x02 = _mm256_add_ps(x02, x22); - x3 = _mm256_add_ps(x02, x3); - _mm256_store_ps(pd, x3); - } - - pm += 3 * crow; - } - - _vleave(); -} - -// Sparse matrix. -EXPORT_API(void) MatMulTranRX(bool add, _In_ const int * pstarts, _In_ const int * pindices, _In_ const float * pcoefs, - _In_ const float * psrc, _Inout_ float * pd, int crow, int ccol) -{ - if (!add) - memset(pd, 0, crow * sizeof(float)); - - const int * pii = pstarts + 1; - const int * pi = pindices; - const float * pm = pcoefs; - const float * psLim = psrc + ccol; - for (const float * ps = psrc; ps < psLim; ps++) - { - float x = *ps; - const int * piLim = pindices + *pii++; - - __m128 x0 = _mm_set1_ps(x); - __m256 x1 = _mm256_set_m128(x0, x0); - for (; pi + 8 <= piLim; pi += 8, pm += 8) - { - __m256 x2 = _mm256_mul_ps(x1, _mm256_loadu_ps(pm)); - x2 = _mm256_add_ps(x2, _load8(pd, pi)); - _store8(x2, pd, pi); - } - if (pi + 4 <= piLim) - { - __m128 x2 = _mm_mul_ps(x0, _mm_loadu_ps(pm)); - x2 = _mm_add_ps(x2, _load4(pd, pi)); - _store4(x2, pd, pi); - pi += 4; pm += 4; - } - for (; pi < piLim; pi++, pm++) - { - __m128 x2 = _mm_mul_ss(x0, _mm_set_ss(*pm)); - x2 = _mm_add_ss(x2, _load1(pd, pi)); - _store1(x2, pd, pi); - } - } - - _vleave(); -} - -// Unpadded convolution. -EXPORT_API(void) MatMulTranCX(bool add, _In_ const int * pmpcoliv, _In_ const int * pmpcolrow, - _In_ const int * pruns, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) -{ - if (!add) - memset(pdst, 0, crow * sizeof(float)); - - int size = pruns[1]; - const int * psupport = pruns + 2; - const int * piv = pmpcoliv; - const int * prow = pmpcolrow; - const int * piLim = psupport + size; - const float * psLim = psrc + ccol; - for (const float * ps = psrc; ps < psLim; ps++) - { - const float * pm = pcoefs + *piv++; - float * pd = pdst + *prow++; - const int * pi = psupport; - - float x = *ps; - __m128 x0 = _mm_set1_ps(x); - __m256 x1 = _mm256_set_m128(x0, x0); - for (; pi + 8 <= piLim; pi += 8, pm += 8) - { - __m256 x2 = _mm256_mul_ps(x1, _mm256_loadu_ps(pm)); - x2 = _mm256_add_ps(x2, _load8(pd, pi)); - _store8(x2, pd, pi); - } - if (pi + 4 <= piLim) - { - __m128 x2 = _mm_mul_ps(x0, _mm_loadu_ps(pm)); - x2 = _mm_add_ps(x2, _load4(pd, pi)); - _store4(x2, pd, pi); - pi += 4; pm += 4; - } - for (; pi < piLim; pi++, pm++) - { - __m128 x2 = _mm_mul_ss(x0, _mm_set_ss(*pm)); - x2 = _mm_add_ss(x2, _load1(pd, pi)); - _store1(x2, pd, pi); - } - } - - _vleave(); -} - -// Padded convolution. -EXPORT_API(void) MatMulTranDX(bool add, _In_ const int * pmpcoliv, _In_ const int * pmpcolrow, _In_ const int * pmpcolrun, - _In_ const int * pruns, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) -{ - if (!add) - memset(pdst, 0, crow * sizeof(float)); - - const int * piv = pmpcoliv; - const int * prow = pmpcolrow; - const float * psLim = psrc + ccol; - int kernelSize = pruns[1]; - - const int * pirun = pmpcolrun; - for (const float * ps = psrc; ps < psLim; ps++) - { - const float * pm = pcoefs + *piv++; - float * pd = pdst + *prow++; - int irun = *pirun++; - const int * pi = pruns + 2 + irun; - const int * piLim = pi + pi[-1]; - - float x = *ps; - __m128 x0 = _mm_set1_ps(x); - __m256 x1 = _mm256_set_m128(x0, x0); - if (irun == 0) - { - // No masking needed. - for (; pi + 8 <= piLim; pi += 8, pm += 8) - { - __m256 x2 = _mm256_mul_ps(x1, _mm256_loadu_ps(pm)); - x2 = _mm256_add_ps(x2, _load8(pd, pi)); - _store8(x2, pd, pi); - } - if (pi + 4 <= piLim) - { - __m128 x2 = _mm_mul_ps(x0, _mm_loadu_ps(pm)); - x2 = _mm_add_ps(x2, _load4(pd, pi)); - _store4(x2, pd, pi); - pi += 4; pm += 4; - } - for (; pi < piLim; pi++, pm++) - { - __m128 x2 = _mm_mul_ss(x0, _mm_set_ss(*pm)); - x2 = _mm_add_ss(x2, _load1(pd, pi)); - _store1(x2, pd, pi); - } - } - else - { - // Need masking. - pm += pi[-2]; - const float * pmask = reinterpret_cast(piLim); - for (; pi + 8 <= piLim; pi += 8, pm += 8, pmask += 8) - { - __m256 x2 = _mm256_mul_ps(_mm256_and_ps(_mm256_loadu_ps(pmask), x1), _mm256_loadu_ps(pm)); - x2 = _mm256_add_ps(x2, _load8(pd, pi)); - _store8(x2, pd, pi); - } - if (pi + 4 <= piLim) - { - __m128 x2 = _mm_mul_ps(_mm_and_ps(_mm_loadu_ps(pmask), x0), _mm_loadu_ps(pm)); - x2 = _mm_add_ps(x2, _load4(pd, pi)); - _store4(x2, pd, pi); - pi += 4; pm += 4; pmask += 4; - } - for (; pi < piLim; pi++, pm++, pmask++) - { - __m128 x2 = _mm_mul_ss(_mm_and_ps(_mm_set_ss(*pmask), x0), _mm_set_ss(*pm)); - x2 = _mm_add_ss(x2, _load1(pd, pi)); - _store1(x2, pd, pi); - } - } - } - - _vleave(); -} - -template -void AddXYTranXCore(float a, _In_ const float * px, _In_ const float * py, _Inout_ float * pmat, int crow, int ccol, float decay) -{ - const float * pyBase = py; - const float * pxLim = px + crow; - const float * pyLim = py + ccol; - float * pm = pmat; - __m256 wd; - if (useDecay) - wd = _mm256_set1_ps(1 - decay); - for (; px < pxLim; px++) - { - float r = a * *px; - py = pyBase; - - __m256 x1 = _mm256_set1_ps(r); - for (; py + 32 <= pyLim; py += 32, pm += 32) - { - __m256 x02 = _mm256_load_ps(py); - __m256 x12 = _mm256_load_ps(py + 8); - __m256 x22 = _mm256_load_ps(py + 16); - __m256 x32 = _mm256_load_ps(py + 24); - __m256 x03 = _mm256_load_ps(pm); - __m256 x13 = _mm256_load_ps(pm + 8); - __m256 x23 = _mm256_load_ps(pm + 16); - __m256 x33 = _mm256_load_ps(pm + 24); - x02 = _mm256_mul_ps(x1, x02); - x12 = _mm256_mul_ps(x1, x12); - x22 = _mm256_mul_ps(x1, x22); - x32 = _mm256_mul_ps(x1, x32); - if (useDecay) - { - x03 = _mm256_mul_ps(wd, x03); - x13 = _mm256_mul_ps(wd, x13); - x23 = _mm256_mul_ps(wd, x23); - x33 = _mm256_mul_ps(wd, x33); - } - x03 = _mm256_add_ps(x02, x03); - x13 = _mm256_add_ps(x12, x13); - x23 = _mm256_add_ps(x22, x23); - x33 = _mm256_add_ps(x32, x33); - _mm256_store_ps(pm, x03); - _mm256_store_ps(pm + 8, x13); - _mm256_store_ps(pm + 16, x23); - _mm256_store_ps(pm + 24, x33); - } - for (; py < pyLim; py += 8, pm += 8) - { - __m256 x02 = _mm256_load_ps(py); - __m256 x03 = _mm256_load_ps(pm); - x02 = _mm256_mul_ps(x1, x02); - if (useDecay) - x03 = _mm256_mul_ps(wd, x03); - x03 = _mm256_add_ps(x02, x03); - _mm256_store_ps(pm, x03); - } - } - - _vleave(); -} - -EXPORT_API(void) AddXYTranX(float a, _In_ const float * px, _In_ const float * py, _Inout_ float * pmat, int crow, int ccol, float decay) -{ - if (decay == 0) - AddXYTranXCore(a, px, py, pmat, crow, ccol, decay); - else - AddXYTranXCore(a, px, py, pmat, crow, ccol, decay); -} - -// Partial sparse source vector. -EXPORT_API(void) AddXYTranPX(float a, _In_ const float * px, _In_ const int * pposY, _In_ const float * pvaluesY, - int posMinY, int iposMinY, int iposLimY, _Inout_ float * pmat, int crow, int ccol) -{ - const int * pposMin = pposY + iposMinY; - const int * pposLim = pposY + iposLimY; - const float * pxLim = px + crow; - float * pm0 = pmat - posMinY; - const float * py = pvaluesY - posMinY; - - __m256 x0 = _mm256_set1_ps(a); - for (; px < pxLim; px += 8, pm0 += 8 * ccol) - { - float * pm1 = pm0 + ccol; - float * pm2 = pm1 + ccol; - float * pm3 = pm2 + ccol; - - __m256 x1 = _mm256_load_ps(px); - x1 = _mm256_mul_ps(x1, x0); - - for (const int * ppos = pposMin; ppos < pposLim; ppos++) - { - int col1 = *ppos; - int col2 = col1 + 4 * ccol; - __m256 x2 = _mm256_set1_ps(py[col1]); - __m256 x3 = _mm256_setr_ps( - pm0[col1], pm1[col1], pm2[col1], pm3[col1], - pm0[col2], pm1[col2], pm2[col2], pm3[col2]); - x2 = _mm256_mul_ps(x2, x1); - x3 = _mm256_add_ps(x3, x2); - - __m128 t1 = _get_lo(x3); - __m128 t2 = _get_hi(x3); - _mm_store_ss(pm0 + col1, t1); t1 = _rotate(t1); - _mm_store_ss(pm1 + col1, t1); t1 = _rotate(t1); - _mm_store_ss(pm2 + col1, t1); t1 = _rotate(t1); - _mm_store_ss(pm3 + col1, t1); - _mm_store_ss(pm0 + col2, t2); t2 = _rotate(t2); - _mm_store_ss(pm1 + col2, t2); t2 = _rotate(t2); - _mm_store_ss(pm2 + col2, t2); t2 = _rotate(t2); - _mm_store_ss(pm3 + col2, t2); - } - } - - _vleave(); -} - -template -void AddXYTranRXCore(float a, _In_ const float * px, _In_ const float * py, - _In_ const int * pstarts, _In_ const int * pindices, _Inout_ float * pcoefs, int crow, float decay) -{ - const int * pii = pstarts + 1; - const int * pi = pindices; - float * pm = pcoefs; - const float * pxLim = px + crow; - __m128 wd0; - __m256 wd1; - if (useDecay) - { - wd0 = _mm_set1_ps(1 - decay); - wd1 = _mm256_set_m128(wd0, wd0); - } - for (; px < pxLim; px++) - { - const int * piLim = pindices + *pii++; - float r = a * *px; - - __m128 x0 = _mm_set1_ps(r); - __m256 x1 = _mm256_set_m128(x0, x0); - for (; pi + 8 <= piLim; pi += 8, pm += 8) - { - __m256 x2 = _mm256_mul_ps(x1, _load8(py, pi)); - __m256 x3 = _mm256_loadu_ps(pm); - if (useDecay) - x3 = _mm256_mul_ps(x3, wd1); - x2 = _mm256_add_ps(x2, x3); - _mm256_storeu_ps(pm, x2); - } - if (pi + 4 <= piLim) - { - __m128 x2 = _mm_mul_ps(x0, _load4(py, pi)); - __m128 x3 = _mm_loadu_ps(pm); - if (useDecay) - x3 = _mm_mul_ps(x3, wd0); - x2 = _mm_add_ps(x2, x3); - _mm_storeu_ps(pm, x2); - pi += 4; pm += 4; - } - for (; pi < piLim; pi++, pm++) - *pm = (useDecay ? (*pm * (1 - decay)) : *pm) + py[*pi] * r; - } - - _vleave(); -} - -// Sparse matrix. -EXPORT_API(void) AddXYTranRX(float a, _In_ const float * px, _In_ const float * py, - _In_ const int * pstarts, _In_ const int * pindices, _Inout_ float * pcoefs, int crow, float decay) -{ - if (decay == 0) - AddXYTranRXCore(a, px, py, pstarts, pindices, pcoefs, crow, decay); - else - AddXYTranRXCore(a, px, py, pstarts, pindices, pcoefs, crow, decay); -} - -// Unpadded convolution. -EXPORT_API(void) AddXYTranCX(float a, _In_ const float * px, _In_ const float * py, _In_ const int * pmprowiv, _In_ const int * pmprowcol, - _In_ const int * pruns, _Inout_ float * pcoefs, int crow) -{ - int size = pruns[1]; - const int * psupport = pruns + 2; - const int * piv = pmprowiv; - const int * pcol = pmprowcol; - const float * pxLim = px + crow; - const int * piLim = psupport + size; - - for (; px < pxLim; px++) - { - float * pm = pcoefs + *piv++; - const float * ps = py + *pcol++; - const int * pi = psupport; - float r = a * *px; - - __m128 x0 = _mm_set1_ps(r); - __m256 x1 = _mm256_set_m128(x0, x0); - for (; pi + 8 <= piLim; pi += 8, pm += 8) - { - __m256 x2 = _mm256_mul_ps(x1, _load8(ps, pi)); - x2 = _mm256_add_ps(x2, _mm256_loadu_ps(pm)); - _mm256_storeu_ps(pm, x2); - } - if (pi + 4 <= piLim) - { - __m128 x2 = _mm_mul_ps(x0, _load4(ps, pi)); - x2 = _mm_add_ps(x2, _mm_loadu_ps(pm)); - _mm_storeu_ps(pm, x2); - pi += 4; pm += 4; - } - for (; pi < piLim; pi++, pm++) - *pm += ps[*pi] * r; - // Update the bias. - *pm += r; - } - - _vleave(); -} - -// Padded convolution. -EXPORT_API(void) AddXYTranDX(float a, _In_ const float * px, _In_ const float * py, _In_ const int * pmprowiv, _In_ const int * pmprowcol, - _In_ const int * pmprowrun, _In_ const int * pruns, _Inout_ float * pcoefs, int crow) -{ - const int * piv = pmprowiv; - const int * pcol = pmprowcol; - const float * pxLim = px + crow; - int kernelSize = pruns[1]; - - const int * pirun = pmprowrun; - for (; px < pxLim; px++) - { - float * pm = pcoefs + *piv++; - const float * ps = py + *pcol++; - int irun = *pirun++; - const int * pi = pruns + 2 + irun; - const int * piLim = pi + pi[-1]; - - float r = a * *px; - - // Update the bias. - pm[kernelSize] += r; - - __m128 x0 = _mm_set1_ps(r); - __m256 x1 = _mm256_set_m128(x0, x0); - if (irun == 0) - { - // No masking needed. - for (; pi + 8 <= piLim; pi += 8, pm += 8) - { - __m256 x2 = _mm256_mul_ps(x1, _load8(ps, pi)); - x2 = _mm256_add_ps(x2, _mm256_loadu_ps(pm)); - _mm256_storeu_ps(pm, x2); - } - if (pi + 4 <= piLim) - { - __m128 x2 = _mm_mul_ps(x0, _load4(ps, pi)); - x2 = _mm_add_ps(x2, _mm_loadu_ps(pm)); - _mm_storeu_ps(pm, x2); - pi += 4; pm += 4; - } - for (; pi < piLim; pi++, pm++) - *pm += ps[*pi] * r; - } - else - { - // Need masking. - pm += pi[-2]; - const float * pmask = reinterpret_cast(piLim); - for (; pi + 8 <= piLim; pi += 8, pm += 8, pmask += 8) - { - __m256 x2 = _mm256_mul_ps(_mm256_and_ps(_mm256_loadu_ps(pmask), x1), _load8(ps, pi)); - x2 = _mm256_add_ps(x2, _mm256_loadu_ps(pm)); - _mm256_storeu_ps(pm, x2); - } - if (pi + 4 <= piLim) - { - __m128 x2 = _mm_mul_ps(_mm_and_ps(_mm_loadu_ps(pmask), x0), _load4(ps, pi)); - x2 = _mm_add_ps(x2, _mm_loadu_ps(pm)); - _mm_storeu_ps(pm, x2); - pi += 4; pm += 4; pmask += 4; - } - for (; pi < piLim; pi++, pm++, pmask++) - { - __m128 x2 = _mm_mul_ss(_mm_and_ps(_mm_set_ss(*pmask), x0), _load1(ps, pi)); - x2 = _mm_add_ss(x2, _mm_set_ss(*pm)); - _mm_store_ss(pm, x2); - } - } - } - - _vleave(); -} - -// With momentum. -EXPORT_API(void) AddXYTranMomX(float a, _In_ const float * px, _In_ const float * py, _Inout_ float * pmat, float momentum, _Inout_ float * pdel, int crow, int ccol) -{ - const float * pyBase = py; - const float * pxLim = px + crow; - const float * pyLim = py + ccol; - float * pm = pmat; - float * pd = pdel; - - __m256 x0 = _mm256_set1_ps(momentum); - for (; px < pxLim; px++) - { - float r = a * *px; - - __m256 x1 = _mm256_set1_ps(r); - for (py = pyBase; py < pyLim; pm += 8, pd += 8, py += 8) - { - __m256 x2 = _mm256_load_ps(py); - __m256 x3 = _mm256_load_ps(pd); - __m256 x4 = _mm256_load_ps(pm); - x2 = _mm256_mul_ps(x1, x2); - x3 = _mm256_mul_ps(x0, x3); - x3 = _mm256_add_ps(x2, x3); - x4 = _mm256_add_ps(x3, x4); - - _mm256_store_ps(pd, x3); - _mm256_store_ps(pm, x4); - } - } - - _vleave(); -} - -// coef: coefs to update, ag: accumulated grads, au: accumulated updates, g: cur grads. -// Note: parameters coef, ag, au and g will be updated, do not reuse parameter g in calling code. -__forceinline void UpdateAdadelta(__m256& coef, __m256& ag, __m256& au, __m256& g, const __m256& dec, const __m256& decc, const __m256& c) -{ - __m256 x4 = _mm256_mul_ps(g, g); // x4 == g * g - x4 = _mm256_mul_ps(decc, x4); // x4 == (1 - decay) * g * g - ag = _mm256_mul_ps(dec, ag); // ag == decay * accG - ag = _mm256_add_ps(ag, x4); // ag == decay * accG + (1 - decay) * g * g - __m256 x41 = _mm256_add_ps(ag, c); // x41 == ag + cond - __m256 x51 = _mm256_add_ps(au, c); // x51 == accU + cond -#if 0 - // naive version: - x51 = _mm256_div_ps(x51, x41); - x41 = _mm256_sqrt_ps(x51); // x41 == rate -#else - // faster (approximate) version: - x41 = _mm256_rsqrt_ps(x41); - __m256 x52 = _mm256_rsqrt_ps(x51); - x51 = _mm256_mul_ps(x51, x52); - x41 = _mm256_mul_ps(x41, x51); // x41 == rate -#endif - g = _mm256_mul_ps(g, x41); // g - current update - coef = _mm256_add_ps(coef, g); - - g = _mm256_mul_ps(g, g); // g == newU * newU - g = _mm256_mul_ps(decc, g); // g == (1 - decay) * newU * newU - au = _mm256_mul_ps(dec, au); // au == decay * accU - au = _mm256_add_ps(au, g); // au == decay * accU + (1 - decay) * newU * newU -} - -// For Adadelta. -EXPORT_API(void) AddXYTranGradX(_In_ const float * px, _In_ const float * py, _Inout_ float * pmat, _Inout_ float * paccGrads, _Inout_ float * paccUpdates, - float decay, float cond, int crow, int ccol) -{ - const float * pyBase = py; - const float * pxLim = px + crow; - const float * pyLim = py + ccol; - float * pm = pmat; - float * pag = paccGrads; - float * pau = paccUpdates; - - __m256 dec = _mm256_set1_ps(decay); - __m256 decc = _mm256_set1_ps(1 - decay); - __m256 c = _mm256_set1_ps(cond); - for (; px < pxLim; px++) - { - float r = *px; - - __m256 x1 = _mm256_set1_ps(r); - for (py = pyBase; py < pyLim; pm += 8, pag += 8, pau += 8, py += 8) - { - __m256 x2 = _mm256_load_ps(py); - __m256 ag = _mm256_load_ps(pag); - __m256 au = _mm256_load_ps(pau); - __m256 coef = _mm256_load_ps(pm); - x2 = _mm256_mul_ps(x1, x2); // x2 == g - - UpdateAdadelta(coef, ag, au, x2, dec, decc, c); - - _mm256_store_ps(pm, coef); - _mm256_store_ps(pag, ag); - _mm256_store_ps(pau, au); - } - } - - _vleave(); -} - -// For Adadelta, sparse matrix. -EXPORT_API(void) AddXYTranGradRX(_In_ const float * px, _In_ const float * py, _In_ const int * pstarts, _In_ const int * pindices, - _Inout_ float * pcoefs, _Inout_ float * paccGrads, _Inout_ float * paccUpdates, float decay, float cond, int crow) -{ - const int * pii = pstarts + 1; - const int * pi = pindices; - float * pm = pcoefs; - const float * pxLim = px + crow; - float * pag = paccGrads; - float * pau = paccUpdates; - - __m256 dec = _mm256_set1_ps(decay); - __m256 decc = _mm256_set1_ps(1 - decay); - __m256 c = _mm256_set1_ps(cond); - - for (; px < pxLim; px++) - { - const int * piLim = pindices + *pii++; - float r = *px; - - __m256 x1 = _mm256_set1_ps(r); - for (; pi + 8 <= piLim; pi += 8, pm += 8, pag += 8, pau += 8) - { - __m256 g = _mm256_mul_ps(x1, _load8(py, pi)); - __m256 ag = _mm256_loadu_ps(pag); - __m256 au = _mm256_loadu_ps(pau); - __m256 coef = _mm256_loadu_ps(pm); - - UpdateAdadelta(coef, ag, au, g, dec, decc, c); - - _mm256_storeu_ps(pm, coef); - _mm256_storeu_ps(pag, ag); - _mm256_storeu_ps(pau, au); - } - - // REVIEW: Why is this so different than the SSE version? - for (; pi < piLim; pi++, pm++, pag++, pau++) - { - float g = py[*pi] * r; - float accGrad = decay * *pag + (1 - decay) * g * g; - float accUpd = *pau; - float newUpd = sqrtf((accUpd + cond) / (accGrad + cond)) * g; - *pm += newUpd; - *pag = accGrad; - *pau = decay * accUpd + (1 - decay) * newUpd * newUpd; - } - } - - _vleave(); -} - -// For Adadelta, partial sparse source vector. -EXPORT_API(void) AddXYTranGradPX(_In_ const float * px, _In_ const int * pposY, _In_ const float * pvaluesY, - int posMinY, int iposMinY, int iposLimY, _Inout_ float * pmat, _Inout_ float * paccGrads, _Inout_ float * paccUpdates, - float decay, float cond, int crow, int ccol) -{ - const int * pposMin = pposY + iposMinY; - const int * pposLim = pposY + iposLimY; - const float * pxLim = px + crow; - const float * py = pvaluesY - posMinY; - float * pm0 = pmat - posMinY; - float * pag0 = paccGrads - posMinY; - float * pau0 = paccUpdates - posMinY; - - __m256 dec = _mm256_set1_ps(decay); - __m256 decc = _mm256_set1_ps(1 - decay); - __m256 c = _mm256_set1_ps(cond); - for (; px < pxLim; px += 8, pm0 += 8 * ccol, pag0 += 8 * ccol, pau0 += 8 * ccol) - { - float * pm1 = pm0 + ccol; - float * pm2 = pm1 + ccol; - float * pm3 = pm2 + ccol; - - float * pag1 = pag0 + ccol; - float * pag2 = pag1 + ccol; - float * pag3 = pag2 + ccol; - - float * pau1 = pau0 + ccol; - float * pau2 = pau1 + ccol; - float * pau3 = pau2 + ccol; - - __m256 x1 = _mm256_load_ps(px); - - for (const int * ppos = pposMin; ppos < pposLim; ppos++) - { - int col1 = *ppos; - int col2 = col1 + 4 * ccol; - __m256 x2 = _mm256_set1_ps(py[col1]); - __m256 ag = _mm256_setr_ps( - pag0[col1], pag1[col1], pag2[col1], pag3[col1], - pag0[col2], pag1[col2], pag2[col2], pag3[col2]); - __m256 au = _mm256_setr_ps( - pau0[col1], pau1[col1], pau2[col1], pau3[col1], - pau0[col2], pau1[col2], pau2[col2], pau3[col2]); - __m256 coef = _mm256_setr_ps( - pm0[col1], pm1[col1], pm2[col1], pm3[col1], - pm0[col2], pm1[col2], pm2[col2], pm3[col2]); - x2 = _mm256_mul_ps(x2, x1); - - UpdateAdadelta(coef, ag, au, x2, dec, decc, c); - - __m128 t1 = _get_lo(coef); - __m128 t2 = _get_hi(coef); - _mm_store_ss(pm0 + col1, t1); t1 = _rotate(t1); - _mm_store_ss(pm1 + col1, t1); t1 = _rotate(t1); - _mm_store_ss(pm2 + col1, t1); t1 = _rotate(t1); - _mm_store_ss(pm3 + col1, t1); - _mm_store_ss(pm0 + col2, t2); t2 = _rotate(t2); - _mm_store_ss(pm1 + col2, t2); t2 = _rotate(t2); - _mm_store_ss(pm2 + col2, t2); t2 = _rotate(t2); - _mm_store_ss(pm3 + col2, t2); - - t1 = _get_lo(ag); - t2 = _get_hi(ag); - _mm_store_ss(pag0 + col1, t1); t1 = _rotate(t1); - _mm_store_ss(pag1 + col1, t1); t1 = _rotate(t1); - _mm_store_ss(pag2 + col1, t1); t1 = _rotate(t1); - _mm_store_ss(pag3 + col1, t1); - _mm_store_ss(pag0 + col2, t2); t2 = _rotate(t2); - _mm_store_ss(pag1 + col2, t2); t2 = _rotate(t2); - _mm_store_ss(pag2 + col2, t2); t2 = _rotate(t2); - _mm_store_ss(pag3 + col2, t2); - - t1 = _get_lo(au); - t2 = _get_hi(au); - _mm_store_ss(pau0 + col1, t1); t1 = _rotate(t1); - _mm_store_ss(pau1 + col1, t1); t1 = _rotate(t1); - _mm_store_ss(pau2 + col1, t1); t1 = _rotate(t1); - _mm_store_ss(pau3 + col1, t1); - _mm_store_ss(pau0 + col2, t2); t2 = _rotate(t2); - _mm_store_ss(pau1 + col2, t2); t2 = _rotate(t2); - _mm_store_ss(pau2 + col2, t2); t2 = _rotate(t2); - _mm_store_ss(pau3 + col2, t2); - } - } - - _vleave(); -} - -EXPORT_API(void) ScaleX(float a, _Inout_ float * pd, int c) -{ - float * pdLim = pd + c; - - __m256 x1 = _mm256_set1_ps(a); - for (; pd < pdLim; pd += 8) - { - __m256 x2 = _mm256_load_ps(pd); - x2 = _mm256_mul_ps(x1, x2); - _mm256_store_ps(pd, x2); - } - - _vleave(); -} - -EXPORT_API(void) ScaleMaxNormX(float maxNorm, _Inout_ float * pmat, int crow, int ccol) -{ - float * pm = pmat; - float maxNormSq = maxNorm * maxNorm; - __m256 m = _mm256_set1_ps(maxNorm); - for (int irow = 0; irow < crow; irow++) - { - __m256 rowNorm = _mm256_set1_ps(0); - float * pms = pm; - float * pmLim = pm + ccol; - for (; pm < pmLim; pm += 8) - { - __m256 x1 = _mm256_load_ps(pm); - x1 = _mm256_mul_ps(x1, x1); - rowNorm = _mm256_add_ps(x1, rowNorm); - } - rowNorm = _mm256_hadd_ps(rowNorm, rowNorm); - rowNorm = _mm256_hadd_ps(rowNorm, rowNorm); - float rowNormRes = _mm_cvtss_f32(_mm_add_ss(_get_lo(rowNorm), _get_hi(rowNorm))); - if (rowNormRes > maxNormSq) - { - __m256 scale = _mm256_set1_ps(rowNormRes); -#if 0 - // REVIEW: this is faster but it uses approximation so results differ significantly from CLR. - scale = _mm256_rsqrt_ps(scale); - scale = _mm256_mul_ps(scale, m); -#else - scale = _mm256_sqrt_ps(scale); - scale = _mm256_div_ps(m, scale); -#endif - for (pm = pms; pm < pmLim; pm += 8) - { - __m256 x1 = _mm256_load_ps(pm); - x1 = _mm256_mul_ps(x1, scale); - _mm256_store_ps(pm, x1); - } - } - } - - _vleave(); -} - -EXPORT_API(void) AddScaleX(float a, _In_ const float * ps, _Inout_ float * pd, int c) -{ - float * pdLim = pd + c; - - __m256 x1 = _mm256_set1_ps(a); - for (; pd < pdLim; pd += 8, ps += 8) - { - __m256 x2 = _mm256_load_ps(ps); - __m256 x3 = _mm256_load_ps(pd); - x2 = _mm256_mul_ps(x1, x2); - x3 = _mm256_add_ps(x2, x3); - _mm256_store_ps(pd, x3); - } - - _vleave(); -} - -EXPORT_API(void) AddScaleMomX(float a, _In_ const float * ps, _Inout_ float * pd, float momentum, _Inout_ float * pe, int c) -{ - float * pdLim = pd + c; - - __m256 x0 = _mm256_set1_ps(momentum); - __m256 x1 = _mm256_set1_ps(a); - for (; pd < pdLim; pd += 8, pe += 8, ps += 8) - { - __m256 x2 = _mm256_load_ps(ps); - __m256 x3 = _mm256_load_ps(pe); - __m256 x4 = _mm256_load_ps(pd); - x2 = _mm256_mul_ps(x1, x2); - x3 = _mm256_mul_ps(x0, x3); - x3 = _mm256_add_ps(x2, x3); - x4 = _mm256_add_ps(x3, x4); - _mm256_store_ps(pe, x3); - _mm256_store_ps(pd, x4); - } - - _vleave(); -} - -EXPORT_API(void) AddScaleGradX(_In_ const float * ps, _Inout_ float * pd, _Inout_ float * paccGrads, _Inout_ float * paccUpdates, - float decay, float cond, int c) -{ - float * pdLim = pd + c; - - __m256 dec = _mm256_set1_ps(decay); - __m256 decc = _mm256_set1_ps(1 - decay); - __m256 cnd = _mm256_set1_ps(cond); - for (; pd < pdLim; pd += 8, ps += 8, paccGrads += 8, paccUpdates += 8) - { - __m256 g = _mm256_load_ps(ps); - __m256 ag = _mm256_load_ps(paccGrads); - __m256 au = _mm256_load_ps(paccUpdates); - __m256 coef = _mm256_load_ps(pd); - - UpdateAdadelta(coef, ag, au, g, dec, decc, cnd); - - _mm256_store_ps(pd, coef); - _mm256_store_ps(paccGrads, ag); - _mm256_store_ps(paccUpdates, au); - } - - _vleave(); -} - -EXPORT_API(void) AddX(_In_ const float * ps, _Inout_ float * pd, int c) -{ - float * pdLim = pd + c; - - for (; pd < pdLim; pd += 8, ps += 8) - { - __m256 x1 = _mm256_load_ps(ps); - __m256 x2 = _mm256_load_ps(pd); - x2 = _mm256_add_ps(x1, x2); - _mm256_store_ps(pd, x2); - } - - _vleave(); -} - -EXPORT_API(float) SumX(const float * ps, int c) -{ - const float * psLim = ps + c; - - __m256 res = _mm256_setzero_ps(); - for (; ps < psLim; ps += 8) - { - __m256 x1 = _mm256_load_ps(ps); - res = _mm256_add_ps(res, x1); - } - res = _mm256_hadd_ps(res, res); - res = _mm256_hadd_ps(res, res); - __m128 r = _mm_add_ss(_get_lo(res), _get_hi(res)); - - float ret = _mm_cvtss_f32(r); - _vleave(); - return ret; -} - -EXPORT_API(void) ScaleAdadeltaX(_Inout_ float * mat, _Inout_ float * accGrads, _Inout_ float * accUpdates, float decay, float cond, _In_ const float * grads, int size) -{ - float * pm = mat; - float * pmLim = pm + size; - float * pag = accGrads; - float * pau = accUpdates; - const float * pg = grads; - - __m256 dec = _mm256_set1_ps(decay); - __m256 decc = _mm256_set1_ps(1 - decay); - __m256 c = _mm256_set1_ps(cond); - - for (; pm + 8 <= pmLim; pm += 8, pag += 8, pau += 8, pg += 8) - { - __m256 g = _mm256_loadu_ps(pg); - __m256 ag = _mm256_loadu_ps(pag); - __m256 au = _mm256_loadu_ps(pau); - __m256 coef = _mm256_loadu_ps(pm); - - UpdateAdadelta(coef, ag, au, g, dec, decc, c); - - _mm256_storeu_ps(pm, coef); - _mm256_storeu_ps(pag, ag); - _mm256_storeu_ps(pau, au); - } - - for (; pm < pmLim; pm++, pag++, pau++, pg++) - { - float g = *pg; - float accGrad = decay * *pag + (1 - decay) * g * g; - float accUpd = *pau; - float newUpd = sqrtf((accUpd + cond) / (accGrad + cond)) * g; - *pm += newUpd; - *pag = accGrad; - *pau = decay * accUpd + (1 - decay) * newUpd * newUpd; - } - - _vleave(); -} diff --git a/src/Native/CpuMathNative/CMakeLists.txt b/src/Native/CpuMathNative/CMakeLists.txt index d4c9772421..1c66a7fcd4 100644 --- a/src/Native/CpuMathNative/CMakeLists.txt +++ b/src/Native/CpuMathNative/CMakeLists.txt @@ -2,15 +2,11 @@ project (CpuMathNative) set(SOURCES Sse.cpp - Avx.cpp MathLinux.S ) -if(WIN32) - set_property(SOURCE Avx.cpp APPEND_STRING PROPERTY COMPILE_FLAGS " /arch:AVX") -else() +if(NOT WIN32) set_property(SOURCE Sse.cpp APPEND_STRING PROPERTY COMPILE_FLAGS " -msse3") - set_property(SOURCE Avx.cpp APPEND_STRING PROPERTY COMPILE_FLAGS " -mavx") list(APPEND SOURCES ${VERSION_FILE_PATH}) endif() diff --git a/src/Native/CpuMathNative/Sse.cpp b/src/Native/CpuMathNative/Sse.cpp index 9fde43fb12..5f519a2492 100644 --- a/src/Native/CpuMathNative/Sse.cpp +++ b/src/Native/CpuMathNative/Sse.cpp @@ -9,15 +9,7 @@ // * U suffix means unaligned and unpadded. // * S suffix means sparse (unaligned) vector. // * P suffix means sparse (unaligned) partial vector - the vector is only part of a larger sparse vector. -// * R suffix means sparse matrix. -// * C suffix means convolution matrix. -// * D suffix means convolution matrix, with implicit source padding. // * Tran means the matrix is transposed. -// -// Other notes: -// * AVX methods should end with _vleave() to avoid performance hit. See: -// https://stackoverflow.com/questions/7839925/using-avx-cpu-instructions-poor-performance-without-archavx. -// * Keep Avx.cpp in sync with Sse.cpp. Note that Avx.cpp is compiled with /arch:AVX, but Sse.cpp is not. #include "../Stdafx.h" #include @@ -49,43 +41,6 @@ x = _rotate(x); _mm_store_ss(pd + pi[2], x); \ x = _rotate(x); _mm_store_ss(pd + pi[3], x) -#ifndef _WIN32 - -typedef unsigned int DWORD; // NOTE: diff from windows.h, for LP64 compat - -// getcpuid and xmmYmmStateSupport are taken from -// https://github.com/dotnet/coreclr/blob/b5f4d2df2e087401f2c3aab2c37021e326707915/src/vm/amd64/unixstubs.cpp#L14-L55 - -DWORD getcpuid(DWORD arg, unsigned char result[16]) -{ - DWORD eax; - __asm(" xor %%ecx, %%ecx\n" \ - " cpuid\n" \ - " mov %%eax, 0(%[result])\n" \ - " mov %%ebx, 4(%[result])\n" \ - " mov %%ecx, 8(%[result])\n" \ - " mov %%edx, 12(%[result])\n" \ - : "=a"(eax) /*output in eax*/\ - : "a"(arg), [result]"r"(result) /*inputs - arg in eax, result in any register*/\ - : "rbx", "ecx", "edx", "memory" /* registers that are clobbered, *result is clobbered */ - ); - return eax; -} - -DWORD xmmYmmStateSupport() -{ - DWORD eax; - __asm(" xgetbv\n" \ - : "=a"(eax) /*output in eax*/\ - : "c"(0) /*inputs - 0 in ecx*/\ - : "edx" /* registers that are clobbered*/ - ); - // check OS has enabled both XMM and YMM state support - return ((eax & 0x06) == 0x06) ? 1 : 0; -} - -#endif - const unsigned int LeadingAlignmentMask[16] = { 0x00000000, 0x00000000, 0x00000000, 0x00000000, @@ -102,25 +57,6 @@ const unsigned int TrailingAlignmentMask[16] = 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, }; -// Test whether Avx is available. -EXPORT_API(bool) ChkAvx() -{ -#ifdef _WIN32 - int cpuInfo[4]; - __cpuid(cpuInfo, 1); - - // 28th bit of second integer of Cpu Info denotes whether the Avx is supported in CPU or not - // Reference https://msdn.microsoft.com/en-us/library/hskdteyh(v=vs.100).aspx - return cpuInfo[2] & (1 << 28) || false; -#else - unsigned char buffer[16]; - (void) getcpuid(1, buffer); - - // taken from https://github.com/dotnet/coreclr/blob/b5f4d2df2e087401f2c3aab2c37021e326707915/src/vm/codeman.cpp#L1381 - return ((buffer[11] & 0x18) == 0x18) && (xmmYmmStateSupport() == 1); -#endif -} - // Multiply matrix times vector into vector. EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) { @@ -255,332 +191,146 @@ EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout } // Partial sparse source vector. -EXPORT_API(void) MatMulPA(_In_ const float * pmat, _In_ const int * pposSrc, _In_ const float * psrc, +EXPORT_API(void) MatMulP(_In_ const float * pmat, _In_ const int * pposSrc, _In_ const float * psrc, int posMin, int iposMin, int iposLim, _Inout_ float * pdst, int crow, int ccol) { // REVIEW: For extremely sparse inputs, interchanging the loops would // likely be more efficient. const int * pposMin = pposSrc + iposMin; - const int * pposLim = pposSrc + iposLim; - const float * pdLim = pdst + crow; + const int * pposEnd = pposSrc + iposLim; + const float * pDstEnd = pdst + crow; const float * pm0 = pmat - posMin; - const float * ps = psrc - posMin; - for (float * pd = pdst; pd < pdLim; pd += 4, pm0 += 4 * ccol) + const float * pSrcCurrent = psrc - posMin; + float* pDstCurrent = pdst; + + uintptr_t address = (uintptr_t)(pDstCurrent); + uintptr_t misalignment = address % 16; + int length = crow; + int remainder = 0; + + if ((misalignment & 3) != 0) { - const float * pm1 = pm0 + ccol; - const float * pm2 = pm1 + ccol; - const float * pm3 = pm2 + ccol; - __m128 res = _mm_setzero_ps(); - for (const int * ppos = pposMin; ppos < pposLim; ppos++) + while (pDstCurrent < pDstEnd) { - int col = *ppos; - __m128 x1 = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]); - __m128 x2 = _mm_set1_ps(ps[col]); - x2 = _mm_mul_ps(x2, x1); - res = _mm_add_ps(res, x2); - } + const float* pm1 = pm0 + ccol; + const float* pm2 = pm1 + ccol; + const float* pm3 = pm2 + ccol; - _mm_store_ps(pd, res); - } -} + __m128 res = _mm_setzero_ps(); + const int* ppos = pposMin; -// Sparse matrix. -EXPORT_API(void) MatMulRU(bool add, _In_ const int * pstarts, _In_ const int * pindices, _In_ const float * pcoefs, - _In_ const float * ps, _Inout_ float * pdst, int crow) -{ - const int * pii = pstarts + 1; - const int * pi = pindices; - const float * pm = pcoefs; - const float * pdLim = pdst + crow; - for (float * pd = pdst; pd < pdLim; pd++) - { - const int * piLim = pindices + *pii++; + while (ppos < pposEnd) + { + int col = *ppos; + __m128 x1 = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]); + __m128 x2 = _mm_set1_ps(pSrcCurrent[col]); + x2 = _mm_mul_ps(x2, x1); + res = _mm_add_ps(res, x2); + ppos++; + } - __m128 res = _mm_setzero_ps(); - for (; pi + 4 <= piLim; pi += 4, pm += 4) - { - __m128 x = _mm_mul_ps(_load4(ps, pi), _mm_loadu_ps(pm)); - res = _mm_add_ps(res, x); - } - for (; pi < piLim; pi++, pm++) - { - __m128 x = _mm_mul_ss(_load1(ps, pi), _mm_set_ss(*pm)); - res = _mm_add_ss(res, x); + _mm_storeu_ps(pDstCurrent, res); + pDstCurrent += 4; + pm0 += 4 * ccol; } - res = _mm_hadd_ps(res, res); - res = _mm_hadd_ps(res, res); - - if (add) - res = _mm_add_ss(res, _mm_set_ss(*pd)); - _mm_store_ss(pd, res); } -} - -// Unpadded convolution. -EXPORT_API(void) MatMulCU(bool add, _In_ const int * pmprowiv, _In_ const int * pmprowcol, - _In_ const int * pruns, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pdst, int crow) -{ - int size = pruns[1]; - const int * psupport = pruns + 2; - const int * piv = pmprowiv; - const int * pcol = pmprowcol; - const int * piLim = psupport + size; - const float * pdLim = pdst + crow; - - for (float * pd = pdst; pd < pdLim; pd++) + else { - const float * pm = pcoefs + *piv++; - const float * ps = psrc + *pcol++; - const int * pi = psupport; - - __m128 res = _mm_setzero_ps(); - for (; pi + 4 <= piLim; pi += 4, pm += 4) - { - __m128 x = _mm_mul_ps(_load4(ps, pi), _mm_loadu_ps(pm)); - res = _mm_add_ps(res, x); - } - for (; pi < piLim; pi++, pm++) + if (misalignment != 0) { - __m128 x = _mm_mul_ss(_load1(ps, pi), _mm_set_ss(*pm)); - res = _mm_add_ss(res, x); - } - res = _mm_hadd_ps(res, res); - res = _mm_hadd_ps(res, res); + misalignment >>= 2; + misalignment = 4 - misalignment; - // Add the bias. - res = _mm_add_ss(res, _mm_set_ss(*pm)); + __m128 mask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); - if (add) - res = _mm_add_ss(res, _mm_set_ss(*pd)); - _mm_store_ss(pd, res); - } -} + const float* pm1 = pm0 + ccol; + const float* pm2 = pm1 + ccol; + const float* pm3 = pm2 + ccol; -// Padded convolution. -EXPORT_API(void) MatMulDU(bool add, _In_ const int * pmprowiv, _In_ const int * pmprowcol, _In_ const int * pmprowrun, - _In_ const int * pruns, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pdst, int crow) -{ - const int * piv = pmprowiv; - const int * pcol = pmprowcol; - const float * pdLim = pdst + crow; - int kernelSize = pruns[1]; + __m128 res = _mm_setzero_ps(); + const int* ppos = pposMin; - const int * pirun = pmprowrun; - for (float * pd = pdst; pd < pdLim; pd++) - { - const float * pm = pcoefs + *piv++; - const float * pmBias = pm + kernelSize; - const float * ps = psrc + *pcol++; - int irun = *pirun++; - - const int * pi = pruns + 2 + irun; - const int * piLim = pi + pi[-1]; - __m128 res = _mm_setzero_ps(); - if (irun == 0) - { - // No masking needed. - for (; pi + 4 <= piLim; pi += 4, pm += 4) - { - __m128 x = _mm_mul_ps(_load4(ps, pi), _mm_loadu_ps(pm)); - res = _mm_add_ps(res, x); - } - for (; pi < piLim; pi++, pm++) + while (ppos < pposEnd) { - __m128 x = _mm_mul_ss(_load1(ps, pi), _mm_set_ss(*pm)); - res = _mm_add_ss(res, x); + int col = *ppos; + __m128 x1 = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]); + x1 = _mm_and_ps(mask, x1); + + __m128 x2 = _mm_set1_ps(pSrcCurrent[col]); + x2 = _mm_mul_ps(x2, x1); + res = _mm_add_ps(res, x2); + ppos++; } + + _mm_storeu_ps(pDstCurrent, res); + pDstCurrent += misalignment; + pm0 += misalignment * ccol; + length -= misalignment; } - else + + if (length > 3) { - // Need masking. - pm += pi[-2]; - const float * pmask = reinterpret_cast(piLim); - for (; pi + 4 <= piLim; pi += 4, pm += 4, pmask += 4) - { - __m128 x = _mm_mul_ps(_load4(ps, pi), _mm_and_ps(_mm_loadu_ps(pmask), _mm_loadu_ps(pm))); - res = _mm_add_ps(res, x); - } - for (; pi < piLim; pi++, pm++, pmask++) + remainder = length % 4; + while (pDstCurrent < pDstEnd) { - __m128 x = _mm_mul_ss(_load1(ps, pi), _mm_and_ps(_mm_set_ss(*pmask), _mm_set_ss(*pm))); - res = _mm_add_ss(res, x); - } - } - res = _mm_hadd_ps(res, res); - res = _mm_hadd_ps(res, res); + const float* pm1 = pm0 + ccol; + const float* pm2 = pm1 + ccol; + const float* pm3 = pm2 + ccol; - res = _mm_add_ss(res, _mm_set_ss(*pmBias)); - if (add) - res = _mm_add_ss(res, _mm_set_ss(*pd)); - _mm_store_ss(pd, res); - } -} - -// Mean pooling. -EXPORT_API(void) MeanU(bool add, _In_ const int * pmprowcol, _In_opt_ const int * pmprowindices, _In_ const int * pindices, - _In_ const float * psrc, _Inout_ float * pdst, int crow) -{ - const int * pcol = pmprowcol; - const float * pdLim = pdst + crow; + const int* ppos = pposMin; + __m128 res = _mm_setzero_ps(); - if (pmprowindices == nullptr) - { - int size = pindices[0]; - __m128 x0 = _mm_set_ss((float)size); - const int * piLim = pindices + 1 + size; - for (float * pd = pdst; pd < pdLim; pd++) - { - const float * ps = psrc + *pcol++; - const int * pi = pindices + 1; + while (ppos < pposEnd) + { + int col = *ppos; + __m128 x1 = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]); + __m128 x2 = _mm_set1_ps(pSrcCurrent[col]); + x2 = _mm_mul_ps(x2, x1); + res = _mm_add_ps(res, x2); + ppos++; + } - __m128 res = _mm_setzero_ps(); - for (; pi + 4 <= piLim; pi += 4) - res = _mm_add_ps(res, _load4(ps, pi)); - for (; pi < piLim; pi++) - res = _mm_add_ss(res, _load1(ps, pi)); - res = _mm_hadd_ps(res, res); - res = _mm_hadd_ps(res, res); - - res = _mm_div_ss(res, x0); - if (add) - res = _mm_add_ss(res, _mm_set_ss(*pd)); - _mm_store_ss(pd, res); + _mm_store_ps(pDstCurrent, res); + pDstCurrent += 4; + pm0 += 4 * ccol; + } } - } - else - { - const int * pii = pmprowindices; - for (float * pd = pdst; pd < pdLim; pd++) + else { - const float * ps = psrc + *pcol++; - int ii = *pii++; - - const int * pi = pindices + ii; - int size = *pi++; - const int * piLim = pi + size; - __m128 res = _mm_setzero_ps(); - for (; pi + 4 <= piLim; pi += 4) - res = _mm_add_ps(res, _load4(ps, pi)); - for (; pi < piLim; pi++) - res = _mm_add_ss(res, _load1(ps, pi)); - res = _mm_hadd_ps(res, res); - res = _mm_hadd_ps(res, res); - - res = _mm_div_ss(res, _mm_set_ss((float)size)); - if (add) - res = _mm_add_ss(res, _mm_set_ss(*pd)); - _mm_store_ss(pd, res); + length = remainder; } - } -} - -// Max pooling. -EXPORT_API(void) MaxU(bool add, _In_ const int * pmprowcol, _In_opt_ const int * pmprowindices, _In_ const int * pindices, - _In_ const float * psrc, _Inout_ float * pdst, int crow) -{ - const int * pcol = pmprowcol; - const float * pdLim = pdst + crow; - __m128 min = _mm_set1_ps(-std::numeric_limits::infinity()); - if (pmprowindices == nullptr) - { - int size = pindices[0]; - const int * piLim = pindices + 1 + size; - for (float * pd = pdst; pd < pdLim; pd++) - { - const float * ps = psrc + *pcol++; - const int * pi = pindices + 1; - - __m128 res = min; - for (; pi + 4 <= piLim; pi += 4) - res = _mm_max_ps(res, _load4(ps, pi)); - for (; pi < piLim; pi++) - res = _mm_max_ss(res, _load1(ps, pi)); - __m128 x1 = _mm_shuffle_ps(res, res, 0xB1); - res = _mm_max_ps(res, x1); - x1 = _mm_shuffle_ps(res, res, 0x02); - res = _mm_max_ss(res, x1); - - if (add) - res = _mm_add_ss(res, _mm_set_ss(*pd)); - _mm_store_ss(pd, res); - } - } - else - { - const int * pii = pmprowindices; - for (float * pd = pdst; pd < pdLim; pd++) + if (remainder != 0) { - const float * ps = psrc + *pcol++; - int ii = *pii++; - - const int * pi = pindices + ii; - int size = *pi++; - const int * piLim = pi + size; - __m128 res = min; - for (; pi + 4 <= piLim; pi += 4) - res = _mm_max_ps(res, _load4(ps, pi)); - for (; pi < piLim; pi++) - res = _mm_max_ss(res, _load1(ps, pi)); - __m128 x1 = _mm_shuffle_ps(res, res, 0xB1); - res = _mm_max_ps(res, x1); - x1 = _mm_shuffle_ps(res, res, 0x02); - res = _mm_max_ss(res, x1); - - if (add) - res = _mm_add_ss(res, _mm_set_ss(*pd)); - _mm_store_ss(pd, res); - } - } -} + pDstCurrent -= (4 - remainder); + pm0 -= (4 - remainder) * ccol; -// REVIEW: Try out SSE/AVX after padding support is added. AVX math platform uses the same code below. -EXPORT_API(void) RespNormU(bool add, float alpha, float beta, bool avgOverFullKernel, float offset, - _In_ const int * pmprowcol, _In_opt_ const int * pmprowindices, _In_ const int * pindices, - _In_ const float * psrc, _Inout_ float * pdst, int crow) -{ - const int * pcol = pmprowcol; - const float * pdLim = pdst + crow; + __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); + __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (( 4 - remainder) * 4)); + + const float* pm1 = pm0 + ccol; + const float* pm2 = pm1 + ccol; + const float* pm3 = pm2 + ccol; - if (pmprowindices == nullptr) - { - int size = pindices[0]; - float scale = alpha / size; - const int * piLim = pindices + 1 + size; - for (float * pd = pdst; pd < pdLim; pd++) - { - const float * ps = psrc + *pcol++; - const int * pi = pindices + 1; - float res = 0; - for (; pi < piLim; pi++) - { - float cur = ps[*pi]; - res += cur * cur; - } - res = ps[0] * powf(offset + scale * res, -beta); - *pd = add ? *pd + res : res; - } - } - else - { - int kernelSize = pindices[0]; - const int * pii = pmprowindices; - for (float * pd = pdst; pd < pdLim; pd++) - { - const float * ps = psrc + *pcol++; - int ii = *pii++; - const int * pi = pindices + ii; - int size = *pi++; - const int * piLim = pi + size; - float res = 0; - for (; pi < piLim; pi++) + const int* ppos = pposMin; + __m128 res = _mm_setzero_ps(); + + while (ppos < pposEnd) { - float cur = ps[*pi]; - res += cur * cur; + int col = *ppos; + __m128 x1 = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]); + x1 = _mm_and_ps(x1, trailingMask); + + __m128 x2 = _mm_set1_ps(pSrcCurrent[col]); + x2 = _mm_mul_ps(x2, x1); + res = _mm_add_ps(res, x2); + ppos++; } - int avgDenom = avgOverFullKernel ? kernelSize : size; - res = ps[0] * powf(offset + alpha / avgDenom * res, -beta); - *pd = add ? *pd + res : res; + + res = _mm_add_ps(res, _mm_and_ps(leadingMask, _mm_loadu_ps(pDstCurrent))); + _mm_storeu_ps(pDstCurrent, res); + pDstCurrent += 4; + pm0 += 4 * ccol; } } } @@ -866,1118 +616,166 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I } } -// Sparse matrix. -EXPORT_API(void) MatMulTranRU(bool add, _In_ const int * pstarts, _In_ const int * pindices, _In_ const float * pcoefs, - _In_ const float * psrc, _Inout_ float * pd, int crow, int ccol) +// pd[i] += a +EXPORT_API(void) AddScalarU(float a, _Inout_ float * pd, int c) { - if (!add) - memset(pd, 0, crow * sizeof(float)); - - const int * pii = pstarts + 1; - const int * pi = pindices; - const float * pm = pcoefs; - const float * psLim = psrc + ccol; - for (const float * ps = psrc; ps < psLim; ps++) - { - float x = *ps; - const int * piLim = pindices + *pii++; + float * pdLim = pd + c; - __m128 x1 = _mm_set1_ps(x); - for (; pi + 4 <= piLim; pi += 4, pm += 4) - { - __m128 x2 = _mm_mul_ps(x1, _mm_loadu_ps(pm)); - x2 = _mm_add_ps(x2, _load4(pd, pi)); - _store4(x2, pd, pi); - } - for (; pi < piLim; pi++, pm++) - { - __m128 x2 = _mm_mul_ss(x1, _mm_set_ss(*pm)); - x2 = _mm_add_ss(x2, _load1(pd, pi)); - _store1(x2, pd, pi); - } + __m128 x1 = _mm_set1_ps(a); + for (; pd + 4 <= pdLim; pd += 4) + { + __m128 x2 = _mm_loadu_ps(pd); + x2 = _mm_add_ps(x2, x1); + _mm_storeu_ps(pd, x2); } -} -// Unpadded convolution. -EXPORT_API(void) MatMulTranCU(bool add, _In_ const int * pmpcoliv, _In_ const int * pmpcolrow, - _In_ const int * pruns, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) -{ - if (!add) - memset(pdst, 0, crow * sizeof(float)); - - int size = pruns[1]; - const int * psupport = pruns + 2; - const int * piv = pmpcoliv; - const int * prow = pmpcolrow; - const int * piLim = psupport + size; - const float * psLim = psrc + ccol; - for (const float * ps = psrc; ps < psLim; ps++) + for (; pd < pdLim; pd++) { - const float * pm = pcoefs + *piv++; - float * pd = pdst + *prow++; - const int * pi = psupport; - - float x = *ps; - __m128 x1 = _mm_set1_ps(x); - for (; pi + 4 <= piLim; pm += 4, pi += 4) - { - __m128 x2 = _mm_mul_ps(x1, _mm_loadu_ps(pm)); - x2 = _mm_add_ps(x2, _load4(pd, pi)); - _store4(x2, pd, pi); - } - for (; pi < piLim; pi++, pm++) - { - __m128 x2 = _mm_mul_ss(x1, _mm_set_ss(*pm)); - x2 = _mm_add_ss(x2, _load1(pd, pi)); - _store1(x2, pd, pi); - } + __m128 x2 = _mm_load_ss(pd); + x2 = _mm_add_ss(x2, x1); + _mm_store_ss(pd, x2); } } -// Padded convolution. -EXPORT_API(void) MatMulTranDU(bool add, _In_ const int * pmpcoliv, _In_ const int * pmpcolrow, _In_ const int * pmpcolrun, - _In_ const int * pruns, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) +EXPORT_API(void) Scale(float a, _Inout_ float * pd, int c) { - if (!add) - memset(pdst, 0, crow * sizeof(float)); - - const int * piv = pmpcoliv; - const int * prow = pmpcolrow; - const float * psLim = psrc + ccol; - - const int * pirun = pmpcolrun; - for (const float * ps = psrc; ps < psLim; ps++) + __m128 x1 = _mm_set1_ps(a); + + if (c < 4) { - const float * pm = pcoefs + *piv++; - float * pd = pdst + *prow++; - int irun = *pirun++; - const int * pi = pruns + 2 + irun; - const int * piLim = pi + pi[-1]; - - float x = *ps; - __m128 x1 = _mm_set1_ps(x); - if (irun == 0) - { - // No masking needed. - for (; pi + 4 <= piLim; pi += 4, pm += 4) - { - __m128 x2 = _mm_mul_ps(x1, _mm_loadu_ps(pm)); - x2 = _mm_add_ps(x2, _load4(pd, pi)); - _store4(x2, pd, pi); - } - for (; pi < piLim; pi++, pm++) - { - __m128 x2 = _mm_mul_ss(x1, _mm_set_ss(*pm)); - x2 = _mm_add_ss(x2, _load1(pd, pi)); - _store1(x2, pd, pi); - } - } - else + switch(c) { - // Need masking. - pm += pi[-2]; - const float * pmask = reinterpret_cast(piLim); - for (; pi + 4 <= piLim; pi += 4, pm += 4, pmask += 4) - { - __m128 x2 = _mm_mul_ps(_mm_and_ps(_mm_loadu_ps(pmask), x1), _mm_loadu_ps(pm)); - x2 = _mm_add_ps(x2, _load4(pd, pi)); - _store4(x2, pd, pi); - } - for (; pi < piLim; pi++, pm++, pmask++) - { - __m128 x2 = _mm_mul_ss(_mm_and_ps(_mm_set_ss(*pmask), x1), _mm_set_ss(*pm)); - x2 = _mm_add_ss(x2, _load1(pd, pi)); - _store1(x2, pd, pi); - } + case 3: pd[2] *= a; + case 2: pd[1] *= a; + case 1: pd[0] *= a; } + return; } -} -// Mean pooling back prop. -EXPORT_API(void) MeanBackU(bool add, _In_ const int * pmpcolrow, _In_opt_ const int * pmpcolindices, _In_ const int * pindices, - _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) -{ - if (!add) - memset(pdst, 0, crow * sizeof(float)); + uintptr_t address = (uintptr_t)(pd); + uintptr_t misalignment = address % 16; + int remainder = 0; - const int * prow = pmpcolrow; - const float * psLim = psrc + ccol; - if (pmpcolindices == nullptr) + if ((misalignment & 3) != 0) { - int size = pindices[0]; - const int * piLim = pindices + 1 + size; - for (const float * ps = psrc; ps < psLim; ps++) + // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations + remainder = c % 4; + + for (const float* pEnd = pd + (c - remainder); pd < pEnd; pd += 4) { - float * pd = pdst + *prow++; - const int * pi = pindices + 1; - - float x = *ps / size; - __m128 x1 = _mm_set1_ps(x); - for (; pi + 4 <= piLim; pi += 4) - { - __m128 x2 = _mm_add_ps(x1, _load4(pd, pi)); - _store4(x2, pd, pi); - } - for (; pi < piLim; pi++) - { - __m128 x2 = _mm_add_ss(x1, _load1(pd, pi)); - _store1(x2, pd, pi); - } + __m128 x2 = _mm_loadu_ps(pd); + x2 = _mm_mul_ps(x1, x2); + _mm_storeu_ps(pd, x2); } } else { - const int * pii = pmpcolindices; - for (const float * ps = psrc; ps < psLim; ps++) + if (misalignment != 0) { - float * pd = pdst + *prow++; - int ii = *pii++; - - const int * pi = pindices + ii; - int size = *pi++; - const int * piLim = pi + size; + // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 4 - misalignment; + + __m128 result = _mm_loadu_ps(pd); + + __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); + __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + ((4 - misalignment) * 4)); + + __m128 temp = _mm_and_ps(result, leadingMask); + result = _mm_and_ps(result, trailingMask); + + temp = _mm_mul_ps(temp, x1); + result = _mm_or_ps(temp, result); + + _mm_storeu_ps(pd, result); + + pd += misalignment; + c -= misalignment; + } - float x = *ps / size; - __m128 x1 = _mm_set1_ps(x); - for (; pi + 4 <= piLim; pi += 4) - { - __m128 x2 = _mm_add_ps(x1, _load4(pd, pi)); - _store4(x2, pd, pi); - } - for (; pi < piLim; pi++) + if (c > 3) + { + // Handle all the 128-bit blocks that we can now that we have offset to an aligned address + remainder = c % 4; + for (const float* pEnd = pd + (c - remainder); pd < pEnd; pd += 4) { - __m128 x2 = _mm_add_ss(x1, _load1(pd, pi)); - _store1(x2, pd, pi); + __m128 x2 = _mm_load_ps(pd); + x2 = _mm_mul_ps(x1, x2); + _mm_storeu_ps(pd, x2); } } + else + { + // Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not + // 128-bit aligned. This means we can't do any aligned loads and will just end up doing two + // unaligned loads where we mask the input each time. + remainder = c; + } + } + + if (remainder != 0) + { + // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next + // unaligned load will read to the end of the array and then mask out any elements already processed + + pd -= (4 - remainder); + __m128 result = _mm_loadu_ps(pd); + + __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); + __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + ((4 - remainder) * 4)); + + __m128 temp = _mm_and_ps(result, trailingMask); + result = _mm_and_ps(result, leadingMask); + + temp = _mm_mul_ps(temp, x1); + result = _mm_or_ps(temp, result); + + _mm_storeu_ps(pd, result); } } -// Max pooling back prop. -EXPORT_API(void) MaxBackU(bool add, _In_ const int * pmpcolrow, _In_opt_ const int * pmpcolindices, _In_ const int * pindices, - _In_ const float * psrc, _Inout_ float * pdst, _In_ const float * pval, int crow, int ccol) +EXPORT_API(void) ScaleSrcU(float a, _In_ const float * ps, _Inout_ float * pd, int c) { - if (!add) - memset(pdst, 0, crow * sizeof(float)); + float * pdLim = pd + c; - const int * prow = pmpcolrow; - const float * psLim = psrc + ccol; - if (pmpcolindices == nullptr) + __m128 x1 = _mm_set1_ps(a); + for (; pd + 4 <= pdLim; pd += 4, ps += 4) { - const int * piLim = pindices + 1 + pindices[0]; - for (const float * ps = psrc; ps < psLim; ps++) - { - int rowBase = *prow++; - float * pd = pdst + rowBase; - const float * pv = pval + rowBase; - const int * pi = pindices + 1; - - int j = *pi++; - float m = pv[j]; - for (; pi < piLim; pi++) - { - if (m < pv[*pi]) - { - j = *pi; - m = pv[j]; - } - } - pd[j] += *ps; - } + __m128 x2 = _mm_loadu_ps(ps); + x2 = _mm_mul_ps(x2, x1); + _mm_storeu_ps(pd, x2); } - else + + for (; pd < pdLim; pd++, ps++) { - const int * pii = pmpcolindices; - for (const float * ps = psrc; ps < psLim; ps++) - { - int rowBase = *prow++; - int ii = *pii++; - float * pd = pdst + rowBase; - const float * pv = pval + rowBase; - const int * pi = pindices + ii + 1; - const int * piLim = pi + pi[-1]; - - int j = *pi++; - float m = pv[j]; - for (; pi < piLim; pi++) - { - if (m < pv[*pi]) - { - j = *pi; - m = pv[j]; - } - } - pd[j] += *ps; - } + __m128 x2 = _mm_load_ss(ps); + x2 = _mm_mul_ss(x2, x1); + _mm_store_ss(pd, x2); } } -// REVIEW: Try out SSE/AVX after padding support is added. AVX math platform uses the same code below. -EXPORT_API(void) RespNormBackU(bool add, float alpha, float beta, bool avgOverFullKernel, float offset, - _In_ const int * pmpcolrow, _In_opt_ const int * pmpcolindices, _In_ const int * pindices, - _In_ const float * perrors, _Inout_ float * perrorsPrev, _In_ const float * pvaluesPrev, int crow, int ccol) -{ - if (!add) - memset(perrorsPrev, 0, crow * sizeof(float)); - - const int * prow = pmpcolrow; - const float * psLim = perrors + ccol; - if (pmpcolindices == nullptr) - { - int size = pindices[0]; - float scale = alpha / size; - const int * piMin = pindices + 1; - const int * piLim = piMin + size; - for (const float * ps = perrors; ps < psLim; ps++) - { - int rowBase = *prow++; - // First compute denominator: denom = offset + scale * Sum(Xj^2) - float denom = 0; - const float * pv = pvaluesPrev + rowBase; - - for (const int * pi = piMin; pi < piLim; pi++) - { - float cur = pv[*pi]; - denom += cur * cur; - } - denom = offset + scale * denom; - float denomPow = powf(denom, -beta); - // The output. - float y = pv[0] * denomPow; - - // The update logic: - // srcError(*ps) X the derivative. - // derivative at i wrt center point = powf(denom, -beta) - 2* scale * beta * X[i] * y / denom. - // derivative at i wrt other points = - 2* scale * beta * X[i] * y / denom. - float commonUpdate = *ps * (-2 * scale * beta * y) / denom; - - float * pd = perrorsPrev + rowBase; - for (const int * pi = piMin; pi < piLim; pi++) - pd[*pi] += pv[*pi] * commonUpdate; - - // Additional update for the center point. - pd[0] += *ps * denomPow; - } - } - else - { - int kernelSize = pindices[0]; - const int * pii = pmpcolindices; - for (const float * ps = perrors; ps < psLim; ps++) - { - int rowBase = *prow++; - // First compute denominator: denom = 1 + scale * Sum(Xj^2) - float denom = 0; - const float * pv = pvaluesPrev + rowBase; - int ii = *pii++; - - const int * piMin = pindices + ii; - int size = *piMin++; - const int * piLim = piMin + size; - - for (const int * pi = piMin; pi < piLim; pi++) - { - float cur = pv[*pi]; - denom += cur * cur; - } - float scale = alpha / (avgOverFullKernel ? kernelSize : size); - denom = offset + scale * denom; - float denomPow = powf(denom, -beta); - // The output. - float y = pv[0] * denomPow; - - // The update logic: - // srcError(*ps) X the derivative. - // derivative at i wrt center point = powf(denom, -beta) - 2* scale * beta * X[i] * y / denom. - // derivative at i wrt other points = - 2* scale * beta * X[i] * y / denom. - float commonUpdate = *ps * (-2 * scale * beta * y) / denom; - - float * pd = perrorsPrev + rowBase; - for (const int * pi = piMin; pi < piLim; pi++) - pd[*pi] += pv[*pi] * commonUpdate; - - // Additional update for the center point. - pd[0] += *ps * denomPow; - } - } -} - -template -void AddXYTranACore(float a, _In_ const float * px, _In_ const float * py, _Inout_ float * pmat, int crow, int ccol, float decay) -{ - const float * pyBase = py; - const float * pxLim = px + crow; - const float * pyLim = py + ccol; - float * pm = pmat; - __m128 wd; - if (useDecay) - wd = _mm_set1_ps(1 - decay); - for (; px < pxLim; px++) - { - float r = a * *px; - py = pyBase; - - __m128 x1 = _mm_set1_ps(r); - for (; py + 16 <= pyLim; py += 16, pm += 16) - { - __m128 x02 = _mm_load_ps(py); - __m128 x12 = _mm_load_ps(py + 4); - __m128 x22 = _mm_load_ps(py + 8); - __m128 x32 = _mm_load_ps(py + 12); - __m128 x03 = _mm_load_ps(pm); - __m128 x13 = _mm_load_ps(pm + 4); - __m128 x23 = _mm_load_ps(pm + 8); - __m128 x33 = _mm_load_ps(pm + 12); - x02 = _mm_mul_ps(x1, x02); - x12 = _mm_mul_ps(x1, x12); - x22 = _mm_mul_ps(x1, x22); - x32 = _mm_mul_ps(x1, x32); - if (useDecay) - { - x03 = _mm_mul_ps(wd, x03); - x13 = _mm_mul_ps(wd, x13); - x23 = _mm_mul_ps(wd, x23); - x33 = _mm_mul_ps(wd, x33); - } - x03 = _mm_add_ps(x02, x03); - x13 = _mm_add_ps(x12, x13); - x23 = _mm_add_ps(x22, x23); - x33 = _mm_add_ps(x32, x33); - _mm_store_ps(pm, x03); - _mm_store_ps(pm + 4, x13); - _mm_store_ps(pm + 8, x23); - _mm_store_ps(pm + 12, x33); - } - for (; py < pyLim; py += 4, pm += 4) - { - __m128 x02 = _mm_load_ps(py); - __m128 x03 = _mm_load_ps(pm); - x02 = _mm_mul_ps(x1, x02); - if (useDecay) - x03 = _mm_mul_ps(wd, x03); - x03 = _mm_add_ps(x02, x03); - _mm_store_ps(pm, x03); - } - } -} - -EXPORT_API(void) AddXYTranA(float a, _In_ const float * px, _In_ const float * py, _Inout_ float * pmat, int crow, int ccol, float decay) -{ - if (decay == 0) - AddXYTranACore(a, px, py, pmat, crow, ccol, decay); - else - AddXYTranACore(a, px, py, pmat, crow, ccol, decay); -} - -// Partial sparse source vector. -EXPORT_API(void) AddXYTranPA(float a, _In_ const float * px, _In_ const int * pposY, _In_ const float * pvaluesY, - int posMinY, int iposMinY, int iposLimY, _Inout_ float * pmat, int crow, int ccol) -{ -#if 1 - // REVIEW: This is faster for MNIST, but the version below is faster for extremely sparse input. - const int * pposMin = pposY + iposMinY; - const int * pposLim = pposY + iposLimY; - const float * pxLim = px + crow; - float * pm0 = pmat - posMinY; - const float * py = pvaluesY - posMinY; - - __m128 x0 = _mm_set1_ps(a); - for (; px < pxLim; px += 4, pm0 += 4 * ccol) - { - float * pm1 = pm0 + ccol; - float * pm2 = pm1 + ccol; - float * pm3 = pm2 + ccol; - - __m128 x1 = _mm_load_ps(px); - x1 = _mm_mul_ps(x1, x0); - - for (const int * ppos = pposMin; ppos < pposLim; ppos++) - { - int col = *ppos; - __m128 x2 = _mm_set1_ps(py[col]); - __m128 x3 = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]); - x2 = _mm_mul_ps(x2, x1); - x3 = _mm_add_ps(x3, x2); - - _mm_store_ss(pm0 + col, x3); x3 = _rotate(x3); - _mm_store_ss(pm1 + col, x3); x3 = _rotate(x3); - _mm_store_ss(pm2 + col, x3); x3 = _rotate(x3); - _mm_store_ss(pm3 + col, x3); - } - } -#else - const int * pposMin = pposY + iposMinY; - const int * pposLim = pposY + iposLimY; - const float * pxLim = px + crow; - float * pm = pmat - posMinY; - const float * py = pvaluesY - posMinY; - - __m128 x0 = _mm_set1_ps(a); - int d1 = 1 * ccol; - int d2 = 2 * ccol; - int d3 = 3 * ccol; - int d4 = 4 * ccol; - for (const int * ppos = pposMin; ppos < pposLim; ppos++) - { - int col = *ppos; - __m128 x2 = _mm_set1_ps(py[col]); - x2 = _mm_mul_ps(x2, x0); - - float * pm0 = pm + col; - for (const float * px0 = px; px0 < pxLim; px0 += 4, pm0 += d4) - { - __m128 x1 = _mm_load_ps(px0); - __m128 x3 = _mm_setr_ps(pm0[0], pm0[d1], pm0[d2], pm0[d3]); - x1 = _mm_mul_ps(x1, x2); - x3 = _mm_add_ps(x3, x1); - - _mm_store_ss(pm0, x3); x3 = _rotate(x3); - _mm_store_ss(pm0 + d1, x3); x3 = _rotate(x3); - _mm_store_ss(pm0 + d2, x3); x3 = _rotate(x3); - _mm_store_ss(pm0 + d3, x3); - } - } -#endif -} - -template -void AddXYTranRUCore(float a, _In_ const float * px, _In_ const float * py, - _In_ const int * pstarts, _In_ const int * pindices, _Inout_ float * pcoefs, int crow, float decay) -{ - const int * pii = pstarts + 1; - const int * pi = pindices; - float * pm = pcoefs; - const float * pxLim = px + crow; - __m128 wd; - if (useDecay) - wd = _mm_set1_ps(1 - decay); - for (; px < pxLim; px++) - { - const int * piLim = pindices + *pii++; - float r = a * *px; - - __m128 x1 = _mm_set1_ps(r); - for (; pi + 4 <= piLim; pi += 4, pm += 4) - { - __m128 x2 = _mm_mul_ps(x1, _load4(py, pi)); - __m128 x3 = _mm_loadu_ps(pm); - if (useDecay) - x3 = _mm_mul_ps(x3, wd); - x2 = _mm_add_ps(x2, x3); - _mm_storeu_ps(pm, x2); - } - for (; pi < piLim; pi++, pm++) - *pm = (useDecay ? (*pm * (1 - decay)) : *pm) + py[*pi] * r; - } -} - -// Sparse matrix. -EXPORT_API(void) AddXYTranRU(float a, _In_ const float * px, _In_ const float * py, - _In_ const int * pstarts, _In_ const int * pindices, _Inout_ float * pcoefs, int crow, float decay) -{ - if (decay == 0) - AddXYTranRUCore(a, px, py, pstarts, pindices, pcoefs, crow, decay); - else - AddXYTranRUCore(a, px, py, pstarts, pindices, pcoefs, crow, decay); -} - -// Unpadded convolution. -EXPORT_API(void) AddXYTranCU(float a, _In_ const float * px, _In_ const float * py, _In_ const int * pmprowiv, - _In_ const int * pmprowcol, _In_ const int * pruns, _Inout_ float * pcoefs, int crow) -{ - int size = pruns[1]; - const int * psupport = pruns + 2; - const int * piv = pmprowiv; - const int * pcol = pmprowcol; - const float * pxLim = px + crow; - const int * piLim = psupport + size; - - for (; px < pxLim; px++) - { - float * pm = pcoefs + *piv++; - const float * ps = py + *pcol++; - const int * pi = psupport; - float r = a * *px; - - __m128 x1 = _mm_set1_ps(r); - for (; pi + 4 <= piLim; pi += 4, pm += 4) - { - __m128 x2 = _mm_mul_ps(x1, _load4(ps, pi)); - x2 = _mm_add_ps(x2, _mm_loadu_ps(pm)); - _mm_storeu_ps(pm, x2); - } - for (; pi < piLim; pi++, pm++) - *pm += ps[*pi] * r; - // Update the bias. - *pm += r; - } -} - -// Padded convolution. -EXPORT_API(void) AddXYTranDU(float a, _In_ const float * px, _In_ const float * py, _In_ const int * pmprowiv, - _In_ const int * pmprowcol, _In_ const int * pmprowrun, _In_ const int * pruns, _Inout_ float * pcoefs, int crow) -{ - const int * piv = pmprowiv; - const int * pcol = pmprowcol; - const float * pxLim = px + crow; - int kernelSize = pruns[1]; - - const int * pirun = pmprowrun; - for (; px < pxLim; px++) - { - float * pm = pcoefs + *piv++; - const float * ps = py + *pcol++; - int irun = *pirun++; - const int * pi = pruns + 2 + irun; - const int * piLim = pi + pi[-1]; - - float r = a * *px; - - // Update the bias. - pm[kernelSize] += r; - - __m128 x1 = _mm_set1_ps(r); - if (irun == 0) - { - // No masking needed. - for (; pi + 4 <= piLim; pi += 4, pm += 4) - { - __m128 x2 = _mm_mul_ps(x1, _load4(ps, pi)); - x2 = _mm_add_ps(x2, _mm_loadu_ps(pm)); - _mm_storeu_ps(pm, x2); - } - for (; pi < piLim; pi++, pm++) - *pm += ps[*pi] * r; - } - else - { - // Need masking. - pm += pi[-2]; - const float * pmask = reinterpret_cast(piLim); - for (; pi + 4 <= piLim; pi += 4, pm += 4, pmask += 4) - { - __m128 x2 = _mm_mul_ps(_mm_and_ps(_mm_loadu_ps(pmask), x1), _load4(ps, pi)); - x2 = _mm_add_ps(x2, _mm_loadu_ps(pm)); - _mm_storeu_ps(pm, x2); - } - for (; pi < piLim; pi++, pm++, pmask++) - { - __m128 x2 = _mm_mul_ss(_mm_and_ps(_mm_set_ss(*pmask), x1), _load1(ps, pi)); - x2 = _mm_add_ss(x2, _mm_set_ss(*pm)); - _mm_store_ss(pm, x2); - } - } - } -} - -// With momentum. -EXPORT_API(void) AddXYTranMomA(float a, _In_ const float * px, _In_ const float * py, _Inout_ float * pmat, float momentum, _Inout_ float * pdel, int crow, int ccol) -{ - const float * pyBase = py; - const float * pxLim = px + crow; - const float * pyLim = py + ccol; - float * pm = pmat; - float * pd = pdel; - - __m128 x0 = _mm_set1_ps(momentum); - for (; px < pxLim; px++) - { - float r = a * *px; - - __m128 x1 = _mm_set1_ps(r); - for (py = pyBase; py < pyLim; pm += 4, pd += 4, py += 4) - { - __m128 x2 = _mm_load_ps(py); - __m128 x3 = _mm_load_ps(pd); - __m128 x4 = _mm_load_ps(pm); - x2 = _mm_mul_ps(x1, x2); - x3 = _mm_mul_ps(x0, x3); - x3 = _mm_add_ps(x2, x3); - x4 = _mm_add_ps(x3, x4); - - _mm_store_ps(pd, x3); - _mm_store_ps(pm, x4); - } - } -} - -// coef: coefs to update, ag: accumulated grads, au: accumulated updates, g: cur grads. -// Note: parameters coef, ag, au and g will be updated, do not reuse parameter g in calling code. -__forceinline void UpdateAdadelta(__m128& coef, __m128& ag, __m128& au, __m128& g, const __m128& dec, const __m128& decc, const __m128& c) -{ - __m128 x4 = _mm_mul_ps(g, g); // x4 == g * g - x4 = _mm_mul_ps(decc, x4); // x4 == (1 - decay) * g * g - ag = _mm_mul_ps(dec, ag); // ag == decay * accG - ag = _mm_add_ps(ag, x4); // ag == decay * accG + (1 - decay) * g * g - __m128 x41 = _mm_add_ps(ag, c); // x41 == ag + cond - __m128 x51 = _mm_add_ps(au, c); // x51 == accU + cond -#if 0 - // naive version: - x51 = _mm_div_ps(x51, x41); - x41 = _mm_sqrt_ps(x51); // x41 == rate -#else - // faster (approximate) version: - x41 = _mm_rsqrt_ps(x41); - __m128 x52 = _mm_rsqrt_ps(x51); - x51 = _mm_mul_ps(x51, x52); - x41 = _mm_mul_ps(x41, x51); // x41 == rate -#endif - g = _mm_mul_ps(g, x41); // g - current update - coef = _mm_add_ps(coef, g); - - g = _mm_mul_ps(g, g); // g == newU * newU - g = _mm_mul_ps(decc, g); // g == (1 - decay) * newU * newU - au = _mm_mul_ps(dec, au); // au == decay * accU - au = _mm_add_ps(au, g); // au == decay * accU + (1 - decay) * newU * newU -} - -// For Adadelta. -EXPORT_API(void) AddXYTranGradA(_In_ const float * px, _In_ const float * py, _Inout_ float * pmat, _Inout_ float * paccGrads, _Inout_ float * paccUpdates, - float decay, float cond, int crow, int ccol) -{ - const float * pyBase = py; - const float * pxLim = px + crow; - const float * pyLim = py + ccol; - float * pm = pmat; - float * pag = paccGrads; - float * pau = paccUpdates; - - __m128 dec = _mm_set1_ps(decay); - __m128 decc = _mm_set1_ps(1 - decay); - __m128 c = _mm_set1_ps(cond); - for (; px < pxLim; px++) - { - float r = *px; - - __m128 x1 = _mm_set1_ps(r); - for (py = pyBase; py < pyLim; pm += 4, pag += 4, pau += 4, py += 4) - { - __m128 x2 = _mm_load_ps(py); - __m128 ag = _mm_load_ps(pag); - __m128 au = _mm_load_ps(pau); - __m128 coef = _mm_load_ps(pm); - x2 = _mm_mul_ps(x1, x2); // x2 == g - - UpdateAdadelta(coef, ag, au, x2, dec, decc, c); - - _mm_store_ps(pm, coef); - _mm_store_ps(pag, ag); - _mm_store_ps(pau, au); - } - } -} - -// For Adadelta, sparse matrix. -EXPORT_API(void) AddXYTranGradRU(_In_ const float * px, _In_ const float * py, _In_ const int * pstarts, _In_ const int * pindices, - _Inout_ float * pcoefs, _Inout_ float * paccGrads, _Inout_ float * paccUpdates, float decay, float cond, int crow) -{ - const int * pii = pstarts + 1; - const int * pi = pindices; - float * pm = pcoefs; - const float * pxLim = px + crow; - float * pag = paccGrads; - float * pau = paccUpdates; - - __m128 dec = _mm_set1_ps(decay); - __m128 decc = _mm_set1_ps(1 - decay); - __m128 c = _mm_set1_ps(cond); - - for (; px < pxLim; px++) - { - const int * piLim = pindices + *pii++; - float r = *px; - - __m128 x1 = _mm_set1_ps(r); - for (; pi + 4 <= piLim; pi += 4, pm += 4, pag += 4, pau += 4) - { - __m128 g = _mm_mul_ps(x1, _load4(py, pi)); - __m128 ag = _mm_loadu_ps(pag); - __m128 au = _mm_loadu_ps(pau); - __m128 coef = _mm_loadu_ps(pm); - - UpdateAdadelta(coef, ag, au, g, dec, decc, c); - - _mm_storeu_ps(pm, coef); - _mm_storeu_ps(pag, ag); - _mm_storeu_ps(pau, au); - } - - if (pi < piLim) - { - size_t ctail = piLim - pi; - __m128 g = _mm_mul_ss(_load1(py, pi++), x1); - __m128 ag = _mm_load_ss(pag++); - __m128 au = _mm_load_ss(pau++); - __m128 coef = _mm_load_ss(pm++); - for (; pi < piLim; pi++, pm++, pag++, pau++) - { - g = _mm_or_ps(_mm_mul_ss(_load1(py, pi), x1), _rotate(g)); - ag = _mm_or_ps(_mm_load_ss(pag), _rotate(ag)); - au = _mm_or_ps(_mm_load_ss(pau), _rotate(au)); - coef = _mm_or_ps(_mm_load_ss(pm), _rotate(coef)); - } - UpdateAdadelta(coef, ag, au, g, dec, decc, c); - for (int i = 0; i < ctail; i++) - { - _mm_store_ss(pm - i - 1, coef); - coef = _rotate_reverse(coef); - _mm_store_ss(pag - i - 1, ag); - ag = _rotate_reverse(ag); - _mm_store_ss(pau - i - 1, au); - au = _rotate_reverse(au); - } - } - } -} - -// For Adadelta, partial sparse source vector. -EXPORT_API(void) AddXYTranGradPA(_In_ const float * px, _In_ const int * pposY, _In_ const float * pvaluesY, - int posMinY, int iposMinY, int iposLimY, _Inout_ float * pmat, _Inout_ float * paccGrads, _Inout_ float * paccUpdates, - float decay, float cond, int crow, int ccol) -{ - const int * pposMin = pposY + iposMinY; - const int * pposLim = pposY + iposLimY; - const float * pxLim = px + crow; - const float * py = pvaluesY - posMinY; - float * pm0 = pmat - posMinY; - float * pag0 = paccGrads - posMinY; - float * pau0 = paccUpdates - posMinY; - - __m128 dec = _mm_set1_ps(decay); - __m128 decc = _mm_set1_ps(1 - decay); - __m128 c = _mm_set1_ps(cond); - for (; px < pxLim; px += 4, pm0 += 4 * ccol, pag0 += 4 * ccol, pau0 += 4 * ccol) - { - float * pm1 = pm0 + ccol; - float * pm2 = pm1 + ccol; - float * pm3 = pm2 + ccol; - - float * pag1 = pag0 + ccol; - float * pag2 = pag1 + ccol; - float * pag3 = pag2 + ccol; - - float * pau1 = pau0 + ccol; - float * pau2 = pau1 + ccol; - float * pau3 = pau2 + ccol; - - __m128 x1 = _mm_load_ps(px); - - for (const int * ppos = pposMin; ppos < pposLim; ppos++) - { - int col = *ppos; - __m128 x2 = _mm_set1_ps(py[col]); - __m128 ag = _mm_setr_ps(pag0[col], pag1[col], pag2[col], pag3[col]); - __m128 au = _mm_setr_ps(pau0[col], pau1[col], pau2[col], pau3[col]); - __m128 coef = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]); - x2 = _mm_mul_ps(x2, x1); - - UpdateAdadelta(coef, ag, au, x2, dec, decc, c); - - _mm_store_ss(pm0 + col, coef); coef = _rotate(coef); - _mm_store_ss(pm1 + col, coef); coef = _rotate(coef); - _mm_store_ss(pm2 + col, coef); coef = _rotate(coef); - _mm_store_ss(pm3 + col, coef); - - _mm_store_ss(pag0 + col, ag); ag = _rotate(ag); - _mm_store_ss(pag1 + col, ag); ag = _rotate(ag); - _mm_store_ss(pag2 + col, ag); ag = _rotate(ag); - _mm_store_ss(pag3 + col, ag); - - _mm_store_ss(pau0 + col, au); au = _rotate(au); - _mm_store_ss(pau1 + col, au); au = _rotate(au); - _mm_store_ss(pau2 + col, au); au = _rotate(au); - _mm_store_ss(pau3 + col, au); - } - } -} - -// pd[i] += a -EXPORT_API(void) AddScalarU(float a, _Inout_ float * pd, int c) -{ - float * pdLim = pd + c; - - __m128 x1 = _mm_set1_ps(a); - for (; pd + 4 <= pdLim; pd += 4) - { - __m128 x2 = _mm_loadu_ps(pd); - x2 = _mm_add_ps(x2, x1); - _mm_storeu_ps(pd, x2); - } - - for (; pd < pdLim; pd++) - { - __m128 x2 = _mm_load_ss(pd); - x2 = _mm_add_ss(x2, x1); - _mm_store_ss(pd, x2); - } -} - -EXPORT_API(void) Scale(float a, _Inout_ float * pd, int c) -{ - __m128 x1 = _mm_set1_ps(a); - - if (c < 4) - { - switch(c) - { - case 3: pd[2] *= a; - case 2: pd[1] *= a; - case 1: pd[0] *= a; - } - return; - } - - uintptr_t address = (uintptr_t)(pd); - uintptr_t misalignment = address % 16; - int remainder = 0; - - if ((misalignment & 3) != 0) - { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - remainder = c % 4; - - for (const float* pEnd = pd + (c - remainder); pd < pEnd; pd += 4) - { - __m128 x2 = _mm_loadu_ps(pd); - x2 = _mm_mul_ps(x1, x2); - _mm_storeu_ps(pd, x2); - } - } - else - { - if (misalignment != 0) - { - // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 4 - misalignment; - - __m128 result = _mm_loadu_ps(pd); - - __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); - __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + ((4 - misalignment) * 4)); - - __m128 temp = _mm_and_ps(result, leadingMask); - result = _mm_and_ps(result, trailingMask); - - temp = _mm_mul_ps(temp, x1); - result = _mm_or_ps(temp, result); - - _mm_storeu_ps(pd, result); - - pd += misalignment; - c -= misalignment; - } - - if (c > 3) - { - // Handle all the 128-bit blocks that we can now that we have offset to an aligned address - remainder = c % 4; - for (const float* pEnd = pd + (c - remainder); pd < pEnd; pd += 4) - { - __m128 x2 = _mm_load_ps(pd); - x2 = _mm_mul_ps(x1, x2); - _mm_storeu_ps(pd, x2); - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not - // 128-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = c; - } - } - - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - - pd -= (4 - remainder); - __m128 result = _mm_loadu_ps(pd); - - __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); - __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + ((4 - remainder) * 4)); - - __m128 temp = _mm_and_ps(result, trailingMask); - result = _mm_and_ps(result, leadingMask); - - temp = _mm_mul_ps(temp, x1); - result = _mm_or_ps(temp, result); - - _mm_storeu_ps(pd, result); - } -} - -EXPORT_API(void) ScaleSrcU(float a, _In_ const float * ps, _Inout_ float * pd, int c) -{ - float * pdLim = pd + c; - - __m128 x1 = _mm_set1_ps(a); - for (; pd + 4 <= pdLim; pd += 4, ps += 4) - { - __m128 x2 = _mm_loadu_ps(ps); - x2 = _mm_mul_ps(x2, x1); - _mm_storeu_ps(pd, x2); - } - - for (; pd < pdLim; pd++, ps++) - { - __m128 x2 = _mm_load_ss(ps); - x2 = _mm_mul_ss(x2, x1); - _mm_store_ss(pd, x2); - } -} - -// pd[i] = a * (pd[i] + b) -EXPORT_API(void) ScaleAddU(float a, float b, _Inout_ float * pd, int c) -{ - float * pdLim = pd + c; - - __m128 x1 = _mm_set1_ps(a); - __m128 x2 = _mm_set1_ps(b); - for (; pd + 4 <= pdLim; pd += 4) - { - __m128 x3 = _mm_loadu_ps(pd); - x3 = _mm_add_ps(x3, x2); - x3 = _mm_mul_ps(x3, x1); - _mm_storeu_ps(pd, x3); - } - - for (; pd < pdLim; pd++) - { - __m128 x3 = _mm_load_ss(pd); - x3 = _mm_add_ss(x3, x2); - x3 = _mm_mul_ss(x3, x1); - _mm_store_ss(pd, x3); - } -} - -EXPORT_API(void) ScaleMaxNormA(float maxNorm, _Inout_ float * pmat, int crow, int ccol) -{ - float * pm = pmat; - float maxNormSq = maxNorm * maxNorm; - __m128 m = _mm_set1_ps(maxNorm); - for (int irow = 0; irow < crow; irow++) - { - __m128 rowNorm = _mm_set1_ps(0); - float * pms = pm; - float * pmLim = pm + ccol; - for (; pm < pmLim; pm += 4) - { - __m128 x1 = _mm_load_ps(pm); - x1 = _mm_mul_ps(x1, x1); - rowNorm = _mm_add_ps(x1, rowNorm); - } - rowNorm = _mm_hadd_ps(rowNorm, rowNorm); - rowNorm = _mm_hadd_ps(rowNorm, rowNorm); - float rowNormRes = _mm_cvtss_f32(rowNorm); - if (rowNormRes > maxNormSq) - { - __m128 scale = _mm_set1_ps(rowNormRes); -#if 0 - // REVIEW: this is faster but it uses approximation so results differ significantly from CLR. - scale = _mm_rsqrt_ps(scale); - scale = _mm_mul_ps(scale, m); -#else - scale = _mm_sqrt_ps(scale); - scale = _mm_div_ps(m, scale); -#endif - for (pm = pms; pm < pmLim; pm += 4) - { - __m128 x1 = _mm_load_ps(pm); - x1 = _mm_mul_ps(x1, scale); - _mm_store_ps(pm, x1); - } - } - } -} - -EXPORT_API(void) ScaleMaxNormTranU(float maxNorm, _Inout_ float * pmat, int crow, int ccol) -{ - for (int icol = 0; icol < ccol; icol++) - { - float * pm = pmat + icol; - float rowNorm = 0; - for (int irow = 0; irow < crow; irow++) - { - rowNorm += *pm * *pm; - pm += ccol; - } - if (rowNorm > maxNorm * maxNorm) - { - float scale = maxNorm / sqrtf(rowNorm); - pm = pmat + icol; - for (int irow = 0; irow < crow; irow++) - { - *pm *= scale; - pm += ccol; - } - } - } -} - -// Sparse matrix. -EXPORT_API(void) ScaleMaxNormRU(float maxNorm, _In_ const int * pstarts, _Inout_ float * pmat, int crow) -{ - for (int irow = 0; irow < crow; irow++) - { - float rowNorm = 0; - for (int idx = pstarts[irow]; idx < pstarts[irow + 1]; idx++) - { - rowNorm += pmat[idx] * pmat[idx]; - } - if (rowNorm > maxNorm * maxNorm) - { - float scale = maxNorm / sqrtf(rowNorm); - for (int idx = pstarts[irow]; idx < pstarts[irow + 1]; idx++) - { - pmat[idx] *= scale; - } - } - } -} - -// Convolution. -EXPORT_API(void) ScaleMaxNormCU(float maxNorm, int kernCount, int kernSize, _Inout_ float * pmat) -{ - float * pm = pmat; - for (int irow = 0; irow < kernCount; irow++) - { - float rowNorm = 0; - for (int icol = 0; icol < kernSize; icol++) - { - rowNorm += *pm * *pm; - pm++; - } - if (rowNorm > maxNorm * maxNorm) - { - float scale = maxNorm / sqrtf(rowNorm); - pm -= kernSize; - for (int icol = 0; icol < kernSize; icol++) - { - *pm *= scale; - pm++; - } - } - // Skip bias. - pm++; - } -} - -EXPORT_API(void) AddScaleA(float a, _In_ const float * ps, _Inout_ float * pd, int c) +// pd[i] = a * (pd[i] + b) +EXPORT_API(void) ScaleAddU(float a, float b, _Inout_ float * pd, int c) { float * pdLim = pd + c; __m128 x1 = _mm_set1_ps(a); - for (; pd < pdLim; pd += 4, ps += 4) + __m128 x2 = _mm_set1_ps(b); + for (; pd + 4 <= pdLim; pd += 4) { - __m128 x2 = _mm_load_ps(ps); - __m128 x3 = _mm_load_ps(pd); - x2 = _mm_mul_ps(x1, x2); - x3 = _mm_add_ps(x2, x3); - _mm_store_ps(pd, x3); + __m128 x3 = _mm_loadu_ps(pd); + x3 = _mm_add_ps(x3, x2); + x3 = _mm_mul_ps(x3, x1); + _mm_storeu_ps(pd, x3); + } + + for (; pd < pdLim; pd++) + { + __m128 x3 = _mm_load_ss(pd); + x3 = _mm_add_ss(x3, x2); + x3 = _mm_mul_ss(x3, x1); + _mm_store_ss(pd, x3); } } @@ -2047,97 +845,6 @@ EXPORT_API(void) AddScaleSU(float a, _In_ const float * ps, _In_ const int * pi, pd[*pi] += a * *ps; } -EXPORT_API(void) AddScaleMomA(float a, _In_ const float * ps, _Inout_ float * pd, float momentum, _Inout_ float * pe, int c) -{ - float * pdLim = pd + c; - - __m128 x0 = _mm_set1_ps(momentum); - __m128 x1 = _mm_set1_ps(a); - for (; pd < pdLim; pd += 4, pe += 4, ps += 4) - { - __m128 x2 = _mm_load_ps(ps); - __m128 x3 = _mm_load_ps(pe); - __m128 x4 = _mm_load_ps(pd); - x2 = _mm_mul_ps(x1, x2); - x3 = _mm_mul_ps(x0, x3); - x3 = _mm_add_ps(x2, x3); - x4 = _mm_add_ps(x3, x4); - _mm_store_ps(pe, x3); - _mm_store_ps(pd, x4); - } -} - -EXPORT_API(void) AddScaleGradA(_In_ const float * ps, _Inout_ float * pd, _Inout_ float * paccGrads, _Inout_ float * paccUpdates, - float decay, float cond, int c) -{ - float * pdLim = pd + c; - - __m128 dec = _mm_set1_ps(decay); - __m128 decc = _mm_set1_ps(1 - decay); - __m128 cnd = _mm_set1_ps(cond); - for (; pd < pdLim; pd += 4, ps += 4, paccGrads += 4, paccUpdates += 4) - { - __m128 g = _mm_load_ps(ps); - __m128 ag = _mm_load_ps(paccGrads); - __m128 au = _mm_load_ps(paccUpdates); - __m128 coef = _mm_load_ps(pd); - - UpdateAdadelta(coef, ag, au, g, dec, decc, cnd); - - _mm_store_ps(pd, coef); - _mm_store_ps(paccGrads, ag); - _mm_store_ps(paccUpdates, au); - } -} - -EXPORT_API(void) AddScaleMultiA(int count, _In_ const float * ps, _Inout_ float * pd, _Inout_ float * paccGrads, _Inout_ float * paccUpdates, - float decay, float cond, int size) -{ - if (1 == count) - AddScaleGradA(ps, pd, paccGrads, paccUpdates, decay, cond, size); - else - { - float * pdLim = pd + size; - - __m128 dec = _mm_set1_ps(decay); - __m128 decc = _mm_set1_ps(1 - decay); - __m128 cnd = _mm_set1_ps(cond); - for (; pd < pdLim; pd += 4, ps += 4, paccGrads += 4, paccUpdates += 4) - { - __m128 g = _mm_set1_ps(0); - const float * ps1 = ps; - // REVIEW: unroll? - for (int i = 0; i < count; i++, ps1 += size) - { - __m128 x1 = _mm_load_ps(ps1); - g = _mm_add_ps(x1, g); - } - __m128 ag = _mm_load_ps(paccGrads); - __m128 au = _mm_load_ps(paccUpdates); - __m128 coef = _mm_load_ps(pd); - - UpdateAdadelta(coef, ag, au, g, dec, decc, cnd); - - _mm_store_ps(pd, coef); - _mm_store_ps(paccGrads, ag); - _mm_store_ps(paccUpdates, au); - } - } -} - -EXPORT_API(void) AddA(_In_ const float * ps, _Inout_ float * pd, int c) -{ - float * pdLim = pd + c; - - for (; pd < pdLim; pd += 4, ps += 4) - { - __m128 x1 = _mm_load_ps(ps); - __m128 x2 = _mm_load_ps(pd); - x2 = _mm_add_ps(x1, x2); - _mm_store_ps(pd, x2); - } -} - EXPORT_API(void) AddU(_In_ const float * ps, _Inout_ float * pd, int c) { float * pdLim = pd + c; @@ -2196,36 +903,6 @@ EXPORT_API(void) MulElementWiseU(_In_ const float * ps1, _In_ const float * ps2, } } -EXPORT_API(void) MulElementWiseSU(_In_ const float * ps1, _In_ const float * ps2, _In_ const int * pi, _Inout_ float * pd, int c) -{ - const int * piLim = pi + c; - - for (; pi + 4 <= piLim; pi += 4) - { - __m128 x1 = _load4(ps1, pi); - __m128 x2 = _load4(ps2, pi); - x2 = _mm_mul_ps(x1, x2); - _store4(x2, pd, pi); - } - - for (; pi < piLim; pi++) - pd[*pi] = ps1[*pi] * ps2[*pi]; -} - -EXPORT_API(float) SumA(const float * ps, int c) -{ - const float * psLim = ps + c; - - __m128 res = _mm_setzero_ps(); - for (; ps < psLim; ps += 4) - res = _mm_add_ps(res, _mm_load_ps(ps)); - - res = _mm_hadd_ps(res, res); - res = _mm_hadd_ps(res, res); - - return _mm_cvtss_f32(res); -} - EXPORT_API(float) SumU(const float * ps, int c) { const float * psLim = ps + c; @@ -2448,449 +1125,6 @@ EXPORT_API(float) Dist2(const float * px, const float * py, int c) return norm2; } -// This is modeled after double-based SSE code - -// 1 / ln(2). -const float RecipLn2 = (float)1.44269504088896340735992468100; - -// Used for computing a 4th degree polynomial approximation of e^x. -const float Coef1 = (float)0.013555747234814917704030793; -const float Coef2 = (float)0.065588116243247810171479524; -const float Coef3 = (float)0.3069678791803394491901401; - -const float ExpInf = 128; -const int ExpBias = 127; -const int ExpShift = 23; - -float ExpFast(float arg) -{ - bool neg = false; - if (arg < 0) - { - arg = -arg; - neg = true; - } - - arg *= RecipLn2; - if (arg >= ExpInf) - return neg ? 0.0f : std::numeric_limits::infinity(); - - int exp = (int)arg; - arg -= exp; - exp += ExpBias; - exp <<= ExpShift; - - float res = (1 + arg) + (arg - 1) * arg * ((Coef1 * arg + Coef2) * arg + Coef3); - res *= *(float *)&exp; - - if (neg) - res = 1 / res; - return res; -} - -// Implements a fast approximation of sigmoid/tanh. -template -void ApplySigmoidCoreA(_In_ const float * ps, _Inout_ float * pd, int c) -{ - float * pdLim = pd + c; - - __m128 cSign = _mm_set1_ps(-0.0f); - __m128 cZero = _mm_set1_ps(0.0f); - __m128 cOne = _mm_set1_ps(1.0f); - - __m128 cMax = _mm_set1_ps(ExpInf); - __m128i cBias = _mm_set1_epi32(ExpBias); - __m128 c0 = _mm_set1_ps(RecipLn2); - __m128 c1 = _mm_set1_ps(Coef1); - __m128 c2 = _mm_set1_ps(Coef2); - __m128 c3 = _mm_set1_ps(Coef3); - - if (isTanh) - c0 = _mm_add_ps(c0, c0); - - for (; pd < pdLim; ps += 4, pd += 4) - { - // Get the argument, capture its sign and take its absolute value. - __m128 xArg = _mm_load_ps(ps); - // maskNaN is set to zero if xArg is not NaN and set equal to xArg otherwise. - __m128 maskNaN = _mm_and_ps(_mm_cmpneq_ps(xArg, xArg), xArg); - __m128 xSign = _mm_and_ps(xArg, cSign); - xArg = _mm_xor_ps(xArg, xSign); - - // Multiply by 1/ln(2) and check for out of bounds. - xArg = _mm_mul_ps(xArg, c0); - __m128 xGood = _mm_cmplt_ps(xArg, cMax); - xArg = _mm_and_ps(xArg, xGood); - - // Get the integer and fractional parts. - __m128i xInt = _mm_cvttps_epi32(xArg); - xArg = _mm_sub_ps(xArg, _mm_cvtepi32_ps(xInt)); - - // Add the exponent bias to xInt, then convert to a floating point - // power of two by shifting past the mantissa bits. - xInt = _mm_add_epi32(xInt, cBias); - xInt = _mm_slli_epi32(xInt, ExpShift); - - // Approximate 2 raised to the fractional part. - // (1 + f) + (f - 1) * f * ((c1 * f + c2) * f + c3) - - // x1 = (c1 * f + c2) * f + c3 - __m128 x1 = _mm_mul_ps(c1, xArg); - x1 = _mm_add_ps(x1, c2); - x1 = _mm_mul_ps(x1, xArg); - x1 = _mm_add_ps(x1, c3); - - // x2 = f * (f - 1) - __m128 x2 = _mm_sub_ps(xArg, cOne); - x2 = _mm_mul_ps(xArg, x2); - - // Add (1 + f). Note that for tanh, we only add f, so we are approximating - // 2^f - 1. This is necessary to preserve precision near zero. In particular, - // near zero, tanh(x) ~ x. - x1 = _mm_mul_ps(x2, x1); - if (!isTanh) - xArg = _mm_add_ps(xArg, cOne); - x1 = _mm_add_ps(xArg, x1); - - // Multiply by 2^n, where n is the integer part. - __m128 x3 = _mm_castsi128_ps(xInt); - x1 = _mm_mul_ps(x1, x3); - - if (!isTanh) - { - // Add 1, and take the reciprocal. - x1 = _mm_add_ps(x1, cOne); - x1 = _mm_div_ps(cOne, x1); - - // Deal with out of bounds. - x1 = _mm_and_ps(x1, xGood); - // If the input was NaN, xGood is zero, so x1 is zero. So can simply or in maskNaN. - x1 = _mm_or_ps(x1, maskNaN); - - // Deal with the sign. Set: - // * x2 = x1 if xSign is -0 (0x80000000) - // * x2 = 1 - x1 if xSign is +0 (0x00000000). - x1 = _mm_or_ps(x1, xSign); - x2 = _mm_or_ps(xSign, cOne); - x2 = _mm_max_ps(x2, cZero); - x2 = _mm_sub_ps(x2, x1); - } - else - { - // [2^n(2^f - 1) + (2^n - 1)] / [2^n(2^f - 1) + (2^n + 1)] - x2 = _mm_add_ps(x1, _mm_sub_ps(x3, cOne)); - x1 = _mm_add_ps(x1, _mm_add_ps(x3, cOne)); - x2 = _mm_div_ps(x2, x1); - - // Deal with out of bounds: x2 = (x2 & xGood) | ((1 + maskNaN) & ~xGood) - x2 = _mm_and_ps(x2, xGood); - x1 = _mm_andnot_ps(xGood, _mm_add_ps(maskNaN, cOne)); - x2 = _mm_or_ps(x2, x1); - - // Deal with the sign. - x2 = _mm_or_ps(x2, xSign); - } - - _mm_store_ps(pd, x2); - } - - // If we overshot, back fill with zero! Since tanh(0) = 0, we only need to do this for sigmoid. - if (!isTanh) - { - while (pd > pdLim) - *--pd = 0.0f; - } -} - -EXPORT_API(void) ApplySigmoidA(_In_ const float * ps, _Inout_ float * pd, int c) -{ - ApplySigmoidCoreA(ps, pd, c); -} - -EXPORT_API(void) ApplySoftMaxU(_In_ const float * ps, _Inout_ float * pd, int c) -{ - // REVIEW: Use SSE - do 4 at a time. - - const float * psLim = ps + c; - - // Compute max output. - float maxOut = -std::numeric_limits::infinity(); - for (const float * p = ps; p < psLim; p++) - { - float v = *p; - if (maxOut < v) - maxOut = v; - } - - // Compute exp and sum. - float sum = 0; - const float * p = ps; - for (float * q = pd; p < psLim; p++, q++) - { - float v = ExpFast(*p - maxOut); - *q = v; - sum += v; - } - - // Normalize. - for (float * q = pd; q < pd + c; q++) - *q /= sum; -} - -EXPORT_API(void) ApplyRectifiedLinearA(_In_ const float * ps, _Inout_ float * pd, int c) -{ - const float * psLim = ps + c; - - __m128 cZero = _mm_set1_ps(0.0f); - for (; ps < psLim; ps += 4, pd += 4) - { - __m128 x1 = _mm_load_ps(ps); - x1 = _mm_max_ps(x1, cZero); - _mm_store_ps(pd, x1); - } -} - -EXPORT_API(void) ApplySquareA(_In_ const float * ps, _Inout_ float * pd, int c) -{ - const float * psLim = ps + c; - - for (; ps < psLim; ps += 4, pd += 4) - { - __m128 x1 = _mm_load_ps(ps); - x1 = _mm_mul_ps(x1, x1); - _mm_store_ps(pd, x1); - } -} - -EXPORT_API(void) ApplySqrtA(_In_ const float * ps, _Inout_ float * pd, int c) -{ - const float * psLim = ps + c; - - __m128 cZero = _mm_set1_ps(0.0f); - for (; ps < psLim; ps += 4, pd += 4) - { - __m128 x1 = _mm_load_ps(ps); - x1 = _mm_max_ps(x1, cZero); - x1 = _mm_sqrt_ps(x1); - _mm_store_ps(pd, x1); - } -} - -EXPORT_API(void) ApplySoftRectifiedLinearU(_In_ const float * ps, _Inout_ float * pd, int c) -{ - const float * psLim = ps + c; - - // Apply: f(x) = log(1 + e^x). To avoid overflow for large x, we use the identity: f(x) = x + f(-x). - // REVIEW: Should we implement a "LogFast"? - // REVIEW: Do 4 at a time. - const float * p = ps; - for (float * q = pd; p < psLim; p++, q++) - { - float x = *p; - if (x > 0) - *q = x + log(1 + ExpFast(-x)); - else - *q = log(1 + ExpFast(x)); - } -} - -EXPORT_API(void) ApplyAbsA(_In_ const float * ps, _Inout_ float * pd, int c) -{ - const float * psLim = ps + c; - - __m128 mask = _mm_castsi128_ps(_mm_set1_epi32(0x7FFFFFFF)); - for (; ps < psLim; ps += 4, pd += 4) - { - __m128 x1 = _mm_load_ps(ps); - x1 = _mm_and_ps(x1, mask); - _mm_store_ps(pd, x1); - } -} - -EXPORT_API(void) ApplyTanhA(_In_ const float * ps, _Inout_ float * pd, int c) -{ - ApplySigmoidCoreA(ps, pd, c); -} - -EXPORT_API(void) ApplyBoundedRectifiedLinearA(_In_ const float * ps, _Inout_ float * pd, int c) -{ - const float * psLim = ps + c; - - __m128 cZero = _mm_set1_ps(0.0f); - __m128 cOne = _mm_set1_ps(1.0f); - for (; ps < psLim; ps += 4, pd += 4) - { - __m128 x1 = _mm_load_ps(ps); - x1 = _mm_max_ps(x1, cZero); - x1 = _mm_min_ps(x1, cOne); - _mm_store_ps(pd, x1); - } -} - -EXPORT_API(void) ApplySigmoidDerivativeA(_In_ const float * pv, _Inout_ float * pg, int c) -{ - float * pgLim = pg + c; - - // pg[i] *= pv[i] * (1 - pv[i]) - __m128 cOne = _mm_set1_ps(1.0f); - for (; pg < pgLim; pg += 4, pv += 4) - { - __m128 x1 = _mm_load_ps(pv); - __m128 x2 = _mm_load_ps(pg); - __m128 x3 = _mm_sub_ps(cOne, x1); - x1 = _mm_mul_ps(x1, x3); - x2 = _mm_mul_ps(x2, x1); - _mm_store_ps(pg, x2); - } -} - -EXPORT_API(void) ApplyRectifiedLinearDerivativeA(_In_ const float * pv, _Inout_ float * pg, int c) -{ - float * pgLim = pg + c; - - __m128 cZero = _mm_set1_ps(0.0f); - for (; pg < pgLim; pg += 4, pv += 4) - { - __m128 x1 = _mm_load_ps(pv); - __m128 x2 = _mm_load_ps(pg); - x1 = _mm_cmpgt_ps(x1, cZero); - x2 = _mm_and_ps(x2, x1); - _mm_store_ps(pg, x2); - } -} - -EXPORT_API(void) ApplySquareDerivativeA(_In_ const float * px, _In_opt_ const float * py, _Inout_ float * pg, int c, bool drop) -{ - float * pgLim = pg + c; - - if (drop) - { - __m128 cZero = _mm_set1_ps(0.0f); - for (; pg < pgLim; pg += 4, px += 4, py += 4) - { - __m128 x0 = _mm_cmpgt_ps(_mm_load_ps(py), cZero); - __m128 x1 = _mm_load_ps(px); - __m128 x2 = _mm_load_ps(pg); - x1 = _mm_add_ps(x1, x1); - x2 = _mm_mul_ps(x2, x1); - x2 = _mm_and_ps(x2, x0); - _mm_store_ps(pg, x2); - } - } - else - { - for (; pg < pgLim; pg += 4, px += 4) - { - __m128 x1 = _mm_load_ps(px); - __m128 x2 = _mm_load_ps(pg); - x1 = _mm_add_ps(x1, x1); - x2 = _mm_mul_ps(x2, x1); - _mm_store_ps(pg, x2); - } - } -} - -EXPORT_API(void) ApplySqrtDerivativeA(_In_ const float * pv, _Inout_ float * pg, int c) -{ - float * pgLim = pg + c; - static const float smallValue = 1e-10F; - - __m128 cZero = _mm_set1_ps(0.0f); - __m128 cSmall = _mm_set1_ps(smallValue); - for (; pg < pgLim; pg += 4, pv += 4) - { - __m128 x1 = _mm_load_ps(pv); - __m128 x2 = _mm_load_ps(pg); - __m128 x3 = _mm_cmpgt_ps(x1, cZero); - x1 = _mm_max_ps(x1, cSmall); - x1 = _mm_add_ps(x1, x1); - x2 = _mm_and_ps(x2, x3); - x2 = _mm_div_ps(x2, x1); - _mm_store_ps(pg, x2); - } -} - -EXPORT_API(void) ApplySoftRectifiedLinearDerivativeU(_In_opt_ const float * px, _In_ const float * py, _Inout_ float * pg, int c) -{ - UNUSED(px); - - float * pgLim = pg + c; - - // Use the identity: y' = 1 - e^(-y). This has a few nice properties: - // * If x is large enough that x == y (after rounding), we'll compute y' as 1. - // * If x is small enough that y == 0 (after rounding), we'll compute y' as 0. - // * If y is zero because of drop out, we'll compute y' as 0. - // REVIEW: Do 4 at a time. - for (; pg < pgLim; pg++, py++) - *pg *= 1 - ExpFast(-*py); -} - -EXPORT_API(void) ApplyAbsDerivativeA(_In_ const float * px, _In_opt_ const float * py, _Inout_ float * pg, int c, bool drop) -{ - float * pgLim = pg + c; - - __m128 cZero = _mm_set1_ps(0.0f); - __m128 cSign = _mm_set1_ps(-0.0f); - if (drop) - { - for (; pg < pgLim; pg += 4, px += 4, py += 4) - { - __m128 x1 = _mm_and_ps(_mm_load_ps(px), cSign); - __m128 x2 = _mm_cmpgt_ps(_mm_load_ps(py), cZero); - __m128 x3 = _mm_load_ps(pg); - x3 = _mm_xor_ps(x3, x1); - x3 = _mm_and_ps(x3, x2); - _mm_store_ps(pg, x3); - } - } - else - { - for (; pg < pgLim; pg += 4, px += 4) - { - __m128 x0 = _mm_load_ps(px); - __m128 x1 = _mm_and_ps(x0, cSign); - __m128 x2 = _mm_cmpneq_ps(x0, cZero); - __m128 x3 = _mm_load_ps(pg); - x3 = _mm_xor_ps(x3, x1); - x3 = _mm_and_ps(x3, x2); - _mm_store_ps(pg, x3); - } - } -} - -EXPORT_API(void) ApplyTanhDerivativeA(_In_ const float * pv, _Inout_ float * pg, int c) -{ - float * pgLim = pg + c; - - // pg[i] *= 1 - pv[i] * pv[i] - __m128 cOne = _mm_set1_ps(1.0f); - for (; pg < pgLim; pg += 4, pv += 4) - { - __m128 x1 = _mm_load_ps(pv); - __m128 x2 = _mm_load_ps(pg); - x1 = _mm_mul_ps(x1, x1); - x1 = _mm_sub_ps(cOne, x1); - x2 = _mm_mul_ps(x2, x1); - _mm_store_ps(pg, x2); - } -} - -EXPORT_API(void) ApplyBoundedRectifiedLinearDerivativeA(_In_ const float * pv, _Inout_ float * pg, int c) -{ - float * pgLim = pg + c; - - __m128 cZero = _mm_set1_ps(0.0f); - __m128 cOne = _mm_set1_ps(1.0f); - for (; pg < pgLim; pg += 4, pv += 4) - { - __m128 x1 = _mm_load_ps(pv); - __m128 x2 = _mm_load_ps(pg); - x2 = _mm_and_ps(x2, _mm_cmpgt_ps(x1, cZero)); - x2 = _mm_and_ps(x2, _mm_cmplt_ps(x1, cOne)); - _mm_store_ps(pg, x2); - } -} - EXPORT_API(void) ZeroItemsU(_Inout_ float * pd, int c, _In_ const int * pindices, int cindices) { DEBUG_ONLY(c); @@ -2992,69 +1226,3 @@ EXPORT_API(void) SdcaL1UpdateSU(float primalUpdate, _In_ const float * ps, _In_ pd2[i] = std::abs(d1) > threshold ? (d1 > 0 ? d1 - threshold : d1 + threshold) : 0; } } - -EXPORT_API(void) ScaleAdadeltaU(_Inout_ float * mat, _Inout_ float * accGrads, _Inout_ float * accUpdates, float decay, float cond, _In_ const float * grads, int size) -{ - float * pm = mat; - float * pmLim = pm + size; - float * pag = accGrads; - float * pau = accUpdates; - const float * pg = grads; - - __m128 dec = _mm_set1_ps(decay); - __m128 decc = _mm_set1_ps(1 - decay); - __m128 c = _mm_set1_ps(cond); - - for (; pm + 4 <= pmLim; pm += 4, pag += 4, pau += 4, pg += 4) - { - __m128 g = _mm_loadu_ps(pg); - __m128 ag = _mm_loadu_ps(pag); - __m128 au = _mm_loadu_ps(pau); - __m128 coef = _mm_loadu_ps(pm); - - UpdateAdadelta(coef, ag, au, g, dec, decc, c); - - _mm_storeu_ps(pm, coef); - _mm_storeu_ps(pag, ag); - _mm_storeu_ps(pau, au); - } - - for (; pm < pmLim; pm++, pag++, pau++, pg++) - { - float g = *pg; - float accGrad = decay * *pag + (1 - decay) * g * g; - float accUpd = *pau; - - float newUpd = sqrtf((accUpd + cond) / (accGrad + cond)) * g; - *pm += newUpd; - *pag = accGrad; - *pau = decay * accUpd + (1 - decay) * newUpd * newUpd; - } -} - -EXPORT_API(void) ScaleAdadeltaA(_Inout_ float * mat, _Inout_ float * accGrads, _Inout_ float * accUpdates, float decay, float cond, _Inout_ float * grads, int size) -{ - float * pm = mat; - float * pmLim = pm + size; - float * pag = accGrads; - float * pau = accUpdates; - float * pg = grads; - - __m128 dec = _mm_set1_ps(decay); - __m128 decc = _mm_set1_ps(1 - decay); - __m128 c = _mm_set1_ps(cond); - - for (; pm < pmLim; pm += 4, pag += 4, pau += 4, pg += 4) - { - __m128 g = _mm_load_ps(pg); - __m128 ag = _mm_load_ps(pag); - __m128 au = _mm_load_ps(pau); - __m128 coef = _mm_load_ps(pm); - - UpdateAdadelta(coef, ag, au, g, dec, decc, c); - - _mm_store_ps(pm, coef); - _mm_store_ps(pag, ag); - _mm_store_ps(pau, au); - } -} diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs index 7d7a851a48..e9282e79c3 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs @@ -108,13 +108,19 @@ public void SdcaL1UpdateU() [BenchmarkCategory("Fma")] public void SdcaL1UpdateSU() => AvxIntrinsics.SdcaL1UpdateSU(DefaultScale, IndexLength, src, idx, DefaultScale, dst, result); + [Benchmark] [BenchmarkCategory("Fma")] - public void MatMulX() - => AvxIntrinsics.MatMulX(src, src1, dst, 1000, 1000); + public void MatMul() + => AvxIntrinsics.MatMul(src, src1, dst, 1000, 1000); + + [Benchmark] + public void MatMulTran() + => AvxIntrinsics.MatMulTran(src, src1, dst, 1000, 1000); [Benchmark] - public void MatMulTranX() - => AvxIntrinsics.MatMulTranX(src, src1, dst, 1000, 1000); + [BenchmarkCategory("Fma")] + public void MatMulP() + => AvxIntrinsics.MatMulP(src, matrixIdx, src1, 0, 0, MatrixIndexLength, dst, 1000, 1000); } } diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs index 07958c19ed..778ce56c93 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs @@ -250,5 +250,17 @@ public unsafe void MatMulTran() Thunk.MatMulTran(psrc1, psrc, pdst, 1000, 1000); } } + + [Benchmark] + public unsafe void MatMulP() + { + fixed (float* psrc = &src[0]) + fixed (float* pdst = &dst[0]) + fixed (float* psrc1 = &src1[0]) + fixed (int* pidx = &matrixIdx[0]) + { + Thunk.MatMulP(psrc1, pidx, psrc, 0, 0, MatrixIndexLength, pdst, 1000, 1000); + } + } } } diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/PerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/PerformanceTests.cs index cffa0af94a..6726603b5a 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/PerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/PerformanceTests.cs @@ -17,12 +17,14 @@ public abstract class PerformanceTests protected const int IndexLength = 1000003; protected const int Length = 1000003; - + protected const int MatrixIndexLength = 100; + private const int DefaultSeed = 253421; protected const float DefaultScale = 1.11f; protected float[] src, dst, original, src1, src2, result; protected int[] idx; + protected int[] matrixIdx; private int _seed = DefaultSeed; @@ -67,6 +69,7 @@ public void Setup() original = new float[Length]; result = new float[Length]; idx = new int[IndexLength]; + matrixIdx = new int[MatrixIndexLength]; _seed = GetSeed(); Random rand = new Random(_seed); @@ -85,6 +88,11 @@ public void Setup() { idx[i] = rand.Next(0, Length); } + + for (int i = 0; i < MatrixIndexLength; i++) + { + matrixIdx[i] = rand.Next(0, 1000); + } } [GlobalCleanup] diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs index 8e94dabb96..fdb8e8adcc 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs @@ -100,11 +100,15 @@ public void SdcaL1UpdateSU() => SseIntrinsics.SdcaL1UpdateSU(DefaultScale, IndexLength, src, idx, DefaultScale, dst, result); [Benchmark] - public void MatMulX() + public void MatMul() => SseIntrinsics.MatMul(src, src1, dst, 1000, 1000); [Benchmark] - public void MatMulTranX() + public void MatMulTran() => SseIntrinsics.MatMulTran(src, src1, dst, 1000, 1000); + + [Benchmark] + public void MatMulP() + => SseIntrinsics.MatMulP(src, matrixIdx, src1, 0, 0, MatrixIndexLength, dst, 1000, 1000); } } diff --git a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs index 8ce7878f83..68e26ac66e 100644 --- a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs +++ b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs @@ -116,7 +116,7 @@ public void MatMulTranTest(int matTest, int srcTest, int dstTest, float[] expect [InlineData(0, 0, 0, new float[] { 38.25002f, 38.25002f, 38.25002f, 38.25002f, 38.25002f, 38.25002f, 38.25002f, 38.25002f })] [InlineData(1, 1, 0, new float[] { 910f, 2190f, 3470f, 4750f, 6030f, 7310f, 8590f, 9870f })] [InlineData(1, 0, 1, new float[] { 95f, 231f, 367f, 503f, 639f, 775f, 911f, 1047f, 1183f, 1319f, 1455f, 1591f, 1727f, 1863f, 1999f, 2135f })] - public void MatMulPATest(int matTest, int srcTest, int dstTest, float[] expected) + public void MatTimesSrcSparseTest(int matTest, int srcTest, int dstTest, float[] expected) { AlignedArray mat = _testMatrices[matTest]; AlignedArray src = _testSrcVectors[srcTest];