From ae55c4529a3ccfb6f80671790fbb2beb44745382 Mon Sep 17 00:00:00 2001 From: Anipik Date: Thu, 22 Nov 2018 15:45:42 -0800 Subject: [PATCH 1/5] Removing Aligned Array usage from rff and timeseries. Removing aligned matrix and cpumathaligned entirely. moving Aligned array to FactorAware where it is only being used --- src/Microsoft.ML.CpuMath/AlignedMatrix.cs | 681 ------------------ src/Microsoft.ML.CpuMath/AvxIntrinsics.cs | 443 ++++++++---- .../CpuAligenedMathUtils.cs | 148 ---- .../CpuMathUtils.netcoreapp.cs | 158 +--- .../CpuMathUtils.netstandard.cs | 15 +- src/Microsoft.ML.CpuMath/Sse.cs | 78 +- src/Microsoft.ML.CpuMath/SseIntrinsics.cs | 443 ++++++++---- src/Microsoft.ML.CpuMath/Thunk.cs | 6 - .../FactorizationMachine}/AlignedArray.cs | 93 +-- ...AdaptiveSingularSpectrumSequenceModeler.cs | 68 +- .../RandomFourierFeaturizing.cs | 61 +- src/Native/CpuMathNative/Sse.cpp | 484 ++++++++----- .../UnitTests.cs | 191 +++-- .../DataPipe/TestDataPipeBase.cs | 6 +- .../TimeSeriesDirectApi.cs | 8 +- 15 files changed, 1167 insertions(+), 1716 deletions(-) delete mode 100644 src/Microsoft.ML.CpuMath/AlignedMatrix.cs delete mode 100644 src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs rename src/{Microsoft.ML.CpuMath => Microsoft.ML.StandardLearners/FactorizationMachine}/AlignedArray.cs (53%) diff --git a/src/Microsoft.ML.CpuMath/AlignedMatrix.cs b/src/Microsoft.ML.CpuMath/AlignedMatrix.cs deleted file mode 100644 index 6d550fc3fc..0000000000 --- a/src/Microsoft.ML.CpuMath/AlignedMatrix.cs +++ /dev/null @@ -1,681 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Float = System.Single; - -using Microsoft.ML.Runtime.Internal.CpuMath.Core; -using System; -using System.Collections; -using System.Collections.Generic; - -namespace Microsoft.ML.Runtime.Internal.CpuMath -{ - using Conditional = System.Diagnostics.ConditionalAttribute; - - /// - /// This implements a logical array of Floats that is automatically aligned for SSE/AVX operations. - /// This is a thin wrapper around the AlignedArray type implemented in C++. This simply couples - /// the AlignedArray with a logical size, which does not include padding, while the AlignedArray - /// size does include padding. - /// - [BestFriend] - internal sealed class CpuAlignedVector : ICpuVector - { - private readonly AlignedArray _items; - private readonly int _size; // The logical size. - - /// - /// The value count. - /// - public int ValueCount { get { return _size; } } - - /// - /// The logical size of the vector. - /// - public int VectorSize { get { return _size; } } - - // Round cflt up to a multiple of cfltAlign. - private static int RoundUp(int cflt, int cfltAlign) - { - Contracts.Assert(0 < cflt); - // cfltAlign should be a power of two. - Contracts.Assert(0 < cfltAlign && (cfltAlign & (cfltAlign - 1)) == 0); - - // Determine the number of "blobs" of size cfltAlign. - int cblob = (cflt + cfltAlign - 1) / cfltAlign; - return cblob * cfltAlign; - } - - /// - /// Allocate an aligned vector with the given alignment (in bytes). - /// The alignment must be a power of two and at least sizeof(Float). - /// - public CpuAlignedVector(int size, int cbAlign) - { - Contracts.Assert(0 < size); - // cbAlign should be a power of two. - Contracts.Assert(sizeof(Float) <= cbAlign); - Contracts.Assert((cbAlign & (cbAlign - 1)) == 0); - - int cfltAlign = cbAlign / sizeof(Float); - int cflt = RoundUp(size, cfltAlign); - _items = new AlignedArray(cflt, cbAlign); - _size = size; - AssertValid(); - } - - public void Dispose() - { - } - - [Conditional("DEBUG")] - private void AssertValid() - { -#if DEBUG - Contracts.Assert(0 < _size && _size <= _items.Size); - - // The padding, [_size, _items.Size), should contain zeros. - for (int i = _size; i < _items.Size; i++) - Contracts.Assert(_items[i] == 0); -#endif - } - - /// - /// The physical AligenedArray items. - /// - public AlignedArray Items { get { return _items; } } - - /// - /// The alignment. - /// - public int CbAlign - { - get { return _items.CbAlign; } - } - - /// - /// Set and get the value of the vector at the given index. - /// - /// The index - /// The value at the given index - public Float this[int index] - { - get - { - Contracts.Assert(0 <= index && index < _size); - return _items[index]; - } - set - { - Contracts.Assert(0 <= index && index < _size); - _items[index] = value; - } - } - - /// - /// Get the value of the vector at the given index. - /// - /// The index - /// The value at the given index - public Float GetValue(int i) - { - Contracts.Assert(0 <= i && i < _size); - return _items[i]; - } - - /// - /// Assign randomized values to the vector elements via the input function. - /// - /// The input rand om function that takes no arguments and returns a float value - public void Randomize(Func rand) - { - Contracts.AssertValue(rand); - for (int i = 0; i < _size; i++) - _items[i] = rand(); - } - - /// - /// Assign zeros to the vector elements. - /// - public void Zero() - { - _items.ZeroItems(); - } - - /// - /// Copy the values into dst, starting at slot ivDst and advancing ivDst. - /// - /// The destination array - /// The starting index in the destination array - public void CopyTo(Float[] dst, ref int ivDst) - { - Contracts.AssertValue(dst); - Contracts.Assert(0 <= ivDst && ivDst <= dst.Length - _size); - _items.CopyTo(dst, ivDst, _size); - ivDst += _size; - } - - /// - /// Copy the values from this vector starting at slot ivSrc into dst, starting at slot ivDst. - /// The number of values that are copied is determined by count. - /// - /// The staring index in this vector - /// The destination array - /// The starting index in the destination array - /// The number of elements to be copied - public void CopyTo(int ivSrc, Float[] dst, int ivDst, int count) - { - Contracts.AssertValue(dst); - Contracts.Assert(0 <= count && count <= dst.Length); - Contracts.Assert(0 <= ivSrc && ivSrc <= _size - count); - Contracts.Assert(0 <= ivDst && ivDst <= dst.Length - count); - _items.CopyTo(ivSrc, dst, ivDst, count); - } - - /// - /// Copy the values from src, starting at slot index and advancing index, into this vector. - /// - /// The source array - /// The starting index in the source array - public void CopyFrom(Float[] src, ref int index) - { - Contracts.AssertValue(src); - Contracts.Assert(0 <= index && index <= src.Length - _size); - _items.CopyFrom(src.AsSpan(index, _size)); - index += _size; - } - - /// - /// Copy the values from src, starting at slot index and advancing index, into this vector, starting at slot ivDst. - /// The number of values that are copied is determined by count. - /// - /// The staring index in this vector - /// The source array - /// The starting index in the source array - /// The number of elements to be copied - public void CopyFrom(int ivDst, Float[] src, int ivSrc, int count) - { - Contracts.AssertValue(src); - Contracts.Assert(0 <= count && count <= src.Length); - Contracts.Assert(0 <= ivDst && ivDst <= _size - count); - Contracts.Assert(0 <= ivSrc && ivSrc <= src.Length - count); - _items.CopyFrom(ivDst, src.AsSpan(ivSrc, _size)); - } - - /// - /// Copy the values of src vector into this vector. The src vector must have the same size as this vector. - /// - /// The source vector - public void CopyFrom(CpuAlignedVector src) - { - Contracts.AssertValue(src); - Contracts.Assert(src._size == _size); - _items.CopyFrom(src._items); - } - - /// - /// Get the underlying AlignedArray as IEnumerator<Float>. - /// - public IEnumerator GetEnumerator() - { - for (int i = 0; i < _size; i++) - yield return _items[i]; - } - - IEnumerator IEnumerable.GetEnumerator() - { - return GetEnumerator(); - } - } - - /// - /// This implements a logical matrix of Floats that is automatically aligned for SSE/AVX operations. - /// The ctor takes an alignment value, which must be a power of two at least sizeof(Float). - /// - [BestFriend] - internal abstract class CpuAlignedMatrixBase - { - // _items includes "head" items filled with NaN, followed by RunLenPhy * RunCntPhy entries, followed by - // "tail" items, also filled with NaN. Note that RunLenPhy and RunCntPhy are divisible by the alignment - // specified in the ctor and are >= RunLen and RunCnt, respectively. It is illegal to access any slot - // outsize [_base, _base + RunLenPhy * RunCntPhy). The padding should all be zero (and maintained as such). - // The items are arranged in "runs" of length RunLen. There are RunCnt such runs. Each run ends with - // (RunLenPhy - RunLen) padding slots. There are an addition (RunCntPhy - RunCnt) padding runs of length - // RunLenPhy, which are entirely zero. Any native code should be able to assume and should maintain - // these invariants. - public AlignedArray Items { get; } - - protected readonly int FloatAlign; // The alignment. - - // Since FloatAlign is a power of two, shifting by Shift = log_2(FloatAlign) is the same as multiplying/dividing by FloatAlign. - protected readonly int Shift; - // Since FloatAlign is a power of two, bitwise and with Mask = FloatAlign - 1 will be the same as moding by FloatAlign. - protected readonly int Mask; - - // Logical length of runs (RunLen) and number of runs (RunCnt). - public readonly int RunLen; - public readonly int RunCnt; - - // Physical (padded) length and number of runs. - public readonly int RunLenPhy; - public readonly int RunCntPhy; - - /// - /// The logical number values in the matrix - /// - public int ValueCount => RunLen * RunCnt; - - /// - /// The logical number of rows - /// - public abstract int RowCount { get; } - - /// - /// The logical number of columns - /// - public abstract int ColCount { get; } - - /// - /// The physical number of rows - /// - public abstract int RowCountPhy { get; } - - /// - /// The pysical number of columns - /// - public abstract int ColCountPhy { get; } - - // Round cflt up to a multiple of cfltAlign. - protected static int RoundUp(int cflt, int cfltAlign) - { - Contracts.Assert(0 < cflt); - // cfltAlign should be a power of two. - Contracts.Assert(0 < cfltAlign && (cfltAlign & (cfltAlign - 1)) == 0); - - // Determine the number of "blobs" of size cfltAlign. - int cblob = (cflt + cfltAlign - 1) / cfltAlign; - return cblob * cfltAlign; - } - - /// - /// Allocate an aligned matrix with the given alignment (in bytes). - /// - protected CpuAlignedMatrixBase(int runLen, int runCnt, int cbAlign) - { - Contracts.Assert(0 < runLen); - Contracts.Assert(0 < runCnt); - // cbAlign should be a power of two. - Contracts.Assert(sizeof(Float) <= cbAlign); - Contracts.Assert((cbAlign & (cbAlign - 1)) == 0); - - RunLen = runLen; - RunCnt = runCnt; - - FloatAlign = cbAlign / sizeof(Float); - Shift = GeneralUtils.CbitLowZero((uint)FloatAlign); - Mask = FloatAlign - 1; - - RunLenPhy = RoundUp(runLen, FloatAlign); - RunCntPhy = RoundUp(runCnt, FloatAlign); - Items = new AlignedArray(RunLenPhy * RunCntPhy, cbAlign); - - AssertValid(); - } - - [Conditional("DEBUG")] - protected void AssertValid() - { -#if DEBUG - Contracts.Assert(0 < RunLen && RunLen <= RunLenPhy); - Contracts.Assert(0 < RunCnt && RunCnt <= RunCntPhy); - Contracts.Assert(RunLenPhy * RunCntPhy == Items.Size); - - // Assert that the padding at the end of each run contains zeros. - for (int i = 0; i < RunCnt; i++) - { - for (int j = RunLen; j < RunLenPhy; j++) - Contracts.Assert(Items[i * RunLenPhy + j] == 0); - } - - // Assert that the padding runs contain zeros. - for (int i = RunCnt; i < RunCntPhy; i++) - { - for (int j = 0; j < RunLenPhy; j++) - Contracts.Assert(Items[i * RunLenPhy + j] == 0); - } -#endif - } - - public void Dispose() - { - } - - /// - /// Assign randomized values to the matrix elements via the input function. - /// - /// The input rand om function that takes no arguments and returns a float value - public void Randomize(Func rand) - { - Contracts.AssertValue(rand); - for (int i = 0, k = 0; i < RunCnt; i++) - { - Contracts.Assert(k == i * RunLenPhy); - for (int j = 0; j < RunLen; j++) - Items[k + j] = rand(); - k += RunLenPhy; - } - } - - /// - /// Assign zeros to the matrix elements. - /// - public void Zero() - { - Items.ZeroItems(); - } - - /// - /// Copy the values of src matrix into this matrix. The src matrix must have the same physical and logical size as this matrix. - /// - /// The source matrix - public void CopyFrom(CpuAlignedMatrixBase src) - { - AssertValid(); - Contracts.AssertValue(src); - src.AssertValid(); - Contracts.Assert(src.RunLen == RunLen); - Contracts.Assert(src.RunCnt == RunCnt); - Contracts.Assert(src.RunLenPhy == RunLenPhy); - Contracts.Assert(src.RunCntPhy == RunCntPhy); - Items.CopyFrom(src.Items); - } - } - - /// - /// This implements a logical row-major matrix of Floats that is automatically aligned for SSE/AVX operations. - /// The ctor takes an alignment value, which must be a power of two at least sizeof(Float). - /// - [BestFriend] - internal abstract class CpuAlignedMatrixRowBase : CpuAlignedMatrixBase, ICpuBuffer - { - protected CpuAlignedMatrixRowBase(int crow, int ccol, int cbAlign) - : base(ccol, crow, cbAlign) - { - } - - /// - /// The logical number of rows - /// - public override int RowCount => RunCnt; - - /// - /// The logical number of columns - /// - public override int ColCount { get { return RunLen; } } - - /// - /// The physical number of rows - /// - public override int RowCountPhy { get { return RunCntPhy; } } - - /// - /// The physical number of columns - /// - public override int ColCountPhy { get { return RunLenPhy; } } - - /// - /// Copy the values into dst, starting at slot ivDst and advancing ivDst. - /// - /// The destination array - /// The starting index in the destination array - public void CopyTo(Float[] dst, ref int ivDst) - { - Contracts.AssertValue(dst); - Contracts.Assert(0 <= ivDst && ivDst <= dst.Length - ValueCount); - - if (ColCount == ColCountPhy) - { - // Can copy all at once. - Items.CopyTo(0, dst, ivDst, ValueCount); - ivDst += ValueCount; - } - else - { - // Copy each row. - int ivSrc = 0; - for (int row = 0; row < RowCount; row++) - { - Items.CopyTo(ivSrc, dst, ivDst, ColCount); - ivSrc += ColCountPhy; - ivDst += ColCount; - } - } - } - - /// - /// Copy the values from src, starting at slot ivSrc and advancing ivSrc. - /// - /// The source array - /// The starting index in the source array - public void CopyFrom(Float[] src, ref int ivSrc) - { - Contracts.AssertValue(src); - Contracts.Assert(0 <= ivSrc && ivSrc <= src.Length - ValueCount); - - if (ColCount == ColCountPhy) - { - Items.CopyFrom(src.AsSpan(ivSrc, ValueCount)); - ivSrc += ValueCount; - } - else - { - for (int row = 0; row < RowCount; row++) - { - Items.CopyFrom(row * ColCountPhy, src.AsSpan(ivSrc, ColCount)); - ivSrc += ColCount; - } - } - } - - /// - /// Get the underlying AlignedArray as IEnumerator<Float>. - /// - public IEnumerator GetEnumerator() - { - for (int row = 0; row < RowCount; row++) - { - int ivBase = row * ColCountPhy; - for (int col = 0; col < ColCount; col++) - yield return Items[ivBase + col]; - } - } - - IEnumerator IEnumerable.GetEnumerator() - { - return GetEnumerator(); - } - } - - /// - /// This implements a row-major matrix of Floats that is automatically aligned for SSE/AVX operations. - /// The ctor takes an alignment value, which must be a power of two at least sizeof(Float). - /// - [BestFriend] - internal sealed class CpuAlignedMatrixRow : CpuAlignedMatrixRowBase, ICpuFullMatrix - { - public CpuAlignedMatrixRow(int crow, int ccol, int cbAlign) - : base(crow, ccol, cbAlign) - { - } - - /// - /// The logical number of rows - /// - public override int RowCount { get { return RunCnt; } } - - /// - /// The logical number of columns - /// - public override int ColCount { get { return RunLen; } } - - /// - /// The physical number of rows - /// - public override int RowCountPhy { get { return RunCntPhy; } } - - /// - /// The physical number of columns - /// - public override int ColCountPhy { get { return RunLenPhy; } } - - /// - /// Copy the values from this matrix, starting from the row into dst, starting at slot ivDst and advancing ivDst. - /// - /// The starting row in this matrix - /// The destination array - /// The starting index in the destination array - public void CopyTo(int row, Float[] dst, ref int ivDst) - { - Contracts.AssertValue(dst); - Contracts.Assert(0 <= row && row < RowCount); - Contracts.Assert(0 <= ivDst && ivDst <= dst.Length - ColCount); - - Items.CopyTo(row * ColCountPhy, dst, ivDst, ColCount); - ivDst += ColCount; - } - - /// - /// Assign zeros to the values at the indices - /// - /// The indices - public void ZeroItems(int[] indices) - { - Contracts.AssertValue(indices); - - // REVIEW: Ideally, we'd adjust the indices once so we wouldn't need to - // repeatedly deal with padding adjustments. - CpuMathUtils.ZeroMatrixItems(Items, ColCount, ColCountPhy, indices); - } - } - - /// - /// This implements a logical matrix of Floats that is automatically aligned for SSE/AVX operations. - /// The ctor takes an alignment value, which must be a power of two at least sizeof(Float). - /// - [BestFriend] - internal sealed class CpuAlignedMatrixCol : CpuAlignedMatrixBase, ICpuFullMatrix - { - /// - /// Allocate an aligned matrix with the given alignment (in bytes). - /// - public CpuAlignedMatrixCol(int crow, int ccol, int cbAlign) - : base(crow, ccol, cbAlign) - { - } - - /// - /// The logical number of rows - /// - public override int RowCount { get { return RunCnt; } } - - /// - /// The logical number of columns - /// - public override int ColCount { get { return RunLen; } } - - /// - /// The physical number of rows - /// - public override int RowCountPhy { get { return RunCntPhy; } } - - /// - /// The physical number of columns - /// - public override int ColCountPhy { get { return RunLenPhy; } } - - /// - /// Copy the values into dst, starting at slot ivDst and advancing ivDst. - /// - /// The destination array - /// The starting index in the destination array - public void CopyTo(Float[] dst, ref int ivDst) - { - Contracts.AssertValue(dst); - Contracts.Assert(0 <= ivDst && ivDst <= dst.Length - ValueCount); - - for (int row = 0; row < RowCount; row++) - { - for (int col = 0; col < ColCount; col++) - dst[ivDst++] = Items[row + col * RowCountPhy]; - } - } - - /// - /// Copy the values from this matrix, starting from the row into dst, starting at slot ivDst and advancing ivDst. - /// - /// The starting row in this matrix - /// The destination array - /// The starting index in the destination array - public void CopyTo(int row, Float[] dst, ref int ivDst) - { - Contracts.AssertValue(dst); - Contracts.Assert(0 <= row && row < RowCount); - Contracts.Assert(0 <= ivDst && ivDst <= dst.Length - ColCount); - - for (int col = 0; col < ColCount; col++) - dst[ivDst++] = Items[row + col * RowCountPhy]; - } - - /// - /// Copy the values from src, starting at slot ivSrc and advancing ivSrc. - /// - /// The source array - /// The starting index in the source array - public void CopyFrom(Float[] src, ref int ivSrc) - { - Contracts.AssertValue(src); - Contracts.Assert(0 <= ivSrc && ivSrc <= src.Length - ValueCount); - for (int row = 0; row < RowCount; row++) - { - for (int col = 0; col < ColCount; col++) - Items[row + col * RowCountPhy] = src[ivSrc++]; - } - } - - /// - /// Assign zeros to the values at the indices - /// - /// The indices - public void ZeroItems(int[] indices) - { - Contracts.AssertValue(indices); - - // REVIEW: Ideally, we'd adjust the indices once so we wouldn't need to - // repeatedly deal with padding adjustments. - foreach (int iv in indices) - { - int row = iv / ColCount; - int col = iv % ColCount; - Items[row + col * ColCountPhy] = 0; - } - } - - /// - /// Get the underlying AlignedArray as IEnumerator<Float>. - /// - public IEnumerator GetEnumerator() - { - for (int row = 0; row < RowCount; row++) - { - for (int col = 0; col < ColCount; col++) - yield return Items[row + col * RowCountPhy]; - } - } - - IEnumerator IEnumerable.GetEnumerator() - { - return GetEnumerator(); - } - } -} \ No newline at end of file diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index 2156ddf5fa..b5e18ac0d9 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -152,11 +152,6 @@ private static Vector256 MultiplyAdd(Vector256 src1, Vector256 mat, ReadOnlySpan src, Span dst, int crow, int ccol) { fixed (float* psrc = &MemoryMarshal.GetReference(src)) @@ -170,12 +165,141 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr float* pDstCurrent = pdst; float* pMatCurrent = pmat; + if (crow % 4 == 0) + { + while (pDstCurrent + 4 <= pDstEnd) + { + Vector256 res0 = Avx.SetZeroVector256(); + Vector256 res1 = Avx.SetZeroVector256(); + Vector256 res2 = Avx.SetZeroVector256(); + Vector256 res3 = Avx.SetZeroVector256(); + + int length = ccol; + float* pSrcCurrent = psrc; + + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 32); + int remainder = 0; + + if ((misalignment & 3) != 0) + { + // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations + while (pSrcCurrent < pSrcEnd) + { + Vector256 vector = Avx.LoadVector256(pSrcCurrent); + + float* pMatTemp = pMatCurrent; + res0 = MultiplyAdd(pMatTemp, vector, res0); + res1 = MultiplyAdd(pMatTemp += ccol, vector, res1); + res2 = MultiplyAdd(pMatTemp += ccol, vector, res2); + res3 = MultiplyAdd(pMatTemp += ccol, vector, res3); + + pSrcCurrent += 8; + pMatCurrent += 8; + } + } + else + { + if (misalignment != 0) + { + // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 8 - misalignment; + + Vector256 mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); + + // We only align pMat since it has significantly more reads. + float* pMatTemp = pMatCurrent; + Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); + Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent)); + + res0 = Avx.Multiply(x01, vector); + res1 = Avx.Multiply(x11, vector); + res2 = Avx.Multiply(x21, vector); + res3 = Avx.Multiply(x31, vector); + + pMatCurrent += misalignment; + pSrcCurrent += misalignment; + length -= misalignment; + } + + if (length > 7) + { + // Handle all the 256-bit blocks that we can now that we have offset to an aligned address + remainder = length % 8; + + while (pSrcCurrent + 8 <= pSrcEnd) + { + // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads + // (due to semantics of the legacy encoding). + // We don't need an assert, since the instruction will throw for unaligned inputs. + Vector256 vector = Avx.LoadVector256(pSrcCurrent); + + float* pMatTemp = pMatCurrent; + res0 = MultiplyAdd(pMatTemp, vector, res0); + res1 = MultiplyAdd(pMatTemp += ccol, vector, res1); + res2 = MultiplyAdd(pMatTemp += ccol, vector, res2); + res3 = MultiplyAdd(pMatTemp += ccol, vector, res3); + + pSrcCurrent += 8; + pMatCurrent += 8; + } + } + else + { + // Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not + // 256-bit aligned. This means we can't do any aligned loads and will just end up doing two + // unaligned loads where we mask the input each time. + remainder = length; + } + + if (remainder != 0) + { + // Handle any trailing elements that don't fit into a 256-bit block by moving back so that the next + // unaligned load will read to the end of the array and then mask out any elements already processed + + pMatCurrent -= (8 - remainder); + pSrcCurrent -= (8 - remainder); + + Vector256 mask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); + + float* pMatTemp = pMatCurrent; + Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); + Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent)); + + res0 = MultiplyAdd(x01, vector, res0); + res1 = MultiplyAdd(x11, vector, res1); + res2 = MultiplyAdd(x21, vector, res2); + res3 = MultiplyAdd(x31, vector, res3); + + pMatCurrent += 8; + pSrcCurrent += 8; + } + } + + // Add up the entries of each, with the 4 results in res0 + res0 = Avx.HorizontalAdd(res0, res1); + res2 = Avx.HorizontalAdd(res2, res3); + res0 = Avx.HorizontalAdd(res0, res2); + + Vector128 sum = Sse.Add(Avx.GetLowerHalf(res0), GetHigh(in res0)); + Sse.Store(pDstCurrent, sum); + + pDstCurrent += 4; + pMatCurrent += 3 * ccol; + } + } + while (pDstCurrent < pDstEnd) { Vector256 res0 = Avx.SetZeroVector256(); - Vector256 res1 = Avx.SetZeroVector256(); - Vector256 res2 = Avx.SetZeroVector256(); - Vector256 res3 = Avx.SetZeroVector256(); int length = ccol; float* pSrcCurrent = psrc; @@ -193,9 +317,6 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr float* pMatTemp = pMatCurrent; res0 = MultiplyAdd(pMatTemp, vector, res0); - res1 = MultiplyAdd(pMatTemp += ccol, vector, res1); - res2 = MultiplyAdd(pMatTemp += ccol, vector, res2); - res3 = MultiplyAdd(pMatTemp += ccol, vector, res3); pSrcCurrent += 8; pMatCurrent += 8; @@ -215,15 +336,9 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr // We only align pMat since it has significantly more reads. float* pMatTemp = pMatCurrent; Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); - Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent)); res0 = Avx.Multiply(x01, vector); - res1 = Avx.Multiply(x11, vector); - res2 = Avx.Multiply(x21, vector); - res3 = Avx.Multiply(x31, vector); pMatCurrent += misalignment; pSrcCurrent += misalignment; @@ -244,9 +359,6 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr float* pMatTemp = pMatCurrent; res0 = MultiplyAdd(pMatTemp, vector, res0); - res1 = MultiplyAdd(pMatTemp += ccol, vector, res1); - res2 = MultiplyAdd(pMatTemp += ccol, vector, res2); - res3 = MultiplyAdd(pMatTemp += ccol, vector, res3); pSrcCurrent += 8; pMatCurrent += 8; @@ -272,42 +384,25 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr float* pMatTemp = pMatCurrent; Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); - Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent)); res0 = MultiplyAdd(x01, vector, res0); - res1 = MultiplyAdd(x11, vector, res1); - res2 = MultiplyAdd(x21, vector, res2); - res3 = MultiplyAdd(x31, vector, res3); pMatCurrent += 8; pSrcCurrent += 8; } } - // Add up the entries of each, with the 4 results in res0 - res0 = Avx.HorizontalAdd(res0, res1); - res2 = Avx.HorizontalAdd(res2, res3); - res0 = Avx.HorizontalAdd(res0, res2); - - Vector128 sum = Sse.Add(Avx.GetLowerHalf(res0), GetHigh(in res0)); - Sse.Store(pDstCurrent, sum); + res0 = VectorSum256(in res0); + float sum = Sse.ConvertToSingle(Sse.AddScalar(Avx.GetLowerHalf(res0), GetHigh(res0))); + *pDstCurrent = sum; - pDstCurrent += 4; - pMatCurrent += 3 * ccol; + pDstCurrent += 1; } } } // Partial sparse source vector. - public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan rgposSrc, AlignedArray src, - int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) - { - MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol); - } - public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgposSrc, ReadOnlySpan src, int posMin, int iposMin, int iposEnd, Span dst, int crow, int ccol) { @@ -461,11 +556,6 @@ Vector256 SparseMultiplicationAcrossRow() } } - public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) - { - MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol); - } - public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crow, int ccol) { fixed (float* psrc = &MemoryMarshal.GetReference(src)) @@ -484,17 +574,9 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan h01 = Sse.LoadVector128(pSrcCurrent); - // Replicate each slot of h01 (ABCD) into its own register. - Vector128 h11 = Avx.Permute(h01, 0x55); // B - Vector128 h21 = Avx.Permute(h01, 0xAA); // C - Vector128 h31 = Avx.Permute(h01, 0xFF); // D h01 = Avx.Permute(h01, 0x00); // A Vector256 x01 = Avx.SetHighLow(h01, h01); - Vector256 x11 = Avx.SetHighLow(h11, h11); - Vector256 x21 = Avx.SetHighLow(h21, h21); - Vector256 x31 = Avx.SetHighLow(h31, h31); - int length = crow; float* pDstCurrent = pdst; @@ -508,13 +590,6 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); Avx.Store(pDstCurrent, x02); pDstCurrent += 8; @@ -536,22 +611,11 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); Vector256 x3 = Avx.LoadVector256(pDstCurrent); - x02 = Avx.Or(x02, Avx.And(x3, trailingMask)); + x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); Avx.Store(pDstCurrent, x02); pMatCurrent += misalignment; @@ -569,15 +633,7 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); Avx.Store(pDstCurrent, x02); pDstCurrent += 8; @@ -603,22 +659,11 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); Vector256 x3 = Avx.LoadVector256(pDstCurrent); - x02 = Avx.Or(x02, Avx.And(x3, leadingMask)); + x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); Avx.Store(pDstCurrent, x02); pDstCurrent += 8; @@ -626,25 +671,164 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan h01 = Sse.LoadVector128(pSrcCurrent); + // Replicate each slot of h01 (ABCD) into its own register. + Vector128 h11 = Avx.Permute(h01, 0x55); // B + Vector128 h21 = Avx.Permute(h01, 0xAA); // C + Vector128 h31 = Avx.Permute(h01, 0xFF); // D + h01 = Avx.Permute(h01, 0x00); // A + + Vector256 x01 = Avx.SetHighLow(h01, h01); + Vector256 x11 = Avx.SetHighLow(h11, h11); + Vector256 x21 = Avx.SetHighLow(h21, h21); + Vector256 x31 = Avx.SetHighLow(h31, h31); + + int length = crow; + float* pDstCurrent = pdst; + + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 32); + + if ((misalignment & 3) != 0) + { + while (pDstCurrent < pDstEnd) + { + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); + + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; + } + } + else + { + int remainder = 0; + if (misalignment != 0) + { + // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 8 - misalignment; + + Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); + + // We only align pMat since it has significantly more reads. + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Multiply(x01, x02); + x12 = Avx.Multiply(x11, x12); + x22 = Avx.Multiply(x21, x22); + x32 = Avx.Multiply(x31, x32); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); + Vector256 x3 = Avx.LoadVector256(pDstCurrent); + x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); + + x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); + + Avx.Store(pDstCurrent, x02); + pMatCurrent += misalignment; + pDstCurrent += misalignment; + length -= misalignment; + } + if (length > 7) + { + remainder = length % 8; + while (pDstCurrent + 8 <= pDstEnd) + { + float* pMatTemp = pMatCurrent; + + Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); + + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; + } + } + else + { + remainder = length; + } + + if (remainder != 0) + { + pMatCurrent -= (8 - remainder); + pDstCurrent -= (8 - remainder); + Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); + + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Multiply(x01, x02); + x12 = Avx.Multiply(x11, x12); + x22 = Avx.Multiply(x21, x22); + x32 = Avx.Multiply(x31, x32); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); + Vector256 x3 = Avx.LoadVector256(pDstCurrent); + x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); + + x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); + + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; + } + } + + pMatCurrent += 3 * crow; + pSrcCurrent += 4; + } + } + while (pSrcCurrent < pSrcEnd) { Vector128 h01 = Sse.LoadVector128(pSrcCurrent); // Replicate each slot of h01 (ABCD) into its own register. - Vector128 h11 = Avx.Permute(h01, 0x55); // B - Vector128 h21 = Avx.Permute(h01, 0xAA); // C - Vector128 h31 = Avx.Permute(h01, 0xFF); // D h01 = Avx.Permute(h01, 0x00); // A - Vector256 x01 = Avx.SetHighLow(h01, h01); - Vector256 x11 = Avx.SetHighLow(h11, h11); - Vector256 x21 = Avx.SetHighLow(h21, h21); - Vector256 x31 = Avx.SetHighLow(h31, h31); - int length = crow; float* pDstCurrent = pdst; @@ -657,14 +841,6 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); Avx.Store(pDstCurrent, x02); @@ -687,22 +863,11 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); Vector256 x3 = Avx.LoadVector256(pDstCurrent); - x02 = Avx.Or(x02, Avx.And(x3, trailingMask)); + x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); @@ -716,17 +881,11 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); Avx.Store(pDstCurrent, x02); @@ -739,7 +898,7 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan 0) { pMatCurrent -= (8 - remainder); pDstCurrent -= (8 - remainder); @@ -747,22 +906,11 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); Vector256 x3 = Avx.LoadVector256(pDstCurrent); - x02 = Avx.Or(x02, Avx.And(x3, leadingMask)); + x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); @@ -770,10 +918,9 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan - where TMatrix : CpuAlignedMatrixBase, ICpuFullMatrix - { - /// - /// Assert the compatibility of the underlying AlignedArray for the input matrix in terms of alignment amount. - /// - /// The input matrix - public static void AssertCompatible(ICpuFullMatrix values) - { -#if DEBUG - var mat = values as TMatrix; - Contracts.AssertValue(mat); - Contracts.Assert((mat.Items.CbAlign % CpuMathUtils.GetVectorAlignment()) == 0); -#endif - } - - /// - /// Assert the compatibility of the underlying AlignedArray for the input vector in terms of alignment amount. - /// - /// The input vector - public static void AssertCompatible(ICpuVector values) - { -#if DEBUG - CpuAlignedVector vec = values as CpuAlignedVector; - Contracts.AssertValue(vec); - Contracts.Assert((vec.Items.CbAlign % CpuMathUtils.GetVectorAlignment()) == 0); -#endif - } - - private static TMatrix A(ICpuFullMatrix x) - { - AssertCompatible(x); - return (TMatrix)x; - } - - private static CpuAlignedVector A(ICpuVector x) - { - AssertCompatible(x); - return (CpuAlignedVector)x; - } - - private static void AssertCompatibleCore(ICpuMatrix mat, ICpuVector src, ICpuVector dst) - { - AssertCompatible(src); - AssertCompatible(dst); - Contracts.Assert(mat.ColCount == src.VectorSize); - Contracts.Assert(mat.RowCount == dst.VectorSize); - } - - /// - /// Asserts the following: - /// 1. The compatibility of the underlying AlignedArray for mat in terms of alignment amount. - /// 2. The compatibility of the underlying AlignedArray for src in terms of alignment amount. - /// 3. The compatibility of the underlying AlignedArray for dst in terms of alignment amount. - /// 4. The compatibility of the matrix-vector multiplication mat * src = dst. - /// - /// - /// - /// - public static void AssertCompatible(ICpuFullMatrix mat, ICpuVector src, ICpuVector dst) - { - // Also check the physical sizes. - AssertCompatible(mat); - AssertCompatibleCore(mat, src, dst); - var m = A(mat); - Contracts.Assert(m.ColCountPhy == A(src).Items.Size); - Contracts.Assert(m.RowCountPhy == A(dst).Items.Size); - } - - /// - /// Matrix multiplication: - /// dst = mat * src - /// - /// The multiplier matrix - /// The source vector - /// The destination vector - public static void MatTimesSrc(ICpuFullMatrix mat, ICpuVector src, ICpuVector dst) - { - bool colMajor = typeof(TMatrix) == typeof(CpuAlignedMatrixCol); - AssertCompatible(mat, src, dst); - var m = A(mat); - CpuMathUtils.MatrixTimesSource(colMajor, m.Items, A(src).Items, A(dst).Items, m.RunCnt); - } - - /// - /// Matrix transpose multiplication: - /// dst = mat' * src - /// - /// The multiplier matrix - /// The source vector - /// The destination vector - public static void MatTranTimesSrc(ICpuFullMatrix mat, ICpuVector src, ICpuVector dst) - { - bool colMajor = typeof(TMatrix) == typeof(CpuAlignedMatrixCol); - AssertCompatible(mat, dst, src); - var m = A(mat); - CpuMathUtils.MatrixTimesSource(!colMajor, m.Items, A(src).Items, A(dst).Items, m.RunCnt); - } - } - - public static class GeneralUtils - { - /// - /// Count the number of zero bits in the lonest string of zero's from the lowest significant bit of the input integer. - /// - /// The input integer - /// - public static int CbitLowZero(uint u) - { - if (u == 0) - return 32; - - int cbit = 0; - if ((u & 0x0000FFFF) == 0) - { - cbit += 16; - u >>= 16; - } - if ((u & 0x000000FF) == 0) - { - cbit += 8; - u >>= 8; - } - if ((u & 0x0000000F) == 0) - { - cbit += 4; - u >>= 4; - } - if ((u & 0x00000003) == 0) - { - cbit += 2; - u >>= 2; - } - if ((u & 0x00000001) == 0) - cbit += 1; - return cbit; - } - } -} diff --git a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs index d895e590a9..76afabd59e 100644 --- a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs +++ b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs @@ -11,76 +11,59 @@ namespace Microsoft.ML.Runtime.Internal.CpuMath { internal static partial class CpuMathUtils { - // The count of bytes in Vector128, corresponding to _cbAlign in AlignedArray - private const int Vector128Alignment = 16; - - // The count of bytes in Vector256, corresponding to _cbAlign in AlignedArray - private const int Vector256Alignment = 32; - - // The count of bytes in a 32-bit float, corresponding to _cbAlign in AlignedArray - private const int FloatAlignment = 4; - - // If neither AVX nor SSE is supported, return basic alignment for a 4-byte float. - [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] - public static int GetVectorAlignment() - => Avx.IsSupported ? Vector256Alignment : (Sse.IsSupported ? Vector128Alignment : FloatAlignment); - - public static void MatrixTimesSource(bool transpose, AlignedArray matrix, AlignedArray source, AlignedArray destination, int stride) + public static void MatrixTimesSource(bool transpose, float[] matrix, ReadOnlySpan source, Span destination, int stride) { - Contracts.Assert(matrix.Size == destination.Size * source.Size); + Contracts.Assert(matrix.Length == destination.Length * source.Length); Contracts.Assert(stride >= 0); - if (Avx.IsSupported) + if (!transpose) { - if (!transpose) - { - Contracts.Assert(stride <= destination.Size); - AvxIntrinsics.MatMul(matrix, source, destination, stride, source.Size); - } - else + if (Avx.IsSupported && source.Length >= 8) { - Contracts.Assert(stride <= source.Size); - AvxIntrinsics.MatMulTran(matrix, source, destination, destination.Size, stride); + Contracts.Assert(stride <= destination.Length); + AvxIntrinsics.MatMul(matrix, source, destination, stride, source.Length); } - } - else if (Sse.IsSupported) - { - if (!transpose) + else if (Sse.IsSupported && source.Length >= 4) { - Contracts.Assert(stride <= destination.Size); - SseIntrinsics.MatMul(matrix, source, destination, stride, source.Size); + Contracts.Assert(stride <= destination.Length); + SseIntrinsics.MatMul(matrix, source, destination, stride, source.Length); } else { - Contracts.Assert(stride <= source.Size); - SseIntrinsics.MatMulTran(matrix, source, destination, destination.Size, stride); - } - } - else - { - if (!transpose) - { - Contracts.Assert(stride <= destination.Size); + Contracts.Assert(stride <= destination.Length); for (int i = 0; i < stride; i++) { float dotProduct = 0; - for (int j = 0; j < source.Size; j++) + for (int j = 0; j < source.Length; j++) { - dotProduct += matrix[i * source.Size + j] * source[j]; + dotProduct += matrix[i * source.Length + j] * source[j]; } destination[i] = dotProduct; } } + } + else + { + if (Avx.IsSupported && destination.Length >= 8) + { + Contracts.Assert(stride <= source.Length); + AvxIntrinsics.MatMulTran(matrix, source, destination, destination.Length, stride); + } + else if (Sse.IsSupported && destination.Length >=4) + { + Contracts.Assert(stride <= source.Length); + SseIntrinsics.MatMulTran(matrix, source, destination, destination.Length, stride); + } else { - Contracts.Assert(stride <= source.Size); - for (int i = 0; i < destination.Size; i++) + Contracts.Assert(stride <= source.Length); + for (int i = 0; i < destination.Length; i++) { float dotProduct = 0; for (int j = 0; j < stride; j++) { - dotProduct += matrix[j * source.Size + i] * source[j]; + dotProduct += matrix[j * destination.Length + i] * source[j]; } destination[i] = dotProduct; @@ -89,17 +72,17 @@ public static void MatrixTimesSource(bool transpose, AlignedArray matrix, Aligne } } - public static void MatrixTimesSource(AlignedArray matrix, ReadOnlySpan rgposSrc, AlignedArray sourceValues, - int posMin, int iposMin, int iposLimit, AlignedArray destination, int stride) + public static void MatrixTimesSource(ReadOnlySpan matrix, ReadOnlySpan rgposSrc, ReadOnlySpan sourceValues, + int posMin, int iposMin, int iposLimit, Span destination, int stride) { Contracts.Assert(iposMin >= 0); Contracts.Assert(iposMin <= iposLimit); Contracts.Assert(iposLimit <= rgposSrc.Length); - Contracts.Assert(matrix.Size == destination.Size * sourceValues.Size); + Contracts.Assert(matrix.Length == destination.Length * sourceValues.Length); if (iposMin >= iposLimit) { - destination.ZeroItems(); + destination.Clear(); return; } @@ -108,24 +91,24 @@ public static void MatrixTimesSource(AlignedArray matrix, ReadOnlySpan rgpo if (Avx.IsSupported) { - Contracts.Assert(stride <= destination.Size); - AvxIntrinsics.MatMulP(matrix, rgposSrc, sourceValues, posMin, iposMin, iposLimit, destination, stride, sourceValues.Size); + Contracts.Assert(stride <= destination.Length); + AvxIntrinsics.MatMulP(matrix, rgposSrc, sourceValues, posMin, iposMin, iposLimit, destination, stride, sourceValues.Length); } else if (Sse.IsSupported) { - Contracts.Assert(stride <= destination.Size); - SseIntrinsics.MatMulP(matrix, rgposSrc, sourceValues, posMin, iposMin, iposLimit, destination, stride, sourceValues.Size); + Contracts.Assert(stride <= destination.Length); + SseIntrinsics.MatMulP(matrix, rgposSrc, sourceValues, posMin, iposMin, iposLimit, destination, stride, sourceValues.Length); } else { - Contracts.Assert(stride <= destination.Size); + Contracts.Assert(stride <= destination.Length); for (int i = 0; i < stride; i++) { float dotProduct = 0; for (int j = iposMin; j < iposLimit; j++) { int col = rgposSrc[j] - posMin; - dotProduct += matrix[i * sourceValues.Size + col] * sourceValues[col]; + dotProduct += matrix[i * sourceValues.Length + col] * sourceValues[col]; } destination[i] = dotProduct; } @@ -636,71 +619,6 @@ public static float L2DistSquared(ReadOnlySpan left, ReadOnlySpan } } - public static void ZeroMatrixItems(AlignedArray destination, int ccol, int cfltRow, int[] indices) - { - Contracts.Assert(ccol > 0); - Contracts.Assert(ccol <= cfltRow); - - if (ccol == cfltRow) - { - ZeroItemsU(destination, destination.Size, indices, indices.Length); - } - else - { - ZeroMatrixItemsCore(destination, destination.Size, ccol, cfltRow, indices, indices.Length); - } - } - - private static unsafe void ZeroItemsU(AlignedArray destination, int c, int[] indices, int cindices) - { - fixed (float* pdst = &destination.Items[0]) - fixed (int* pidx = &indices[0]) - { - for (int i = 0; i < cindices; ++i) - { - int index = pidx[i]; - Contracts.Assert(index >= 0); - Contracts.Assert(index < c); - pdst[index] = 0; - } - } - } - - private static unsafe void ZeroMatrixItemsCore(AlignedArray destination, int c, int ccol, int cfltRow, int[] indices, int cindices) - { - fixed (float* pdst = &destination.Items[0]) - fixed (int* pidx = &indices[0]) - { - int ivLogMin = 0; - int ivLogLim = ccol; - int ivPhyMin = 0; - - for (int i = 0; i < cindices; ++i) - { - int index = pidx[i]; - Contracts.Assert(index >= 0); - Contracts.Assert(index < c); - - int col = index - ivLogMin; - if ((uint)col >= (uint)ccol) - { - Contracts.Assert(ivLogMin > index || index >= ivLogLim); - - int row = index / ccol; - ivLogMin = row * ccol; - ivLogLim = ivLogMin + ccol; - ivPhyMin = row * cfltRow; - - Contracts.Assert(index >= ivLogMin); - Contracts.Assert(index < ivLogLim); - col = index - ivLogMin; - } - - pdst[ivPhyMin + col] = 0; - } - } - } - public static void SdcaL1UpdateDense(float primalUpdate, int count, ReadOnlySpan source, float threshold, Span v, Span w) { Contracts.AssertNonEmpty(source); diff --git a/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs b/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs index 5ecbc62be1..400b70d651 100644 --- a/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs +++ b/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs @@ -11,17 +11,10 @@ namespace Microsoft.ML.Runtime.Internal.CpuMath [BestFriend] internal static partial class CpuMathUtils { - // The count of bytes in Vector128, corresponding to _cbAlign in AlignedArray - private const int Vector128Alignment = 16; + public static void MatrixTimesSource(bool transpose, ReadOnlySpan matrix, ReadOnlySpan source, Span destination, int stride) => SseUtils.MatTimesSrc(transpose, matrix, source, destination, stride); - [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] - public static int GetVectorAlignment() - => Vector128Alignment; - - public static void MatrixTimesSource(bool transpose, AlignedArray matrix, AlignedArray source, AlignedArray destination, int stride) => SseUtils.MatTimesSrc(transpose, matrix, source, destination, stride); - - public static void MatrixTimesSource(AlignedArray matrix, ReadOnlySpan rgposSrc, AlignedArray sourceValues, - int posMin, int iposMin, int iposLimit, AlignedArray destination, int stride) => SseUtils.MatTimesSrc(matrix, rgposSrc, sourceValues, posMin, iposMin, iposLimit, destination, stride); + public static void MatrixTimesSource(ReadOnlySpan matrix, ReadOnlySpan rgposSrc, ReadOnlySpan sourceValues, + int posMin, int iposMin, int iposLimit, Span destination, int stride) => SseUtils.MatTimesSrc(matrix, rgposSrc, sourceValues, posMin, iposMin, iposLimit, destination, stride); public static void Add(float value, Span destination) => SseUtils.Add(value, destination); @@ -63,8 +56,6 @@ public static void MatrixTimesSource(AlignedArray matrix, ReadOnlySpan rgpo public static float L2DistSquared(ReadOnlySpan left, ReadOnlySpan right, int count) => SseUtils.L2DistSquared(left, right, count); - public static void ZeroMatrixItems(AlignedArray destination, int ccol, int cfltRow, int[] indices) => SseUtils.ZeroMatrixItems(destination, ccol, cfltRow, indices); - public static void SdcaL1UpdateDense(float primalUpdate, int count, ReadOnlySpan source, float threshold, Span v, Span w) => SseUtils.SdcaL1UpdateDense(primalUpdate, count, source, threshold, v, w); diff --git a/src/Microsoft.ML.CpuMath/Sse.cs b/src/Microsoft.ML.CpuMath/Sse.cs index 8b1c4da70f..d57b400e9c 100644 --- a/src/Microsoft.ML.CpuMath/Sse.cs +++ b/src/Microsoft.ML.CpuMath/Sse.cs @@ -15,74 +15,53 @@ namespace Microsoft.ML.Runtime.Internal.CpuMath [BestFriend] internal static class SseUtils { - public const int CbAlign = 16; - - private static bool Compat(AlignedArray a) - { - Contracts.AssertValue(a); - Contracts.Assert(a.Size > 0); - return a.CbAlign == CbAlign; - } - - private static unsafe float* Ptr(AlignedArray a, float* p) - { - Contracts.AssertValue(a); - float* q = p + a.GetBase((long)p); - Contracts.Assert(((long)q & (CbAlign - 1)) == 0); - return q; - } - - public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) + public static void MatTimesSrc(bool tran, ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crun) { - Contracts.Assert(Compat(mat)); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(mat.Size == dst.Size * src.Size); + Contracts.Assert(mat.Length == dst.Length * src.Length); unsafe { - fixed (float* pmat = &mat.Items[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) + fixed (float* pmat = &mat[0]) + fixed (float* psrc = &src[0]) + fixed (float* pdst = &dst[0]) { if (!tran) { - Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMul(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size); + Contracts.Assert(0 <= crun && crun <= dst.Length); + Thunk.MatMul(pmat, psrc, pdst, crun, src.Length); } else { - Contracts.Assert(0 <= crun && crun <= src.Size); - Thunk.MatMulTran(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun); + Contracts.Assert(0 <= crun && crun <= src.Length); + Thunk.MatMulTran(pmat, psrc, pdst, dst.Length, crun); } } } } - public static void MatTimesSrc(AlignedArray mat, ReadOnlySpan rgposSrc, AlignedArray srcValues, - int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) + public static void MatTimesSrc(ReadOnlySpan mat, ReadOnlySpan rgposSrc, ReadOnlySpan srcValues, + int posMin, int iposMin, int iposLim, Span dst, int crun) { - Contracts.Assert(Compat(mat)); - Contracts.Assert(Compat(srcValues)); - Contracts.Assert(Compat(dst)); Contracts.Assert(0 <= iposMin && iposMin <= iposLim && iposLim <= rgposSrc.Length); - Contracts.Assert(mat.Size == dst.Size * srcValues.Size); + Contracts.Assert(mat.Length == dst.Length * srcValues.Length); if (iposMin >= iposLim) { - dst.ZeroItems(); + dst.Clear(); return; } + Contracts.AssertNonEmpty(rgposSrc); + unsafe { - fixed (float* pdst = &dst.Items[0]) - fixed (float* pmat = &mat.Items[0]) - fixed (float* psrc = &srcValues.Items[0]) + fixed (float* pdst = &dst[0]) + fixed (float* pmat = &mat[0]) + fixed (float* psrc = &srcValues[0]) fixed (int* ppossrc = &rgposSrc[0]) { - Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMulP(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); + Contracts.Assert(0 <= crun && crun <= dst.Length); + Thunk.MatMulP(pmat, ppossrc, psrc, posMin, iposMin, iposLim, pdst, crun, srcValues.Length); } } } @@ -366,23 +345,6 @@ public static float L2DistSquared(ReadOnlySpan a, ReadOnlySpan b, } } - public static void ZeroMatrixItems(AlignedArray dst, int ccol, int cfltRow, int[] indices) - { - Contracts.Assert(0 < ccol && ccol <= cfltRow); - - unsafe - { - fixed (float* pdst = &dst.Items[0]) - fixed (int* pi = &indices[0]) - { - if (ccol == cfltRow) - Thunk.ZeroItemsU(Ptr(dst, pdst), dst.Size, pi, indices.Length); - else - Thunk.ZeroMatrixItemsCore(Ptr(dst, pdst), dst.Size, ccol, cfltRow, pi, indices.Length); - } - } - } - public static void SdcaL1UpdateDense(float primalUpdate, int count, ReadOnlySpan src, float threshold, Span v, Span w) { Contracts.AssertNonEmpty(src); diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index b83fd6bbc6..ef0c076814 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -116,11 +116,6 @@ internal static Vector128 GetNewDst128(in Vector128 xDst1, in Vect } // Multiply matrix times vector into vector. - public static unsafe void MatMul(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) - { - MatMul(mat.Items, src.Items, dst.Items, crow, ccol); - } - public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crow, int ccol) { fixed (float* psrc = &MemoryMarshal.GetReference(src)) @@ -134,12 +129,149 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr float* pDstCurrent = pdst; float* pMatCurrent = pmat; + if (ccol % 4 == 0) + { + while (pDstCurrent + 4 <= pDstEnd) + { + Vector128 res0 = Sse.SetZeroVector128(); + Vector128 res1 = Sse.SetZeroVector128(); + Vector128 res2 = Sse.SetZeroVector128(); + Vector128 res3 = Sse.SetZeroVector128(); + + int length = ccol; + float* pSrcCurrent = psrc; + + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 16); + int remainder = 0; + + if ((misalignment & 3) != 0) + { + // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations + while (pSrcCurrent < pSrcEnd) + { + Vector128 vector = Sse.LoadVector128(pSrcCurrent); + + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp)); + Vector128 x11 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x21 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x31 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + + res0 = Sse.Add(res0, x01); + res1 = Sse.Add(res1, x11); + res2 = Sse.Add(res2, x21); + res3 = Sse.Add(res3, x31); + + pSrcCurrent += 4; + pMatCurrent += 4; + } + } + else + { + if (misalignment != 0) + { + // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 4 - misalignment; + + Vector128 mask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); + + // We only align pMat since it has significantly more reads. + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.And(mask, Sse.LoadVector128(pMatTemp)); + Vector128 x11 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x21 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x31 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 vector = Sse.And(mask, Sse.LoadVector128(pSrcCurrent)); + + res0 = Sse.Multiply(x01, vector); + res1 = Sse.Multiply(x11, vector); + res2 = Sse.Multiply(x21, vector); + res3 = Sse.Multiply(x31, vector); + + pMatCurrent += misalignment; + pSrcCurrent += misalignment; + length -= misalignment; + } + + if (length > 3) + { + // Handle all the 128-bit blocks that we can now that we have offset to an aligned address + remainder = length % 4; + + // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads + // (due to semantics of the legacy encoding). + // We don't need an assert, since the instruction will throw for unaligned inputs. + while (pSrcCurrent + 4 <= pSrcEnd) + { + Vector128 vector = Sse.LoadVector128(pSrcCurrent); + + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp)); + Vector128 x11 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); + Vector128 x21 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); + Vector128 x31 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); + + res0 = Sse.Add(res0, x01); + res1 = Sse.Add(res1, x11); + res2 = Sse.Add(res2, x21); + res3 = Sse.Add(res3, x31); + + pSrcCurrent += 4; + pMatCurrent += 4; + } + } + else + { + // Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not + // 128-bit aligned. This means we can't do any aligned loads and will just end up doing two + // unaligned loads where we mask the input each time. + remainder = length; + } + + if (remainder != 0) + { + // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next + // unaligned load will read to the end of the array and then mask out any elements already processed + + pMatCurrent -= (4 - remainder); + pSrcCurrent -= (4 - remainder); + + Vector128 mask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); + + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.And(mask, Sse.LoadVector128(pMatTemp)); + Vector128 x11 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x21 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x31 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 vector = Sse.And(mask, Sse.LoadVector128(pSrcCurrent)); + + res0 = Sse.Add(res0, Sse.Multiply(x01, vector)); + res1 = Sse.Add(res1, Sse.Multiply(x11, vector)); + res2 = Sse.Add(res2, Sse.Multiply(x21, vector)); + res3 = Sse.Add(res3, Sse.Multiply(x31, vector)); + + pMatCurrent += 4; + pSrcCurrent += 4; + } + } + + // Add up the entries of each, with the 4 results in res0 + res0 = Sse3.HorizontalAdd(res0, res1); + res2 = Sse3.HorizontalAdd(res2, res3); + res0 = Sse3.HorizontalAdd(res0, res2); + + Sse.Store(pDstCurrent, res0); + pDstCurrent += 4; + pMatCurrent += 3 * ccol; + } + } + while (pDstCurrent < pDstEnd) { Vector128 res0 = Sse.SetZeroVector128(); - Vector128 res1 = Sse.SetZeroVector128(); - Vector128 res2 = Sse.SetZeroVector128(); - Vector128 res3 = Sse.SetZeroVector128(); int length = ccol; float* pSrcCurrent = psrc; @@ -157,14 +289,7 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr float* pMatTemp = pMatCurrent; Vector128 x01 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp)); - Vector128 x11 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x21 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x31 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); - res0 = Sse.Add(res0, x01); - res1 = Sse.Add(res1, x11); - res2 = Sse.Add(res2, x21); - res3 = Sse.Add(res3, x31); pSrcCurrent += 4; pMatCurrent += 4; @@ -184,15 +309,9 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr // We only align pMat since it has significantly more reads. float* pMatTemp = pMatCurrent; Vector128 x01 = Sse.And(mask, Sse.LoadVector128(pMatTemp)); - Vector128 x11 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x21 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x31 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); Vector128 vector = Sse.And(mask, Sse.LoadVector128(pSrcCurrent)); res0 = Sse.Multiply(x01, vector); - res1 = Sse.Multiply(x11, vector); - res2 = Sse.Multiply(x21, vector); - res3 = Sse.Multiply(x31, vector); pMatCurrent += misalignment; pSrcCurrent += misalignment; @@ -213,14 +332,7 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr float* pMatTemp = pMatCurrent; Vector128 x01 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp)); - Vector128 x11 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); - Vector128 x21 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); - Vector128 x31 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); - res0 = Sse.Add(res0, x01); - res1 = Sse.Add(res1, x11); - res2 = Sse.Add(res2, x21); - res3 = Sse.Add(res3, x31); pSrcCurrent += 4; pMatCurrent += 4; @@ -246,15 +358,9 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr float* pMatTemp = pMatCurrent; Vector128 x01 = Sse.And(mask, Sse.LoadVector128(pMatTemp)); - Vector128 x11 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x21 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x31 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); Vector128 vector = Sse.And(mask, Sse.LoadVector128(pSrcCurrent)); res0 = Sse.Add(res0, Sse.Multiply(x01, vector)); - res1 = Sse.Add(res1, Sse.Multiply(x11, vector)); - res2 = Sse.Add(res2, Sse.Multiply(x21, vector)); - res3 = Sse.Add(res3, Sse.Multiply(x31, vector)); pMatCurrent += 4; pSrcCurrent += 4; @@ -262,24 +368,17 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr } // Add up the entries of each, with the 4 results in res0 - res0 = Sse3.HorizontalAdd(res0, res1); - res2 = Sse3.HorizontalAdd(res2, res3); - res0 = Sse3.HorizontalAdd(res0, res2); + res0 = VectorSum128(in res0); + float sum = Sse.ConvertToSingle(res0); + *pDstCurrent = sum; Sse.Store(pDstCurrent, res0); - pDstCurrent += 4; - pMatCurrent += 3 * ccol; + pDstCurrent += 1; } } } // Partial sparse source vector. - public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan rgposSrc, AlignedArray src, - int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) - { - MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol); - } - public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgposSrc, ReadOnlySpan src, int posMin, int iposMin, int iposEnd, Span dst, int crow, int ccol) { @@ -436,11 +535,6 @@ Vector128 SparseMultiplicationAcrossRow() } } - public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) - { - MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol); - } - public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crow, int ccol) { fixed (float* psrc = &MemoryMarshal.GetReference(src)) @@ -459,10 +553,6 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x01 = Sse.LoadVector128(pSrcCurrent); - // Replicate each 32-bit slot of x01 (ABCD) into its own register. - Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B - Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C - Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D x01 = Sse.Shuffle(x01, x01, 0x00); // A int length = crow; @@ -478,13 +568,6 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.Multiply(x11, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.Multiply(x21, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.Multiply(x31, Sse.LoadVector128(pMatTemp += crow)); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); Sse.Store(pDstCurrent, x02); pDstCurrent += 4; @@ -506,22 +589,11 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + ((4 - misalignment) * 4)); Vector128 x3 = Sse.LoadVector128(pDstCurrent); - x02 = Sse.Or(x02, Sse.And(x3, trailingMask)); + x02 = Sse.Add(x02, Sse.And(x3, trailingMask)); Sse.Store(pDstCurrent, x02); pMatCurrent += misalignment; @@ -538,15 +610,7 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Sse.Multiply(x01, Sse.LoadAlignedVector128(pMatTemp)); - Vector128 x12 = Sse.Multiply(x11, Sse.LoadAlignedVector128(pMatTemp += crow)); - Vector128 x22 = Sse.Multiply(x21, Sse.LoadAlignedVector128(pMatTemp += crow)); - Vector128 x32 = Sse.Multiply(x31, Sse.LoadAlignedVector128(pMatTemp += crow)); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); Sse.Store(pDstCurrent, x02); pDstCurrent += 4; @@ -572,22 +636,11 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); Vector128 x3 = Sse.LoadVector128(pDstCurrent); - x02 = Sse.Or(x02, Sse.And(x3, leadingMask)); + x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); Sse.Store(pDstCurrent, x02); pDstCurrent += 4; @@ -595,18 +648,155 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x01 = Sse.LoadVector128(pSrcCurrent); + // Replicate each 32-bit slot of x01 (ABCD) into its own register. + Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B + Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C + Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D + x01 = Sse.Shuffle(x01, x01, 0x00); // A + + int length = crow; + float* pDstCurrent = pdst; + + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 16); + + if ((misalignment & 3) != 0) + { + while (pDstCurrent < pDstEnd) + { + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.Multiply(x11, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.Multiply(x21, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.Multiply(x31, Sse.LoadVector128(pMatTemp += crow)); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); + + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } + } + else + { + int remainder = 0; + if (misalignment != 0) + { + // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 4 - misalignment; + + Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); + + // We only align pMat since it has significantly more reads. + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); + + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + ((4 - misalignment) * 4)); + Vector128 x3 = Sse.LoadVector128(pDstCurrent); + x02 = Sse.Add(x02, Sse.And(x3, trailingMask)); + + x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); + + Sse.Store(pDstCurrent, x02); + pMatCurrent += misalignment; + pDstCurrent += misalignment; + length -= misalignment; + } + if (length > 3) + { + remainder = length % 4; + while (pDstCurrent + 4 <= pDstEnd) + { + float* pMatTemp = pMatCurrent; + + Vector128 x02 = Sse.Multiply(x01, Sse.LoadAlignedVector128(pMatTemp)); + Vector128 x12 = Sse.Multiply(x11, Sse.LoadAlignedVector128(pMatTemp += crow)); + Vector128 x22 = Sse.Multiply(x21, Sse.LoadAlignedVector128(pMatTemp += crow)); + Vector128 x32 = Sse.Multiply(x31, Sse.LoadAlignedVector128(pMatTemp += crow)); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } + } + else + { + remainder = length; + } + + if (remainder != 0) + { + pMatCurrent -= (4 - remainder); + pDstCurrent -= (4 - remainder); + Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); + + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); + + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); + Vector128 x3 = Sse.LoadVector128(pDstCurrent); + x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); + + x02 = Sse.Add(x02, Sse.And(x3, trailingMask)); + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } + } + + pMatCurrent += 3 * crow; + pSrcCurrent += 4; + } } - // We do 4-way unrolling while (pSrcCurrent < pSrcEnd) { Vector128 x01 = Sse.LoadVector128(pSrcCurrent); // Replicate each 32-bit slot of x01 (ABCD) into its own register. - Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B - Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C - Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D x01 = Sse.Shuffle(x01, x01, 0x00); // A int length = crow; @@ -621,13 +811,6 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.Multiply(x11, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.Multiply(x21, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.Multiply(x31, Sse.LoadVector128(pMatTemp += crow)); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); @@ -651,22 +834,11 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + ((4 - misalignment) * 4)); Vector128 x3 = Sse.LoadVector128(pDstCurrent); - x02 = Sse.Or(x02, Sse.And(x3, trailingMask)); + x02 = Sse.Add(x02, Sse.And(x3, trailingMask)); x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); @@ -680,18 +852,14 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Sse.Multiply(x01, Sse.LoadAlignedVector128(pMatTemp)); - Vector128 x12 = Sse.Multiply(x11, Sse.LoadAlignedVector128(pMatTemp += crow)); - Vector128 x22 = Sse.Multiply(x21, Sse.LoadAlignedVector128(pMatTemp += crow)); - Vector128 x32 = Sse.Multiply(x31, Sse.LoadAlignedVector128(pMatTemp += crow)); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); + Vector128 x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); + Sse.Store(pDstCurrent, x02); pDstCurrent += 4; pMatCurrent += 4; @@ -702,7 +870,7 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan 0) { pMatCurrent -= (4 - remainder); pDstCurrent -= (4 - remainder); @@ -710,32 +878,21 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); Vector128 x3 = Sse.LoadVector128(pDstCurrent); - x02 = Sse.Or(x02, Sse.And(x3, leadingMask)); + x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); x02 = Sse.Add(x02, Sse.And(x3, trailingMask)); + Sse.Store(pDstCurrent, x02); pDstCurrent += 4; pMatCurrent += 4; } - } - pMatCurrent += 3 * crow; - pSrcCurrent += 4; + pSrcCurrent += 1; + } } } } diff --git a/src/Microsoft.ML.CpuMath/Thunk.cs b/src/Microsoft.ML.CpuMath/Thunk.cs index 8ff725b54a..9505db8766 100644 --- a/src/Microsoft.ML.CpuMath/Thunk.cs +++ b/src/Microsoft.ML.CpuMath/Thunk.cs @@ -86,12 +86,6 @@ public static extern void MatMulP(/*const*/ float* pmat, /*const*/ int* pposSrc, [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern float Dist2(/*const*/ float* px, /*const*/ float* py, int c); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ZeroItemsU(float* pd, int c, /*const*/ int* pindices, int cindices); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void ZeroMatrixItemsCore(float* pd, int c, int ccol, int cfltRow, /*const*/ int* pindices, int cindices); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void SdcaL1UpdateU(float primalUpdate, /*const*/ float* ps, float threshold, float* pd1, float* pd2, int c); diff --git a/src/Microsoft.ML.CpuMath/AlignedArray.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/AlignedArray.cs similarity index 53% rename from src/Microsoft.ML.CpuMath/AlignedArray.cs rename to src/Microsoft.ML.StandardLearners/FactorizationMachine/AlignedArray.cs index 0a631be0e9..f01da9fe28 100644 --- a/src/Microsoft.ML.CpuMath/AlignedArray.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/AlignedArray.cs @@ -5,7 +5,7 @@ using Microsoft.ML.Runtime.Internal.CpuMath.Core; using System; -namespace Microsoft.ML.Runtime.Internal.CpuMath +namespace Microsoft.ML.Runtime.FactorizationMachine { using Float = System.Single; @@ -17,7 +17,6 @@ namespace Microsoft.ML.Runtime.Internal.CpuMath /// /// The ctor takes an alignment value, which must be a power of two at least sizeof(Float). /// - [BestFriend] internal sealed class AlignedArray { // Items includes "head" items filled with NaN, followed by _size entries, followed by "tail" @@ -110,60 +109,6 @@ public Float this[int index] } } - public void CopyTo(Span dst, int index, int count) - { - Contracts.Assert(0 <= count && count <= _size); - Contracts.Assert(dst != null); - Contracts.Assert(0 <= index && index <= dst.Length - count); - Items.AsSpan(_base, count).CopyTo(dst.Slice(index)); - } - - public void CopyTo(int start, Span dst, int index, int count) - { - Contracts.Assert(0 <= count); - Contracts.Assert(0 <= start && start <= _size - count); - Contracts.Assert(dst != null); - Contracts.Assert(0 <= index && index <= dst.Length - count); - Items.AsSpan(start + _base, count).CopyTo(dst.Slice(index)); - } - - public void CopyFrom(ReadOnlySpan src) - { - Contracts.Assert(src.Length <= _size); - src.CopyTo(Items.AsSpan(_base)); - } - - public void CopyFrom(int start, ReadOnlySpan src) - { - Contracts.Assert(0 <= start && start <= _size - src.Length); - src.CopyTo(Items.AsSpan(start + _base)); - } - - // Copies values from a sparse vector. - // valuesSrc contains only the non-zero entries. Those are copied into their logical positions in the dense array. - // rgposSrc contains the logical positions + offset of the non-zero entries in the dense array. - // rgposSrc runs parallel to the valuesSrc array. - public void CopyFrom(ReadOnlySpan rgposSrc, ReadOnlySpan valuesSrc, int posMin, int iposMin, int iposLim, bool zeroItems) - { - Contracts.Assert(rgposSrc != null); - Contracts.Assert(valuesSrc != null); - Contracts.Assert(rgposSrc.Length <= valuesSrc.Length); - Contracts.Assert(0 <= iposMin && iposMin <= iposLim && iposLim <= rgposSrc.Length); - - // Zeroing-out and setting the values in one-pass does not seem to give any perf benefit. - // So explicitly zeroing and then setting the values. - if (zeroItems) - ZeroItems(); - - for (int ipos = iposMin; ipos < iposLim; ++ipos) - { - Contracts.Assert(posMin <= rgposSrc[ipos]); - int iv = _base + rgposSrc[ipos] - posMin; - Contracts.Assert(iv < _size + _base); - Items[iv] = valuesSrc[ipos]; - } - } - public void CopyFrom(AlignedArray src) { Contracts.Assert(src != null); @@ -171,41 +116,5 @@ public void CopyFrom(AlignedArray src) Contracts.Assert(src._cbAlign == _cbAlign); Array.Copy(src.Items, src._base, Items, _base, _size); } - - public void ZeroItems() - { - Array.Clear(Items, _base, _size); - } - - public void ZeroItems(int[] rgposSrc, int posMin, int iposMin, int iposLim) - { - Contracts.Assert(rgposSrc != null); - Contracts.Assert(0 <= iposMin && iposMin <= iposLim && iposLim <= rgposSrc.Length); - Contracts.Assert(iposLim - iposMin <= _size); - - int ivCur = 0; - for (int ipos = iposMin; ipos < iposLim; ++ipos) - { - int ivNextNonZero = rgposSrc[ipos] - posMin; - Contracts.Assert(ivCur <= ivNextNonZero && ivNextNonZero < _size); - while (ivCur < ivNextNonZero) - Items[_base + ivCur++] = 0; - Contracts.Assert(ivCur == ivNextNonZero); - // Skip the non-zero element at ivNextNonZero. - ivCur++; - } - - while (ivCur < _size) - Items[_base + ivCur++] = 0; - } - - // REVIEW: This is hackish and slightly dangerous. Perhaps we should wrap this in an - // IDisposable that "locks" this, prohibiting GetBase from being called, while the buffer - // is "checked out". - public void GetRawBuffer(out Float[] items, out int offset) - { - items = Items; - offset = _base; - } } } \ No newline at end of file diff --git a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs index a473ccec29..3d3baeaa2f 100644 --- a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs +++ b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs @@ -174,8 +174,8 @@ public ModelInfo Info private Single[] _alpha; private Single[] _state; private readonly FixedSizeQueue _buffer; - private CpuAlignedVector _x; - private CpuAlignedVector _xSmooth; + private float[] _x; + private float[] _xSmooth; private int _windowSize; private readonly int _seriesLength; private readonly RankSelectionMethod _rankSelectionMethod; @@ -188,14 +188,14 @@ public ModelInfo Info private readonly IHost _host; - private CpuAlignedMatrixRow _wTrans; + private float[] _wTrans; private Single _observationNoiseVariance; private Single _observationNoiseMean; private Single _autoregressionNoiseVariance; private Single _autoregressionNoiseMean; private int _rank; - private CpuAlignedVector _y; + private float[] _y; private Single _nextPrediction; /// @@ -290,8 +290,8 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, int trainSi _alpha = new Single[windowSize - 1]; _state = new Single[windowSize - 1]; - _x = new CpuAlignedVector(windowSize, SseUtils.CbAlign); - _xSmooth = new CpuAlignedVector(windowSize, SseUtils.CbAlign); + _x = new float[windowSize]; + _xSmooth = new float[windowSize]; ShouldComputeForecastIntervals = shouldComputeForecastIntervals; _observationNoiseVariance = 0; @@ -345,14 +345,14 @@ private AdaptiveSingularSpectrumSequenceModeler(AdaptiveSingularSpectrumSequence _state = new Single[_windowSize - 1]; Array.Copy(model._state, _state, _windowSize - 1); - _x = new CpuAlignedVector(_windowSize, SseUtils.CbAlign); - _xSmooth = new CpuAlignedVector(_windowSize, SseUtils.CbAlign); + _x = new float[_windowSize]; + _xSmooth = new float[_windowSize]; if (model._wTrans != null) { - _y = new CpuAlignedVector(_rank, SseUtils.CbAlign); - _wTrans = new CpuAlignedMatrixRow(_rank, _windowSize, SseUtils.CbAlign); - _wTrans.CopyFrom(model._wTrans); + _y = new float[_rank]; + _wTrans = new float[_rank * _windowSize]; + Array.Copy(model._wTrans, _wTrans, _rank * _windowSize); } } @@ -452,18 +452,16 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo { var tempArray = ctx.Reader.ReadFloatArray(); _host.CheckDecode(Utils.Size(tempArray) == _rank * _windowSize); - _wTrans = new CpuAlignedMatrixRow(_rank, _windowSize, SseUtils.CbAlign); - int i = 0; - _wTrans.CopyFrom(tempArray, ref i); + _wTrans = new float[_rank * _windowSize]; + Array.Copy(tempArray, _wTrans, tempArray.Length); tempArray = ctx.Reader.ReadFloatArray(); - i = 0; - _y = new CpuAlignedVector(_rank, SseUtils.CbAlign); - _y.CopyFrom(tempArray, ref i); + _y = new float[_rank]; + Array.Copy(tempArray, _y, tempArray.Length); } _buffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(ctx.Reader, _host); - _x = new CpuAlignedVector(_windowSize, SseUtils.CbAlign); - _xSmooth = new CpuAlignedVector(_windowSize, SseUtils.CbAlign); + _x = new float[_windowSize]; + _xSmooth = new float[_windowSize]; } public override void Save(ModelSaveContext ctx) @@ -527,10 +525,8 @@ public override void Save(ModelSaveContext ctx) if (_wTrans != null) { - // REVIEW: this may not be the most efficient way for serializing an aligned matrix. var tempArray = new Single[_rank * _windowSize]; - int iv = 0; - _wTrans.CopyTo(tempArray, ref iv); + Array.Copy(_wTrans, tempArray, _wTrans.Length); ctx.Writer.WriteSingleArray(tempArray); tempArray = new float[_rank]; iv = 0; @@ -1130,15 +1126,14 @@ internal override void Consume(ref Single input, bool updateModel = false) if (_wTrans == null) { - _y = new CpuAlignedVector(_rank, SseUtils.CbAlign); - _wTrans = new CpuAlignedMatrixRow(_rank, _windowSize, SseUtils.CbAlign); + _y = new float[_rank]; + _wTrans = new float[_rank * _windowSize]; Single[] vecs = new Single[_rank * _windowSize]; for (i = 0; i < _rank; ++i) vecs[(_windowSize + 1) * i] = 1; - i = 0; - _wTrans.CopyFrom(vecs, ref i); + Array.Copy(_wTrans, vecs, _rank * _windowSize); } // Forming vector x @@ -1157,10 +1152,10 @@ internal override void Consume(ref Single input, bool updateModel = false) _x[_windowSize - 1] = input; // Computing y: Eq. (11) in https://hal-institut-mines-telecom.archives-ouvertes.fr/hal-00479772/file/twocolumns.pdf - CpuAligenedMathUtils.MatTimesSrc(_wTrans, _x, _y); + CpuMathUtils.MatrixTimesSource(transpose: false, _wTrans, _x, _y, _y.Length); // Updating the state vector - CpuAligenedMathUtils.MatTranTimesSrc(_wTrans, _y, _xSmooth); + CpuMathUtils.MatrixTimesSource(transpose: true, _wTrans, _y, _xSmooth, _y.Length); _nextPrediction = _autoregressionNoiseMean + _observationNoiseMean; for (i = 0; i < _windowSize - 2; ++i) @@ -1311,8 +1306,8 @@ private void TrainCore(Single[] dataArray, int originalSeriesLength) _maxRank = _windowSize / 2; _alpha = new Single[_windowSize - 1]; _state = new Single[_windowSize - 1]; - _x = new CpuAlignedVector(_windowSize, SseUtils.CbAlign); - _xSmooth = new CpuAlignedVector(_windowSize, SseUtils.CbAlign); + _x = new float[_windowSize]; + _xSmooth = new float[_windowSize]; TrainCore(dataArray, originalSeriesLength); return; @@ -1349,12 +1344,11 @@ private void TrainCore(Single[] dataArray, int originalSeriesLength) } // Setting the the y vector - _y = new CpuAlignedVector(_rank, SseUtils.CbAlign); + _y = new float[_rank]; // Setting the weight matrix - _wTrans = new CpuAlignedMatrixRow(_rank, _windowSize, SseUtils.CbAlign); - i = 0; - _wTrans.CopyFrom(leftSingularVecs, ref i); + _wTrans = new float[_rank * _windowSize]; + Array.Copy(leftSingularVecs, _wTrans, _rank * _windowSize); // Setting alpha Single nu = 0; @@ -1364,7 +1358,7 @@ private void TrainCore(Single[] dataArray, int originalSeriesLength) nu += _y[i] * _y[i]; } - CpuAligenedMathUtils.MatTranTimesSrc(_wTrans, _y, _xSmooth); + CpuMathUtils.MatrixTimesSource(transpose: true, _wTrans, _y, _xSmooth, _y.Length); for (i = 0; i < _windowSize - 1; ++i) _alpha[i] = _xSmooth[i] / (1 - nu); @@ -1409,8 +1403,8 @@ private void TrainCore(Single[] dataArray, int originalSeriesLength) _x[i - originalSeriesLength + _windowSize] = dataArray[i]; } - CpuAligenedMathUtils.MatTimesSrc(_wTrans, _x, _y); - CpuAligenedMathUtils.MatTranTimesSrc(_wTrans, _y, _xSmooth); + CpuMathUtils.MatrixTimesSource(transpose: false, _wTrans, _x, _y, _y.Length); + CpuMathUtils.MatrixTimesSource(transpose: true, _wTrans, _y, _xSmooth, _y.Length); for (i = 1; i < _windowSize; ++i) { diff --git a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs index 0984f40f97..f89547471b 100644 --- a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs +++ b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs @@ -95,10 +95,10 @@ private sealed class TransformInfo public readonly int SrcDim; // the matrix containing the random fourier vectors - public readonly AlignedArray RndFourierVectors; + public readonly float[] RndFourierVectors; // the random rotations - public readonly AlignedArray RotationTerms; + public readonly float[] RotationTerms; private readonly IFourierDistributionSampler _matrixGenerator; private readonly bool _useSin; @@ -120,10 +120,10 @@ public TransformInfo(IHost host, ColumnInfo column, int d, float avgDist) var generator = column.Generator; _matrixGenerator = generator.CreateComponent(host, avgDist); - int roundedUpD = RoundUp(NewDim, _cfltAlign); - int roundedUpNumFeatures = RoundUp(SrcDim, _cfltAlign); - RndFourierVectors = new AlignedArray(roundedUpD * roundedUpNumFeatures, CpuMathUtils.GetVectorAlignment()); - RotationTerms = _useSin ? null : new AlignedArray(roundedUpD, CpuMathUtils.GetVectorAlignment()); + int roundedUpD = RoundUp(NewDim); + int roundedUpNumFeatures = RoundUp(SrcDim); + RndFourierVectors = new float[roundedUpD * roundedUpNumFeatures]; + RotationTerms = _useSin ? null : new float[roundedUpD]; InitializeFourierCoefficients(roundedUpNumFeatures, roundedUpD); } @@ -154,10 +154,10 @@ public TransformInfo(IHostEnvironment env, ModelLoadContext ctx, string director ctx.LoadModelOrNull(env, out _matrixGenerator, directoryName)); // initialize the transform matrix - int roundedUpD = RoundUp(NewDim, _cfltAlign); - int roundedUpNumFeatures = RoundUp(SrcDim, _cfltAlign); - RndFourierVectors = new AlignedArray(roundedUpD * roundedUpNumFeatures, CpuMathUtils.GetVectorAlignment()); - RotationTerms = _useSin ? null : new AlignedArray(roundedUpD, CpuMathUtils.GetVectorAlignment()); + int roundedUpD = RoundUp(NewDim); + int roundedUpNumFeatures = RoundUp(SrcDim); + RndFourierVectors = new float[roundedUpD * roundedUpNumFeatures]; + RotationTerms = _useSin ? null : new float[roundedUpD]; InitializeFourierCoefficients(roundedUpNumFeatures, roundedUpD); } @@ -225,8 +225,6 @@ private static VersionInfo GetVersionInfo() private readonly TransformInfo[] _transformInfos; - private static readonly int _cfltAlign = CpuMathUtils.GetVectorAlignment() / sizeof(float); - private static string TestColumnType(ColumnType type) { if (type.ItemType == NumberType.Float && type.IsKnownSizeVector) @@ -295,16 +293,11 @@ public RandomFourierFeaturizingTransformer(IHostEnvironment env, IDataView input } } - // Round cflt up to a multiple of cfltAlign. - private static int RoundUp(int cflt, int cfltAlign) + private static int RoundUp(int number) { - Contracts.Assert(0 < cflt); - // cfltAlign should be a power of two. - Contracts.Assert(0 < cfltAlign && (cfltAlign & (cfltAlign - 1)) == 0); - - // Determine the number of "blobs" of size cfltAlign. - int cblob = (cflt + cfltAlign - 1) / cfltAlign; - return cblob * cfltAlign; + Contracts.Assert(0 < number); + int multipleOf4 = (number + 3) / 4; + return multipleOf4 * 4; } private float[] GetAvgDistances(ColumnInfo[] columns, IDataView input) @@ -555,8 +548,8 @@ private ValueGetter> GetterFromVectorType(IRow input, int iinfo) var getSrc = input.GetGetter>(_srcCols[iinfo]); var src = default(VBuffer); - var featuresAligned = new AlignedArray(RoundUp(_srcTypes[iinfo].ValueCount, _cfltAlign), CpuMathUtils.GetVectorAlignment()); - var productAligned = new AlignedArray(RoundUp(_parent._transformInfos[iinfo].NewDim, _cfltAlign), CpuMathUtils.GetVectorAlignment()); + var featuresAligned = new float[RoundUp(_srcTypes[iinfo].ValueCount)]; + var productAligned = new float[RoundUp(_parent._transformInfos[iinfo].NewDim)]; return (ref VBuffer dst) => @@ -572,8 +565,8 @@ private ValueGetter> GetterFromFloatType(IRow input, int iinfo) var getSrc = input.GetGetter(_srcCols[iinfo]); var src = default(float); - var featuresAligned = new AlignedArray(RoundUp(1, _cfltAlign), CpuMathUtils.GetVectorAlignment()); - var productAligned = new AlignedArray(RoundUp(_parent._transformInfos[iinfo].NewDim, _cfltAlign), CpuMathUtils.GetVectorAlignment()); + var featuresAligned = new float[4]; + var productAligned = new float[RoundUp(_parent._transformInfos[iinfo].NewDim)]; var oneDimensionalVector = new VBuffer(1, new float[] { 0 }); @@ -587,7 +580,7 @@ private ValueGetter> GetterFromFloatType(IRow input, int iinfo) } private void TransformFeatures(in VBuffer src, ref VBuffer dst, TransformInfo transformInfo, - AlignedArray featuresAligned, AlignedArray productAligned) + float[] featuresAligned, float[] productAligned) { Host.Check(src.Length == transformInfo.SrcDim, "column does not have the expected dimensionality."); @@ -606,9 +599,9 @@ private void TransformFeatures(in VBuffer src, ref VBuffer dst, Tr if (src.IsDense) { - featuresAligned.CopyFrom(src.GetValues()); - CpuMathUtils.MatrixTimesSource(false, transformInfo.RndFourierVectors, featuresAligned, productAligned, - transformInfo.NewDim); + src.GetValues().CopyTo(featuresAligned); + CpuMathUtils.MatrixTimesSource(transpose: false, transformInfo.RndFourierVectors, featuresAligned, productAligned, + RoundUp(transformInfo.NewDim)); } else { @@ -616,9 +609,15 @@ private void TransformFeatures(in VBuffer src, ref VBuffer dst, Tr // no need to zero them out. var srcValues = src.GetValues(); var srcIndices = src.GetIndices(); - featuresAligned.CopyFrom(srcIndices, srcValues, 0, 0, srcValues.Length, zeroItems: false); + + for (int i = 0; i< srcValues.Length; i++) + { + int iv = srcIndices[0]; + featuresAligned[iv] = srcValues[i]; + } + CpuMathUtils.MatrixTimesSource(transformInfo.RndFourierVectors, srcIndices, featuresAligned, 0, 0, - srcValues.Length, productAligned, transformInfo.NewDim); + srcValues.Length, productAligned, RoundUp(transformInfo.NewDim)); } var dstEditor = VBufferEditor.Create(ref dst, newDstLength); diff --git a/src/Native/CpuMathNative/Sse.cpp b/src/Native/CpuMathNative/Sse.cpp index 607af332d1..45344ecf97 100644 --- a/src/Native/CpuMathNative/Sse.cpp +++ b/src/Native/CpuMathNative/Sse.cpp @@ -60,17 +60,166 @@ const unsigned int TrailingAlignmentMask[16] = // Multiply matrix times vector into vector. EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) { + if (ccol < 4) + { + for (int i = 0 ; i < crow; i++) + { + float dotProduct = 0; + for (int j = 0; j < ccol; j++) + { + dotProduct += pmat[i * ccol + j] * psrc[j]; + } + pdst[i] = dotProduct; + } + return; + } + const float * pSrcEnd = psrc + ccol; const float * pDstEnd = pdst + crow; float* pDstCurrent = pdst; const float* pMatCurrent = pmat; + if (ccol % 4 == 0) + { + while (pDstCurrent + 4 <= pDstEnd) + { + __m128 res0 = _mm_setzero_ps(); + __m128 res1 = res0; + __m128 res2 = res0; + __m128 res3 = res0; + + int length = ccol; + const float* pSrcCurrent = psrc; + + uintptr_t address = (uintptr_t)(pMatCurrent); + uintptr_t misalignment = address % 16; + int remainder = 0; + + if ((misalignment & 3) != 0) + { + // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations + while (pSrcCurrent < pSrcEnd) + { + __m128 vector = _mm_loadu_ps(pSrcCurrent); + + const float* pMatTemp = pMatCurrent; + __m128 x01 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp)); + __m128 x11 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x21 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x31 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); + + res0 = _mm_add_ps(res0, x01); + res1 = _mm_add_ps(res1, x11); + res2 = _mm_add_ps(res2, x21); + res3 = _mm_add_ps(res3, x31); + + pSrcCurrent += 4; + pMatCurrent += 4; + } + } + else + { + if (misalignment != 0) + { + // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 4 - misalignment; + + __m128 mask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); + + // We only align pMat since it has significantly more reads. + const float* pMatTemp = pMatCurrent; + __m128 x01 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp)); + __m128 x11 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x21 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x31 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 vector = _mm_and_ps(mask, _mm_loadu_ps(pSrcCurrent)); + + res0 = _mm_mul_ps(x01, vector); + res1 = _mm_mul_ps(x11, vector); + res2 = _mm_mul_ps(x21, vector); + res3 = _mm_mul_ps(x31, vector); + + pMatCurrent += misalignment; + pSrcCurrent += misalignment; + length -= misalignment; + } + + if (length > 3) + { + // Handle all the 128-bit blocks that we can now that we have offset to an aligned address + remainder = length % 4; + + while (pSrcCurrent + 4 <= pSrcEnd) + { + __m128 vector = _mm_loadu_ps(pSrcCurrent); + + const float* pMatTemp = pMatCurrent; + __m128 x01 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp)); + __m128 x11 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); + __m128 x21 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); + __m128 x31 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); + + res0 = _mm_add_ps(res0, x01); + res1 = _mm_add_ps(res1, x11); + res2 = _mm_add_ps(res2, x21); + res3 = _mm_add_ps(res3, x31); + + pSrcCurrent += 4; + pMatCurrent += 4; + } + } + else + { + // Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not + // 128-bit aligned. This means we can't do any aligned loads and will just end up doing two + // unaligned loads where we mask the input each time. + remainder = length; + } + + if (remainder != 0) + { + // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next + // unaligned load will read to the end of the array and then mask out any elements already processed + + pMatCurrent -= (4 - remainder); + pSrcCurrent -= (4 - remainder); + + __m128 mask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); + + const float* pMatTemp = pMatCurrent; + __m128 x01 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp)); + __m128 x11 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x21 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x31 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 vector = _mm_and_ps(mask, _mm_loadu_ps(pSrcCurrent)); + + res0 = _mm_add_ps(res0, _mm_mul_ps(x01, vector)); + res1 = _mm_add_ps(res1, _mm_mul_ps(x11, vector)); + res2 = _mm_add_ps(res2, _mm_mul_ps(x21, vector)); + res3 = _mm_add_ps(res3, _mm_mul_ps(x31, vector)); + + pMatCurrent += 4; + pSrcCurrent += 4; + } + } + + // Add up the entries of each, with the 4 results in res0 + res0 = _mm_hadd_ps(res0, res1); + res2 = _mm_hadd_ps(res2, res3); + res0 = _mm_hadd_ps(res0, res2); + + _mm_storeu_ps(pDstCurrent, res0); + + pDstCurrent += 4; + pMatCurrent += 3 * ccol; + } + } + while (pDstCurrent < pDstEnd) { __m128 res0 = _mm_setzero_ps(); - __m128 res1 = res0; - __m128 res2 = res0; - __m128 res3 = res0; int length = ccol; const float* pSrcCurrent = psrc; @@ -88,14 +237,7 @@ EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout const float* pMatTemp = pMatCurrent; __m128 x01 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp)); - __m128 x11 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x21 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x31 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); - res0 = _mm_add_ps(res0, x01); - res1 = _mm_add_ps(res1, x11); - res2 = _mm_add_ps(res2, x21); - res3 = _mm_add_ps(res3, x31); pSrcCurrent += 4; pMatCurrent += 4; @@ -115,15 +257,8 @@ EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout // We only align pMat since it has significantly more reads. const float* pMatTemp = pMatCurrent; __m128 x01 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp)); - __m128 x11 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x21 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x31 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); __m128 vector = _mm_and_ps(mask, _mm_loadu_ps(pSrcCurrent)); - res0 = _mm_mul_ps(x01, vector); - res1 = _mm_mul_ps(x11, vector); - res2 = _mm_mul_ps(x21, vector); - res3 = _mm_mul_ps(x31, vector); pMatCurrent += misalignment; pSrcCurrent += misalignment; @@ -141,14 +276,7 @@ EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout const float* pMatTemp = pMatCurrent; __m128 x01 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp)); - __m128 x11 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); - __m128 x21 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); - __m128 x31 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); - res0 = _mm_add_ps(res0, x01); - res1 = _mm_add_ps(res1, x11); - res2 = _mm_add_ps(res2, x21); - res3 = _mm_add_ps(res3, x31); pSrcCurrent += 4; pMatCurrent += 4; @@ -174,30 +302,21 @@ EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout const float* pMatTemp = pMatCurrent; __m128 x01 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp)); - __m128 x11 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x21 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x31 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); __m128 vector = _mm_and_ps(mask, _mm_loadu_ps(pSrcCurrent)); res0 = _mm_add_ps(res0, _mm_mul_ps(x01, vector)); - res1 = _mm_add_ps(res1, _mm_mul_ps(x11, vector)); - res2 = _mm_add_ps(res2, _mm_mul_ps(x21, vector)); - res3 = _mm_add_ps(res3, _mm_mul_ps(x31, vector)); pMatCurrent += 4; pSrcCurrent += 4; } } - // Add up the entries of each, with the 4 results in res0 - res0 = _mm_hadd_ps(res0, res1); - res2 = _mm_hadd_ps(res2, res3); - res0 = _mm_hadd_ps(res0, res2); + // Sum all the elements together and return the result + res0 = _mm_hadd_ps(res0, res0); + res0 = _mm_hadd_ps(res0, res0); - _mm_storeu_ps(pDstCurrent, res0); - - pDstCurrent += 4; - pMatCurrent += 3 * ccol; + *pDstCurrent = _mm_cvtss_f32(res0); + pDstCurrent += 1; } } @@ -358,6 +477,20 @@ EXPORT_API(void) MatMulP(_In_ const float * pmat, _In_ const int * pposSrc, _In_ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) { + if (crow < 4) + { + for (int i = 0 ; i < crow; i++) + { + float dotProduct = 0; + for (int j = 0; j < ccol; j++) + { + dotProduct += pmat[j * crow + i] * psrc[j]; + } + pdst[i] = dotProduct; + } + return; + } + const float * pSrcEnd = psrc + ccol; const float * pDstEnd = pdst + crow; @@ -367,10 +500,6 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I if (pSrcCurrent < pSrcEnd) { __m128 x01 = _mm_loadu_ps(pSrcCurrent); - // Replicate each slot of x01 into its own register. - __m128 x11 = _mm_shuffle_ps(x01, x01, 0x55); - __m128 x21 = _mm_shuffle_ps(x01, x01, 0xAA); - __m128 x31 = _mm_shuffle_ps(x01, x01, 0xFF); x01 = _mm_shuffle_ps(x01, x01, 0x00); int length = crow; @@ -387,13 +516,6 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I { const float* pMatTemp = pMatCurrent; __m128 x02 = _mm_mul_ps(x01, _mm_loadu_ps(pMatTemp)); - __m128 x12 = _mm_mul_ps(x11, _mm_loadu_ps(pMatTemp += crow)); - __m128 x22 = _mm_mul_ps(x21, _mm_loadu_ps(pMatTemp += crow)); - __m128 x32 = _mm_mul_ps(x31, _mm_loadu_ps(pMatTemp += crow)); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); _mm_storeu_ps(pDstCurrent, x02); pDstCurrent += 4; @@ -415,22 +537,11 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I // We only align pMat since it has significantly more reads. const float* pMatTemp = pMatCurrent; __m128 x02 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp)); - __m128 x12 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x22 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x32 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); - x02 = _mm_mul_ps(x01, x02); - x12 = _mm_mul_ps(x11, x12); - x22 = _mm_mul_ps(x21, x22); - x32 = _mm_mul_ps(x31, x32); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + ((4 - misalignment) * 4)); __m128 x3 = _mm_loadu_ps(pDstCurrent); - x02 = _mm_or_ps(x02, _mm_and_ps(x3, trailingMask)); + x02 = _mm_add_ps(x02, _mm_and_ps(x3, trailingMask)); _mm_storeu_ps(pDstCurrent, x02); pMatCurrent += misalignment; @@ -447,13 +558,6 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I { const float* pMatTemp = pMatCurrent; __m128 x02 = _mm_mul_ps(x01, _mm_load_ps(pMatTemp)); - __m128 x12 = _mm_mul_ps(x11, _mm_load_ps(pMatTemp += crow)); - __m128 x22 = _mm_mul_ps(x21, _mm_load_ps(pMatTemp += crow)); - __m128 x32 = _mm_mul_ps(x31, _mm_load_ps(pMatTemp += crow)); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); _mm_storeu_ps(pDstCurrent, x02); @@ -481,22 +585,11 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I const float* pMatTemp = pMatCurrent; __m128 x02 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp)); - __m128 x12 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x22 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x32 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); - x02 = _mm_mul_ps(x01, x02); - x12 = _mm_mul_ps(x11, x12); - x22 = _mm_mul_ps(x21, x22); - x32 = _mm_mul_ps(x31, x32); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + ((4 - remainder) * 4)); __m128 x3 = _mm_loadu_ps(pDstCurrent); - x02 = _mm_or_ps(x02, _mm_and_ps(x3, leadingMask)); + x02 = _mm_add_ps(x02, _mm_and_ps(x3, leadingMask)); _mm_storeu_ps(pDstCurrent, x02); pMatCurrent += 4; @@ -504,17 +597,154 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I } } - pMatCurrent += 3 * crow; - pSrcCurrent += 4; + pSrcCurrent += 1; + } + + if (crow % 4 == 0) + { + while (pSrcCurrent + 4 <= pSrcEnd) + { + __m128 x01 = _mm_loadu_ps(pSrcCurrent); + // Replicate each slot of x01 into its own register. + __m128 x11 = _mm_shuffle_ps(x01, x01, 0x55); + __m128 x21 = _mm_shuffle_ps(x01, x01, 0xAA); + __m128 x31 = _mm_shuffle_ps(x01, x01, 0xFF); + x01 = _mm_shuffle_ps(x01, x01, 0x00); + + int length = crow; + float* pDstCurrent = pdst; + + uintptr_t address = (uintptr_t)(pMatCurrent); + uintptr_t misalignment = address % 16; + int remainder = 0; + + if ((misalignment & 3) != 0) + { + while (pDstCurrent < pDstEnd) + { + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_mul_ps(x01, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_mul_ps(x11, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_mul_ps(x21, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_mul_ps(x31, _mm_loadu_ps(pMatTemp += crow)); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); + + _mm_storeu_ps(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } + } + else + { + int remainder = 0; + if (misalignment != 0) + { + misalignment >>= 2; + misalignment = 4 - misalignment; + + __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); + + // We only align pMat since it has significantly more reads. + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); + + x02 = _mm_mul_ps(x01, x02); + x12 = _mm_mul_ps(x11, x12); + x22 = _mm_mul_ps(x21, x22); + x32 = _mm_mul_ps(x31, x32); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + ((4 - misalignment) * 4)); + __m128 x3 = _mm_loadu_ps(pDstCurrent); + x02 = _mm_add_ps(x02, _mm_and_ps(x3, trailingMask)); + x02 = _mm_add_ps(x02, _mm_and_ps(x3, leadingMask)); + + _mm_storeu_ps(pDstCurrent, x02); + pMatCurrent += misalignment; + pDstCurrent += misalignment; + length -= misalignment; + } + + if (length > 3) + { + remainder = length % 4; + while (pDstCurrent + 4 <= pDstEnd) + { + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_mul_ps(x01, _mm_load_ps(pMatTemp)); + __m128 x12 = _mm_mul_ps(x11, _mm_load_ps(pMatTemp += crow)); + __m128 x22 = _mm_mul_ps(x21, _mm_load_ps(pMatTemp += crow)); + __m128 x32 = _mm_mul_ps(x31, _mm_load_ps(pMatTemp += crow)); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); + + _mm_storeu_ps(pDstCurrent, x02); + + pDstCurrent += 4; + pMatCurrent += 4; + } + } + else + { + remainder = length; + } + + if (remainder != 0) + { + pMatCurrent -= (4 - remainder); + pDstCurrent -= (4 - remainder); + + __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); + + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); + + x02 = _mm_mul_ps(x01, x02); + x12 = _mm_mul_ps(x11, x12); + x22 = _mm_mul_ps(x21, x22); + x32 = _mm_mul_ps(x31, x32); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + ((4 - remainder) * 4)); + __m128 x3 = _mm_loadu_ps(pDstCurrent); + x02 = _mm_add_ps(x02, _mm_and_ps(x3, leadingMask)); + + x02 = _mm_add_ps(x02, _mm_and_ps(x3, trailingMask)); + _mm_storeu_ps(pDstCurrent, x02); + pMatCurrent += 4; + pDstCurrent += 4; + } + } + + pMatCurrent += 3 * crow; + pSrcCurrent += 4; + } } while (pSrcCurrent < pSrcEnd) { __m128 x01 = _mm_loadu_ps(pSrcCurrent); - // Replicate each slot of x01 into its own register. - __m128 x11 = _mm_shuffle_ps(x01, x01, 0x55); - __m128 x21 = _mm_shuffle_ps(x01, x01, 0xAA); - __m128 x31 = _mm_shuffle_ps(x01, x01, 0xFF); x01 = _mm_shuffle_ps(x01, x01, 0x00); int length = crow; @@ -522,7 +752,6 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I uintptr_t address = (uintptr_t)(pMatCurrent); uintptr_t misalignment = address % 16; - int remainder = 0; if ((misalignment & 3) != 0) { @@ -530,14 +759,6 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I { const float* pMatTemp = pMatCurrent; __m128 x02 = _mm_mul_ps(x01, _mm_loadu_ps(pMatTemp)); - __m128 x12 = _mm_mul_ps(x11, _mm_loadu_ps(pMatTemp += crow)); - __m128 x22 = _mm_mul_ps(x21, _mm_loadu_ps(pMatTemp += crow)); - __m128 x32 = _mm_mul_ps(x31, _mm_loadu_ps(pMatTemp += crow)); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); - x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); _mm_storeu_ps(pDstCurrent, x02); @@ -558,22 +779,12 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I // We only align pMat since it has significantly more reads. const float* pMatTemp = pMatCurrent; __m128 x02 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp)); - __m128 x12 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x22 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x32 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); x02 = _mm_mul_ps(x01, x02); - x12 = _mm_mul_ps(x11, x12); - x22 = _mm_mul_ps(x21, x22); - x32 = _mm_mul_ps(x31, x32); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + ((4 - misalignment) * 4)); __m128 x3 = _mm_loadu_ps(pDstCurrent); - x02 = _mm_or_ps(x02, _mm_and_ps(x3, trailingMask)); + x02 = _mm_add_ps(x02, _mm_and_ps(x3, trailingMask)); x02 = _mm_add_ps(x02, _mm_and_ps(x3, leadingMask)); _mm_storeu_ps(pDstCurrent, x02); @@ -589,14 +800,6 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I { const float* pMatTemp = pMatCurrent; __m128 x02 = _mm_mul_ps(x01, _mm_load_ps(pMatTemp)); - __m128 x12 = _mm_mul_ps(x11, _mm_load_ps(pMatTemp += crow)); - __m128 x22 = _mm_mul_ps(x21, _mm_load_ps(pMatTemp += crow)); - __m128 x32 = _mm_mul_ps(x31, _mm_load_ps(pMatTemp += crow)); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); - x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); _mm_storeu_ps(pDstCurrent, x02); @@ -619,22 +822,11 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I const float* pMatTemp = pMatCurrent; __m128 x02 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp)); - __m128 x12 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x22 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x32 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); - x02 = _mm_mul_ps(x01, x02); - x12 = _mm_mul_ps(x11, x12); - x22 = _mm_mul_ps(x21, x22); - x32 = _mm_mul_ps(x31, x32); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + ((4 - remainder) * 4)); __m128 x3 = _mm_loadu_ps(pDstCurrent); - x02 = _mm_or_ps(x02, _mm_and_ps(x3, leadingMask)); + x02 = _mm_add_ps(x02, _mm_and_ps(x3, leadingMask)); x02 = _mm_add_ps(x02, _mm_and_ps(x3, trailingMask)); _mm_storeu_ps(pDstCurrent, x02); @@ -643,8 +835,7 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I } } - pMatCurrent += 3 * crow; - pSrcCurrent += 4; + pSrcCurrent += 1; } } @@ -1238,43 +1429,6 @@ EXPORT_API(float) Dist2(const float * px, const float * py, int c) return norm2; } -EXPORT_API(void) ZeroItemsU(_Inout_ float * pd, int c, _In_ const int * pindices, int cindices) -{ - DEBUG_ONLY(c); - for (int i = 0; i < cindices; ++i) - { - int iv = pindices[i]; - assert(0 <= iv && iv < c); - pd[iv] = 0; - } -} - -EXPORT_API(void) ZeroMatrixItemsCore(_Inout_ float * pd, int c, int ccol, int cfltRow, _In_ const int * pindices, int cindices) -{ - DEBUG_ONLY(c); - int ivLogMin = 0; - int ivLogLim = ccol; - int ivPhyMin = 0; - for (int i = 0; i < cindices; ++i) - { - int iv = pindices[i]; - assert(0 <= iv && iv < c); - - int col = iv - ivLogMin; - if ((unsigned int)col >= (unsigned int)ccol) - { - assert(ivLogMin > iv || iv >= ivLogLim); - int row = iv / ccol; - ivLogMin = row * ccol; - ivLogLim = ivLogMin + ccol; - ivPhyMin = row * cfltRow; - assert(ivLogMin <= iv && iv < ivLogLim); - col = iv - ivLogMin; - } - pd[ivPhyMin + col] = 0; - } -} - EXPORT_API(void) SdcaL1UpdateU(float primalUpdate, _In_ const float * ps, float threshold, _Inout_ float *pd1, _Inout_ float * pd2, int c) { const float * psLim = ps + c; diff --git a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs index 25996ec42c..f8c856cfec 100644 --- a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs +++ b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs @@ -14,14 +14,14 @@ public class CpuMathUtilsUnitTests { private readonly float[][] _testArrays; private readonly int[] _testIndexArray; - private readonly AlignedArray[] _testMatrices; - private readonly AlignedArray[] _testSrcVectors; - private readonly AlignedArray[] _testDstVectors; - private readonly int _vectorAlignment = CpuMathUtils.GetVectorAlignment(); + private readonly float[][] _testMatrices; + private readonly float[][] _testSrcVectors; + private readonly float[][] _testDstVectors; private readonly FloatEqualityComparer _comparer; private readonly FloatEqualityComparerForMatMul _matMulComparer; private const float DefaultScale = 1.7f; + private const int DefaultSeed = 253421; public CpuMathUtilsUnitTests() { @@ -50,34 +50,19 @@ public CpuMathUtilsUnitTests() testMatrix2[i] = i + 1; } - AlignedArray testMatrixAligned1 = new AlignedArray(8 * 8, _vectorAlignment); - AlignedArray testMatrixAligned2 = new AlignedArray(8 * 16, _vectorAlignment); - testMatrixAligned1.CopyFrom(testMatrix1); - testMatrixAligned2.CopyFrom(testMatrix2); - - _testMatrices = new AlignedArray[] { testMatrixAligned1, testMatrixAligned2 }; + _testMatrices = new float[][] { testMatrix1, testMatrix2 }; // Padded source vectors whose dimensions are multiples of 8 float[] testSrcVector1 = new float[8] { 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f }; float[] testSrcVector2 = new float[16] { 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f, 16f }; - AlignedArray testSrcVectorAligned1 = new AlignedArray(8, _vectorAlignment); - AlignedArray testSrcVectorAligned2 = new AlignedArray(16, _vectorAlignment); - testSrcVectorAligned1.CopyFrom(testSrcVector1); - testSrcVectorAligned2.CopyFrom(testSrcVector2); - - _testSrcVectors = new AlignedArray[] { testSrcVectorAligned1, testSrcVectorAligned2 }; + _testSrcVectors = new float[][] { testSrcVector1, testSrcVector2 }; // Padded destination vectors whose dimensions are multiples of 8 float[] testDstVector1 = new float[8] { 0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f }; float[] testDstVector2 = new float[16] { 0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f }; - AlignedArray testDstVectorAligned1 = new AlignedArray(8, _vectorAlignment); - AlignedArray testDstVectorAligned2 = new AlignedArray(16, _vectorAlignment); - testDstVectorAligned1.CopyFrom(testDstVector1); - testDstVectorAligned2.CopyFrom(testDstVector2); - - _testDstVectors = new AlignedArray[] { testDstVectorAligned1, testDstVectorAligned2 }; + _testDstVectors = new float[][] { testDstVector1, testDstVector2 }; } [Theory] @@ -86,29 +71,127 @@ public CpuMathUtilsUnitTests() [InlineData(1, 0, 1, new float[] { 204f, 492f, 780f, 1068f, 1356f, 1644f, 1932f, 2220f, 2508f, 2796f, 3084f, 3372f, 3660f, 3948f, 4236f, 4524f })] public void MatMulTest(int matTest, int srcTest, int dstTest, float[] expected) { - AlignedArray mat = _testMatrices[matTest]; - AlignedArray src = _testSrcVectors[srcTest]; - AlignedArray dst = _testDstVectors[dstTest]; + float[] mat = _testMatrices[matTest]; + float[] src = _testSrcVectors[srcTest]; + float[] dst = _testDstVectors[dstTest]; - CpuMathUtils.MatrixTimesSource(false, mat, src, dst, dst.Size); - float[] actual = new float[dst.Size]; - dst.CopyTo(actual, 0, dst.Size); + CpuMathUtils.MatrixTimesSource(false, mat, src, dst, dst.Length); + float[] actual = new float[dst.Length]; + Array.Copy(dst, actual, dst.Length); Assert.Equal(expected, actual, _matMulComparer); } + [Theory] + [InlineData(10, 5)] + [InlineData(10, 8)] + [InlineData(10, 11)] + [InlineData(11, 8)] + [InlineData(8, 23)] + [InlineData(2, 8)] + [InlineData(2, 9)] + [InlineData(2, 3)] + [InlineData(2, 5)] + [InlineData(4, 5)] + [InlineData(4, 7)] + [InlineData(4, 9)] + [InlineData(5, 7)] + [InlineData(5, 9)] + + private void MatMulAnyDimensionTest(int col, int row) + { + Random rand = new Random(DefaultSeed); + float[] mat = new float[col * row]; + for (int i = 0; i < col * row; i++) + { + mat[i] = rand.Next(-10, 10); + } + + float[] src = new float[col]; + for (int i = 0; i < col; i++) + { + src[i] = rand.Next(-10, 10); + } + + float[] dst = new float[row]; + float[] expected = new float[row]; + + for (int i = 0; i < row; i++) + { + float dotProduct = 0; + for (int j = 0; j < src.Length; j++) + { + dotProduct += mat[i * src.Length + j] * src[j]; + } + + expected[i] = dotProduct; + } + + CpuMathUtils.MatrixTimesSource(false, mat, src, dst, dst.Length); + Assert.Equal(expected, dst, _matMulComparer); + } + + [Theory] + [InlineData(10, 5)] + [InlineData(10, 8)] + [InlineData(10, 11)] + [InlineData(11, 8)] + [InlineData(8, 23)] + [InlineData(2, 8)] + [InlineData(2, 9)] + [InlineData(2, 3)] + [InlineData(2, 5)] + [InlineData(4, 5)] + [InlineData(4, 7)] + [InlineData(4, 9)] + [InlineData(5, 7)] + [InlineData(5, 9)] + + private void MatMulTranAnyDimensionTest(int col, int row) + { + float[] mat = new float[col * row]; + Random rand = new Random(DefaultSeed); + for (int i = 0; i < col * row; i++) + { + mat[i] = rand.Next(0, 10); + } + + float[] src = new float[row]; + for (int i = 0; i < row; i++) + { + src[i] = rand.Next(0, 10); + } + + float[] dst = new float[col]; + float[] expected = new float[col]; + + for (int i = 0; i < dst.Length; i++) + { + float dotProduct = 0; + for (int j = 0; j < row; j++) + { + dotProduct += mat[j * dst.Length + i] * src[j]; + } + + expected[i] = dotProduct; + } + + CpuMathUtils.MatrixTimesSource(true, mat, src, dst, row); + Assert.Equal(expected, dst, _matMulComparer); + } + [Theory] [InlineData(0, 0, 0, new float[] { 70.56001f, -85.68f, -351.36f, 498.24f, -3829.32f, -969.48f, 1168.2f, 118.44f })] [InlineData(1, 0, 1, new float[] { 2724f, 2760f, 2796f, 2832f, 2868f, 2904f, 2940f, 2976f, 3012f, 3048f, 3084f, 3120f, 3156f, 3192f, 3228f, 3264f })] [InlineData(1, 1, 0, new float[] { 11016f, 11152f, 11288f, 11424f, 11560f, 11696f, 11832f, 11968f })] public void MatMulTranTest(int matTest, int srcTest, int dstTest, float[] expected) { - AlignedArray mat = _testMatrices[matTest]; - AlignedArray src = _testSrcVectors[srcTest]; - AlignedArray dst = _testDstVectors[dstTest]; + float[] mat = _testMatrices[matTest]; + float[] src = _testSrcVectors[srcTest]; + float[] dst = _testDstVectors[dstTest]; - CpuMathUtils.MatrixTimesSource(true, mat, src, dst, src.Size); - float[] actual = new float[dst.Size]; - dst.CopyTo(actual, 0, dst.Size); + CpuMathUtils.MatrixTimesSource(true, mat, src, dst, src.Length); + float[] actual = new float[dst.Length]; + Array.Copy(dst, actual, dst.Length); Assert.Equal(expected, actual, _matMulComparer); } @@ -118,14 +201,14 @@ public void MatMulTranTest(int matTest, int srcTest, int dstTest, float[] expect [InlineData(1, 0, 1, new float[] { 95f, 231f, 367f, 503f, 639f, 775f, 911f, 1047f, 1183f, 1319f, 1455f, 1591f, 1727f, 1863f, 1999f, 2135f })] public void MatTimesSrcSparseTest(int matTest, int srcTest, int dstTest, float[] expected) { - AlignedArray mat = _testMatrices[matTest]; - AlignedArray src = _testSrcVectors[srcTest]; - AlignedArray dst = _testDstVectors[dstTest]; + float[] mat = _testMatrices[matTest]; + float[] src = _testSrcVectors[srcTest]; + float[] dst = _testDstVectors[dstTest]; int[] idx = _testIndexArray; - CpuMathUtils.MatrixTimesSource(mat, idx, src, 0, 0, (srcTest == 0) ? 4 : 9, dst, dst.Size); - float[] actual = new float[dst.Size]; - dst.CopyTo(actual, 0, dst.Size); + CpuMathUtils.MatrixTimesSource(mat, idx, src, 0, 0, (srcTest == 0) ? 4 : 9, dst, dst.Length); + float[] actual = new float[dst.Length]; + Array.Copy(dst, actual, dst.Length); Assert.Equal(expected, actual, _matMulComparer); } @@ -467,34 +550,6 @@ public void Dist2Test(int test, float expected) Assert.Equal(expected, actual, 0); } - [Theory] - [InlineData(0, new int[] { 0, 2, 5, 6 }, new float[] { 0f, 2f, 0f, 4f, 5f, 0f, 0f, 8f })] - [InlineData(1, new int[] { 0, 2, 5, 6, 8, 11, 12, 13, 14 }, new float[] { 0f, 2f, 0f, 4f, 5f, 0f, 0f, 8f, 0f, 10f, 11f, 0f, 0f, 0f, 0f, 16f })] - public void ZeroItemsUTest(int test, int[] idx, float[] expected) - { - AlignedArray src = new AlignedArray(8 + 8 * test, _vectorAlignment); - src.CopyFrom(_testSrcVectors[test]); - - CpuMathUtils.ZeroMatrixItems(src, src.Size, src.Size, idx); - float[] actual = new float[src.Size]; - src.CopyTo(actual, 0, src.Size); - Assert.Equal(expected, actual, _comparer); - } - - [Theory] - [InlineData(0, new int[] { 0, 2, 5 }, new float[] { 0f, 2f, 0f, 4f, 5f, 6f, 0f, 8f })] - [InlineData(1, new int[] { 0, 2, 5, 6, 8, 11, 12, 13 }, new float[] { 0f, 2f, 0f, 4f, 5f, 0f, 0f, 8f, 9f, 0f, 11f, 12f, 0f, 0f, 0f, 16f })] - public void ZeroMatrixItemsCoreTest(int test, int[] idx, float[] expected) - { - AlignedArray src = new AlignedArray(8 + 8 * test, _vectorAlignment); - src.CopyFrom(_testSrcVectors[test]); - - CpuMathUtils.ZeroMatrixItems(src, src.Size / 2 - 1, src.Size / 2, idx); - float[] actual = new float[src.Size]; - src.CopyTo(actual, 0, src.Size); - Assert.Equal(expected, actual, _comparer); - } - [Theory] [InlineData(0)] [InlineData(1)] diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index e2bc679830..3cd44add3f 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -934,7 +934,7 @@ protected bool CheckSameValues(IRowCursor curs1, IRowCursor curs2, bool exactTyp var comp = comps[col]; if (comp != null && !comp()) { - Fail("Different values in column {0} of row {1}", col, curs1.Position); + Fail($"Different values in column {col} of row {curs1.Position}"); return Failed(); } if (idComp != null && !idComp()) @@ -1158,12 +1158,12 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ throw Contracts.Except("Unknown type in GetColumnComparer: '{0}'", type); } - private const Double DoubleEps = 1e-9; + private const Double DoubleEps = 1e-5; private static bool EqualWithEpsDouble(Double x, Double y) { // bitwise comparison is needed because Abs(Inf-Inf) and Abs(NaN-NaN) are not 0s. - return FloatUtils.GetBits(x) == FloatUtils.GetBits(y) || Math.Abs(x - y) < DoubleEps; + return FloatUtils.GetBits(x) == FloatUtils.GetBits(y) || Math.Abs(x - y) / Math.Abs(x) < DoubleEps; } private const float SingleEps = 1e-6f; diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs index 0bd4604daa..23c98fbb8c 100644 --- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs @@ -130,10 +130,10 @@ public void ChangePointDetectionWithSeasonality() while (enumerator.MoveNext() && index < expectedValues.Count) { row = enumerator.Current; - Assert.Equal(expectedValues[index++], row.Change[0], precision: 7); // Alert - Assert.Equal(expectedValues[index++], row.Change[1], precision: 7); // Raw score - Assert.Equal(expectedValues[index++], row.Change[2], precision: 7); // P-Value score - Assert.Equal(expectedValues[index++], row.Change[3], precision: 7); // Martingale score + Assert.Equal(expectedValues[index++], row.Change[0], precision: 5); // Alert + Assert.Equal(expectedValues[index++], row.Change[1], precision: 5); // Raw score + Assert.Equal(expectedValues[index++], row.Change[2], precision: 5); // P-Value score + Assert.Equal(expectedValues[index++], row.Change[3], precision: 5); // Martingale score } } From a8daf9b12ed87cda2b8a1e56d5f20d1c5beb8bf9 Mon Sep 17 00:00:00 2001 From: Anipik Date: Thu, 22 Nov 2018 18:44:00 -0800 Subject: [PATCH 2/5] adding some asserts --- src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs index 76afabd59e..6546eeef31 100644 --- a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs +++ b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs @@ -11,8 +11,11 @@ namespace Microsoft.ML.Runtime.Internal.CpuMath { internal static partial class CpuMathUtils { - public static void MatrixTimesSource(bool transpose, float[] matrix, ReadOnlySpan source, Span destination, int stride) + public static void MatrixTimesSource(bool transpose, ReadOnlySpan matrix, ReadOnlySpan source, Span destination, int stride) { + Contracts.AssertNonEmpty(matrix); + Contracts.AssertNonEmpty(source); + Contracts.AssertNonEmpty(destination); Contracts.Assert(matrix.Length == destination.Length * source.Length); Contracts.Assert(stride >= 0); @@ -78,6 +81,11 @@ public static void MatrixTimesSource(ReadOnlySpan matrix, ReadOnlySpan= 0); Contracts.Assert(iposMin <= iposLimit); Contracts.Assert(iposLimit <= rgposSrc.Length); + Contracts.AssertNonEmpty(matrix); + Contracts.AssertNonEmpty(sourceValues); + Contracts.AssertNonEmpty(destination); + Contracts.AssertNonEmpty(rgposSrc); + Contracts.Assert(stride > 0); Contracts.Assert(matrix.Length == destination.Length * sourceValues.Length); if (iposMin >= iposLimit) From 346980b9e188cb0b6c64dd884017d4d22faf7822 Mon Sep 17 00:00:00 2001 From: Anipik Date: Wed, 28 Nov 2018 15:54:39 -0800 Subject: [PATCH 3/5] unrolling the loop, aligned removed from name, using new float comparision --- ...AdaptiveSingularSpectrumSequenceModeler.cs | 5 +-- .../RandomFourierFeaturizing.cs | 38 +++++++++---------- src/Native/CpuMathNative/Sse.cpp | 9 ++++- .../UnitTests.cs | 1 - .../DataPipe/TestDataPipeBase.cs | 8 ++-- .../TimeSeriesDirectApi.cs | 8 ++-- 6 files changed, 35 insertions(+), 34 deletions(-) diff --git a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs index 3d3baeaa2f..88cbc536dd 100644 --- a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs +++ b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs @@ -529,8 +529,7 @@ public override void Save(ModelSaveContext ctx) Array.Copy(_wTrans, tempArray, _wTrans.Length); ctx.Writer.WriteSingleArray(tempArray); tempArray = new float[_rank]; - iv = 0; - _y.CopyTo(tempArray, ref iv); + Array.Copy(_y, tempArray, tempArray.Length); ctx.Writer.WriteSingleArray(tempArray); } @@ -1348,7 +1347,7 @@ private void TrainCore(Single[] dataArray, int originalSeriesLength) // Setting the weight matrix _wTrans = new float[_rank * _windowSize]; - Array.Copy(leftSingularVecs, _wTrans, _rank * _windowSize); + Array.Copy(leftSingularVecs, _wTrans, _wTrans.Length); // Setting alpha Single nu = 0; diff --git a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs index f89547471b..f629201dfa 100644 --- a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs +++ b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs @@ -120,8 +120,8 @@ public TransformInfo(IHost host, ColumnInfo column, int d, float avgDist) var generator = column.Generator; _matrixGenerator = generator.CreateComponent(host, avgDist); - int roundedUpD = RoundUp(NewDim); - int roundedUpNumFeatures = RoundUp(SrcDim); + int roundedUpD = RoundToMultipleOf4(NewDim); + int roundedUpNumFeatures = RoundToMultipleOf4(SrcDim); RndFourierVectors = new float[roundedUpD * roundedUpNumFeatures]; RotationTerms = _useSin ? null : new float[roundedUpD]; @@ -154,8 +154,8 @@ public TransformInfo(IHostEnvironment env, ModelLoadContext ctx, string director ctx.LoadModelOrNull(env, out _matrixGenerator, directoryName)); // initialize the transform matrix - int roundedUpD = RoundUp(NewDim); - int roundedUpNumFeatures = RoundUp(SrcDim); + int roundedUpD = RoundToMultipleOf4(NewDim); + int roundedUpNumFeatures = RoundToMultipleOf4(SrcDim); RndFourierVectors = new float[roundedUpD * roundedUpNumFeatures]; RotationTerms = _useSin ? null : new float[roundedUpD]; InitializeFourierCoefficients(roundedUpNumFeatures, roundedUpD); @@ -293,7 +293,7 @@ public RandomFourierFeaturizingTransformer(IHostEnvironment env, IDataView input } } - private static int RoundUp(int number) + private static int RoundToMultipleOf4(int number) { Contracts.Assert(0 < number); int multipleOf4 = (number + 3) / 4; @@ -548,14 +548,14 @@ private ValueGetter> GetterFromVectorType(IRow input, int iinfo) var getSrc = input.GetGetter>(_srcCols[iinfo]); var src = default(VBuffer); - var featuresAligned = new float[RoundUp(_srcTypes[iinfo].ValueCount)]; - var productAligned = new float[RoundUp(_parent._transformInfos[iinfo].NewDim)]; + var features = new float[RoundToMultipleOf4(_srcTypes[iinfo].ValueCount)]; + var product = new float[RoundToMultipleOf4(_parent._transformInfos[iinfo].NewDim)]; return (ref VBuffer dst) => { getSrc(ref src); - TransformFeatures(in src, ref dst, _parent._transformInfos[iinfo], featuresAligned, productAligned); + TransformFeatures(in src, ref dst, _parent._transformInfos[iinfo], features, product); }; } @@ -566,7 +566,7 @@ private ValueGetter> GetterFromFloatType(IRow input, int iinfo) var src = default(float); var featuresAligned = new float[4]; - var productAligned = new float[RoundUp(_parent._transformInfos[iinfo].NewDim)]; + var productAligned = new float[RoundToMultipleOf4(_parent._transformInfos[iinfo].NewDim)]; var oneDimensionalVector = new VBuffer(1, new float[] { 0 }); @@ -580,7 +580,7 @@ private ValueGetter> GetterFromFloatType(IRow input, int iinfo) } private void TransformFeatures(in VBuffer src, ref VBuffer dst, TransformInfo transformInfo, - float[] featuresAligned, float[] productAligned) + float[] features, float[] product) { Host.Check(src.Length == transformInfo.SrcDim, "column does not have the expected dimensionality."); @@ -599,9 +599,9 @@ private void TransformFeatures(in VBuffer src, ref VBuffer dst, Tr if (src.IsDense) { - src.GetValues().CopyTo(featuresAligned); - CpuMathUtils.MatrixTimesSource(transpose: false, transformInfo.RndFourierVectors, featuresAligned, productAligned, - RoundUp(transformInfo.NewDim)); + src.GetValues().CopyTo(features); + CpuMathUtils.MatrixTimesSource(transpose: false, transformInfo.RndFourierVectors, features, product, + RoundToMultipleOf4(transformInfo.NewDim)); } else { @@ -610,20 +610,20 @@ private void TransformFeatures(in VBuffer src, ref VBuffer dst, Tr var srcValues = src.GetValues(); var srcIndices = src.GetIndices(); - for (int i = 0; i< srcValues.Length; i++) + for (int i = 0; i < srcValues.Length; i++) { - int iv = srcIndices[0]; - featuresAligned[iv] = srcValues[i]; + int iv = srcIndices[i]; + features[iv] = srcValues[i]; } - CpuMathUtils.MatrixTimesSource(transformInfo.RndFourierVectors, srcIndices, featuresAligned, 0, 0, - srcValues.Length, productAligned, RoundUp(transformInfo.NewDim)); + CpuMathUtils.MatrixTimesSource(transformInfo.RndFourierVectors, srcIndices, features, 0, 0, + srcValues.Length, product, RoundToMultipleOf4(transformInfo.NewDim)); } var dstEditor = VBufferEditor.Create(ref dst, newDstLength); for (int i = 0; i < transformInfo.NewDim; i++) { - var dotProduct = productAligned[i]; + var dotProduct = product[i]; if (transformInfo.RotationTerms != null) dstEditor.Values[i] = (float)MathUtils.Cos(dotProduct + transformInfo.RotationTerms[i]) * scale; else diff --git a/src/Native/CpuMathNative/Sse.cpp b/src/Native/CpuMathNative/Sse.cpp index 45344ecf97..74a915de70 100644 --- a/src/Native/CpuMathNative/Sse.cpp +++ b/src/Native/CpuMathNative/Sse.cpp @@ -65,9 +65,14 @@ EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout for (int i = 0 ; i < crow; i++) { float dotProduct = 0; - for (int j = 0; j < ccol; j++) + switch (ccol) { - dotProduct += pmat[i * ccol + j] * psrc[j]; + case 3: + dotProduct += pmat[i * ccol + 2] * psrc[2]; + case 2: + dotProduct += pmat[i * ccol + 1] * psrc[1]; + case 1: + dotProduct += pmat[i * ccol + 0] * psrc[0]; } pdst[i] = dotProduct; } diff --git a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs index f8c856cfec..c3d378751e 100644 --- a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs +++ b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs @@ -145,7 +145,6 @@ private void MatMulAnyDimensionTest(int col, int row) [InlineData(4, 9)] [InlineData(5, 7)] [InlineData(5, 9)] - private void MatMulTranAnyDimensionTest(int col, int row) { float[] mat = new float[col * row]; diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index 3cd44add3f..b0aa994f27 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -17,6 +17,7 @@ using Microsoft.ML.Runtime.Model; using Microsoft.ML.TestFramework; using Xunit; +using Xunit.Abstractions; namespace Microsoft.ML.Runtime.RunTests { @@ -774,7 +775,6 @@ protected bool SaveLoadTransposed(IDataView view, IHostEnvironment env, string s public abstract partial class TestDataViewBase : BaseTestBaseline { - public class SentimentData { [ColumnName("Label")] @@ -1158,12 +1158,10 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ throw Contracts.Except("Unknown type in GetColumnComparer: '{0}'", type); } - private const Double DoubleEps = 1e-5; - - private static bool EqualWithEpsDouble(Double x, Double y) + private bool EqualWithEpsDouble(Double x, Double y) { // bitwise comparison is needed because Abs(Inf-Inf) and Abs(NaN-NaN) are not 0s. - return FloatUtils.GetBits(x) == FloatUtils.GetBits(y) || Math.Abs(x - y) / Math.Abs(x) < DoubleEps; + return FloatUtils.GetBits(x) == FloatUtils.GetBits(y) || CompareNumbersWithTolerance(x, y, null, 3); } private const float SingleEps = 1e-6f; diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs index 23c98fbb8c..06db1dd73f 100644 --- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs @@ -193,7 +193,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngineNoColumn() Assert.Equal(1.1661833524703979, prediction2.Change[1], precision: 5); // Raw score prediction2 = engine2.Predict(new Data(1)); //Raw score after second input. - Assert.Equal(0.12216401100158691, prediction2.Change[1], precision: 5); // Raw score + Assert.Equal(0.12216401100158691, prediction2.Change[1], precision: 4); // Raw score //Even though time series column is not requested it will // pass the observation through time series transform and update the state with the first input. @@ -210,7 +210,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngineNoColumn() //and raw score should match the raw score obtained by passing the two input in the first model. var engine3 = model3.CreateTimeSeriesPredictionFunction(ml); var prediction3 = engine3.Predict(new Data(1)); - Assert.Equal(0.12216401100158691, prediction2.Change[1], precision: 5); // Raw score + Assert.Equal(0.12216401100158691, prediction2.Change[1], precision: 4); // Raw score } [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] @@ -264,7 +264,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngine() //Model 1: Prediction #2 prediction = engine.Predict(new Data(1)); Assert.Equal(0, prediction.Change[0], precision: 7); // Alert - Assert.Equal(0.12216401100158691, prediction.Change[1], precision: 5); // Raw score + Assert.Equal(0.12216401100158691, prediction.Change[1], precision: 4); // Raw score Assert.Equal(0.14823824685192111, prediction.Change[2], precision: 5); // P-Value score Assert.Equal(1.5292508189989167E-07, prediction.Change[3], precision: 7); // Martingale score @@ -277,7 +277,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngine() engine = model2.CreateTimeSeriesPredictionFunction(ml); prediction = engine.Predict(new Data(1)); Assert.Equal(0, prediction.Change[0], precision: 7); // Alert - Assert.Equal(0.12216401100158691, prediction.Change[1], precision: 5); // Raw score + Assert.Equal(0.12216401100158691, prediction.Change[1], precision: 4); // Raw score Assert.Equal(0.14823824685192111, prediction.Change[2], precision: 5); // P-Value score Assert.Equal(1.5292508189989167E-07, prediction.Change[3], precision: 5); // Martingale score } From 04a75ff292f4aa58ae968a331a798f6d90660fe5 Mon Sep 17 00:00:00 2001 From: Anipik Date: Wed, 28 Nov 2018 17:28:26 -0800 Subject: [PATCH 4/5] enabling some more tests --- test/Microsoft.ML.TimeSeries.Tests/TimeSeries.cs | 2 +- test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs | 8 ++++---- .../TimeSeriesEstimatorTests.cs | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeries.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeries.cs index 39340d225b..90428934ba 100644 --- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeries.cs +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeries.cs @@ -72,7 +72,7 @@ public void SavePipeSsaSpike() Done(); } - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] // netcore3.0 output differs from Baseline + [Fact] public void SavePipeSsaSpikeNoData() { string pathData = DeleteOutputPath("SavePipe", "SsaSpikeNoData.txt"); diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs index 06db1dd73f..274a229630 100644 --- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs @@ -87,7 +87,7 @@ public void ChangeDetection() } } - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] // netcore3.0 output differs from Baseline + [Fact] public void ChangePointDetectionWithSeasonality() { var env = new MLContext(conc: 1); @@ -137,7 +137,7 @@ public void ChangePointDetectionWithSeasonality() } } - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] + [Fact] public void ChangePointDetectionWithSeasonalityPredictionEngineNoColumn() { const int ChangeHistorySize = 10; @@ -190,7 +190,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngineNoColumn() var engine2 = model2.CreateTimeSeriesPredictionFunction(ml); var prediction2 = engine2.Predict(new Data(1)); //Raw score after first input. - Assert.Equal(1.1661833524703979, prediction2.Change[1], precision: 5); // Raw score + Assert.Equal(1.1661833524703979, prediction2.Change[1], precision: 4); // Raw score prediction2 = engine2.Predict(new Data(1)); //Raw score after second input. Assert.Equal(0.12216401100158691, prediction2.Change[1], precision: 4); // Raw score @@ -213,7 +213,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngineNoColumn() Assert.Equal(0.12216401100158691, prediction2.Change[1], precision: 4); // Raw score } - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] + [Fact] public void ChangePointDetectionWithSeasonalityPredictionEngine() { const int ChangeHistorySize = 10; diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs index a7892701d7..0e6a5063bd 100644 --- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs @@ -41,7 +41,7 @@ public TimeSeriesEstimatorTests(ITestOutputHelper output) : base(output) { } - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] // netcore3.0 output differs from Baseline + [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCoreAnd64BitProcess))] // 32bit and netcore3.0 output differs from Baseline void TestSsaChangePointEstimator() { int Confidence = 95; @@ -75,7 +75,7 @@ void TestSsaChangePointEstimator() Done(); } - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] // netcore3.0 output differs from Baseline + [Fact] void TestSsaSpikeEstimator() { int Confidence = 95; From 483efc5c820e9d7deb2ad50fab35625f12f241c1 Mon Sep 17 00:00:00 2001 From: Anipik Date: Mon, 3 Dec 2018 13:22:45 -0800 Subject: [PATCH 5/5] case comment added, tanners feedback --- src/Microsoft.ML.CpuMath/AvxIntrinsics.cs | 579 +++++++----------- src/Microsoft.ML.CpuMath/SseIntrinsics.cs | 574 +++++++---------- src/Native/CpuMathNative/Sse.cpp | 574 +++++++---------- .../UnitTests.cs | 1 - 4 files changed, 689 insertions(+), 1039 deletions(-) diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index b5e18ac0d9..d6fb169094 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -164,142 +164,14 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr float* pDstEnd = pdst + crow; float* pDstCurrent = pdst; float* pMatCurrent = pmat; + int numRows = crow; - if (crow % 4 == 0) - { - while (pDstCurrent + 4 <= pDstEnd) - { - Vector256 res0 = Avx.SetZeroVector256(); - Vector256 res1 = Avx.SetZeroVector256(); - Vector256 res2 = Avx.SetZeroVector256(); - Vector256 res3 = Avx.SetZeroVector256(); - - int length = ccol; - float* pSrcCurrent = psrc; - - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 32); - int remainder = 0; - - if ((misalignment & 3) != 0) - { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pSrcCurrent < pSrcEnd) - { - Vector256 vector = Avx.LoadVector256(pSrcCurrent); - - float* pMatTemp = pMatCurrent; - res0 = MultiplyAdd(pMatTemp, vector, res0); - res1 = MultiplyAdd(pMatTemp += ccol, vector, res1); - res2 = MultiplyAdd(pMatTemp += ccol, vector, res2); - res3 = MultiplyAdd(pMatTemp += ccol, vector, res3); - - pSrcCurrent += 8; - pMatCurrent += 8; - } - } - else - { - if (misalignment != 0) - { - // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 8 - misalignment; - - Vector256 mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); - - // We only align pMat since it has significantly more reads. - float* pMatTemp = pMatCurrent; - Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); - Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent)); - - res0 = Avx.Multiply(x01, vector); - res1 = Avx.Multiply(x11, vector); - res2 = Avx.Multiply(x21, vector); - res3 = Avx.Multiply(x31, vector); - - pMatCurrent += misalignment; - pSrcCurrent += misalignment; - length -= misalignment; - } - - if (length > 7) - { - // Handle all the 256-bit blocks that we can now that we have offset to an aligned address - remainder = length % 8; - - while (pSrcCurrent + 8 <= pSrcEnd) - { - // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads - // (due to semantics of the legacy encoding). - // We don't need an assert, since the instruction will throw for unaligned inputs. - Vector256 vector = Avx.LoadVector256(pSrcCurrent); - - float* pMatTemp = pMatCurrent; - res0 = MultiplyAdd(pMatTemp, vector, res0); - res1 = MultiplyAdd(pMatTemp += ccol, vector, res1); - res2 = MultiplyAdd(pMatTemp += ccol, vector, res2); - res3 = MultiplyAdd(pMatTemp += ccol, vector, res3); - - pSrcCurrent += 8; - pMatCurrent += 8; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not - // 256-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; - } - - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 256-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - - pMatCurrent -= (8 - remainder); - pSrcCurrent -= (8 - remainder); - - Vector256 mask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); - - float* pMatTemp = pMatCurrent; - Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); - Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent)); - - res0 = MultiplyAdd(x01, vector, res0); - res1 = MultiplyAdd(x11, vector, res1); - res2 = MultiplyAdd(x21, vector, res2); - res3 = MultiplyAdd(x31, vector, res3); - - pMatCurrent += 8; - pSrcCurrent += 8; - } - } - - // Add up the entries of each, with the 4 results in res0 - res0 = Avx.HorizontalAdd(res0, res1); - res2 = Avx.HorizontalAdd(res2, res3); - res0 = Avx.HorizontalAdd(res0, res2); - - Vector128 sum = Sse.Add(Avx.GetLowerHalf(res0), GetHigh(in res0)); - Sse.Store(pDstCurrent, sum); - - pDstCurrent += 4; - pMatCurrent += 3 * ccol; - } - } - - while (pDstCurrent < pDstEnd) + while (pDstCurrent + 4 <= pDstEnd) { Vector256 res0 = Avx.SetZeroVector256(); + Vector256 res1 = Avx.SetZeroVector256(); + Vector256 res2 = Avx.SetZeroVector256(); + Vector256 res3 = Avx.SetZeroVector256(); int length = ccol; float* pSrcCurrent = psrc; @@ -308,15 +180,19 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr int misalignment = (int)(address % 32); int remainder = 0; - if ((misalignment & 3) != 0) + if ((misalignment & 3) != 0 || (ccol % 8 != 0)) { // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pSrcCurrent < pSrcEnd) + remainder = length % 8; + while (pSrcCurrent + 8 <= pSrcEnd) { Vector256 vector = Avx.LoadVector256(pSrcCurrent); float* pMatTemp = pMatCurrent; res0 = MultiplyAdd(pMatTemp, vector, res0); + res1 = MultiplyAdd(pMatTemp += ccol, vector, res1); + res2 = MultiplyAdd(pMatTemp += ccol, vector, res2); + res3 = MultiplyAdd(pMatTemp += ccol, vector, res3); pSrcCurrent += 8; pMatCurrent += 8; @@ -336,9 +212,15 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr // We only align pMat since it has significantly more reads. float* pMatTemp = pMatCurrent; Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); + Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent)); res0 = Avx.Multiply(x01, vector); + res1 = Avx.Multiply(x11, vector); + res2 = Avx.Multiply(x21, vector); + res3 = Avx.Multiply(x31, vector); pMatCurrent += misalignment; pSrcCurrent += misalignment; @@ -359,6 +241,9 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr float* pMatTemp = pMatCurrent; res0 = MultiplyAdd(pMatTemp, vector, res0); + res1 = MultiplyAdd(pMatTemp += ccol, vector, res1); + res2 = MultiplyAdd(pMatTemp += ccol, vector, res2); + res3 = MultiplyAdd(pMatTemp += ccol, vector, res3); pSrcCurrent += 8; pMatCurrent += 8; @@ -371,37 +256,95 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr // unaligned loads where we mask the input each time. remainder = length; } + } - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 256-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed + if (remainder != 0) + { + // Handle any trailing elements that don't fit into a 256-bit block by moving back so that the next + // unaligned load will read to the end of the array and then mask out any elements already processed - pMatCurrent -= (8 - remainder); - pSrcCurrent -= (8 - remainder); + pMatCurrent -= (8 - remainder); + pSrcCurrent -= (8 - remainder); - Vector256 mask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); + Vector256 mask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); - float* pMatTemp = pMatCurrent; - Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); - Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent)); + float* pMatTemp = pMatCurrent; + Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); + Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent)); - res0 = MultiplyAdd(x01, vector, res0); + res0 = MultiplyAdd(x01, vector, res0); + res1 = MultiplyAdd(x11, vector, res1); + res2 = MultiplyAdd(x21, vector, res2); + res3 = MultiplyAdd(x31, vector, res3); - pMatCurrent += 8; - pSrcCurrent += 8; - } + pMatCurrent += 8; + pSrcCurrent += 8; } - res0 = VectorSum256(in res0); - float sum = Sse.ConvertToSingle(Sse.AddScalar(Avx.GetLowerHalf(res0), GetHigh(res0))); - *pDstCurrent = sum; + // Add up the entries of each, with the 4 results in res0 + res0 = Avx.HorizontalAdd(res0, res1); + res2 = Avx.HorizontalAdd(res2, res3); + res0 = Avx.HorizontalAdd(res0, res2); + + Vector128 sum = Sse.Add(Avx.GetLowerHalf(res0), GetHigh(in res0)); + Sse.Store(pDstCurrent, sum); + + pDstCurrent += 4; + pMatCurrent += 3 * ccol; + numRows -= 4; + } - pDstCurrent += 1; + // falling through the case statements + switch (numRows) + { + case 3: + *(pDstCurrent + 2) = RowMultiply(pMatCurrent + 2 * ccol, psrc, pSrcEnd, ccol); + goto case 2; + case 2: + *(pDstCurrent + 1) = RowMultiply(pMatCurrent + ccol, psrc, pSrcEnd, ccol); + goto case 1; + case 1: + *pDstCurrent = RowMultiply(pMatCurrent, psrc, pSrcEnd, ccol); + break; } } } + private static unsafe float RowMultiply(float* pMatCurrent, float* pSrcCurrent, float* pSrcEnd, int ccol) + { + Vector256 res0 = Avx.SetZeroVector256(); + int remainder = ccol % 8; + while (pSrcCurrent + 8 <= pSrcEnd) + { + Vector256 vector = Avx.LoadVector256(pSrcCurrent); + + float* pMatTemp = pMatCurrent; + res0 = MultiplyAdd(pMatTemp, vector, res0); + + pSrcCurrent += 8; + pMatCurrent += 8; + } + + res0 = VectorSum256(in res0); + float sum = Sse.ConvertToSingle(Sse.AddScalar(Avx.GetLowerHalf(res0), GetHigh(res0))); + + // falling through the case statements + switch (remainder) + { + case 7: sum += *(pSrcCurrent + 6) * *(pMatCurrent + 6); goto case 6; + case 6: sum += *(pSrcCurrent + 5) * *(pMatCurrent + 5); goto case 5; + case 5: sum += *(pSrcCurrent + 4) * *(pMatCurrent + 4); goto case 4; + case 4: sum += *(pSrcCurrent + 3) * *(pMatCurrent + 3); goto case 3; + case 3: sum += *(pSrcCurrent + 2) * *(pMatCurrent + 2); goto case 2; + case 2: sum += *(pSrcCurrent + 1) * *(pMatCurrent + 1); goto case 1; + case 1: sum += *(pSrcCurrent) * *(pMatCurrent); break; + } + return sum; + } + // Partial sparse source vector. public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgposSrc, ReadOnlySpan src, int posMin, int iposMin, int iposEnd, Span dst, int crow, int ccol) @@ -568,6 +511,8 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan mat, ReadOnlySpan x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); @@ -598,7 +544,6 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan mat, ReadOnlySpan trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); + pMatCurrent -= (8 - remainder); + pDstCurrent -= (8 - remainder); + Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp)); - x02 = Avx.Multiply(x01, x02); + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp)); + x02 = Avx.Multiply(x01, x02); - Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); - Vector256 x3 = Avx.LoadVector256(pDstCurrent); - x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); + Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); + Vector256 x3 = Avx.LoadVector256(pDstCurrent); + x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; } pSrcCurrent += 1; + numCol -= 1; } // We do 4-way unrolling - if (crow % 4 == 0) - { - while (pSrcCurrent + 4 <= pSrcEnd) - { - Vector128 h01 = Sse.LoadVector128(pSrcCurrent); - // Replicate each slot of h01 (ABCD) into its own register. - Vector128 h11 = Avx.Permute(h01, 0x55); // B - Vector128 h21 = Avx.Permute(h01, 0xAA); // C - Vector128 h31 = Avx.Permute(h01, 0xFF); // D - h01 = Avx.Permute(h01, 0x00); // A - - Vector256 x01 = Avx.SetHighLow(h01, h01); - Vector256 x11 = Avx.SetHighLow(h11, h11); - Vector256 x21 = Avx.SetHighLow(h21, h21); - Vector256 x31 = Avx.SetHighLow(h31, h31); - - int length = crow; - float* pDstCurrent = pdst; - - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 32); - - if ((misalignment & 3) != 0) - { - while (pDstCurrent < pDstEnd) - { - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); - - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } - } - else - { - int remainder = 0; - if (misalignment != 0) - { - // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 8 - misalignment; - - Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); - - // We only align pMat since it has significantly more reads. - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); - Vector256 x3 = Avx.LoadVector256(pDstCurrent); - x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); - - x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); - - Avx.Store(pDstCurrent, x02); - pMatCurrent += misalignment; - pDstCurrent += misalignment; - length -= misalignment; - } - if (length > 7) - { - remainder = length % 8; - while (pDstCurrent + 8 <= pDstEnd) - { - float* pMatTemp = pMatCurrent; - - Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); - - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } - } - else - { - remainder = length; - } - - if (remainder != 0) - { - pMatCurrent -= (8 - remainder); - pDstCurrent -= (8 - remainder); - Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); - - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); - Vector256 x3 = Avx.LoadVector256(pDstCurrent); - x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); - - x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); - - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } - } - - pMatCurrent += 3 * crow; - pSrcCurrent += 4; - } - } - - while (pSrcCurrent < pSrcEnd) + while (pSrcCurrent + 4 <= pSrcEnd) { Vector128 h01 = Sse.LoadVector128(pSrcCurrent); // Replicate each slot of h01 (ABCD) into its own register. + Vector128 h11 = Avx.Permute(h01, 0x55); // B + Vector128 h21 = Avx.Permute(h01, 0xAA); // C + Vector128 h31 = Avx.Permute(h01, 0xFF); // D h01 = Avx.Permute(h01, 0x00); // A + Vector256 x01 = Avx.SetHighLow(h01, h01); + Vector256 x11 = Avx.SetHighLow(h11, h11); + Vector256 x21 = Avx.SetHighLow(h21, h21); + Vector256 x31 = Avx.SetHighLow(h31, h31); + int length = crow; float* pDstCurrent = pdst; nuint address = (nuint)(pMatCurrent); int misalignment = (int)(address % 32); - if ((misalignment & 3) != 0) + if ((misalignment & 3) != 0 || (crow % 8 != 0)) { - while (pDstCurrent < pDstEnd) + remainder = length % 8; + while (pDstCurrent + 4 <= pDstEnd) { float* pMatTemp = pMatCurrent; Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); Avx.Store(pDstCurrent, x02); @@ -850,7 +665,6 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan mat, ReadOnlySpan x02 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + x02 = Avx.Multiply(x01, x02); + x12 = Avx.Multiply(x11, x12); + x22 = Avx.Multiply(x21, x22); + x32 = Avx.Multiply(x31, x32); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); Vector256 x3 = Avx.LoadVector256(pDstCurrent); @@ -881,11 +706,17 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); Avx.Store(pDstCurrent, x02); @@ -897,32 +728,90 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan 0) - { - pMatCurrent -= (8 - remainder); - pDstCurrent -= (8 - remainder); - Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); + if (remainder != 0) + { + pMatCurrent -= (8 - remainder); + pDstCurrent -= (8 - remainder); + Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp)); - x02 = Avx.Multiply(x01, x02); + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); - Vector256 x3 = Avx.LoadVector256(pDstCurrent); - x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); + x02 = Avx.Multiply(x01, x02); + x12 = Avx.Multiply(x11, x12); + x22 = Avx.Multiply(x21, x22); + x32 = Avx.Multiply(x31, x32); - x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } + Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); + Vector256 x3 = Avx.LoadVector256(pDstCurrent); + x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); + + x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); - pSrcCurrent += 1; + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; } + + pMatCurrent += 3 * crow; + pSrcCurrent += 4; + numCol -= 4; } + + // falling through the case statements + switch (numCol) + { + case 3: ColumnMultiply(pMatCurrent + 2 * crow, pSrcCurrent + 2, pdst, pDstEnd, crow); goto case 2; + case 2: ColumnMultiply(pMatCurrent + crow, pSrcCurrent + 1, pdst, pDstEnd, crow); goto case 1; + case 1: ColumnMultiply(pMatCurrent, pSrcCurrent, pdst, pDstEnd, crow); break; + } + } + } + + private static unsafe void ColumnMultiply(float* pMatCurrent, float* pSrcCurrent, float* pdst, float* pDstEnd, int crow) + { + Vector128 h01 = Sse.LoadVector128(pSrcCurrent); + // Replicate each slot of h01 (ABCD) into its own register. + h01 = Avx.Permute(h01, 0x00); // A + Vector256 x01 = Avx.SetHighLow(h01, h01); + int remainder = crow % 8; + float* pDstCurrent = pdst; + + while (pDstCurrent + 8 <= pDstEnd) + { + // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads + // (due to semantics of the legacy encoding). + // We don't need an assert, since the instruction will throw for unaligned inputs. + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); + x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); + + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; + } + + // falling through the case statements + switch (remainder) + { + case 7: *(pDstCurrent + 6) += *(pSrcCurrent) * *(pMatCurrent + 6); goto case 6; + case 6: *(pDstCurrent + 5) += *(pSrcCurrent) * *(pMatCurrent + 5); goto case 5; + case 5: *(pDstCurrent + 4) += *(pSrcCurrent) * *(pMatCurrent + 4); goto case 4; + case 4: *(pDstCurrent + 3) += *(pSrcCurrent) * *(pMatCurrent + 3); goto case 3; + case 3: *(pDstCurrent + 2) += *(pSrcCurrent) * *(pMatCurrent + 2); goto case 2; + case 2: *(pDstCurrent + 1) += *(pSrcCurrent) * *(pMatCurrent + 1); goto case 1; + case 1: *pDstCurrent += *(pSrcCurrent) * *(pMatCurrent); break; } + return; } // dst[i] += scale diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index ef0c076814..98394f5d7d 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -128,150 +128,14 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr float* pDstEnd = pdst + crow; float* pDstCurrent = pdst; float* pMatCurrent = pmat; + int numRows = crow; - if (ccol % 4 == 0) - { - while (pDstCurrent + 4 <= pDstEnd) - { - Vector128 res0 = Sse.SetZeroVector128(); - Vector128 res1 = Sse.SetZeroVector128(); - Vector128 res2 = Sse.SetZeroVector128(); - Vector128 res3 = Sse.SetZeroVector128(); - - int length = ccol; - float* pSrcCurrent = psrc; - - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 16); - int remainder = 0; - - if ((misalignment & 3) != 0) - { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pSrcCurrent < pSrcEnd) - { - Vector128 vector = Sse.LoadVector128(pSrcCurrent); - - float* pMatTemp = pMatCurrent; - Vector128 x01 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp)); - Vector128 x11 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x21 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x31 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); - - res0 = Sse.Add(res0, x01); - res1 = Sse.Add(res1, x11); - res2 = Sse.Add(res2, x21); - res3 = Sse.Add(res3, x31); - - pSrcCurrent += 4; - pMatCurrent += 4; - } - } - else - { - if (misalignment != 0) - { - // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 4 - misalignment; - - Vector128 mask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); - - // We only align pMat since it has significantly more reads. - float* pMatTemp = pMatCurrent; - Vector128 x01 = Sse.And(mask, Sse.LoadVector128(pMatTemp)); - Vector128 x11 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x21 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x31 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 vector = Sse.And(mask, Sse.LoadVector128(pSrcCurrent)); - - res0 = Sse.Multiply(x01, vector); - res1 = Sse.Multiply(x11, vector); - res2 = Sse.Multiply(x21, vector); - res3 = Sse.Multiply(x31, vector); - - pMatCurrent += misalignment; - pSrcCurrent += misalignment; - length -= misalignment; - } - - if (length > 3) - { - // Handle all the 128-bit blocks that we can now that we have offset to an aligned address - remainder = length % 4; - - // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads - // (due to semantics of the legacy encoding). - // We don't need an assert, since the instruction will throw for unaligned inputs. - while (pSrcCurrent + 4 <= pSrcEnd) - { - Vector128 vector = Sse.LoadVector128(pSrcCurrent); - - float* pMatTemp = pMatCurrent; - Vector128 x01 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp)); - Vector128 x11 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); - Vector128 x21 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); - Vector128 x31 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); - - res0 = Sse.Add(res0, x01); - res1 = Sse.Add(res1, x11); - res2 = Sse.Add(res2, x21); - res3 = Sse.Add(res3, x31); - - pSrcCurrent += 4; - pMatCurrent += 4; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not - // 128-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; - } - - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - - pMatCurrent -= (4 - remainder); - pSrcCurrent -= (4 - remainder); - - Vector128 mask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); - - float* pMatTemp = pMatCurrent; - Vector128 x01 = Sse.And(mask, Sse.LoadVector128(pMatTemp)); - Vector128 x11 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x21 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x31 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 vector = Sse.And(mask, Sse.LoadVector128(pSrcCurrent)); - - res0 = Sse.Add(res0, Sse.Multiply(x01, vector)); - res1 = Sse.Add(res1, Sse.Multiply(x11, vector)); - res2 = Sse.Add(res2, Sse.Multiply(x21, vector)); - res3 = Sse.Add(res3, Sse.Multiply(x31, vector)); - - pMatCurrent += 4; - pSrcCurrent += 4; - } - } - - // Add up the entries of each, with the 4 results in res0 - res0 = Sse3.HorizontalAdd(res0, res1); - res2 = Sse3.HorizontalAdd(res2, res3); - res0 = Sse3.HorizontalAdd(res0, res2); - - Sse.Store(pDstCurrent, res0); - pDstCurrent += 4; - pMatCurrent += 3 * ccol; - } - } - - while (pDstCurrent < pDstEnd) + while (pDstCurrent + 4 <= pDstEnd) { Vector128 res0 = Sse.SetZeroVector128(); + Vector128 res1 = Sse.SetZeroVector128(); + Vector128 res2 = Sse.SetZeroVector128(); + Vector128 res3 = Sse.SetZeroVector128(); int length = ccol; float* pSrcCurrent = psrc; @@ -280,16 +144,24 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr int misalignment = (int)(address % 16); int remainder = 0; - if ((misalignment & 3) != 0) + if ((misalignment & 3) != 0 || (ccol % 4 != 0)) { // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pSrcCurrent < pSrcEnd) + remainder = length % 4; + while (pSrcCurrent + 4 <= pSrcEnd) { Vector128 vector = Sse.LoadVector128(pSrcCurrent); float* pMatTemp = pMatCurrent; Vector128 x01 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp)); + Vector128 x11 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x21 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x31 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + res0 = Sse.Add(res0, x01); + res1 = Sse.Add(res1, x11); + res2 = Sse.Add(res2, x21); + res3 = Sse.Add(res3, x31); pSrcCurrent += 4; pMatCurrent += 4; @@ -309,9 +181,15 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr // We only align pMat since it has significantly more reads. float* pMatTemp = pMatCurrent; Vector128 x01 = Sse.And(mask, Sse.LoadVector128(pMatTemp)); + Vector128 x11 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x21 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x31 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); Vector128 vector = Sse.And(mask, Sse.LoadVector128(pSrcCurrent)); res0 = Sse.Multiply(x01, vector); + res1 = Sse.Multiply(x11, vector); + res2 = Sse.Multiply(x21, vector); + res3 = Sse.Multiply(x31, vector); pMatCurrent += misalignment; pSrcCurrent += misalignment; @@ -332,7 +210,14 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr float* pMatTemp = pMatCurrent; Vector128 x01 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp)); + Vector128 x11 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); + Vector128 x21 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); + Vector128 x31 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); + res0 = Sse.Add(res0, x01); + res1 = Sse.Add(res1, x11); + res2 = Sse.Add(res2, x21); + res3 = Sse.Add(res3, x31); pSrcCurrent += 4; pMatCurrent += 4; @@ -345,37 +230,89 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr // unaligned loads where we mask the input each time. remainder = length; } + } - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed + if (remainder != 0) + { + // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next + // unaligned load will read to the end of the array and then mask out any elements already processed - pMatCurrent -= (4 - remainder); - pSrcCurrent -= (4 - remainder); + pMatCurrent -= (4 - remainder); + pSrcCurrent -= (4 - remainder); - Vector128 mask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); + Vector128 mask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); - float* pMatTemp = pMatCurrent; - Vector128 x01 = Sse.And(mask, Sse.LoadVector128(pMatTemp)); - Vector128 vector = Sse.And(mask, Sse.LoadVector128(pSrcCurrent)); + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.And(mask, Sse.LoadVector128(pMatTemp)); + Vector128 x11 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x21 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x31 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 vector = Sse.And(mask, Sse.LoadVector128(pSrcCurrent)); - res0 = Sse.Add(res0, Sse.Multiply(x01, vector)); + res0 = Sse.Add(res0, Sse.Multiply(x01, vector)); + res1 = Sse.Add(res1, Sse.Multiply(x11, vector)); + res2 = Sse.Add(res2, Sse.Multiply(x21, vector)); + res3 = Sse.Add(res3, Sse.Multiply(x31, vector)); - pMatCurrent += 4; - pSrcCurrent += 4; - } + pMatCurrent += 4; + pSrcCurrent += 4; } // Add up the entries of each, with the 4 results in res0 - res0 = VectorSum128(in res0); - float sum = Sse.ConvertToSingle(res0); - *pDstCurrent = sum; + res0 = Sse3.HorizontalAdd(res0, res1); + res2 = Sse3.HorizontalAdd(res2, res3); + res0 = Sse3.HorizontalAdd(res0, res2); Sse.Store(pDstCurrent, res0); - pDstCurrent += 1; + pDstCurrent += 4; + pMatCurrent += 3 * ccol; + numRows -= 4; } + + // falling through the case statements + switch (numRows) + { + case 3: + *(pDstCurrent + 2) = RowMultiply(pMatCurrent + 2 * ccol, psrc, pSrcEnd, ccol); + goto case 2; + case 2: + *(pDstCurrent + 1) = RowMultiply(pMatCurrent + ccol, psrc, pSrcEnd, ccol); + goto case 1; + case 1: + *pDstCurrent = RowMultiply(pMatCurrent, psrc, pSrcEnd, ccol); + break; + } + } + } + + private static unsafe float RowMultiply(float* pMatCurrent, float* pSrcCurrent, float* pSrcEnd, int ccol) + { + Vector128 res0 = Sse.SetZeroVector128(); + int remainder = ccol % 4; + while (pSrcCurrent + 4 <= pSrcEnd) + { + Vector128 vector = Sse.LoadVector128(pSrcCurrent); + + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp)); + res0 = Sse.Add(res0, x01); + + pSrcCurrent += 4; + pMatCurrent += 4; + } + + // Add up the entries of each, with the 4 results in res0 + res0 = VectorSum128(in res0); + float sum = Sse.ConvertToSingle(res0); + + // falling through the case statements + switch (remainder) + { + case 3: sum += *(pSrcCurrent + 2) * *(pMatCurrent + 2); goto case 2; + case 2: sum += *(pSrcCurrent + 1) * *(pMatCurrent + 1); goto case 1; + case 1: sum += *(pSrcCurrent) * *(pMatCurrent); break; } + return sum; } // Partial sparse source vector. @@ -547,6 +484,8 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan mat, ReadOnlySpan x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); @@ -576,7 +516,6 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan mat, ReadOnlySpan trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); - - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp)); - x02 = Sse.Multiply(x01, x02); - - Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); - Vector128 x3 = Sse.LoadVector128(pDstCurrent); - x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); - - Sse.Store(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } } - pSrcCurrent += 1; - } - - if (crow % 4 == 0) - { - // We do 4-way unrolling - while (pSrcCurrent + 4 <= pSrcEnd) + if (remainder != 0) { - Vector128 x01 = Sse.LoadVector128(pSrcCurrent); - // Replicate each 32-bit slot of x01 (ABCD) into its own register. - Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B - Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C - Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D - x01 = Sse.Shuffle(x01, x01, 0x00); // A - - int length = crow; - float* pDstCurrent = pdst; - - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 16); - - if ((misalignment & 3) != 0) - { - while (pDstCurrent < pDstEnd) - { - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.Multiply(x11, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.Multiply(x21, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.Multiply(x31, Sse.LoadVector128(pMatTemp += crow)); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - - x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); - - Sse.Store(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } - } - else - { - int remainder = 0; - if (misalignment != 0) - { - // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 4 - misalignment; - - Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); - - // We only align pMat since it has significantly more reads. - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - - Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + ((4 - misalignment) * 4)); - Vector128 x3 = Sse.LoadVector128(pDstCurrent); - x02 = Sse.Add(x02, Sse.And(x3, trailingMask)); - - x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); - - Sse.Store(pDstCurrent, x02); - pMatCurrent += misalignment; - pDstCurrent += misalignment; - length -= misalignment; - } - if (length > 3) - { - remainder = length % 4; - while (pDstCurrent + 4 <= pDstEnd) - { - float* pMatTemp = pMatCurrent; - - Vector128 x02 = Sse.Multiply(x01, Sse.LoadAlignedVector128(pMatTemp)); - Vector128 x12 = Sse.Multiply(x11, Sse.LoadAlignedVector128(pMatTemp += crow)); - Vector128 x22 = Sse.Multiply(x21, Sse.LoadAlignedVector128(pMatTemp += crow)); - Vector128 x32 = Sse.Multiply(x31, Sse.LoadAlignedVector128(pMatTemp += crow)); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - - x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); - Sse.Store(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } - } - else - { - remainder = length; - } - - if (remainder != 0) - { - pMatCurrent -= (4 - remainder); - pDstCurrent -= (4 - remainder); - Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); - - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); + // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next + // unaligned load will read to the end of the array and then mask out any elements already processed + pMatCurrent -= (4 - remainder); + pDstCurrent -= (4 - remainder); - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); + Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); - Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); - Vector128 x3 = Sse.LoadVector128(pDstCurrent); - x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp)); + x02 = Sse.Multiply(x01, x02); - x02 = Sse.Add(x02, Sse.And(x3, trailingMask)); - Sse.Store(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } - } + Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); + Vector128 x3 = Sse.LoadVector128(pDstCurrent); + x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); - pMatCurrent += 3 * crow; - pSrcCurrent += 4; + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; } + + numCol -= 1; + pSrcCurrent += 1; } - while (pSrcCurrent < pSrcEnd) + // We do 4-way unrolling + while (pSrcCurrent + 4 <= pSrcEnd) { Vector128 x01 = Sse.LoadVector128(pSrcCurrent); // Replicate each 32-bit slot of x01 (ABCD) into its own register. + Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B + Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C + Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D x01 = Sse.Shuffle(x01, x01, 0x00); // A int length = crow; @@ -805,12 +607,20 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.Multiply(x11, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.Multiply(x21, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.Multiply(x31, Sse.LoadVector128(pMatTemp += crow)); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); @@ -821,7 +631,6 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan mat, ReadOnlySpan x02 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + ((4 - misalignment) * 4)); Vector128 x3 = Sse.LoadVector128(pDstCurrent); @@ -852,14 +672,18 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); - x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); + Vector128 x02 = Sse.Multiply(x01, Sse.LoadAlignedVector128(pMatTemp)); + Vector128 x12 = Sse.Multiply(x11, Sse.LoadAlignedVector128(pMatTemp += crow)); + Vector128 x22 = Sse.Multiply(x21, Sse.LoadAlignedVector128(pMatTemp += crow)); + Vector128 x32 = Sse.Multiply(x31, Sse.LoadAlignedVector128(pMatTemp += crow)); + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); Sse.Store(pDstCurrent, x02); pDstCurrent += 4; pMatCurrent += 4; @@ -869,34 +693,86 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan 0) - { - pMatCurrent -= (4 - remainder); - pDstCurrent -= (4 - remainder); - Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); + if (remainder != 0) + { + pMatCurrent -= (4 - remainder); + pDstCurrent -= (4 - remainder); + Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp)); - x02 = Sse.Multiply(x01, x02); + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); - Vector128 x3 = Sse.LoadVector128(pDstCurrent); - x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); - x02 = Sse.Add(x02, Sse.And(x3, trailingMask)); + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); - Sse.Store(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } + Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); + Vector128 x3 = Sse.LoadVector128(pDstCurrent); + x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); - pSrcCurrent += 1; + x02 = Sse.Add(x02, Sse.And(x3, trailingMask)); + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; } + + pMatCurrent += 3 * crow; + pSrcCurrent += 4; + numCol -= 4; + } + + // falling through the case statements + switch (numCol) + { + case 3: ColumnMultiply(pMatCurrent + 2 * crow, pSrcCurrent + 2, pdst, pDstEnd, crow); goto case 2; + case 2: ColumnMultiply(pMatCurrent + crow, pSrcCurrent + 1, pdst, pDstEnd, crow); goto case 1; + case 1: ColumnMultiply(pMatCurrent, pSrcCurrent, pdst, pDstEnd, crow); break; } } } + private static unsafe void ColumnMultiply(float* pMatCurrent, float* pSrcCurrent, float* pdst, float* pDstEnd, int crow) + { + Vector128 x01 = Sse.LoadVector128(pSrcCurrent); + // Replicate each slot of h01 (ABCD) into its own register. + x01 = Sse.Shuffle(x01, x01, 0x00); // A + int remainder = crow % 4; + float* pDstCurrent = pdst; + + while (pDstCurrent + 4 <= pDstEnd) + { + // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads + // (due to semantics of the legacy encoding). + // We don't need an assert, since the instruction will throw for unaligned inputs. + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); + x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); + + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } + + // falling through the case statements + switch (remainder) + { + case 3: *(pDstCurrent + 2) += *(pSrcCurrent) * *(pMatCurrent + 2); goto case 2; + case 2: *(pDstCurrent + 1) += *(pSrcCurrent) * *(pMatCurrent + 1); goto case 1; + case 1: *pDstCurrent += *(pSrcCurrent) * *(pMatCurrent); break; + } + return; + } + // dst[i] += scale public static unsafe void AddScalarU(float scalar, Span dst) { diff --git a/src/Native/CpuMathNative/Sse.cpp b/src/Native/CpuMathNative/Sse.cpp index 74a915de70..2240d630eb 100644 --- a/src/Native/CpuMathNative/Sse.cpp +++ b/src/Native/CpuMathNative/Sse.cpp @@ -57,6 +57,40 @@ const unsigned int TrailingAlignmentMask[16] = 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, }; +float RowMultiply(const float* pMatCurrent, const float* pSrcCurrent, const float* pSrcEnd, int ccol) +{ + __m128 res0 = _mm_setzero_ps(); + int remainder = ccol % 4; + + while (pSrcCurrent + 4 <= pSrcEnd) + { + __m128 vector = _mm_loadu_ps(pSrcCurrent); + const float* pMatTemp = pMatCurrent; + __m128 x01 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp)); + res0 = _mm_add_ps(res0, x01); + + pSrcCurrent += 4; + pMatCurrent += 4; + } + + res0 = _mm_hadd_ps(res0, res0); + res0 = _mm_hadd_ps(res0, res0); + + float sum = _mm_cvtss_f32(res0); + + // falling through the case statements + switch (remainder) + { + case 3: + sum += *(pSrcCurrent + 2) * *(pMatCurrent + 2); + case 2: + sum += *(pSrcCurrent + 1) * *(pMatCurrent + 1); + case 1: + sum += *(pSrcCurrent) * *(pMatCurrent); + } + return sum; +} + // Multiply matrix times vector into vector. EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) { @@ -83,148 +117,14 @@ EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout const float * pDstEnd = pdst + crow; float* pDstCurrent = pdst; const float* pMatCurrent = pmat; + int numRows = crow; - if (ccol % 4 == 0) - { - while (pDstCurrent + 4 <= pDstEnd) - { - __m128 res0 = _mm_setzero_ps(); - __m128 res1 = res0; - __m128 res2 = res0; - __m128 res3 = res0; - - int length = ccol; - const float* pSrcCurrent = psrc; - - uintptr_t address = (uintptr_t)(pMatCurrent); - uintptr_t misalignment = address % 16; - int remainder = 0; - - if ((misalignment & 3) != 0) - { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pSrcCurrent < pSrcEnd) - { - __m128 vector = _mm_loadu_ps(pSrcCurrent); - - const float* pMatTemp = pMatCurrent; - __m128 x01 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp)); - __m128 x11 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x21 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x31 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); - - res0 = _mm_add_ps(res0, x01); - res1 = _mm_add_ps(res1, x11); - res2 = _mm_add_ps(res2, x21); - res3 = _mm_add_ps(res3, x31); - - pSrcCurrent += 4; - pMatCurrent += 4; - } - } - else - { - if (misalignment != 0) - { - // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 4 - misalignment; - - __m128 mask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); - - // We only align pMat since it has significantly more reads. - const float* pMatTemp = pMatCurrent; - __m128 x01 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp)); - __m128 x11 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x21 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x31 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); - __m128 vector = _mm_and_ps(mask, _mm_loadu_ps(pSrcCurrent)); - - res0 = _mm_mul_ps(x01, vector); - res1 = _mm_mul_ps(x11, vector); - res2 = _mm_mul_ps(x21, vector); - res3 = _mm_mul_ps(x31, vector); - - pMatCurrent += misalignment; - pSrcCurrent += misalignment; - length -= misalignment; - } - - if (length > 3) - { - // Handle all the 128-bit blocks that we can now that we have offset to an aligned address - remainder = length % 4; - - while (pSrcCurrent + 4 <= pSrcEnd) - { - __m128 vector = _mm_loadu_ps(pSrcCurrent); - - const float* pMatTemp = pMatCurrent; - __m128 x01 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp)); - __m128 x11 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); - __m128 x21 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); - __m128 x31 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); - - res0 = _mm_add_ps(res0, x01); - res1 = _mm_add_ps(res1, x11); - res2 = _mm_add_ps(res2, x21); - res3 = _mm_add_ps(res3, x31); - - pSrcCurrent += 4; - pMatCurrent += 4; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not - // 128-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; - } - - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - - pMatCurrent -= (4 - remainder); - pSrcCurrent -= (4 - remainder); - - __m128 mask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); - - const float* pMatTemp = pMatCurrent; - __m128 x01 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp)); - __m128 x11 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x21 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x31 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); - __m128 vector = _mm_and_ps(mask, _mm_loadu_ps(pSrcCurrent)); - - res0 = _mm_add_ps(res0, _mm_mul_ps(x01, vector)); - res1 = _mm_add_ps(res1, _mm_mul_ps(x11, vector)); - res2 = _mm_add_ps(res2, _mm_mul_ps(x21, vector)); - res3 = _mm_add_ps(res3, _mm_mul_ps(x31, vector)); - - pMatCurrent += 4; - pSrcCurrent += 4; - } - } - - // Add up the entries of each, with the 4 results in res0 - res0 = _mm_hadd_ps(res0, res1); - res2 = _mm_hadd_ps(res2, res3); - res0 = _mm_hadd_ps(res0, res2); - - _mm_storeu_ps(pDstCurrent, res0); - - pDstCurrent += 4; - pMatCurrent += 3 * ccol; - } - } - - while (pDstCurrent < pDstEnd) + while (pDstCurrent + 4 <= pDstEnd) { __m128 res0 = _mm_setzero_ps(); + __m128 res1 = res0; + __m128 res2 = res0; + __m128 res3 = res0; int length = ccol; const float* pSrcCurrent = psrc; @@ -233,16 +133,24 @@ EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout uintptr_t misalignment = address % 16; int remainder = 0; - if ((misalignment & 3) != 0) + if ((misalignment & 3) != 0 || (ccol % 4 != 0)) { // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pSrcCurrent < pSrcEnd) + remainder = length % 4; + while (pSrcCurrent + 4 <= pSrcEnd) { __m128 vector = _mm_loadu_ps(pSrcCurrent); const float* pMatTemp = pMatCurrent; __m128 x01 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp)); + __m128 x11 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x21 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x31 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); + res0 = _mm_add_ps(res0, x01); + res1 = _mm_add_ps(res1, x11); + res2 = _mm_add_ps(res2, x21); + res3 = _mm_add_ps(res3, x31); pSrcCurrent += 4; pMatCurrent += 4; @@ -262,8 +170,15 @@ EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout // We only align pMat since it has significantly more reads. const float* pMatTemp = pMatCurrent; __m128 x01 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp)); + __m128 x11 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x21 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x31 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); __m128 vector = _mm_and_ps(mask, _mm_loadu_ps(pSrcCurrent)); + res0 = _mm_mul_ps(x01, vector); + res1 = _mm_mul_ps(x11, vector); + res2 = _mm_mul_ps(x21, vector); + res3 = _mm_mul_ps(x31, vector); pMatCurrent += misalignment; pSrcCurrent += misalignment; @@ -281,7 +196,14 @@ EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout const float* pMatTemp = pMatCurrent; __m128 x01 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp)); + __m128 x11 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); + __m128 x21 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); + __m128 x31 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); + res0 = _mm_add_ps(res0, x01); + res1 = _mm_add_ps(res1, x11); + res2 = _mm_add_ps(res2, x21); + res3 = _mm_add_ps(res3, x31); pSrcCurrent += 4; pMatCurrent += 4; @@ -294,34 +216,56 @@ EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout // unaligned loads where we mask the input each time. remainder = length; } + } - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed + if (remainder != 0) + { + // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next + // unaligned load will read to the end of the array and then mask out any elements already processed - pMatCurrent -= (4 - remainder); - pSrcCurrent -= (4 - remainder); + pMatCurrent -= (4 - remainder); + pSrcCurrent -= (4 - remainder); - __m128 mask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); + __m128 mask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); - const float* pMatTemp = pMatCurrent; - __m128 x01 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp)); - __m128 vector = _mm_and_ps(mask, _mm_loadu_ps(pSrcCurrent)); + const float* pMatTemp = pMatCurrent; + __m128 x01 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp)); + __m128 x11 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x21 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x31 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 vector = _mm_and_ps(mask, _mm_loadu_ps(pSrcCurrent)); - res0 = _mm_add_ps(res0, _mm_mul_ps(x01, vector)); + res0 = _mm_add_ps(res0, _mm_mul_ps(x01, vector)); + res1 = _mm_add_ps(res1, _mm_mul_ps(x11, vector)); + res2 = _mm_add_ps(res2, _mm_mul_ps(x21, vector)); + res3 = _mm_add_ps(res3, _mm_mul_ps(x31, vector)); - pMatCurrent += 4; - pSrcCurrent += 4; - } + pMatCurrent += 4; + pSrcCurrent += 4; } - // Sum all the elements together and return the result - res0 = _mm_hadd_ps(res0, res0); - res0 = _mm_hadd_ps(res0, res0); + // Add up the entries of each, with the 4 results in res0 + res0 = _mm_hadd_ps(res0, res1); + res2 = _mm_hadd_ps(res2, res3); + res0 = _mm_hadd_ps(res0, res2); + + _mm_storeu_ps(pDstCurrent, res0); - *pDstCurrent = _mm_cvtss_f32(res0); - pDstCurrent += 1; + pDstCurrent += 4; + pMatCurrent += 3 * ccol; + numRows -= 4; + } + + // falling through the case statements + switch(numRows) + { + case 3: + *(pDstCurrent + 2) = RowMultiply(pMatCurrent + 2 * ccol, psrc, pSrcEnd, ccol); + case 2: + *(pDstCurrent + 1) = RowMultiply(pMatCurrent + 1 * ccol, psrc, pSrcEnd, ccol); + case 1: + *pDstCurrent = RowMultiply(pMatCurrent, psrc, pSrcEnd, ccol); + break; } } @@ -480,6 +424,35 @@ EXPORT_API(void) MatMulP(_In_ const float * pmat, _In_ const int * pposSrc, _In_ } } +void ColumnMultiply(const float* pMatCurrent, const float* pSrcCurrent, float* pdst, const float* pDstEnd, int crow) +{ + __m128 x01 = _mm_loadu_ps(pSrcCurrent); + x01 = _mm_shuffle_ps(x01, x01, 0x00); + float* pDstCurrent = pdst; + int remainder = crow % 4; + + while (pDstCurrent + 4 <= pDstEnd) + { + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_mul_ps(x01, _mm_loadu_ps(pMatTemp)); + x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); + + _mm_storeu_ps(pDstCurrent, x02); + + pDstCurrent += 4; + pMatCurrent += 4; + } + + // falling through the case statements + switch (remainder) + { + case 3: *(pDstCurrent + 2) += *(pSrcCurrent) * *(pMatCurrent + 2); + case 2: *(pDstCurrent + 1) += *(pSrcCurrent) * *(pMatCurrent + 1); + case 1: *pDstCurrent += *(pSrcCurrent) * *(pMatCurrent); break; + } + return; +} + EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) { if (crow < 4) @@ -501,6 +474,8 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I const float* pMatCurrent = pmat; const float* pSrcCurrent = psrc; + int remainder = 0; + int numCol = ccol; if (pSrcCurrent < pSrcEnd) { @@ -514,10 +489,11 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I uintptr_t misalignment = address % 16; int remainder = 0; - if ((misalignment & 3) != 0) + if ((misalignment & 3) != 0 || (crow % 4 != 0)) { // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pDstCurrent < pDstEnd) + remainder = crow % 4; + while (pDstCurrent + 4 <= pDstEnd) { const float* pMatTemp = pMatCurrent; __m128 x02 = _mm_mul_ps(x01, _mm_loadu_ps(pMatTemp)); @@ -529,7 +505,6 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I } else { - int remainder = 0; if (misalignment != 0) { // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then @@ -577,179 +552,42 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I // unaligned loads where we mask the input each time. remainder = length; } - - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - - pMatCurrent -= (4 - remainder); - pDstCurrent -= (4 - remainder); - - __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); - - const float* pMatTemp = pMatCurrent; - __m128 x02 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp)); - x02 = _mm_mul_ps(x01, x02); - - __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + ((4 - remainder) * 4)); - __m128 x3 = _mm_loadu_ps(pDstCurrent); - x02 = _mm_add_ps(x02, _mm_and_ps(x3, leadingMask)); - - _mm_storeu_ps(pDstCurrent, x02); - pMatCurrent += 4; - pDstCurrent += 4; - } } - pSrcCurrent += 1; - } - - if (crow % 4 == 0) - { - while (pSrcCurrent + 4 <= pSrcEnd) + if (remainder != 0) { - __m128 x01 = _mm_loadu_ps(pSrcCurrent); - // Replicate each slot of x01 into its own register. - __m128 x11 = _mm_shuffle_ps(x01, x01, 0x55); - __m128 x21 = _mm_shuffle_ps(x01, x01, 0xAA); - __m128 x31 = _mm_shuffle_ps(x01, x01, 0xFF); - x01 = _mm_shuffle_ps(x01, x01, 0x00); - - int length = crow; - float* pDstCurrent = pdst; - - uintptr_t address = (uintptr_t)(pMatCurrent); - uintptr_t misalignment = address % 16; - int remainder = 0; - - if ((misalignment & 3) != 0) - { - while (pDstCurrent < pDstEnd) - { - const float* pMatTemp = pMatCurrent; - __m128 x02 = _mm_mul_ps(x01, _mm_loadu_ps(pMatTemp)); - __m128 x12 = _mm_mul_ps(x11, _mm_loadu_ps(pMatTemp += crow)); - __m128 x22 = _mm_mul_ps(x21, _mm_loadu_ps(pMatTemp += crow)); - __m128 x32 = _mm_mul_ps(x31, _mm_loadu_ps(pMatTemp += crow)); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); - - x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); - - _mm_storeu_ps(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } - } - else - { - int remainder = 0; - if (misalignment != 0) - { - misalignment >>= 2; - misalignment = 4 - misalignment; - - __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); - - // We only align pMat since it has significantly more reads. - const float* pMatTemp = pMatCurrent; - __m128 x02 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp)); - __m128 x12 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x22 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x32 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); - - x02 = _mm_mul_ps(x01, x02); - x12 = _mm_mul_ps(x11, x12); - x22 = _mm_mul_ps(x21, x22); - x32 = _mm_mul_ps(x31, x32); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); - - __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + ((4 - misalignment) * 4)); - __m128 x3 = _mm_loadu_ps(pDstCurrent); - x02 = _mm_add_ps(x02, _mm_and_ps(x3, trailingMask)); - x02 = _mm_add_ps(x02, _mm_and_ps(x3, leadingMask)); - - _mm_storeu_ps(pDstCurrent, x02); - pMatCurrent += misalignment; - pDstCurrent += misalignment; - length -= misalignment; - } - - if (length > 3) - { - remainder = length % 4; - while (pDstCurrent + 4 <= pDstEnd) - { - const float* pMatTemp = pMatCurrent; - __m128 x02 = _mm_mul_ps(x01, _mm_load_ps(pMatTemp)); - __m128 x12 = _mm_mul_ps(x11, _mm_load_ps(pMatTemp += crow)); - __m128 x22 = _mm_mul_ps(x21, _mm_load_ps(pMatTemp += crow)); - __m128 x32 = _mm_mul_ps(x31, _mm_load_ps(pMatTemp += crow)); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); - - x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); - - _mm_storeu_ps(pDstCurrent, x02); - - pDstCurrent += 4; - pMatCurrent += 4; - } - } - else - { - remainder = length; - } - - if (remainder != 0) - { - pMatCurrent -= (4 - remainder); - pDstCurrent -= (4 - remainder); - - __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); - - const float* pMatTemp = pMatCurrent; - __m128 x02 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp)); - __m128 x12 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x22 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x32 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); + // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next + // unaligned load will read to the end of the array and then mask out any elements already processed - x02 = _mm_mul_ps(x01, x02); - x12 = _mm_mul_ps(x11, x12); - x22 = _mm_mul_ps(x21, x22); - x32 = _mm_mul_ps(x31, x32); + pMatCurrent -= (4 - remainder); + pDstCurrent -= (4 - remainder); - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); + __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); - __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + ((4 - remainder) * 4)); - __m128 x3 = _mm_loadu_ps(pDstCurrent); - x02 = _mm_add_ps(x02, _mm_and_ps(x3, leadingMask)); + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp)); + x02 = _mm_mul_ps(x01, x02); - x02 = _mm_add_ps(x02, _mm_and_ps(x3, trailingMask)); - _mm_storeu_ps(pDstCurrent, x02); - pMatCurrent += 4; - pDstCurrent += 4; - } - } + __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + ((4 - remainder) * 4)); + __m128 x3 = _mm_loadu_ps(pDstCurrent); + x02 = _mm_add_ps(x02, _mm_and_ps(x3, leadingMask)); - pMatCurrent += 3 * crow; - pSrcCurrent += 4; + _mm_storeu_ps(pDstCurrent, x02); + pMatCurrent += 4; + pDstCurrent += 4; } - } - while (pSrcCurrent < pSrcEnd) + numCol -= 1; + pSrcCurrent += 1; + } + + while (pSrcCurrent + 4 <= pSrcEnd) { __m128 x01 = _mm_loadu_ps(pSrcCurrent); + // Replicate each slot of x01 into its own register. + __m128 x11 = _mm_shuffle_ps(x01, x01, 0x55); + __m128 x21 = _mm_shuffle_ps(x01, x01, 0xAA); + __m128 x31 = _mm_shuffle_ps(x01, x01, 0xFF); x01 = _mm_shuffle_ps(x01, x01, 0x00); int length = crow; @@ -757,13 +595,23 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I uintptr_t address = (uintptr_t)(pMatCurrent); uintptr_t misalignment = address % 16; + int remainder = 0; - if ((misalignment & 3) != 0) + if ((misalignment & 3) != 0 || (crow % 4 != 0)) { - while (pDstCurrent < pDstEnd) + remainder = length % 4; + while (pDstCurrent + 4 <= pDstEnd) { const float* pMatTemp = pMatCurrent; __m128 x02 = _mm_mul_ps(x01, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_mul_ps(x11, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_mul_ps(x21, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_mul_ps(x31, _mm_loadu_ps(pMatTemp += crow)); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); _mm_storeu_ps(pDstCurrent, x02); @@ -773,7 +621,6 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I } else { - int remainder = 0; if (misalignment != 0) { misalignment >>= 2; @@ -784,8 +631,18 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I // We only align pMat since it has significantly more reads. const float* pMatTemp = pMatCurrent; __m128 x02 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); x02 = _mm_mul_ps(x01, x02); + x12 = _mm_mul_ps(x11, x12); + x22 = _mm_mul_ps(x21, x22); + x32 = _mm_mul_ps(x31, x32); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + ((4 - misalignment) * 4)); __m128 x3 = _mm_loadu_ps(pDstCurrent); @@ -805,6 +662,14 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I { const float* pMatTemp = pMatCurrent; __m128 x02 = _mm_mul_ps(x01, _mm_load_ps(pMatTemp)); + __m128 x12 = _mm_mul_ps(x11, _mm_load_ps(pMatTemp += crow)); + __m128 x22 = _mm_mul_ps(x21, _mm_load_ps(pMatTemp += crow)); + __m128 x32 = _mm_mul_ps(x31, _mm_load_ps(pMatTemp += crow)); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); _mm_storeu_ps(pDstCurrent, x02); @@ -817,30 +682,51 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I { remainder = length; } + } - if (remainder != 0) - { - pMatCurrent -= (4 - remainder); - pDstCurrent -= (4 - remainder); + if (remainder != 0) + { + pMatCurrent -= (4 - remainder); + pDstCurrent -= (4 - remainder); - __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); + __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); - const float* pMatTemp = pMatCurrent; - __m128 x02 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp)); - x02 = _mm_mul_ps(x01, x02); + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + ((4 - remainder) * 4)); - __m128 x3 = _mm_loadu_ps(pDstCurrent); - x02 = _mm_add_ps(x02, _mm_and_ps(x3, leadingMask)); + x02 = _mm_mul_ps(x01, x02); + x12 = _mm_mul_ps(x11, x12); + x22 = _mm_mul_ps(x21, x22); + x32 = _mm_mul_ps(x31, x32); - x02 = _mm_add_ps(x02, _mm_and_ps(x3, trailingMask)); - _mm_storeu_ps(pDstCurrent, x02); - pMatCurrent += 4; - pDstCurrent += 4; - } + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + ((4 - remainder) * 4)); + __m128 x3 = _mm_loadu_ps(pDstCurrent); + x02 = _mm_add_ps(x02, _mm_and_ps(x3, leadingMask)); + + x02 = _mm_add_ps(x02, _mm_and_ps(x3, trailingMask)); + _mm_storeu_ps(pDstCurrent, x02); + pMatCurrent += 4; + pDstCurrent += 4; } - pSrcCurrent += 1; + numCol -= 4; + pMatCurrent += 3 * crow; + pSrcCurrent += 4; + } + + // falling through the case statements + switch (numCol) + { + case 3: ColumnMultiply(pMatCurrent + 2 * crow, pSrcCurrent + 2, pdst, pDstEnd, crow); + case 2: ColumnMultiply(pMatCurrent + crow, pSrcCurrent + 1, pdst, pDstEnd, crow); + case 1: ColumnMultiply(pMatCurrent, pSrcCurrent, pdst, pDstEnd, crow); break; } } diff --git a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs index c3d378751e..5adb641b92 100644 --- a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs +++ b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs @@ -96,7 +96,6 @@ public void MatMulTest(int matTest, int srcTest, int dstTest, float[] expected) [InlineData(4, 9)] [InlineData(5, 7)] [InlineData(5, 9)] - private void MatMulAnyDimensionTest(int col, int row) { Random rand = new Random(DefaultSeed);