diff --git a/src/Microsoft.ML.CpuMath/Avx.cs b/src/Microsoft.ML.CpuMath/Avx.cs
deleted file mode 100644
index f7769b3295..0000000000
--- a/src/Microsoft.ML.CpuMath/Avx.cs
+++ /dev/null
@@ -1,1165 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System;
-
-namespace Microsoft.ML.Runtime.Internal.CpuMath
-{
- ///
- /// Keep Avx.cs in sync with Sse.cs. When making changes to one, use BeyondCompare or a similar tool
- /// to view diffs and propagate appropriate changes to the other.
- ///
- public static class AvxUtils
- {
- public const int CbAlign = 32;
-
- private static bool Compat(AlignedArray a)
- {
- Contracts.AssertValue(a);
- Contracts.Assert(a.Size > 0);
- return a.CbAlign == CbAlign;
- }
-
- private static unsafe float* Ptr(AlignedArray a, float* p)
- {
- Contracts.AssertValue(a);
- float* q = p + a.GetBase((long)p);
- Contracts.Assert(((long)q & (CbAlign - 1)) == 0);
- return q;
- }
-
- public static bool CheckAvx()
- {
- return Thunk.ChkAvx();
- }
-
- public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun)
- {
- Contracts.Assert(Compat(mat));
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(mat.Size == dst.Size * src.Size);
-
- unsafe
- {
- fixed (float* pmat = &mat.Items[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- {
- if (!tran)
- {
- Contracts.Assert(0 <= crun && crun <= dst.Size);
- Thunk.MatMulX(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size);
- }
- else
- {
- Contracts.Assert(0 <= crun && crun <= src.Size);
- Thunk.MatMulTranX(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun);
- }
- }
- }
- }
-
- public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray srcValues,
- int posMin, int iposMin, int iposLim, AlignedArray dst, int crun)
- {
- Contracts.Assert(Compat(mat));
- Contracts.Assert(Compat(srcValues));
- Contracts.Assert(Compat(dst));
- Contracts.AssertValue(rgposSrc);
- Contracts.Assert(0 <= iposMin && iposMin <= iposLim && iposLim <= rgposSrc.Length);
- Contracts.Assert(mat.Size == dst.Size * srcValues.Size);
-
- if (iposMin >= iposLim)
- {
- dst.ZeroItems();
- return;
- }
- Contracts.AssertNonEmpty(rgposSrc);
- unsafe
- {
- fixed (float* pdst = &dst.Items[0])
- fixed (float* pmat = &mat.Items[0])
- fixed (float* psrc = &srcValues.Items[0])
- fixed (int* ppossrc = &rgposSrc[0])
- {
- Contracts.Assert(0 <= crun && crun <= dst.Size);
- Thunk.MatMulPX(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size);
- }
- }
- }
-
- public static void MatTimesSrc(bool add, int[] starts, int[] indices, float[] coefs,
- AlignedArray src, AlignedArray dst, int crow)
- {
- Contracts.AssertNonEmpty(starts);
- Contracts.Assert(starts.Length == crow + 1);
- Contracts.Assert(starts[0] == 0);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(starts[crow] == indices.Length);
- Contracts.AssertNonEmpty(coefs);
- Contracts.Assert(indices.Length == coefs.Length);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(0 < crow && crow <= dst.Size);
- Contracts.Assert(crow * src.Size >= coefs.Length);
-
- unsafe
- {
- fixed (int* pstarts = &starts[0])
- fixed (int* pindices = &indices[0])
- fixed (float* pcoefs = &coefs[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.MatMulRX(add, pstarts, pindices, pcoefs, Ptr(src, psrc), Ptr(dst, pdst), crow);
- }
- }
-
- public static void MatTimesSrc(bool add, int[] mprowiv, int[] mprowcol,
- int[] mprowrun, int[] runs, float[] coefs,
- AlignedArray src, AlignedArray dst, int crow)
- {
- Contracts.AssertNonEmpty(mprowiv);
- Contracts.Assert(mprowiv.Length == crow);
- Contracts.AssertNonEmpty(mprowcol);
- Contracts.Assert(mprowcol.Length == crow);
- Contracts.Assert(mprowrun == null || mprowrun.Length == crow);
- Contracts.AssertNonEmpty(runs);
- Contracts.AssertNonEmpty(coefs);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(0 < crow && crow <= dst.Size);
-
- unsafe
- {
- fixed (int* pmprowiv = &mprowiv[0])
- fixed (int* pmprowcol = &mprowcol[0])
- fixed (int* pruns = &runs[0])
- fixed (float* pcoefs = &coefs[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- {
- if (mprowrun == null)
- {
- Thunk.MatMulCX(add, pmprowiv, pmprowcol, pruns, pcoefs,
- Ptr(src, psrc), Ptr(dst, pdst), crow);
- }
- else
- {
- fixed (int* pmprowrun = &mprowrun[0])
- {
- Thunk.MatMulDX(add, pmprowiv, pmprowcol, pmprowrun, pruns, pcoefs,
- Ptr(src, psrc), Ptr(dst, pdst), crow);
- }
- }
- }
- }
- }
-
- public static void MeanOfSrc(bool add, int[] mprowcol, int[] mprowindices,
- int[] indices, AlignedArray src, AlignedArray dst, int crow)
- {
- Contracts.AssertNonEmpty(mprowcol);
- Contracts.Assert(mprowcol.Length == crow);
- Contracts.Assert(mprowindices == null || mprowindices.Length == crow);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(0 < crow && crow <= dst.Size);
-
- unsafe
- {
- fixed (int* pmprowcol = &mprowcol[0])
- fixed (int* pmprowindices = mprowindices)
- fixed (int* pindices = &indices[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- {
- // REVIEW: Implement using AVX
- Thunk.MeanU(add, pmprowcol, pmprowindices, pindices, Ptr(src, psrc), Ptr(dst, pdst), crow);
- }
- }
- }
-
- public static void MaxOfSrc(bool add, int[] mprowcol, int[] mprowindices,
- int[] indices, AlignedArray src, AlignedArray dst, int crow)
- {
- Contracts.AssertNonEmpty(mprowcol);
- Contracts.Assert(mprowcol.Length == crow);
- Contracts.Assert(mprowindices == null || mprowindices.Length == crow);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(0 < crow && crow <= dst.Size);
-
- unsafe
- {
- fixed (int* pmprowcol = &mprowcol[0])
- fixed (int* pmprowindices = mprowindices)
- fixed (int* pindices = &indices[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- {
- // REVIEW: Implement using AVX
- Thunk.MaxU(add, pmprowcol, pmprowindices, pindices, Ptr(src, psrc), Ptr(dst, pdst), crow);
- }
- }
- }
-
- public static void RespNormOfSrc(bool add, float alpha, float beta, bool avgOverFullKernel, float offset,
- int[] mprowcol, int[] mprowindices, int[] indices,
- AlignedArray src, AlignedArray dst, int crow)
- {
- Contracts.AssertNonEmpty(mprowcol);
- Contracts.Assert(mprowcol.Length == crow);
- Contracts.Assert(mprowindices == null || mprowindices.Length == crow);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(0 < crow && crow <= dst.Size);
-
- unsafe
- {
- fixed (int* pmprowcol = &mprowcol[0])
- fixed (int* pmprowindices = mprowindices)
- fixed (int* pindices = &indices[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- {
- // REVIEW: Implement using AVX
- Thunk.RespNormU(add, alpha, beta, avgOverFullKernel, offset, pmprowcol, pmprowindices, pindices,
- Ptr(src, psrc), Ptr(dst, pdst), crow);
- }
- }
- }
-
- public static void MatTranTimesSrc(bool add, int[] starts, int[] indices, float[] coefs,
- AlignedArray src, AlignedArray dst, int ccol)
- {
- Contracts.AssertNonEmpty(starts);
- Contracts.Assert(starts.Length == ccol + 1);
- Contracts.Assert(starts[0] == 0);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(starts[ccol] == indices.Length);
- Contracts.AssertNonEmpty(coefs);
- Contracts.Assert(indices.Length == coefs.Length);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(0 < ccol && ccol <= src.Size);
- Contracts.Assert(dst.Size * ccol >= coefs.Length);
-
- unsafe
- {
- fixed (int* pstarts = &starts[0])
- fixed (int* pindices = &indices[0])
- fixed (float* pcoefs = &coefs[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.MatMulTranRX(add, pstarts, pindices, pcoefs, Ptr(src, psrc), Ptr(dst, pdst), dst.Size, ccol);
- }
- }
-
- public static void MatTranTimesSrc(bool add, int[] mpcoliv, int[] mpcolrow, int[] mpcolrun,
- int[] runs, float[] coefs, AlignedArray src, AlignedArray dst, int ccol)
- {
- Contracts.AssertNonEmpty(mpcoliv);
- Contracts.Assert(mpcoliv.Length == ccol);
- Contracts.AssertNonEmpty(mpcolrow);
- Contracts.Assert(mpcolrow.Length == ccol);
- Contracts.AssertNonEmpty(runs);
- Contracts.AssertNonEmpty(coefs);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(mpcolrun == null || mpcolrun.Length == ccol);
- Contracts.Assert(0 < ccol && ccol <= src.Size);
-
- unsafe
- {
- fixed (int* pmpcoliv = &mpcoliv[0])
- fixed (int* pmpcolrow = &mpcolrow[0])
- fixed (int* pruns = &runs[0])
- fixed (float* pcoefs = &coefs[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- {
- if (mpcolrun == null)
- {
- Thunk.MatMulTranCX(add, pmpcoliv, pmpcolrow, pruns, pcoefs,
- Ptr(src, psrc), Ptr(dst, pdst), dst.Size, ccol);
- }
- else
- {
- fixed (int* pmpcolrun = &mpcolrun[0])
- {
- Thunk.MatMulTranDX(add, pmpcoliv, pmpcolrow, pmpcolrun, pruns, pcoefs,
- Ptr(src, psrc), Ptr(dst, pdst), dst.Size, ccol);
- }
- }
- }
- }
- }
-
- public static void MeanBackOfSrc(bool add, int[] mpcolrow, int[] mpcolindices,
- int[] indices, AlignedArray src, AlignedArray dst, int ccol)
- {
- Contracts.AssertNonEmpty(mpcolrow);
- Contracts.Assert(mpcolrow.Length == ccol);
- Contracts.Assert(mpcolindices == null || mpcolindices.Length == ccol);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(0 < ccol && ccol <= src.Size);
-
- unsafe
- {
- fixed (int* pmpcolrow = &mpcolrow[0])
- fixed (int* pmpcolindices = mpcolindices)
- fixed (int* pindices = &indices[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- {
- // REVIEW: Implement using AVX
- Thunk.MeanBackU(add, pmpcolrow, pmpcolindices, pindices, Ptr(src, psrc), Ptr(dst, pdst), dst.Size, ccol);
- }
- }
- }
-
- public static void MaxBackOfSrc(bool add, int[] mpcolrow, int[] mpcolindices,
- int[] indices, AlignedArray src, AlignedArray dst, AlignedArray val, int ccol)
- {
- Contracts.AssertNonEmpty(mpcolrow);
- Contracts.Assert(mpcolrow.Length == ccol);
- Contracts.Assert(mpcolindices == null || mpcolindices.Length == ccol);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(Compat(val));
- Contracts.Assert(0 < ccol && ccol <= src.Size);
- Contracts.Assert(dst.Size == val.Size);
-
- unsafe
- {
- fixed (int* pmpcolrow = &mpcolrow[0])
- fixed (int* pmpcolindices = mpcolindices)
- fixed (int* pindices = &indices[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- fixed (float* pval = &val.Items[0])
- {
- // REVIEW: Implement using AVX
- Thunk.MaxBackU(add, pmpcolrow, pmpcolindices, pindices, Ptr(src, psrc), Ptr(dst, pdst), Ptr(val, pval), dst.Size, ccol);
- }
- }
- }
-
- public static void RespNormBackOfSrc(bool add, float alpha, float beta, bool avgOverFullKernel, float offset,
- int[] mpcolrow, int[] mpcolindices, int[] indices,
- AlignedArray errors, AlignedArray errorsPrev, AlignedArray valuesPrev, int ccol)
- {
- Contracts.AssertNonEmpty(mpcolrow);
- Contracts.Assert(mpcolrow.Length == ccol);
- Contracts.Assert(mpcolindices == null || mpcolindices.Length == ccol);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(Compat(errors));
- Contracts.Assert(Compat(errorsPrev));
- Contracts.Assert(Compat(valuesPrev));
- Contracts.Assert(0 < ccol && ccol <= errors.Size);
- Contracts.Assert(errorsPrev.Size == valuesPrev.Size);
-
- unsafe
- {
- fixed (int* pmpcolrow = &mpcolrow[0])
- fixed (int* pmpcolindices = mpcolindices)
- fixed (int* pindices = &indices[0])
- fixed (float* perr = &errors.Items[0])
- fixed (float* perrPrev = &errorsPrev.Items[0])
- fixed (float* pvalPrev = &valuesPrev.Items[0])
- {
- // REVIEW: Implement using AVX
- Thunk.RespNormBackU(add, alpha, beta, avgOverFullKernel, offset, pmpcolrow, pmpcolindices, pindices,
- Ptr(errors, perr), Ptr(errorsPrev, perrPrev), Ptr(valuesPrev, pvalPrev), errorsPrev.Size, ccol);
- }
- }
- }
-
- public static void AddXYTran(float a, AlignedArray x, AlignedArray y, AlignedArray mat, int crow, float decay)
- {
- Contracts.Assert(Compat(x));
- Contracts.Assert(Compat(y));
- Contracts.Assert(Compat(mat));
- Contracts.Assert(0 < crow && crow <= x.Size);
- Contracts.Assert(x.Size * y.Size == mat.Size);
- Contracts.Assert(decay >= 0);
-
- unsafe
- {
- fixed (float* px = &x.Items[0])
- fixed (float* py = &y.Items[0])
- fixed (float* pmat = &mat.Items[0])
- Thunk.AddXYTranX(a, Ptr(x, px), Ptr(y, py), Ptr(mat, pmat), crow, y.Size, decay);
- }
- }
-
- public static void AddXYTran(float a, AlignedArray x, int[] rgposY, AlignedArray valuesY,
- int posMinY, int iposMinY, int iposLimY, AlignedArray mat, int crow)
- {
- Contracts.Assert(Compat(x));
- Contracts.Assert(Compat(valuesY));
- Contracts.Assert(Compat(mat));
- Contracts.AssertNonEmpty(rgposY);
- Contracts.Assert(0 <= iposMinY && iposMinY <= iposLimY && iposLimY <= rgposY.Length);
- Contracts.Assert(0 < crow && crow <= x.Size);
- Contracts.Assert(x.Size * valuesY.Size == mat.Size);
-
- if (iposMinY >= iposLimY)
- return;
-
- unsafe
- {
- fixed (float* px = &x.Items[0])
- fixed (float* py = &valuesY.Items[0])
- fixed (int* pposy = &rgposY[0])
- fixed (float* pmat = &mat.Items[0])
- {
- Thunk.AddXYTranPX(a, Ptr(x, px), pposy, Ptr(valuesY, py), posMinY, iposMinY, iposLimY, Ptr(mat, pmat),
- crow, valuesY.Size);
- }
- }
- }
-
- public static void AddXYTran(float a, AlignedArray x, AlignedArray y,
- int[] starts, int[] indices, float[] coefs, int crow, float decay)
- {
- Contracts.Assert(Compat(x));
- Contracts.Assert(Compat(y));
- Contracts.Assert(0 < crow && crow <= x.Size);
- Contracts.AssertNonEmpty(starts);
- Contracts.Assert(starts.Length == crow + 1);
- Contracts.Assert(starts[0] == 0);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(starts[crow] == indices.Length);
- Contracts.AssertNonEmpty(coefs);
- Contracts.Assert(indices.Length == coefs.Length);
- Contracts.Assert(crow * y.Size >= coefs.Length);
- Contracts.Assert(decay >= 0);
-
- unsafe
- {
- fixed (float* px = &x.Items[0])
- fixed (float* py = &y.Items[0])
- fixed (int* pstarts = &starts[0])
- fixed (int* pindices = &indices[0])
- fixed (float* pcoefs = &coefs[0])
- Thunk.AddXYTranRX(a, Ptr(x, px), Ptr(y, py), pstarts, pindices, pcoefs, crow, decay);
- }
- }
-
- public static void AddXYTran(float a, AlignedArray x, AlignedArray y, int[] mprowiv,
- int[] mprowcol, int[] mprowrun, int[] runs, float[] coefs, int crow)
- {
- Contracts.Assert(Compat(x));
- Contracts.Assert(Compat(y));
- Contracts.Assert(0 < crow && crow <= x.Size);
- Contracts.AssertNonEmpty(mprowiv);
- Contracts.Assert(mprowiv.Length == crow);
- Contracts.AssertNonEmpty(mprowcol);
- Contracts.Assert(mprowcol.Length == crow);
- Contracts.Assert(mprowrun == null || mprowrun.Length == crow);
- Contracts.AssertNonEmpty(runs);
- Contracts.AssertNonEmpty(coefs);
-
- unsafe
- {
- fixed (float* px = &x.Items[0])
- fixed (float* py = &y.Items[0])
- fixed (int* pmprowiv = &mprowiv[0])
- fixed (int* pmprowcol = &mprowcol[0])
- fixed (int* pruns = &runs[0])
- fixed (float* pcoefs = &coefs[0])
- {
- if (mprowrun == null)
- Thunk.AddXYTranCX(a, Ptr(x, px), Ptr(y, py), pmprowiv, pmprowcol, pruns, pcoefs, crow);
- else
- {
- fixed (int* pmprowrun = mprowrun)
- Thunk.AddXYTranDX(a, Ptr(x, px), Ptr(y, py), pmprowiv, pmprowcol, pmprowrun, pruns, pcoefs, crow);
- }
- }
- }
- }
-
- public static void AddXYTran(float a, AlignedArray x, AlignedArray y, AlignedArray mat, float momentum, AlignedArray delta, int crow)
- {
- Contracts.Assert(Compat(x));
- Contracts.Assert(Compat(y));
- Contracts.Assert(Compat(mat));
- Contracts.Assert(Compat(delta));
- Contracts.Assert(0 < crow && crow <= x.Size);
- Contracts.Assert(x.Size * y.Size == mat.Size);
- Contracts.Assert(mat.Size == delta.Size);
-
- unsafe
- {
- fixed (float* px = &x.Items[0])
- fixed (float* py = &y.Items[0])
- fixed (float* pmat = &mat.Items[0])
- fixed (float* pdel = &delta.Items[0])
- Thunk.AddXYTranMomX(a, Ptr(x, px), Ptr(y, py), Ptr(mat, pmat), momentum, Ptr(delta, pdel), crow, y.Size);
- }
- }
-
- public static void AddXYTran(AlignedArray x, AlignedArray y, AlignedArray mat, AlignedArray accGrads, AlignedArray accUpdates,
- float decay, float cond, int crow)
- {
- Contracts.Assert(Compat(x));
- Contracts.Assert(Compat(y));
- Contracts.Assert(Compat(mat));
- Contracts.Assert(Compat(accGrads));
- Contracts.Assert(Compat(accUpdates));
- Contracts.Assert(0 < crow && crow <= x.Size);
- Contracts.Assert(x.Size * y.Size == mat.Size);
- Contracts.Assert(mat.Size == accGrads.Size);
- Contracts.Assert(mat.Size == accUpdates.Size);
-
- unsafe
- {
- fixed (float* px = &x.Items[0])
- fixed (float* py = &y.Items[0])
- fixed (float* pmat = &mat.Items[0])
- fixed (float* pag = &accGrads.Items[0])
- fixed (float* pau = &accUpdates.Items[0])
- Thunk.AddXYTranGradX(Ptr(x, px), Ptr(y, py), Ptr(mat, pmat), Ptr(accGrads, pag), Ptr(accUpdates, pau), decay, cond, crow, y.Size);
- }
- }
-
- public static void AddXYTran(AlignedArray x, AlignedArray y, int[] starts, int[] indices,
- float[] coefs, float[] accGrads, float[] accUpdates, float decay, float cond, int crow)
- {
- Contracts.Assert(Compat(x));
- Contracts.Assert(Compat(y));
- Contracts.Assert(0 < crow && crow <= x.Size);
- Contracts.AssertNonEmpty(starts);
- Contracts.Assert(starts.Length == crow + 1);
- Contracts.Assert(starts[0] == 0);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(starts[crow] == indices.Length);
- Contracts.AssertNonEmpty(coefs);
- Contracts.Assert(indices.Length == coefs.Length);
- Contracts.Assert(crow * y.Size >= coefs.Length);
- Contracts.AssertNonEmpty(accGrads);
- Contracts.Assert(coefs.Length == accGrads.Length);
- Contracts.AssertNonEmpty(accUpdates);
- Contracts.Assert(coefs.Length == accUpdates.Length);
-
- unsafe
- {
- fixed (float* px = &x.Items[0])
- fixed (float* py = &y.Items[0])
- fixed (int* pstarts = &starts[0])
- fixed (int* pindices = &indices[0])
- fixed (float* pcoefs = &coefs[0])
- fixed (float* pag = &accGrads[0])
- fixed (float* pau = &accUpdates[0])
- Thunk.AddXYTranGradRX(Ptr(x, px), Ptr(y, py), pstarts, pindices, pcoefs, pag, pau, decay, cond, crow);
- }
- }
-
- public static void AddXYTran(AlignedArray x, int[] rgposY, AlignedArray valuesY,
- int posMinY, int iposMinY, int iposLimY, AlignedArray mat,
- AlignedArray accGrads, AlignedArray accUpdates, float decay, float cond, int crow)
- {
- Contracts.Assert(Compat(x));
- Contracts.AssertNonEmpty(rgposY);
- Contracts.Assert(Compat(valuesY));
- Contracts.Assert(Compat(mat));
- Contracts.Assert(0 <= iposMinY && iposMinY <= iposLimY && iposLimY <= rgposY.Length);
- Contracts.Assert(0 < crow && crow <= x.Size);
- Contracts.Assert(x.Size * valuesY.Size == mat.Size);
- Contracts.Assert(mat.Size == accGrads.Size);
- Contracts.Assert(mat.Size == accUpdates.Size);
-
- if (iposMinY >= iposLimY)
- return;
-
- unsafe
- {
- fixed (float* px = &x.Items[0])
- fixed (float* py = &valuesY.Items[0])
- fixed (int* pposy = &rgposY[0])
- fixed (float* pmat = &mat.Items[0])
- fixed (float* pag = &accGrads.Items[0])
- fixed (float* pau = &accUpdates.Items[0])
- {
- Thunk.AddXYTranGradPX(Ptr(x, px), pposy, Ptr(valuesY, py), posMinY, iposMinY, iposLimY, Ptr(mat, pmat),
- Ptr(accGrads, pag), Ptr(accUpdates, pau), decay, cond, crow, valuesY.Size);
- }
- }
- }
-
- public static void Scale(float a, AlignedArray dst)
- {
- Contracts.Assert(Compat(dst));
-
- unsafe
- {
- fixed (float* pdst = &dst.Items[0])
- Thunk.ScaleX(a, Ptr(dst, pdst), dst.Size);
- }
- }
-
- public static void Scale(float a, float[] dst, int count)
- {
- Contracts.AssertNonEmpty(dst);
- Contracts.Assert(0 < count && count <= dst.Length);
-
- unsafe
- {
- fixed (float* pd = &dst[0])
- Thunk.Scale(a, pd, count);
- }
- }
-
- public static void ScaleConvWeights(float a, int kernelSize, float[] dst)
- {
- Contracts.AssertValue(dst);
-
- // REVIEW: implement in SSE/AVX.
- for (int istart = 0; istart < dst.Length; istart += kernelSize + 1)
- {
- for (int i = 0; i < kernelSize; i++)
- dst[istart + i] *= a;
- }
- }
-
- public static void ScaleMaxNorm(bool tran, float maxNorm, AlignedArray mat, int crun, int runLenPhy)
- {
- // Called only with Avx alignment.
- Contracts.Assert(Compat(mat));
-
- unsafe
- {
- fixed (float* pmat = &mat.Items[0])
- {
- if (!tran)
- Thunk.ScaleMaxNormX(maxNorm, Ptr(mat, pmat), crun, runLenPhy);
- else
- Thunk.ScaleMaxNormTranU(maxNorm, Ptr(mat, pmat), crun, runLenPhy);
- }
- }
- }
-
- public static void ScaleMaxNorm(float maxNorm, int[] starts, int[] indices, float[] mat)
- {
- Contracts.AssertNonEmpty(starts);
-
- int crow = starts.Length - 1;
- Contracts.Assert(starts[0] == 0);
- Contracts.AssertValue(indices);
- Contracts.Assert(starts[crow] == indices.Length);
- Contracts.AssertNonEmpty(mat);
-
- unsafe
- {
- fixed (int* pstarts = &starts[0])
- fixed (float* pmat = &mat[0])
- Thunk.ScaleMaxNormRU(maxNorm, pstarts, pmat, crow);
- }
- }
-
- public static void ScaleMaxNorm(float maxNorm, int kernCount, int kernSize, float[] mat)
- {
- Contracts.AssertNonEmpty(mat);
-
- unsafe
- {
- fixed (float* pmat = &mat[0])
- Thunk.ScaleMaxNormCU(maxNorm, kernCount, kernSize, pmat);
- }
- }
-
- public static void AddScale(float a, AlignedArray src, AlignedArray dst)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.AddScaleX(a, Ptr(src, psrc), Ptr(dst, pdst), dst.Size);
- }
- }
-
- public static void AddScale(float a, AlignedArray src, AlignedArray dst, float momentum, AlignedArray delta)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(Compat(delta));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(src.Size == delta.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- fixed (float* pdel = &delta.Items[0])
- Thunk.AddScaleMomX(a, Ptr(src, psrc), Ptr(dst, pdst), momentum, Ptr(delta, pdel), dst.Size);
- }
- }
-
- public static void AddScale(float a, float[] src, float[] dst, int count)
- {
- Contracts.AssertNonEmpty(src);
- Contracts.Assert(0 < count && count <= src.Length);
- Contracts.AssertNonEmpty(dst);
- Contracts.Assert(count <= dst.Length);
-
- unsafe
- {
- fixed (float* psrc = &src[0])
- fixed (float* pdst = &dst[0])
- Thunk.AddScaleU(a, psrc, pdst, count);
- }
- }
-
- public static void AddScale(AlignedArray src, AlignedArray dst,
- AlignedArray accGrads, AlignedArray accUpdates, float decay, float cond)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(Compat(accGrads));
- Contracts.Assert(Compat(accUpdates));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(src.Size == accGrads.Size);
- Contracts.Assert(src.Size == accUpdates.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- fixed (float* pag = &accGrads.Items[0])
- fixed (float* pau = &accUpdates.Items[0])
- Thunk.AddScaleGradX(Ptr(src, psrc), Ptr(dst, pdst), Ptr(accGrads, pag), Ptr(accUpdates, pau), decay, cond, dst.Size);
- }
- }
-
- public static void Add(AlignedArray src, AlignedArray dst)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.AddX(Ptr(src, psrc), Ptr(dst, pdst), dst.Size);
- }
- }
-
- public static void Add(float[] src, float[] dst, int count)
- {
- Contracts.AssertNonEmpty(src);
- Contracts.Assert(0 < count && count <= src.Length);
- Contracts.AssertNonEmpty(dst);
- Contracts.Assert(count <= dst.Length);
-
- unsafe
- {
- fixed (float* ps = &src[0])
- fixed (float* pd = &dst[0])
- Thunk.AddU(ps, pd, count);
- }
- }
-
- public static float Sum(AlignedArray src)
- {
- Contracts.Assert(Compat(src));
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- return Thunk.SumX(Ptr(src, psrc), src.Size);
- }
- }
-
- public static float Sum(float[] src, int count)
- {
- Contracts.AssertNonEmpty(src);
- Contracts.Assert(0 < count && count <= src.Length);
-
- unsafe
- {
- fixed (float* psrc = &src[0])
- return Thunk.SumU(psrc, count);
- }
- }
-
- public static float SumSq(float[] src, int count)
- {
- Contracts.AssertNonEmpty(src);
- Contracts.Assert(0 < count && count <= src.Length);
-
- unsafe
- {
- fixed (float* psrc = &src[0])
- return Thunk.SumSqU(psrc, count);
- }
- }
-
- public static float SumAbs(float[] src, int count)
- {
- Contracts.AssertNonEmpty(src);
- Contracts.Assert(0 < count && count <= src.Length);
-
- unsafe
- {
- fixed (float* psrc = &src[0])
- return Thunk.SumAbsU(psrc, count);
- }
- }
-
- public static float MaxAbs(float[] src, int count)
- {
- Contracts.AssertNonEmpty(src);
- Contracts.Assert(0 < count && count <= src.Length);
-
- unsafe
- {
- fixed (float* psrc = &src[0])
- return Thunk.MaxAbsU(psrc, count);
- }
- }
-
- public static float DotProductSparse(float[] a, float[] b, int[] indices, int count)
- {
- Contracts.AssertNonEmpty(a);
- Contracts.AssertNonEmpty(b);
- Contracts.Assert(0 < count);
- Contracts.Assert(count < a.Length);
- Contracts.Assert(count <= b.Length);
- Contracts.Assert(count <= indices.Length);
-
- unsafe
- {
- fixed (float* pa = &a[0])
- fixed (float* pb = &b[0])
- fixed (int* pi = &indices[0])
- return Thunk.DotSU(pa, pb, pi, count);
- }
- }
-
- public static float DotProductSparse(float[] a, int offset, float[] b, int[] indices, int count)
- {
- Contracts.AssertNonEmpty(a);
- Contracts.Assert(0 < count);
- Contracts.Assert(0 <= offset && offset < a.Length);
- Contracts.Assert(a.Length - offset > count);
- Contracts.AssertNonEmpty(b);
- Contracts.Assert(count <= b.Length);
- Contracts.Assert(count <= indices.Length);
-
- unsafe
- {
- fixed (float* pa = &a[offset])
- fixed (float* pb = &b[0])
- fixed (int* pi = &indices[0])
- return Thunk.DotSU(pa, pb, pi, count);
- }
- }
-
- public static void ApplySigmoid(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 < c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplySigmoidX(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplySoftMax(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 < c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplySoftMaxX(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplyRectifiedLinear(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 < c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplyRectifiedLinearX(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplySquare(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 < c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplySquareX(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplySqrt(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 < c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplySqrtX(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplySoftRectifiedLinear(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 < c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplySoftRectifiedLinearX(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplyAbs(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 < c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplyAbsX(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplyTanh(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 < c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplyTanhX(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplyBoundedRectifiedLinear(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 <= c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplyBoundedRectifiedLinearX(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplySigmoidDerivative(AlignedArray value, AlignedArray grad)
- {
- Contracts.Assert(Compat(value));
- Contracts.Assert(Compat(grad));
- Contracts.Assert(value.Size == grad.Size);
-
- unsafe
- {
- fixed (float* pvalue = &value.Items[0])
- fixed (float* pgrad = &grad.Items[0])
- Thunk.ApplySigmoidDerivativeX(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size);
- }
- }
-
- public static void ApplyRectifiedLinearDerivative(AlignedArray value, AlignedArray grad)
- {
- Contracts.Assert(Compat(value));
- Contracts.Assert(Compat(grad));
- Contracts.Assert(value.Size == grad.Size);
-
- unsafe
- {
- fixed (float* pvalue = &value.Items[0])
- fixed (float* pgrad = &grad.Items[0])
- Thunk.ApplyRectifiedLinearDerivativeX(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size);
- }
- }
-
- public static void ApplySquareDerivative(AlignedArray input, AlignedArray output, AlignedArray grad, bool drop)
- {
- Contracts.Assert(Compat(input));
- Contracts.Assert(Compat(output));
- Contracts.Assert(Compat(grad));
- Contracts.Assert(output.Size == input.Size);
- Contracts.Assert(output.Size == grad.Size);
-
- unsafe
- {
- fixed (float* px = &input.Items[0])
- fixed (float* py = &output.Items[0])
- fixed (float* pg = &grad.Items[0])
- Thunk.ApplySquareDerivativeX(Ptr(input, px), Ptr(output, py), Ptr(grad, pg), grad.Size, drop);
- }
- }
-
- public static void ApplySqrtDerivative(AlignedArray value, AlignedArray grad)
- {
- Contracts.Assert(Compat(value));
- Contracts.Assert(Compat(grad));
- Contracts.Assert(value.Size == grad.Size);
-
- unsafe
- {
- fixed (float* pvalue = &value.Items[0])
- fixed (float* pgrad = &grad.Items[0])
- Thunk.ApplySqrtDerivativeX(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size);
- }
- }
-
- public static void ApplySoftRectifiedLinearDerivative(AlignedArray input, AlignedArray output, AlignedArray grad)
- {
- Contracts.Assert(Compat(input));
- Contracts.Assert(Compat(output));
- Contracts.Assert(Compat(grad));
- Contracts.Assert(output.Size == input.Size);
- Contracts.Assert(output.Size == grad.Size);
-
- unsafe
- {
- fixed (float* px = &input.Items[0])
- fixed (float* py = &output.Items[0])
- fixed (float* pg = &grad.Items[0])
- Thunk.ApplySoftRectifiedLinearDerivativeX(Ptr(input, px), Ptr(output, py), Ptr(grad, pg), grad.Size);
- }
- }
-
- public static void ApplyAbsDerivative(AlignedArray input, AlignedArray output, AlignedArray grad, bool drop)
- {
- Contracts.Assert(Compat(input));
- Contracts.Assert(Compat(output));
- Contracts.Assert(Compat(grad));
- Contracts.Assert(output.Size == input.Size);
- Contracts.Assert(output.Size == grad.Size);
-
- unsafe
- {
- fixed (float* px = &input.Items[0])
- fixed (float* py = &output.Items[0])
- fixed (float* pg = &grad.Items[0])
- Thunk.ApplyAbsDerivativeX(Ptr(input, px), Ptr(output, py), Ptr(grad, pg), grad.Size, drop);
- }
- }
-
- public static void ApplyTanhDerivative(AlignedArray value, AlignedArray grad)
- {
- Contracts.Assert(Compat(value));
- Contracts.Assert(Compat(grad));
- Contracts.Assert(value.Size == grad.Size);
-
- unsafe
- {
- fixed (float* pvalue = &value.Items[0])
- fixed (float* pgrad = &grad.Items[0])
- Thunk.ApplyTanhDerivativeX(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size);
- }
- }
-
- public static void ApplyBoundedRectifiedLinearDerivative(AlignedArray value, AlignedArray grad)
- {
- Contracts.Assert(Compat(value));
- Contracts.Assert(Compat(grad));
- Contracts.Assert(value.Size == grad.Size);
-
- unsafe
- {
- fixed (float* pvalue = &value.Items[0])
- fixed (float* pgrad = &grad.Items[0])
- Thunk.ApplyBoundedRectifiedLinearDerivativeX(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size);
- }
- }
-
- public static void ZeroMatrixItems(AlignedArray dst, int ccol, int cfltRow, int[] indices)
- {
- Contracts.Assert(0 < ccol && ccol <= cfltRow);
-
- unsafe
- {
- fixed (float* pdst = &dst.Items[0])
- fixed (int* pi = &indices[0])
- {
- if (ccol == cfltRow)
- Thunk.ZeroItemsU(Ptr(dst, pdst), dst.Size, pi, indices.Length);
- else
- Thunk.ZeroMatrixItemsCore(Ptr(dst, pdst), dst.Size, ccol, cfltRow, pi, indices.Length);
- }
- }
- }
-
- public static void ScaleAdadelta(float[] mat, float[] accGrads, float[] accUpdates, float decay, float cond, float[] grads)
- {
- Contracts.AssertNonEmpty(mat);
- Contracts.AssertNonEmpty(accGrads);
- Contracts.AssertNonEmpty(accUpdates);
- Contracts.Assert(mat.Length == accGrads.Length);
- Contracts.Assert(mat.Length == accUpdates.Length);
- Contracts.Assert(mat.Length <= grads.Length);
-
- unsafe
- {
- fixed (float* pm = &mat[0])
- fixed (float* pag = &accGrads[0])
- fixed (float* pau = &accUpdates[0])
- fixed (float* pg = &grads[0])
- Thunk.ScaleAdadeltaX(pm, pag, pau, decay, cond, pg, mat.Length);
- }
- }
- }
-}
\ No newline at end of file
diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs
index d99a3b0570..c812b47687 100644
--- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs
+++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs
@@ -46,25 +46,6 @@ internal static class AvxIntrinsics
private static readonly Vector256 _absMask256 = Avx.StaticCast(Avx.SetAllVector256(0x7FFFFFFF));
- private const int Vector256Alignment = 32;
-
- [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
- private static bool HasCompatibleAlignment(AlignedArray alignedArray)
- {
- Contracts.AssertValue(alignedArray);
- Contracts.Assert(alignedArray.Size > 0);
- return (alignedArray.CbAlign % Vector256Alignment) == 0;
- }
-
- [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
- private static unsafe float* GetAlignedBase(AlignedArray alignedArray, float* unalignedBase)
- {
- Contracts.AssertValue(alignedArray);
- float* alignedBase = unalignedBase + alignedArray.GetBase((long)unalignedBase);
- Contracts.Assert(((long)alignedBase % Vector256Alignment) == 0);
- return alignedBase;
- }
-
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
private static Vector128 GetHigh(in Vector256 x)
=> Avx.ExtractVector128(x, 1);
@@ -170,19 +151,19 @@ private static Vector256 MultiplyAdd(Vector256 src1, Vector256 mat, ReadOnlySpan src, Span dst, int crow, int ccol)
{
- fixed (float* psrc = &src[0])
- fixed (float* pdst = &dst[0])
- fixed (float* pmat = &mat[0])
+ Contracts.Assert(crow % 4 == 0);
+ Contracts.Assert(ccol % 4 == 0);
+
+ fixed (float* psrc = &MemoryMarshal.GetReference(src))
+ fixed (float* pdst = &MemoryMarshal.GetReference(dst))
+ fixed (float* pmat = &MemoryMarshal.GetReference(mat))
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
{
@@ -312,24 +293,27 @@ public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int cro
}
// Partial sparse source vector.
- public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArray src,
- int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol)
+ public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan rgposSrc, AlignedArray src,
+ int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol)
{
- Contracts.Assert(HasCompatibleAlignment(mat));
- Contracts.Assert(HasCompatibleAlignment(src));
- Contracts.Assert(HasCompatibleAlignment(dst));
+ MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol);
+ }
+
+ public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgposSrc, ReadOnlySpan src,
+ int posMin, int iposMin, int iposEnd, Span dst, int crow, int ccol)
+ {
+ Contracts.Assert(crow % 8 == 0);
+ Contracts.Assert(ccol % 8 == 0);
// REVIEW: For extremely sparse inputs, interchanging the loops would
// likely be more efficient.
- fixed (float* pSrcStart = &src.Items[0])
- fixed (float* pDstStart = &dst.Items[0])
- fixed (float* pMatStart = &mat.Items[0])
- fixed (int* pposSrc = &rgposSrc[0])
+ fixed (float* psrc = &MemoryMarshal.GetReference(src))
+ fixed (float* pdst = &MemoryMarshal.GetReference(dst))
+ fixed (float* pmat = &MemoryMarshal.GetReference(mat))
+ fixed (int* pposSrc = &MemoryMarshal.GetReference(rgposSrc))
+ fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
+ fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
{
- float* psrc = GetAlignedBase(src, pSrcStart);
- float* pdst = GetAlignedBase(dst, pDstStart);
- float* pmat = GetAlignedBase(mat, pMatStart);
-
int* pposMin = pposSrc + iposMin;
int* pposEnd = pposSrc + iposEnd;
float* pDstEnd = pdst + crow;
@@ -337,7 +321,106 @@ public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArra
float* pSrcCurrent = psrc - posMin;
float* pDstCurrent = pdst;
- while (pDstCurrent < pDstEnd)
+ nuint address = (nuint)(pDstCurrent);
+ int misalignment = (int)(address % 32);
+ int length = crow;
+ int remainder = 0;
+
+ if ((misalignment & 3) != 0)
+ {
+ while (pDstCurrent < pDstEnd)
+ {
+ Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow());
+ pDstCurrent += 8;
+ pm0 += 8 * ccol;
+ }
+ }
+ else
+ {
+ if (misalignment != 0)
+ {
+ misalignment >>= 2;
+ misalignment = 8 - misalignment;
+
+ Vector256 mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8));
+
+ float* pm1 = pm0 + ccol;
+ float* pm2 = pm1 + ccol;
+ float* pm3 = pm2 + ccol;
+ Vector256 result = Avx.SetZeroVector256();
+
+ int* ppos = pposMin;
+
+ while (ppos < pposEnd)
+ {
+ int col1 = *ppos;
+ int col2 = col1 + 4 * ccol;
+ Vector256 x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2],
+ pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
+
+ x1 = Avx.And(mask, x1);
+ Vector256 x2 = Avx.SetAllVector256(pSrcCurrent[col1]);
+ result = MultiplyAdd(x2, x1, result);
+ ppos++;
+ }
+
+ Avx.Store(pDstCurrent, result);
+ pDstCurrent += misalignment;
+ pm0 += misalignment * ccol;
+ length -= misalignment;
+ }
+
+ if (length > 7)
+ {
+ remainder = length % 8;
+ while (pDstCurrent < pDstEnd)
+ {
+ Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow());
+ pDstCurrent += 8;
+ pm0 += 8 * ccol;
+ }
+ }
+ else
+ {
+ remainder = length;
+ }
+
+ if (remainder != 0)
+ {
+ pDstCurrent -= (8 - remainder);
+ pm0 -= (8 - remainder) * ccol;
+ Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8));
+ Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8));
+
+ float* pm1 = pm0 + ccol;
+ float* pm2 = pm1 + ccol;
+ float* pm3 = pm2 + ccol;
+ Vector256 result = Avx.SetZeroVector256();
+
+ int* ppos = pposMin;
+
+ while (ppos < pposEnd)
+ {
+ int col1 = *ppos;
+ int col2 = col1 + 4 * ccol;
+ Vector256 x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2],
+ pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
+ x1 = Avx.And(x1, trailingMask);
+
+ Vector256 x2 = Avx.SetAllVector256(pSrcCurrent[col1]);
+ result = MultiplyAdd(x2, x1, result);
+ ppos++;
+ }
+
+ result = Avx.Add(result, Avx.And(leadingMask, Avx.LoadVector256(pDstCurrent)));
+
+ Avx.Store(pDstCurrent, result);
+ pDstCurrent += 8;
+ pm0 += 8 * ccol;
+ }
+ }
+
+ Vector256 SparseMultiplicationAcrossRow()
{
float* pm1 = pm0 + ccol;
float* pm2 = pm1 + ccol;
@@ -351,33 +434,30 @@ public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArra
int col1 = *ppos;
int col2 = col1 + 4 * ccol;
Vector256 x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2],
- pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
+ pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
Vector256 x2 = Avx.SetAllVector256(pSrcCurrent[col1]);
result = MultiplyAdd(x2, x1, result);
-
ppos++;
}
- Avx.StoreAligned(pDstCurrent, result);
- pDstCurrent += 8;
- pm0 += 8 * ccol;
+ return result;
}
}
}
- public static unsafe void MatMulTranX(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
+ public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
{
- Contracts.Assert(crow % 4 == 0);
- Contracts.Assert(ccol % 4 == 0);
-
- MatMulTranX(mat.Items, src.Items, dst.Items, crow, ccol);
+ MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol);
}
- public static unsafe void MatMulTranX(float[] mat, float[] src, float[] dst, int crow, int ccol)
+ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crow, int ccol)
{
- fixed (float* psrc = &src[0])
- fixed (float* pdst = &dst[0])
- fixed (float* pmat = &mat[0])
+ Contracts.Assert(crow % 4 == 0);
+ Contracts.Assert(ccol % 4 == 0);
+
+ fixed (float* psrc = &MemoryMarshal.GetReference(src))
+ fixed (float* pdst = &MemoryMarshal.GetReference(dst))
+ fixed (float* pmat = &MemoryMarshal.GetReference(mat))
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
{
diff --git a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs
index aa8ff85bc6..a41ea6f703 100644
--- a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs
+++ b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs
@@ -34,12 +34,12 @@ public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, Al
if (!tran)
{
Contracts.Assert(crun <= dst.Size);
- AvxIntrinsics.MatMulX(mat, src, dst, crun, src.Size);
+ AvxIntrinsics.MatMul(mat, src, dst, crun, src.Size);
}
else
{
Contracts.Assert(crun <= src.Size);
- AvxIntrinsics.MatMulTranX(mat, src, dst, dst.Size, crun);
+ AvxIntrinsics.MatMulTran(mat, src, dst, dst.Size, crun);
}
}
else if (Sse.IsSupported)
@@ -109,12 +109,12 @@ public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray sr
if (Avx.IsSupported)
{
Contracts.Assert(crun <= dst.Size);
- AvxIntrinsics.MatMulPX(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
+ AvxIntrinsics.MatMulP(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
}
else if (Sse.IsSupported)
{
Contracts.Assert(crun <= dst.Size);
- SseIntrinsics.MatMulPA(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
+ SseIntrinsics.MatMulP(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
}
else
{
diff --git a/src/Microsoft.ML.CpuMath/Sse.cs b/src/Microsoft.ML.CpuMath/Sse.cs
index 4281893852..a25c286d55 100644
--- a/src/Microsoft.ML.CpuMath/Sse.cs
+++ b/src/Microsoft.ML.CpuMath/Sse.cs
@@ -81,500 +81,7 @@ public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray sr
fixed (int* ppossrc = &rgposSrc[0])
{
Contracts.Assert(0 <= crun && crun <= dst.Size);
- Thunk.MatMulPA(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size);
- }
- }
- }
-
- public static void MatTimesSrc(bool add, int[] starts, int[] indices, float[] coefs,
- AlignedArray src, AlignedArray dst, int crow)
- {
- Contracts.AssertNonEmpty(starts);
- Contracts.Assert(starts.Length == crow + 1);
- Contracts.Assert(starts[0] == 0);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(starts[crow] == indices.Length);
- Contracts.AssertNonEmpty(coefs);
- Contracts.Assert(indices.Length == coefs.Length);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(0 < crow && crow <= dst.Size);
- Contracts.Assert(crow * src.Size >= coefs.Length);
-
- unsafe
- {
- fixed (int* pstarts = &starts[0])
- fixed (int* pindices = &indices[0])
- fixed (float* pcoefs = &coefs[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.MatMulRU(add, pstarts, pindices, pcoefs, Ptr(src, psrc), Ptr(dst, pdst), crow);
- }
- }
-
- public static void MatTimesSrc(bool add, int[] mprowiv, int[] mprowcol,
- int[] mprowrun, int[] runs, float[] coefs,
- AlignedArray src, AlignedArray dst, int crow)
- {
- Contracts.AssertNonEmpty(mprowiv);
- Contracts.Assert(mprowiv.Length == crow);
- Contracts.AssertNonEmpty(mprowcol);
- Contracts.Assert(mprowcol.Length == crow);
- Contracts.Assert(mprowrun == null || mprowrun.Length == crow);
- Contracts.AssertNonEmpty(runs);
- Contracts.AssertNonEmpty(coefs);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(0 < crow && crow <= dst.Size);
-
- unsafe
- {
- fixed (int* pmprowiv = &mprowiv[0])
- fixed (int* pmprowcol = &mprowcol[0])
- fixed (int* pruns = &runs[0])
- fixed (float* pcoefs = &coefs[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- {
- if (mprowrun == null)
- {
- Thunk.MatMulCU(add, pmprowiv, pmprowcol, pruns, pcoefs,
- Ptr(src, psrc), Ptr(dst, pdst), crow);
- }
- else
- {
- fixed (int* pmprowrun = &mprowrun[0])
- {
- Thunk.MatMulDU(add, pmprowiv, pmprowcol, pmprowrun, pruns, pcoefs,
- Ptr(src, psrc), Ptr(dst, pdst), crow);
- }
- }
- }
- }
- }
-
- public static void MeanOfSrc(bool add, int[] mprowcol, int[] mprowindices,
- int[] indices, AlignedArray src, AlignedArray dst, int crow)
- {
- Contracts.AssertNonEmpty(mprowcol);
- Contracts.Assert(mprowcol.Length == crow);
- Contracts.Assert(mprowindices == null || mprowindices.Length == crow);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(0 < crow && crow <= dst.Size);
-
- unsafe
- {
- fixed (int* pmprowcol = &mprowcol[0])
- fixed (int* pmprowindices = mprowindices)
- fixed (int* pindices = &indices[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.MeanU(add, pmprowcol, pmprowindices, pindices, Ptr(src, psrc), Ptr(dst, pdst), crow);
- }
- }
-
- public static void MaxOfSrc(bool add, int[] mprowcol, int[] mprowindices,
- int[] indices, AlignedArray src, AlignedArray dst, int crow)
- {
- Contracts.AssertNonEmpty(mprowcol);
- Contracts.Assert(mprowcol.Length == crow);
- Contracts.Assert(mprowindices == null || mprowindices.Length == crow);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(0 < crow && crow <= dst.Size);
-
- unsafe
- {
- fixed (int* pmprowcol = &mprowcol[0])
- fixed (int* pmprowindices = mprowindices)
- fixed (int* pindices = &indices[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.MaxU(add, pmprowcol, pmprowindices, pindices, Ptr(src, psrc), Ptr(dst, pdst), crow);
- }
- }
-
- public static void RespNormOfSrc(bool add, float alpha, float beta, bool avgOverFullKernel, float offset,
- int[] mprowcol, int[] mprowindices, int[] indices,
- AlignedArray src, AlignedArray dst, int crow)
- {
- Contracts.AssertNonEmpty(mprowcol);
- Contracts.Assert(mprowcol.Length == crow);
- Contracts.Assert(mprowindices == null || mprowindices.Length == crow);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(0 < crow && crow <= dst.Size);
-
- unsafe
- {
- fixed (int* pmprowcol = &mprowcol[0])
- fixed (int* pmprowindices = mprowindices)
- fixed (int* pindices = &indices[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- {
- Thunk.RespNormU(add, alpha, beta, avgOverFullKernel, offset, pmprowcol, pmprowindices, pindices,
- Ptr(src, psrc), Ptr(dst, pdst), crow);
- }
- }
- }
-
- public static void MatTranTimesSrc(bool add, int[] starts, int[] indices, float[] coefs,
- AlignedArray src, AlignedArray dst, int ccol)
- {
- Contracts.AssertNonEmpty(starts);
- Contracts.Assert(starts.Length == ccol + 1);
- Contracts.Assert(starts[0] == 0);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(starts[ccol] == indices.Length);
- Contracts.AssertNonEmpty(coefs);
- Contracts.Assert(indices.Length == coefs.Length);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(0 < ccol && ccol <= src.Size);
- Contracts.Assert(dst.Size * ccol >= coefs.Length);
-
- unsafe
- {
- fixed (int* pstarts = &starts[0])
- fixed (int* pindices = &indices[0])
- fixed (float* pcoefs = &coefs[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.MatMulTranRU(add, pstarts, pindices, pcoefs, Ptr(src, psrc), Ptr(dst, pdst), dst.Size, ccol);
- }
- }
-
- public static void MatTranTimesSrc(bool add, int[] mpcoliv, int[] mpcolrow, int[] mpcolrun,
- int[] runs, float[] coefs, AlignedArray src, AlignedArray dst, int ccol)
- {
- Contracts.AssertNonEmpty(mpcoliv);
- Contracts.Assert(mpcoliv.Length == ccol);
- Contracts.AssertNonEmpty(mpcolrow);
- Contracts.Assert(mpcolrow.Length == ccol);
- Contracts.AssertNonEmpty(runs);
- Contracts.AssertNonEmpty(coefs);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(mpcolrun == null || mpcolrun.Length == ccol);
- Contracts.Assert(0 < ccol && ccol <= src.Size);
-
- unsafe
- {
- fixed (int* pmpcoliv = &mpcoliv[0])
- fixed (int* pmpcolrow = &mpcolrow[0])
- fixed (int* pruns = &runs[0])
- fixed (float* pcoefs = &coefs[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- {
- if (mpcolrun == null)
- {
- Thunk.MatMulTranCU(add, pmpcoliv, pmpcolrow, pruns, pcoefs,
- Ptr(src, psrc), Ptr(dst, pdst), dst.Size, ccol);
- }
- else
- {
- fixed (int* pmpcolrun = &mpcolrun[0])
- {
- Thunk.MatMulTranDU(add, pmpcoliv, pmpcolrow, pmpcolrun, pruns, pcoefs,
- Ptr(src, psrc), Ptr(dst, pdst), dst.Size, ccol);
- }
- }
- }
- }
- }
-
- public static void MeanBackOfSrc(bool add, int[] mpcolrow, int[] mpcolindices,
- int[] indices, AlignedArray src, AlignedArray dst, int ccol)
- {
- Contracts.AssertNonEmpty(mpcolrow);
- Contracts.Assert(mpcolrow.Length == ccol);
- Contracts.Assert(mpcolindices == null || mpcolindices.Length == ccol);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(0 < ccol && ccol <= src.Size);
-
- unsafe
- {
- fixed (int* pmpcolrow = &mpcolrow[0])
- fixed (int* pmpcolindices = mpcolindices)
- fixed (int* pindices = &indices[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.MeanBackU(add, pmpcolrow, pmpcolindices, pindices, Ptr(src, psrc), Ptr(dst, pdst), dst.Size, ccol);
- }
- }
-
- public static void MaxBackOfSrc(bool add, int[] mpcolrow, int[] mpcolindices,
- int[] indices, AlignedArray src, AlignedArray dst, AlignedArray val, int ccol)
- {
- Contracts.AssertNonEmpty(mpcolrow);
- Contracts.Assert(mpcolrow.Length == ccol);
- Contracts.Assert(mpcolindices == null || mpcolindices.Length == ccol);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(Compat(val));
- Contracts.Assert(0 < ccol && ccol <= src.Size);
- Contracts.Assert(dst.Size == val.Size);
-
- unsafe
- {
- fixed (int* pmpcolrow = &mpcolrow[0])
- fixed (int* pmpcolindices = mpcolindices)
- fixed (int* pindices = &indices[0])
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- fixed (float* pval = &val.Items[0])
- Thunk.MaxBackU(add, pmpcolrow, pmpcolindices, pindices, Ptr(src, psrc), Ptr(dst, pdst), Ptr(val, pval), dst.Size, ccol);
- }
- }
-
- public static void RespNormBackOfSrc(bool add, float alpha, float beta, bool avgOverFullKernel, float offset,
- int[] mpcolrow, int[] mpcolindices, int[] indices,
- AlignedArray errors, AlignedArray errorsPrev, AlignedArray valuesPrev, int ccol)
- {
- Contracts.AssertNonEmpty(mpcolrow);
- Contracts.Assert(mpcolrow.Length == ccol);
- Contracts.Assert(mpcolindices == null || mpcolindices.Length == ccol);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(Compat(errors));
- Contracts.Assert(Compat(errorsPrev));
- Contracts.Assert(Compat(valuesPrev));
- Contracts.Assert(0 < ccol && ccol <= errors.Size);
- Contracts.Assert(errorsPrev.Size == valuesPrev.Size);
-
- unsafe
- {
- fixed (int* pmpcolrow = &mpcolrow[0])
- fixed (int* pmpcolindices = mpcolindices)
- fixed (int* pindices = &indices[0])
- fixed (float* perr = &errors.Items[0])
- fixed (float* perrPrev = &errorsPrev.Items[0])
- fixed (float* pvalPrev = &valuesPrev.Items[0])
- {
- Thunk.RespNormBackU(add, alpha, beta, avgOverFullKernel, offset, pmpcolrow, pmpcolindices, pindices,
- Ptr(errors, perr), Ptr(errorsPrev, perrPrev), Ptr(valuesPrev, pvalPrev), errorsPrev.Size, ccol);
- }
- }
- }
-
- public static void AddXYTran(float a, AlignedArray x, AlignedArray y, AlignedArray mat, int crow, float decay)
- {
- Contracts.Assert(Compat(x));
- Contracts.Assert(Compat(y));
- Contracts.Assert(Compat(mat));
- Contracts.Assert(0 < crow && crow <= x.Size);
- Contracts.Assert(x.Size * y.Size == mat.Size);
- Contracts.Assert(decay >= 0);
-
- unsafe
- {
- fixed (float* px = &x.Items[0])
- fixed (float* py = &y.Items[0])
- fixed (float* pmat = &mat.Items[0])
- Thunk.AddXYTranA(a, Ptr(x, px), Ptr(y, py), Ptr(mat, pmat), crow, y.Size, decay);
- }
- }
-
- public static void AddXYTran(float a, AlignedArray x, int[] rgposY, AlignedArray valuesY,
- int posMinY, int iposMinY, int iposLimY, AlignedArray mat, int crow)
- {
- Contracts.Assert(Compat(x));
- Contracts.Assert(Compat(valuesY));
- Contracts.Assert(Compat(mat));
- Contracts.AssertNonEmpty(rgposY);
- Contracts.Assert(0 <= iposMinY && iposMinY <= iposLimY && iposLimY <= rgposY.Length);
- Contracts.Assert(0 < crow && crow <= x.Size);
- Contracts.Assert(x.Size * valuesY.Size == mat.Size);
-
- if (iposMinY >= iposLimY)
- return;
-
- unsafe
- {
- fixed (float* px = &x.Items[0])
- fixed (float* py = &valuesY.Items[0])
- fixed (int* pposy = &rgposY[0])
- fixed (float* pmat = &mat.Items[0])
- {
- Thunk.AddXYTranPA(a, Ptr(x, px), pposy, Ptr(valuesY, py), posMinY, iposMinY, iposLimY, Ptr(mat, pmat),
- crow, valuesY.Size);
- }
- }
- }
-
- public static void AddXYTran(float a, AlignedArray x, AlignedArray y,
- int[] starts, int[] indices, float[] coefs, int crow, float decay)
- {
- Contracts.Assert(Compat(x));
- Contracts.Assert(Compat(y));
- Contracts.Assert(0 < crow && crow <= x.Size);
- Contracts.AssertNonEmpty(starts);
- Contracts.Assert(starts.Length == crow + 1);
- Contracts.Assert(starts[0] == 0);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(starts[crow] == indices.Length);
- Contracts.AssertNonEmpty(coefs);
- Contracts.Assert(indices.Length == coefs.Length);
- Contracts.Assert(crow * y.Size >= coefs.Length);
- Contracts.Assert(decay >= 0);
-
- unsafe
- {
- fixed (float* px = &x.Items[0])
- fixed (float* py = &y.Items[0])
- fixed (int* pstarts = &starts[0])
- fixed (int* pindices = &indices[0])
- fixed (float* pcoefs = &coefs[0])
- Thunk.AddXYTranRU(a, Ptr(x, px), Ptr(y, py), pstarts, pindices, pcoefs, crow, decay);
- }
- }
-
- public static void AddXYTran(float a, AlignedArray x, AlignedArray y, int[] mprowiv,
- int[] mprowcol, int[] mprowrun, int[] runs, float[] coefs, int crow)
- {
- Contracts.Assert(Compat(x));
- Contracts.Assert(Compat(y));
- Contracts.Assert(0 < crow && crow <= x.Size);
- Contracts.AssertNonEmpty(mprowiv);
- Contracts.Assert(mprowiv.Length == crow);
- Contracts.AssertNonEmpty(mprowcol);
- Contracts.Assert(mprowcol.Length == crow);
- Contracts.Assert(mprowrun == null || mprowrun.Length == crow);
- Contracts.AssertNonEmpty(runs);
- Contracts.AssertNonEmpty(coefs);
-
- unsafe
- {
- fixed (float* px = &x.Items[0])
- fixed (float* py = &y.Items[0])
- fixed (int* pmprowiv = &mprowiv[0])
- fixed (int* pmprowcol = &mprowcol[0])
- fixed (int* pruns = &runs[0])
- fixed (float* pcoefs = &coefs[0])
- {
- if (mprowrun == null)
- Thunk.AddXYTranCU(a, Ptr(x, px), Ptr(y, py), pmprowiv, pmprowcol, pruns, pcoefs, crow);
- else
- {
- fixed (int* pmprowrun = &mprowrun[0])
- Thunk.AddXYTranDU(a, Ptr(x, px), Ptr(y, py), pmprowiv, pmprowcol, pmprowrun, pruns, pcoefs, crow);
- }
- }
- }
- }
-
- public static void AddXYTran(float a, AlignedArray x, AlignedArray y, AlignedArray mat, float momentum, AlignedArray delta, int crow)
- {
- Contracts.Assert(Compat(x));
- Contracts.Assert(Compat(y));
- Contracts.Assert(Compat(mat));
- Contracts.Assert(Compat(delta));
- Contracts.Assert(0 < crow && crow <= x.Size);
- Contracts.Assert(x.Size * y.Size == mat.Size);
- Contracts.Assert(mat.Size == delta.Size);
-
- unsafe
- {
- fixed (float* px = &x.Items[0])
- fixed (float* py = &y.Items[0])
- fixed (float* pmat = &mat.Items[0])
- fixed (float* pdel = &delta.Items[0])
- Thunk.AddXYTranMomA(a, Ptr(x, px), Ptr(y, py), Ptr(mat, pmat), momentum, Ptr(delta, pdel), crow, y.Size);
- }
- }
-
- public static void AddXYTran(AlignedArray x, AlignedArray y, AlignedArray mat, AlignedArray accGrads, AlignedArray accUpdates,
- float decay, float cond, int crow)
- {
- Contracts.Assert(Compat(x));
- Contracts.Assert(Compat(y));
- Contracts.Assert(Compat(mat));
- Contracts.Assert(Compat(accGrads));
- Contracts.Assert(Compat(accUpdates));
- Contracts.Assert(0 < crow && crow <= x.Size);
- Contracts.Assert(x.Size * y.Size == mat.Size);
- Contracts.Assert(mat.Size == accGrads.Size);
- Contracts.Assert(mat.Size == accUpdates.Size);
-
- unsafe
- {
- fixed (float* px = &x.Items[0])
- fixed (float* py = &y.Items[0])
- fixed (float* pmat = &mat.Items[0])
- fixed (float* pag = &accGrads.Items[0])
- fixed (float* pau = &accUpdates.Items[0])
- Thunk.AddXYTranGradA(Ptr(x, px), Ptr(y, py), Ptr(mat, pmat), Ptr(accGrads, pag), Ptr(accUpdates, pau), decay, cond, crow, y.Size);
- }
- }
-
- public static void AddXYTran(AlignedArray x, AlignedArray y, int[] starts, int[] indices,
- float[] coefs, float[] accGrads, float[] accUpdates, float decay, float cond, int crow)
- {
- Contracts.Assert(Compat(x));
- Contracts.Assert(Compat(y));
- Contracts.Assert(0 < crow && crow <= x.Size);
- Contracts.AssertNonEmpty(starts);
- Contracts.Assert(starts.Length == crow + 1);
- Contracts.Assert(starts[0] == 0);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(starts[crow] == indices.Length);
- Contracts.AssertNonEmpty(coefs);
- Contracts.Assert(indices.Length == coefs.Length);
- Contracts.Assert(crow * y.Size >= coefs.Length);
- Contracts.AssertNonEmpty(accGrads);
- Contracts.Assert(coefs.Length == accGrads.Length);
- Contracts.AssertNonEmpty(accUpdates);
- Contracts.Assert(coefs.Length == accUpdates.Length);
-
- unsafe
- {
- fixed (float* px = &x.Items[0])
- fixed (float* py = &y.Items[0])
- fixed (int* pstarts = &starts[0])
- fixed (int* pindices = &indices[0])
- fixed (float* pcoefs = &coefs[0])
- fixed (float* pag = &accGrads[0])
- fixed (float* pau = &accUpdates[0])
- Thunk.AddXYTranGradRU(Ptr(x, px), Ptr(y, py), pstarts, pindices, pcoefs, pag, pau, decay, cond, crow);
- }
- }
-
- public static void AddXYTran(AlignedArray x, int[] rgposY, AlignedArray valuesY,
- int posMinY, int iposMinY, int iposLimY, AlignedArray mat,
- AlignedArray accGrads, AlignedArray accUpdates, float decay, float cond, int crow)
- {
- Contracts.Assert(Compat(x));
- Contracts.AssertNonEmpty(rgposY);
- Contracts.Assert(Compat(valuesY));
- Contracts.Assert(Compat(mat));
- Contracts.Assert(0 <= iposMinY && iposMinY <= iposLimY && iposLimY <= rgposY.Length);
- Contracts.Assert(0 < crow && crow <= x.Size);
- Contracts.Assert(x.Size * valuesY.Size == mat.Size);
- Contracts.Assert(mat.Size == accGrads.Size);
- Contracts.Assert(mat.Size == accUpdates.Size);
-
- if (iposMinY >= iposLimY)
- return;
-
- unsafe
- {
- fixed (float* px = &x.Items[0])
- fixed (float* py = &valuesY.Items[0])
- fixed (int* pposy = &rgposY[0])
- fixed (float* pmat = &mat.Items[0])
- fixed (float* pag = &accGrads.Items[0])
- fixed (float* pau = &accUpdates.Items[0])
- {
- Thunk.AddXYTranGradPA(Ptr(x, px), pposy, Ptr(valuesY, py), posMinY, iposMinY, iposLimY, Ptr(mat, pmat),
- Ptr(accGrads, pag), Ptr(accUpdates, pau), decay, cond, crow, valuesY.Size);
+ Thunk.MatMulP(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size);
}
}
}
@@ -591,17 +98,6 @@ public static void Add(float a, Span dst)
}
}
- public static void Scale(float a, AlignedArray dst)
- {
- Contracts.Assert(Compat(dst));
-
- unsafe
- {
- fixed (float* pdst = &dst.Items[0])
- Thunk.Scale(a, Ptr(dst, pdst), dst.Size);
- }
- }
-
public static void Scale(float a, Span dst)
{
Contracts.AssertNonEmpty(dst);
@@ -613,19 +109,6 @@ public static void Scale(float a, Span dst)
}
}
- public static void Scale(float a, float[] dst, int offset, int count)
- {
- Contracts.AssertNonEmpty(dst);
- Contracts.Assert(0 < count);
- Contracts.Assert(0 <= offset && offset < dst.Length - count);
-
- unsafe
- {
- fixed (float* pd = &dst[offset])
- Thunk.Scale(a, pd, count);
- }
- }
-
// dst = a * src
public static void Scale(float a, ReadOnlySpan src, Span dst, int count)
{
@@ -656,98 +139,6 @@ public static void ScaleAdd(float a, float b, Span dst)
}
}
- public static void ScaleConvWeights(float a, int kernelSize, float[] dst)
- {
- Contracts.AssertValue(dst);
-
- // REVIEW: implement in SSE/AVX.
- for (int istart = 0; istart < dst.Length; istart += kernelSize + 1)
- {
- for (int i = 0; i < kernelSize; i++)
- dst[istart + i] *= a;
- }
- }
-
- public static void ScaleMaxNorm(bool tran, float maxNorm, AlignedArray mat, int crun, int runLenPhy)
- {
- // Called also by MklMath which uses Avx alignment, which is a multiple of Sse alignment.
- // Hence, Compat(mat) cannot be asserted here since it checks for exact Sse alignment (mat.CbAlign == CbAlign).
- Contracts.AssertValue(mat);
- Contracts.Assert(mat.Size > 0);
- Contracts.Assert((mat.CbAlign % CbAlign) == 0);
-
- unsafe
- {
- fixed (float* pmat = &mat.Items[0])
- {
- if (!tran)
- Thunk.ScaleMaxNormA(maxNorm, Ptr(mat, pmat), crun, runLenPhy);
- else
- Thunk.ScaleMaxNormTranU(maxNorm, Ptr(mat, pmat), crun, runLenPhy);
- }
- }
- }
-
- public static void ScaleMaxNorm(float maxNorm, int[] starts, int[] indices, float[] mat)
- {
- Contracts.AssertNonEmpty(starts);
-
- int crow = starts.Length - 1;
- Contracts.Assert(starts[0] == 0);
- Contracts.AssertValue(indices);
- Contracts.Assert(starts[crow] == indices.Length);
- Contracts.AssertNonEmpty(mat);
-
- unsafe
- {
- fixed (int* pstarts = &starts[0])
- fixed (float* pmat = &mat[0])
- Thunk.ScaleMaxNormRU(maxNorm, pstarts, pmat, crow);
- }
- }
-
- public static void ScaleMaxNorm(float maxNorm, int kernCount, int kernSize, float[] mat)
- {
- Contracts.AssertNonEmpty(mat);
-
- unsafe
- {
- fixed (float* pmat = &mat[0])
- Thunk.ScaleMaxNormCU(maxNorm, kernCount, kernSize, pmat);
- }
- }
-
- public static void AddScale(float a, AlignedArray src, AlignedArray dst)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.AddScaleA(a, Ptr(src, psrc), Ptr(dst, pdst), dst.Size);
- }
- }
-
- public static void AddScale(float a, AlignedArray src, AlignedArray dst, float momentum, AlignedArray delta)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(Compat(delta));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(src.Size == delta.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- fixed (float* pdel = &delta.Items[0])
- Thunk.AddScaleMomA(a, Ptr(src, psrc), Ptr(dst, pdst), momentum, Ptr(delta, pdel), dst.Size);
- }
- }
-
public static void AddScale(float a, ReadOnlySpan src, Span dst, int count)
{
Contracts.AssertNonEmpty(src);
@@ -799,41 +190,6 @@ public static void AddScaleCopy(float a, ReadOnlySpan src, ReadOnlySpan src, Span dst, int count)
{
Contracts.AssertNonEmpty(src);
@@ -883,36 +239,6 @@ public static void MulElementWise(ReadOnlySpan src1, ReadOnlySpan
}
}
- public static void MulElementWise(float[] src1, float[] src2, int[] indices, float[] dst, int count)
- {
- Contracts.AssertNonEmpty(src1);
- Contracts.Assert(0 < count && count <= src1.Length);
- Contracts.AssertNonEmpty(src2);
- Contracts.Assert(0 < count && count <= src2.Length);
- Contracts.AssertNonEmpty(dst);
- Contracts.AssertNonEmpty(indices);
- Contracts.Assert(count <= indices.Length);
- unsafe
- {
- fixed (float* ps1 = &src1[0])
- fixed (float* ps2 = &src2[0])
- fixed (int* pi = &indices[0])
- fixed (float* pd = &dst[0])
- Thunk.MulElementWiseSU(ps1, ps2, pi, pd, count);
- }
- }
-
- public static float Sum(AlignedArray src)
- {
- Contracts.Assert(Compat(src));
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- return Thunk.SumA(Ptr(src, psrc), src.Size);
- }
- }
-
public static float Sum(ReadOnlySpan src)
{
Contracts.AssertNonEmpty(src);
@@ -1039,262 +365,6 @@ public static float L2DistSquared(ReadOnlySpan a, ReadOnlySpan b,
}
}
- public static void ApplySigmoid(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 < c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplySigmoidA(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplySoftMax(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 < c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplySoftMaxA(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplyRectifiedLinear(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 < c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplyRectifiedLinearA(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplySquare(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 < c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplySquareA(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplySqrt(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 < c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplySqrtA(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplySoftRectifiedLinear(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 < c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplySoftRectifiedLinearA(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplyAbs(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 < c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplyAbsA(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplyTanh(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 < c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplyTanhA(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplyBoundedRectifiedLinear(AlignedArray src, AlignedArray dst, int c)
- {
- Contracts.Assert(Compat(src));
- Contracts.Assert(Compat(dst));
- Contracts.Assert(src.Size == dst.Size);
- Contracts.Assert(0 <= c && c <= dst.Size);
-
- unsafe
- {
- fixed (float* psrc = &src.Items[0])
- fixed (float* pdst = &dst.Items[0])
- Thunk.ApplyBoundedRectifiedLinearA(Ptr(src, psrc), Ptr(dst, pdst), c);
- }
- }
-
- public static void ApplySigmoidDerivative(AlignedArray value, AlignedArray grad)
- {
- Contracts.Assert(Compat(value));
- Contracts.Assert(Compat(grad));
- Contracts.Assert(value.Size == grad.Size);
-
- unsafe
- {
- fixed (float* pvalue = &value.Items[0])
- fixed (float* pgrad = &grad.Items[0])
- Thunk.ApplySigmoidDerivativeA(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size);
- }
- }
-
- public static void ApplyRectifiedLinearDerivative(AlignedArray value, AlignedArray grad)
- {
- Contracts.Assert(Compat(value));
- Contracts.Assert(Compat(grad));
- Contracts.Assert(value.Size == grad.Size);
-
- unsafe
- {
- fixed (float* pvalue = &value.Items[0])
- fixed (float* pgrad = &grad.Items[0])
- Thunk.ApplyRectifiedLinearDerivativeA(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size);
- }
- }
-
- public static void ApplySquareDerivative(AlignedArray input, AlignedArray output, AlignedArray grad, bool drop)
- {
- Contracts.Assert(Compat(input));
- Contracts.Assert(Compat(output));
- Contracts.Assert(Compat(grad));
- Contracts.Assert(output.Size == input.Size);
- Contracts.Assert(output.Size == grad.Size);
-
- unsafe
- {
- fixed (float* px = &input.Items[0])
- fixed (float* py = &output.Items[0])
- fixed (float* pg = &grad.Items[0])
- Thunk.ApplySquareDerivativeA(Ptr(input, px), Ptr(output, py), Ptr(grad, pg), grad.Size, drop);
- }
- }
-
- public static void ApplySqrtDerivative(AlignedArray value, AlignedArray grad)
- {
- Contracts.Assert(Compat(value));
- Contracts.Assert(Compat(grad));
- Contracts.Assert(value.Size == grad.Size);
-
- unsafe
- {
- fixed (float* pvalue = &value.Items[0])
- fixed (float* pgrad = &grad.Items[0])
- Thunk.ApplySqrtDerivativeA(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size);
- }
- }
-
- public static void ApplySoftRectifiedLinearDerivative(AlignedArray input, AlignedArray output, AlignedArray grad)
- {
- Contracts.Assert(Compat(input));
- Contracts.Assert(Compat(output));
- Contracts.Assert(Compat(grad));
- Contracts.Assert(output.Size == input.Size);
- Contracts.Assert(output.Size == grad.Size);
-
- unsafe
- {
- fixed (float* px = &input.Items[0])
- fixed (float* py = &output.Items[0])
- fixed (float* pg = &grad.Items[0])
- Thunk.ApplySoftRectifiedLinearDerivativeA(Ptr(input, px), Ptr(output, py), Ptr(grad, pg), grad.Size);
- }
- }
-
- public static void ApplyAbsDerivative(AlignedArray input, AlignedArray output, AlignedArray grad, bool drop)
- {
- Contracts.Assert(Compat(input));
- Contracts.Assert(Compat(output));
- Contracts.Assert(Compat(grad));
- Contracts.Assert(output.Size == input.Size);
- Contracts.Assert(output.Size == grad.Size);
-
- unsafe
- {
- fixed (float* px = &input.Items[0])
- fixed (float* py = &output.Items[0])
- fixed (float* pg = &grad.Items[0])
- Thunk.ApplyAbsDerivativeA(Ptr(input, px), Ptr(output, py), Ptr(grad, pg), grad.Size, drop);
- }
- }
-
- public static void ApplyTanhDerivative(AlignedArray value, AlignedArray grad)
- {
- Contracts.Assert(Compat(value));
- Contracts.Assert(Compat(grad));
- Contracts.Assert(value.Size == grad.Size);
-
- unsafe
- {
- fixed (float* pvalue = &value.Items[0])
- fixed (float* pgrad = &grad.Items[0])
- Thunk.ApplyTanhDerivativeA(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size);
- }
- }
-
- public static void ApplyBoundedRectifiedLinearDerivative(AlignedArray value, AlignedArray grad)
- {
- Contracts.Assert(Compat(value));
- Contracts.Assert(Compat(grad));
- Contracts.Assert(value.Size == grad.Size);
-
- unsafe
- {
- fixed (float* pvalue = &value.Items[0])
- fixed (float* pgrad = &grad.Items[0])
- Thunk.ApplyBoundedRectifiedLinearDerivativeA(Ptr(value, pvalue), Ptr(grad, pgrad), grad.Size);
- }
- }
-
public static void ZeroMatrixItems(AlignedArray dst, int ccol, int cfltRow, int[] indices)
{
Contracts.Assert(0 < ccol && ccol <= cfltRow);
@@ -1352,24 +422,5 @@ public static void SdcaL1UpdateSparse(float primalUpdate, int count, ReadOnlySpa
Thunk.SdcaL1UpdateSU(primalUpdate, psrc, pi, threshold, pd1, pd2, count);
}
}
-
- public static void ScaleAdadelta(float[] mat, float[] accGrads, float[] accUpdates, float decay, float cond, float[] grads)
- {
- Contracts.AssertNonEmpty(mat);
- Contracts.AssertNonEmpty(accGrads);
- Contracts.AssertNonEmpty(accUpdates);
- Contracts.Assert(mat.Length == accGrads.Length);
- Contracts.Assert(mat.Length == accUpdates.Length);
- Contracts.Assert(mat.Length <= grads.Length);
-
- unsafe
- {
- fixed (float* pm = &mat[0])
- fixed (float* pag = &accGrads[0])
- fixed (float* pau = &accUpdates[0])
- fixed (float* pg = &grads[0])
- Thunk.ScaleAdadeltaU(pm, pag, pau, decay, cond, pg, mat.Length);
- }
- }
}
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs
index 1a99c5bd92..dbcd69aeef 100644
--- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs
+++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs
@@ -44,26 +44,6 @@ internal static class SseIntrinsics
Sse.StaticCast(Sse2.SetAllVector128(0x7FFFFFFF)) :
Sse.SetAllVector128(BitConverter.Int32BitsToSingle(0x7FFFFFFF));
- // The count of bytes in Vector128, corresponding to _cbAlign in AlignedArray
- private const int Vector128Alignment = 16;
-
- [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
- private static bool HasCompatibleAlignment(AlignedArray alignedArray)
- {
- Contracts.AssertValue(alignedArray);
- Contracts.Assert(alignedArray.Size > 0);
- return (alignedArray.CbAlign % Vector128Alignment) == 0;
- }
-
- [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
- private static unsafe float* GetAlignedBase(AlignedArray alignedArray, float* unalignedBase)
- {
- Contracts.AssertValue(alignedArray);
- float* alignedBase = unalignedBase + alignedArray.GetBase((long)unalignedBase);
- Contracts.Assert(((long)alignedBase & (Vector128Alignment - 1)) == 0);
- return alignedBase;
- }
-
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
internal static unsafe Vector128 Load1(float* src, int* idx)
=> Sse.SetScalarVector128(src[idx[0]]);
@@ -137,17 +117,17 @@ internal static Vector128 GetNewDst128(in Vector128 xDst1, in Vect
// Multiply matrix times vector into vector.
public static unsafe void MatMul(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
{
- Contracts.Assert(crow % 4 == 0);
- Contracts.Assert(ccol % 4 == 0);
-
MatMul(mat.Items, src.Items, dst.Items, crow, ccol);
}
- public static unsafe void MatMul(float[] mat, float[] src, float[] dst, int crow, int ccol)
+ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crow, int ccol)
{
- fixed (float* psrc = &src[0])
- fixed (float* pdst = &dst[0])
- fixed (float* pmat = &mat[0])
+ Contracts.Assert(crow % 4 == 0);
+ Contracts.Assert(ccol % 4 == 0);
+
+ fixed (float* psrc = &MemoryMarshal.GetReference(src))
+ fixed (float* pdst = &MemoryMarshal.GetReference(dst))
+ fixed (float* pmat = &MemoryMarshal.GetReference(mat))
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
{
@@ -285,24 +265,27 @@ public static unsafe void MatMul(float[] mat, float[] src, float[] dst, int crow
}
// Partial sparse source vector.
- public static unsafe void MatMulPA(AlignedArray mat, int[] rgposSrc, AlignedArray src,
- int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol)
+ public static unsafe void MatMulP(AlignedArray mat, int[] rgposSrc, AlignedArray src,
+ int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol)
{
- Contracts.Assert(HasCompatibleAlignment(mat));
- Contracts.Assert(HasCompatibleAlignment(src));
- Contracts.Assert(HasCompatibleAlignment(dst));
+ MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol);
+ }
+
+ public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgposSrc, ReadOnlySpan src,
+ int posMin, int iposMin, int iposEnd, Span dst, int crow, int ccol)
+ {
+ Contracts.Assert(crow % 4 == 0);
+ Contracts.Assert(ccol % 4 == 0);
// REVIEW: For extremely sparse inputs, interchanging the loops would
// likely be more efficient.
- fixed (float* pSrcStart = &src.Items[0])
- fixed (float* pDstStart = &dst.Items[0])
- fixed (float* pMatStart = &mat.Items[0])
- fixed (int* pposSrc = &rgposSrc[0])
+ fixed (float* psrc = &MemoryMarshal.GetReference(src))
+ fixed (float* pdst = &MemoryMarshal.GetReference(dst))
+ fixed (float* pmat = &MemoryMarshal.GetReference(mat))
+ fixed (int* pposSrc = &MemoryMarshal.GetReference(rgposSrc))
+ fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
+ fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
{
- float* psrc = GetAlignedBase(src, pSrcStart);
- float* pdst = GetAlignedBase(dst, pDstStart);
- float* pmat = GetAlignedBase(mat, pMatStart);
-
int* pposMin = pposSrc + iposMin;
int* pposEnd = pposSrc + iposEnd;
float* pDstEnd = pdst + crow;
@@ -310,7 +293,105 @@ public static unsafe void MatMulPA(AlignedArray mat, int[] rgposSrc, AlignedArra
float* pSrcCurrent = psrc - posMin;
float* pDstCurrent = pdst;
- while (pDstCurrent < pDstEnd)
+ nuint address = (nuint)(pDstCurrent);
+ int misalignment = (int)(address % 16);
+
+ int length = crow;
+ int remainder = 0;
+
+ if ((misalignment & 3) != 0)
+ {
+ while (pDstCurrent < pDstEnd)
+ {
+ Sse.Store(pDstCurrent, SparseMultiplicationAcrossRow());
+ pDstCurrent += 4;
+ pm0 += 4 * ccol;
+ }
+ }
+ else
+ {
+ if (misalignment != 0)
+ {
+ misalignment >>= 2;
+ misalignment = 4 - misalignment;
+
+ Vector128 mask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4));
+
+ float* pm1 = pm0 + ccol;
+ float* pm2 = pm1 + ccol;
+ float* pm3 = pm2 + ccol;
+ Vector128 result = Sse.SetZeroVector128();
+
+ int* ppos = pposMin;
+
+ while (ppos < pposEnd)
+ {
+ int col = *ppos;
+ Vector128 x1 = Sse.SetVector128(pm3[col], pm2[col], pm1[col], pm0[col]);
+
+ x1 = Sse.And(mask, x1);
+ Vector128 x2 = Sse.SetAllVector128(pSrcCurrent[col]);
+ x2 = Sse.Multiply(x2, x1);
+ result = Sse.Add(result, x2);
+ ppos++;
+ }
+
+ Sse.Store(pDstCurrent, result);
+ pDstCurrent += misalignment;
+ pm0 += misalignment * ccol;
+ length -= misalignment;
+ }
+
+ if (length > 3)
+ {
+ remainder = length % 4;
+ while (pDstCurrent < pDstEnd)
+ {
+ Sse.Store(pDstCurrent, SparseMultiplicationAcrossRow());
+ pDstCurrent += 4;
+ pm0 += 4 * ccol;
+ }
+ }
+ else
+ {
+ remainder = length;
+ }
+
+ if (remainder != 0)
+ {
+ pDstCurrent -= (4 - remainder);
+ pm0 -= (4 - remainder) * ccol;
+ Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4));
+ Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4));
+
+ float* pm1 = pm0 + ccol;
+ float* pm2 = pm1 + ccol;
+ float* pm3 = pm2 + ccol;
+ Vector128 result = Sse.SetZeroVector128();
+
+ int* ppos = pposMin;
+
+ while (ppos < pposEnd)
+ {
+ int col = *ppos;
+ Vector128 x1 = Sse.SetVector128(pm3[col], pm2[col], pm1[col], pm0[col]);
+ x1 = Sse.And(x1, trailingMask);
+
+ Vector128 x2 = Sse.SetAllVector128(pSrcCurrent[col]);
+ x2 = Sse.Multiply(x2, x1);
+ result = Sse.Add(result, x2);
+ ppos++;
+ }
+
+ result = Sse.Add(result, Sse.And(leadingMask, Sse.LoadVector128(pDstCurrent)));
+
+ Sse.Store(pDstCurrent, result);
+ pDstCurrent += 4;
+ pm0 += 4 * ccol;
+ }
+ }
+
+ Vector128 SparseMultiplicationAcrossRow()
{
float* pm1 = pm0 + ccol;
float* pm2 = pm1 + ccol;
@@ -326,29 +407,27 @@ public static unsafe void MatMulPA(AlignedArray mat, int[] rgposSrc, AlignedArra
Vector128 x2 = Sse.SetAllVector128(pSrcCurrent[col]);
x2 = Sse.Multiply(x2, x1);
result = Sse.Add(result, x2);
-
ppos++;
}
- Sse.StoreAligned(pDstCurrent, result);
- pDstCurrent += 4;
- pm0 += 4 * ccol;
+ return result;
}
}
}
public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
{
- Contracts.Assert(crow % 4 == 0);
- Contracts.Assert(ccol % 4 == 0);
MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol);
}
- public static unsafe void MatMulTran(float[] mat, float[] src, float[] dst, int crow, int ccol)
+ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crow, int ccol)
{
- fixed (float* psrc = &src[0])
- fixed (float* pdst = &dst[0])
- fixed (float* pmat = &mat[0])
+ Contracts.Assert(crow % 4 == 0);
+ Contracts.Assert(ccol % 4 == 0);
+
+ fixed (float* psrc = &MemoryMarshal.GetReference(src))
+ fixed (float* pdst = &MemoryMarshal.GetReference(dst))
+ fixed (float* pmat = &MemoryMarshal.GetReference(mat))
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
{
diff --git a/src/Microsoft.ML.CpuMath/Thunk.cs b/src/Microsoft.ML.CpuMath/Thunk.cs
index 22938bb0d7..d505bf40a3 100644
--- a/src/Microsoft.ML.CpuMath/Thunk.cs
+++ b/src/Microsoft.ML.CpuMath/Thunk.cs
@@ -12,397 +12,79 @@ internal static unsafe class Thunk
{
internal const string NativePath = "CpuMathNative";
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern bool ChkAvx();
-
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern void MatMul(/*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol);
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MatMulX(/*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MatMulPA(/*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc,
+ public static extern void MatMulP(/*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc,
int posMin, int iposMin, int iposLim, float* pdst, int crow, int ccol);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MatMulPX(/*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc,
- int posMin, int iposMin, int iposLim, float* pdst, int crow, int ccol);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MatMulRU(bool add, /*const*/ int* pstarts, /*const*/ int* pindices, /*const*/ float* pcoefs,
- /*const*/ float* psrc, float* pdst, int crow);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MatMulRX(bool add, /*const*/ int* pstarts, /*const*/ int* pindices, /*const*/ float* pcoefs,
- /*const*/ float* psrc, float* pdst, int crow);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MatMulCU(bool add, /*const*/ int* pmprowiv, /*const*/ int* pmprowcol,
- /*const*/ int* pruns, /*const*/ float* pcoefs, /*const*/ float* psrc, float* pdst, int crow);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MatMulDU(bool add, /*const*/ int* pmprowiv, /*const*/ int* pmprowcol, /*const*/ int* pmprowrun,
- /*const*/ int* pruns, /*const*/ float* pcoefs, /*const*/ float* psrc, float* pdst, int crow);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MatMulCX(bool add, /*const*/ int* pmprowiv, /*const*/ int* pmprowcol,
- /*const*/ int* pruns, /*const*/ float* pcoefs, /*const*/ float* psrc, float* pdst, int crow);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MatMulDX(bool add, /*const*/ int* pmprowiv, /*const*/ int* pmprowcol, /*const*/ int* pmprowrun,
- /*const*/ int* pruns, /*const*/ float* pcoefs, /*const*/ float* psrc, float* pdst, int crow);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MeanU(bool add, /*const*/ int* pmprowcol, /*const*/ int* pmprowindices, /*const*/ int* pindices,
- /*const*/ float* psrc, float* pdst, int crow);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MaxU(bool add, /*const*/ int* pmprowcol, /*const*/ int* pmprowindices, /*const*/ int* pindices,
- /*const*/ float* psrc, float* pdst, int crow);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void RespNormU(bool add, float alpha, float beta, bool avgOverFullKernel, float offset,
- /*const*/ int* pmprowcol, /*const*/ int* pmprowindices, /*const*/ int* pindices,
- /*const*/ float* psrc, float* pdst, int crow);
// These treat pmat as if it is stored in column-major order. Thus, crow and ccol are the numbers of rows
// and columns from that perspective. Alternatively, crow is the number of rows in the transpose of pmat
// (thought of as row-major order).
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern void MatMulTran(/*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MatMulTranX(/*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MatMulTranRU(bool add, /*const*/ int* pstarts, /*const*/ int* pindices, /*const*/ float* pcoefs,
- /*const*/ float* psrc, float* pdst, int crow, int ccol);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MatMulTranRX(bool add, /*const*/ int* pstarts, /*const*/ int* pindices, /*const*/ float* pcoefs,
- /*const*/ float* psrc, float* pdst, int crow, int ccol);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MatMulTranCU(bool add, /*const*/ int* pmpcoliv, /*const*/ int* pmpcolrow,
- /*const*/ int* pruns, /*const*/ float* pcoefs, /*const*/ float* psrc, float* pdst, int crow, int ccol);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MatMulTranDU(bool add, /*const*/ int* pmpcoliv, /*const*/ int* pmpcolrow, /*const*/ int* pmpcolrun,
- /*const*/ int* pruns, /*const*/ float* pcoefs, /*const*/ float* psrc, float* pdst, int crow, int ccol);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MatMulTranCX(bool add, /*const*/ int* pmpcoliv, /*const*/ int* pmpcolrow,
- /*const*/ int* pruns, /*const*/ float* pcoefs, /*const*/ float* psrc, float* pdst, int crow, int ccol);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MatMulTranDX(bool add, /*const*/ int* pmpcoliv, /*const*/ int* pmpcolrow, /*const*/ int* pmpcolrun,
- /*const*/ int* pruns, /*const*/ float* pcoefs, /*const*/ float* psrc, float* pdst, int crow, int ccol);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MeanBackU(bool add, /*const*/ int* pmpcolrow, /*const*/ int* pmpcolindices, /*const*/ int* pindices,
- /*const*/ float* psrc, float* pdst, int crow, int ccol);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void MaxBackU(bool add, /*const*/ int* pmpcolrow, /*const*/ int* pmpcolindices, /*const*/ int* pindices,
- /*const*/ float* psrc, float* pdst, /*const*/ float* pval, int crow, int ccol);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void RespNormBackU(bool add, float alpha, float beta, bool avgOverFullKernel, float offset,
- /*const*/ int* pmpcolrow, /*const*/ int* pmpcolindices, /*const*/ int* pindices,
- /*const*/ float* perrors, float* perrorsPrev, /*const*/ float* pvaluesPrev, int crow, int ccol);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranA(float a, /*const*/ float* px, /*const*/ float* py, float* pmat, int crow, int ccol, float decay);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranX(float a, /*const*/ float* px, /*const*/ float* py, float* pmat, int crow, int ccol, float decay);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranPA(float a, /*const*/ float* px, /*const*/ int* pposY, /*const*/ float* pvaluesY,
- int posMinY, int iposMinY, int iposLimY, float* pmat, int crow, int ccol);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranPX(float a, /*const*/ float* px, /*const*/ int* pposY, /*const*/ float* pvaluesY,
- int posMinY, int iposMinY, int iposLimY, float* pmat, int crow, int ccol);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranRU(float a, /*const*/ float* px, /*const*/ float* py,
- /*const*/ int* pstarts, /*const*/ int* pindices, float* pcoefs, int crow, float decay);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranRX(float a, /*const*/ float* px, /*const*/ float* py,
- /*const*/ int* pstarts, /*const*/ int* pindices, float* pcoefs, int crow, float decay);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranCU(float a, /*const*/ float* px, /*const*/ float* py, /*const*/ int* pmprowiv, /*const*/ int* pmprowcol,
- /*const*/ int* pruns, float* pcoefs, int crow);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranDU(float a, /*const*/ float* px, /*const*/ float* py, /*const*/ int* pmprowiv, /*const*/ int* pmprowcol,
- /*const*/ int* pmprowrun, /*const*/ int* pruns, float* pcoefs, int crow);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranCX(float a, /*const*/ float* px, /*const*/ float* py, /*const*/ int* pmprowiv, /*const*/ int* pmprowcol,
- /*const*/ int* pruns, float* pcoefs, int crow);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranDX(float a, /*const*/ float* px, /*const*/ float* py, /*const*/ int* pmprowiv, /*const*/ int* pmprowcol,
- /*const*/ int* pmprowrun, /*const*/ int* pruns, float* pcoefs, int crow);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranMomA(float a, /*const*/ float* px, /*const*/ float* py, float* pmat, float momentum, float* pdel, int crow, int ccol);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranMomX(float a, /*const*/ float* px, /*const*/ float* py, float* pmat, float momentum, float* pdel, int crow, int ccol);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranGradA(/*const*/ float* px, /*const*/ float* py, float* pmat, float* paccGrads, float* paccUpdates,
- float decay, float cond, int crow, int ccol);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranGradX(/*const*/ float* px, /*const*/ float* py, float* pmat, float* paccGrads, float* paccUpdates,
- float decay, float cond, int crow, int ccol);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranGradRU(/*const*/ float* px, /*const*/ float* py, /*const*/ int* pstarts, /*const*/ int* pindices,
- float* pcoefs, float* paccGrads, float* paccUpdates, float decay, float cond, int crow);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranGradRX(/*const*/ float* px, /*const*/ float* py, /*const*/ int* pstarts, /*const*/ int* pindices,
- float* pcoefs, float* paccGrads, float* paccUpdates, float decay, float cond, int crow);
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranGradPA(/*const*/ float* px, /*const*/ int* pposY, /*const*/ float* pvaluesY,
- int posMinY, int iposMinY, int iposLimY, float* pmat, float* paccGrads, float* paccUpdates,
- float decay, float cond, int crow, int ccol);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddXYTranGradPX(/*const*/ float* px, /*const*/ int* pposY, /*const*/ float* pvaluesY,
- int posMinY, int iposMinY, int iposLimY, float* pmat, float* paccGrads, float* paccUpdates,
- float decay, float cond, int crow, int ccol);
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern void Scale(float a, float* pd, int c);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ScaleX(float a, float* pd, int c);
+
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern void ScaleSrcU(float a, /*const*/ float* ps, float* pd, int c);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ScaleAddU(float a, float b, float* pd, int c);
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ScaleMaxNormA(float maxNorm, float* pmat, int crow, int ccol);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ScaleMaxNormX(float maxNorm, float* pmat, int crow, int ccol);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ScaleMaxNormTranU(float maxNorm, float* pmat, int crow, int ccol);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ScaleMaxNormRU(float maxNorm, /*const*/ int* pstarts, float* pmat, int crow);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ScaleMaxNormCU(float maxNorm, int kernCount, int kernSize, float* pmat);
+ public static extern void ScaleAddU(float a, float b, float* pd, int c);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddScaleA(float a, /*const*/ float* ps, float* pd, int c);
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern void AddScaleU(float a, /*const*/ float* ps, float* pd, int c);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddScaleX(float a, /*const*/ float* ps, float* pd, int c);
+
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern void AddScaleSU(float a, /*const*/ float* ps, /*const*/ int* pi, float* pd, int c);
+
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern void AddScaleCopyU(float a, /*const*/ float* ps, /*const*/ float* pd, float* pr, int c);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddScaleMomA(float a, /*const*/ float* ps, float* pd, float momentum, float* pdel, int c);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddScaleMomX(float a, /*const*/ float* ps, float* pd, float momentum, float* pdel, int c);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddScaleGradA(/*const*/ float* ps, float* pd, float* paccGrads, float* paccUpdates,
- float decay, float cond, int c);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddScaleGradX(/*const*/ float* ps, float* pd, float* paccGrads, float* paccUpdates,
- float decay, float cond, int c);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddScaleMultiA(int count, /*const*/ float* ps, float* pd, float* paccGrads,
- float* paccUpdates, float decay, float cond, int size);
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern void AddScalarU(float a, float* pd, int c);
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern void AddU(/*const*/ float* ps, float* pd, int c);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddA(/*const*/ float* ps, float* pd, int c);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void AddX(/*const*/ float* ps, float* pd, int c);
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern void AddSU(/*const*/ float* ps, /*const*/ int* pi, float* pd, int c);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern float SumA(/*const*/ float* ps, int c);
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern float SumU(/*const*/ float* ps, int c);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern float SumX(/*const*/ float* ps, int c);
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern float SumSqU(/*const*/ float* ps, int c);
+
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern float SumSqDiffU(float mean, /*const*/ float* ps, int c);
+
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern float SumAbsU(/*const*/ float* ps, int c);
+
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern float SumAbsDiffU(float mean, /*const*/ float* ps, int c);
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern float MulElementWiseU(/*const*/ float* ps1, /*const*/float* ps2, float* pd, int c);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern float MulElementWiseSU(/*const*/ float* ps1, /*const*/float* ps2, /*const*/ int* pi, float* pd, int c);
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern float MaxAbsU(/*const*/ float* ps, int c);
+
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern float MaxAbsDiffU(float mean, /*const*/ float* ps, int c);
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern float DotU(/*const*/ float* pa, /*const*/ float* pb, int c);
+
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern float DotSU(/*const*/ float* pa, /*const*/ float* pb, /*const*/ int* pi, int c);
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern float Dist2(/*const*/ float* px, /*const*/ float* py, int c);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ApplySigmoidA(/*const*/ float* ps, float* pd, int c);
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplySigmoidX(/*const*/ float* ps, float* pd, int c)
- {
- ApplySigmoidA(ps, pd, c);
- }
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ApplySoftMaxU(float* ps, float* pd, int c);
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplySoftMaxA(float* ps, float* pd, int c)
- {
- ApplySoftMaxU(ps, pd, c);
- }
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplySoftMaxX(float* ps, float* pd, int c)
- {
- ApplySoftMaxU(ps, pd, c);
- }
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ApplyRectifiedLinearA(float* ps, float* pd, int c);
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplyRectifiedLinearX(float* ps, float* pd, int c)
- {
- ApplyRectifiedLinearA(ps, pd, c);
- }
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ApplySquareA(float* ps, float* pd, int c);
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplySquareX(float* ps, float* pd, int c)
- {
- ApplySquareA(ps, pd, c);
- }
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ApplySqrtA(float* ps, float* pd, int c);
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplySqrtX(float* ps, float* pd, int c)
- {
- ApplySqrtA(ps, pd, c);
- }
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ApplySoftRectifiedLinearU(float* ps, float* pd, int c);
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplySoftRectifiedLinearA(float* ps, float* pd, int c)
- {
- ApplySoftRectifiedLinearU(ps, pd, c);
- }
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplySoftRectifiedLinearX(float* ps, float* pd, int c)
- {
- ApplySoftRectifiedLinearU(ps, pd, c);
- }
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ApplyAbsA(float* ps, float* pd, int c);
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplyAbsX(float* ps, float* pd, int c)
- {
- ApplyAbsA(ps, pd, c);
- }
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ApplyTanhA(float* ps, float* pd, int c);
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplyTanhX(float* ps, float* pd, int c)
- {
- ApplyTanhA(ps, pd, c);
- }
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ApplyBoundedRectifiedLinearA(float* ps, float* pd, int c);
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplyBoundedRectifiedLinearX(float* ps, float* pd, int c)
- {
- ApplyBoundedRectifiedLinearA(ps, pd, c);
- }
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ApplySigmoidDerivativeA(/*const*/ float* pv, float* pg, int c);
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplySigmoidDerivativeX(/*const*/ float* pv, float* pg, int c)
- {
- ApplySigmoidDerivativeA(pv, pg, c);
- }
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ApplyRectifiedLinearDerivativeA(/*const*/ float* pv, float* pg, int c);
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplyRectifiedLinearDerivativeX(/*const*/ float* pv, float* pg, int c)
- {
- ApplyRectifiedLinearDerivativeA(pv, pg, c);
- }
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ApplySquareDerivativeA(/*const*/ float* px, /*const*/ float* py, float* pg, int c, bool drop);
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplySquareDerivativeX(/*const*/ float* px, /*const*/ float* py, float* pg, int c, bool drop)
- {
- ApplySquareDerivativeA(px, py, pg, c, drop);
- }
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ApplySqrtDerivativeA(/*const*/ float* pv, float* pg, int c);
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplySqrtDerivativeX(/*const*/ float* pv, float* pg, int c)
- {
- ApplySqrtDerivativeA(pv, pg, c);
- }
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ApplySoftRectifiedLinearDerivativeU(/*const*/ float* px, /*const*/ float* py, float* pg, int c);
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplySoftRectifiedLinearDerivativeA(/*const*/ float* px, /*const*/ float* py, float* pg, int c)
- {
- ApplySoftRectifiedLinearDerivativeU(px, py, pg, c);
- }
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplySoftRectifiedLinearDerivativeX(/*const*/ float* px, /*const*/ float* py, float* pg, int c)
- {
- ApplySoftRectifiedLinearDerivativeU(px, py, pg, c);
- }
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ApplyAbsDerivativeA(/*const*/ float* px, /*const*/ float* py, float* pg, int c, bool drop);
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplyAbsDerivativeX(/*const*/ float* px, /*const*/ float* py, float* pg, int c, bool drop)
- {
- ApplyAbsDerivativeA(px, py, pg, c, drop);
- }
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ApplyTanhDerivativeA(/*const*/ float* pv, float* pg, int c);
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplyTanhDerivativeX(/*const*/ float* pv, float* pg, int c)
- {
- ApplyTanhDerivativeA(pv, pg, c);
- }
-
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ApplyBoundedRectifiedLinearDerivativeA(/*const*/ float* pv, float* pg, int c);
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void ApplyBoundedRectifiedLinearDerivativeX(/*const*/ float* pv, float* pg, int c)
- {
- ApplyBoundedRectifiedLinearDerivativeA(pv, pg, c);
- }
-
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern void ZeroItemsU(float* pd, int c, /*const*/ int* pindices, int cindices);
@@ -411,15 +93,9 @@ public static void ApplyBoundedRectifiedLinearDerivativeX(/*const*/ float* pv, f
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
public static extern void SdcaL1UpdateU(float primalUpdate, /*const*/ float* ps, float threshold, float* pd1, float* pd2, int c);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void SdcaL1UpdateSU(float primalUpdate, /*const*/ float* ps, /*const*/ int* pi, float threshold, float* pd1, float* pd2, int c);
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ScaleAdadeltaU(float* mat, float* accGrads, float* accUpdates, float decay, float cond, float* grads, int size);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ScaleAdadeltaA(float* mat, float* accGrads, float* accUpdates, float decay, float cond, float* grads, int size);
- [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
- public static extern void ScaleAdadeltaX(float* mat, float* accGrads, float* accUpdates, float decay, float cond, float* grads, int size);
+ public static extern void SdcaL1UpdateSU(float primalUpdate, /*const*/ float* ps, /*const*/ int* pi, float threshold, float* pd1, float* pd2, int c);
#if !CORECLR
// In CoreCLR we use Buffer.MemoryCopy directly instead of
diff --git a/src/Native/CpuMathNative/Avx.cpp b/src/Native/CpuMathNative/Avx.cpp
deleted file mode 100644
index b52e28cd80..0000000000
--- a/src/Native/CpuMathNative/Avx.cpp
+++ /dev/null
@@ -1,1285 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-// The exported function names need to be unique (can't be disambiguated based on signature), hence
-// we introduce suffix letters to indicate the general patterns used.
-// * A suffix means aligned and padded for SSE operations.
-// * X suffix means aligned and padded for AVX operations.
-// * U suffix means unaligned and unpadded.
-// * S suffix means sparse (unaligned) vector.
-// * P suffix means sparse (unaligned) partial vector - the vector is only part of a larger sparse vector.
-// * R suffix means sparse matrix.
-// * C suffix means convolution matrix.
-// * D suffix means convolution matrix, with implicit source padding.
-// * Tran means the matrix is transposed.
-//
-// Other notes:
-// * AVX methods should end with _vleave() to avoid performance hit. See:
-// https://stackoverflow.com/questions/7839925/using-avx-cpu-instructions-poor-performance-without-archavx.
-// * Keep Avx.cpp in sync with Sse.cpp. Note that Avx.cpp is compiled with /arch:AVX, but Sse.cpp is not.
-
-// REVIEW: There is code below that mixes SSE and AVX instructions. Does compiling with /arch:AVX
-// make that OK? Does the code need to be rewritten?
-
-#include "../Stdafx.h"
-#include
-
-#ifndef _WIN32
-#define _mm256_set_m128(va, vb) _mm256_insertf128_ps(_mm256_castps128_ps256(vb), va, 1)
-#endif
-
-#define _vleave _mm256_zeroupper
-
-#define _get_lo(x) _mm256_extractf128_ps(x, 0)
-#define _get_hi(x) _mm256_extractf128_ps(x, 1)
-
-#define _load1(ps, pi) \
- _mm_set_ss(ps[pi[0]])
-
-#define _load4(ps, pi) \
- _mm_setr_ps(ps[pi[0]], ps[pi[1]], ps[pi[2]], ps[pi[3]])
-
-#define _load8(ps, pi) \
- _mm256_setr_ps(ps[pi[0]], ps[pi[1]], ps[pi[2]], ps[pi[3]], ps[pi[4]], ps[pi[5]], ps[pi[6]], ps[pi[7]])
-
-#define _rotate(x) _mm_shuffle_ps(x, x, 0x39)
-
-#define _store1(x, pd, pi) \
- _mm_store_ss(pd + pi[0], x)
-
-//Warning: this operation changes the value of x => do not reuse x
-#define _store4(x, pd, pi) \
- _mm_store_ss(pd + pi[0], x); \
- x = _rotate(x); _mm_store_ss(pd + pi[1], x); \
- x = _rotate(x); _mm_store_ss(pd + pi[2], x); \
- x = _rotate(x); _mm_store_ss(pd + pi[3], x)
-
-#define _store8(x, pd, pi) \
- __m128 tmp = _get_lo(x); _mm_store_ss(pd + pi[0], tmp); \
- tmp = _rotate(tmp); _mm_store_ss(pd + pi[1], tmp); \
- tmp = _rotate(tmp); _mm_store_ss(pd + pi[2], tmp); \
- tmp = _rotate(tmp); _mm_store_ss(pd + pi[3], tmp); \
- tmp = _get_hi(x); _mm_store_ss(pd + pi[4], tmp); \
- tmp = _rotate(tmp); _mm_store_ss(pd + pi[5], tmp); \
- tmp = _rotate(tmp); _mm_store_ss(pd + pi[6], tmp); \
- tmp = _rotate(tmp); _mm_store_ss(pd + pi[7], tmp)
-
-// Multiply matrix times vector into vector.
-EXPORT_API(void) MatMulX(bool add, _In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol)
-{
- const float * psLim = psrc + ccol;
- const float * pdLim = pdst + crow;
- const float * pm = pmat;
- for (float * pd = pdst; pd < pdLim; pd += 4, pm += 3 * ccol)
- {
- const float * ps = psrc;
- __m256 res0 = _mm256_setzero_ps();
- __m256 res1 = res0;
- __m256 res2 = res0;
- __m256 res3 = res0;
- for (; ps < psLim; ps += 8, pm += 8)
- {
- const float * pmTmp;
- __m256 x01 = _mm256_load_ps(pmTmp = pm);
- __m256 x11 = _mm256_load_ps(pmTmp += ccol);
- __m256 x21 = _mm256_load_ps(pmTmp += ccol);
- __m256 x31 = _mm256_load_ps(pmTmp += ccol);
- __m256 x02 = _mm256_load_ps(ps);
- x01 = _mm256_mul_ps(x01, x02);
- x11 = _mm256_mul_ps(x11, x02);
- x21 = _mm256_mul_ps(x21, x02);
- x31 = _mm256_mul_ps(x31, x02);
- res0 = _mm256_add_ps(res0, x01);
- res1 = _mm256_add_ps(res1, x11);
- res2 = _mm256_add_ps(res2, x21);
- res3 = _mm256_add_ps(res3, x31);
- }
-
- // Add up the entries of each, with the 4x2 results in res0
- res0 = _mm256_hadd_ps(res0, res1);
- res2 = _mm256_hadd_ps(res2, res3);
- res0 = _mm256_hadd_ps(res0, res2);
-
- __m128 sum = _mm_add_ps(_get_lo(res0), _get_hi(res0));
- if (add)
- sum = _mm_add_ps(sum, _mm_load_ps(pd));
- _mm_store_ps(pd, sum);
- }
-
- _vleave();
-}
-
-// Partial sparse source vector.
-EXPORT_API(void) MatMulPX(bool add, _In_ const float * pmat, _In_ const int * pposSrc, _In_ const float * psrc,
- int posMin, int iposMin, int iposLim, _Inout_ float * pdst, int crow, int ccol)
-{
- const int * pposMin = pposSrc + iposMin;
- const int * pposLim = pposSrc + iposLim;
- const float * pdLim = pdst + crow;
- const float * pm0 = pmat - posMin;
- const float * ps = psrc - posMin;
- for (float * pd = pdst; pd < pdLim; pd += 8, pm0 += 8 * ccol)
- {
- const float * pm1 = pm0 + ccol;
- const float * pm2 = pm1 + ccol;
- const float * pm3 = pm2 + ccol;
- __m256 res = _mm256_setzero_ps();
- for (const int * ppos = pposMin; ppos < pposLim; ppos++)
- {
- int col1 = *ppos;
- int col2 = col1 + 4 * ccol;
- __m256 x1 = _mm256_setr_ps(
- pm0[col1], pm1[col1], pm2[col1], pm3[col1],
- pm0[col2], pm1[col2], pm2[col2], pm3[col2]);
- __m256 x2 = _mm256_set1_ps(ps[col1]);
- x2 = _mm256_mul_ps(x2, x1);
- res = _mm256_add_ps(res, x2);
- }
-
- if (add)
- res = _mm256_add_ps(res, _mm256_load_ps(pd));
- _mm256_store_ps(pd, res);
- }
-
- _vleave();
-}
-
-// Sparse matrix.
-EXPORT_API(void) MatMulRX(bool add, _In_ const int * pstarts, _In_ const int * pindices, _In_ const float * pcoefs,
- _In_ const float * ps, _Inout_ float * pdst, int crow)
-{
- const int * pii = pstarts + 1;
- const int * pi = pindices;
- const float * pm = pcoefs;
- const float * pdLim = pdst + crow;
- for (float * pd = pdst; pd < pdLim; pd++)
- {
- const int * piLim = pindices + *pii++;
-
- __m256 res2 = _mm256_setzero_ps();
- for (; pi + 8 <= piLim; pi += 8, pm += 8)
- {
- __m256 x = _mm256_mul_ps(_load8(ps, pi), _mm256_loadu_ps(pm));
- res2 = _mm256_add_ps(res2, x);
- }
- __m128 res = _mm_add_ps(_get_lo(res2), _get_hi(res2));
- if (pi + 4 <= piLim)
- {
- __m128 x = _mm_mul_ps(_load4(ps, pi), _mm_loadu_ps(pm));
- res = _mm_add_ps(res, x);
- pi += 4; pm += 4;
- }
- for (; pi < piLim; pi++, pm++)
- {
- __m128 x = _mm_mul_ss(_load1(ps, pi), _mm_set_ss(*pm));
- res = _mm_add_ss(res, x);
- }
- res = _mm_hadd_ps(res, res);
- res = _mm_hadd_ps(res, res);
-
- if (add)
- res = _mm_add_ss(res, _mm_set_ss(*pd));
- _mm_store_ss(pd, res);
- }
-
- _vleave();
-}
-
-// Unpadded convolution.
-EXPORT_API(void) MatMulCX(bool add, _In_ const int * pmprowiv, _In_ const int * pmprowcol,
- _In_ const int * pruns, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pdst, int crow)
-{
- int size = pruns[1];
- const int * psupport = pruns + 2;
- const int * piv = pmprowiv;
- const int * pcol = pmprowcol;
- const int * piLim = psupport + size;
- const float * pdLim = pdst + crow;
-
- for (float * pd = pdst; pd < pdLim; pd++)
- {
- const float * pm = pcoefs + *piv++;
- const float * ps = psrc + *pcol++;
- const int * pi = psupport;
-
- __m256 res2 = _mm256_setzero_ps();
- for (; pi + 8 <= piLim; pi += 8, pm += 8)
- {
- __m256 x = _mm256_mul_ps(_load8(ps, pi), _mm256_loadu_ps(pm));
- res2 = _mm256_add_ps(res2, x);
- }
- __m128 res = _mm_add_ps(_get_lo(res2), _get_hi(res2));
- if (pi + 4 <= piLim)
- {
- __m128 x = _mm_mul_ps(_load4(ps, pi), _mm_loadu_ps(pm));
- res = _mm_add_ps(res, x);
- pi += 4; pm += 4;
- }
- for (; pi < piLim; pi++, pm++)
- {
- __m128 x = _mm_mul_ss(_load1(ps, pi), _mm_set_ss(*pm));
- res = _mm_add_ss(res, x);
- }
- res = _mm_hadd_ps(res, res);
- res = _mm_hadd_ps(res, res);
-
- // Add the bias.
- res = _mm_add_ss(res, _mm_set_ss(*pm));
-
- if (add)
- res = _mm_add_ss(res, _mm_set_ss(*pd));
- _mm_store_ss(pd, res);
- }
-
- _vleave();
-}
-
-// Padded convolution.
-EXPORT_API(void) MatMulDX(bool add, _In_ const int * pmprowiv, _In_ const int * pmprowcol, _In_ const int * pmprowrun,
- _In_ const int * pruns, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pdst, int crow)
-{
- const int * piv = pmprowiv;
- const int * pcol = pmprowcol;
- const float * pdLim = pdst + crow;
- int kernelSize = pruns[1];
-
- const int * pirun = pmprowrun;
- for (float * pd = pdst; pd < pdLim; pd++)
- {
- const float * pm = pcoefs + *piv++;
- const float * pmBias = pm + kernelSize;
- const float * ps = psrc + *pcol++;
- int irun = *pirun++;
-
- const int * pi = pruns + 2 + irun;
- const int * piLim = pi + pi[-1];
- __m256 res2 = _mm256_setzero_ps();
- __m128 res;
- if (irun == 0)
- {
- // No masking needed.
- for (; pi + 8 <= piLim; pi += 8, pm += 8)
- {
- __m256 x = _mm256_mul_ps(_load8(ps, pi), _mm256_loadu_ps(pm));
- res2 = _mm256_add_ps(res2, x);
- }
- res = _mm_add_ps(_get_lo(res2), _get_hi(res2));
- if (pi + 4 <= piLim)
- {
- __m128 x = _mm_mul_ps(_load4(ps, pi), _mm_loadu_ps(pm));
- res = _mm_add_ps(res, x);
- pi += 4; pm += 4;
- }
- for (; pi < piLim; pi++, pm++)
- {
- __m128 x = _mm_mul_ss(_load1(ps, pi), _mm_set_ss(*pm));
- res = _mm_add_ss(res, x);
- }
- }
- else
- {
- // Need masking.
- pm += pi[-2];
- const float * pmask = reinterpret_cast(piLim);
- for (; pi + 8 <= piLim; pi += 8, pm += 8, pmask += 8)
- {
- __m256 x = _mm256_mul_ps(_load8(ps, pi), _mm256_and_ps(_mm256_loadu_ps(pmask), _mm256_loadu_ps(pm)));
- res2 = _mm256_add_ps(res2, x);
- }
- res = _mm_add_ps(_get_lo(res2), _get_hi(res2));
- if (pi + 4 <= piLim)
- {
- __m128 x = _mm_mul_ps(_load4(ps, pi), _mm_and_ps(_mm_loadu_ps(pmask), _mm_loadu_ps(pm)));
- res = _mm_add_ps(res, x);
- pi += 4; pm += 4; pmask += 4;
- }
- for (; pi < piLim; pi++, pm++, pmask++)
- {
- __m128 x = _mm_mul_ss(_load1(ps, pi), _mm_and_ps(_mm_set_ss(*pmask), _mm_set_ss(*pm)));
- res = _mm_add_ss(res, x);
- }
- }
- res = _mm_hadd_ps(res, res);
- res = _mm_hadd_ps(res, res);
-
- res = _mm_add_ss(res, _mm_set_ss(*pmBias));
- if (add)
- res = _mm_add_ss(res, _mm_set_ss(*pd));
- _mm_store_ss(pd, res);
- }
-
- _vleave();
-}
-
-EXPORT_API(void) MatMulTranX(bool add, _In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol)
-{
- const float * psLim = psrc + ccol;
- const float * pdLim = pdst + crow;
- const float * pm = pmat;
- const float * ps = psrc;
-
- // We do 4-way unrolling
- if (!add)
- {
- __m128 h01 = _mm_load_ps(ps);
- // Replicate each slot of x01 into its own register.
- __m128 h11 = _mm_shuffle_ps(h01, h01, 0x55);
- __m128 h21 = _mm_shuffle_ps(h01, h01, 0xAA);
- __m128 h31 = _mm_shuffle_ps(h01, h01, 0xFF);
- h01 = _mm_shuffle_ps(h01, h01, 0x00);
-
- __m256 x01 = _mm256_set_m128(h01, h01);
- __m256 x11 = _mm256_set_m128(h11, h11);
- __m256 x21 = _mm256_set_m128(h21, h21);
- __m256 x31 = _mm256_set_m128(h31, h31);
- ps += 4;
-
- for (float * pd = pdst; pd < pdLim; pd += 8, pm += 8)
- {
- const float * pmTmp;
- __m256 x02 = _mm256_load_ps(pmTmp = pm);
- __m256 x12 = _mm256_load_ps(pmTmp += crow);
- __m256 x22 = _mm256_load_ps(pmTmp += crow);
- __m256 x32 = _mm256_load_ps(pmTmp += crow);
- x02 = _mm256_mul_ps(x01, x02);
- x12 = _mm256_mul_ps(x11, x12);
- x22 = _mm256_mul_ps(x21, x22);
- x32 = _mm256_mul_ps(x31, x32);
- x02 = _mm256_add_ps(x02, x12);
- x22 = _mm256_add_ps(x22, x32);
- x02 = _mm256_add_ps(x02, x22);
- _mm256_store_ps(pd, x02);
- }
-
- pm += 3 * crow;
- }
-
- for (; ps < psLim; ps += 4)
- {
- __m128 h01 = _mm_load_ps(ps);
- // Replicate each slot of x01 into its own register.
- __m128 h11 = _mm_shuffle_ps(h01, h01, 0x55);
- __m128 h21 = _mm_shuffle_ps(h01, h01, 0xAA);
- __m128 h31 = _mm_shuffle_ps(h01, h01, 0xFF);
- h01 = _mm_shuffle_ps(h01, h01, 0x00);
-
- __m256 x01 = _mm256_set_m128(h01, h01);
- __m256 x11 = _mm256_set_m128(h11, h11);
- __m256 x21 = _mm256_set_m128(h21, h21);
- __m256 x31 = _mm256_set_m128(h31, h31);
-
- for (float * pd = pdst; pd < pdLim; pd += 8, pm += 8)
- {
- const float * pmTmp;
- __m256 x02 = _mm256_load_ps(pmTmp = pm);
- __m256 x12 = _mm256_load_ps(pmTmp += crow);
- __m256 x22 = _mm256_load_ps(pmTmp += crow);
- __m256 x32 = _mm256_load_ps(pmTmp += crow);
- __m256 x3 = _mm256_load_ps(pd);
- x02 = _mm256_mul_ps(x01, x02);
- x12 = _mm256_mul_ps(x11, x12);
- x22 = _mm256_mul_ps(x21, x22);
- x32 = _mm256_mul_ps(x31, x32);
- x02 = _mm256_add_ps(x02, x12);
- x22 = _mm256_add_ps(x22, x32);
- x02 = _mm256_add_ps(x02, x22);
- x3 = _mm256_add_ps(x02, x3);
- _mm256_store_ps(pd, x3);
- }
-
- pm += 3 * crow;
- }
-
- _vleave();
-}
-
-// Sparse matrix.
-EXPORT_API(void) MatMulTranRX(bool add, _In_ const int * pstarts, _In_ const int * pindices, _In_ const float * pcoefs,
- _In_ const float * psrc, _Inout_ float * pd, int crow, int ccol)
-{
- if (!add)
- memset(pd, 0, crow * sizeof(float));
-
- const int * pii = pstarts + 1;
- const int * pi = pindices;
- const float * pm = pcoefs;
- const float * psLim = psrc + ccol;
- for (const float * ps = psrc; ps < psLim; ps++)
- {
- float x = *ps;
- const int * piLim = pindices + *pii++;
-
- __m128 x0 = _mm_set1_ps(x);
- __m256 x1 = _mm256_set_m128(x0, x0);
- for (; pi + 8 <= piLim; pi += 8, pm += 8)
- {
- __m256 x2 = _mm256_mul_ps(x1, _mm256_loadu_ps(pm));
- x2 = _mm256_add_ps(x2, _load8(pd, pi));
- _store8(x2, pd, pi);
- }
- if (pi + 4 <= piLim)
- {
- __m128 x2 = _mm_mul_ps(x0, _mm_loadu_ps(pm));
- x2 = _mm_add_ps(x2, _load4(pd, pi));
- _store4(x2, pd, pi);
- pi += 4; pm += 4;
- }
- for (; pi < piLim; pi++, pm++)
- {
- __m128 x2 = _mm_mul_ss(x0, _mm_set_ss(*pm));
- x2 = _mm_add_ss(x2, _load1(pd, pi));
- _store1(x2, pd, pi);
- }
- }
-
- _vleave();
-}
-
-// Unpadded convolution.
-EXPORT_API(void) MatMulTranCX(bool add, _In_ const int * pmpcoliv, _In_ const int * pmpcolrow,
- _In_ const int * pruns, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol)
-{
- if (!add)
- memset(pdst, 0, crow * sizeof(float));
-
- int size = pruns[1];
- const int * psupport = pruns + 2;
- const int * piv = pmpcoliv;
- const int * prow = pmpcolrow;
- const int * piLim = psupport + size;
- const float * psLim = psrc + ccol;
- for (const float * ps = psrc; ps < psLim; ps++)
- {
- const float * pm = pcoefs + *piv++;
- float * pd = pdst + *prow++;
- const int * pi = psupport;
-
- float x = *ps;
- __m128 x0 = _mm_set1_ps(x);
- __m256 x1 = _mm256_set_m128(x0, x0);
- for (; pi + 8 <= piLim; pi += 8, pm += 8)
- {
- __m256 x2 = _mm256_mul_ps(x1, _mm256_loadu_ps(pm));
- x2 = _mm256_add_ps(x2, _load8(pd, pi));
- _store8(x2, pd, pi);
- }
- if (pi + 4 <= piLim)
- {
- __m128 x2 = _mm_mul_ps(x0, _mm_loadu_ps(pm));
- x2 = _mm_add_ps(x2, _load4(pd, pi));
- _store4(x2, pd, pi);
- pi += 4; pm += 4;
- }
- for (; pi < piLim; pi++, pm++)
- {
- __m128 x2 = _mm_mul_ss(x0, _mm_set_ss(*pm));
- x2 = _mm_add_ss(x2, _load1(pd, pi));
- _store1(x2, pd, pi);
- }
- }
-
- _vleave();
-}
-
-// Padded convolution.
-EXPORT_API(void) MatMulTranDX(bool add, _In_ const int * pmpcoliv, _In_ const int * pmpcolrow, _In_ const int * pmpcolrun,
- _In_ const int * pruns, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol)
-{
- if (!add)
- memset(pdst, 0, crow * sizeof(float));
-
- const int * piv = pmpcoliv;
- const int * prow = pmpcolrow;
- const float * psLim = psrc + ccol;
- int kernelSize = pruns[1];
-
- const int * pirun = pmpcolrun;
- for (const float * ps = psrc; ps < psLim; ps++)
- {
- const float * pm = pcoefs + *piv++;
- float * pd = pdst + *prow++;
- int irun = *pirun++;
- const int * pi = pruns + 2 + irun;
- const int * piLim = pi + pi[-1];
-
- float x = *ps;
- __m128 x0 = _mm_set1_ps(x);
- __m256 x1 = _mm256_set_m128(x0, x0);
- if (irun == 0)
- {
- // No masking needed.
- for (; pi + 8 <= piLim; pi += 8, pm += 8)
- {
- __m256 x2 = _mm256_mul_ps(x1, _mm256_loadu_ps(pm));
- x2 = _mm256_add_ps(x2, _load8(pd, pi));
- _store8(x2, pd, pi);
- }
- if (pi + 4 <= piLim)
- {
- __m128 x2 = _mm_mul_ps(x0, _mm_loadu_ps(pm));
- x2 = _mm_add_ps(x2, _load4(pd, pi));
- _store4(x2, pd, pi);
- pi += 4; pm += 4;
- }
- for (; pi < piLim; pi++, pm++)
- {
- __m128 x2 = _mm_mul_ss(x0, _mm_set_ss(*pm));
- x2 = _mm_add_ss(x2, _load1(pd, pi));
- _store1(x2, pd, pi);
- }
- }
- else
- {
- // Need masking.
- pm += pi[-2];
- const float * pmask = reinterpret_cast(piLim);
- for (; pi + 8 <= piLim; pi += 8, pm += 8, pmask += 8)
- {
- __m256 x2 = _mm256_mul_ps(_mm256_and_ps(_mm256_loadu_ps(pmask), x1), _mm256_loadu_ps(pm));
- x2 = _mm256_add_ps(x2, _load8(pd, pi));
- _store8(x2, pd, pi);
- }
- if (pi + 4 <= piLim)
- {
- __m128 x2 = _mm_mul_ps(_mm_and_ps(_mm_loadu_ps(pmask), x0), _mm_loadu_ps(pm));
- x2 = _mm_add_ps(x2, _load4(pd, pi));
- _store4(x2, pd, pi);
- pi += 4; pm += 4; pmask += 4;
- }
- for (; pi < piLim; pi++, pm++, pmask++)
- {
- __m128 x2 = _mm_mul_ss(_mm_and_ps(_mm_set_ss(*pmask), x0), _mm_set_ss(*pm));
- x2 = _mm_add_ss(x2, _load1(pd, pi));
- _store1(x2, pd, pi);
- }
- }
- }
-
- _vleave();
-}
-
-template
-void AddXYTranXCore(float a, _In_ const float * px, _In_ const float * py, _Inout_ float * pmat, int crow, int ccol, float decay)
-{
- const float * pyBase = py;
- const float * pxLim = px + crow;
- const float * pyLim = py + ccol;
- float * pm = pmat;
- __m256 wd;
- if (useDecay)
- wd = _mm256_set1_ps(1 - decay);
- for (; px < pxLim; px++)
- {
- float r = a * *px;
- py = pyBase;
-
- __m256 x1 = _mm256_set1_ps(r);
- for (; py + 32 <= pyLim; py += 32, pm += 32)
- {
- __m256 x02 = _mm256_load_ps(py);
- __m256 x12 = _mm256_load_ps(py + 8);
- __m256 x22 = _mm256_load_ps(py + 16);
- __m256 x32 = _mm256_load_ps(py + 24);
- __m256 x03 = _mm256_load_ps(pm);
- __m256 x13 = _mm256_load_ps(pm + 8);
- __m256 x23 = _mm256_load_ps(pm + 16);
- __m256 x33 = _mm256_load_ps(pm + 24);
- x02 = _mm256_mul_ps(x1, x02);
- x12 = _mm256_mul_ps(x1, x12);
- x22 = _mm256_mul_ps(x1, x22);
- x32 = _mm256_mul_ps(x1, x32);
- if (useDecay)
- {
- x03 = _mm256_mul_ps(wd, x03);
- x13 = _mm256_mul_ps(wd, x13);
- x23 = _mm256_mul_ps(wd, x23);
- x33 = _mm256_mul_ps(wd, x33);
- }
- x03 = _mm256_add_ps(x02, x03);
- x13 = _mm256_add_ps(x12, x13);
- x23 = _mm256_add_ps(x22, x23);
- x33 = _mm256_add_ps(x32, x33);
- _mm256_store_ps(pm, x03);
- _mm256_store_ps(pm + 8, x13);
- _mm256_store_ps(pm + 16, x23);
- _mm256_store_ps(pm + 24, x33);
- }
- for (; py < pyLim; py += 8, pm += 8)
- {
- __m256 x02 = _mm256_load_ps(py);
- __m256 x03 = _mm256_load_ps(pm);
- x02 = _mm256_mul_ps(x1, x02);
- if (useDecay)
- x03 = _mm256_mul_ps(wd, x03);
- x03 = _mm256_add_ps(x02, x03);
- _mm256_store_ps(pm, x03);
- }
- }
-
- _vleave();
-}
-
-EXPORT_API(void) AddXYTranX(float a, _In_ const float * px, _In_ const float * py, _Inout_ float * pmat, int crow, int ccol, float decay)
-{
- if (decay == 0)
- AddXYTranXCore(a, px, py, pmat, crow, ccol, decay);
- else
- AddXYTranXCore(a, px, py, pmat, crow, ccol, decay);
-}
-
-// Partial sparse source vector.
-EXPORT_API(void) AddXYTranPX(float a, _In_ const float * px, _In_ const int * pposY, _In_ const float * pvaluesY,
- int posMinY, int iposMinY, int iposLimY, _Inout_ float * pmat, int crow, int ccol)
-{
- const int * pposMin = pposY + iposMinY;
- const int * pposLim = pposY + iposLimY;
- const float * pxLim = px + crow;
- float * pm0 = pmat - posMinY;
- const float * py = pvaluesY - posMinY;
-
- __m256 x0 = _mm256_set1_ps(a);
- for (; px < pxLim; px += 8, pm0 += 8 * ccol)
- {
- float * pm1 = pm0 + ccol;
- float * pm2 = pm1 + ccol;
- float * pm3 = pm2 + ccol;
-
- __m256 x1 = _mm256_load_ps(px);
- x1 = _mm256_mul_ps(x1, x0);
-
- for (const int * ppos = pposMin; ppos < pposLim; ppos++)
- {
- int col1 = *ppos;
- int col2 = col1 + 4 * ccol;
- __m256 x2 = _mm256_set1_ps(py[col1]);
- __m256 x3 = _mm256_setr_ps(
- pm0[col1], pm1[col1], pm2[col1], pm3[col1],
- pm0[col2], pm1[col2], pm2[col2], pm3[col2]);
- x2 = _mm256_mul_ps(x2, x1);
- x3 = _mm256_add_ps(x3, x2);
-
- __m128 t1 = _get_lo(x3);
- __m128 t2 = _get_hi(x3);
- _mm_store_ss(pm0 + col1, t1); t1 = _rotate(t1);
- _mm_store_ss(pm1 + col1, t1); t1 = _rotate(t1);
- _mm_store_ss(pm2 + col1, t1); t1 = _rotate(t1);
- _mm_store_ss(pm3 + col1, t1);
- _mm_store_ss(pm0 + col2, t2); t2 = _rotate(t2);
- _mm_store_ss(pm1 + col2, t2); t2 = _rotate(t2);
- _mm_store_ss(pm2 + col2, t2); t2 = _rotate(t2);
- _mm_store_ss(pm3 + col2, t2);
- }
- }
-
- _vleave();
-}
-
-template
-void AddXYTranRXCore(float a, _In_ const float * px, _In_ const float * py,
- _In_ const int * pstarts, _In_ const int * pindices, _Inout_ float * pcoefs, int crow, float decay)
-{
- const int * pii = pstarts + 1;
- const int * pi = pindices;
- float * pm = pcoefs;
- const float * pxLim = px + crow;
- __m128 wd0;
- __m256 wd1;
- if (useDecay)
- {
- wd0 = _mm_set1_ps(1 - decay);
- wd1 = _mm256_set_m128(wd0, wd0);
- }
- for (; px < pxLim; px++)
- {
- const int * piLim = pindices + *pii++;
- float r = a * *px;
-
- __m128 x0 = _mm_set1_ps(r);
- __m256 x1 = _mm256_set_m128(x0, x0);
- for (; pi + 8 <= piLim; pi += 8, pm += 8)
- {
- __m256 x2 = _mm256_mul_ps(x1, _load8(py, pi));
- __m256 x3 = _mm256_loadu_ps(pm);
- if (useDecay)
- x3 = _mm256_mul_ps(x3, wd1);
- x2 = _mm256_add_ps(x2, x3);
- _mm256_storeu_ps(pm, x2);
- }
- if (pi + 4 <= piLim)
- {
- __m128 x2 = _mm_mul_ps(x0, _load4(py, pi));
- __m128 x3 = _mm_loadu_ps(pm);
- if (useDecay)
- x3 = _mm_mul_ps(x3, wd0);
- x2 = _mm_add_ps(x2, x3);
- _mm_storeu_ps(pm, x2);
- pi += 4; pm += 4;
- }
- for (; pi < piLim; pi++, pm++)
- *pm = (useDecay ? (*pm * (1 - decay)) : *pm) + py[*pi] * r;
- }
-
- _vleave();
-}
-
-// Sparse matrix.
-EXPORT_API(void) AddXYTranRX(float a, _In_ const float * px, _In_ const float * py,
- _In_ const int * pstarts, _In_ const int * pindices, _Inout_ float * pcoefs, int crow, float decay)
-{
- if (decay == 0)
- AddXYTranRXCore(a, px, py, pstarts, pindices, pcoefs, crow, decay);
- else
- AddXYTranRXCore(a, px, py, pstarts, pindices, pcoefs, crow, decay);
-}
-
-// Unpadded convolution.
-EXPORT_API(void) AddXYTranCX(float a, _In_ const float * px, _In_ const float * py, _In_ const int * pmprowiv, _In_ const int * pmprowcol,
- _In_ const int * pruns, _Inout_ float * pcoefs, int crow)
-{
- int size = pruns[1];
- const int * psupport = pruns + 2;
- const int * piv = pmprowiv;
- const int * pcol = pmprowcol;
- const float * pxLim = px + crow;
- const int * piLim = psupport + size;
-
- for (; px < pxLim; px++)
- {
- float * pm = pcoefs + *piv++;
- const float * ps = py + *pcol++;
- const int * pi = psupport;
- float r = a * *px;
-
- __m128 x0 = _mm_set1_ps(r);
- __m256 x1 = _mm256_set_m128(x0, x0);
- for (; pi + 8 <= piLim; pi += 8, pm += 8)
- {
- __m256 x2 = _mm256_mul_ps(x1, _load8(ps, pi));
- x2 = _mm256_add_ps(x2, _mm256_loadu_ps(pm));
- _mm256_storeu_ps(pm, x2);
- }
- if (pi + 4 <= piLim)
- {
- __m128 x2 = _mm_mul_ps(x0, _load4(ps, pi));
- x2 = _mm_add_ps(x2, _mm_loadu_ps(pm));
- _mm_storeu_ps(pm, x2);
- pi += 4; pm += 4;
- }
- for (; pi < piLim; pi++, pm++)
- *pm += ps[*pi] * r;
- // Update the bias.
- *pm += r;
- }
-
- _vleave();
-}
-
-// Padded convolution.
-EXPORT_API(void) AddXYTranDX(float a, _In_ const float * px, _In_ const float * py, _In_ const int * pmprowiv, _In_ const int * pmprowcol,
- _In_ const int * pmprowrun, _In_ const int * pruns, _Inout_ float * pcoefs, int crow)
-{
- const int * piv = pmprowiv;
- const int * pcol = pmprowcol;
- const float * pxLim = px + crow;
- int kernelSize = pruns[1];
-
- const int * pirun = pmprowrun;
- for (; px < pxLim; px++)
- {
- float * pm = pcoefs + *piv++;
- const float * ps = py + *pcol++;
- int irun = *pirun++;
- const int * pi = pruns + 2 + irun;
- const int * piLim = pi + pi[-1];
-
- float r = a * *px;
-
- // Update the bias.
- pm[kernelSize] += r;
-
- __m128 x0 = _mm_set1_ps(r);
- __m256 x1 = _mm256_set_m128(x0, x0);
- if (irun == 0)
- {
- // No masking needed.
- for (; pi + 8 <= piLim; pi += 8, pm += 8)
- {
- __m256 x2 = _mm256_mul_ps(x1, _load8(ps, pi));
- x2 = _mm256_add_ps(x2, _mm256_loadu_ps(pm));
- _mm256_storeu_ps(pm, x2);
- }
- if (pi + 4 <= piLim)
- {
- __m128 x2 = _mm_mul_ps(x0, _load4(ps, pi));
- x2 = _mm_add_ps(x2, _mm_loadu_ps(pm));
- _mm_storeu_ps(pm, x2);
- pi += 4; pm += 4;
- }
- for (; pi < piLim; pi++, pm++)
- *pm += ps[*pi] * r;
- }
- else
- {
- // Need masking.
- pm += pi[-2];
- const float * pmask = reinterpret_cast(piLim);
- for (; pi + 8 <= piLim; pi += 8, pm += 8, pmask += 8)
- {
- __m256 x2 = _mm256_mul_ps(_mm256_and_ps(_mm256_loadu_ps(pmask), x1), _load8(ps, pi));
- x2 = _mm256_add_ps(x2, _mm256_loadu_ps(pm));
- _mm256_storeu_ps(pm, x2);
- }
- if (pi + 4 <= piLim)
- {
- __m128 x2 = _mm_mul_ps(_mm_and_ps(_mm_loadu_ps(pmask), x0), _load4(ps, pi));
- x2 = _mm_add_ps(x2, _mm_loadu_ps(pm));
- _mm_storeu_ps(pm, x2);
- pi += 4; pm += 4; pmask += 4;
- }
- for (; pi < piLim; pi++, pm++, pmask++)
- {
- __m128 x2 = _mm_mul_ss(_mm_and_ps(_mm_set_ss(*pmask), x0), _load1(ps, pi));
- x2 = _mm_add_ss(x2, _mm_set_ss(*pm));
- _mm_store_ss(pm, x2);
- }
- }
- }
-
- _vleave();
-}
-
-// With momentum.
-EXPORT_API(void) AddXYTranMomX(float a, _In_ const float * px, _In_ const float * py, _Inout_ float * pmat, float momentum, _Inout_ float * pdel, int crow, int ccol)
-{
- const float * pyBase = py;
- const float * pxLim = px + crow;
- const float * pyLim = py + ccol;
- float * pm = pmat;
- float * pd = pdel;
-
- __m256 x0 = _mm256_set1_ps(momentum);
- for (; px < pxLim; px++)
- {
- float r = a * *px;
-
- __m256 x1 = _mm256_set1_ps(r);
- for (py = pyBase; py < pyLim; pm += 8, pd += 8, py += 8)
- {
- __m256 x2 = _mm256_load_ps(py);
- __m256 x3 = _mm256_load_ps(pd);
- __m256 x4 = _mm256_load_ps(pm);
- x2 = _mm256_mul_ps(x1, x2);
- x3 = _mm256_mul_ps(x0, x3);
- x3 = _mm256_add_ps(x2, x3);
- x4 = _mm256_add_ps(x3, x4);
-
- _mm256_store_ps(pd, x3);
- _mm256_store_ps(pm, x4);
- }
- }
-
- _vleave();
-}
-
-// coef: coefs to update, ag: accumulated grads, au: accumulated updates, g: cur grads.
-// Note: parameters coef, ag, au and g will be updated, do not reuse parameter g in calling code.
-__forceinline void UpdateAdadelta(__m256& coef, __m256& ag, __m256& au, __m256& g, const __m256& dec, const __m256& decc, const __m256& c)
-{
- __m256 x4 = _mm256_mul_ps(g, g); // x4 == g * g
- x4 = _mm256_mul_ps(decc, x4); // x4 == (1 - decay) * g * g
- ag = _mm256_mul_ps(dec, ag); // ag == decay * accG
- ag = _mm256_add_ps(ag, x4); // ag == decay * accG + (1 - decay) * g * g
- __m256 x41 = _mm256_add_ps(ag, c); // x41 == ag + cond
- __m256 x51 = _mm256_add_ps(au, c); // x51 == accU + cond
-#if 0
- // naive version:
- x51 = _mm256_div_ps(x51, x41);
- x41 = _mm256_sqrt_ps(x51); // x41 == rate
-#else
- // faster (approximate) version:
- x41 = _mm256_rsqrt_ps(x41);
- __m256 x52 = _mm256_rsqrt_ps(x51);
- x51 = _mm256_mul_ps(x51, x52);
- x41 = _mm256_mul_ps(x41, x51); // x41 == rate
-#endif
- g = _mm256_mul_ps(g, x41); // g - current update
- coef = _mm256_add_ps(coef, g);
-
- g = _mm256_mul_ps(g, g); // g == newU * newU
- g = _mm256_mul_ps(decc, g); // g == (1 - decay) * newU * newU
- au = _mm256_mul_ps(dec, au); // au == decay * accU
- au = _mm256_add_ps(au, g); // au == decay * accU + (1 - decay) * newU * newU
-}
-
-// For Adadelta.
-EXPORT_API(void) AddXYTranGradX(_In_ const float * px, _In_ const float * py, _Inout_ float * pmat, _Inout_ float * paccGrads, _Inout_ float * paccUpdates,
- float decay, float cond, int crow, int ccol)
-{
- const float * pyBase = py;
- const float * pxLim = px + crow;
- const float * pyLim = py + ccol;
- float * pm = pmat;
- float * pag = paccGrads;
- float * pau = paccUpdates;
-
- __m256 dec = _mm256_set1_ps(decay);
- __m256 decc = _mm256_set1_ps(1 - decay);
- __m256 c = _mm256_set1_ps(cond);
- for (; px < pxLim; px++)
- {
- float r = *px;
-
- __m256 x1 = _mm256_set1_ps(r);
- for (py = pyBase; py < pyLim; pm += 8, pag += 8, pau += 8, py += 8)
- {
- __m256 x2 = _mm256_load_ps(py);
- __m256 ag = _mm256_load_ps(pag);
- __m256 au = _mm256_load_ps(pau);
- __m256 coef = _mm256_load_ps(pm);
- x2 = _mm256_mul_ps(x1, x2); // x2 == g
-
- UpdateAdadelta(coef, ag, au, x2, dec, decc, c);
-
- _mm256_store_ps(pm, coef);
- _mm256_store_ps(pag, ag);
- _mm256_store_ps(pau, au);
- }
- }
-
- _vleave();
-}
-
-// For Adadelta, sparse matrix.
-EXPORT_API(void) AddXYTranGradRX(_In_ const float * px, _In_ const float * py, _In_ const int * pstarts, _In_ const int * pindices,
- _Inout_ float * pcoefs, _Inout_ float * paccGrads, _Inout_ float * paccUpdates, float decay, float cond, int crow)
-{
- const int * pii = pstarts + 1;
- const int * pi = pindices;
- float * pm = pcoefs;
- const float * pxLim = px + crow;
- float * pag = paccGrads;
- float * pau = paccUpdates;
-
- __m256 dec = _mm256_set1_ps(decay);
- __m256 decc = _mm256_set1_ps(1 - decay);
- __m256 c = _mm256_set1_ps(cond);
-
- for (; px < pxLim; px++)
- {
- const int * piLim = pindices + *pii++;
- float r = *px;
-
- __m256 x1 = _mm256_set1_ps(r);
- for (; pi + 8 <= piLim; pi += 8, pm += 8, pag += 8, pau += 8)
- {
- __m256 g = _mm256_mul_ps(x1, _load8(py, pi));
- __m256 ag = _mm256_loadu_ps(pag);
- __m256 au = _mm256_loadu_ps(pau);
- __m256 coef = _mm256_loadu_ps(pm);
-
- UpdateAdadelta(coef, ag, au, g, dec, decc, c);
-
- _mm256_storeu_ps(pm, coef);
- _mm256_storeu_ps(pag, ag);
- _mm256_storeu_ps(pau, au);
- }
-
- // REVIEW: Why is this so different than the SSE version?
- for (; pi < piLim; pi++, pm++, pag++, pau++)
- {
- float g = py[*pi] * r;
- float accGrad = decay * *pag + (1 - decay) * g * g;
- float accUpd = *pau;
- float newUpd = sqrtf((accUpd + cond) / (accGrad + cond)) * g;
- *pm += newUpd;
- *pag = accGrad;
- *pau = decay * accUpd + (1 - decay) * newUpd * newUpd;
- }
- }
-
- _vleave();
-}
-
-// For Adadelta, partial sparse source vector.
-EXPORT_API(void) AddXYTranGradPX(_In_ const float * px, _In_ const int * pposY, _In_ const float * pvaluesY,
- int posMinY, int iposMinY, int iposLimY, _Inout_ float * pmat, _Inout_ float * paccGrads, _Inout_ float * paccUpdates,
- float decay, float cond, int crow, int ccol)
-{
- const int * pposMin = pposY + iposMinY;
- const int * pposLim = pposY + iposLimY;
- const float * pxLim = px + crow;
- const float * py = pvaluesY - posMinY;
- float * pm0 = pmat - posMinY;
- float * pag0 = paccGrads - posMinY;
- float * pau0 = paccUpdates - posMinY;
-
- __m256 dec = _mm256_set1_ps(decay);
- __m256 decc = _mm256_set1_ps(1 - decay);
- __m256 c = _mm256_set1_ps(cond);
- for (; px < pxLim; px += 8, pm0 += 8 * ccol, pag0 += 8 * ccol, pau0 += 8 * ccol)
- {
- float * pm1 = pm0 + ccol;
- float * pm2 = pm1 + ccol;
- float * pm3 = pm2 + ccol;
-
- float * pag1 = pag0 + ccol;
- float * pag2 = pag1 + ccol;
- float * pag3 = pag2 + ccol;
-
- float * pau1 = pau0 + ccol;
- float * pau2 = pau1 + ccol;
- float * pau3 = pau2 + ccol;
-
- __m256 x1 = _mm256_load_ps(px);
-
- for (const int * ppos = pposMin; ppos < pposLim; ppos++)
- {
- int col1 = *ppos;
- int col2 = col1 + 4 * ccol;
- __m256 x2 = _mm256_set1_ps(py[col1]);
- __m256 ag = _mm256_setr_ps(
- pag0[col1], pag1[col1], pag2[col1], pag3[col1],
- pag0[col2], pag1[col2], pag2[col2], pag3[col2]);
- __m256 au = _mm256_setr_ps(
- pau0[col1], pau1[col1], pau2[col1], pau3[col1],
- pau0[col2], pau1[col2], pau2[col2], pau3[col2]);
- __m256 coef = _mm256_setr_ps(
- pm0[col1], pm1[col1], pm2[col1], pm3[col1],
- pm0[col2], pm1[col2], pm2[col2], pm3[col2]);
- x2 = _mm256_mul_ps(x2, x1);
-
- UpdateAdadelta(coef, ag, au, x2, dec, decc, c);
-
- __m128 t1 = _get_lo(coef);
- __m128 t2 = _get_hi(coef);
- _mm_store_ss(pm0 + col1, t1); t1 = _rotate(t1);
- _mm_store_ss(pm1 + col1, t1); t1 = _rotate(t1);
- _mm_store_ss(pm2 + col1, t1); t1 = _rotate(t1);
- _mm_store_ss(pm3 + col1, t1);
- _mm_store_ss(pm0 + col2, t2); t2 = _rotate(t2);
- _mm_store_ss(pm1 + col2, t2); t2 = _rotate(t2);
- _mm_store_ss(pm2 + col2, t2); t2 = _rotate(t2);
- _mm_store_ss(pm3 + col2, t2);
-
- t1 = _get_lo(ag);
- t2 = _get_hi(ag);
- _mm_store_ss(pag0 + col1, t1); t1 = _rotate(t1);
- _mm_store_ss(pag1 + col1, t1); t1 = _rotate(t1);
- _mm_store_ss(pag2 + col1, t1); t1 = _rotate(t1);
- _mm_store_ss(pag3 + col1, t1);
- _mm_store_ss(pag0 + col2, t2); t2 = _rotate(t2);
- _mm_store_ss(pag1 + col2, t2); t2 = _rotate(t2);
- _mm_store_ss(pag2 + col2, t2); t2 = _rotate(t2);
- _mm_store_ss(pag3 + col2, t2);
-
- t1 = _get_lo(au);
- t2 = _get_hi(au);
- _mm_store_ss(pau0 + col1, t1); t1 = _rotate(t1);
- _mm_store_ss(pau1 + col1, t1); t1 = _rotate(t1);
- _mm_store_ss(pau2 + col1, t1); t1 = _rotate(t1);
- _mm_store_ss(pau3 + col1, t1);
- _mm_store_ss(pau0 + col2, t2); t2 = _rotate(t2);
- _mm_store_ss(pau1 + col2, t2); t2 = _rotate(t2);
- _mm_store_ss(pau2 + col2, t2); t2 = _rotate(t2);
- _mm_store_ss(pau3 + col2, t2);
- }
- }
-
- _vleave();
-}
-
-EXPORT_API(void) ScaleX(float a, _Inout_ float * pd, int c)
-{
- float * pdLim = pd + c;
-
- __m256 x1 = _mm256_set1_ps(a);
- for (; pd < pdLim; pd += 8)
- {
- __m256 x2 = _mm256_load_ps(pd);
- x2 = _mm256_mul_ps(x1, x2);
- _mm256_store_ps(pd, x2);
- }
-
- _vleave();
-}
-
-EXPORT_API(void) ScaleMaxNormX(float maxNorm, _Inout_ float * pmat, int crow, int ccol)
-{
- float * pm = pmat;
- float maxNormSq = maxNorm * maxNorm;
- __m256 m = _mm256_set1_ps(maxNorm);
- for (int irow = 0; irow < crow; irow++)
- {
- __m256 rowNorm = _mm256_set1_ps(0);
- float * pms = pm;
- float * pmLim = pm + ccol;
- for (; pm < pmLim; pm += 8)
- {
- __m256 x1 = _mm256_load_ps(pm);
- x1 = _mm256_mul_ps(x1, x1);
- rowNorm = _mm256_add_ps(x1, rowNorm);
- }
- rowNorm = _mm256_hadd_ps(rowNorm, rowNorm);
- rowNorm = _mm256_hadd_ps(rowNorm, rowNorm);
- float rowNormRes = _mm_cvtss_f32(_mm_add_ss(_get_lo(rowNorm), _get_hi(rowNorm)));
- if (rowNormRes > maxNormSq)
- {
- __m256 scale = _mm256_set1_ps(rowNormRes);
-#if 0
- // REVIEW: this is faster but it uses approximation so results differ significantly from CLR.
- scale = _mm256_rsqrt_ps(scale);
- scale = _mm256_mul_ps(scale, m);
-#else
- scale = _mm256_sqrt_ps(scale);
- scale = _mm256_div_ps(m, scale);
-#endif
- for (pm = pms; pm < pmLim; pm += 8)
- {
- __m256 x1 = _mm256_load_ps(pm);
- x1 = _mm256_mul_ps(x1, scale);
- _mm256_store_ps(pm, x1);
- }
- }
- }
-
- _vleave();
-}
-
-EXPORT_API(void) AddScaleX(float a, _In_ const float * ps, _Inout_ float * pd, int c)
-{
- float * pdLim = pd + c;
-
- __m256 x1 = _mm256_set1_ps(a);
- for (; pd < pdLim; pd += 8, ps += 8)
- {
- __m256 x2 = _mm256_load_ps(ps);
- __m256 x3 = _mm256_load_ps(pd);
- x2 = _mm256_mul_ps(x1, x2);
- x3 = _mm256_add_ps(x2, x3);
- _mm256_store_ps(pd, x3);
- }
-
- _vleave();
-}
-
-EXPORT_API(void) AddScaleMomX(float a, _In_ const float * ps, _Inout_ float * pd, float momentum, _Inout_ float * pe, int c)
-{
- float * pdLim = pd + c;
-
- __m256 x0 = _mm256_set1_ps(momentum);
- __m256 x1 = _mm256_set1_ps(a);
- for (; pd < pdLim; pd += 8, pe += 8, ps += 8)
- {
- __m256 x2 = _mm256_load_ps(ps);
- __m256 x3 = _mm256_load_ps(pe);
- __m256 x4 = _mm256_load_ps(pd);
- x2 = _mm256_mul_ps(x1, x2);
- x3 = _mm256_mul_ps(x0, x3);
- x3 = _mm256_add_ps(x2, x3);
- x4 = _mm256_add_ps(x3, x4);
- _mm256_store_ps(pe, x3);
- _mm256_store_ps(pd, x4);
- }
-
- _vleave();
-}
-
-EXPORT_API(void) AddScaleGradX(_In_ const float * ps, _Inout_ float * pd, _Inout_ float * paccGrads, _Inout_ float * paccUpdates,
- float decay, float cond, int c)
-{
- float * pdLim = pd + c;
-
- __m256 dec = _mm256_set1_ps(decay);
- __m256 decc = _mm256_set1_ps(1 - decay);
- __m256 cnd = _mm256_set1_ps(cond);
- for (; pd < pdLim; pd += 8, ps += 8, paccGrads += 8, paccUpdates += 8)
- {
- __m256 g = _mm256_load_ps(ps);
- __m256 ag = _mm256_load_ps(paccGrads);
- __m256 au = _mm256_load_ps(paccUpdates);
- __m256 coef = _mm256_load_ps(pd);
-
- UpdateAdadelta(coef, ag, au, g, dec, decc, cnd);
-
- _mm256_store_ps(pd, coef);
- _mm256_store_ps(paccGrads, ag);
- _mm256_store_ps(paccUpdates, au);
- }
-
- _vleave();
-}
-
-EXPORT_API(void) AddX(_In_ const float * ps, _Inout_ float * pd, int c)
-{
- float * pdLim = pd + c;
-
- for (; pd < pdLim; pd += 8, ps += 8)
- {
- __m256 x1 = _mm256_load_ps(ps);
- __m256 x2 = _mm256_load_ps(pd);
- x2 = _mm256_add_ps(x1, x2);
- _mm256_store_ps(pd, x2);
- }
-
- _vleave();
-}
-
-EXPORT_API(float) SumX(const float * ps, int c)
-{
- const float * psLim = ps + c;
-
- __m256 res = _mm256_setzero_ps();
- for (; ps < psLim; ps += 8)
- {
- __m256 x1 = _mm256_load_ps(ps);
- res = _mm256_add_ps(res, x1);
- }
- res = _mm256_hadd_ps(res, res);
- res = _mm256_hadd_ps(res, res);
- __m128 r = _mm_add_ss(_get_lo(res), _get_hi(res));
-
- float ret = _mm_cvtss_f32(r);
- _vleave();
- return ret;
-}
-
-EXPORT_API(void) ScaleAdadeltaX(_Inout_ float * mat, _Inout_ float * accGrads, _Inout_ float * accUpdates, float decay, float cond, _In_ const float * grads, int size)
-{
- float * pm = mat;
- float * pmLim = pm + size;
- float * pag = accGrads;
- float * pau = accUpdates;
- const float * pg = grads;
-
- __m256 dec = _mm256_set1_ps(decay);
- __m256 decc = _mm256_set1_ps(1 - decay);
- __m256 c = _mm256_set1_ps(cond);
-
- for (; pm + 8 <= pmLim; pm += 8, pag += 8, pau += 8, pg += 8)
- {
- __m256 g = _mm256_loadu_ps(pg);
- __m256 ag = _mm256_loadu_ps(pag);
- __m256 au = _mm256_loadu_ps(pau);
- __m256 coef = _mm256_loadu_ps(pm);
-
- UpdateAdadelta(coef, ag, au, g, dec, decc, c);
-
- _mm256_storeu_ps(pm, coef);
- _mm256_storeu_ps(pag, ag);
- _mm256_storeu_ps(pau, au);
- }
-
- for (; pm < pmLim; pm++, pag++, pau++, pg++)
- {
- float g = *pg;
- float accGrad = decay * *pag + (1 - decay) * g * g;
- float accUpd = *pau;
- float newUpd = sqrtf((accUpd + cond) / (accGrad + cond)) * g;
- *pm += newUpd;
- *pag = accGrad;
- *pau = decay * accUpd + (1 - decay) * newUpd * newUpd;
- }
-
- _vleave();
-}
diff --git a/src/Native/CpuMathNative/CMakeLists.txt b/src/Native/CpuMathNative/CMakeLists.txt
index d4c9772421..1c66a7fcd4 100644
--- a/src/Native/CpuMathNative/CMakeLists.txt
+++ b/src/Native/CpuMathNative/CMakeLists.txt
@@ -2,15 +2,11 @@ project (CpuMathNative)
set(SOURCES
Sse.cpp
- Avx.cpp
MathLinux.S
)
-if(WIN32)
- set_property(SOURCE Avx.cpp APPEND_STRING PROPERTY COMPILE_FLAGS " /arch:AVX")
-else()
+if(NOT WIN32)
set_property(SOURCE Sse.cpp APPEND_STRING PROPERTY COMPILE_FLAGS " -msse3")
- set_property(SOURCE Avx.cpp APPEND_STRING PROPERTY COMPILE_FLAGS " -mavx")
list(APPEND SOURCES ${VERSION_FILE_PATH})
endif()
diff --git a/src/Native/CpuMathNative/Sse.cpp b/src/Native/CpuMathNative/Sse.cpp
index 9fde43fb12..5f519a2492 100644
--- a/src/Native/CpuMathNative/Sse.cpp
+++ b/src/Native/CpuMathNative/Sse.cpp
@@ -9,15 +9,7 @@
// * U suffix means unaligned and unpadded.
// * S suffix means sparse (unaligned) vector.
// * P suffix means sparse (unaligned) partial vector - the vector is only part of a larger sparse vector.
-// * R suffix means sparse matrix.
-// * C suffix means convolution matrix.
-// * D suffix means convolution matrix, with implicit source padding.
// * Tran means the matrix is transposed.
-//
-// Other notes:
-// * AVX methods should end with _vleave() to avoid performance hit. See:
-// https://stackoverflow.com/questions/7839925/using-avx-cpu-instructions-poor-performance-without-archavx.
-// * Keep Avx.cpp in sync with Sse.cpp. Note that Avx.cpp is compiled with /arch:AVX, but Sse.cpp is not.
#include "../Stdafx.h"
#include
@@ -49,43 +41,6 @@
x = _rotate(x); _mm_store_ss(pd + pi[2], x); \
x = _rotate(x); _mm_store_ss(pd + pi[3], x)
-#ifndef _WIN32
-
-typedef unsigned int DWORD; // NOTE: diff from windows.h, for LP64 compat
-
-// getcpuid and xmmYmmStateSupport are taken from
-// https://github.com/dotnet/coreclr/blob/b5f4d2df2e087401f2c3aab2c37021e326707915/src/vm/amd64/unixstubs.cpp#L14-L55
-
-DWORD getcpuid(DWORD arg, unsigned char result[16])
-{
- DWORD eax;
- __asm(" xor %%ecx, %%ecx\n" \
- " cpuid\n" \
- " mov %%eax, 0(%[result])\n" \
- " mov %%ebx, 4(%[result])\n" \
- " mov %%ecx, 8(%[result])\n" \
- " mov %%edx, 12(%[result])\n" \
- : "=a"(eax) /*output in eax*/\
- : "a"(arg), [result]"r"(result) /*inputs - arg in eax, result in any register*/\
- : "rbx", "ecx", "edx", "memory" /* registers that are clobbered, *result is clobbered */
- );
- return eax;
-}
-
-DWORD xmmYmmStateSupport()
-{
- DWORD eax;
- __asm(" xgetbv\n" \
- : "=a"(eax) /*output in eax*/\
- : "c"(0) /*inputs - 0 in ecx*/\
- : "edx" /* registers that are clobbered*/
- );
- // check OS has enabled both XMM and YMM state support
- return ((eax & 0x06) == 0x06) ? 1 : 0;
-}
-
-#endif
-
const unsigned int LeadingAlignmentMask[16] =
{
0x00000000, 0x00000000, 0x00000000, 0x00000000,
@@ -102,25 +57,6 @@ const unsigned int TrailingAlignmentMask[16] =
0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
};
-// Test whether Avx is available.
-EXPORT_API(bool) ChkAvx()
-{
-#ifdef _WIN32
- int cpuInfo[4];
- __cpuid(cpuInfo, 1);
-
- // 28th bit of second integer of Cpu Info denotes whether the Avx is supported in CPU or not
- // Reference https://msdn.microsoft.com/en-us/library/hskdteyh(v=vs.100).aspx
- return cpuInfo[2] & (1 << 28) || false;
-#else
- unsigned char buffer[16];
- (void) getcpuid(1, buffer);
-
- // taken from https://github.com/dotnet/coreclr/blob/b5f4d2df2e087401f2c3aab2c37021e326707915/src/vm/codeman.cpp#L1381
- return ((buffer[11] & 0x18) == 0x18) && (xmmYmmStateSupport() == 1);
-#endif
-}
-
// Multiply matrix times vector into vector.
EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol)
{
@@ -255,332 +191,146 @@ EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout
}
// Partial sparse source vector.
-EXPORT_API(void) MatMulPA(_In_ const float * pmat, _In_ const int * pposSrc, _In_ const float * psrc,
+EXPORT_API(void) MatMulP(_In_ const float * pmat, _In_ const int * pposSrc, _In_ const float * psrc,
int posMin, int iposMin, int iposLim, _Inout_ float * pdst, int crow, int ccol)
{
// REVIEW: For extremely sparse inputs, interchanging the loops would
// likely be more efficient.
const int * pposMin = pposSrc + iposMin;
- const int * pposLim = pposSrc + iposLim;
- const float * pdLim = pdst + crow;
+ const int * pposEnd = pposSrc + iposLim;
+ const float * pDstEnd = pdst + crow;
const float * pm0 = pmat - posMin;
- const float * ps = psrc - posMin;
- for (float * pd = pdst; pd < pdLim; pd += 4, pm0 += 4 * ccol)
+ const float * pSrcCurrent = psrc - posMin;
+ float* pDstCurrent = pdst;
+
+ uintptr_t address = (uintptr_t)(pDstCurrent);
+ uintptr_t misalignment = address % 16;
+ int length = crow;
+ int remainder = 0;
+
+ if ((misalignment & 3) != 0)
{
- const float * pm1 = pm0 + ccol;
- const float * pm2 = pm1 + ccol;
- const float * pm3 = pm2 + ccol;
- __m128 res = _mm_setzero_ps();
- for (const int * ppos = pposMin; ppos < pposLim; ppos++)
+ while (pDstCurrent < pDstEnd)
{
- int col = *ppos;
- __m128 x1 = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]);
- __m128 x2 = _mm_set1_ps(ps[col]);
- x2 = _mm_mul_ps(x2, x1);
- res = _mm_add_ps(res, x2);
- }
+ const float* pm1 = pm0 + ccol;
+ const float* pm2 = pm1 + ccol;
+ const float* pm3 = pm2 + ccol;
- _mm_store_ps(pd, res);
- }
-}
+ __m128 res = _mm_setzero_ps();
+ const int* ppos = pposMin;
-// Sparse matrix.
-EXPORT_API(void) MatMulRU(bool add, _In_ const int * pstarts, _In_ const int * pindices, _In_ const float * pcoefs,
- _In_ const float * ps, _Inout_ float * pdst, int crow)
-{
- const int * pii = pstarts + 1;
- const int * pi = pindices;
- const float * pm = pcoefs;
- const float * pdLim = pdst + crow;
- for (float * pd = pdst; pd < pdLim; pd++)
- {
- const int * piLim = pindices + *pii++;
+ while (ppos < pposEnd)
+ {
+ int col = *ppos;
+ __m128 x1 = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]);
+ __m128 x2 = _mm_set1_ps(pSrcCurrent[col]);
+ x2 = _mm_mul_ps(x2, x1);
+ res = _mm_add_ps(res, x2);
+ ppos++;
+ }
- __m128 res = _mm_setzero_ps();
- for (; pi + 4 <= piLim; pi += 4, pm += 4)
- {
- __m128 x = _mm_mul_ps(_load4(ps, pi), _mm_loadu_ps(pm));
- res = _mm_add_ps(res, x);
- }
- for (; pi < piLim; pi++, pm++)
- {
- __m128 x = _mm_mul_ss(_load1(ps, pi), _mm_set_ss(*pm));
- res = _mm_add_ss(res, x);
+ _mm_storeu_ps(pDstCurrent, res);
+ pDstCurrent += 4;
+ pm0 += 4 * ccol;
}
- res = _mm_hadd_ps(res, res);
- res = _mm_hadd_ps(res, res);
-
- if (add)
- res = _mm_add_ss(res, _mm_set_ss(*pd));
- _mm_store_ss(pd, res);
}
-}
-
-// Unpadded convolution.
-EXPORT_API(void) MatMulCU(bool add, _In_ const int * pmprowiv, _In_ const int * pmprowcol,
- _In_ const int * pruns, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pdst, int crow)
-{
- int size = pruns[1];
- const int * psupport = pruns + 2;
- const int * piv = pmprowiv;
- const int * pcol = pmprowcol;
- const int * piLim = psupport + size;
- const float * pdLim = pdst + crow;
-
- for (float * pd = pdst; pd < pdLim; pd++)
+ else
{
- const float * pm = pcoefs + *piv++;
- const float * ps = psrc + *pcol++;
- const int * pi = psupport;
-
- __m128 res = _mm_setzero_ps();
- for (; pi + 4 <= piLim; pi += 4, pm += 4)
- {
- __m128 x = _mm_mul_ps(_load4(ps, pi), _mm_loadu_ps(pm));
- res = _mm_add_ps(res, x);
- }
- for (; pi < piLim; pi++, pm++)
+ if (misalignment != 0)
{
- __m128 x = _mm_mul_ss(_load1(ps, pi), _mm_set_ss(*pm));
- res = _mm_add_ss(res, x);
- }
- res = _mm_hadd_ps(res, res);
- res = _mm_hadd_ps(res, res);
+ misalignment >>= 2;
+ misalignment = 4 - misalignment;
- // Add the bias.
- res = _mm_add_ss(res, _mm_set_ss(*pm));
+ __m128 mask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4));
- if (add)
- res = _mm_add_ss(res, _mm_set_ss(*pd));
- _mm_store_ss(pd, res);
- }
-}
+ const float* pm1 = pm0 + ccol;
+ const float* pm2 = pm1 + ccol;
+ const float* pm3 = pm2 + ccol;
-// Padded convolution.
-EXPORT_API(void) MatMulDU(bool add, _In_ const int * pmprowiv, _In_ const int * pmprowcol, _In_ const int * pmprowrun,
- _In_ const int * pruns, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pdst, int crow)
-{
- const int * piv = pmprowiv;
- const int * pcol = pmprowcol;
- const float * pdLim = pdst + crow;
- int kernelSize = pruns[1];
+ __m128 res = _mm_setzero_ps();
+ const int* ppos = pposMin;
- const int * pirun = pmprowrun;
- for (float * pd = pdst; pd < pdLim; pd++)
- {
- const float * pm = pcoefs + *piv++;
- const float * pmBias = pm + kernelSize;
- const float * ps = psrc + *pcol++;
- int irun = *pirun++;
-
- const int * pi = pruns + 2 + irun;
- const int * piLim = pi + pi[-1];
- __m128 res = _mm_setzero_ps();
- if (irun == 0)
- {
- // No masking needed.
- for (; pi + 4 <= piLim; pi += 4, pm += 4)
- {
- __m128 x = _mm_mul_ps(_load4(ps, pi), _mm_loadu_ps(pm));
- res = _mm_add_ps(res, x);
- }
- for (; pi < piLim; pi++, pm++)
+ while (ppos < pposEnd)
{
- __m128 x = _mm_mul_ss(_load1(ps, pi), _mm_set_ss(*pm));
- res = _mm_add_ss(res, x);
+ int col = *ppos;
+ __m128 x1 = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]);
+ x1 = _mm_and_ps(mask, x1);
+
+ __m128 x2 = _mm_set1_ps(pSrcCurrent[col]);
+ x2 = _mm_mul_ps(x2, x1);
+ res = _mm_add_ps(res, x2);
+ ppos++;
}
+
+ _mm_storeu_ps(pDstCurrent, res);
+ pDstCurrent += misalignment;
+ pm0 += misalignment * ccol;
+ length -= misalignment;
}
- else
+
+ if (length > 3)
{
- // Need masking.
- pm += pi[-2];
- const float * pmask = reinterpret_cast(piLim);
- for (; pi + 4 <= piLim; pi += 4, pm += 4, pmask += 4)
- {
- __m128 x = _mm_mul_ps(_load4(ps, pi), _mm_and_ps(_mm_loadu_ps(pmask), _mm_loadu_ps(pm)));
- res = _mm_add_ps(res, x);
- }
- for (; pi < piLim; pi++, pm++, pmask++)
+ remainder = length % 4;
+ while (pDstCurrent < pDstEnd)
{
- __m128 x = _mm_mul_ss(_load1(ps, pi), _mm_and_ps(_mm_set_ss(*pmask), _mm_set_ss(*pm)));
- res = _mm_add_ss(res, x);
- }
- }
- res = _mm_hadd_ps(res, res);
- res = _mm_hadd_ps(res, res);
+ const float* pm1 = pm0 + ccol;
+ const float* pm2 = pm1 + ccol;
+ const float* pm3 = pm2 + ccol;
- res = _mm_add_ss(res, _mm_set_ss(*pmBias));
- if (add)
- res = _mm_add_ss(res, _mm_set_ss(*pd));
- _mm_store_ss(pd, res);
- }
-}
-
-// Mean pooling.
-EXPORT_API(void) MeanU(bool add, _In_ const int * pmprowcol, _In_opt_ const int * pmprowindices, _In_ const int * pindices,
- _In_ const float * psrc, _Inout_ float * pdst, int crow)
-{
- const int * pcol = pmprowcol;
- const float * pdLim = pdst + crow;
+ const int* ppos = pposMin;
+ __m128 res = _mm_setzero_ps();
- if (pmprowindices == nullptr)
- {
- int size = pindices[0];
- __m128 x0 = _mm_set_ss((float)size);
- const int * piLim = pindices + 1 + size;
- for (float * pd = pdst; pd < pdLim; pd++)
- {
- const float * ps = psrc + *pcol++;
- const int * pi = pindices + 1;
+ while (ppos < pposEnd)
+ {
+ int col = *ppos;
+ __m128 x1 = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]);
+ __m128 x2 = _mm_set1_ps(pSrcCurrent[col]);
+ x2 = _mm_mul_ps(x2, x1);
+ res = _mm_add_ps(res, x2);
+ ppos++;
+ }
- __m128 res = _mm_setzero_ps();
- for (; pi + 4 <= piLim; pi += 4)
- res = _mm_add_ps(res, _load4(ps, pi));
- for (; pi < piLim; pi++)
- res = _mm_add_ss(res, _load1(ps, pi));
- res = _mm_hadd_ps(res, res);
- res = _mm_hadd_ps(res, res);
-
- res = _mm_div_ss(res, x0);
- if (add)
- res = _mm_add_ss(res, _mm_set_ss(*pd));
- _mm_store_ss(pd, res);
+ _mm_store_ps(pDstCurrent, res);
+ pDstCurrent += 4;
+ pm0 += 4 * ccol;
+ }
}
- }
- else
- {
- const int * pii = pmprowindices;
- for (float * pd = pdst; pd < pdLim; pd++)
+ else
{
- const float * ps = psrc + *pcol++;
- int ii = *pii++;
-
- const int * pi = pindices + ii;
- int size = *pi++;
- const int * piLim = pi + size;
- __m128 res = _mm_setzero_ps();
- for (; pi + 4 <= piLim; pi += 4)
- res = _mm_add_ps(res, _load4(ps, pi));
- for (; pi < piLim; pi++)
- res = _mm_add_ss(res, _load1(ps, pi));
- res = _mm_hadd_ps(res, res);
- res = _mm_hadd_ps(res, res);
-
- res = _mm_div_ss(res, _mm_set_ss((float)size));
- if (add)
- res = _mm_add_ss(res, _mm_set_ss(*pd));
- _mm_store_ss(pd, res);
+ length = remainder;
}
- }
-}
-
-// Max pooling.
-EXPORT_API(void) MaxU(bool add, _In_ const int * pmprowcol, _In_opt_ const int * pmprowindices, _In_ const int * pindices,
- _In_ const float * psrc, _Inout_ float * pdst, int crow)
-{
- const int * pcol = pmprowcol;
- const float * pdLim = pdst + crow;
- __m128 min = _mm_set1_ps(-std::numeric_limits::infinity());
- if (pmprowindices == nullptr)
- {
- int size = pindices[0];
- const int * piLim = pindices + 1 + size;
- for (float * pd = pdst; pd < pdLim; pd++)
- {
- const float * ps = psrc + *pcol++;
- const int * pi = pindices + 1;
-
- __m128 res = min;
- for (; pi + 4 <= piLim; pi += 4)
- res = _mm_max_ps(res, _load4(ps, pi));
- for (; pi < piLim; pi++)
- res = _mm_max_ss(res, _load1(ps, pi));
- __m128 x1 = _mm_shuffle_ps(res, res, 0xB1);
- res = _mm_max_ps(res, x1);
- x1 = _mm_shuffle_ps(res, res, 0x02);
- res = _mm_max_ss(res, x1);
-
- if (add)
- res = _mm_add_ss(res, _mm_set_ss(*pd));
- _mm_store_ss(pd, res);
- }
- }
- else
- {
- const int * pii = pmprowindices;
- for (float * pd = pdst; pd < pdLim; pd++)
+ if (remainder != 0)
{
- const float * ps = psrc + *pcol++;
- int ii = *pii++;
-
- const int * pi = pindices + ii;
- int size = *pi++;
- const int * piLim = pi + size;
- __m128 res = min;
- for (; pi + 4 <= piLim; pi += 4)
- res = _mm_max_ps(res, _load4(ps, pi));
- for (; pi < piLim; pi++)
- res = _mm_max_ss(res, _load1(ps, pi));
- __m128 x1 = _mm_shuffle_ps(res, res, 0xB1);
- res = _mm_max_ps(res, x1);
- x1 = _mm_shuffle_ps(res, res, 0x02);
- res = _mm_max_ss(res, x1);
-
- if (add)
- res = _mm_add_ss(res, _mm_set_ss(*pd));
- _mm_store_ss(pd, res);
- }
- }
-}
+ pDstCurrent -= (4 - remainder);
+ pm0 -= (4 - remainder) * ccol;
-// REVIEW: Try out SSE/AVX after padding support is added. AVX math platform uses the same code below.
-EXPORT_API(void) RespNormU(bool add, float alpha, float beta, bool avgOverFullKernel, float offset,
- _In_ const int * pmprowcol, _In_opt_ const int * pmprowindices, _In_ const int * pindices,
- _In_ const float * psrc, _Inout_ float * pdst, int crow)
-{
- const int * pcol = pmprowcol;
- const float * pdLim = pdst + crow;
+ __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4));
+ __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (( 4 - remainder) * 4));
+
+ const float* pm1 = pm0 + ccol;
+ const float* pm2 = pm1 + ccol;
+ const float* pm3 = pm2 + ccol;
- if (pmprowindices == nullptr)
- {
- int size = pindices[0];
- float scale = alpha / size;
- const int * piLim = pindices + 1 + size;
- for (float * pd = pdst; pd < pdLim; pd++)
- {
- const float * ps = psrc + *pcol++;
- const int * pi = pindices + 1;
- float res = 0;
- for (; pi < piLim; pi++)
- {
- float cur = ps[*pi];
- res += cur * cur;
- }
- res = ps[0] * powf(offset + scale * res, -beta);
- *pd = add ? *pd + res : res;
- }
- }
- else
- {
- int kernelSize = pindices[0];
- const int * pii = pmprowindices;
- for (float * pd = pdst; pd < pdLim; pd++)
- {
- const float * ps = psrc + *pcol++;
- int ii = *pii++;
- const int * pi = pindices + ii;
- int size = *pi++;
- const int * piLim = pi + size;
- float res = 0;
- for (; pi < piLim; pi++)
+ const int* ppos = pposMin;
+ __m128 res = _mm_setzero_ps();
+
+ while (ppos < pposEnd)
{
- float cur = ps[*pi];
- res += cur * cur;
+ int col = *ppos;
+ __m128 x1 = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]);
+ x1 = _mm_and_ps(x1, trailingMask);
+
+ __m128 x2 = _mm_set1_ps(pSrcCurrent[col]);
+ x2 = _mm_mul_ps(x2, x1);
+ res = _mm_add_ps(res, x2);
+ ppos++;
}
- int avgDenom = avgOverFullKernel ? kernelSize : size;
- res = ps[0] * powf(offset + alpha / avgDenom * res, -beta);
- *pd = add ? *pd + res : res;
+
+ res = _mm_add_ps(res, _mm_and_ps(leadingMask, _mm_loadu_ps(pDstCurrent)));
+ _mm_storeu_ps(pDstCurrent, res);
+ pDstCurrent += 4;
+ pm0 += 4 * ccol;
}
}
}
@@ -866,1118 +616,166 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I
}
}
-// Sparse matrix.
-EXPORT_API(void) MatMulTranRU(bool add, _In_ const int * pstarts, _In_ const int * pindices, _In_ const float * pcoefs,
- _In_ const float * psrc, _Inout_ float * pd, int crow, int ccol)
+// pd[i] += a
+EXPORT_API(void) AddScalarU(float a, _Inout_ float * pd, int c)
{
- if (!add)
- memset(pd, 0, crow * sizeof(float));
-
- const int * pii = pstarts + 1;
- const int * pi = pindices;
- const float * pm = pcoefs;
- const float * psLim = psrc + ccol;
- for (const float * ps = psrc; ps < psLim; ps++)
- {
- float x = *ps;
- const int * piLim = pindices + *pii++;
+ float * pdLim = pd + c;
- __m128 x1 = _mm_set1_ps(x);
- for (; pi + 4 <= piLim; pi += 4, pm += 4)
- {
- __m128 x2 = _mm_mul_ps(x1, _mm_loadu_ps(pm));
- x2 = _mm_add_ps(x2, _load4(pd, pi));
- _store4(x2, pd, pi);
- }
- for (; pi < piLim; pi++, pm++)
- {
- __m128 x2 = _mm_mul_ss(x1, _mm_set_ss(*pm));
- x2 = _mm_add_ss(x2, _load1(pd, pi));
- _store1(x2, pd, pi);
- }
+ __m128 x1 = _mm_set1_ps(a);
+ for (; pd + 4 <= pdLim; pd += 4)
+ {
+ __m128 x2 = _mm_loadu_ps(pd);
+ x2 = _mm_add_ps(x2, x1);
+ _mm_storeu_ps(pd, x2);
}
-}
-// Unpadded convolution.
-EXPORT_API(void) MatMulTranCU(bool add, _In_ const int * pmpcoliv, _In_ const int * pmpcolrow,
- _In_ const int * pruns, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol)
-{
- if (!add)
- memset(pdst, 0, crow * sizeof(float));
-
- int size = pruns[1];
- const int * psupport = pruns + 2;
- const int * piv = pmpcoliv;
- const int * prow = pmpcolrow;
- const int * piLim = psupport + size;
- const float * psLim = psrc + ccol;
- for (const float * ps = psrc; ps < psLim; ps++)
+ for (; pd < pdLim; pd++)
{
- const float * pm = pcoefs + *piv++;
- float * pd = pdst + *prow++;
- const int * pi = psupport;
-
- float x = *ps;
- __m128 x1 = _mm_set1_ps(x);
- for (; pi + 4 <= piLim; pm += 4, pi += 4)
- {
- __m128 x2 = _mm_mul_ps(x1, _mm_loadu_ps(pm));
- x2 = _mm_add_ps(x2, _load4(pd, pi));
- _store4(x2, pd, pi);
- }
- for (; pi < piLim; pi++, pm++)
- {
- __m128 x2 = _mm_mul_ss(x1, _mm_set_ss(*pm));
- x2 = _mm_add_ss(x2, _load1(pd, pi));
- _store1(x2, pd, pi);
- }
+ __m128 x2 = _mm_load_ss(pd);
+ x2 = _mm_add_ss(x2, x1);
+ _mm_store_ss(pd, x2);
}
}
-// Padded convolution.
-EXPORT_API(void) MatMulTranDU(bool add, _In_ const int * pmpcoliv, _In_ const int * pmpcolrow, _In_ const int * pmpcolrun,
- _In_ const int * pruns, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol)
+EXPORT_API(void) Scale(float a, _Inout_ float * pd, int c)
{
- if (!add)
- memset(pdst, 0, crow * sizeof(float));
-
- const int * piv = pmpcoliv;
- const int * prow = pmpcolrow;
- const float * psLim = psrc + ccol;
-
- const int * pirun = pmpcolrun;
- for (const float * ps = psrc; ps < psLim; ps++)
+ __m128 x1 = _mm_set1_ps(a);
+
+ if (c < 4)
{
- const float * pm = pcoefs + *piv++;
- float * pd = pdst + *prow++;
- int irun = *pirun++;
- const int * pi = pruns + 2 + irun;
- const int * piLim = pi + pi[-1];
-
- float x = *ps;
- __m128 x1 = _mm_set1_ps(x);
- if (irun == 0)
- {
- // No masking needed.
- for (; pi + 4 <= piLim; pi += 4, pm += 4)
- {
- __m128 x2 = _mm_mul_ps(x1, _mm_loadu_ps(pm));
- x2 = _mm_add_ps(x2, _load4(pd, pi));
- _store4(x2, pd, pi);
- }
- for (; pi < piLim; pi++, pm++)
- {
- __m128 x2 = _mm_mul_ss(x1, _mm_set_ss(*pm));
- x2 = _mm_add_ss(x2, _load1(pd, pi));
- _store1(x2, pd, pi);
- }
- }
- else
+ switch(c)
{
- // Need masking.
- pm += pi[-2];
- const float * pmask = reinterpret_cast(piLim);
- for (; pi + 4 <= piLim; pi += 4, pm += 4, pmask += 4)
- {
- __m128 x2 = _mm_mul_ps(_mm_and_ps(_mm_loadu_ps(pmask), x1), _mm_loadu_ps(pm));
- x2 = _mm_add_ps(x2, _load4(pd, pi));
- _store4(x2, pd, pi);
- }
- for (; pi < piLim; pi++, pm++, pmask++)
- {
- __m128 x2 = _mm_mul_ss(_mm_and_ps(_mm_set_ss(*pmask), x1), _mm_set_ss(*pm));
- x2 = _mm_add_ss(x2, _load1(pd, pi));
- _store1(x2, pd, pi);
- }
+ case 3: pd[2] *= a;
+ case 2: pd[1] *= a;
+ case 1: pd[0] *= a;
}
+ return;
}
-}
-// Mean pooling back prop.
-EXPORT_API(void) MeanBackU(bool add, _In_ const int * pmpcolrow, _In_opt_ const int * pmpcolindices, _In_ const int * pindices,
- _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol)
-{
- if (!add)
- memset(pdst, 0, crow * sizeof(float));
+ uintptr_t address = (uintptr_t)(pd);
+ uintptr_t misalignment = address % 16;
+ int remainder = 0;
- const int * prow = pmpcolrow;
- const float * psLim = psrc + ccol;
- if (pmpcolindices == nullptr)
+ if ((misalignment & 3) != 0)
{
- int size = pindices[0];
- const int * piLim = pindices + 1 + size;
- for (const float * ps = psrc; ps < psLim; ps++)
+ // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations
+ remainder = c % 4;
+
+ for (const float* pEnd = pd + (c - remainder); pd < pEnd; pd += 4)
{
- float * pd = pdst + *prow++;
- const int * pi = pindices + 1;
-
- float x = *ps / size;
- __m128 x1 = _mm_set1_ps(x);
- for (; pi + 4 <= piLim; pi += 4)
- {
- __m128 x2 = _mm_add_ps(x1, _load4(pd, pi));
- _store4(x2, pd, pi);
- }
- for (; pi < piLim; pi++)
- {
- __m128 x2 = _mm_add_ss(x1, _load1(pd, pi));
- _store1(x2, pd, pi);
- }
+ __m128 x2 = _mm_loadu_ps(pd);
+ x2 = _mm_mul_ps(x1, x2);
+ _mm_storeu_ps(pd, x2);
}
}
else
{
- const int * pii = pmpcolindices;
- for (const float * ps = psrc; ps < psLim; ps++)
+ if (misalignment != 0)
{
- float * pd = pdst + *prow++;
- int ii = *pii++;
-
- const int * pi = pindices + ii;
- int size = *pi++;
- const int * piLim = pi + size;
+ // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then
+ // masking any elements that will be included in the first aligned read
+ misalignment >>= 2;
+ misalignment = 4 - misalignment;
+
+ __m128 result = _mm_loadu_ps(pd);
+
+ __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4));
+ __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + ((4 - misalignment) * 4));
+
+ __m128 temp = _mm_and_ps(result, leadingMask);
+ result = _mm_and_ps(result, trailingMask);
+
+ temp = _mm_mul_ps(temp, x1);
+ result = _mm_or_ps(temp, result);
+
+ _mm_storeu_ps(pd, result);
+
+ pd += misalignment;
+ c -= misalignment;
+ }
- float x = *ps / size;
- __m128 x1 = _mm_set1_ps(x);
- for (; pi + 4 <= piLim; pi += 4)
- {
- __m128 x2 = _mm_add_ps(x1, _load4(pd, pi));
- _store4(x2, pd, pi);
- }
- for (; pi < piLim; pi++)
+ if (c > 3)
+ {
+ // Handle all the 128-bit blocks that we can now that we have offset to an aligned address
+ remainder = c % 4;
+ for (const float* pEnd = pd + (c - remainder); pd < pEnd; pd += 4)
{
- __m128 x2 = _mm_add_ss(x1, _load1(pd, pi));
- _store1(x2, pd, pi);
+ __m128 x2 = _mm_load_ps(pd);
+ x2 = _mm_mul_ps(x1, x2);
+ _mm_storeu_ps(pd, x2);
}
}
+ else
+ {
+ // Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not
+ // 128-bit aligned. This means we can't do any aligned loads and will just end up doing two
+ // unaligned loads where we mask the input each time.
+ remainder = c;
+ }
+ }
+
+ if (remainder != 0)
+ {
+ // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next
+ // unaligned load will read to the end of the array and then mask out any elements already processed
+
+ pd -= (4 - remainder);
+ __m128 result = _mm_loadu_ps(pd);
+
+ __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4));
+ __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + ((4 - remainder) * 4));
+
+ __m128 temp = _mm_and_ps(result, trailingMask);
+ result = _mm_and_ps(result, leadingMask);
+
+ temp = _mm_mul_ps(temp, x1);
+ result = _mm_or_ps(temp, result);
+
+ _mm_storeu_ps(pd, result);
}
}
-// Max pooling back prop.
-EXPORT_API(void) MaxBackU(bool add, _In_ const int * pmpcolrow, _In_opt_ const int * pmpcolindices, _In_ const int * pindices,
- _In_ const float * psrc, _Inout_ float * pdst, _In_ const float * pval, int crow, int ccol)
+EXPORT_API(void) ScaleSrcU(float a, _In_ const float * ps, _Inout_ float * pd, int c)
{
- if (!add)
- memset(pdst, 0, crow * sizeof(float));
+ float * pdLim = pd + c;
- const int * prow = pmpcolrow;
- const float * psLim = psrc + ccol;
- if (pmpcolindices == nullptr)
+ __m128 x1 = _mm_set1_ps(a);
+ for (; pd + 4 <= pdLim; pd += 4, ps += 4)
{
- const int * piLim = pindices + 1 + pindices[0];
- for (const float * ps = psrc; ps < psLim; ps++)
- {
- int rowBase = *prow++;
- float * pd = pdst + rowBase;
- const float * pv = pval + rowBase;
- const int * pi = pindices + 1;
-
- int j = *pi++;
- float m = pv[j];
- for (; pi < piLim; pi++)
- {
- if (m < pv[*pi])
- {
- j = *pi;
- m = pv[j];
- }
- }
- pd[j] += *ps;
- }
+ __m128 x2 = _mm_loadu_ps(ps);
+ x2 = _mm_mul_ps(x2, x1);
+ _mm_storeu_ps(pd, x2);
}
- else
+
+ for (; pd < pdLim; pd++, ps++)
{
- const int * pii = pmpcolindices;
- for (const float * ps = psrc; ps < psLim; ps++)
- {
- int rowBase = *prow++;
- int ii = *pii++;
- float * pd = pdst + rowBase;
- const float * pv = pval + rowBase;
- const int * pi = pindices + ii + 1;
- const int * piLim = pi + pi[-1];
-
- int j = *pi++;
- float m = pv[j];
- for (; pi < piLim; pi++)
- {
- if (m < pv[*pi])
- {
- j = *pi;
- m = pv[j];
- }
- }
- pd[j] += *ps;
- }
+ __m128 x2 = _mm_load_ss(ps);
+ x2 = _mm_mul_ss(x2, x1);
+ _mm_store_ss(pd, x2);
}
}
-// REVIEW: Try out SSE/AVX after padding support is added. AVX math platform uses the same code below.
-EXPORT_API(void) RespNormBackU(bool add, float alpha, float beta, bool avgOverFullKernel, float offset,
- _In_ const int * pmpcolrow, _In_opt_ const int * pmpcolindices, _In_ const int * pindices,
- _In_ const float * perrors, _Inout_ float * perrorsPrev, _In_ const float * pvaluesPrev, int crow, int ccol)
-{
- if (!add)
- memset(perrorsPrev, 0, crow * sizeof(float));
-
- const int * prow = pmpcolrow;
- const float * psLim = perrors + ccol;
- if (pmpcolindices == nullptr)
- {
- int size = pindices[0];
- float scale = alpha / size;
- const int * piMin = pindices + 1;
- const int * piLim = piMin + size;
- for (const float * ps = perrors; ps < psLim; ps++)
- {
- int rowBase = *prow++;
- // First compute denominator: denom = offset + scale * Sum(Xj^2)
- float denom = 0;
- const float * pv = pvaluesPrev + rowBase;
-
- for (const int * pi = piMin; pi < piLim; pi++)
- {
- float cur = pv[*pi];
- denom += cur * cur;
- }
- denom = offset + scale * denom;
- float denomPow = powf(denom, -beta);
- // The output.
- float y = pv[0] * denomPow;
-
- // The update logic:
- // srcError(*ps) X the derivative.
- // derivative at i wrt center point = powf(denom, -beta) - 2* scale * beta * X[i] * y / denom.
- // derivative at i wrt other points = - 2* scale * beta * X[i] * y / denom.
- float commonUpdate = *ps * (-2 * scale * beta * y) / denom;
-
- float * pd = perrorsPrev + rowBase;
- for (const int * pi = piMin; pi < piLim; pi++)
- pd[*pi] += pv[*pi] * commonUpdate;
-
- // Additional update for the center point.
- pd[0] += *ps * denomPow;
- }
- }
- else
- {
- int kernelSize = pindices[0];
- const int * pii = pmpcolindices;
- for (const float * ps = perrors; ps < psLim; ps++)
- {
- int rowBase = *prow++;
- // First compute denominator: denom = 1 + scale * Sum(Xj^2)
- float denom = 0;
- const float * pv = pvaluesPrev + rowBase;
- int ii = *pii++;
-
- const int * piMin = pindices + ii;
- int size = *piMin++;
- const int * piLim = piMin + size;
-
- for (const int * pi = piMin; pi < piLim; pi++)
- {
- float cur = pv[*pi];
- denom += cur * cur;
- }
- float scale = alpha / (avgOverFullKernel ? kernelSize : size);
- denom = offset + scale * denom;
- float denomPow = powf(denom, -beta);
- // The output.
- float y = pv[0] * denomPow;
-
- // The update logic:
- // srcError(*ps) X the derivative.
- // derivative at i wrt center point = powf(denom, -beta) - 2* scale * beta * X[i] * y / denom.
- // derivative at i wrt other points = - 2* scale * beta * X[i] * y / denom.
- float commonUpdate = *ps * (-2 * scale * beta * y) / denom;
-
- float * pd = perrorsPrev + rowBase;
- for (const int * pi = piMin; pi < piLim; pi++)
- pd[*pi] += pv[*pi] * commonUpdate;
-
- // Additional update for the center point.
- pd[0] += *ps * denomPow;
- }
- }
-}
-
-template
-void AddXYTranACore(float a, _In_ const float * px, _In_ const float * py, _Inout_ float * pmat, int crow, int ccol, float decay)
-{
- const float * pyBase = py;
- const float * pxLim = px + crow;
- const float * pyLim = py + ccol;
- float * pm = pmat;
- __m128 wd;
- if (useDecay)
- wd = _mm_set1_ps(1 - decay);
- for (; px < pxLim; px++)
- {
- float r = a * *px;
- py = pyBase;
-
- __m128 x1 = _mm_set1_ps(r);
- for (; py + 16 <= pyLim; py += 16, pm += 16)
- {
- __m128 x02 = _mm_load_ps(py);
- __m128 x12 = _mm_load_ps(py + 4);
- __m128 x22 = _mm_load_ps(py + 8);
- __m128 x32 = _mm_load_ps(py + 12);
- __m128 x03 = _mm_load_ps(pm);
- __m128 x13 = _mm_load_ps(pm + 4);
- __m128 x23 = _mm_load_ps(pm + 8);
- __m128 x33 = _mm_load_ps(pm + 12);
- x02 = _mm_mul_ps(x1, x02);
- x12 = _mm_mul_ps(x1, x12);
- x22 = _mm_mul_ps(x1, x22);
- x32 = _mm_mul_ps(x1, x32);
- if (useDecay)
- {
- x03 = _mm_mul_ps(wd, x03);
- x13 = _mm_mul_ps(wd, x13);
- x23 = _mm_mul_ps(wd, x23);
- x33 = _mm_mul_ps(wd, x33);
- }
- x03 = _mm_add_ps(x02, x03);
- x13 = _mm_add_ps(x12, x13);
- x23 = _mm_add_ps(x22, x23);
- x33 = _mm_add_ps(x32, x33);
- _mm_store_ps(pm, x03);
- _mm_store_ps(pm + 4, x13);
- _mm_store_ps(pm + 8, x23);
- _mm_store_ps(pm + 12, x33);
- }
- for (; py < pyLim; py += 4, pm += 4)
- {
- __m128 x02 = _mm_load_ps(py);
- __m128 x03 = _mm_load_ps(pm);
- x02 = _mm_mul_ps(x1, x02);
- if (useDecay)
- x03 = _mm_mul_ps(wd, x03);
- x03 = _mm_add_ps(x02, x03);
- _mm_store_ps(pm, x03);
- }
- }
-}
-
-EXPORT_API(void) AddXYTranA(float a, _In_ const float * px, _In_ const float * py, _Inout_ float * pmat, int crow, int ccol, float decay)
-{
- if (decay == 0)
- AddXYTranACore(a, px, py, pmat, crow, ccol, decay);
- else
- AddXYTranACore(a, px, py, pmat, crow, ccol, decay);
-}
-
-// Partial sparse source vector.
-EXPORT_API(void) AddXYTranPA(float a, _In_ const float * px, _In_ const int * pposY, _In_ const float * pvaluesY,
- int posMinY, int iposMinY, int iposLimY, _Inout_ float * pmat, int crow, int ccol)
-{
-#if 1
- // REVIEW: This is faster for MNIST, but the version below is faster for extremely sparse input.
- const int * pposMin = pposY + iposMinY;
- const int * pposLim = pposY + iposLimY;
- const float * pxLim = px + crow;
- float * pm0 = pmat - posMinY;
- const float * py = pvaluesY - posMinY;
-
- __m128 x0 = _mm_set1_ps(a);
- for (; px < pxLim; px += 4, pm0 += 4 * ccol)
- {
- float * pm1 = pm0 + ccol;
- float * pm2 = pm1 + ccol;
- float * pm3 = pm2 + ccol;
-
- __m128 x1 = _mm_load_ps(px);
- x1 = _mm_mul_ps(x1, x0);
-
- for (const int * ppos = pposMin; ppos < pposLim; ppos++)
- {
- int col = *ppos;
- __m128 x2 = _mm_set1_ps(py[col]);
- __m128 x3 = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]);
- x2 = _mm_mul_ps(x2, x1);
- x3 = _mm_add_ps(x3, x2);
-
- _mm_store_ss(pm0 + col, x3); x3 = _rotate(x3);
- _mm_store_ss(pm1 + col, x3); x3 = _rotate(x3);
- _mm_store_ss(pm2 + col, x3); x3 = _rotate(x3);
- _mm_store_ss(pm3 + col, x3);
- }
- }
-#else
- const int * pposMin = pposY + iposMinY;
- const int * pposLim = pposY + iposLimY;
- const float * pxLim = px + crow;
- float * pm = pmat - posMinY;
- const float * py = pvaluesY - posMinY;
-
- __m128 x0 = _mm_set1_ps(a);
- int d1 = 1 * ccol;
- int d2 = 2 * ccol;
- int d3 = 3 * ccol;
- int d4 = 4 * ccol;
- for (const int * ppos = pposMin; ppos < pposLim; ppos++)
- {
- int col = *ppos;
- __m128 x2 = _mm_set1_ps(py[col]);
- x2 = _mm_mul_ps(x2, x0);
-
- float * pm0 = pm + col;
- for (const float * px0 = px; px0 < pxLim; px0 += 4, pm0 += d4)
- {
- __m128 x1 = _mm_load_ps(px0);
- __m128 x3 = _mm_setr_ps(pm0[0], pm0[d1], pm0[d2], pm0[d3]);
- x1 = _mm_mul_ps(x1, x2);
- x3 = _mm_add_ps(x3, x1);
-
- _mm_store_ss(pm0, x3); x3 = _rotate(x3);
- _mm_store_ss(pm0 + d1, x3); x3 = _rotate(x3);
- _mm_store_ss(pm0 + d2, x3); x3 = _rotate(x3);
- _mm_store_ss(pm0 + d3, x3);
- }
- }
-#endif
-}
-
-template
-void AddXYTranRUCore(float a, _In_ const float * px, _In_ const float * py,
- _In_ const int * pstarts, _In_ const int * pindices, _Inout_ float * pcoefs, int crow, float decay)
-{
- const int * pii = pstarts + 1;
- const int * pi = pindices;
- float * pm = pcoefs;
- const float * pxLim = px + crow;
- __m128 wd;
- if (useDecay)
- wd = _mm_set1_ps(1 - decay);
- for (; px < pxLim; px++)
- {
- const int * piLim = pindices + *pii++;
- float r = a * *px;
-
- __m128 x1 = _mm_set1_ps(r);
- for (; pi + 4 <= piLim; pi += 4, pm += 4)
- {
- __m128 x2 = _mm_mul_ps(x1, _load4(py, pi));
- __m128 x3 = _mm_loadu_ps(pm);
- if (useDecay)
- x3 = _mm_mul_ps(x3, wd);
- x2 = _mm_add_ps(x2, x3);
- _mm_storeu_ps(pm, x2);
- }
- for (; pi < piLim; pi++, pm++)
- *pm = (useDecay ? (*pm * (1 - decay)) : *pm) + py[*pi] * r;
- }
-}
-
-// Sparse matrix.
-EXPORT_API(void) AddXYTranRU(float a, _In_ const float * px, _In_ const float * py,
- _In_ const int * pstarts, _In_ const int * pindices, _Inout_ float * pcoefs, int crow, float decay)
-{
- if (decay == 0)
- AddXYTranRUCore(a, px, py, pstarts, pindices, pcoefs, crow, decay);
- else
- AddXYTranRUCore(a, px, py, pstarts, pindices, pcoefs, crow, decay);
-}
-
-// Unpadded convolution.
-EXPORT_API(void) AddXYTranCU(float a, _In_ const float * px, _In_ const float * py, _In_ const int * pmprowiv,
- _In_ const int * pmprowcol, _In_ const int * pruns, _Inout_ float * pcoefs, int crow)
-{
- int size = pruns[1];
- const int * psupport = pruns + 2;
- const int * piv = pmprowiv;
- const int * pcol = pmprowcol;
- const float * pxLim = px + crow;
- const int * piLim = psupport + size;
-
- for (; px < pxLim; px++)
- {
- float * pm = pcoefs + *piv++;
- const float * ps = py + *pcol++;
- const int * pi = psupport;
- float r = a * *px;
-
- __m128 x1 = _mm_set1_ps(r);
- for (; pi + 4 <= piLim; pi += 4, pm += 4)
- {
- __m128 x2 = _mm_mul_ps(x1, _load4(ps, pi));
- x2 = _mm_add_ps(x2, _mm_loadu_ps(pm));
- _mm_storeu_ps(pm, x2);
- }
- for (; pi < piLim; pi++, pm++)
- *pm += ps[*pi] * r;
- // Update the bias.
- *pm += r;
- }
-}
-
-// Padded convolution.
-EXPORT_API(void) AddXYTranDU(float a, _In_ const float * px, _In_ const float * py, _In_ const int * pmprowiv,
- _In_ const int * pmprowcol, _In_ const int * pmprowrun, _In_ const int * pruns, _Inout_ float * pcoefs, int crow)
-{
- const int * piv = pmprowiv;
- const int * pcol = pmprowcol;
- const float * pxLim = px + crow;
- int kernelSize = pruns[1];
-
- const int * pirun = pmprowrun;
- for (; px < pxLim; px++)
- {
- float * pm = pcoefs + *piv++;
- const float * ps = py + *pcol++;
- int irun = *pirun++;
- const int * pi = pruns + 2 + irun;
- const int * piLim = pi + pi[-1];
-
- float r = a * *px;
-
- // Update the bias.
- pm[kernelSize] += r;
-
- __m128 x1 = _mm_set1_ps(r);
- if (irun == 0)
- {
- // No masking needed.
- for (; pi + 4 <= piLim; pi += 4, pm += 4)
- {
- __m128 x2 = _mm_mul_ps(x1, _load4(ps, pi));
- x2 = _mm_add_ps(x2, _mm_loadu_ps(pm));
- _mm_storeu_ps(pm, x2);
- }
- for (; pi < piLim; pi++, pm++)
- *pm += ps[*pi] * r;
- }
- else
- {
- // Need masking.
- pm += pi[-2];
- const float * pmask = reinterpret_cast(piLim);
- for (; pi + 4 <= piLim; pi += 4, pm += 4, pmask += 4)
- {
- __m128 x2 = _mm_mul_ps(_mm_and_ps(_mm_loadu_ps(pmask), x1), _load4(ps, pi));
- x2 = _mm_add_ps(x2, _mm_loadu_ps(pm));
- _mm_storeu_ps(pm, x2);
- }
- for (; pi < piLim; pi++, pm++, pmask++)
- {
- __m128 x2 = _mm_mul_ss(_mm_and_ps(_mm_set_ss(*pmask), x1), _load1(ps, pi));
- x2 = _mm_add_ss(x2, _mm_set_ss(*pm));
- _mm_store_ss(pm, x2);
- }
- }
- }
-}
-
-// With momentum.
-EXPORT_API(void) AddXYTranMomA(float a, _In_ const float * px, _In_ const float * py, _Inout_ float * pmat, float momentum, _Inout_ float * pdel, int crow, int ccol)
-{
- const float * pyBase = py;
- const float * pxLim = px + crow;
- const float * pyLim = py + ccol;
- float * pm = pmat;
- float * pd = pdel;
-
- __m128 x0 = _mm_set1_ps(momentum);
- for (; px < pxLim; px++)
- {
- float r = a * *px;
-
- __m128 x1 = _mm_set1_ps(r);
- for (py = pyBase; py < pyLim; pm += 4, pd += 4, py += 4)
- {
- __m128 x2 = _mm_load_ps(py);
- __m128 x3 = _mm_load_ps(pd);
- __m128 x4 = _mm_load_ps(pm);
- x2 = _mm_mul_ps(x1, x2);
- x3 = _mm_mul_ps(x0, x3);
- x3 = _mm_add_ps(x2, x3);
- x4 = _mm_add_ps(x3, x4);
-
- _mm_store_ps(pd, x3);
- _mm_store_ps(pm, x4);
- }
- }
-}
-
-// coef: coefs to update, ag: accumulated grads, au: accumulated updates, g: cur grads.
-// Note: parameters coef, ag, au and g will be updated, do not reuse parameter g in calling code.
-__forceinline void UpdateAdadelta(__m128& coef, __m128& ag, __m128& au, __m128& g, const __m128& dec, const __m128& decc, const __m128& c)
-{
- __m128 x4 = _mm_mul_ps(g, g); // x4 == g * g
- x4 = _mm_mul_ps(decc, x4); // x4 == (1 - decay) * g * g
- ag = _mm_mul_ps(dec, ag); // ag == decay * accG
- ag = _mm_add_ps(ag, x4); // ag == decay * accG + (1 - decay) * g * g
- __m128 x41 = _mm_add_ps(ag, c); // x41 == ag + cond
- __m128 x51 = _mm_add_ps(au, c); // x51 == accU + cond
-#if 0
- // naive version:
- x51 = _mm_div_ps(x51, x41);
- x41 = _mm_sqrt_ps(x51); // x41 == rate
-#else
- // faster (approximate) version:
- x41 = _mm_rsqrt_ps(x41);
- __m128 x52 = _mm_rsqrt_ps(x51);
- x51 = _mm_mul_ps(x51, x52);
- x41 = _mm_mul_ps(x41, x51); // x41 == rate
-#endif
- g = _mm_mul_ps(g, x41); // g - current update
- coef = _mm_add_ps(coef, g);
-
- g = _mm_mul_ps(g, g); // g == newU * newU
- g = _mm_mul_ps(decc, g); // g == (1 - decay) * newU * newU
- au = _mm_mul_ps(dec, au); // au == decay * accU
- au = _mm_add_ps(au, g); // au == decay * accU + (1 - decay) * newU * newU
-}
-
-// For Adadelta.
-EXPORT_API(void) AddXYTranGradA(_In_ const float * px, _In_ const float * py, _Inout_ float * pmat, _Inout_ float * paccGrads, _Inout_ float * paccUpdates,
- float decay, float cond, int crow, int ccol)
-{
- const float * pyBase = py;
- const float * pxLim = px + crow;
- const float * pyLim = py + ccol;
- float * pm = pmat;
- float * pag = paccGrads;
- float * pau = paccUpdates;
-
- __m128 dec = _mm_set1_ps(decay);
- __m128 decc = _mm_set1_ps(1 - decay);
- __m128 c = _mm_set1_ps(cond);
- for (; px < pxLim; px++)
- {
- float r = *px;
-
- __m128 x1 = _mm_set1_ps(r);
- for (py = pyBase; py < pyLim; pm += 4, pag += 4, pau += 4, py += 4)
- {
- __m128 x2 = _mm_load_ps(py);
- __m128 ag = _mm_load_ps(pag);
- __m128 au = _mm_load_ps(pau);
- __m128 coef = _mm_load_ps(pm);
- x2 = _mm_mul_ps(x1, x2); // x2 == g
-
- UpdateAdadelta(coef, ag, au, x2, dec, decc, c);
-
- _mm_store_ps(pm, coef);
- _mm_store_ps(pag, ag);
- _mm_store_ps(pau, au);
- }
- }
-}
-
-// For Adadelta, sparse matrix.
-EXPORT_API(void) AddXYTranGradRU(_In_ const float * px, _In_ const float * py, _In_ const int * pstarts, _In_ const int * pindices,
- _Inout_ float * pcoefs, _Inout_ float * paccGrads, _Inout_ float * paccUpdates, float decay, float cond, int crow)
-{
- const int * pii = pstarts + 1;
- const int * pi = pindices;
- float * pm = pcoefs;
- const float * pxLim = px + crow;
- float * pag = paccGrads;
- float * pau = paccUpdates;
-
- __m128 dec = _mm_set1_ps(decay);
- __m128 decc = _mm_set1_ps(1 - decay);
- __m128 c = _mm_set1_ps(cond);
-
- for (; px < pxLim; px++)
- {
- const int * piLim = pindices + *pii++;
- float r = *px;
-
- __m128 x1 = _mm_set1_ps(r);
- for (; pi + 4 <= piLim; pi += 4, pm += 4, pag += 4, pau += 4)
- {
- __m128 g = _mm_mul_ps(x1, _load4(py, pi));
- __m128 ag = _mm_loadu_ps(pag);
- __m128 au = _mm_loadu_ps(pau);
- __m128 coef = _mm_loadu_ps(pm);
-
- UpdateAdadelta(coef, ag, au, g, dec, decc, c);
-
- _mm_storeu_ps(pm, coef);
- _mm_storeu_ps(pag, ag);
- _mm_storeu_ps(pau, au);
- }
-
- if (pi < piLim)
- {
- size_t ctail = piLim - pi;
- __m128 g = _mm_mul_ss(_load1(py, pi++), x1);
- __m128 ag = _mm_load_ss(pag++);
- __m128 au = _mm_load_ss(pau++);
- __m128 coef = _mm_load_ss(pm++);
- for (; pi < piLim; pi++, pm++, pag++, pau++)
- {
- g = _mm_or_ps(_mm_mul_ss(_load1(py, pi), x1), _rotate(g));
- ag = _mm_or_ps(_mm_load_ss(pag), _rotate(ag));
- au = _mm_or_ps(_mm_load_ss(pau), _rotate(au));
- coef = _mm_or_ps(_mm_load_ss(pm), _rotate(coef));
- }
- UpdateAdadelta(coef, ag, au, g, dec, decc, c);
- for (int i = 0; i < ctail; i++)
- {
- _mm_store_ss(pm - i - 1, coef);
- coef = _rotate_reverse(coef);
- _mm_store_ss(pag - i - 1, ag);
- ag = _rotate_reverse(ag);
- _mm_store_ss(pau - i - 1, au);
- au = _rotate_reverse(au);
- }
- }
- }
-}
-
-// For Adadelta, partial sparse source vector.
-EXPORT_API(void) AddXYTranGradPA(_In_ const float * px, _In_ const int * pposY, _In_ const float * pvaluesY,
- int posMinY, int iposMinY, int iposLimY, _Inout_ float * pmat, _Inout_ float * paccGrads, _Inout_ float * paccUpdates,
- float decay, float cond, int crow, int ccol)
-{
- const int * pposMin = pposY + iposMinY;
- const int * pposLim = pposY + iposLimY;
- const float * pxLim = px + crow;
- const float * py = pvaluesY - posMinY;
- float * pm0 = pmat - posMinY;
- float * pag0 = paccGrads - posMinY;
- float * pau0 = paccUpdates - posMinY;
-
- __m128 dec = _mm_set1_ps(decay);
- __m128 decc = _mm_set1_ps(1 - decay);
- __m128 c = _mm_set1_ps(cond);
- for (; px < pxLim; px += 4, pm0 += 4 * ccol, pag0 += 4 * ccol, pau0 += 4 * ccol)
- {
- float * pm1 = pm0 + ccol;
- float * pm2 = pm1 + ccol;
- float * pm3 = pm2 + ccol;
-
- float * pag1 = pag0 + ccol;
- float * pag2 = pag1 + ccol;
- float * pag3 = pag2 + ccol;
-
- float * pau1 = pau0 + ccol;
- float * pau2 = pau1 + ccol;
- float * pau3 = pau2 + ccol;
-
- __m128 x1 = _mm_load_ps(px);
-
- for (const int * ppos = pposMin; ppos < pposLim; ppos++)
- {
- int col = *ppos;
- __m128 x2 = _mm_set1_ps(py[col]);
- __m128 ag = _mm_setr_ps(pag0[col], pag1[col], pag2[col], pag3[col]);
- __m128 au = _mm_setr_ps(pau0[col], pau1[col], pau2[col], pau3[col]);
- __m128 coef = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]);
- x2 = _mm_mul_ps(x2, x1);
-
- UpdateAdadelta(coef, ag, au, x2, dec, decc, c);
-
- _mm_store_ss(pm0 + col, coef); coef = _rotate(coef);
- _mm_store_ss(pm1 + col, coef); coef = _rotate(coef);
- _mm_store_ss(pm2 + col, coef); coef = _rotate(coef);
- _mm_store_ss(pm3 + col, coef);
-
- _mm_store_ss(pag0 + col, ag); ag = _rotate(ag);
- _mm_store_ss(pag1 + col, ag); ag = _rotate(ag);
- _mm_store_ss(pag2 + col, ag); ag = _rotate(ag);
- _mm_store_ss(pag3 + col, ag);
-
- _mm_store_ss(pau0 + col, au); au = _rotate(au);
- _mm_store_ss(pau1 + col, au); au = _rotate(au);
- _mm_store_ss(pau2 + col, au); au = _rotate(au);
- _mm_store_ss(pau3 + col, au);
- }
- }
-}
-
-// pd[i] += a
-EXPORT_API(void) AddScalarU(float a, _Inout_ float * pd, int c)
-{
- float * pdLim = pd + c;
-
- __m128 x1 = _mm_set1_ps(a);
- for (; pd + 4 <= pdLim; pd += 4)
- {
- __m128 x2 = _mm_loadu_ps(pd);
- x2 = _mm_add_ps(x2, x1);
- _mm_storeu_ps(pd, x2);
- }
-
- for (; pd < pdLim; pd++)
- {
- __m128 x2 = _mm_load_ss(pd);
- x2 = _mm_add_ss(x2, x1);
- _mm_store_ss(pd, x2);
- }
-}
-
-EXPORT_API(void) Scale(float a, _Inout_ float * pd, int c)
-{
- __m128 x1 = _mm_set1_ps(a);
-
- if (c < 4)
- {
- switch(c)
- {
- case 3: pd[2] *= a;
- case 2: pd[1] *= a;
- case 1: pd[0] *= a;
- }
- return;
- }
-
- uintptr_t address = (uintptr_t)(pd);
- uintptr_t misalignment = address % 16;
- int remainder = 0;
-
- if ((misalignment & 3) != 0)
- {
- // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations
- remainder = c % 4;
-
- for (const float* pEnd = pd + (c - remainder); pd < pEnd; pd += 4)
- {
- __m128 x2 = _mm_loadu_ps(pd);
- x2 = _mm_mul_ps(x1, x2);
- _mm_storeu_ps(pd, x2);
- }
- }
- else
- {
- if (misalignment != 0)
- {
- // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then
- // masking any elements that will be included in the first aligned read
- misalignment >>= 2;
- misalignment = 4 - misalignment;
-
- __m128 result = _mm_loadu_ps(pd);
-
- __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4));
- __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + ((4 - misalignment) * 4));
-
- __m128 temp = _mm_and_ps(result, leadingMask);
- result = _mm_and_ps(result, trailingMask);
-
- temp = _mm_mul_ps(temp, x1);
- result = _mm_or_ps(temp, result);
-
- _mm_storeu_ps(pd, result);
-
- pd += misalignment;
- c -= misalignment;
- }
-
- if (c > 3)
- {
- // Handle all the 128-bit blocks that we can now that we have offset to an aligned address
- remainder = c % 4;
- for (const float* pEnd = pd + (c - remainder); pd < pEnd; pd += 4)
- {
- __m128 x2 = _mm_load_ps(pd);
- x2 = _mm_mul_ps(x1, x2);
- _mm_storeu_ps(pd, x2);
- }
- }
- else
- {
- // Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not
- // 128-bit aligned. This means we can't do any aligned loads and will just end up doing two
- // unaligned loads where we mask the input each time.
- remainder = c;
- }
- }
-
- if (remainder != 0)
- {
- // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next
- // unaligned load will read to the end of the array and then mask out any elements already processed
-
- pd -= (4 - remainder);
- __m128 result = _mm_loadu_ps(pd);
-
- __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4));
- __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + ((4 - remainder) * 4));
-
- __m128 temp = _mm_and_ps(result, trailingMask);
- result = _mm_and_ps(result, leadingMask);
-
- temp = _mm_mul_ps(temp, x1);
- result = _mm_or_ps(temp, result);
-
- _mm_storeu_ps(pd, result);
- }
-}
-
-EXPORT_API(void) ScaleSrcU(float a, _In_ const float * ps, _Inout_ float * pd, int c)
-{
- float * pdLim = pd + c;
-
- __m128 x1 = _mm_set1_ps(a);
- for (; pd + 4 <= pdLim; pd += 4, ps += 4)
- {
- __m128 x2 = _mm_loadu_ps(ps);
- x2 = _mm_mul_ps(x2, x1);
- _mm_storeu_ps(pd, x2);
- }
-
- for (; pd < pdLim; pd++, ps++)
- {
- __m128 x2 = _mm_load_ss(ps);
- x2 = _mm_mul_ss(x2, x1);
- _mm_store_ss(pd, x2);
- }
-}
-
-// pd[i] = a * (pd[i] + b)
-EXPORT_API(void) ScaleAddU(float a, float b, _Inout_ float * pd, int c)
-{
- float * pdLim = pd + c;
-
- __m128 x1 = _mm_set1_ps(a);
- __m128 x2 = _mm_set1_ps(b);
- for (; pd + 4 <= pdLim; pd += 4)
- {
- __m128 x3 = _mm_loadu_ps(pd);
- x3 = _mm_add_ps(x3, x2);
- x3 = _mm_mul_ps(x3, x1);
- _mm_storeu_ps(pd, x3);
- }
-
- for (; pd < pdLim; pd++)
- {
- __m128 x3 = _mm_load_ss(pd);
- x3 = _mm_add_ss(x3, x2);
- x3 = _mm_mul_ss(x3, x1);
- _mm_store_ss(pd, x3);
- }
-}
-
-EXPORT_API(void) ScaleMaxNormA(float maxNorm, _Inout_ float * pmat, int crow, int ccol)
-{
- float * pm = pmat;
- float maxNormSq = maxNorm * maxNorm;
- __m128 m = _mm_set1_ps(maxNorm);
- for (int irow = 0; irow < crow; irow++)
- {
- __m128 rowNorm = _mm_set1_ps(0);
- float * pms = pm;
- float * pmLim = pm + ccol;
- for (; pm < pmLim; pm += 4)
- {
- __m128 x1 = _mm_load_ps(pm);
- x1 = _mm_mul_ps(x1, x1);
- rowNorm = _mm_add_ps(x1, rowNorm);
- }
- rowNorm = _mm_hadd_ps(rowNorm, rowNorm);
- rowNorm = _mm_hadd_ps(rowNorm, rowNorm);
- float rowNormRes = _mm_cvtss_f32(rowNorm);
- if (rowNormRes > maxNormSq)
- {
- __m128 scale = _mm_set1_ps(rowNormRes);
-#if 0
- // REVIEW: this is faster but it uses approximation so results differ significantly from CLR.
- scale = _mm_rsqrt_ps(scale);
- scale = _mm_mul_ps(scale, m);
-#else
- scale = _mm_sqrt_ps(scale);
- scale = _mm_div_ps(m, scale);
-#endif
- for (pm = pms; pm < pmLim; pm += 4)
- {
- __m128 x1 = _mm_load_ps(pm);
- x1 = _mm_mul_ps(x1, scale);
- _mm_store_ps(pm, x1);
- }
- }
- }
-}
-
-EXPORT_API(void) ScaleMaxNormTranU(float maxNorm, _Inout_ float * pmat, int crow, int ccol)
-{
- for (int icol = 0; icol < ccol; icol++)
- {
- float * pm = pmat + icol;
- float rowNorm = 0;
- for (int irow = 0; irow < crow; irow++)
- {
- rowNorm += *pm * *pm;
- pm += ccol;
- }
- if (rowNorm > maxNorm * maxNorm)
- {
- float scale = maxNorm / sqrtf(rowNorm);
- pm = pmat + icol;
- for (int irow = 0; irow < crow; irow++)
- {
- *pm *= scale;
- pm += ccol;
- }
- }
- }
-}
-
-// Sparse matrix.
-EXPORT_API(void) ScaleMaxNormRU(float maxNorm, _In_ const int * pstarts, _Inout_ float * pmat, int crow)
-{
- for (int irow = 0; irow < crow; irow++)
- {
- float rowNorm = 0;
- for (int idx = pstarts[irow]; idx < pstarts[irow + 1]; idx++)
- {
- rowNorm += pmat[idx] * pmat[idx];
- }
- if (rowNorm > maxNorm * maxNorm)
- {
- float scale = maxNorm / sqrtf(rowNorm);
- for (int idx = pstarts[irow]; idx < pstarts[irow + 1]; idx++)
- {
- pmat[idx] *= scale;
- }
- }
- }
-}
-
-// Convolution.
-EXPORT_API(void) ScaleMaxNormCU(float maxNorm, int kernCount, int kernSize, _Inout_ float * pmat)
-{
- float * pm = pmat;
- for (int irow = 0; irow < kernCount; irow++)
- {
- float rowNorm = 0;
- for (int icol = 0; icol < kernSize; icol++)
- {
- rowNorm += *pm * *pm;
- pm++;
- }
- if (rowNorm > maxNorm * maxNorm)
- {
- float scale = maxNorm / sqrtf(rowNorm);
- pm -= kernSize;
- for (int icol = 0; icol < kernSize; icol++)
- {
- *pm *= scale;
- pm++;
- }
- }
- // Skip bias.
- pm++;
- }
-}
-
-EXPORT_API(void) AddScaleA(float a, _In_ const float * ps, _Inout_ float * pd, int c)
+// pd[i] = a * (pd[i] + b)
+EXPORT_API(void) ScaleAddU(float a, float b, _Inout_ float * pd, int c)
{
float * pdLim = pd + c;
__m128 x1 = _mm_set1_ps(a);
- for (; pd < pdLim; pd += 4, ps += 4)
+ __m128 x2 = _mm_set1_ps(b);
+ for (; pd + 4 <= pdLim; pd += 4)
{
- __m128 x2 = _mm_load_ps(ps);
- __m128 x3 = _mm_load_ps(pd);
- x2 = _mm_mul_ps(x1, x2);
- x3 = _mm_add_ps(x2, x3);
- _mm_store_ps(pd, x3);
+ __m128 x3 = _mm_loadu_ps(pd);
+ x3 = _mm_add_ps(x3, x2);
+ x3 = _mm_mul_ps(x3, x1);
+ _mm_storeu_ps(pd, x3);
+ }
+
+ for (; pd < pdLim; pd++)
+ {
+ __m128 x3 = _mm_load_ss(pd);
+ x3 = _mm_add_ss(x3, x2);
+ x3 = _mm_mul_ss(x3, x1);
+ _mm_store_ss(pd, x3);
}
}
@@ -2047,97 +845,6 @@ EXPORT_API(void) AddScaleSU(float a, _In_ const float * ps, _In_ const int * pi,
pd[*pi] += a * *ps;
}
-EXPORT_API(void) AddScaleMomA(float a, _In_ const float * ps, _Inout_ float * pd, float momentum, _Inout_ float * pe, int c)
-{
- float * pdLim = pd + c;
-
- __m128 x0 = _mm_set1_ps(momentum);
- __m128 x1 = _mm_set1_ps(a);
- for (; pd < pdLim; pd += 4, pe += 4, ps += 4)
- {
- __m128 x2 = _mm_load_ps(ps);
- __m128 x3 = _mm_load_ps(pe);
- __m128 x4 = _mm_load_ps(pd);
- x2 = _mm_mul_ps(x1, x2);
- x3 = _mm_mul_ps(x0, x3);
- x3 = _mm_add_ps(x2, x3);
- x4 = _mm_add_ps(x3, x4);
- _mm_store_ps(pe, x3);
- _mm_store_ps(pd, x4);
- }
-}
-
-EXPORT_API(void) AddScaleGradA(_In_ const float * ps, _Inout_ float * pd, _Inout_ float * paccGrads, _Inout_ float * paccUpdates,
- float decay, float cond, int c)
-{
- float * pdLim = pd + c;
-
- __m128 dec = _mm_set1_ps(decay);
- __m128 decc = _mm_set1_ps(1 - decay);
- __m128 cnd = _mm_set1_ps(cond);
- for (; pd < pdLim; pd += 4, ps += 4, paccGrads += 4, paccUpdates += 4)
- {
- __m128 g = _mm_load_ps(ps);
- __m128 ag = _mm_load_ps(paccGrads);
- __m128 au = _mm_load_ps(paccUpdates);
- __m128 coef = _mm_load_ps(pd);
-
- UpdateAdadelta(coef, ag, au, g, dec, decc, cnd);
-
- _mm_store_ps(pd, coef);
- _mm_store_ps(paccGrads, ag);
- _mm_store_ps(paccUpdates, au);
- }
-}
-
-EXPORT_API(void) AddScaleMultiA(int count, _In_ const float * ps, _Inout_ float * pd, _Inout_ float * paccGrads, _Inout_ float * paccUpdates,
- float decay, float cond, int size)
-{
- if (1 == count)
- AddScaleGradA(ps, pd, paccGrads, paccUpdates, decay, cond, size);
- else
- {
- float * pdLim = pd + size;
-
- __m128 dec = _mm_set1_ps(decay);
- __m128 decc = _mm_set1_ps(1 - decay);
- __m128 cnd = _mm_set1_ps(cond);
- for (; pd < pdLim; pd += 4, ps += 4, paccGrads += 4, paccUpdates += 4)
- {
- __m128 g = _mm_set1_ps(0);
- const float * ps1 = ps;
- // REVIEW: unroll?
- for (int i = 0; i < count; i++, ps1 += size)
- {
- __m128 x1 = _mm_load_ps(ps1);
- g = _mm_add_ps(x1, g);
- }
- __m128 ag = _mm_load_ps(paccGrads);
- __m128 au = _mm_load_ps(paccUpdates);
- __m128 coef = _mm_load_ps(pd);
-
- UpdateAdadelta(coef, ag, au, g, dec, decc, cnd);
-
- _mm_store_ps(pd, coef);
- _mm_store_ps(paccGrads, ag);
- _mm_store_ps(paccUpdates, au);
- }
- }
-}
-
-EXPORT_API(void) AddA(_In_ const float * ps, _Inout_ float * pd, int c)
-{
- float * pdLim = pd + c;
-
- for (; pd < pdLim; pd += 4, ps += 4)
- {
- __m128 x1 = _mm_load_ps(ps);
- __m128 x2 = _mm_load_ps(pd);
- x2 = _mm_add_ps(x1, x2);
- _mm_store_ps(pd, x2);
- }
-}
-
EXPORT_API(void) AddU(_In_ const float * ps, _Inout_ float * pd, int c)
{
float * pdLim = pd + c;
@@ -2196,36 +903,6 @@ EXPORT_API(void) MulElementWiseU(_In_ const float * ps1, _In_ const float * ps2,
}
}
-EXPORT_API(void) MulElementWiseSU(_In_ const float * ps1, _In_ const float * ps2, _In_ const int * pi, _Inout_ float * pd, int c)
-{
- const int * piLim = pi + c;
-
- for (; pi + 4 <= piLim; pi += 4)
- {
- __m128 x1 = _load4(ps1, pi);
- __m128 x2 = _load4(ps2, pi);
- x2 = _mm_mul_ps(x1, x2);
- _store4(x2, pd, pi);
- }
-
- for (; pi < piLim; pi++)
- pd[*pi] = ps1[*pi] * ps2[*pi];
-}
-
-EXPORT_API(float) SumA(const float * ps, int c)
-{
- const float * psLim = ps + c;
-
- __m128 res = _mm_setzero_ps();
- for (; ps < psLim; ps += 4)
- res = _mm_add_ps(res, _mm_load_ps(ps));
-
- res = _mm_hadd_ps(res, res);
- res = _mm_hadd_ps(res, res);
-
- return _mm_cvtss_f32(res);
-}
-
EXPORT_API(float) SumU(const float * ps, int c)
{
const float * psLim = ps + c;
@@ -2448,449 +1125,6 @@ EXPORT_API(float) Dist2(const float * px, const float * py, int c)
return norm2;
}
-// This is modeled after double-based SSE code
-
-// 1 / ln(2).
-const float RecipLn2 = (float)1.44269504088896340735992468100;
-
-// Used for computing a 4th degree polynomial approximation of e^x.
-const float Coef1 = (float)0.013555747234814917704030793;
-const float Coef2 = (float)0.065588116243247810171479524;
-const float Coef3 = (float)0.3069678791803394491901401;
-
-const float ExpInf = 128;
-const int ExpBias = 127;
-const int ExpShift = 23;
-
-float ExpFast(float arg)
-{
- bool neg = false;
- if (arg < 0)
- {
- arg = -arg;
- neg = true;
- }
-
- arg *= RecipLn2;
- if (arg >= ExpInf)
- return neg ? 0.0f : std::numeric_limits::infinity();
-
- int exp = (int)arg;
- arg -= exp;
- exp += ExpBias;
- exp <<= ExpShift;
-
- float res = (1 + arg) + (arg - 1) * arg * ((Coef1 * arg + Coef2) * arg + Coef3);
- res *= *(float *)&exp;
-
- if (neg)
- res = 1 / res;
- return res;
-}
-
-// Implements a fast approximation of sigmoid/tanh.
-template
-void ApplySigmoidCoreA(_In_ const float * ps, _Inout_ float * pd, int c)
-{
- float * pdLim = pd + c;
-
- __m128 cSign = _mm_set1_ps(-0.0f);
- __m128 cZero = _mm_set1_ps(0.0f);
- __m128 cOne = _mm_set1_ps(1.0f);
-
- __m128 cMax = _mm_set1_ps(ExpInf);
- __m128i cBias = _mm_set1_epi32(ExpBias);
- __m128 c0 = _mm_set1_ps(RecipLn2);
- __m128 c1 = _mm_set1_ps(Coef1);
- __m128 c2 = _mm_set1_ps(Coef2);
- __m128 c3 = _mm_set1_ps(Coef3);
-
- if (isTanh)
- c0 = _mm_add_ps(c0, c0);
-
- for (; pd < pdLim; ps += 4, pd += 4)
- {
- // Get the argument, capture its sign and take its absolute value.
- __m128 xArg = _mm_load_ps(ps);
- // maskNaN is set to zero if xArg is not NaN and set equal to xArg otherwise.
- __m128 maskNaN = _mm_and_ps(_mm_cmpneq_ps(xArg, xArg), xArg);
- __m128 xSign = _mm_and_ps(xArg, cSign);
- xArg = _mm_xor_ps(xArg, xSign);
-
- // Multiply by 1/ln(2) and check for out of bounds.
- xArg = _mm_mul_ps(xArg, c0);
- __m128 xGood = _mm_cmplt_ps(xArg, cMax);
- xArg = _mm_and_ps(xArg, xGood);
-
- // Get the integer and fractional parts.
- __m128i xInt = _mm_cvttps_epi32(xArg);
- xArg = _mm_sub_ps(xArg, _mm_cvtepi32_ps(xInt));
-
- // Add the exponent bias to xInt, then convert to a floating point
- // power of two by shifting past the mantissa bits.
- xInt = _mm_add_epi32(xInt, cBias);
- xInt = _mm_slli_epi32(xInt, ExpShift);
-
- // Approximate 2 raised to the fractional part.
- // (1 + f) + (f - 1) * f * ((c1 * f + c2) * f + c3)
-
- // x1 = (c1 * f + c2) * f + c3
- __m128 x1 = _mm_mul_ps(c1, xArg);
- x1 = _mm_add_ps(x1, c2);
- x1 = _mm_mul_ps(x1, xArg);
- x1 = _mm_add_ps(x1, c3);
-
- // x2 = f * (f - 1)
- __m128 x2 = _mm_sub_ps(xArg, cOne);
- x2 = _mm_mul_ps(xArg, x2);
-
- // Add (1 + f). Note that for tanh, we only add f, so we are approximating
- // 2^f - 1. This is necessary to preserve precision near zero. In particular,
- // near zero, tanh(x) ~ x.
- x1 = _mm_mul_ps(x2, x1);
- if (!isTanh)
- xArg = _mm_add_ps(xArg, cOne);
- x1 = _mm_add_ps(xArg, x1);
-
- // Multiply by 2^n, where n is the integer part.
- __m128 x3 = _mm_castsi128_ps(xInt);
- x1 = _mm_mul_ps(x1, x3);
-
- if (!isTanh)
- {
- // Add 1, and take the reciprocal.
- x1 = _mm_add_ps(x1, cOne);
- x1 = _mm_div_ps(cOne, x1);
-
- // Deal with out of bounds.
- x1 = _mm_and_ps(x1, xGood);
- // If the input was NaN, xGood is zero, so x1 is zero. So can simply or in maskNaN.
- x1 = _mm_or_ps(x1, maskNaN);
-
- // Deal with the sign. Set:
- // * x2 = x1 if xSign is -0 (0x80000000)
- // * x2 = 1 - x1 if xSign is +0 (0x00000000).
- x1 = _mm_or_ps(x1, xSign);
- x2 = _mm_or_ps(xSign, cOne);
- x2 = _mm_max_ps(x2, cZero);
- x2 = _mm_sub_ps(x2, x1);
- }
- else
- {
- // [2^n(2^f - 1) + (2^n - 1)] / [2^n(2^f - 1) + (2^n + 1)]
- x2 = _mm_add_ps(x1, _mm_sub_ps(x3, cOne));
- x1 = _mm_add_ps(x1, _mm_add_ps(x3, cOne));
- x2 = _mm_div_ps(x2, x1);
-
- // Deal with out of bounds: x2 = (x2 & xGood) | ((1 + maskNaN) & ~xGood)
- x2 = _mm_and_ps(x2, xGood);
- x1 = _mm_andnot_ps(xGood, _mm_add_ps(maskNaN, cOne));
- x2 = _mm_or_ps(x2, x1);
-
- // Deal with the sign.
- x2 = _mm_or_ps(x2, xSign);
- }
-
- _mm_store_ps(pd, x2);
- }
-
- // If we overshot, back fill with zero! Since tanh(0) = 0, we only need to do this for sigmoid.
- if (!isTanh)
- {
- while (pd > pdLim)
- *--pd = 0.0f;
- }
-}
-
-EXPORT_API(void) ApplySigmoidA(_In_ const float * ps, _Inout_ float * pd, int c)
-{
- ApplySigmoidCoreA(ps, pd, c);
-}
-
-EXPORT_API(void) ApplySoftMaxU(_In_ const float * ps, _Inout_ float * pd, int c)
-{
- // REVIEW: Use SSE - do 4 at a time.
-
- const float * psLim = ps + c;
-
- // Compute max output.
- float maxOut = -std::numeric_limits::infinity();
- for (const float * p = ps; p < psLim; p++)
- {
- float v = *p;
- if (maxOut < v)
- maxOut = v;
- }
-
- // Compute exp and sum.
- float sum = 0;
- const float * p = ps;
- for (float * q = pd; p < psLim; p++, q++)
- {
- float v = ExpFast(*p - maxOut);
- *q = v;
- sum += v;
- }
-
- // Normalize.
- for (float * q = pd; q < pd + c; q++)
- *q /= sum;
-}
-
-EXPORT_API(void) ApplyRectifiedLinearA(_In_ const float * ps, _Inout_ float * pd, int c)
-{
- const float * psLim = ps + c;
-
- __m128 cZero = _mm_set1_ps(0.0f);
- for (; ps < psLim; ps += 4, pd += 4)
- {
- __m128 x1 = _mm_load_ps(ps);
- x1 = _mm_max_ps(x1, cZero);
- _mm_store_ps(pd, x1);
- }
-}
-
-EXPORT_API(void) ApplySquareA(_In_ const float * ps, _Inout_ float * pd, int c)
-{
- const float * psLim = ps + c;
-
- for (; ps < psLim; ps += 4, pd += 4)
- {
- __m128 x1 = _mm_load_ps(ps);
- x1 = _mm_mul_ps(x1, x1);
- _mm_store_ps(pd, x1);
- }
-}
-
-EXPORT_API(void) ApplySqrtA(_In_ const float * ps, _Inout_ float * pd, int c)
-{
- const float * psLim = ps + c;
-
- __m128 cZero = _mm_set1_ps(0.0f);
- for (; ps < psLim; ps += 4, pd += 4)
- {
- __m128 x1 = _mm_load_ps(ps);
- x1 = _mm_max_ps(x1, cZero);
- x1 = _mm_sqrt_ps(x1);
- _mm_store_ps(pd, x1);
- }
-}
-
-EXPORT_API(void) ApplySoftRectifiedLinearU(_In_ const float * ps, _Inout_ float * pd, int c)
-{
- const float * psLim = ps + c;
-
- // Apply: f(x) = log(1 + e^x). To avoid overflow for large x, we use the identity: f(x) = x + f(-x).
- // REVIEW: Should we implement a "LogFast"?
- // REVIEW: Do 4 at a time.
- const float * p = ps;
- for (float * q = pd; p < psLim; p++, q++)
- {
- float x = *p;
- if (x > 0)
- *q = x + log(1 + ExpFast(-x));
- else
- *q = log(1 + ExpFast(x));
- }
-}
-
-EXPORT_API(void) ApplyAbsA(_In_ const float * ps, _Inout_ float * pd, int c)
-{
- const float * psLim = ps + c;
-
- __m128 mask = _mm_castsi128_ps(_mm_set1_epi32(0x7FFFFFFF));
- for (; ps < psLim; ps += 4, pd += 4)
- {
- __m128 x1 = _mm_load_ps(ps);
- x1 = _mm_and_ps(x1, mask);
- _mm_store_ps(pd, x1);
- }
-}
-
-EXPORT_API(void) ApplyTanhA(_In_ const float * ps, _Inout_ float * pd, int c)
-{
- ApplySigmoidCoreA(ps, pd, c);
-}
-
-EXPORT_API(void) ApplyBoundedRectifiedLinearA(_In_ const float * ps, _Inout_ float * pd, int c)
-{
- const float * psLim = ps + c;
-
- __m128 cZero = _mm_set1_ps(0.0f);
- __m128 cOne = _mm_set1_ps(1.0f);
- for (; ps < psLim; ps += 4, pd += 4)
- {
- __m128 x1 = _mm_load_ps(ps);
- x1 = _mm_max_ps(x1, cZero);
- x1 = _mm_min_ps(x1, cOne);
- _mm_store_ps(pd, x1);
- }
-}
-
-EXPORT_API(void) ApplySigmoidDerivativeA(_In_ const float * pv, _Inout_ float * pg, int c)
-{
- float * pgLim = pg + c;
-
- // pg[i] *= pv[i] * (1 - pv[i])
- __m128 cOne = _mm_set1_ps(1.0f);
- for (; pg < pgLim; pg += 4, pv += 4)
- {
- __m128 x1 = _mm_load_ps(pv);
- __m128 x2 = _mm_load_ps(pg);
- __m128 x3 = _mm_sub_ps(cOne, x1);
- x1 = _mm_mul_ps(x1, x3);
- x2 = _mm_mul_ps(x2, x1);
- _mm_store_ps(pg, x2);
- }
-}
-
-EXPORT_API(void) ApplyRectifiedLinearDerivativeA(_In_ const float * pv, _Inout_ float * pg, int c)
-{
- float * pgLim = pg + c;
-
- __m128 cZero = _mm_set1_ps(0.0f);
- for (; pg < pgLim; pg += 4, pv += 4)
- {
- __m128 x1 = _mm_load_ps(pv);
- __m128 x2 = _mm_load_ps(pg);
- x1 = _mm_cmpgt_ps(x1, cZero);
- x2 = _mm_and_ps(x2, x1);
- _mm_store_ps(pg, x2);
- }
-}
-
-EXPORT_API(void) ApplySquareDerivativeA(_In_ const float * px, _In_opt_ const float * py, _Inout_ float * pg, int c, bool drop)
-{
- float * pgLim = pg + c;
-
- if (drop)
- {
- __m128 cZero = _mm_set1_ps(0.0f);
- for (; pg < pgLim; pg += 4, px += 4, py += 4)
- {
- __m128 x0 = _mm_cmpgt_ps(_mm_load_ps(py), cZero);
- __m128 x1 = _mm_load_ps(px);
- __m128 x2 = _mm_load_ps(pg);
- x1 = _mm_add_ps(x1, x1);
- x2 = _mm_mul_ps(x2, x1);
- x2 = _mm_and_ps(x2, x0);
- _mm_store_ps(pg, x2);
- }
- }
- else
- {
- for (; pg < pgLim; pg += 4, px += 4)
- {
- __m128 x1 = _mm_load_ps(px);
- __m128 x2 = _mm_load_ps(pg);
- x1 = _mm_add_ps(x1, x1);
- x2 = _mm_mul_ps(x2, x1);
- _mm_store_ps(pg, x2);
- }
- }
-}
-
-EXPORT_API(void) ApplySqrtDerivativeA(_In_ const float * pv, _Inout_ float * pg, int c)
-{
- float * pgLim = pg + c;
- static const float smallValue = 1e-10F;
-
- __m128 cZero = _mm_set1_ps(0.0f);
- __m128 cSmall = _mm_set1_ps(smallValue);
- for (; pg < pgLim; pg += 4, pv += 4)
- {
- __m128 x1 = _mm_load_ps(pv);
- __m128 x2 = _mm_load_ps(pg);
- __m128 x3 = _mm_cmpgt_ps(x1, cZero);
- x1 = _mm_max_ps(x1, cSmall);
- x1 = _mm_add_ps(x1, x1);
- x2 = _mm_and_ps(x2, x3);
- x2 = _mm_div_ps(x2, x1);
- _mm_store_ps(pg, x2);
- }
-}
-
-EXPORT_API(void) ApplySoftRectifiedLinearDerivativeU(_In_opt_ const float * px, _In_ const float * py, _Inout_ float * pg, int c)
-{
- UNUSED(px);
-
- float * pgLim = pg + c;
-
- // Use the identity: y' = 1 - e^(-y). This has a few nice properties:
- // * If x is large enough that x == y (after rounding), we'll compute y' as 1.
- // * If x is small enough that y == 0 (after rounding), we'll compute y' as 0.
- // * If y is zero because of drop out, we'll compute y' as 0.
- // REVIEW: Do 4 at a time.
- for (; pg < pgLim; pg++, py++)
- *pg *= 1 - ExpFast(-*py);
-}
-
-EXPORT_API(void) ApplyAbsDerivativeA(_In_ const float * px, _In_opt_ const float * py, _Inout_ float * pg, int c, bool drop)
-{
- float * pgLim = pg + c;
-
- __m128 cZero = _mm_set1_ps(0.0f);
- __m128 cSign = _mm_set1_ps(-0.0f);
- if (drop)
- {
- for (; pg < pgLim; pg += 4, px += 4, py += 4)
- {
- __m128 x1 = _mm_and_ps(_mm_load_ps(px), cSign);
- __m128 x2 = _mm_cmpgt_ps(_mm_load_ps(py), cZero);
- __m128 x3 = _mm_load_ps(pg);
- x3 = _mm_xor_ps(x3, x1);
- x3 = _mm_and_ps(x3, x2);
- _mm_store_ps(pg, x3);
- }
- }
- else
- {
- for (; pg < pgLim; pg += 4, px += 4)
- {
- __m128 x0 = _mm_load_ps(px);
- __m128 x1 = _mm_and_ps(x0, cSign);
- __m128 x2 = _mm_cmpneq_ps(x0, cZero);
- __m128 x3 = _mm_load_ps(pg);
- x3 = _mm_xor_ps(x3, x1);
- x3 = _mm_and_ps(x3, x2);
- _mm_store_ps(pg, x3);
- }
- }
-}
-
-EXPORT_API(void) ApplyTanhDerivativeA(_In_ const float * pv, _Inout_ float * pg, int c)
-{
- float * pgLim = pg + c;
-
- // pg[i] *= 1 - pv[i] * pv[i]
- __m128 cOne = _mm_set1_ps(1.0f);
- for (; pg < pgLim; pg += 4, pv += 4)
- {
- __m128 x1 = _mm_load_ps(pv);
- __m128 x2 = _mm_load_ps(pg);
- x1 = _mm_mul_ps(x1, x1);
- x1 = _mm_sub_ps(cOne, x1);
- x2 = _mm_mul_ps(x2, x1);
- _mm_store_ps(pg, x2);
- }
-}
-
-EXPORT_API(void) ApplyBoundedRectifiedLinearDerivativeA(_In_ const float * pv, _Inout_ float * pg, int c)
-{
- float * pgLim = pg + c;
-
- __m128 cZero = _mm_set1_ps(0.0f);
- __m128 cOne = _mm_set1_ps(1.0f);
- for (; pg < pgLim; pg += 4, pv += 4)
- {
- __m128 x1 = _mm_load_ps(pv);
- __m128 x2 = _mm_load_ps(pg);
- x2 = _mm_and_ps(x2, _mm_cmpgt_ps(x1, cZero));
- x2 = _mm_and_ps(x2, _mm_cmplt_ps(x1, cOne));
- _mm_store_ps(pg, x2);
- }
-}
-
EXPORT_API(void) ZeroItemsU(_Inout_ float * pd, int c, _In_ const int * pindices, int cindices)
{
DEBUG_ONLY(c);
@@ -2992,69 +1226,3 @@ EXPORT_API(void) SdcaL1UpdateSU(float primalUpdate, _In_ const float * ps, _In_
pd2[i] = std::abs(d1) > threshold ? (d1 > 0 ? d1 - threshold : d1 + threshold) : 0;
}
}
-
-EXPORT_API(void) ScaleAdadeltaU(_Inout_ float * mat, _Inout_ float * accGrads, _Inout_ float * accUpdates, float decay, float cond, _In_ const float * grads, int size)
-{
- float * pm = mat;
- float * pmLim = pm + size;
- float * pag = accGrads;
- float * pau = accUpdates;
- const float * pg = grads;
-
- __m128 dec = _mm_set1_ps(decay);
- __m128 decc = _mm_set1_ps(1 - decay);
- __m128 c = _mm_set1_ps(cond);
-
- for (; pm + 4 <= pmLim; pm += 4, pag += 4, pau += 4, pg += 4)
- {
- __m128 g = _mm_loadu_ps(pg);
- __m128 ag = _mm_loadu_ps(pag);
- __m128 au = _mm_loadu_ps(pau);
- __m128 coef = _mm_loadu_ps(pm);
-
- UpdateAdadelta(coef, ag, au, g, dec, decc, c);
-
- _mm_storeu_ps(pm, coef);
- _mm_storeu_ps(pag, ag);
- _mm_storeu_ps(pau, au);
- }
-
- for (; pm < pmLim; pm++, pag++, pau++, pg++)
- {
- float g = *pg;
- float accGrad = decay * *pag + (1 - decay) * g * g;
- float accUpd = *pau;
-
- float newUpd = sqrtf((accUpd + cond) / (accGrad + cond)) * g;
- *pm += newUpd;
- *pag = accGrad;
- *pau = decay * accUpd + (1 - decay) * newUpd * newUpd;
- }
-}
-
-EXPORT_API(void) ScaleAdadeltaA(_Inout_ float * mat, _Inout_ float * accGrads, _Inout_ float * accUpdates, float decay, float cond, _Inout_ float * grads, int size)
-{
- float * pm = mat;
- float * pmLim = pm + size;
- float * pag = accGrads;
- float * pau = accUpdates;
- float * pg = grads;
-
- __m128 dec = _mm_set1_ps(decay);
- __m128 decc = _mm_set1_ps(1 - decay);
- __m128 c = _mm_set1_ps(cond);
-
- for (; pm < pmLim; pm += 4, pag += 4, pau += 4, pg += 4)
- {
- __m128 g = _mm_load_ps(pg);
- __m128 ag = _mm_load_ps(pag);
- __m128 au = _mm_load_ps(pau);
- __m128 coef = _mm_load_ps(pm);
-
- UpdateAdadelta(coef, ag, au, g, dec, decc, c);
-
- _mm_store_ps(pm, coef);
- _mm_store_ps(pag, ag);
- _mm_store_ps(pau, au);
- }
-}
diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs
index 7d7a851a48..e9282e79c3 100644
--- a/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs
+++ b/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs
@@ -108,13 +108,19 @@ public void SdcaL1UpdateU()
[BenchmarkCategory("Fma")]
public void SdcaL1UpdateSU()
=> AvxIntrinsics.SdcaL1UpdateSU(DefaultScale, IndexLength, src, idx, DefaultScale, dst, result);
+
[Benchmark]
[BenchmarkCategory("Fma")]
- public void MatMulX()
- => AvxIntrinsics.MatMulX(src, src1, dst, 1000, 1000);
+ public void MatMul()
+ => AvxIntrinsics.MatMul(src, src1, dst, 1000, 1000);
+
+ [Benchmark]
+ public void MatMulTran()
+ => AvxIntrinsics.MatMulTran(src, src1, dst, 1000, 1000);
[Benchmark]
- public void MatMulTranX()
- => AvxIntrinsics.MatMulTranX(src, src1, dst, 1000, 1000);
+ [BenchmarkCategory("Fma")]
+ public void MatMulP()
+ => AvxIntrinsics.MatMulP(src, matrixIdx, src1, 0, 0, MatrixIndexLength, dst, 1000, 1000);
}
}
diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs
index 07958c19ed..778ce56c93 100644
--- a/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs
+++ b/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs
@@ -250,5 +250,17 @@ public unsafe void MatMulTran()
Thunk.MatMulTran(psrc1, psrc, pdst, 1000, 1000);
}
}
+
+ [Benchmark]
+ public unsafe void MatMulP()
+ {
+ fixed (float* psrc = &src[0])
+ fixed (float* pdst = &dst[0])
+ fixed (float* psrc1 = &src1[0])
+ fixed (int* pidx = &matrixIdx[0])
+ {
+ Thunk.MatMulP(psrc1, pidx, psrc, 0, 0, MatrixIndexLength, pdst, 1000, 1000);
+ }
+ }
}
}
diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/PerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/PerformanceTests.cs
index cffa0af94a..6726603b5a 100644
--- a/test/Microsoft.ML.CpuMath.PerformanceTests/PerformanceTests.cs
+++ b/test/Microsoft.ML.CpuMath.PerformanceTests/PerformanceTests.cs
@@ -17,12 +17,14 @@ public abstract class PerformanceTests
protected const int IndexLength = 1000003;
protected const int Length = 1000003;
-
+ protected const int MatrixIndexLength = 100;
+
private const int DefaultSeed = 253421;
protected const float DefaultScale = 1.11f;
protected float[] src, dst, original, src1, src2, result;
protected int[] idx;
+ protected int[] matrixIdx;
private int _seed = DefaultSeed;
@@ -67,6 +69,7 @@ public void Setup()
original = new float[Length];
result = new float[Length];
idx = new int[IndexLength];
+ matrixIdx = new int[MatrixIndexLength];
_seed = GetSeed();
Random rand = new Random(_seed);
@@ -85,6 +88,11 @@ public void Setup()
{
idx[i] = rand.Next(0, Length);
}
+
+ for (int i = 0; i < MatrixIndexLength; i++)
+ {
+ matrixIdx[i] = rand.Next(0, 1000);
+ }
}
[GlobalCleanup]
diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs
index 8e94dabb96..fdb8e8adcc 100644
--- a/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs
+++ b/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs
@@ -100,11 +100,15 @@ public void SdcaL1UpdateSU()
=> SseIntrinsics.SdcaL1UpdateSU(DefaultScale, IndexLength, src, idx, DefaultScale, dst, result);
[Benchmark]
- public void MatMulX()
+ public void MatMul()
=> SseIntrinsics.MatMul(src, src1, dst, 1000, 1000);
[Benchmark]
- public void MatMulTranX()
+ public void MatMulTran()
=> SseIntrinsics.MatMulTran(src, src1, dst, 1000, 1000);
+
+ [Benchmark]
+ public void MatMulP()
+ => SseIntrinsics.MatMulP(src, matrixIdx, src1, 0, 0, MatrixIndexLength, dst, 1000, 1000);
}
}
diff --git a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs
index 8ce7878f83..68e26ac66e 100644
--- a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs
+++ b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs
@@ -116,7 +116,7 @@ public void MatMulTranTest(int matTest, int srcTest, int dstTest, float[] expect
[InlineData(0, 0, 0, new float[] { 38.25002f, 38.25002f, 38.25002f, 38.25002f, 38.25002f, 38.25002f, 38.25002f, 38.25002f })]
[InlineData(1, 1, 0, new float[] { 910f, 2190f, 3470f, 4750f, 6030f, 7310f, 8590f, 9870f })]
[InlineData(1, 0, 1, new float[] { 95f, 231f, 367f, 503f, 639f, 775f, 911f, 1047f, 1183f, 1319f, 1455f, 1591f, 1727f, 1863f, 1999f, 2135f })]
- public void MatMulPATest(int matTest, int srcTest, int dstTest, float[] expected)
+ public void MatTimesSrcSparseTest(int matTest, int srcTest, int dstTest, float[] expected)
{
AlignedArray mat = _testMatrices[matTest];
AlignedArray src = _testSrcVectors[srcTest];